Tharun156 commited on
Commit
fcb87c7
Β·
verified Β·
1 Parent(s): 6a8f745

Update demo.py

Browse files
Files changed (1) hide show
  1. demo.py +16 -5
demo.py CHANGED
@@ -222,6 +222,8 @@ class BaseTrainer(object):
222
 
223
  return vq_models
224
 
 
 
225
  def _create_rvqvae_model(self, dim_pose: int, body_part: str) -> RVQVAE:
226
  """Create a single RVQVAE model with specified configuration."""
227
  args = self.args
@@ -230,12 +232,21 @@ class BaseTrainer(object):
230
  args.down_t, args.stride_t, args.width, args.depth,
231
  args.dilation_growth_rate, args.vq_act, args.vq_norm
232
  )
233
-
234
- # Load pretrained weights
235
- checkpoint_path = getattr(args, f'vqvae_{body_part}_path')
236
- state = torch.load(checkpoint_path, map_location='cpu')
237
- model.load_state_dict(state['net'])
 
 
 
 
 
 
 
 
238
  return model
 
239
 
240
 
241
  def inverse_selection(self, filtered_t, selection_array, n):
 
222
 
223
  return vq_models
224
 
225
+
226
+
227
  def _create_rvqvae_model(self, dim_pose: int, body_part: str) -> RVQVAE:
228
  """Create a single RVQVAE model with specified configuration."""
229
  args = self.args
 
232
  args.down_t, args.stride_t, args.width, args.depth,
233
  args.dilation_growth_rate, args.vq_act, args.vq_norm
234
  )
235
+
236
+ # Base directory = folder where demo.py lives
237
+ base_dir = Path(__file__).resolve().parent
238
+ checkpoint_path = base_dir / "ckpt" / f"net_300000_{body_part}.pth"
239
+
240
+ if not checkpoint_path.exists():
241
+ raise FileNotFoundError(
242
+ f"RVQVAE checkpoint for '{body_part}' not found at '{checkpoint_path}'.\n"
243
+ f"CWD is {Path.cwd()}."
244
+ )
245
+
246
+ state = torch.load(str(checkpoint_path), map_location="cpu")
247
+ model.load_state_dict(state["net"])
248
  return model
249
+
250
 
251
 
252
  def inverse_selection(self, filtered_t, selection_array, n):