import os import shutil import sys import tarfile import zipfile from pathlib import Path from huggingface_hub import snapshot_download # --------------------------------------------------------------------------- # Optional: pull checkpoints and auxiliary assets at startup. Set the # HF_GESTURELSM_WEIGHTS_REPO environment variable in the Space settings to the # dataset or model repo that hosts the pre-trained weights (e.g. "username/gesturelsm-assets"). # Files will be placed under ckpt/ so the existing config paths keep working. # --------------------------------------------------------------------------- BASE_DIR = Path(__file__).parent.resolve() ROOT_DIR = BASE_DIR.parent.resolve() # Ensure project root is on sys.path so intra-repo imports (e.g. `optimizers`) work. if str(ROOT_DIR) not in sys.path: sys.path.insert(0, str(ROOT_DIR)) # Quick sanity check in Space logs to confirm the repo root is visible at runtime. print("[GestureLSM] sys.path:") for entry in sys.path: print(" ", entry) ASSET_CACHE_DIR = BASE_DIR / "asset_cache" ASSET_CACHE_DIR.mkdir(parents=True, exist_ok=True) weights_repo = os.environ.get("HF_GESTURELSM_WEIGHTS_REPO", "").strip() if weights_repo: downloaded_root = Path( snapshot_download( repo_id=weights_repo, repo_type="dataset", local_dir=ASSET_CACHE_DIR, local_dir_use_symlinks=False, allow_patterns=[ "*.pth", "*.pt", "*.bin", "*.npz", "*.npy", "*.tar", "*.tar.gz", "*.zip", ], ) ) else: downloaded_root = ASSET_CACHE_DIR def _sync_downloaded_assets(download_root: Path) -> None: """Copy all downloaded asset files into the repo root.""" allowed_suffixes = { ".pth", ".pt", ".bin", ".npz", ".npy", ".tar", ".gz", ".zip", } anchor_dirs = {"ckpt", "mean_std", "datasets", "weights", "hf_assets"} copied = 0 skipped = 0 for file_path in download_root.rglob("*"): if not file_path.is_file(): continue if file_path.suffix.lower() not in allowed_suffixes: continue relative = file_path.relative_to(download_root) parts = relative.parts anchor_index = next((idx for idx, part in enumerate(parts) if part in anchor_dirs), None) if anchor_index is None: skipped += 1 continue destination = ROOT_DIR / Path(*parts[anchor_index:]) destination.parent.mkdir(parents=True, exist_ok=True) try: shutil.copy2(file_path, destination) copied += 1 except Exception as exc: # pragma: no cover - defensive print(f"[GestureLSM] Failed to copy {file_path} -> {destination}: {exc}") if copied: print(f"[GestureLSM] Synced {copied} asset files from {download_root} to repository root") elif skipped: print(f"[GestureLSM] Skipped {skipped} files with no anchor directory match in {download_root}") _sync_downloaded_assets(downloaded_root) # Ensure expected runtime directories exist so the demo can write outputs. datasets_hub_dir = ROOT_DIR / "datasets" / "hub" for relative in ["outputs/audio2pose", "datasets/BEAT_SMPL", "datasets/hub"]: (ROOT_DIR / relative).mkdir(parents=True, exist_ok=True) smplx_dest = datasets_hub_dir / "smplx_models" smplx_dest.mkdir(parents=True, exist_ok=True) if not any(smplx_dest.iterdir()): smplx_sources = list(ASSET_CACHE_DIR.glob("**/smplx_models")) if smplx_sources: smplx_source = smplx_sources[0] for item in smplx_source.iterdir(): target = smplx_dest / item.name if item.is_dir(): shutil.copytree(item, target, dirs_exist_ok=True) else: shutil.copy2(item, target) if not any(smplx_dest.iterdir()): archives = list(ASSET_CACHE_DIR.glob("**/smplx_models.*")) for archive in archives: try: if zipfile.is_zipfile(archive): with zipfile.ZipFile(archive) as zf: zf.extractall(smplx_dest) elif tarfile.is_tarfile(archive): with tarfile.open(archive) as tf: tf.extractall(smplx_dest) except Exception as exc: print(f"[GestureLSM] Failed to extract {archive}: {exc}") if not any(smplx_dest.iterdir()): print("[GestureLSM] WARNING: smplx_models directory missing; ensure the weights repo contains it.") nested_smplx = smplx_dest / "smplx" if nested_smplx.exists() and nested_smplx.is_dir(): for item in nested_smplx.iterdir(): target = smplx_dest / item.name if item.is_dir(): shutil.copytree(item, target, dirs_exist_ok=True) shutil.rmtree(item) else: shutil.move(str(item), str(target)) try: nested_smplx.rmdir() except OSError: pass print("[GestureLSM] smplx_models contents:") for item in smplx_dest.glob("*"): print(" -", item.name) # Reuse the existing Gradio interface defined in demo.py. from demo import demo as gesture_demo # noqa: E402 if __name__ == "__main__": gesture_demo.queue(concurrency_count=1).launch(server_name="0.0.0.0", share=False)