GestureLSM / trainer /base_trainer.py
Tharun156's picture
Upload 149 files
f7400bf verified
# from system_utils import get_gpt_id
# dev = get_gpt_id()
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import signal
import time
import csv
import sys
import warnings
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import numpy as np
import time
import pprint
from loguru import logger
import smplx
from torch.utils.tensorboard import SummaryWriter
import wandb
import matplotlib.pyplot as plt
from utils import logger_tools, other_tools, metric
import shutil
import argparse
from omegaconf import OmegaConf
from datetime import datetime
import importlib
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data._utils.collate import default_collate
from dataloaders.build_vocab import Vocab
class BaseTrainer(object):
def __init__(self, cfg, args):
self.cfg = cfg
self.args = args
self.rank = 0
self.checkpoint_path = os.path.join(cfg.output_dir, cfg.exp_name)
# Initialize best metrics tracking
self.val_best = {
"fgd": {"value": float('inf'), "epoch": 0}, # Add fgd if not present
"l1div": {"value": float('-inf'), "epoch": 0}, # Higher is better, so start with -inf
"bc": {"value": float('-inf'), "epoch": 0}, # Higher is better, so start with -inf
"test_clip_fgd": {"value": float('inf'), "epoch": 0},
}
self.train_data = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg.data, loader_type='train')
self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.train_data)
self.train_loader = DataLoader(self.train_data, batch_size=cfg.data.train_bs, sampler=self.train_sampler, drop_last=True, num_workers=4)
if cfg.data.test_clip:
# test data for test_clip, only used for test_clip_fgd
self.test_clip_data = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg.data, loader_type='test')
self.test_clip_loader = DataLoader(self.test_clip_data, batch_size=64, drop_last=False)
# test data for fgd, l1div and bc
test_data_cfg = cfg.data.copy()
test_data_cfg.test_clip = False
self.test_data = init_class(cfg.data.name_pyfile, cfg.data.class_name, test_data_cfg, loader_type='test')
self.test_loader = DataLoader(self.test_data, batch_size=1, drop_last=False)
self.train_length = len(self.train_loader)
logger.info(f"Init train andtest dataloader successfully")
if args.mode == "train":
# Setup logging with wandb
if self.rank == 0:
run_time = datetime.now().strftime("%Y%m%d-%H%M")
run_name = cfg.exp_name + "_" + run_time
if hasattr(cfg, 'resume_from_checkpoint') and cfg.resume_from_checkpoint:
run_name += f"_resumed"
wandb.init(
project=cfg.wandb_project,
name=run_name,
entity=cfg.wandb_entity,
dir=cfg.wandb_log_dir,
config=OmegaConf.to_container(cfg)
)
eval_model_module = __import__(f"models.motion_representation", fromlist=["something"])
eval_args = type('Args', (), {})()
eval_args.vae_layer = 4
eval_args.vae_length = 240
eval_args.vae_test_dim = 330
eval_args.variational = False
eval_args.data_path_1 = "./datasets/hub/"
eval_args.vae_grow = [1,1,2,1]
eval_copy = getattr(eval_model_module, 'VAESKConv')(eval_args)
other_tools.load_checkpoints(
eval_copy,
'./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/weights/AESKConv_240_100.bin',
'VAESKConv'
)
self.eval_copy = eval_copy
self.smplx = smplx.create(
self.cfg.data.data_path_1+"smplx_models/",
model_type='smplx',
gender='NEUTRAL_2020',
use_face_contour=False,
num_betas=300,
num_expression_coeffs=100,
ext='npz',
use_pca=False,
).eval()
self.alignmenter = metric.alignment(0.3, 7, self.train_data.avg_vel, upper_body=[3,6,9,12,13,14,15,16,17,18,19,20,21]) if self.rank == 0 else None
self.align_mask = 60
self.l1_calculator = metric.L1div() if self.rank == 0 else None
def train_recording(self, epoch, its, t_data, t_train, mem_cost, lr_g, lr_d=None):
"""Enhanced training metrics logging"""
metrics = {}
# Collect all metrics
for name, states in self.tracker.loss_meters.items():
metric = states['train']
if metric.count > 0:
value = metric.avg
metrics[name] = value
metrics[f"train/{name}"] = value
# Add learning rates and memory usage
metrics.update({
"train/learning_rate": lr_g,
"train/data_time_ms": t_data*1000,
"train/train_time_ms": t_train*1000,
})
# Log all metrics at once if using wandb
wandb.log(metrics, step=epoch*self.train_length+its)
# Print progress
pstr = f"[{epoch:03d}][{its:03d}/{self.train_length:03d}] "
pstr += " ".join([f"{k}: {v:.3f}" for k, v in metrics.items() if "train/" not in k])
logger.info(pstr)
def val_recording(self, epoch):
"""Enhanced validation metrics logging"""
metrics = {}
# Process all validation metrics
for name, states in self.tracker.loss_meters.items():
metric = states['val']
if metric.count > 0:
value = float(metric.avg) if metric.count > 0 else float(metric.sum)
metrics[f"val/{name}"] = value
# Compare with best values to track best performance
if name in self.val_best:
current_best = self.val_best[name]["value"]
# Custom comparison logic
if name in ["fgd", "test_clip_fgd"]:
is_better = value < current_best
elif name in ["l1div", "bc"]:
is_better = value > current_best
else:
is_better = value < current_best # Default: lower is better
if is_better:
self.val_best[name] = {
"value": float(value),
"epoch": int(epoch)
}
# Save best checkpoint separately
self.save_checkpoint(
epoch=epoch,
iteration=epoch * len(self.train_loader),
is_best=True,
best_metric_name=name
)
# Add best value to metrics
metrics[f"best_{name}"] = float(self.val_best[name]["value"])
metrics[f"best_{name}_epoch"] = int(self.val_best[name]["epoch"])
# Always save regular checkpoint for every validation
self.save_checkpoint(
epoch=epoch,
iteration=epoch * len(self.train_loader),
is_best=False,
best_metric_name=None
)
# Log metrics
if self.rank == 0:
try:
wandb.log(metrics, step=epoch*len(self.train_loader))
except:
logger.info("WANDB not initialized ! Probably doing the testing now")
# Print validation results
pstr = "Validation Results >>>> "
pstr += " ".join([
f"{k.split('/')[-1]}: {v:.3f}"
for k, v in metrics.items()
if k.startswith("val/")
])
logger.info(pstr)
# Print best results
pstr = "Best Results >>>> "
pstr += " ".join([
f"{k}: {v['value']:.3f} (epoch {v['epoch']})"
for k, v in self.val_best.items()
])
logger.info(pstr)
def test_recording(self, dict_name, value, epoch):
self.tracker.update_meter(dict_name, "test", value)
_ = self.tracker.update_values(dict_name, 'test', epoch)
def save_checkpoint(self, epoch, iteration, is_best=False, best_metric_name=None):
"""Save training checkpoint
Args:
epoch (int): Current epoch number
iteration (int): Current iteration number
is_best (bool): Whether this is the best model so far
best_metric_name (str, optional): Name of the metric if this is a best checkpoint
"""
checkpoint = {
'epoch': epoch,
'iteration': iteration,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.opt.state_dict(),
'scheduler_state_dict': self.opt_s.state_dict() if hasattr(self, 'opt_s') and self.opt_s else None,
'val_best': self.val_best,
}
# Save regular checkpoint every 20 epochs
if epoch % 20 == 0:
checkpoint_path = os.path.join(self.checkpoint_path, f"checkpoint_{epoch}")
os.makedirs(checkpoint_path, exist_ok=True)
torch.save(checkpoint, os.path.join(checkpoint_path, "ckpt.pth"))
# Save best checkpoint if specified
if is_best and best_metric_name:
best_path = os.path.join(self.checkpoint_path, f"best_{best_metric_name}")
os.makedirs(best_path, exist_ok=True)
torch.save(checkpoint, os.path.join(best_path, "ckpt.pth"))
def prepare_all():
"""
Parse command line arguments and prepare configuration
"""
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/intention_w_distill.yaml")
parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume from")
parser.add_argument("--debug", action="store_true", help="Enable debugging mode")
parser.add_argument("--mode", type=str, choices=['train', 'test', 'render'], default='train',
help="Choose between 'train' or 'test' or 'render' mode")
parser.add_argument("--checkpoint", type=str, default=None,
help="Checkpoint path for testing or resuming training")
parser.add_argument('overrides', nargs=argparse.REMAINDER)
args = parser.parse_args()
# Load config
if args.config.endswith(".yaml"):
cfg = OmegaConf.load(args.config)
cfg.exp_name = args.config.split("/")[-1][:-5]
else:
raise ValueError("Unsupported config file format. Only .yaml files are allowed.")
# Handle resume from checkpoint
if args.resume:
cfg.resume_from_checkpoint = args.resume
# Debug mode settings
if args.debug:
cfg.wandb_project = "debug"
cfg.exp_name = "debug"
cfg.solver.max_train_steps = 4
# Process override arguments
if args.overrides:
for arg in args.overrides:
if '=' in arg:
key, value = arg.split('=')
try:
value = eval(value)
except:
pass
if key in cfg:
cfg[key] = value
else:
try:
# Handle nested config with dot notation
keys = key.split('.')
cfg_node = cfg
for k in keys[:-1]:
cfg_node = cfg_node[k]
cfg_node[keys[-1]] = value
except:
raise ValueError(f"Key {key} not found in config.")
# Set up wandb
if hasattr(cfg, 'wandb_key'):
os.environ["WANDB_API_KEY"] = cfg.wandb_key
# Create output directories
save_dir = os.path.join(cfg.output_dir, cfg.exp_name)
os.makedirs(save_dir, exist_ok=True)
os.makedirs(os.path.join(save_dir, 'sanity_check'), exist_ok=True)
# Save config
config_path = os.path.join(save_dir, 'sanity_check', f'{cfg.exp_name}.yaml')
with open(config_path, 'w') as f:
OmegaConf.save(cfg, f)
# Copy source files for reproducibility
current_dir = os.path.dirname(os.path.abspath(__file__))
sanity_check_dir = os.path.join(save_dir, 'sanity_check')
output_dir = os.path.abspath(cfg.output_dir)
def is_in_output_dir(path):
return os.path.abspath(path).startswith(output_dir)
def should_copy_file(file_path):
if is_in_output_dir(file_path):
return False
if '__pycache__' in file_path:
return False
if file_path.endswith('.pyc'):
return False
return True
# Copy Python files
for root, dirs, files in os.walk(current_dir):
if is_in_output_dir(root):
continue
for file in files:
if file.endswith(".py"):
full_file_path = os.path.join(root, file)
if should_copy_file(full_file_path):
relative_path = os.path.relpath(full_file_path, current_dir)
dest_path = os.path.join(sanity_check_dir, relative_path)
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
try:
shutil.copy(full_file_path, dest_path)
except Exception as e:
print(f"Warning: Could not copy {full_file_path}: {str(e)}")
return cfg, args
def init_class(module_name, class_name, config, **kwargs):
"""
Dynamically import and initialize a class
"""
module = importlib.import_module(module_name)
model_class = getattr(module, class_name)
instance = model_class(config, **kwargs)
return instance
def seed_everything(seed):
"""
Set random seeds for reproducibility
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
@logger.catch
def main_worker(rank, world_size, cfg, args):
if not sys.warnoptions:
warnings.simplefilter("ignore")
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
logger_tools.set_args_and_logger(cfg, rank)
seed_everything(cfg.seed)
other_tools.print_exp_info(cfg)
# Initialize trainer
trainer = __import__(f"shortcut_rvqvae_trainer", fromlist=["something"]).CustomTrainer(cfg, args)
# Resume logic
resume_epoch = 0
if args.resume:
# Find the checkpoint path
if os.path.isdir(args.resume):
ckpt_path = os.path.join(args.resume, "ckpt.pth")
else:
ckpt_path = args.resume
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu")
trainer.load_checkpoint(checkpoint)
resume_epoch = checkpoint.get('epoch', 0) + 1 # Start from next epoch
logger.info(f"Resumed from checkpoint {ckpt_path}, starting at epoch {resume_epoch}")
if args.mode == "train" and not args.resume:
logger.info("Training from scratch ...")
elif args.mode == "train" and args.resume:
logger.info(f"Resuming training from checkpoint {args.resume} ...")
elif args.mode == "test":
logger.info("Testing ...")
elif args.mode == "render":
logger.info("Rendering ...")
if args.mode == "train":
start_time = time.time()
for epoch in range(resume_epoch, cfg.solver.epochs+1):
if cfg.ddp:
trainer.val_loader.sampler.set_epoch(epoch)
if (epoch) % cfg.val_period == 0 and epoch > 0:
if rank == 0:
if cfg.data.test_clip:
trainer.test_clip(epoch)
else:
trainer.val(epoch)
epoch_time = time.time()-start_time
if trainer.rank == 0:
logger.info(f"Time info >>>> elapsed: {epoch_time/60:.2f} mins\t" +
f"remain: {(cfg.solver.epochs/(epoch+1e-7)-1)*epoch_time/60:.2f} mins")
if epoch != cfg.solver.epochs:
if cfg.ddp:
trainer.train_loader.sampler.set_epoch(epoch)
trainer.tracker.reset()
trainer.train(epoch)
if cfg.debug:
trainer.test(epoch)
# Final cleanup and logging
if rank == 0:
for k, v in trainer.val_best.items():
logger.info(f"Best {k}: {v['value']:.6f} at epoch {v['epoch']}")
wandb.finish()
elif args.mode == "test":
trainer.test_clip(999)
trainer.test(999)
elif args.mode == "render":
trainer.test_render(999)
if __name__ == "__main__":
# Set up distributed training environment
master_addr = '127.0.0.1'
master_port = 29500
import socket
# Function to check if a port is in use
def is_port_in_use(port, host='127.0.0.1'):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind((host, port))
return False # Port is available
except socket.error:
return True # Port is in use
# Find available port
while is_port_in_use(master_port):
print(f"Port {master_port} is in use, trying next port...")
master_port += 1
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = str(master_port)
cfg, args = prepare_all()
if cfg.ddp:
mp.set_start_method("spawn", force=True)
mp.spawn(
main_worker,
args=(len(cfg.gpus), cfg, args),
nprocs=len(cfg.gpus),
)
else:
main_worker(0, 1, cfg, args)