Spaces:
Runtime error
Runtime error
Update demo.py
Browse files
demo.py
CHANGED
|
@@ -222,6 +222,8 @@ class BaseTrainer(object):
|
|
| 222 |
|
| 223 |
return vq_models
|
| 224 |
|
|
|
|
|
|
|
| 225 |
def _create_rvqvae_model(self, dim_pose: int, body_part: str) -> RVQVAE:
|
| 226 |
"""Create a single RVQVAE model with specified configuration."""
|
| 227 |
args = self.args
|
|
@@ -230,12 +232,21 @@ class BaseTrainer(object):
|
|
| 230 |
args.down_t, args.stride_t, args.width, args.depth,
|
| 231 |
args.dilation_growth_rate, args.vq_act, args.vq_norm
|
| 232 |
)
|
| 233 |
-
|
| 234 |
-
#
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
return model
|
|
|
|
| 239 |
|
| 240 |
|
| 241 |
def inverse_selection(self, filtered_t, selection_array, n):
|
|
|
|
| 222 |
|
| 223 |
return vq_models
|
| 224 |
|
| 225 |
+
|
| 226 |
+
|
| 227 |
def _create_rvqvae_model(self, dim_pose: int, body_part: str) -> RVQVAE:
|
| 228 |
"""Create a single RVQVAE model with specified configuration."""
|
| 229 |
args = self.args
|
|
|
|
| 232 |
args.down_t, args.stride_t, args.width, args.depth,
|
| 233 |
args.dilation_growth_rate, args.vq_act, args.vq_norm
|
| 234 |
)
|
| 235 |
+
|
| 236 |
+
# Base directory = folder where demo.py lives
|
| 237 |
+
base_dir = Path(__file__).resolve().parent
|
| 238 |
+
checkpoint_path = base_dir / "ckpt" / f"net_300000_{body_part}.pth"
|
| 239 |
+
|
| 240 |
+
if not checkpoint_path.exists():
|
| 241 |
+
raise FileNotFoundError(
|
| 242 |
+
f"RVQVAE checkpoint for '{body_part}' not found at '{checkpoint_path}'.\n"
|
| 243 |
+
f"CWD is {Path.cwd()}."
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
state = torch.load(str(checkpoint_path), map_location="cpu")
|
| 247 |
+
model.load_state_dict(state["net"])
|
| 248 |
return model
|
| 249 |
+
|
| 250 |
|
| 251 |
|
| 252 |
def inverse_selection(self, filtered_t, selection_array, n):
|