AbstractPhil commited on
Commit
6534798
·
verified ·
1 Parent(s): 96e03fa

Trainer v2 update, necessary elements included

Browse files
Files changed (1) hide show
  1. trainer_v2.py +296 -147
trainer_v2.py CHANGED
@@ -1,7 +1,7 @@
1
  # =====================================================================================
2
  # SD1.5 Flow-Matching Trainer — David-Driven Adaptive Timestep Sampling
3
  # Quartermaster: Mirel
4
- # NEW: David-weighted timesteps + SD3 shift + adaptive chaos
5
  # =====================================================================================
6
  from __future__ import annotations
7
  import os, json, math, random, re, shutil
@@ -86,7 +86,8 @@ class BaseConfig:
86
  timestep_shift: float = 3.0 # SD3-style shift (higher = bias toward clean)
87
  base_jitter: int = 5 # Base ±jitter around bin center
88
  adaptive_chaos: bool = True # Scale jitter by pattern difficulty
89
- profile_samples: int = 500 # Samples to profile David's difficulty
 
90
 
91
  # Scheduler
92
  num_train_timesteps: int = 1000
@@ -114,13 +115,16 @@ class BaseConfig:
114
  class DavidWeightedTimestepSampler:
115
  """
116
  Samples timesteps weighted by David's inherent difficulty + SD3 shift + adaptive chaos.
 
 
117
  """
118
- def __init__(self, num_timesteps=1000, num_bins=100, shift=3.0, base_jitter=5, adaptive_chaos=True):
119
  self.num_timesteps = num_timesteps
120
  self.num_bins = num_bins
121
  self.shift = shift
122
  self.base_jitter = base_jitter
123
  self.adaptive_chaos = adaptive_chaos
 
124
 
125
  self.difficulty_weights = None # Timestep difficulty
126
  self.pattern_difficulty = None # Pattern confusion per bin
@@ -162,38 +166,38 @@ class DavidWeightedTimestepSampler:
162
  # Pool features
163
  pooled = {name: f.mean(dim=(2, 3)) for name, f in feats.items()}
164
 
165
- # Get David's outputs
166
  outputs = david(pooled, t.float())
167
 
 
 
 
 
168
  # 1. Timestep difficulty (from classification error)
169
- ts_key = None
170
- for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]:
171
- if key in outputs:
172
- ts_key = key
173
- break
174
 
175
- if ts_key:
176
- ts_logits = outputs[ts_key]
177
- if isinstance(ts_logits, dict):
178
- ts_logits = torch.stack(list(ts_logits.values())).mean(0)
179
-
180
  preds = ts_logits.argmax(dim=-1)
 
181
  for pred, true_bin in zip(preds, t_bins):
182
  bin_idx = true_bin.item()
183
  correct_per_bin[bin_idx] += (pred == true_bin).float().item()
184
  total_per_bin[bin_idx] += 1
185
 
186
  # 2. Pattern difficulty (from entropy)
187
- pt_key = None
188
- for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]:
189
- if key in outputs:
190
- pt_key = key
191
- break
192
 
193
- if pt_key:
194
- pt_logits = outputs[pt_key]
195
- if isinstance(pt_logits, dict):
196
- pt_logits = torch.stack(list(pt_logits.values())).mean(0)
197
 
198
  P = pt_logits.softmax(-1)
199
  ent = -(P * P.clamp_min(1e-9).log()).sum(-1)
@@ -204,23 +208,78 @@ class DavidWeightedTimestepSampler:
204
  entropy_per_bin[bin_idx] += norm_ent[i].item()
205
  entropy_count_per_bin[bin_idx] += 1
206
 
207
- # Compute timestep difficulty (inverse of accuracy)
208
  accuracy_per_bin = correct_per_bin / (total_per_bin.clamp(min=1))
209
- timestep_difficulty = (1.0 - accuracy_per_bin) + 0.1 # Higher = harder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  self.difficulty_weights = timestep_difficulty / timestep_difficulty.sum()
211
 
212
  # Compute pattern difficulty (average entropy per bin)
213
  self.pattern_difficulty = entropy_per_bin / (entropy_count_per_bin.clamp(min=1))
214
  self.pattern_difficulty = self.pattern_difficulty.clamp(min=0.1, max=1.0)
215
 
