AbstractPhil commited on
Commit
896c974
·
verified ·
1 Parent(s): 6534798

trainer to finish the next 10 epochs barring major errors

Browse files
Files changed (1) hide show
  1. trainer_v2.py +32 -23
trainer_v2.py CHANGED
@@ -51,7 +51,7 @@ class BaseConfig:
51
  pooling: str = "mean"
52
 
53
  # Flow training
54
- epochs: int = 10
55
  lr: float = 1e-4
56
  weight_decay: float = 1e-3
57
  grad_clip: float = 1.0
@@ -964,35 +964,44 @@ class FlowMatchDavidTrainer:
964
  steps = steps or self.cfg.sample_steps
965
  guidance = guidance if guidance is not None else self.cfg.guidance_scale
966
 
967
- # Get model dtype from student
968
- model_dtype = next(self.student.unet.parameters()).dtype
 
969
 
970
- cond_e = self.teacher.encode(prompts)
971
- uncond_e = self.teacher.encode([""]*len(prompts))
972
- sched = self.teacher.sched
973
- sched.set_timesteps(steps, device=self.device)
974
-
975
- # Create latents with correct dtype
976
- x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device, dtype=model_dtype)
 
 
 
977
 
978
- for t_scalar in sched.timesteps:
979
- t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long)
980
- v_u, _ = self.student(x_t, t, uncond_e)
981
- v_c, _ = self.student(x_t, t, cond_e)
982
- v_hat = v_u + guidance*(v_c - v_u)
983
 
984
- alpha, sigma = self.teacher.alpha_sigma(t)
985
- denom = (alpha**2 + sigma**2)
986
- x0_hat = (alpha * x_t - sigma * v_hat) / (denom + 1e-8)
987
- eps_hat = (x_t - alpha * x0_hat) / (sigma + 1e-8)
988
 
989
- step = sched.step(model_output=eps_hat, timestep=t_scalar, sample=x_t)
990
- x_t = step.prev_sample
991
 
992
- imgs = self.teacher.pipe.vae.decode(x_t / 0.18215).sample
 
 
 
 
 
 
993
  return imgs.clamp(-1,1)
994
 
995
-
996
  # =====================================================================================
997
  # 9) MAIN
998
  # =====================================================================================
 
51
  pooling: str = "mean"
52
 
53
  # Flow training
54
+ epochs: int = 20
55
  lr: float = 1e-4
56
  weight_decay: float = 1e-3
57
  grad_clip: float = 1.0
 
964
  steps = steps or self.cfg.sample_steps
965
  guidance = guidance if guidance is not None else self.cfg.guidance_scale
966
 
967
+ # Ensure student is in eval mode
968
+ was_training = self.student.training
969
+ self.student.eval()
970
 
971
+ # Use autocast to handle dtype conversions automatically
972
+ with torch.cuda.amp.autocast(enabled=self.cfg.amp):
973
+ cond_e = self.teacher.encode(prompts)
974
+ uncond_e = self.teacher.encode([""]*len(prompts))
975
+
976
+ sched = self.teacher.sched
977
+ sched.set_timesteps(steps, device=self.device)
978
+
979
+ # Create latents (autocast will handle dtype)
980
+ x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device)
981
 
982
+ for t_scalar in sched.timesteps:
983
+ t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long)
984
+ v_u, _ = self.student(x_t, t, uncond_e)
985
+ v_c, _ = self.student(x_t, t, cond_e)
986
+ v_hat = v_u + guidance*(v_c - v_u)
987
 
988
+ alpha, sigma = self.teacher.alpha_sigma(t)
989
+ denom = (alpha**2 + sigma**2)
990
+ x0_hat = (alpha * x_t - sigma * v_hat) / (denom + 1e-8)
991
+ eps_hat = (x_t - alpha * x0_hat) / (sigma + 1e-8)
992
 
993
+ step = sched.step(model_output=eps_hat, timestep=t_scalar, sample=x_t)
994
+ x_t = step.prev_sample
995
 
996
+ # Decode (keep x_t at current dtype for VAE)
997
+ imgs = self.teacher.pipe.vae.decode(x_t / 0.18215).sample
998
+
999
+ # Restore training mode
1000
+ if was_training:
1001
+ self.student.train()
1002
+
1003
  return imgs.clamp(-1,1)
1004
 
 
1005
  # =====================================================================================
1006
  # 9) MAIN
1007
  # =====================================================================================