Spaces:
Runtime error
Runtime error
| import os | |
| import signal | |
| import time | |
| import csv | |
| import sys | |
| import warnings | |
| import random | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import torch.multiprocessing as mp | |
| import numpy as np | |
| import time | |
| import pprint | |
| from loguru import logger | |
| import smplx | |
| import matplotlib.pyplot as plt | |
| from utils import config, logger_tools, other_tools_hf, metric, data_transfer, other_tools | |
| from utils.joints import upper_body_mask, hands_body_mask, lower_body_mask | |
| from dataloaders import data_tools | |
| from dataloaders.build_vocab import Vocab | |
| from dataloaders.data_tools import joints_list | |
| from utils import rotation_conversions as rc | |
| import soundfile as sf | |
| import librosa | |
| import subprocess | |
| import shutil | |
| from transformers import pipeline | |
| from models.vq.model import RVQVAE | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| import platform | |
| if platform.system() == "Linux": | |
| os.environ['PYOPENGL_PLATFORM'] = 'egl' | |
| pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-tiny.en", | |
| chunk_length_s=30, | |
| device=device, | |
| ) | |
| debug = False | |
| class BaseTrainer(object): | |
| def __init__(self, args, cfg, ap): | |
| hf_dir = "hf" | |
| time_local = time.localtime() | |
| time_name_expend = "%02d%02d_%02d%02d%02d_"%(time_local[1], time_local[2],time_local[3], time_local[4], time_local[5]) | |
| self.time_name_expend = time_name_expend | |
| tmp_dir = args.out_path + "custom/"+ time_name_expend + hf_dir | |
| if not os.path.exists(tmp_dir + "/"): | |
| os.makedirs(tmp_dir + "/") | |
| self.audio_path = tmp_dir + "/tmp.wav" | |
| sf.write(self.audio_path, ap[1], ap[0]) | |
| audio, ssr = librosa.load(self.audio_path,sr=args.audio_sr) | |
| # use asr model to get corresponding text transcripts | |
| file_path = tmp_dir+"/tmp.lab" | |
| self.textgrid_path = tmp_dir + "/tmp.TextGrid" | |
| if not debug: | |
| text = pipe(audio, batch_size=8)["text"] | |
| with open(file_path, "w", encoding="utf-8") as file: | |
| file.write(text) | |
| # use montreal forced aligner to get textgrid | |
| mfa_override = os.environ.get("MFA_BINARY") | |
| mfa_path = mfa_override or shutil.which("mfa") | |
| if not mfa_path: | |
| raise FileNotFoundError( | |
| "Montreal Forced Aligner binary not found. Install it or set MFA_BINARY" | |
| ) | |
| env = os.environ.copy() | |
| command = [mfa_path, "align", tmp_dir, "english_us_arpa", "english_us_arpa", tmp_dir] | |
| result = subprocess.run(command, capture_output=True, text=True, env=env) | |
| print(f"MFA result: {result}") | |
| if result.returncode != 0: | |
| print(f"MFA stderr: {result.stderr}") | |
| ap = (ssr, audio) | |
| self.args = args | |
| self.rank = 0 # dist.get_rank() | |
| args.textgrid_file_path = self.textgrid_path | |
| args.audio_file_path = self.audio_path | |
| self.rank = 0 # dist.get_rank() | |
| self.checkpoint_path = tmp_dir | |
| args.tmp_dir = tmp_dir | |
| if self.rank == 0: | |
| self.test_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "test") | |
| self.test_loader = torch.utils.data.DataLoader( | |
| self.test_data, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=args.loader_workers, | |
| drop_last=False, | |
| ) | |
| logger.info(f"Init test dataloader success") | |
| model_module = __import__(f"models.{cfg.model.model_name}", fromlist=["something"]) | |
| self.model = getattr(model_module, cfg.model.g_name)(cfg) | |
| if self.rank == 0: | |
| logger.info(self.model) | |
| logger.info(f"init {cfg.model.g_name} success") | |
| smplx_path = Path(self.args.data_path_1) / "smplx_models" | |
| if not smplx_path.exists(): | |
| raise FileNotFoundError( | |
| "SMPL-X model directory missing at {}. Ensure assets are downloaded or" | |
| " set HF_GESTURELSM_WEIGHTS_REPO with smplx_models.".format(smplx_path) | |
| ) | |
| self.smplx = smplx.SMPLX( | |
| model_path=str(smplx_path), | |
| gender='NEUTRAL_2020', | |
| use_face_contour=False, | |
| num_betas=300, | |
| num_expression_coeffs=100, | |
| ext='npz', | |
| use_pca=False, | |
| ).eval() | |
| self.args = args | |
| self.ori_joint_list = joints_list[self.args.ori_joints] | |
| self.tar_joint_list_face = joints_list["beat_smplx_face"] | |
| self.tar_joint_list_upper = joints_list["beat_smplx_upper"] | |
| self.tar_joint_list_hands = joints_list["beat_smplx_hands"] | |
| self.tar_joint_list_lower = joints_list["beat_smplx_lower"] | |
| self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys()))*3) | |
| self.joints = 55 | |
| for joint_name in self.tar_joint_list_face: | |
| self.joint_mask_face[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 | |
| self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys()))*3) | |
| for joint_name in self.tar_joint_list_upper: | |
| self.joint_mask_upper[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 | |
| self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys()))*3) | |
| for joint_name in self.tar_joint_list_hands: | |
| self.joint_mask_hands[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 | |
| self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys()))*3) | |
| for joint_name in self.tar_joint_list_lower: | |
| self.joint_mask_lower[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 | |
| self.tracker = other_tools.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'exp', 'lvd', 'mse', "cls", "rec_face", "latent", "cls_full", "cls_self", "cls_word", "latent_word","latent_self","predict_x0_loss"], [False,True,True, False, False, False, False, False, False, False, False, False, False, False, False, False, False,False, False, False,False,False,False]) | |
| ##### VQ-VAE models ##### | |
| """Initialize and load VQ-VAE models for different body parts.""" | |
| # Face VQ model | |
| vq_model_module = __import__("models.motion_representation", fromlist=["something"]) | |
| self.vq_model_face = self._create_face_vq_model(vq_model_module) | |
| # Body part VQ models | |
| self.vq_models = self._create_body_vq_models() | |
| # Set all VQ models to eval mode | |
| self.vq_model_face.eval() | |
| for model in self.vq_models.values(): | |
| model.eval() | |
| self.vq_model_upper, self.vq_model_hands, self.vq_model_lower = self.vq_models.values() | |
| self.vqvae_latent_scale = self.args.vqvae_latent_scale | |
| self.args.vae_length = 240 | |
| ##### Loss functions ##### | |
| self.reclatent_loss = nn.MSELoss() | |
| self.vel_loss = torch.nn.L1Loss(reduction='mean') | |
| ##### Normalization ##### | |
| self.use_trans = self.args.use_trans | |
| self.mean = np.load(args.mean_pose_path) | |
| self.std = np.load(args.std_pose_path) | |
| # Extract body part specific normalizations | |
| for part in ['upper', 'hands', 'lower']: | |
| mask = globals()[f'{part}_body_mask'] | |
| setattr(self, f'mean_{part}', torch.from_numpy(self.mean[mask])) | |
| setattr(self, f'std_{part}', torch.from_numpy(self.std[mask])) | |
| # Translation normalization if needed | |
| if self.args.use_trans: | |
| self.trans_mean = torch.from_numpy(np.load(self.args.mean_trans_path)) | |
| self.trans_std = torch.from_numpy(np.load(self.args.std_trans_path)) | |
| def _create_face_vq_model(self, module): | |
| """Create and initialize face VQ model.""" | |
| self.args.vae_layer = 2 | |
| self.args.vae_length = 256 | |
| self.args.vae_test_dim = 106 | |
| model = getattr(module, "VQVAEConvZero")(self.args) | |
| other_tools.load_checkpoints(model, "./datasets/hub/pretrained_vq/face_vertex_1layer_790.bin", | |
| self.args.e_name) | |
| return model | |
| def _create_body_vq_models(self): | |
| """Create VQ-VAE models for body parts.""" | |
| vq_configs = { | |
| 'upper': {'dim_pose': 78}, | |
| 'hands': {'dim_pose': 180}, | |
| 'lower': {'dim_pose': 54 if not self.args.use_trans else 57} | |
| } | |
| vq_models = {} | |
| for part, config in vq_configs.items(): | |
| model = self._create_rvqvae_model(config['dim_pose'], part) | |
| vq_models[part] = model | |
| return vq_models | |
| def _create_rvqvae_model(self, dim_pose: int, body_part: str) -> RVQVAE: | |
| """Create a single RVQVAE model with specified configuration.""" | |
| args = self.args | |
| model = RVQVAE( | |
| args, dim_pose, args.nb_code, args.code_dim, args.code_dim, | |
| args.down_t, args.stride_t, args.width, args.depth, | |
| args.dilation_growth_rate, args.vq_act, args.vq_norm | |
| ) | |
| # Base directory = folder where demo.py lives | |
| base_dir = Path(__file__).resolve().parent | |
| checkpoint_path = base_dir / "ckpt" / f"net_300000_{body_part}.pth" | |
| if not checkpoint_path.exists(): | |
| raise FileNotFoundError( | |
| f"RVQVAE checkpoint for '{body_part}' not found at '{checkpoint_path}'.\n" | |
| f"CWD is {Path.cwd()}." | |
| ) | |
| state = torch.load(str(checkpoint_path), map_location="cpu") | |
| model.load_state_dict(state["net"]) | |
| return model | |
| def inverse_selection(self, filtered_t, selection_array, n): | |
| original_shape_t = np.zeros((n, selection_array.size)) | |
| selected_indices = np.where(selection_array == 1)[0] | |
| for i in range(n): | |
| original_shape_t[i, selected_indices] = filtered_t[i] | |
| return original_shape_t | |
| def inverse_selection_tensor(self, filtered_t, selection_array, n): | |
| selection_array = torch.from_numpy(selection_array) | |
| original_shape_t = torch.zeros((n, 165)) | |
| selected_indices = torch.where(selection_array == 1)[0] | |
| for i in range(n): | |
| original_shape_t[i, selected_indices] = filtered_t[i] | |
| return original_shape_t | |
| def _load_data(self, dict_data): | |
| tar_pose_raw = dict_data["pose"] | |
| tar_pose = tar_pose_raw[:, :, :165] | |
| tar_contact = tar_pose_raw[:, :, 165:169] | |
| tar_trans = dict_data["trans"] | |
| tar_trans_v = dict_data["trans_v"] | |
| tar_exps = dict_data["facial"] | |
| in_audio = dict_data["audio"] | |
| audio_onset = dict_data.get("audio_onset") | |
| if audio_onset is None: | |
| audio_onset = in_audio | |
| if 'wavlm' in dict_data: | |
| wavlm = dict_data["wavlm"] | |
| else: | |
| wavlm = None | |
| in_word = dict_data["word"] | |
| tar_beta = dict_data["beta"] | |
| tar_id = dict_data["id"].long() | |
| bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints | |
| tar_pose_hands = tar_pose[:, :, 25*3:55*3] | |
| tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) | |
| tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) | |
| tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)] | |
| tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) | |
| tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) | |
| tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)] | |
| tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) | |
| tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) | |
| tar_pose_lower = tar_pose_leg | |
| if self.args.pose_norm: | |
| tar_pose_upper = (tar_pose_upper - self.mean_upper) / self.std_upper | |
| tar_pose_hands = (tar_pose_hands - self.mean_hands) / self.std_hands | |
| tar_pose_lower = (tar_pose_lower - self.mean_lower) / self.std_lower | |
| if self.use_trans: | |
| tar_trans_v = (tar_trans_v - self.trans_mean)/self.trans_std | |
| tar_pose_lower = torch.cat([tar_pose_lower,tar_trans_v], dim=-1) | |
| latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper) | |
| latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands) | |
| latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower) | |
| latent_lengths = [latent_upper_top.shape[1], latent_hands_top.shape[1], latent_lower_top.shape[1]] | |
| if len(set(latent_lengths)) != 1: | |
| min_len = min(latent_lengths) | |
| logger.warning( | |
| "Latent length mismatch detected (upper=%d, hands=%d, lower=%d); truncating to %d", | |
| latent_upper_top.shape[1], | |
| latent_hands_top.shape[1], | |
| latent_lower_top.shape[1], | |
| min_len, | |
| ) | |
| latent_upper_top = latent_upper_top[:, :min_len, :] | |
| latent_hands_top = latent_hands_top[:, :min_len, :] | |
| latent_lower_top = latent_lower_top[:, :min_len, :] | |
| latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2)/self.args.vqvae_latent_scale | |
| style_feature = None | |
| return { | |
| "in_audio": in_audio, | |
| "wavlm": wavlm, | |
| "in_word": in_word, | |
| "tar_trans": tar_trans, | |
| "tar_exps": tar_exps, | |
| "tar_beta": tar_beta, | |
| "tar_pose": tar_pose, | |
| "latent_in": latent_in, | |
| "audio_onset": audio_onset, | |
| "tar_id": tar_id, | |
| "tar_contact": tar_contact, | |
| "style_feature":style_feature, | |
| } | |
| def _g_test(self, loaded_data): | |
| mode = 'test' | |
| bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints | |
| tar_pose = loaded_data["tar_pose"] | |
| tar_beta = loaded_data["tar_beta"] | |
| tar_exps = loaded_data["tar_exps"] | |
| tar_contact = loaded_data["tar_contact"] | |
| tar_trans = loaded_data["tar_trans"] | |
| in_word = loaded_data["in_word"] | |
| in_audio = loaded_data["in_audio"] | |
| audio_onset = loaded_data.get("audio_onset") | |
| in_x0 = loaded_data['latent_in'] | |
| in_seed = loaded_data['latent_in'] | |
| remain = n%8 | |
| if remain != 0: | |
| tar_pose = tar_pose[:, :-remain, :] | |
| tar_beta = tar_beta[:, :-remain, :] | |
| tar_trans = tar_trans[:, :-remain, :] | |
| in_word = in_word[:, :-remain] | |
| tar_exps = tar_exps[:, :-remain, :] | |
| tar_contact = tar_contact[:, :-remain, :] | |
| in_x0 = in_x0[:, :in_x0.shape[1]-(remain//self.args.vqvae_squeeze_scale), :] | |
| in_seed = in_seed[:, :in_x0.shape[1]-(remain//self.args.vqvae_squeeze_scale), :] | |
| n = n - remain | |
| tar_pose_jaw = tar_pose[:, :, 66:69] | |
| tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) | |
| tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) | |
| tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) | |
| tar_pose_hands = tar_pose[:, :, 25*3:55*3] | |
| tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) | |
| tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) | |
| tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)] | |
| tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) | |
| tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) | |
| tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)] | |
| tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) | |
| tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) | |
| tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2) | |
| tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) | |
| tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) | |
| latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) | |
| rec_all_face = [] | |
| rec_all_upper = [] | |
| rec_all_lower = [] | |
| rec_all_hands = [] | |
| vqvae_squeeze_scale = self.args.vqvae_squeeze_scale | |
| roundt = (n - self.args.pre_frames * vqvae_squeeze_scale) // (self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale) | |
| remain = (n - self.args.pre_frames * vqvae_squeeze_scale) % (self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale) | |
| round_l = self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale | |
| for i in range(0, roundt): | |
| in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames * vqvae_squeeze_scale] | |
| in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*self.args.pre_frames * vqvae_squeeze_scale] | |
| if audio_onset is not None: | |
| in_audio_onset_tmp = audio_onset[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*self.args.pre_frames * vqvae_squeeze_scale] | |
| else: | |
| in_audio_onset_tmp = in_audio_tmp | |
| in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames] | |
| in_seed_tmp = in_seed[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+self.args.pre_frames] | |
| in_x0_tmp = in_x0[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+self.args.pre_frames] | |
| mask_val = torch.ones(bs, self.args.pose_length, self.args.pose_dims+3+4).float() | |
| mask_val[:, :self.args.pre_frames, :] = 0.0 | |
| if i == 0: | |
| in_seed_tmp = in_seed_tmp[:, :self.args.pre_frames, :] | |
| else: | |
| in_seed_tmp = last_sample[:, -self.args.pre_frames:, :] | |
| cond_ = {'y':{}} | |
| cond_['y']['audio'] = in_audio_tmp | |
| cond_['y']['audio_onset'] = in_audio_onset_tmp | |
| cond_['y']['word'] = in_word_tmp | |
| cond_['y']['id'] = in_id_tmp | |
| cond_['y']['seed'] =in_seed_tmp | |
| cond_['y']['mask'] = (torch.zeros([self.args.batch_size, 1, 1, self.args.pose_length]) < 1) | |
| cond_['y']['style_feature'] = torch.zeros([bs, 512]) | |
| shape_ = (bs, 3*128, 1, 32) | |
| sample = self.model(cond_)['latents'] | |
| sample = sample.squeeze().permute(1,0).unsqueeze(0) | |
| last_sample = sample.clone() | |
| rec_latent_upper = sample[...,:128] | |
| rec_latent_hands = sample[...,128:2*128] | |
| rec_latent_lower = sample[...,2*128:] | |
| if i == 0: | |
| rec_all_upper.append(rec_latent_upper) | |
| rec_all_hands.append(rec_latent_hands) | |
| rec_all_lower.append(rec_latent_lower) | |
| else: | |
| rec_all_upper.append(rec_latent_upper[:, self.args.pre_frames:]) | |
| rec_all_hands.append(rec_latent_hands[:, self.args.pre_frames:]) | |
| rec_all_lower.append(rec_latent_lower[:, self.args.pre_frames:]) | |
| try: | |
| rec_all_upper = torch.cat(rec_all_upper, dim=1) * self.vqvae_latent_scale | |
| rec_all_hands = torch.cat(rec_all_hands, dim=1) * self.vqvae_latent_scale | |
| rec_all_lower = torch.cat(rec_all_lower, dim=1) * self.vqvae_latent_scale | |
| except RuntimeError as exc: | |
| shape_summary = { | |
| "upper": [tuple(t.shape) for t in rec_all_upper], | |
| "hands": [tuple(t.shape) for t in rec_all_hands], | |
| "lower": [tuple(t.shape) for t in rec_all_lower], | |
| } | |
| logger.error("Failed to concatenate latent segments: %s | shapes=%s", exc, shape_summary) | |
| raise | |
| rec_upper = self.vq_model_upper.latent2origin(rec_all_upper)[0] | |
| rec_hands = self.vq_model_hands.latent2origin(rec_all_hands)[0] | |
| rec_lower = self.vq_model_lower.latent2origin(rec_all_lower)[0] | |
| if self.use_trans: | |
| rec_trans_v = rec_lower[...,-3:] | |
| rec_trans_v = rec_trans_v * self.trans_std + self.trans_mean | |
| rec_trans = torch.zeros_like(rec_trans_v) | |
| rec_trans = torch.cumsum(rec_trans_v, dim=-2) | |
| rec_trans[...,1]=rec_trans_v[...,1] | |
| rec_lower = rec_lower[...,:-3] | |
| if self.args.pose_norm: | |
| rec_upper = rec_upper * self.std_upper + self.mean_upper | |
| rec_hands = rec_hands * self.std_hands + self.mean_hands | |
| rec_lower = rec_lower * self.std_lower + self.mean_lower | |
| n = n - remain | |
| tar_pose = tar_pose[:, :n, :] | |
| tar_exps = tar_exps[:, :n, :] | |
| tar_trans = tar_trans[:, :n, :] | |
| tar_beta = tar_beta[:, :n, :] | |
| rec_exps = tar_exps | |
| #rec_pose_jaw = rec_face[:, :, :6] | |
| rec_pose_legs = rec_lower[:, :, :54] | |
| bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1] | |
| rec_pose_upper = rec_upper.reshape(bs, n, 13, 6) | |
| rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# | |
| rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) | |
| rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs*n) | |
| rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) | |
| rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) | |
| rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6) | |
| rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) | |
| rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs*n) | |
| rec_pose_hands = rec_hands.reshape(bs, n, 30, 6) | |
| rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) | |
| rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) | |
| rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs*n) | |
| rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover | |
| rec_pose[:, 66:69] = tar_pose.reshape(bs*n, 55*3)[:, 66:69] | |
| rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3)) | |
| rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) | |
| tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3)) | |
| tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) | |
| return { | |
| 'rec_pose': rec_pose, | |
| 'rec_trans': rec_trans, | |
| 'tar_pose': tar_pose, | |
| 'tar_exps': tar_exps, | |
| 'tar_beta': tar_beta, | |
| 'tar_trans': tar_trans, | |
| 'rec_exps': rec_exps, | |
| } | |
| def test_demo(self, epoch): | |
| ''' | |
| input audio and text, output motion | |
| do not calculate loss and metric | |
| save video | |
| ''' | |
| print("=== Starting test_demo ===") | |
| results_save_path = self.checkpoint_path + f"/{epoch}/" | |
| if os.path.exists(results_save_path): | |
| import shutil | |
| shutil.rmtree(results_save_path) | |
| os.makedirs(results_save_path) | |
| start_time = time.time() | |
| total_length = 0 | |
| print("Setting models to eval mode...") | |
| self.model.eval() | |
| self.smplx.eval() | |
| # self.eval_copy.eval() | |
| print("Starting inference loop...") | |
| with torch.no_grad(): | |
| for its, batch_data in enumerate(self.test_loader): | |
| print(f"Processing batch {its}...") | |
| print("Loading data...") | |
| loaded_data = self._load_data(batch_data) | |
| print("Running model inference (this may take several minutes on CPU)...") | |
| net_out = self._g_test(loaded_data) | |
| print("Model inference complete!") | |
| tar_pose = net_out['tar_pose'] | |
| rec_pose = net_out['rec_pose'] | |
| tar_exps = net_out['tar_exps'] | |
| tar_beta = net_out['tar_beta'] | |
| rec_trans = net_out['rec_trans'] | |
| tar_trans = net_out['tar_trans'] | |
| rec_exps = net_out['rec_exps'] | |
| bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints | |
| if (30/self.args.pose_fps) != 1: | |
| assert 30%self.args.pose_fps == 0 | |
| n *= int(30/self.args.pose_fps) | |
| tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
| rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
| rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) | |
| rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) | |
| tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) | |
| tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) | |
| rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) | |
| rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) | |
| tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) | |
| tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) | |
| tar_pose_np = tar_pose.detach().cpu().numpy() | |
| rec_pose_np = rec_pose.detach().cpu().numpy() | |
| rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs*n, 3) | |
| rec_exp_np = rec_exps.detach().cpu().numpy().reshape(bs*n, 100) | |
| tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) | |
| tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs*n, 3) | |
| gt_npz = np.load("./demo/examples/2_scott_0_1_1.npz", allow_pickle=True) | |
| print("Saving results to npz file...") | |
| results_npz_file_save_path = results_save_path+f"result_{self.time_name_expend}"+'.npz' | |
| np.savez(results_npz_file_save_path, | |
| betas=gt_npz["betas"], | |
| poses=rec_pose_np, | |
| expressions=rec_exp_np, | |
| trans=rec_trans_np, | |
| model='smplx2020', | |
| gender='neutral', | |
| mocap_frame_rate = 30, | |
| ) | |
| total_length += n | |
| print("Rendering video (this may take 1-2 minutes)...") | |
| render_vid_path = other_tools_hf.render_one_sequence_no_gt( | |
| results_npz_file_save_path, | |
| # results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', | |
| results_save_path, | |
| self.audio_path, | |
| self.args.data_path_1+"smplx_models/", | |
| use_matplotlib = False, | |
| args = self.args, | |
| ) | |
| print(f"Video rendered successfully: {render_vid_path}") | |
| result = ( | |
| render_vid_path, | |
| results_npz_file_save_path, | |
| ) | |
| end_time = time.time() - start_time | |
| print(f"=== Complete! Total time: {int(end_time)} seconds ===") | |
| logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") | |
| return result | |
| def gesturelsm(audio_path, sample_stratege=None): | |
| print("\n" + "="*60) | |
| print("STARTING GESTURE GENERATION") | |
| print("="*60) | |
| # Set the config path for demo | |
| import sys | |
| sys.argv = ['demo.py', '--config', 'configs/shortcut_rvqvae_128_hf.yaml'] | |
| args, cfg = config.parse_args() | |
| print(f"Sample strategy: {sample_stratege}") | |
| #os.environ['TRANSFORMERS_CACHE'] = args.data_path_1 + "hub/" | |
| if not sys.warnoptions: | |
| warnings.simplefilter("ignore") | |
| # dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) | |
| #logger_tools.set_args_and_logger(args, rank) | |
| other_tools_hf.set_random_seed(args) | |
| other_tools_hf.print_exp_info(args) | |
| # return one intance of trainer | |
| try: | |
| print("Creating trainer instance...") | |
| trainer = BaseTrainer(args, cfg, ap=audio_path) | |
| print("Loading model checkpoint...") | |
| other_tools.load_checkpoints(trainer.model, args.test_ckpt, args.g_name) | |
| print("Checkpoint loaded successfully!") | |
| result = trainer.test_demo(999) | |
| if isinstance(result, tuple) and len(result) == 2: | |
| return result | |
| # If a single path or None returned, expand to two outputs | |
| return (result, None) | |
| except Exception as e: | |
| logger.exception("GestureLSM demo inference failed") | |
| # Return two Nones to satisfy Gradio output schema | |
| return (None, None) | |
| examples = [ | |
| ["demo/examples/2_scott_0_1_1.wav"], | |
| ["demo/examples/2_scott_0_2_2.wav"], | |
| ["demo/examples/2_scott_0_3_3.wav"], | |
| ["demo/examples/2_scott_0_4_4.wav"], | |
| ["demo/examples/2_scott_0_5_5.wav"], | |
| ] | |
| demo = gr.Interface( | |
| gesturelsm, # function | |
| inputs=[ | |
| gr.Audio(), | |
| ], # input type | |
| outputs=[ | |
| gr.Video(format="mp4", visible=True), | |
| gr.File(label="download motion and visualize in blender") | |
| ], | |
| title='GestureLSM: Latent Shortcut based Co-Speech Gesture Generation with Spatial-Temporal Modeling', | |
| description="1. Upload your audio. <br/>\ | |
| 2. Then, sit back and wait for the rendering to happen! This may take a while (e.g. 1-4 minutes) <br/>\ | |
| 3. After, you can view the videos. <br/>\ | |
| 4. Notice that we use a fix face animation, our method only produce body motion. <br/>\ | |
| 5. Use DDPM sample strategy will generate a better result, while it will take more inference time. \ | |
| ", | |
| article="Project links: [GestureLSM](https://github.com/andypinxinliu/GestureLSM). <br/>\ | |
| Reference links: [EMAGE](https://pantomatrix.github.io/EMAGE/). ", | |
| examples=examples, | |
| ) | |
| if __name__ == "__main__": | |
| os.environ["MASTER_ADDR"]='127.0.0.3' | |
| os.environ["MASTER_PORT"]='8678' | |
| #os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" | |
| demo.launch(server_name="0.0.0.0",share=True) | |