216
- print(f"✓ David difficulty map computed:")
217
- print(f" Avg timestep accuracy: {accuracy_per_bin.mean():.2%}")
218
- print(f" Hardest timestep bin: {accuracy_per_bin.argmin().item()} ({accuracy_per_bin.min():.2%} acc)")
219
- print(f" Easiest timestep bin: {accuracy_per_bin.argmax().item()} ({accuracy_per_bin.max():.2%} acc)")
220
- print(f" Avg pattern entropy: {self.pattern_difficulty.mean():.3f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  return self.difficulty_weights
223
-
224
  def sample(self, batch_size: int) -> List[int]:
225
  """Sample timesteps with David weighting + shift + adaptive chaos."""
226
  if self.difficulty_weights is None:
@@ -424,6 +483,12 @@ class DavidLoader:
424
  cfg.block_weights = self.hf_config["block_weights"]
425
 
426
  class DavidAssessor(nn.Module):
 
 
 
 
 
 
427
  def __init__(self, gdc: GeoDavidCollective, pooling: str):
428
  super().__init__()
429
  self.gdc = gdc
@@ -435,56 +500,88 @@ class DavidAssessor(nn.Module):
435
  @torch.no_grad()
436
  def forward(self, feats_student: Dict[str, torch.Tensor], t: torch.LongTensor
437
  ) -> Tuple[Dict[str,float], Dict[str,float], Dict[str,float]]:
 
 
 
 
 
 
 
 
 
438
  Zs = self._pool(feats_student)
 
 
 
439
  outs = self.gdc(Zs, t.float())
 
 
440
  e_t, e_p, coh = {}, {}, {}
441
-
442
- ts_key = None
443
- for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]:
444
- if key in outs: ts_key = key; break
445
 
446
- pt_key = None
447
- for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]:
448
- if key in outs: pt_key = key; break
449
-
450
  t_bins = (t // 10).to(next(self.gdc.parameters()).device)
451
- if ts_key is not None:
452
- ts_logits = outs[ts_key]
453
- if isinstance(ts_logits, dict):
454
- for name, L in ts_logits.items():
455
- ce = F.cross_entropy(L, t_bins, reduction="mean")
456
- e_t[name] = float(ce.item())
457
- else:
458
  ce = F.cross_entropy(ts_logits, t_bins, reduction="mean")
459
- for name in Zs.keys():
460
- e_t[name] = float(ce.item())
461
- else:
462
- for name in Zs.keys(): e_t[name] = 0.0
463
-
464
- if pt_key is not None:
465
- pt_logits = outs[pt_key]
466
- if isinstance(pt_logits, dict):
467
- for name, L in pt_logits.items():
468
- P = L.softmax(-1)
469
- ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean()
470
- e_p[name] = float(ent.item() / math.log(P.shape[-1]))
471
- else:
 
 
472
  P = pt_logits.softmax(-1)
473
  ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean()
474
- for name in Zs.keys():
475
- e_p[name] = float(ent.item() / math.log(P.shape[-1]))
476
- else:
477
- for name in Zs.keys(): e_p[name] = 0.0
478
-
479
- alphas = {}
 
 
 
 
 
 
480
  try:
481
  alphas = self.gdc.get_cantor_alphas()
 
 
 
 
 
 
 
 
482
  except Exception:
483
- alphas = {}
484
- avg_alpha = float(sum(alphas.values())/max(len(alphas),1)) if alphas else 1.0
 
 
 
485
  for name in Zs.keys():
486
- coh[name] = avg_alpha
487
-
 
 
 
 
 
 
488
  return e_t, e_p, coh
489
 
490
  class BlockPenaltyFusion:
@@ -528,13 +625,15 @@ class FlowMatchDavidTrainer:
528
  print(f" SD3 shift: {cfg.timestep_shift}")
529
  print(f" Base jitter: ±{cfg.base_jitter}")
530
  print(f" Adaptive chaos: {cfg.adaptive_chaos}")
 
531
 
532
  self.timestep_sampler = DavidWeightedTimestepSampler(
533
  num_timesteps=cfg.num_train_timesteps,
534
  num_bins=100,
535
  shift=cfg.timestep_shift if cfg.use_david_weights else 0.0,
536
  base_jitter=cfg.base_jitter,
537
- adaptive_chaos=cfg.adaptive_chaos
 
538
  )
539
 
540
  if cfg.use_david_weights:
@@ -558,95 +657,43 @@ class FlowMatchDavidTrainer:
558
  self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader))
559
  self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
560
 
561
- # Load checkpoints
562
- emergency_path = Path("./EMERGENCY_SAVE_SUCCESS.pt")
563
- if not emergency_path.exists():
564
- print("\n🔍 Emergency checkpoint not found locally, checking HuggingFace...")
565
- emergency_path = self._download_emergency_checkpoint()
566
-
567
- if emergency_path and emergency_path.exists():
568
- self._load_emergency_checkpoint(emergency_path)
569
- elif cfg.continue_training:
570
  self._load_latest_from_hf()
571
 
572
  self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name))
573
 
