import os import pprint import random import sys import time import warnings from typing import Dict import librosa import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import wandb from dataloaders import data_tools from dataloaders.data_tools import joints_list from loguru import logger from models.vq.model import RVQVAE from optimizers.optim_factory import create_optimizer from optimizers.scheduler_factory import create_scheduler from torch.nn.parallel import DistributedDataParallel as DDP from tqdm import tqdm from trainer.base_trainer import BaseTrainer from utils import ( data_transfer, logger_tools, metric, other_tools, other_tools_hf, rotation_conversions as rc, ) from utils.joints import hands_body_mask, lower_body_mask, upper_body_mask def convert_15d_to_6d(motion): """ Convert 15D motion to 6D motion, the current motion is 15D, but the eval model is 6D """ bs = motion.shape[0] motion_6d = motion.reshape(bs, -1, 55, 15)[:, :, :, 6:12] motion_6d = motion_6d.reshape(bs, -1, 55 * 6) return motion_6d class CustomTrainer(BaseTrainer): """ Generative Trainer to support various generative models """ def __init__(self, cfg, args): super().__init__(cfg, args) self.cfg = cfg self.args = args self.joints = 55 self.ori_joint_list = joints_list["beat_smplx_joints"] self.tar_joint_list_face = joints_list["beat_smplx_face"] self.tar_joint_list_upper = joints_list["beat_smplx_upper"] self.tar_joint_list_hands = joints_list["beat_smplx_hands"] self.tar_joint_list_lower = joints_list["beat_smplx_lower"] self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys())) * 3) self.joints = 55 for joint_name in self.tar_joint_list_face: self.joint_mask_face[ self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0] : self.ori_joint_list[joint_name][ 1 ] ] = 1 self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys())) * 3) for joint_name in self.tar_joint_list_upper: self.joint_mask_upper[ self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0] : self.ori_joint_list[joint_name][ 1 ] ] = 1 self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys())) * 3) for joint_name in self.tar_joint_list_hands: self.joint_mask_hands[ self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0] : self.ori_joint_list[joint_name][ 1 ] ] = 1 self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys())) * 3) for joint_name in self.tar_joint_list_lower: self.joint_mask_lower[ self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0] : self.ori_joint_list[joint_name][ 1 ] ] = 1 self.tracker = other_tools.EpochTracker( ["fgd", "bc", "l1div", "predict_x0_loss", "test_clip_fgd"], [True, True, True, True, True], ) ##### Model ##### model_module = __import__( f"models.{cfg.model.model_name}", fromlist=["something"] ) if self.cfg.ddp: self.model = getattr(model_module, cfg.model.g_name)(cfg).to(self.rank) process_group = torch.distributed.new_group() self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( self.model, process_group ) self.model = DDP( self.model, device_ids=[self.rank], output_device=self.rank, broadcast_buffers=False, find_unused_parameters=False, ) else: self.model = getattr(model_module, cfg.model.g_name)(cfg) if self.args.mode == "train": if self.rank == 0: logger.info(self.model) logger.info(f"init {self.cfg.model.g_name} success") wandb.watch(self.model) ##### Optimizer and Scheduler ##### self.opt = create_optimizer(self.cfg.solver, self.model) self.opt_s = create_scheduler(self.cfg.solver, self.opt) ##### VQ-VAE models ##### """Initialize and load VQ-VAE models for different body parts.""" # Body part VQ models self.vq_models = self._create_body_vq_models() # Set all VQ models to eval mode for model in self.vq_models.values(): model.eval() self.vq_model_upper, self.vq_model_hands, self.vq_model_lower = ( self.vq_models.values() ) ##### Loss functions ##### self.reclatent_loss = nn.MSELoss() self.vel_loss = torch.nn.L1Loss(reduction="mean") ##### Normalization ##### self.mean = np.load("./mean_std/beatx_2_330_mean.npy") self.std = np.load("./mean_std/beatx_2_330_std.npy") # Extract body part specific normalizations for part in ["upper", "hands", "lower"]: mask = globals()[f"{part}_body_mask"] setattr(self, f"mean_{part}", torch.from_numpy(self.mean[mask])) setattr(self, f"std_{part}", torch.from_numpy(self.std[mask])) self.trans_mean = torch.from_numpy( np.load("./mean_std/beatx_2_trans_mean.npy") ) self.trans_std = torch.from_numpy( np.load("./mean_std/beatx_2_trans_std.npy") ) if self.args.checkpoint: try: ckpt_state_dict = torch.load(self.args.checkpoint, weights_only=False)[ "model_state_dict" ] except: ckpt_state_dict = torch.load(self.args.checkpoint, weights_only=False)[ "model_state" ] # remove 'audioEncoder' from the state_dict due to legacy issues ckpt_state_dict = { k: v for k, v in ckpt_state_dict.items() if "modality_encoder.audio_encoder." not in k } self.model.load_state_dict(ckpt_state_dict, strict=False) logger.info(f"Loaded checkpoint from {self.args.checkpoint}") def _create_body_vq_models(self) -> Dict[str, RVQVAE]: """Create VQ-VAE models for body parts.""" vq_configs = { "upper": {"dim_pose": 78}, "hands": {"dim_pose": 180}, "lower": {"dim_pose": 57}, } vq_models = {} for part, config in vq_configs.items(): model = self._create_rvqvae_model(config["dim_pose"], part) vq_models[part] = model return vq_models def _create_rvqvae_model(self, dim_pose: int, body_part: str) -> RVQVAE: """Create a single RVQVAE model with specified configuration.""" vq_args = self.args vq_args.num_quantizers = 6 vq_args.shared_codebook = False vq_args.quantize_dropout_prob = 0.2 vq_args.quantize_dropout_cutoff_index = 0 vq_args.mu = 0.99 vq_args.beta = 1.0 model = RVQVAE( vq_args, input_width=dim_pose, nb_code=1024, code_dim=128, output_emb_width=128, down_t=2, stride_t=2, width=512, depth=3, dilation_growth_rate=3, activation="relu", norm=None, ) # Load pretrained weights checkpoint_path = getattr(self.cfg, f"vqvae_{body_part}_path") model.load_state_dict(torch.load(checkpoint_path)["net"]) return model def inverse_selection(self, filtered_t, selection_array, n): original_shape_t = np.zeros((n, selection_array.size)) selected_indices = np.where(selection_array == 1)[0] for i in range(n): original_shape_t[i, selected_indices] = filtered_t[i] return original_shape_t def inverse_selection_tensor(self, filtered_t, selection_array, n): selection_array = torch.from_numpy(selection_array) original_shape_t = torch.zeros((n, 165)) selected_indices = torch.where(selection_array == 1)[0] for i in range(n): original_shape_t[i, selected_indices] = filtered_t[i] return original_shape_t def _load_data(self, dict_data): facial_rep = dict_data["facial"] beta = dict_data["beta"] tar_trans = dict_data["trans"] tar_id = dict_data["id"] # process the pose data tar_pose = dict_data["pose"][:, :, :165] tar_trans_v = dict_data["trans_v"] tar_trans = dict_data["trans"] bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints tar_pose_hands = tar_pose[:, :, 25 * 3 : 55 * 3] tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30 * 6) tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)] tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13 * 6) tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)] tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9 * 6) tar_pose_lower = tar_pose_leg tar_pose_upper = (tar_pose_upper - self.mean_upper) / self.std_upper tar_pose_hands = (tar_pose_hands - self.mean_hands) / self.std_hands tar_pose_lower = (tar_pose_lower - self.mean_lower) / self.std_lower tar_trans_v = (tar_trans_v - self.trans_mean) / self.trans_std tar_pose_lower = torch.cat([tar_pose_lower, tar_trans_v], dim=-1) latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper) latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands) latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower) latent_lengths = [latent_upper_top.shape[1], latent_hands_top.shape[1], latent_lower_top.shape[1]] if len(set(latent_lengths)) != 1: min_len = min(latent_lengths) logger.warning( "Latent length mismatch detected (upper=%d, hands=%d, lower=%d); truncating to %d", latent_upper_top.shape[1], latent_hands_top.shape[1], latent_lower_top.shape[1], min_len, ) latent_upper_top = latent_upper_top[:, :min_len, :] latent_hands_top = latent_hands_top[:, :min_len, :] latent_lower_top = latent_lower_top[:, :min_len, :] ## TODO: Whether the latent scale is needed here? # latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2) latent_in = ( torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2) / 5 ) word = dict_data.get("word", None) # style feature is always None (without annotation, we never know what it is) style_feature = None audio_onset = None if self.cfg.data.onset_rep: audio_onset = dict_data["audio_onset"] return { "audio_onset": audio_onset, "word": word, "latent_in": latent_in, "tar_id": tar_id, "facial_rep": facial_rep, "beta": beta, "tar_pose": tar_pose, "trans": tar_trans, "style_feature": style_feature, } def _g_training(self, loaded_data, mode="train", epoch=0): self.model.train() cond_ = {"y": {}} cond_["y"]["audio_onset"] = loaded_data["audio_onset"] cond_["y"]["word"] = loaded_data["word"] cond_["y"]["id"] = loaded_data["tar_id"] cond_["y"]["seed"] = loaded_data["latent_in"][:, : self.cfg.pre_frames] cond_["y"]["style_feature"] = loaded_data["style_feature"] x0 = loaded_data["latent_in"] x0 = x0.permute(0, 2, 1).unsqueeze(2) g_loss_final = self.model.module.train_forward(cond_, x0)["loss"] self.tracker.update_meter("predict_x0_loss", "train", g_loss_final.item()) if mode == "train": return g_loss_final def _g_test(self, loaded_data): self.model.eval() tar_beta = loaded_data["beta"] tar_pose = loaded_data["tar_pose"] tar_exps = loaded_data["facial_rep"] tar_trans = loaded_data["trans"] audio_onset = loaded_data["audio_onset"] in_word = loaded_data["word"] in_x0 = loaded_data["latent_in"] in_seed = loaded_data["latent_in"] bs, n, j = ( loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints, ) remain = n % 8 if remain != 0: tar_pose = tar_pose[:, :-remain, :] tar_beta = tar_beta[:, :-remain, :] tar_exps = tar_exps[:, :-remain, :] in_x0 = in_x0[ :, : in_x0.shape[1] - (remain // self.cfg.vqvae_squeeze_scale), : ] in_seed = in_seed[ :, : in_x0.shape[1] - (remain // self.cfg.vqvae_squeeze_scale), : ] in_word = in_word[:, :-remain] n = n - remain rec_all_upper = [] rec_all_lower = [] rec_all_hands = [] vqvae_squeeze_scale = self.cfg.vqvae_squeeze_scale pre_frames_scaled = self.cfg.pre_frames * vqvae_squeeze_scale roundt = (n - pre_frames_scaled) // ( self.cfg.data.pose_length - pre_frames_scaled ) remain = (n - pre_frames_scaled) % ( self.cfg.data.pose_length - pre_frames_scaled ) round_l = self.cfg.pose_length - pre_frames_scaled round_audio = int(round_l / 3 * 5) in_audio_onset_tmp = None in_word_tmp = None for i in range(0, roundt): if audio_onset is not None: in_audio_onset_tmp = audio_onset[ :, i * (16000 // 30 * round_l) : (i + 1) * (16000 // 30 * round_l) + 16000 // 30 * self.cfg.pre_frames * vqvae_squeeze_scale, ] if in_word is not None: in_word_tmp = in_word[ :, i * (round_l) : (i + 1) * (round_l) + self.cfg.pre_frames * vqvae_squeeze_scale, ] in_id_tmp = loaded_data["tar_id"][ :, i * (round_l) : (i + 1) * (round_l) + self.cfg.pre_frames ] in_seed_tmp = in_seed[ :, i * (round_l) // vqvae_squeeze_scale : (i + 1) * (round_l) // vqvae_squeeze_scale + self.cfg.pre_frames, ] if i == 0: in_seed_tmp = in_seed_tmp[:, : self.cfg.pre_frames, :] else: in_seed_tmp = last_sample[:, -self.cfg.pre_frames :, :] cond_ = {"y": {}} cond_["y"]["audio_onset"] = in_audio_onset_tmp cond_["y"]["word"] = in_word_tmp cond_["y"]["id"] = in_id_tmp cond_["y"]["seed"] = in_seed_tmp cond_["y"]["style_feature"] = torch.zeros([bs, 512]) sample = self.model(cond_)["latents"] sample = sample.squeeze(2).permute(0, 2, 1) last_sample = sample.clone() code_dim = self.vq_model_upper.code_dim rec_latent_upper = sample[..., :code_dim] rec_latent_hands = sample[..., code_dim : code_dim * 2] rec_latent_lower = sample[..., code_dim * 2 : code_dim * 3] if i == 0: rec_all_upper.append(rec_latent_upper) rec_all_hands.append(rec_latent_hands) rec_all_lower.append(rec_latent_lower) else: rec_all_upper.append(rec_latent_upper[:, self.cfg.pre_frames :]) rec_all_hands.append(rec_latent_hands[:, self.cfg.pre_frames :]) rec_all_lower.append(rec_latent_lower[:, self.cfg.pre_frames :]) try: rec_all_upper = torch.cat(rec_all_upper, dim=1) * 5 rec_all_hands = torch.cat(rec_all_hands, dim=1) * 5 rec_all_lower = torch.cat(rec_all_lower, dim=1) * 5 except RuntimeError as exc: shape_summary = { "upper": [tuple(t.shape) for t in rec_all_upper], "hands": [tuple(t.shape) for t in rec_all_hands], "lower": [tuple(t.shape) for t in rec_all_lower], } logger.error("Failed to concatenate latent segments: %s | shapes=%s", exc, shape_summary) raise rec_upper = self.vq_model_upper.latent2origin(rec_all_upper)[0] rec_hands = self.vq_model_hands.latent2origin(rec_all_hands)[0] rec_lower = self.vq_model_lower.latent2origin(rec_all_lower)[0] rec_trans_v = rec_lower[..., -3:] rec_trans_v = rec_trans_v * self.trans_std + self.trans_mean rec_trans = torch.zeros_like(rec_trans_v) rec_trans = torch.cumsum(rec_trans_v, dim=-2) rec_trans[..., 1] = rec_trans_v[..., 1] rec_lower = rec_lower[..., :-3] rec_upper = rec_upper * self.std_upper + self.mean_upper rec_hands = rec_hands * self.std_hands + self.mean_hands rec_lower = rec_lower * self.std_lower + self.mean_lower n = n - remain tar_pose = tar_pose[:, :n, :] tar_exps = tar_exps[:, :n, :] tar_trans = tar_trans[:, :n, :] tar_beta = tar_beta[:, :n, :] if hasattr(self.cfg.model, "use_exp") and self.cfg.model.use_exp: rec_exps = tar_exps # fallback to tar_exps since rec_face is not defined else: rec_exps = tar_exps rec_trans = tar_trans rec_pose_legs = rec_lower[:, :, :54] bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1] rec_pose_upper = rec_upper.reshape(bs, n, 13, 6) rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper) # rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs * n, 13 * 3) rec_pose_upper_recover = self.inverse_selection_tensor( rec_pose_upper, self.joint_mask_upper, bs * n ) rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs * n, 9 * 3) rec_pose_lower_recover = self.inverse_selection_tensor( rec_pose_lower, self.joint_mask_lower, bs * n ) rec_pose_hands = rec_hands.reshape(bs, n, 30, 6) rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs * n, 30 * 3) rec_pose_hands_recover = self.inverse_selection_tensor( rec_pose_hands, self.joint_mask_hands, bs * n ) rec_pose = ( rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover ) rec_pose[:, 66:69] = tar_pose.reshape(bs * n, 55 * 3)[:, 66:69] rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs * n, j, 3)) rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j * 6) tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs * n, j, 3)) tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j * 6) return { "rec_pose": rec_pose, "rec_exps": rec_exps, "rec_trans": rec_trans, "tar_pose": tar_pose, "tar_exps": tar_exps, "tar_beta": tar_beta, "tar_trans": tar_trans, } def train(self, epoch): self.model.train() t_start = time.time() self.tracker.reset() for its, batch_data in enumerate(self.train_loader): loaded_data = self._load_data(batch_data) t_data = time.time() - t_start self.opt.zero_grad() g_loss_final = 0 g_loss_final += self._g_training(loaded_data, "train", epoch) g_loss_final.backward() if self.cfg.solver.grad_norm != 0: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.cfg.solver.grad_norm ) self.opt.step() mem_cost = torch.cuda.memory_cached() / 1e9 lr_g = self.opt.param_groups[0]["lr"] t_train = time.time() - t_start - t_data t_start = time.time() if its % self.cfg.log_period == 0: self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g) if self.cfg.debug: if its == 1: break self.opt_s.step(epoch) @torch.no_grad() def _common_test_inference( self, data_loader, epoch, mode="val", max_iterations=None, save_results=False ): """ Common inference logic shared by val, test, test_clip, and test_render methods. Args: data_loader: The data loader to iterate over epoch: Current epoch number mode: Mode string for logging ("val", "test", "test_clip", "test_render") max_iterations: Maximum number of iterations (None for no limit) save_results: Whether to save result files Returns: Dictionary containing computed metrics and results """ start_time = time.time() total_length = 0 test_seq_list = self.test_data.selected_file align = 0 latent_out = [] latent_ori = [] l2_all = 0 lvel = 0 results = [] # Setup save path for test mode results_save_path = None if save_results: results_save_path = self.checkpoint_path + f"/{epoch}/" if mode == "test_render": if os.path.exists(results_save_path): import shutil shutil.rmtree(results_save_path) os.makedirs(results_save_path, exist_ok=True) self.model.eval() self.smplx.eval() if hasattr(self, "eval_copy"): self.eval_copy.eval() with torch.no_grad(): iterator = enumerate(data_loader) if mode in ["test_clip", "test"]: iterator = enumerate( tqdm(data_loader, desc=f"Testing {mode}", leave=True) ) for its, batch_data in iterator: if max_iterations is not None and its > max_iterations: break loaded_data = self._load_data(batch_data) net_out = self._g_test(loaded_data) tar_pose = net_out["tar_pose"] rec_pose = net_out["rec_pose"] tar_exps = net_out["tar_exps"] tar_beta = net_out["tar_beta"] rec_trans = net_out["rec_trans"] tar_trans = net_out.get("tar_trans", rec_trans) rec_exps = net_out.get("rec_exps", tar_exps) bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints # Handle frame rate conversion if (30 / self.cfg.data.pose_fps) != 1: assert 30 % self.cfg.data.pose_fps == 0 n *= int(30 / self.cfg.data.pose_fps) tar_pose = torch.nn.functional.interpolate( tar_pose.permute(0, 2, 1), scale_factor=30 / self.cfg.data.pose_fps, mode="linear", ).permute(0, 2, 1) scale_factor = ( 30 / self.cfg.data.pose_fps if mode != "test" else 30 / self.cfg.pose_fps ) rec_pose = torch.nn.functional.interpolate( rec_pose.permute(0, 2, 1), scale_factor=scale_factor, mode="linear", ).permute(0, 2, 1) # Calculate latent representations for evaluation if hasattr(self, "eval_copy") and mode != "test_render": remain = n % self.cfg.vae_test_len latent_out.append( self.eval_copy.map2latent(rec_pose[:, : n - remain]) .reshape(-1, self.cfg.vae_length) .detach() .cpu() .numpy() ) latent_ori.append( self.eval_copy.map2latent(tar_pose[:, : n - remain]) .reshape(-1, self.cfg.vae_length) .detach() .cpu() .numpy() ) rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs * n, j, 6)) rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs * n, j * 3) tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs * n, j, 6)) tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs * n, j * 3) # Generate SMPLX vertices and joints vertices_rec = self.smplx( betas=tar_beta.reshape(bs * n, 300), transl=rec_trans.reshape(bs * n, 3) - rec_trans.reshape(bs * n, 3), expression=tar_exps.reshape(bs * n, 100) - tar_exps.reshape(bs * n, 100), jaw_pose=rec_pose[:, 66:69], global_orient=rec_pose[:, :3], body_pose=rec_pose[:, 3 : 21 * 3 + 3], left_hand_pose=rec_pose[:, 25 * 3 : 40 * 3], right_hand_pose=rec_pose[:, 40 * 3 : 55 * 3], return_joints=True, leye_pose=rec_pose[:, 69:72], reye_pose=rec_pose[:, 72:75], ) joints_rec = ( vertices_rec["joints"] .detach() .cpu() .numpy() .reshape(bs, n, 127 * 3)[0, :n, : 55 * 3] ) # Calculate L1 diversity if hasattr(self, "l1_calculator"): _ = self.l1_calculator.run(joints_rec) # Calculate alignment for single batch if ( hasattr(self, "alignmenter") and self.alignmenter is not None and bs == 1 and mode != "test_render" ): in_audio_eval, sr = librosa.load( self.cfg.data.data_path + "wave16k/" + test_seq_list.iloc[its]["id"] + ".wav" ) in_audio_eval = librosa.resample( in_audio_eval, orig_sr=sr, target_sr=self.cfg.data.audio_sr ) a_offset = int( self.align_mask * (self.cfg.data.audio_sr / self.cfg.data.pose_fps) ) onset_bt = self.alignmenter.load_audio( in_audio_eval[ : int(self.cfg.data.audio_sr / self.cfg.data.pose_fps * n) ], a_offset, len(in_audio_eval) - a_offset, True, ) beat_vel = self.alignmenter.load_pose( joints_rec, self.align_mask, n - self.align_mask, 30, True ) align += self.alignmenter.calculate_align( onset_bt, beat_vel, 30 ) * (n - 2 * self.align_mask) # Mode-specific processing if mode == "test" and save_results: # Calculate facial losses for test mode vertices_rec_face = self.smplx( betas=tar_beta.reshape(bs * n, 300), transl=rec_trans.reshape(bs * n, 3) - rec_trans.reshape(bs * n, 3), expression=rec_exps.reshape(bs * n, 100), jaw_pose=rec_pose[:, 66:69], global_orient=rec_pose[:, :3] - rec_pose[:, :3], body_pose=rec_pose[:, 3 : 21 * 3 + 3] - rec_pose[:, 3 : 21 * 3 + 3], left_hand_pose=rec_pose[:, 25 * 3 : 40 * 3] - rec_pose[:, 25 * 3 : 40 * 3], right_hand_pose=rec_pose[:, 40 * 3 : 55 * 3] - rec_pose[:, 40 * 3 : 55 * 3], return_verts=True, return_joints=True, leye_pose=rec_pose[:, 69:72] - rec_pose[:, 69:72], reye_pose=rec_pose[:, 72:75] - rec_pose[:, 72:75], ) vertices_tar_face = self.smplx( betas=tar_beta.reshape(bs * n, 300), transl=tar_trans.reshape(bs * n, 3) - tar_trans.reshape(bs * n, 3), expression=tar_exps.reshape(bs * n, 100), jaw_pose=tar_pose[:, 66:69], global_orient=tar_pose[:, :3] - tar_pose[:, :3], body_pose=tar_pose[:, 3 : 21 * 3 + 3] - tar_pose[:, 3 : 21 * 3 + 3], left_hand_pose=tar_pose[:, 25 * 3 : 40 * 3] - tar_pose[:, 25 * 3 : 40 * 3], right_hand_pose=tar_pose[:, 40 * 3 : 55 * 3] - tar_pose[:, 40 * 3 : 55 * 3], return_verts=True, return_joints=True, leye_pose=tar_pose[:, 69:72] - tar_pose[:, 69:72], reye_pose=tar_pose[:, 72:75] - tar_pose[:, 72:75], ) facial_rec = ( vertices_rec_face["vertices"].reshape(1, n, -1)[0, :n].cpu() ) facial_tar = ( vertices_tar_face["vertices"].reshape(1, n, -1)[0, :n].cpu() ) face_vel_loss = self.vel_loss( facial_rec[1:, :] - facial_tar[:-1, :], facial_tar[1:, :] - facial_tar[:-1, :], ) l2 = self.reclatent_loss(facial_rec, facial_tar) l2_all += l2.item() * n lvel += face_vel_loss.item() * n # Save results if needed if save_results: if mode == "test": # Save NPZ files for test mode tar_pose_np = tar_pose.detach().cpu().numpy() rec_pose_np = rec_pose.detach().cpu().numpy() rec_trans_np = ( rec_trans.detach().cpu().numpy().reshape(bs * n, 3) ) rec_exp_np = ( rec_exps.detach().cpu().numpy().reshape(bs * n, 100) ) tar_exp_np = ( tar_exps.detach().cpu().numpy().reshape(bs * n, 100) ) tar_trans_np = ( tar_trans.detach().cpu().numpy().reshape(bs * n, 3) ) gt_npz = np.load( self.cfg.data.data_path + self.cfg.data.pose_rep + "/" + test_seq_list.iloc[its]["id"] + ".npz", allow_pickle=True, ) np.savez( results_save_path + "gt_" + test_seq_list.iloc[its]["id"] + ".npz", betas=gt_npz["betas"], poses=tar_pose_np, expressions=tar_exp_np, trans=tar_trans_np, model="smplx2020", gender="neutral", mocap_frame_rate=30, ) np.savez( results_save_path + "res_" + test_seq_list.iloc[its]["id"] + ".npz", betas=gt_npz["betas"], poses=rec_pose_np, expressions=rec_exp_np, trans=rec_trans_np, model="smplx2020", gender="neutral", mocap_frame_rate=30, ) elif mode == "test_render": # Save results and render for test_render mode audio_name = loaded_data["audio_name"][0] rec_pose_np = rec_pose.detach().cpu().numpy() rec_trans_np = ( rec_trans.detach().cpu().numpy().reshape(bs * n, 3) ) rec_exp_np = ( rec_exps.detach().cpu().numpy().reshape(bs * n, 100) ) gt_npz = np.load( "./demo/examples/2_scott_0_1_1.npz", allow_pickle=True ) file_name = audio_name.split("/")[-1].split(".")[0] results_npz_file_save_path = ( results_save_path + f"result_{file_name}.npz" ) np.savez( results_npz_file_save_path, betas=gt_npz["betas"], poses=rec_pose_np, expressions=rec_exp_np, trans=rec_trans_np, model="smplx2020", gender="neutral", mocap_frame_rate=30, ) render_vid_path = other_tools_hf.render_one_sequence_no_gt( results_npz_file_save_path, results_save_path, audio_name, self.cfg.data_path_1 + "smplx_models/", use_matplotlib=False, args=self.cfg, ) total_length += n return { "total_length": total_length, "align": align, "latent_out": latent_out, "latent_ori": latent_ori, "l2_all": l2_all, "lvel": lvel, "start_time": start_time, } def val(self, epoch): self.tracker.reset() results = self._common_test_inference( self.test_loader, epoch, mode="val", max_iterations=15 ) total_length = results["total_length"] align = results["align"] latent_out = results["latent_out"] latent_ori = results["latent_ori"] l2_all = results["l2_all"] lvel = results["lvel"] start_time = results["start_time"] logger.info(f"l2 loss: {l2_all/total_length:.10f}") logger.info(f"lvel loss: {lvel/total_length:.10f}") latent_out_all = np.concatenate(latent_out, axis=0) latent_ori_all = np.concatenate(latent_ori, axis=0) fgd = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all) logger.info(f"fgd score: {fgd}") self.tracker.update_meter("fgd", "val", fgd) align_avg = align / (total_length - 2 * len(self.test_loader) * self.align_mask) logger.info(f"align score: {align_avg}") self.tracker.update_meter("bc", "val", align_avg) l1div = self.l1_calculator.avg() logger.info(f"l1div score: {l1div}") self.tracker.update_meter("l1div", "val", l1div) self.val_recording(epoch) end_time = time.time() - start_time logger.info( f"total inference time: {int(end_time)} s for {int(total_length/self.cfg.data.pose_fps)} s motion" ) def test_clip(self, epoch): self.tracker.reset() # Test on CLIP dataset results_clip = self._common_test_inference( self.test_clip_loader, epoch, mode="test_clip" ) total_length_clip = results_clip["total_length"] latent_out_clip = results_clip["latent_out"] latent_ori_clip = results_clip["latent_ori"] start_time = results_clip["start_time"] latent_out_all_clip = np.concatenate(latent_out_clip, axis=0) latent_ori_all_clip = np.concatenate(latent_ori_clip, axis=0) fgd_clip = data_tools.FIDCalculator.frechet_distance( latent_out_all_clip, latent_ori_all_clip ) logger.info(f"test_clip fgd score: {fgd_clip}") self.tracker.update_meter("test_clip_fgd", "val", fgd_clip) current_time = time.time() test_clip_time = current_time - start_time logger.info( f"total test_clip inference time: {int(test_clip_time)} s for {int(total_length_clip/self.cfg.data.pose_fps)} s motion" ) # Test on regular test dataset for recording results_test = self._common_test_inference( self.test_loader, epoch, mode="test_clip" ) total_length = results_test["total_length"] align = results_test["align"] latent_out = results_test["latent_out"] latent_ori = results_test["latent_ori"] latent_out_all = np.concatenate(latent_out, axis=0) latent_ori_all = np.concatenate(latent_ori, axis=0) fgd = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all) logger.info(f"fgd score: {fgd}") self.tracker.update_meter("fgd", "val", fgd) align_avg = align / (total_length - 2 * len(self.test_loader) * self.align_mask) logger.info(f"align score: {align_avg}") self.tracker.update_meter("bc", "val", align_avg) l1div = self.l1_calculator.avg() logger.info(f"l1div score: {l1div}") self.tracker.update_meter("l1div", "val", l1div) self.val_recording(epoch) end_time = time.time() - current_time logger.info( f"total inference time: {int(end_time)} s for {int(total_length/self.cfg.data.pose_fps)} s motion" ) def test(self, epoch): results_save_path = self.checkpoint_path + f"/{epoch}/" os.makedirs(results_save_path, exist_ok=True) results = self._common_test_inference( self.test_loader, epoch, mode="test", save_results=True ) total_length = results["total_length"] align = results["align"] latent_out = results["latent_out"] latent_ori = results["latent_ori"] l2_all = results["l2_all"] lvel = results["lvel"] start_time = results["start_time"] logger.info(f"l2 loss: {l2_all/total_length:.10f}") logger.info(f"lvel loss: {lvel/total_length:.10f}") latent_out_all = np.concatenate(latent_out, axis=0) latent_ori_all = np.concatenate(latent_ori, axis=0) fgd = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all) logger.info(f"fgd score: {fgd}") self.test_recording("fgd", fgd, epoch) align_avg = align / (total_length - 2 * len(self.test_loader) * self.align_mask) logger.info(f"align score: {align_avg}") self.test_recording("bc", align_avg, epoch) l1div = self.l1_calculator.avg() logger.info(f"l1div score: {l1div}") self.test_recording("l1div", l1div, epoch) end_time = time.time() - start_time logger.info( f"total inference time: {int(end_time)} s for {int(total_length/self.cfg.data.pose_fps)} s motion" ) def test_render(self, epoch): import platform if platform.system() == "Linux": os.environ["PYOPENGL_PLATFORM"] = "egl" """ input audio and text, output motion do not calculate loss and metric save video """ results = self._common_test_inference( self.test_loader, epoch, mode="test_render", save_results=True ) def load_checkpoint(self, checkpoint): # checkpoint is already a dict, do NOT call torch.load again! try: ckpt_state_dict = checkpoint["model_state_dict"] except: ckpt_state_dict = checkpoint["model_state"] # remove 'audioEncoder' from the state_dict due to legacy issues ckpt_state_dict = { k: v for k, v in ckpt_state_dict.items() if "modality_encoder.audio_encoder." not in k } self.model.load_state_dict(ckpt_state_dict, strict=False) try: self.opt.load_state_dict(checkpoint["optimizer_state_dict"]) except: print("No optimizer loaded!") if ( "scheduler_state_dict" in checkpoint and checkpoint["scheduler_state_dict"] is not None ): self.opt_s.load_state_dict(checkpoint["scheduler_state_dict"]) if "val_best" in checkpoint: self.val_best = checkpoint["val_best"] logger.info("Checkpoint loaded successfully.")