File size: 5,380 Bytes
f54f126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import shutil
import sys
import tarfile
import zipfile
from pathlib import Path

from huggingface_hub import snapshot_download

# ---------------------------------------------------------------------------
# Optional: pull checkpoints and auxiliary assets at startup.  Set the
# HF_GESTURELSM_WEIGHTS_REPO environment variable in the Space settings to the
# dataset or model repo that hosts the pre-trained weights (e.g. "username/gesturelsm-assets").
# Files will be placed under ckpt/ so the existing config paths keep working.
# ---------------------------------------------------------------------------
BASE_DIR = Path(__file__).parent.resolve()
ROOT_DIR = BASE_DIR.parent.resolve()

# Ensure project root is on sys.path so intra-repo imports (e.g. `optimizers`) work.
if str(ROOT_DIR) not in sys.path:
    sys.path.insert(0, str(ROOT_DIR))

# Quick sanity check in Space logs to confirm the repo root is visible at runtime.
print("[GestureLSM] sys.path:")
for entry in sys.path:
    print("  ", entry)
ASSET_CACHE_DIR = BASE_DIR / "asset_cache"
ASSET_CACHE_DIR.mkdir(parents=True, exist_ok=True)

weights_repo = os.environ.get("HF_GESTURELSM_WEIGHTS_REPO", "").strip()
if weights_repo:
    downloaded_root = Path(
        snapshot_download(
            repo_id=weights_repo,
            repo_type="dataset",
            local_dir=ASSET_CACHE_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=[
                "*.pth",
                "*.pt",
                "*.bin",
                "*.npz",
                "*.npy",
                "*.tar",
                "*.tar.gz",
                "*.zip",
            ],
        )
    )
else:
    downloaded_root = ASSET_CACHE_DIR


def _sync_downloaded_assets(download_root: Path) -> None:
    """Copy all downloaded asset files into the repo root."""

    allowed_suffixes = {
        ".pth",
        ".pt",
        ".bin",
        ".npz",
        ".npy",
        ".tar",
        ".gz",
        ".zip",
    }

    anchor_dirs = {"ckpt", "mean_std", "datasets", "weights", "hf_assets"}

    copied = 0
    skipped = 0
    for file_path in download_root.rglob("*"):
        if not file_path.is_file():
            continue
        if file_path.suffix.lower() not in allowed_suffixes:
            continue
        relative = file_path.relative_to(download_root)
        parts = relative.parts
        anchor_index = next((idx for idx, part in enumerate(parts) if part in anchor_dirs), None)
        if anchor_index is None:
            skipped += 1
            continue
        destination = ROOT_DIR / Path(*parts[anchor_index:])
        destination.parent.mkdir(parents=True, exist_ok=True)
        try:
            shutil.copy2(file_path, destination)
            copied += 1
        except Exception as exc:  # pragma: no cover - defensive
            print(f"[GestureLSM] Failed to copy {file_path} -> {destination}: {exc}")

    if copied:
        print(f"[GestureLSM] Synced {copied} asset files from {download_root} to repository root")
    elif skipped:
        print(f"[GestureLSM] Skipped {skipped} files with no anchor directory match in {download_root}")



_sync_downloaded_assets(downloaded_root)

# Ensure expected runtime directories exist so the demo can write outputs.
datasets_hub_dir = ROOT_DIR / "datasets" / "hub"
for relative in ["outputs/audio2pose", "datasets/BEAT_SMPL", "datasets/hub"]:
    (ROOT_DIR / relative).mkdir(parents=True, exist_ok=True)

smplx_dest = datasets_hub_dir / "smplx_models"
smplx_dest.mkdir(parents=True, exist_ok=True)

if not any(smplx_dest.iterdir()):
    smplx_sources = list(ASSET_CACHE_DIR.glob("**/smplx_models"))
    if smplx_sources:
        smplx_source = smplx_sources[0]
        for item in smplx_source.iterdir():
            target = smplx_dest / item.name
            if item.is_dir():
                shutil.copytree(item, target, dirs_exist_ok=True)
            else:
                shutil.copy2(item, target)

if not any(smplx_dest.iterdir()):
    archives = list(ASSET_CACHE_DIR.glob("**/smplx_models.*"))
    for archive in archives:
        try:
            if zipfile.is_zipfile(archive):
                with zipfile.ZipFile(archive) as zf:
                    zf.extractall(smplx_dest)
            elif tarfile.is_tarfile(archive):
                with tarfile.open(archive) as tf:
                    tf.extractall(smplx_dest)
        except Exception as exc:
            print(f"[GestureLSM] Failed to extract {archive}: {exc}")

if not any(smplx_dest.iterdir()):
    print("[GestureLSM] WARNING: smplx_models directory missing; ensure the weights repo contains it.")

nested_smplx = smplx_dest / "smplx"
if nested_smplx.exists() and nested_smplx.is_dir():
    for item in nested_smplx.iterdir():
        target = smplx_dest / item.name
        if item.is_dir():
            shutil.copytree(item, target, dirs_exist_ok=True)
            shutil.rmtree(item)
        else:
            shutil.move(str(item), str(target))
    try:
        nested_smplx.rmdir()
    except OSError:
        pass

print("[GestureLSM] smplx_models contents:")
for item in smplx_dest.glob("*"):
    print("  -", item.name)

# Reuse the existing Gradio interface defined in demo.py.
from demo import demo as gesture_demo  # noqa: E402

if __name__ == "__main__":
    gesture_demo.queue(concurrency_count=1).launch(server_name="0.0.0.0", share=False)