574
- def _download_emergency_checkpoint(self) -> Optional[Path]:
575
- """Download emergency checkpoint from HuggingFace backup repo."""
576
- emergency_repo = "AbstractPhil/sd15-flow-emergency-backup"
577
- emergency_file = "EMERGENCY_SAVE_SUCCESS.pt"
578
-
579
- try:
580
- print(f"📥 Downloading emergency checkpoint from {emergency_repo}...")
581
- local_path = hf_hub_download(
582
- repo_id=emergency_repo,
583
- filename=emergency_file,
584
- repo_type="model",
585
- cache_dir="./_emergency_cache"
586
- )
587
-
588
- target_path = Path("./EMERGENCY_SAVE_SUCCESS.pt")
589
- shutil.copy(local_path, target_path)
590
-
591
- size_mb = target_path.stat().st_size / 1e6
592
- print(f"✅ Downloaded emergency checkpoint ({size_mb:.1f} MB)")
593
- return target_path
594
-
595
- except Exception as e:
596
- print(f"⚠️ Could not download emergency checkpoint: {e}")
597
- return None
598
-
599
- def _load_emergency_checkpoint(self, path: Path):
600
- """Load emergency checkpoint with student_unet structure."""
601
- try:
602
- print(f"\n🚨 Found emergency checkpoint: {path}")
603
- checkpoint = torch.load(path, map_location='cpu')
604
-
605
- if 'student_unet' in checkpoint:
606
- print("📦 Loading emergency checkpoint format...")
607
- missing, unexpected = self.student.unet.load_state_dict(checkpoint['student_unet'], strict=False)
608
- print(f"✓ Loaded student UNet")
609
-
610
- if 'opt' in checkpoint:
611
- self.opt.load_state_dict(checkpoint['opt'])
612
- print("✓ Loaded optimizer state")
613
-
614
- if 'sched' in checkpoint:
615
- self.sched.load_state_dict(checkpoint['sched'])
616
- print("✓ Loaded scheduler state")
617
-
618
- if 'gstep' in checkpoint:
619
- self.start_gstep = checkpoint['gstep']
620
- self.start_epoch = self.start_gstep // len(self.loader)
621
- print(f"✓ Resuming from global step {self.start_gstep} (epoch ~{self.start_epoch})")
622
-
623
- print("✅ Emergency checkpoint loaded successfully!")
624
-
625
- except Exception as e:
626
- print(f"⚠️ Failed to load emergency checkpoint: {e}")
627
-
628
  def _load_latest_from_hf(self):
 
629
  if not self.cfg.hf_repo_id:
 
630
  return
631
 
632
  try:
633
  api = HfApi()
634
  print(f"\n🔍 Searching for latest checkpoint in {self.cfg.hf_repo_id}...")
635
 
 
636
  files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model")
 
 
637
  epochs = []
638
  for f in files:
639
- if f.endswith('.pt'):
640
  match = re.search(r'_e(\d+)\.pt$', f)
641
  if match:
642
- epochs.append((int(match.group(1)), f))
 
643
 
644
  if not epochs:
 
645
  return
646
 
 
647
  latest_epoch, latest_file = max(epochs, key=lambda x: x[0])
648
- print(f"📥 Downloading: {latest_file}")
649
 
 
650
  local_path = hf_hub_download(
651
  repo_id=self.cfg.hf_repo_id,
652
  filename=latest_file,
@@ -654,27 +701,58 @@ class FlowMatchDavidTrainer:
654
  cache_dir=self.cfg.ckpt_dir
655
  )
656
 
 
 
657
  checkpoint = torch.load(local_path, map_location='cpu')
658
 
659
- if 'student_unet' in checkpoint:
660
- self.student.unet.load_state_dict(checkpoint['student_unet'], strict=False)
661
- elif 'student' in checkpoint:
662
- self.student.load_state_dict(checkpoint['student'], strict=False)
 
 
 
 
 
 
 
663
 
 
664
  if 'opt' in checkpoint:
665
- self.opt.load_state_dict(checkpoint['opt'])
 
 
 
 
 
 
666
  if 'sched' in checkpoint:
667
- self.sched.load_state_dict(checkpoint['sched'])
 
 
 
 
668
 
669
- self.start_epoch = latest_epoch
670
- self.start_gstep = latest_epoch * len(self.loader)
 
 
 
 
 
 
 
 
671
 
672
- print(f"✅ Resuming from epoch {self.start_epoch + 1}")
673
  del checkpoint
674
  torch.cuda.empty_cache()
675
 
 
 
676
  except Exception as e:
677
- print(f"⚠️ Failed to load from HF: {e}")
 
678
 
679
  def _v_star(self, x_t, t, eps_hat):
