trainer to finish the next 10 epochs barring major errors
Browse files- trainer_v2.py +32 -23
trainer_v2.py
CHANGED
|
@@ -51,7 +51,7 @@ class BaseConfig:
|
|
| 51 |
pooling: str = "mean"
|
| 52 |
|
| 53 |
# Flow training
|
| 54 |
-
epochs: int =
|
| 55 |
lr: float = 1e-4
|
| 56 |
weight_decay: float = 1e-3
|
| 57 |
grad_clip: float = 1.0
|
|
@@ -964,35 +964,44 @@ class FlowMatchDavidTrainer:
|
|
| 964 |
steps = steps or self.cfg.sample_steps
|
| 965 |
guidance = guidance if guidance is not None else self.cfg.guidance_scale
|
| 966 |
|
| 967 |
-
#
|
| 968 |
-
|
|
|
|
| 969 |
|
| 970 |
-
|
| 971 |
-
|
| 972 |
-
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
|
|
|
|
|
|
|
|
|
| 977 |
|
| 978 |
-
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
|
| 983 |
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
|
| 989 |
-
|
| 990 |
-
|
| 991 |
|
| 992 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 993 |
return imgs.clamp(-1,1)
|
| 994 |
|
| 995 |
-
|
| 996 |
# =====================================================================================
|
| 997 |
# 9) MAIN
|
| 998 |
# =====================================================================================
|
|
|
|
| 51 |
pooling: str = "mean"
|
| 52 |
|
| 53 |
# Flow training
|
| 54 |
+
epochs: int = 20
|
| 55 |
lr: float = 1e-4
|
| 56 |
weight_decay: float = 1e-3
|
| 57 |
grad_clip: float = 1.0
|
|
|
|
| 964 |
steps = steps or self.cfg.sample_steps
|
| 965 |
guidance = guidance if guidance is not None else self.cfg.guidance_scale
|
| 966 |
|
| 967 |
+
# Ensure student is in eval mode
|
| 968 |
+
was_training = self.student.training
|
| 969 |
+
self.student.eval()
|
| 970 |
|
| 971 |
+
# Use autocast to handle dtype conversions automatically
|
| 972 |
+
with torch.cuda.amp.autocast(enabled=self.cfg.amp):
|
| 973 |
+
cond_e = self.teacher.encode(prompts)
|
| 974 |
+
uncond_e = self.teacher.encode([""]*len(prompts))
|
| 975 |
+
|
| 976 |
+
sched = self.teacher.sched
|
| 977 |
+
sched.set_timesteps(steps, device=self.device)
|
| 978 |
+
|
| 979 |
+
# Create latents (autocast will handle dtype)
|
| 980 |
+
x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device)
|
| 981 |
|
| 982 |
+
for t_scalar in sched.timesteps:
|
| 983 |
+
t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long)
|
| 984 |
+
v_u, _ = self.student(x_t, t, uncond_e)
|
| 985 |
+
v_c, _ = self.student(x_t, t, cond_e)
|
| 986 |
+
v_hat = v_u + guidance*(v_c - v_u)
|
| 987 |
|
| 988 |
+
alpha, sigma = self.teacher.alpha_sigma(t)
|
| 989 |
+
denom = (alpha**2 + sigma**2)
|
| 990 |
+
x0_hat = (alpha * x_t - sigma * v_hat) / (denom + 1e-8)
|
| 991 |
+
eps_hat = (x_t - alpha * x0_hat) / (sigma + 1e-8)
|
| 992 |
|
| 993 |
+
step = sched.step(model_output=eps_hat, timestep=t_scalar, sample=x_t)
|
| 994 |
+
x_t = step.prev_sample
|
| 995 |
|
| 996 |
+
# Decode (keep x_t at current dtype for VAE)
|
| 997 |
+
imgs = self.teacher.pipe.vae.decode(x_t / 0.18215).sample
|
| 998 |
+
|
| 999 |
+
# Restore training mode
|
| 1000 |
+
if was_training:
|
| 1001 |
+
self.student.train()
|
| 1002 |
+
|
| 1003 |
return imgs.clamp(-1,1)
|
| 1004 |
|
|
|
|
| 1005 |
# =====================================================================================
|
| 1006 |
# 9) MAIN
|
| 1007 |
# =====================================================================================
|