Tharun156 commited on
Commit
aed272e
Β·
verified Β·
1 Parent(s): f58a55b

Update dataloaders/beat_sep_single.py

Browse files
Files changed (1) hide show
  1. dataloaders/beat_sep_single.py +32 -22
dataloaders/beat_sep_single.py CHANGED
@@ -340,31 +340,41 @@ class CustomDataset(Dataset):
340
  if self.args.word_rep is not None:
341
  logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #")
342
  word_file = self.textgrid_file_path
343
- if not os.path.exists(word_file):
344
- logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #")
345
- self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
346
  word_save_path = f"{self.data_dir}{self.args.t_pre_encoder}/{id_pose}.npy"
347
 
348
- tgrid = tg.TextGrid.fromFile(word_file)
 
 
349
 
350
- for i in range(pose_each_file.shape[0]):
351
- found_flag = False
352
- current_time = i/self.args.pose_fps + time_offset
353
- j_last = 0
354
- for j, word in enumerate(tgrid[0]):
355
- word_n, word_s, word_e = word.mark, word.minTime, word.maxTime
356
- if word_s<=current_time and current_time<=word_e:
357
- if word_n == " ":
358
- word_each_file.append(self.lang_model.PAD_token)
359
- else:
360
- word_each_file.append(self.lang_model.get_word_index(word_n))
361
- found_flag = True
362
- j_last = j
363
- break
364
- else: continue
365
- if not found_flag:
366
- word_each_file.append(self.lang_model.UNK_token)
367
- word_each_file = np.array(word_each_file)
 
 
 
 
 
 
 
 
 
 
 
368
 
369
 
370
 
 
340
  if self.args.word_rep is not None:
341
  logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #")
342
  word_file = self.textgrid_file_path
 
 
 
343
  word_save_path = f"{self.data_dir}{self.args.t_pre_encoder}/{id_pose}.npy"
344
 
345
+ def _fallback_word_tokens(length: int) -> np.ndarray:
346
+ token = getattr(self.lang_model, "PAD_token", 0) if self.lang_model else 0
347
+ return np.full((length,), token, dtype=np.int64)
348
 
349
+ if not os.path.exists(word_file):
350
+ logger.warning(
351
+ f"# ---- TextGrid not found for Word {id_pose}; using fallback tokens ---- #"
352
+ )
353
+ word_each_file = _fallback_word_tokens(pose_each_file.shape[0])
354
+ else:
355
+ try:
356
+ tgrid = tg.TextGrid.fromFile(word_file)
357
+ except Exception as exc:
358
+ logger.warning(
359
+ f"# ---- Failed to load TextGrid for Word {id_pose}: {exc}; using fallback tokens ---- #"
360
+ )
361
+ word_each_file = _fallback_word_tokens(pose_each_file.shape[0])
362
+ else:
363
+ for i in range(pose_each_file.shape[0]):
364
+ found_flag = False
365
+ current_time = i/self.args.pose_fps + time_offset
366
+ for word in tgrid[0]:
367
+ word_n, word_s, word_e = word.mark, word.minTime, word.maxTime
368
+ if word_s <= current_time <= word_e:
369
+ if word_n == " ":
370
+ word_each_file.append(self.lang_model.PAD_token)
371
+ else:
372
+ word_each_file.append(self.lang_model.get_word_index(word_n))
373
+ found_flag = True
374
+ break
375
+ if not found_flag:
376
+ word_each_file.append(self.lang_model.UNK_token)
377
+ word_each_file = np.array(word_each_file)
378
 
379
 
380