680
  alpha, sigma = self.teacher.alpha_sigma(t)
@@ -692,7 +770,45 @@ class FlowMatchDavidTrainer:
692
  cfg = self.cfg
693
  gstep = self.start_gstep
694
 
 
 
 
 
 
 
 
695
  for ep in range(self.start_epoch, cfg.epochs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696
  self.student.train()
697
  pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}",
698
  dynamic_ncols=True, leave=True, position=0)
@@ -778,6 +894,33 @@ class FlowMatchDavidTrainer:
778
  self._save(ep+1, gstep)
779
 
780
  self._save("final", gstep)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
  self.writer.close()
782
 
783
  def _save(self, tag, gstep):
@@ -820,11 +963,17 @@ class FlowMatchDavidTrainer:
820
  def sample(self, prompts: List[str], steps: Optional[int]=None, guidance: Optional[float]=None) -> torch.Tensor:
821
  steps = steps or self.cfg.sample_steps
822
  guidance = guidance if guidance is not None else self.cfg.guidance_scale
 
 
 
 
823
  cond_e = self.teacher.encode(prompts)
824
  uncond_e = self.teacher.encode([""]*len(prompts))
825
  sched = self.teacher.sched
826
  sched.set_timesteps(steps, device=self.device)
827
- x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device)
 
 
828
 
829
  for t_scalar in sched.timesteps:
830
  t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long)
 
1
  # =====================================================================================
2
  # SD1.5 Flow-Matching Trainer — David-Driven Adaptive Timestep Sampling
3
  # Quartermaster: Mirel
4
+ # FIXED: David nested output handling + reliability filtering + clean checkpoint loading
5
  # =====================================================================================
6
  from __future__ import annotations
7
  import os, json, math, random, re, shutil
 
86
  timestep_shift: float = 3.0 # SD3-style shift (higher = bias toward clean)
87
  base_jitter: int = 5 # Base ±jitter around bin center
88
  adaptive_chaos: bool = True # Scale jitter by pattern difficulty
89
+ profile_samples: int = 2500 # Samples to profile David's difficulty
90
+ reliability_threshold: float = 0.15 # Minimum accuracy to trust David's guidance
91
 
92
  # Scheduler
93
  num_train_timesteps: int = 1000
 
115
  class DavidWeightedTimestepSampler:
116
  """
117
  Samples timesteps weighted by David's inherent difficulty + SD3 shift + adaptive chaos.
118
+ FIXED: Properly handles nested GeoDavidCollective output structure.
119
+ FIXED: Filters out unreliable bins (accuracy < threshold).
120
  """
121
+ def __init__(self, num_timesteps=1000, num_bins=100, shift=3.0, base_jitter=5, adaptive_chaos=True, reliability_threshold=0.15):
122
  self.num_timesteps = num_timesteps
123
  self.num_bins = num_bins
124
  self.shift = shift
125
  self.base_jitter = base_jitter
126
  self.adaptive_chaos = adaptive_chaos
127
+ self.reliability_threshold = reliability_threshold
128
 
129
  self.difficulty_weights = None # Timestep difficulty
130
  self.pattern_difficulty = None # Pattern confusion per bin
 
166
  # Pool features
167
  pooled = {name: f.mean(dim=(2, 3)) for name, f in feats.items()}
168
 
169
+ # Get David's outputs (NESTED STRUCTURE!)
170
  outputs = david(pooled, t.float())
171
 
172
+ # ================================================================
173
+ # FIXED: Aggregate across blocks
174
+ # ================================================================
175
+
176
  # 1. Timestep difficulty (from classification error)
177
+ timestep_logits_list = []
178
+ for block_name, block_out in outputs.items():
179
+ if 'timestep_logits' in block_out:
180
+ timestep_logits_list.append(block_out['timestep_logits'])
 
181
 
182
+ if timestep_logits_list:
183
+ # Average predictions across blocks
184
+ ts_logits = torch.stack(timestep_logits_list).mean(0)
 
 
185
  preds = ts_logits.argmax(dim=-1)
186
+
187
  for pred, true_bin in zip(preds, t_bins):
188
  bin_idx = true_bin.item()
189
  correct_per_bin[bin_idx] += (pred == true_bin).float().item()
190
  total_per_bin[bin_idx] += 1
191
 
192
  # 2. Pattern difficulty (from entropy)
193
+ pattern_logits_list = []
194
+ for block_name, block_out in outputs.items():
195
+ if 'pattern_logits' in block_out:
196
+ pattern_logits_list.append(block_out['pattern_logits'])
 
197
 
198
+ if pattern_logits_list:
199
+ # Average predictions across blocks
200
+ pt_logits = torch.stack(pattern_logits_list).mean(0)
 
201
 
