GestureLSM / app.py
Tharun156's picture
Update app.py
f54f126 verified
import os
import shutil
import sys
import tarfile
import zipfile
from pathlib import Path
from huggingface_hub import snapshot_download
# ---------------------------------------------------------------------------
# Optional: pull checkpoints and auxiliary assets at startup. Set the
# HF_GESTURELSM_WEIGHTS_REPO environment variable in the Space settings to the
# dataset or model repo that hosts the pre-trained weights (e.g. "username/gesturelsm-assets").
# Files will be placed under ckpt/ so the existing config paths keep working.
# ---------------------------------------------------------------------------
BASE_DIR = Path(__file__).parent.resolve()
ROOT_DIR = BASE_DIR.parent.resolve()
# Ensure project root is on sys.path so intra-repo imports (e.g. `optimizers`) work.
if str(ROOT_DIR) not in sys.path:
sys.path.insert(0, str(ROOT_DIR))
# Quick sanity check in Space logs to confirm the repo root is visible at runtime.
print("[GestureLSM] sys.path:")
for entry in sys.path:
print(" ", entry)
ASSET_CACHE_DIR = BASE_DIR / "asset_cache"
ASSET_CACHE_DIR.mkdir(parents=True, exist_ok=True)
weights_repo = os.environ.get("HF_GESTURELSM_WEIGHTS_REPO", "").strip()
if weights_repo:
downloaded_root = Path(
snapshot_download(
repo_id=weights_repo,
repo_type="dataset",
local_dir=ASSET_CACHE_DIR,
local_dir_use_symlinks=False,
allow_patterns=[
"*.pth",
"*.pt",
"*.bin",
"*.npz",
"*.npy",
"*.tar",
"*.tar.gz",
"*.zip",
],
)
)
else:
downloaded_root = ASSET_CACHE_DIR
def _sync_downloaded_assets(download_root: Path) -> None:
"""Copy all downloaded asset files into the repo root."""
allowed_suffixes = {
".pth",
".pt",
".bin",
".npz",
".npy",
".tar",
".gz",
".zip",
}
anchor_dirs = {"ckpt", "mean_std", "datasets", "weights", "hf_assets"}
copied = 0
skipped = 0
for file_path in download_root.rglob("*"):
if not file_path.is_file():
continue
if file_path.suffix.lower() not in allowed_suffixes:
continue
relative = file_path.relative_to(download_root)
parts = relative.parts
anchor_index = next((idx for idx, part in enumerate(parts) if part in anchor_dirs), None)
if anchor_index is None:
skipped += 1
continue
destination = ROOT_DIR / Path(*parts[anchor_index:])
destination.parent.mkdir(parents=True, exist_ok=True)
try:
shutil.copy2(file_path, destination)
copied += 1
except Exception as exc: # pragma: no cover - defensive
print(f"[GestureLSM] Failed to copy {file_path} -> {destination}: {exc}")
if copied:
print(f"[GestureLSM] Synced {copied} asset files from {download_root} to repository root")
elif skipped:
print(f"[GestureLSM] Skipped {skipped} files with no anchor directory match in {download_root}")
_sync_downloaded_assets(downloaded_root)
# Ensure expected runtime directories exist so the demo can write outputs.
datasets_hub_dir = ROOT_DIR / "datasets" / "hub"
for relative in ["outputs/audio2pose", "datasets/BEAT_SMPL", "datasets/hub"]:
(ROOT_DIR / relative).mkdir(parents=True, exist_ok=True)
smplx_dest = datasets_hub_dir / "smplx_models"
smplx_dest.mkdir(parents=True, exist_ok=True)
if not any(smplx_dest.iterdir()):
smplx_sources = list(ASSET_CACHE_DIR.glob("**/smplx_models"))
if smplx_sources:
smplx_source = smplx_sources[0]
for item in smplx_source.iterdir():
target = smplx_dest / item.name
if item.is_dir():
shutil.copytree(item, target, dirs_exist_ok=True)
else:
shutil.copy2(item, target)
if not any(smplx_dest.iterdir()):
archives = list(ASSET_CACHE_DIR.glob("**/smplx_models.*"))
for archive in archives:
try:
if zipfile.is_zipfile(archive):
with zipfile.ZipFile(archive) as zf:
zf.extractall(smplx_dest)
elif tarfile.is_tarfile(archive):
with tarfile.open(archive) as tf:
tf.extractall(smplx_dest)
except Exception as exc:
print(f"[GestureLSM] Failed to extract {archive}: {exc}")
if not any(smplx_dest.iterdir()):
print("[GestureLSM] WARNING: smplx_models directory missing; ensure the weights repo contains it.")
nested_smplx = smplx_dest / "smplx"
if nested_smplx.exists() and nested_smplx.is_dir():
for item in nested_smplx.iterdir():
target = smplx_dest / item.name
if item.is_dir():
shutil.copytree(item, target, dirs_exist_ok=True)
shutil.rmtree(item)
else:
shutil.move(str(item), str(target))
try:
nested_smplx.rmdir()
except OSError:
pass
print("[GestureLSM] smplx_models contents:")
for item in smplx_dest.glob("*"):
print(" -", item.name)
# Reuse the existing Gradio interface defined in demo.py.
from demo import demo as gesture_demo # noqa: E402
if __name__ == "__main__":
gesture_demo.queue(concurrency_count=1).launch(server_name="0.0.0.0", share=False)