Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| import sys | |
| import tarfile | |
| import zipfile | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| # --------------------------------------------------------------------------- | |
| # Optional: pull checkpoints and auxiliary assets at startup. Set the | |
| # HF_GESTURELSM_WEIGHTS_REPO environment variable in the Space settings to the | |
| # dataset or model repo that hosts the pre-trained weights (e.g. "username/gesturelsm-assets"). | |
| # Files will be placed under ckpt/ so the existing config paths keep working. | |
| # --------------------------------------------------------------------------- | |
| BASE_DIR = Path(__file__).parent.resolve() | |
| ROOT_DIR = BASE_DIR.parent.resolve() | |
| # Ensure project root is on sys.path so intra-repo imports (e.g. `optimizers`) work. | |
| if str(ROOT_DIR) not in sys.path: | |
| sys.path.insert(0, str(ROOT_DIR)) | |
| # Quick sanity check in Space logs to confirm the repo root is visible at runtime. | |
| print("[GestureLSM] sys.path:") | |
| for entry in sys.path: | |
| print(" ", entry) | |
| ASSET_CACHE_DIR = BASE_DIR / "asset_cache" | |
| ASSET_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| weights_repo = os.environ.get("HF_GESTURELSM_WEIGHTS_REPO", "").strip() | |
| if weights_repo: | |
| downloaded_root = Path( | |
| snapshot_download( | |
| repo_id=weights_repo, | |
| repo_type="dataset", | |
| local_dir=ASSET_CACHE_DIR, | |
| local_dir_use_symlinks=False, | |
| allow_patterns=[ | |
| "*.pth", | |
| "*.pt", | |
| "*.bin", | |
| "*.npz", | |
| "*.npy", | |
| "*.tar", | |
| "*.tar.gz", | |
| "*.zip", | |
| ], | |
| ) | |
| ) | |
| else: | |
| downloaded_root = ASSET_CACHE_DIR | |
| def _sync_downloaded_assets(download_root: Path) -> None: | |
| """Copy all downloaded asset files into the repo root.""" | |
| allowed_suffixes = { | |
| ".pth", | |
| ".pt", | |
| ".bin", | |
| ".npz", | |
| ".npy", | |
| ".tar", | |
| ".gz", | |
| ".zip", | |
| } | |
| anchor_dirs = {"ckpt", "mean_std", "datasets", "weights", "hf_assets"} | |
| copied = 0 | |
| skipped = 0 | |
| for file_path in download_root.rglob("*"): | |
| if not file_path.is_file(): | |
| continue | |
| if file_path.suffix.lower() not in allowed_suffixes: | |
| continue | |
| relative = file_path.relative_to(download_root) | |
| parts = relative.parts | |
| anchor_index = next((idx for idx, part in enumerate(parts) if part in anchor_dirs), None) | |
| if anchor_index is None: | |
| skipped += 1 | |
| continue | |
| destination = ROOT_DIR / Path(*parts[anchor_index:]) | |
| destination.parent.mkdir(parents=True, exist_ok=True) | |
| try: | |
| shutil.copy2(file_path, destination) | |
| copied += 1 | |
| except Exception as exc: # pragma: no cover - defensive | |
| print(f"[GestureLSM] Failed to copy {file_path} -> {destination}: {exc}") | |
| if copied: | |
| print(f"[GestureLSM] Synced {copied} asset files from {download_root} to repository root") | |
| elif skipped: | |
| print(f"[GestureLSM] Skipped {skipped} files with no anchor directory match in {download_root}") | |
| _sync_downloaded_assets(downloaded_root) | |
| # Ensure expected runtime directories exist so the demo can write outputs. | |
| datasets_hub_dir = ROOT_DIR / "datasets" / "hub" | |
| for relative in ["outputs/audio2pose", "datasets/BEAT_SMPL", "datasets/hub"]: | |
| (ROOT_DIR / relative).mkdir(parents=True, exist_ok=True) | |
| smplx_dest = datasets_hub_dir / "smplx_models" | |
| smplx_dest.mkdir(parents=True, exist_ok=True) | |
| if not any(smplx_dest.iterdir()): | |
| smplx_sources = list(ASSET_CACHE_DIR.glob("**/smplx_models")) | |
| if smplx_sources: | |
| smplx_source = smplx_sources[0] | |
| for item in smplx_source.iterdir(): | |
| target = smplx_dest / item.name | |
| if item.is_dir(): | |
| shutil.copytree(item, target, dirs_exist_ok=True) | |
| else: | |
| shutil.copy2(item, target) | |
| if not any(smplx_dest.iterdir()): | |
| archives = list(ASSET_CACHE_DIR.glob("**/smplx_models.*")) | |
| for archive in archives: | |
| try: | |
| if zipfile.is_zipfile(archive): | |
| with zipfile.ZipFile(archive) as zf: | |
| zf.extractall(smplx_dest) | |
| elif tarfile.is_tarfile(archive): | |
| with tarfile.open(archive) as tf: | |
| tf.extractall(smplx_dest) | |
| except Exception as exc: | |
| print(f"[GestureLSM] Failed to extract {archive}: {exc}") | |
| if not any(smplx_dest.iterdir()): | |
| print("[GestureLSM] WARNING: smplx_models directory missing; ensure the weights repo contains it.") | |
| nested_smplx = smplx_dest / "smplx" | |
| if nested_smplx.exists() and nested_smplx.is_dir(): | |
| for item in nested_smplx.iterdir(): | |
| target = smplx_dest / item.name | |
| if item.is_dir(): | |
| shutil.copytree(item, target, dirs_exist_ok=True) | |
| shutil.rmtree(item) | |
| else: | |
| shutil.move(str(item), str(target)) | |
| try: | |
| nested_smplx.rmdir() | |
| except OSError: | |
| pass | |
| print("[GestureLSM] smplx_models contents:") | |
| for item in smplx_dest.glob("*"): | |
| print(" -", item.name) | |
| # Reuse the existing Gradio interface defined in demo.py. | |
| from demo import demo as gesture_demo # noqa: E402 | |
| if __name__ == "__main__": | |
| gesture_demo.queue(concurrency_count=1).launch(server_name="0.0.0.0", share=False) | |