202
  P = pt_logits.softmax(-1)
203
  ent = -(P * P.clamp_min(1e-9).log()).sum(-1)
 
208
  entropy_per_bin[bin_idx] += norm_ent[i].item()
209
  entropy_count_per_bin[bin_idx] += 1
210
 
211
+ # Compute accuracy per bin
212
  accuracy_per_bin = correct_per_bin / (total_per_bin.clamp(min=1))
213
+
214
+ # ========================================================================
215
+ # RELIABILITY FILTERING: Disable bins with accuracy < threshold
216
+ # ========================================================================
217
+ reliable_mask = accuracy_per_bin >= self.reliability_threshold
218
+ num_reliable = reliable_mask.sum().item()
219
+ num_disabled = self.num_bins - num_reliable
220
+
221
+ print(f"\n🎯 Reliability Analysis:")
222
+ print(f" Threshold: {self.reliability_threshold:.0%}")
223
+ print(f" Reliable bins: {num_reliable}/{self.num_bins}")
224
+ print(f" Disabled bins: {num_disabled}/{self.num_bins}")
225
+
226
+ if num_disabled > 0:
227
+ disabled_bins = torch.where(~reliable_mask)[0].tolist()
228
+ disabled_accs = [accuracy_per_bin[i].item() for i in disabled_bins]
229
+ print(f" Disabled: {disabled_bins[:10]}{'...' if len(disabled_bins) > 10 else ''}")
230
+ print(f" (accuracies: {[f'{a:.1%}' for a in disabled_accs[:10]]})")
231
+
232
+ # Create difficulty weights ONLY for reliable bins
233
+ if num_reliable == 0:
234
+ print("\n⚠️ WARNING: No reliable bins found! Falling back to uniform sampling.")
235
+ self.difficulty_weights = torch.ones(self.num_bins) / self.num_bins
236
+ self.pattern_difficulty = torch.ones(self.num_bins) * 0.5
237
+ return self.difficulty_weights
238
+
239
+ # Compute difficulty (inverse accuracy) for reliable bins
240
+ timestep_difficulty = torch.zeros(self.num_bins)
241
+ timestep_difficulty[reliable_mask] = (1.0 - accuracy_per_bin[reliable_mask]) + 0.1
242
+
243
+ # Zero out unreliable bins (won't be sampled)
244
+ timestep_difficulty[~reliable_mask] = 0.0
245
+
246
+ # Normalize weights over reliable bins only
247
  self.difficulty_weights = timestep_difficulty / timestep_difficulty.sum()
248
 
249
  # Compute pattern difficulty (average entropy per bin)
250
  self.pattern_difficulty = entropy_per_bin / (entropy_count_per_bin.clamp(min=1))
251
  self.pattern_difficulty = self.pattern_difficulty.clamp(min=0.1, max=1.0)
252
 
253
+ # Set entropy to 0.5 (neutral) for disabled bins
254
+ self.pattern_difficulty[~reliable_mask] = 0.5
255
+
256
+ # ========================================================================
257
+ # REPORT
258
+ # ========================================================================
259
+ print(f"\n✓ David difficulty map computed (filtered):")
260
+ print(f" Avg timestep accuracy (all bins): {accuracy_per_bin.mean():.2%}")
261
+ print(f" Avg timestep accuracy (reliable): {accuracy_per_bin[reliable_mask].mean():.2%}")
262
+
263
+ # Find hardest/easiest among reliable bins
264
+ reliable_indices = torch.where(reliable_mask)[0]
265
+ if len(reliable_indices) > 0:
266
+ hardest_idx = reliable_indices[accuracy_per_bin[reliable_mask].argmin()].item()
267
+ easiest_idx = reliable_indices[accuracy_per_bin[reliable_mask].argmax()].item()
268
+
269
+ print(f" Hardest reliable bin: {hardest_idx} ({accuracy_per_bin[hardest_idx]:.2%} acc)")
270
+ print(f" Easiest reliable bin: {easiest_idx} ({accuracy_per_bin[easiest_idx]:.2%} acc)")
271
+
272
+ print(f" Avg pattern entropy (reliable): {self.pattern_difficulty[reliable_mask].mean():.3f}")
273
+
274
+ # Show sampling distribution (top 10 weighted bins)
275
+ top_weights, top_bins = self.difficulty_weights.topk(10)
276
+ print(f"\n📊 Top 10 sampled bins (by difficulty weight):")
277
+ for i, (bin_idx, weight) in enumerate(zip(top_bins.tolist(), top_weights.tolist())):
278
+ acc = accuracy_per_bin[bin_idx].item()
279
+ print(f" {i+1}. Bin {bin_idx:2d}: weight={weight:.3f} (acc={acc:.1%})")
280
 
281
  return self.difficulty_weights
282
+
283
  def sample(self, batch_size: int) -> List[int]:
284
  """Sample timesteps with David weighting + shift + adaptive chaos."""
285
  if self.difficulty_weights is None:
 
483
  cfg.block_weights = self.hf_config["block_weights"]
484
 
485
  class DavidAssessor(nn.Module):
486
+ """
487
+ CORRECTED: Properly handles GeoDavidCollective's nested multi-block output structure.
488
+
489
+ GeoDavidCollective returns: Dict[block_name, Dict[str, Tensor]]
490
+ Not a flat Dict[str, Tensor]!
491
+ """
492
  def __init__(self, gdc: GeoDavidCollective, pooling: str):
493
  super().__init__()
494
  self.gdc = gdc
 
500
  @torch.no_grad()
501
  def forward(self, feats_student: Dict[str, torch.Tensor], t: torch.LongTensor
502
  ) -> Tuple[Dict[str,float], Dict[str,float], Dict[str,float]]:
503
+ """
504
+ Assess student features using David's geometric knowledge.
505
+
506
+ Returns:
507
+ e_t: Dict[block_name, timestep_error] - classification error per block
508
+ e_p: Dict[block_name, pattern_entropy] - normalized entropy per block
509
+ coh: Dict[block_name, coherence] - geometric coherence per block
510
+ """
511
+ # Pool spatial features
512
  Zs = self._pool(feats_student)
513
+
514
+ # Forward through GeoDavidCollective
515
+ # Returns: Dict[block_name, Dict[str, Tensor]]
516
  outs = self.gdc(Zs, t.float())
517
+
518
+ # Initialize output dicts
519
  e_t, e_p, coh = {}, {}, {}
 
 
 
 
520
 
521
+ # Compute timestep bins for targets
 
 
 
522
  t_bins = (t // 10).to(next(self.gdc.parameters()).device)
523
+
524
+ # ====================================================================
525
+ # TIMESTEP ERROR - Per-block
526
+ # ====================================================================
527
+ for block_name, block_out in outs.items():
528
+ if 'timestep_logits' in block_out:
529
+ ts_logits = block_out['timestep_logits']
530
  ce = F.cross_entropy(ts_logits, t_bins, reduction="mean")
531
+ e_t[block_name] = float(ce.item())
532
+
533
+ # If no timestep predictions, set all errors to 0
534
+ if not e_t:
535
+ for name in Zs.keys():
536
+ e_t[name] = 0.0
537
+
538
+ # ====================================================================
539
+ # PATTERN ENTROPY - Per-block
540
+ # ====================================================================
541
+ for block_name, block_out in outs.items():
542
+ if 'pattern_logits' in block_out:
543
+ pt_logits = block_out['pattern_logits']
544
+
545
+ # Compute normalized entropy
546
  P = pt_logits.softmax(-1)
547
  ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean()
548
+ norm_ent = ent / math.log(P.shape[-1]) # Normalize by max entropy
549
+
550
+ e_p[block_name] = float(norm_ent.item())
551
+
552
+ # If no pattern predictions, set all entropies to 0
553
+ if not e_p:
554
+ for name in Zs.keys():
555
+ e_p[name] = 0.0
556
+
557
+ # ====================================================================
558
+ # COHERENCE (from Cantor alphas)
559
+ # ====================================================================
560
  try:
561
  alphas = self.gdc.get_cantor_alphas()
562
+ # Alphas should be close to 0.5 for good coherence
563
+ # Map to coherence: 1.0 = perfect (alpha=0.5), lower = worse
564
+ for name, alpha in alphas.items():
565
+ # Coherence = 1 - 2*|alpha - 0.5|
566
+ # When alpha=0.5: coherence=1.0
567
+ # When alpha=0 or 1: coherence=0.0
568
+ coherence = 1.0 - 2.0 * abs(alpha - 0.5)
569
+ coh[name] = max(0.0, min(1.0, coherence))
570
  except Exception:
571
+ # Fallback: assume perfect coherence
572
+ for name in Zs.keys():
573
+ coh[name] = 1.0
574
+
575
+ # Ensure all input blocks have values (fill missing with block averages)
576
  for name in Zs.keys():
577
+ if name not in e_t:
578
+ # Use average of available blocks
579
+ e_t[name] = sum(e_t.values()) / max(len(e_t), 1) if e_t else 0.0
580
+ if name not in e_p:
581
+ e_p[name] = sum(e_p.values()) / max(len(e_p), 1) if e_p else 0.0
582
+ if name not in coh:
583
+ coh[name] = sum(coh.values()) / max(len(coh), 1) if coh else 1.0
584
+
585
  return e_t, e_p, coh
586
 
587
  class BlockPenaltyFusion:
 
625
  print(f" SD3 shift: {cfg.timestep_shift}")
626
  print(f" Base jitter: ±{cfg.base_jitter}")
627
  print(f" Adaptive chaos: {cfg.adaptive_chaos}")
628
+ print(f" Reliability threshold: {cfg.reliability_threshold:.0%}")
629
 
630
  self.timestep_sampler = DavidWeightedTimestepSampler(
631
  num_timesteps=cfg.num_train_timesteps,
632
  num_bins=100,
633
  shift=cfg.timestep_shift if cfg.use_david_weights else 0.0,
634
  base_jitter=cfg.base_jitter,
635
+ adaptive_chaos=cfg.adaptive_chaos,
636
+ reliability_threshold=cfg.reliability_threshold
637
  )
638
 
639
  if cfg.use_david_weights:
 
657
  self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader))
658
  self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
659
 
660
+ # Load latest checkpoint from HuggingFace if continuing training
661
+ if cfg.continue_training:
 
 
 
 
 
 
 
662
  self._load_latest_from_hf()
663
 
664
  self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name))
665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
  def _load_latest_from_hf(self):
667
+ """Load the most recent checkpoint from HuggingFace repo."""
668
  if not self.cfg.hf_repo_id:
669
+ print("ℹ️ No HuggingFace repo specified, starting from scratch\n")
670
  return
671
 
672
  try:
673
  api = HfApi()
674
  print(f"\n🔍 Searching for latest checkpoint in {self.cfg.hf_repo_id}...")
675
 
676
+ # List all files in the repo
677
  files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model")
678
+
679
+ # Find all epoch checkpoints (format: {run_name}_e{epoch}.pt)
680
  epochs = []
681
  for f in files:
682
+ if f.endswith('.pt') and 'final' not in f.lower():
683
  match = re.search(r'_e(\d+)\.pt$', f)
684
  if match:
685
+ epoch_num = int(match.group(1))
686
+ epochs.append((epoch_num, f))
687
 
688
  if not epochs:
689
+ print("ℹ️ No previous checkpoints found, starting from scratch\n")
690
  return
691
 
692
+ # Get the latest epoch
693
  latest_epoch, latest_file = max(epochs, key=lambda x: x[0])
694
+ print(f"📥 Found latest checkpoint: {latest_file} (epoch {latest_epoch})")
695
 
696
+ # Download checkpoint
697
  local_path = hf_hub_download(
698
  repo_id=self.cfg.hf_repo_id,
699
  filename=latest_file,
 
701
  cache_dir=self.cfg.ckpt_dir
702
  )
703
 
704
+ # Load checkpoint
705
+ print(f"📦 Loading checkpoint...")
706
  checkpoint = torch.load(local_path, map_location='cpu')
707
 
708
+ # Load student state dict
709
+ if 'student' in checkpoint:
710
+ missing, unexpected = self.student.load_state_dict(checkpoint['student'], strict=False)
711
+ if missing:
712
+ print(f" ⚠️ Missing keys: {len(missing)}")
713
+ if unexpected:
714
+ print(f" ⚠️ Unexpected keys: {len(unexpected)}")
715
+ print(f" ✓ Loaded student model")
716
+ else:
717
+ print(f" ⚠️ Warning: 'student' key not found in checkpoint")
718
+ return
719
 
720
+ # Load optimizer state
721
  if 'opt' in checkpoint:
722
+ try:
723
+ self.opt.load_state_dict(checkpoint['opt'])
724
+ print(" ✓ Loaded optimizer state")
725
+ except Exception as e:
726
+ print(f" ⚠️ Failed to load optimizer state: {e}")
727
+
728
+ # Load scheduler state
729
  if 'sched' in checkpoint:
730
+ try:
731
+ self.sched.load_state_dict(checkpoint['sched'])
732
+ print(" ✓ Loaded scheduler state")
733
+ except Exception as e:
734
+ print(f" ⚠️ Failed to load scheduler state: {e}")
735
 
736
+ # Set starting epoch and global step
737
+ if 'gstep' in checkpoint:
738
+ self.start_gstep = checkpoint['gstep']
739
+ self.start_epoch = latest_epoch
740
+ print(f" ✓ Resuming from epoch {self.start_epoch + 1}, global step {self.start_gstep}")
741
+ else:
742
+ # Fallback: estimate from epoch number
743
+ self.start_epoch = latest_epoch
744
+ self.start_gstep = latest_epoch * len(self.loader)
745
+ print(f" ✓ Resuming from epoch {self.start_epoch + 1} (estimated step {self.start_gstep})")
746
 
747
+ # Cleanup
748
  del checkpoint
749
  torch.cuda.empty_cache()
750
 
751
+ print(f"✅ Successfully resumed from checkpoint!\n")
752
+
753
  except Exception as e:
754
+ print(f"⚠️ Failed to load checkpoint: {e}")
755
+ print(" Starting training from scratch...\n")
756
 
757
  def _v_star(self, x_t, t, eps_hat):
758
  alpha, sigma = self.teacher.alpha_sigma(t)
 
770
  cfg = self.cfg
771
  gstep = self.start_gstep
772
 
773
+ # Test prompts for monitoring progress
774
+ test_prompts = [
775
+ "a castle at sunset",
776
+ "a mountain landscape with trees",
777
+ "a city street at night"
778
+ ]
779
+
780
  for ep in range(self.start_epoch, cfg.epochs):
781
+ # Sample before epoch to monitor progress
782
+ if ep > 0 or self.start_epoch > 0: # Skip first ever epoch
783
+ print(f"\n🎨 Sampling test images before epoch {ep+1}...")
784
+ try:
785
+ test_imgs = self.sample(test_prompts, steps=30, guidance=7.5)
786
+
787
+ # Save individual images
788
+ sample_dir = Path(cfg.out_dir) / "samples"
789
+ sample_dir.mkdir(exist_ok=True, parents=True)
790
+
791
+ for i, (img, prompt) in enumerate(zip(test_imgs, test_prompts)):
792
+ # Convert to PIL
793
+ img_np = ((img.cpu().permute(1,2,0).numpy() + 1) / 2 * 255).astype('uint8')
794
+ from PIL import Image
795
+ pil_img = Image.fromarray(img_np)
796
+
797
+ # Save with epoch number
798
+ safe_prompt = prompt.replace(" ", "_")[:30]
799
+ img_path = sample_dir / f"e{ep}_p{i}_{safe_prompt}.png"
800
+ pil_img.save(img_path)
801
+
802
+ # Log to tensorboard
803
+ self.writer.add_image(f"samples/{safe_prompt}",
804
+ (img + 1) / 2, # Normalize to [0,1]
805
+ global_step=ep)
806
+
807
+ print(f"✓ Saved {len(test_imgs)} test images to {sample_dir}")
808
+
809
+ except Exception as e:
810
+ print(f"⚠️ Sampling failed: {e}")
811
+
812
  self.student.train()
813
  pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}",
814
  dynamic_ncols=True, leave=True, position=0)
 
894
  self._save(ep+1, gstep)
895
 
896
  self._save("final", gstep)
897
+
898
+ # Final comprehensive sampling
899
+ print("\n🎨 Generating final test samples...")
900
+ final_prompts = [
901
+ "a castle at sunset",
902
+ "a mountain landscape with trees",
903
+ "a city street at night",
904
+ "a portrait of a person",
905
+ "abstract geometric shapes"
906
+ ]
907
+ try:
908
+ final_imgs = self.sample(final_prompts, steps=30, guidance=7.5)
909
+
910
+ sample_dir = Path(cfg.out_dir) / "samples"
911
+ sample_dir.mkdir(exist_ok=True, parents=True)
912
+
913
+ for i, (img, prompt) in enumerate(zip(final_imgs, final_prompts)):
914
+ from PIL import Image
915
+ img_np = ((img.cpu().permute(1,2,0).numpy() + 1) / 2 * 255).astype('uint8')
916
+ pil_img = Image.fromarray(img_np)
917
+ safe_prompt = prompt.replace(" ", "_")[:30]
918
+ pil_img.save(sample_dir / f"final_{safe_prompt}.png")
919
+
920
+ print(f"✓ Saved {len(final_imgs)} final images to {sample_dir}")
921
+ except Exception as e:
922
+ print(f"⚠️ Final sampling failed: {e}")
923
+
924
  self.writer.close()
925
 
926
  def _save(self, tag, gstep):
 
963
  def sample(self, prompts: List[str], steps: Optional[int]=None, guidance: Optional[float]=None) -> torch.Tensor:
964
  steps = steps or self.cfg.sample_steps
965
  guidance = guidance if guidance is not None else self.cfg.guidance_scale
966
+
967
+ # Get model dtype from student
968
+ model_dtype = next(self.student.unet.parameters()).dtype
969
+
970
  cond_e = self.teacher.encode(prompts)
971
  uncond_e = self.teacher.encode([""]*len(prompts))
972
  sched = self.teacher.sched
973
  sched.set_timesteps(steps, device=self.device)
974
+
975
+ # Create latents with correct dtype
976
+ x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device, dtype=model_dtype)
977
 
978
  for t_scalar in sched.timesteps:
979
  t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long)