Trainer v2 update, necessary elements included
Browse files- trainer_v2.py +296 -147
trainer_v2.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
# =====================================================================================
|
| 2 |
# SD1.5 Flow-Matching Trainer — David-Driven Adaptive Timestep Sampling
|
| 3 |
# Quartermaster: Mirel
|
| 4 |
-
#
|
| 5 |
# =====================================================================================
|
| 6 |
from __future__ import annotations
|
| 7 |
import os, json, math, random, re, shutil
|
|
@@ -86,7 +86,8 @@ class BaseConfig:
|
|
| 86 |
timestep_shift: float = 3.0 # SD3-style shift (higher = bias toward clean)
|
| 87 |
base_jitter: int = 5 # Base ±jitter around bin center
|
| 88 |
adaptive_chaos: bool = True # Scale jitter by pattern difficulty
|
| 89 |
-
profile_samples: int =
|
|
|
|
| 90 |
|
| 91 |
# Scheduler
|
| 92 |
num_train_timesteps: int = 1000
|
|
@@ -114,13 +115,16 @@ class BaseConfig:
|
|
| 114 |
class DavidWeightedTimestepSampler:
|
| 115 |
"""
|
| 116 |
Samples timesteps weighted by David's inherent difficulty + SD3 shift + adaptive chaos.
|
|
|
|
|
|
|
| 117 |
"""
|
| 118 |
-
def __init__(self, num_timesteps=1000, num_bins=100, shift=3.0, base_jitter=5, adaptive_chaos=True):
|
| 119 |
self.num_timesteps = num_timesteps
|
| 120 |
self.num_bins = num_bins
|
| 121 |
self.shift = shift
|
| 122 |
self.base_jitter = base_jitter
|
| 123 |
self.adaptive_chaos = adaptive_chaos
|
|
|
|
| 124 |
|
| 125 |
self.difficulty_weights = None # Timestep difficulty
|
| 126 |
self.pattern_difficulty = None # Pattern confusion per bin
|
|
@@ -162,38 +166,38 @@ class DavidWeightedTimestepSampler:
|
|
| 162 |
# Pool features
|
| 163 |
pooled = {name: f.mean(dim=(2, 3)) for name, f in feats.items()}
|
| 164 |
|
| 165 |
-
# Get David's outputs
|
| 166 |
outputs = david(pooled, t.float())
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
# 1. Timestep difficulty (from classification error)
|
| 169 |
-
|
| 170 |
-
for
|
| 171 |
-
if
|
| 172 |
-
|
| 173 |
-
break
|
| 174 |
|
| 175 |
-
if
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
ts_logits = torch.stack(list(ts_logits.values())).mean(0)
|
| 179 |
-
|
| 180 |
preds = ts_logits.argmax(dim=-1)
|
|
|
|
| 181 |
for pred, true_bin in zip(preds, t_bins):
|
| 182 |
bin_idx = true_bin.item()
|
| 183 |
correct_per_bin[bin_idx] += (pred == true_bin).float().item()
|
| 184 |
total_per_bin[bin_idx] += 1
|
| 185 |
|
| 186 |
# 2. Pattern difficulty (from entropy)
|
| 187 |
-
|
| 188 |
-
for
|
| 189 |
-
if
|
| 190 |
-
|
| 191 |
-
break
|
| 192 |
|
| 193 |
-
if
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
pt_logits = torch.stack(list(pt_logits.values())).mean(0)
|
| 197 |
|
| 198 |
P = pt_logits.softmax(-1)
|
| 199 |
ent = -(P * P.clamp_min(1e-9).log()).sum(-1)
|
|
@@ -204,23 +208,78 @@ class DavidWeightedTimestepSampler:
|
|
| 204 |
entropy_per_bin[bin_idx] += norm_ent[i].item()
|
| 205 |
entropy_count_per_bin[bin_idx] += 1
|
| 206 |
|
| 207 |
-
# Compute
|
| 208 |
accuracy_per_bin = correct_per_bin / (total_per_bin.clamp(min=1))
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
self.difficulty_weights = timestep_difficulty / timestep_difficulty.sum()
|
| 211 |
|
| 212 |
# Compute pattern difficulty (average entropy per bin)
|
| 213 |
self.pattern_difficulty = entropy_per_bin / (entropy_count_per_bin.clamp(min=1))
|
| 214 |
self.pattern_difficulty = self.pattern_difficulty.clamp(min=0.1, max=1.0)
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
return self.difficulty_weights
|
| 223 |
-
|
| 224 |
def sample(self, batch_size: int) -> List[int]:
|
| 225 |
"""Sample timesteps with David weighting + shift + adaptive chaos."""
|
| 226 |
if self.difficulty_weights is None:
|
|
@@ -424,6 +483,12 @@ class DavidLoader:
|
|
| 424 |
cfg.block_weights = self.hf_config["block_weights"]
|
| 425 |
|
| 426 |
class DavidAssessor(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
def __init__(self, gdc: GeoDavidCollective, pooling: str):
|
| 428 |
super().__init__()
|
| 429 |
self.gdc = gdc
|
|
@@ -435,56 +500,88 @@ class DavidAssessor(nn.Module):
|
|
| 435 |
@torch.no_grad()
|
| 436 |
def forward(self, feats_student: Dict[str, torch.Tensor], t: torch.LongTensor
|
| 437 |
) -> Tuple[Dict[str,float], Dict[str,float], Dict[str,float]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
Zs = self._pool(feats_student)
|
|
|
|
|
|
|
|
|
|
| 439 |
outs = self.gdc(Zs, t.float())
|
|
|
|
|
|
|
| 440 |
e_t, e_p, coh = {}, {}, {}
|
| 441 |
-
|
| 442 |
-
ts_key = None
|
| 443 |
-
for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]:
|
| 444 |
-
if key in outs: ts_key = key; break
|
| 445 |
|
| 446 |
-
|
| 447 |
-
for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]:
|
| 448 |
-
if key in outs: pt_key = key; break
|
| 449 |
-
|
| 450 |
t_bins = (t // 10).to(next(self.gdc.parameters()).device)
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
ce = F.cross_entropy(ts_logits, t_bins, reduction="mean")
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
|
|
|
|
|
|
| 472 |
P = pt_logits.softmax(-1)
|
| 473 |
ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean()
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
try:
|
| 481 |
alphas = self.gdc.get_cantor_alphas()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
except Exception:
|
| 483 |
-
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
| 485 |
for name in Zs.keys():
|
| 486 |
-
|
| 487 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
return e_t, e_p, coh
|
| 489 |
|
| 490 |
class BlockPenaltyFusion:
|
|
@@ -528,13 +625,15 @@ class FlowMatchDavidTrainer:
|
|
| 528 |
print(f" SD3 shift: {cfg.timestep_shift}")
|
| 529 |
print(f" Base jitter: ±{cfg.base_jitter}")
|
| 530 |
print(f" Adaptive chaos: {cfg.adaptive_chaos}")
|
|
|
|
| 531 |
|
| 532 |
self.timestep_sampler = DavidWeightedTimestepSampler(
|
| 533 |
num_timesteps=cfg.num_train_timesteps,
|
| 534 |
num_bins=100,
|
| 535 |
shift=cfg.timestep_shift if cfg.use_david_weights else 0.0,
|
| 536 |
base_jitter=cfg.base_jitter,
|
| 537 |
-
adaptive_chaos=cfg.adaptive_chaos
|
|
|
|
| 538 |
)
|
| 539 |
|
| 540 |
if cfg.use_david_weights:
|
|
@@ -558,95 +657,43 @@ class FlowMatchDavidTrainer:
|
|
| 558 |
self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader))
|
| 559 |
self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
|
| 560 |
|
| 561 |
-
# Load
|
| 562 |
-
|
| 563 |
-
if not emergency_path.exists():
|
| 564 |
-
print("\n🔍 Emergency checkpoint not found locally, checking HuggingFace...")
|
| 565 |
-
emergency_path = self._download_emergency_checkpoint()
|
| 566 |
-
|
| 567 |
-
if emergency_path and emergency_path.exists():
|
| 568 |
-
self._load_emergency_checkpoint(emergency_path)
|
| 569 |
-
elif cfg.continue_training:
|
| 570 |
self._load_latest_from_hf()
|
| 571 |
|
| 572 |
self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name))
|
| 573 |
|
| 574 |
-
def _download_emergency_checkpoint(self) -> Optional[Path]:
|
| 575 |
-
"""Download emergency checkpoint from HuggingFace backup repo."""
|
| 576 |
-
emergency_repo = "AbstractPhil/sd15-flow-emergency-backup"
|
| 577 |
-
emergency_file = "EMERGENCY_SAVE_SUCCESS.pt"
|
| 578 |
-
|
| 579 |
-
try:
|
| 580 |
-
print(f"📥 Downloading emergency checkpoint from {emergency_repo}...")
|
| 581 |
-
local_path = hf_hub_download(
|
| 582 |
-
repo_id=emergency_repo,
|
| 583 |
-
filename=emergency_file,
|
| 584 |
-
repo_type="model",
|
| 585 |
-
cache_dir="./_emergency_cache"
|
| 586 |
-
)
|
| 587 |
-
|
| 588 |
-
target_path = Path("./EMERGENCY_SAVE_SUCCESS.pt")
|
| 589 |
-
shutil.copy(local_path, target_path)
|
| 590 |
-
|
| 591 |
-
size_mb = target_path.stat().st_size / 1e6
|
| 592 |
-
print(f"✅ Downloaded emergency checkpoint ({size_mb:.1f} MB)")
|
| 593 |
-
return target_path
|
| 594 |
-
|
| 595 |
-
except Exception as e:
|
| 596 |
-
print(f"⚠️ Could not download emergency checkpoint: {e}")
|
| 597 |
-
return None
|
| 598 |
-
|
| 599 |
-
def _load_emergency_checkpoint(self, path: Path):
|
| 600 |
-
"""Load emergency checkpoint with student_unet structure."""
|
| 601 |
-
try:
|
| 602 |
-
print(f"\n🚨 Found emergency checkpoint: {path}")
|
| 603 |
-
checkpoint = torch.load(path, map_location='cpu')
|
| 604 |
-
|
| 605 |
-
if 'student_unet' in checkpoint:
|
| 606 |
-
print("📦 Loading emergency checkpoint format...")
|
| 607 |
-
missing, unexpected = self.student.unet.load_state_dict(checkpoint['student_unet'], strict=False)
|
| 608 |
-
print(f"✓ Loaded student UNet")
|
| 609 |
-
|
| 610 |
-
if 'opt' in checkpoint:
|
| 611 |
-
self.opt.load_state_dict(checkpoint['opt'])
|
| 612 |
-
print("✓ Loaded optimizer state")
|
| 613 |
-
|
| 614 |
-
if 'sched' in checkpoint:
|
| 615 |
-
self.sched.load_state_dict(checkpoint['sched'])
|
| 616 |
-
print("✓ Loaded scheduler state")
|
| 617 |
-
|
| 618 |
-
if 'gstep' in checkpoint:
|
| 619 |
-
self.start_gstep = checkpoint['gstep']
|
| 620 |
-
self.start_epoch = self.start_gstep // len(self.loader)
|
| 621 |
-
print(f"✓ Resuming from global step {self.start_gstep} (epoch ~{self.start_epoch})")
|
| 622 |
-
|
| 623 |
-
print("✅ Emergency checkpoint loaded successfully!")
|
| 624 |
-
|
| 625 |
-
except Exception as e:
|
| 626 |
-
print(f"⚠️ Failed to load emergency checkpoint: {e}")
|
| 627 |
-
|
| 628 |
def _load_latest_from_hf(self):
|
|
|
|
| 629 |
if not self.cfg.hf_repo_id:
|
|
|
|
| 630 |
return
|
| 631 |
|
| 632 |
try:
|
| 633 |
api = HfApi()
|
| 634 |
print(f"\n🔍 Searching for latest checkpoint in {self.cfg.hf_repo_id}...")
|
| 635 |
|
|
|
|
| 636 |
files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model")
|
|
|
|
|
|
|
| 637 |
epochs = []
|
| 638 |
for f in files:
|
| 639 |
-
if f.endswith('.pt'):
|
| 640 |
match = re.search(r'_e(\d+)\.pt$', f)
|
| 641 |
if match:
|
| 642 |
-
|
|
|
|
| 643 |
|
| 644 |
if not epochs:
|
|
|
|
| 645 |
return
|
| 646 |
|
|
|
|
| 647 |
latest_epoch, latest_file = max(epochs, key=lambda x: x[0])
|
| 648 |
-
print(f"📥
|
| 649 |
|
|
|
|
| 650 |
local_path = hf_hub_download(
|
| 651 |
repo_id=self.cfg.hf_repo_id,
|
| 652 |
filename=latest_file,
|
|
@@ -654,27 +701,58 @@ class FlowMatchDavidTrainer:
|
|
| 654 |
cache_dir=self.cfg.ckpt_dir
|
| 655 |
)
|
| 656 |
|
|
|
|
|
|
|
| 657 |
checkpoint = torch.load(local_path, map_location='cpu')
|
| 658 |
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
|
|
|
|
| 664 |
if 'opt' in checkpoint:
|
| 665 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
if 'sched' in checkpoint:
|
| 667 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
|
| 669 |
-
|
| 670 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 671 |
|
| 672 |
-
|
| 673 |
del checkpoint
|
| 674 |
torch.cuda.empty_cache()
|
| 675 |
|
|
|
|
|
|
|
| 676 |
except Exception as e:
|
| 677 |
-
print(f"⚠️ Failed to load
|
|
|
|
| 678 |
|
| 679 |
def _v_star(self, x_t, t, eps_hat):
|
| 680 |
alpha, sigma = self.teacher.alpha_sigma(t)
|
|
@@ -692,7 +770,45 @@ class FlowMatchDavidTrainer:
|
|
| 692 |
cfg = self.cfg
|
| 693 |
gstep = self.start_gstep
|
| 694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
for ep in range(self.start_epoch, cfg.epochs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
self.student.train()
|
| 697 |
pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}",
|
| 698 |
dynamic_ncols=True, leave=True, position=0)
|
|
@@ -778,6 +894,33 @@ class FlowMatchDavidTrainer:
|
|
| 778 |
self._save(ep+1, gstep)
|
| 779 |
|
| 780 |
self._save("final", gstep)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
self.writer.close()
|
| 782 |
|
| 783 |
def _save(self, tag, gstep):
|
|
@@ -820,11 +963,17 @@ class FlowMatchDavidTrainer:
|
|
| 820 |
def sample(self, prompts: List[str], steps: Optional[int]=None, guidance: Optional[float]=None) -> torch.Tensor:
|
| 821 |
steps = steps or self.cfg.sample_steps
|
| 822 |
guidance = guidance if guidance is not None else self.cfg.guidance_scale
|
|
|
|
|
|
|
|
|
|
|
|
|
| 823 |
cond_e = self.teacher.encode(prompts)
|
| 824 |
uncond_e = self.teacher.encode([""]*len(prompts))
|
| 825 |
sched = self.teacher.sched
|
| 826 |
sched.set_timesteps(steps, device=self.device)
|
| 827 |
-
|
|
|
|
|
|
|
| 828 |
|
| 829 |
for t_scalar in sched.timesteps:
|
| 830 |
t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long)
|
|
|
|
| 1 |
# =====================================================================================
|
| 2 |
# SD1.5 Flow-Matching Trainer — David-Driven Adaptive Timestep Sampling
|
| 3 |
# Quartermaster: Mirel
|
| 4 |
+
# FIXED: David nested output handling + reliability filtering + clean checkpoint loading
|
| 5 |
# =====================================================================================
|
| 6 |
from __future__ import annotations
|
| 7 |
import os, json, math, random, re, shutil
|
|
|
|
| 86 |
timestep_shift: float = 3.0 # SD3-style shift (higher = bias toward clean)
|
| 87 |
base_jitter: int = 5 # Base ±jitter around bin center
|
| 88 |
adaptive_chaos: bool = True # Scale jitter by pattern difficulty
|
| 89 |
+
profile_samples: int = 2500 # Samples to profile David's difficulty
|
| 90 |
+
reliability_threshold: float = 0.15 # Minimum accuracy to trust David's guidance
|
| 91 |
|
| 92 |
# Scheduler
|
| 93 |
num_train_timesteps: int = 1000
|
|
|
|
| 115 |
class DavidWeightedTimestepSampler:
|
| 116 |
"""
|
| 117 |
Samples timesteps weighted by David's inherent difficulty + SD3 shift + adaptive chaos.
|
| 118 |
+
FIXED: Properly handles nested GeoDavidCollective output structure.
|
| 119 |
+
FIXED: Filters out unreliable bins (accuracy < threshold).
|
| 120 |
"""
|
| 121 |
+
def __init__(self, num_timesteps=1000, num_bins=100, shift=3.0, base_jitter=5, adaptive_chaos=True, reliability_threshold=0.15):
|
| 122 |
self.num_timesteps = num_timesteps
|
| 123 |
self.num_bins = num_bins
|
| 124 |
self.shift = shift
|
| 125 |
self.base_jitter = base_jitter
|
| 126 |
self.adaptive_chaos = adaptive_chaos
|
| 127 |
+
self.reliability_threshold = reliability_threshold
|
| 128 |
|
| 129 |
self.difficulty_weights = None # Timestep difficulty
|
| 130 |
self.pattern_difficulty = None # Pattern confusion per bin
|
|
|
|
| 166 |
# Pool features
|
| 167 |
pooled = {name: f.mean(dim=(2, 3)) for name, f in feats.items()}
|
| 168 |
|
| 169 |
+
# Get David's outputs (NESTED STRUCTURE!)
|
| 170 |
outputs = david(pooled, t.float())
|
| 171 |
|
| 172 |
+
# ================================================================
|
| 173 |
+
# FIXED: Aggregate across blocks
|
| 174 |
+
# ================================================================
|
| 175 |
+
|
| 176 |
# 1. Timestep difficulty (from classification error)
|
| 177 |
+
timestep_logits_list = []
|
| 178 |
+
for block_name, block_out in outputs.items():
|
| 179 |
+
if 'timestep_logits' in block_out:
|
| 180 |
+
timestep_logits_list.append(block_out['timestep_logits'])
|
|
|
|
| 181 |
|
| 182 |
+
if timestep_logits_list:
|
| 183 |
+
# Average predictions across blocks
|
| 184 |
+
ts_logits = torch.stack(timestep_logits_list).mean(0)
|
|
|
|
|
|
|
| 185 |
preds = ts_logits.argmax(dim=-1)
|
| 186 |
+
|
| 187 |
for pred, true_bin in zip(preds, t_bins):
|
| 188 |
bin_idx = true_bin.item()
|
| 189 |
correct_per_bin[bin_idx] += (pred == true_bin).float().item()
|
| 190 |
total_per_bin[bin_idx] += 1
|
| 191 |
|
| 192 |
# 2. Pattern difficulty (from entropy)
|
| 193 |
+
pattern_logits_list = []
|
| 194 |
+
for block_name, block_out in outputs.items():
|
| 195 |
+
if 'pattern_logits' in block_out:
|
| 196 |
+
pattern_logits_list.append(block_out['pattern_logits'])
|
|
|
|
| 197 |
|
| 198 |
+
if pattern_logits_list:
|
| 199 |
+
# Average predictions across blocks
|
| 200 |
+
pt_logits = torch.stack(pattern_logits_list).mean(0)
|
|
|
|
| 201 |
|
| 202 |
P = pt_logits.softmax(-1)
|
| 203 |
ent = -(P * P.clamp_min(1e-9).log()).sum(-1)
|
|
|
|
| 208 |
entropy_per_bin[bin_idx] += norm_ent[i].item()
|
| 209 |
entropy_count_per_bin[bin_idx] += 1
|
| 210 |
|
| 211 |
+
# Compute accuracy per bin
|
| 212 |
accuracy_per_bin = correct_per_bin / (total_per_bin.clamp(min=1))
|
| 213 |
+
|
| 214 |
+
# ========================================================================
|
| 215 |
+
# RELIABILITY FILTERING: Disable bins with accuracy < threshold
|
| 216 |
+
# ========================================================================
|
| 217 |
+
reliable_mask = accuracy_per_bin >= self.reliability_threshold
|
| 218 |
+
num_reliable = reliable_mask.sum().item()
|
| 219 |
+
num_disabled = self.num_bins - num_reliable
|
| 220 |
+
|
| 221 |
+
print(f"\n🎯 Reliability Analysis:")
|
| 222 |
+
print(f" Threshold: {self.reliability_threshold:.0%}")
|
| 223 |
+
print(f" Reliable bins: {num_reliable}/{self.num_bins}")
|
| 224 |
+
print(f" Disabled bins: {num_disabled}/{self.num_bins}")
|
| 225 |
+
|
| 226 |
+
if num_disabled > 0:
|
| 227 |
+
disabled_bins = torch.where(~reliable_mask)[0].tolist()
|
| 228 |
+
disabled_accs = [accuracy_per_bin[i].item() for i in disabled_bins]
|
| 229 |
+
print(f" Disabled: {disabled_bins[:10]}{'...' if len(disabled_bins) > 10 else ''}")
|
| 230 |
+
print(f" (accuracies: {[f'{a:.1%}' for a in disabled_accs[:10]]})")
|
| 231 |
+
|
| 232 |
+
# Create difficulty weights ONLY for reliable bins
|
| 233 |
+
if num_reliable == 0:
|
| 234 |
+
print("\n⚠️ WARNING: No reliable bins found! Falling back to uniform sampling.")
|
| 235 |
+
self.difficulty_weights = torch.ones(self.num_bins) / self.num_bins
|
| 236 |
+
self.pattern_difficulty = torch.ones(self.num_bins) * 0.5
|
| 237 |
+
return self.difficulty_weights
|
| 238 |
+
|
| 239 |
+
# Compute difficulty (inverse accuracy) for reliable bins
|
| 240 |
+
timestep_difficulty = torch.zeros(self.num_bins)
|
| 241 |
+
timestep_difficulty[reliable_mask] = (1.0 - accuracy_per_bin[reliable_mask]) + 0.1
|
| 242 |
+
|
| 243 |
+
# Zero out unreliable bins (won't be sampled)
|
| 244 |
+
timestep_difficulty[~reliable_mask] = 0.0
|
| 245 |
+
|
| 246 |
+
# Normalize weights over reliable bins only
|
| 247 |
self.difficulty_weights = timestep_difficulty / timestep_difficulty.sum()
|
| 248 |
|
| 249 |
# Compute pattern difficulty (average entropy per bin)
|
| 250 |
self.pattern_difficulty = entropy_per_bin / (entropy_count_per_bin.clamp(min=1))
|
| 251 |
self.pattern_difficulty = self.pattern_difficulty.clamp(min=0.1, max=1.0)
|
| 252 |
|
| 253 |
+
# Set entropy to 0.5 (neutral) for disabled bins
|
| 254 |
+
self.pattern_difficulty[~reliable_mask] = 0.5
|
| 255 |
+
|
| 256 |
+
# ========================================================================
|
| 257 |
+
# REPORT
|
| 258 |
+
# ========================================================================
|
| 259 |
+
print(f"\n✓ David difficulty map computed (filtered):")
|
| 260 |
+
print(f" Avg timestep accuracy (all bins): {accuracy_per_bin.mean():.2%}")
|
| 261 |
+
print(f" Avg timestep accuracy (reliable): {accuracy_per_bin[reliable_mask].mean():.2%}")
|
| 262 |
+
|
| 263 |
+
# Find hardest/easiest among reliable bins
|
| 264 |
+
reliable_indices = torch.where(reliable_mask)[0]
|
| 265 |
+
if len(reliable_indices) > 0:
|
| 266 |
+
hardest_idx = reliable_indices[accuracy_per_bin[reliable_mask].argmin()].item()
|
| 267 |
+
easiest_idx = reliable_indices[accuracy_per_bin[reliable_mask].argmax()].item()
|
| 268 |
+
|
| 269 |
+
print(f" Hardest reliable bin: {hardest_idx} ({accuracy_per_bin[hardest_idx]:.2%} acc)")
|
| 270 |
+
print(f" Easiest reliable bin: {easiest_idx} ({accuracy_per_bin[easiest_idx]:.2%} acc)")
|
| 271 |
+
|
| 272 |
+
print(f" Avg pattern entropy (reliable): {self.pattern_difficulty[reliable_mask].mean():.3f}")
|
| 273 |
+
|
| 274 |
+
# Show sampling distribution (top 10 weighted bins)
|
| 275 |
+
top_weights, top_bins = self.difficulty_weights.topk(10)
|
| 276 |
+
print(f"\n📊 Top 10 sampled bins (by difficulty weight):")
|
| 277 |
+
for i, (bin_idx, weight) in enumerate(zip(top_bins.tolist(), top_weights.tolist())):
|
| 278 |
+
acc = accuracy_per_bin[bin_idx].item()
|
| 279 |
+
print(f" {i+1}. Bin {bin_idx:2d}: weight={weight:.3f} (acc={acc:.1%})")
|
| 280 |
|
| 281 |
return self.difficulty_weights
|
| 282 |
+
|
| 283 |
def sample(self, batch_size: int) -> List[int]:
|
| 284 |
"""Sample timesteps with David weighting + shift + adaptive chaos."""
|
| 285 |
if self.difficulty_weights is None:
|
|
|
|
| 483 |
cfg.block_weights = self.hf_config["block_weights"]
|
| 484 |
|
| 485 |
class DavidAssessor(nn.Module):
|
| 486 |
+
"""
|
| 487 |
+
CORRECTED: Properly handles GeoDavidCollective's nested multi-block output structure.
|
| 488 |
+
|
| 489 |
+
GeoDavidCollective returns: Dict[block_name, Dict[str, Tensor]]
|
| 490 |
+
Not a flat Dict[str, Tensor]!
|
| 491 |
+
"""
|
| 492 |
def __init__(self, gdc: GeoDavidCollective, pooling: str):
|
| 493 |
super().__init__()
|
| 494 |
self.gdc = gdc
|
|
|
|
| 500 |
@torch.no_grad()
|
| 501 |
def forward(self, feats_student: Dict[str, torch.Tensor], t: torch.LongTensor
|
| 502 |
) -> Tuple[Dict[str,float], Dict[str,float], Dict[str,float]]:
|
| 503 |
+
"""
|
| 504 |
+
Assess student features using David's geometric knowledge.
|
| 505 |
+
|
| 506 |
+
Returns:
|
| 507 |
+
e_t: Dict[block_name, timestep_error] - classification error per block
|
| 508 |
+
e_p: Dict[block_name, pattern_entropy] - normalized entropy per block
|
| 509 |
+
coh: Dict[block_name, coherence] - geometric coherence per block
|
| 510 |
+
"""
|
| 511 |
+
# Pool spatial features
|
| 512 |
Zs = self._pool(feats_student)
|
| 513 |
+
|
| 514 |
+
# Forward through GeoDavidCollective
|
| 515 |
+
# Returns: Dict[block_name, Dict[str, Tensor]]
|
| 516 |
outs = self.gdc(Zs, t.float())
|
| 517 |
+
|
| 518 |
+
# Initialize output dicts
|
| 519 |
e_t, e_p, coh = {}, {}, {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
|
| 521 |
+
# Compute timestep bins for targets
|
|
|
|
|
|
|
|
|
|
| 522 |
t_bins = (t // 10).to(next(self.gdc.parameters()).device)
|
| 523 |
+
|
| 524 |
+
# ====================================================================
|
| 525 |
+
# TIMESTEP ERROR - Per-block
|
| 526 |
+
# ====================================================================
|
| 527 |
+
for block_name, block_out in outs.items():
|
| 528 |
+
if 'timestep_logits' in block_out:
|
| 529 |
+
ts_logits = block_out['timestep_logits']
|
| 530 |
ce = F.cross_entropy(ts_logits, t_bins, reduction="mean")
|
| 531 |
+
e_t[block_name] = float(ce.item())
|
| 532 |
+
|
| 533 |
+
# If no timestep predictions, set all errors to 0
|
| 534 |
+
if not e_t:
|
| 535 |
+
for name in Zs.keys():
|
| 536 |
+
e_t[name] = 0.0
|
| 537 |
+
|
| 538 |
+
# ====================================================================
|
| 539 |
+
# PATTERN ENTROPY - Per-block
|
| 540 |
+
# ====================================================================
|
| 541 |
+
for block_name, block_out in outs.items():
|
| 542 |
+
if 'pattern_logits' in block_out:
|
| 543 |
+
pt_logits = block_out['pattern_logits']
|
| 544 |
+
|
| 545 |
+
# Compute normalized entropy
|
| 546 |
P = pt_logits.softmax(-1)
|
| 547 |
ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean()
|
| 548 |
+
norm_ent = ent / math.log(P.shape[-1]) # Normalize by max entropy
|
| 549 |
+
|
| 550 |
+
e_p[block_name] = float(norm_ent.item())
|
| 551 |
+
|
| 552 |
+
# If no pattern predictions, set all entropies to 0
|
| 553 |
+
if not e_p:
|
| 554 |
+
for name in Zs.keys():
|
| 555 |
+
e_p[name] = 0.0
|
| 556 |
+
|
| 557 |
+
# ====================================================================
|
| 558 |
+
# COHERENCE (from Cantor alphas)
|
| 559 |
+
# ====================================================================
|
| 560 |
try:
|
| 561 |
alphas = self.gdc.get_cantor_alphas()
|
| 562 |
+
# Alphas should be close to 0.5 for good coherence
|
| 563 |
+
# Map to coherence: 1.0 = perfect (alpha=0.5), lower = worse
|
| 564 |
+
for name, alpha in alphas.items():
|
| 565 |
+
# Coherence = 1 - 2*|alpha - 0.5|
|
| 566 |
+
# When alpha=0.5: coherence=1.0
|
| 567 |
+
# When alpha=0 or 1: coherence=0.0
|
| 568 |
+
coherence = 1.0 - 2.0 * abs(alpha - 0.5)
|
| 569 |
+
coh[name] = max(0.0, min(1.0, coherence))
|
| 570 |
except Exception:
|
| 571 |
+
# Fallback: assume perfect coherence
|
| 572 |
+
for name in Zs.keys():
|
| 573 |
+
coh[name] = 1.0
|
| 574 |
+
|
| 575 |
+
# Ensure all input blocks have values (fill missing with block averages)
|
| 576 |
for name in Zs.keys():
|
| 577 |
+
if name not in e_t:
|
| 578 |
+
# Use average of available blocks
|
| 579 |
+
e_t[name] = sum(e_t.values()) / max(len(e_t), 1) if e_t else 0.0
|
| 580 |
+
if name not in e_p:
|
| 581 |
+
e_p[name] = sum(e_p.values()) / max(len(e_p), 1) if e_p else 0.0
|
| 582 |
+
if name not in coh:
|
| 583 |
+
coh[name] = sum(coh.values()) / max(len(coh), 1) if coh else 1.0
|
| 584 |
+
|
| 585 |
return e_t, e_p, coh
|
| 586 |
|
| 587 |
class BlockPenaltyFusion:
|
|
|
|
| 625 |
print(f" SD3 shift: {cfg.timestep_shift}")
|
| 626 |
print(f" Base jitter: ±{cfg.base_jitter}")
|
| 627 |
print(f" Adaptive chaos: {cfg.adaptive_chaos}")
|
| 628 |
+
print(f" Reliability threshold: {cfg.reliability_threshold:.0%}")
|
| 629 |
|
| 630 |
self.timestep_sampler = DavidWeightedTimestepSampler(
|
| 631 |
num_timesteps=cfg.num_train_timesteps,
|
| 632 |
num_bins=100,
|
| 633 |
shift=cfg.timestep_shift if cfg.use_david_weights else 0.0,
|
| 634 |
base_jitter=cfg.base_jitter,
|
| 635 |
+
adaptive_chaos=cfg.adaptive_chaos,
|
| 636 |
+
reliability_threshold=cfg.reliability_threshold
|
| 637 |
)
|
| 638 |
|
| 639 |
if cfg.use_david_weights:
|
|
|
|
| 657 |
self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader))
|
| 658 |
self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
|
| 659 |
|
| 660 |
+
# Load latest checkpoint from HuggingFace if continuing training
|
| 661 |
+
if cfg.continue_training:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
self._load_latest_from_hf()
|
| 663 |
|
| 664 |
self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name))
|
| 665 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
def _load_latest_from_hf(self):
|
| 667 |
+
"""Load the most recent checkpoint from HuggingFace repo."""
|
| 668 |
if not self.cfg.hf_repo_id:
|
| 669 |
+
print("ℹ️ No HuggingFace repo specified, starting from scratch\n")
|
| 670 |
return
|
| 671 |
|
| 672 |
try:
|
| 673 |
api = HfApi()
|
| 674 |
print(f"\n🔍 Searching for latest checkpoint in {self.cfg.hf_repo_id}...")
|
| 675 |
|
| 676 |
+
# List all files in the repo
|
| 677 |
files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model")
|
| 678 |
+
|
| 679 |
+
# Find all epoch checkpoints (format: {run_name}_e{epoch}.pt)
|
| 680 |
epochs = []
|
| 681 |
for f in files:
|
| 682 |
+
if f.endswith('.pt') and 'final' not in f.lower():
|
| 683 |
match = re.search(r'_e(\d+)\.pt$', f)
|
| 684 |
if match:
|
| 685 |
+
epoch_num = int(match.group(1))
|
| 686 |
+
epochs.append((epoch_num, f))
|
| 687 |
|
| 688 |
if not epochs:
|
| 689 |
+
print("ℹ️ No previous checkpoints found, starting from scratch\n")
|
| 690 |
return
|
| 691 |
|
| 692 |
+
# Get the latest epoch
|
| 693 |
latest_epoch, latest_file = max(epochs, key=lambda x: x[0])
|
| 694 |
+
print(f"📥 Found latest checkpoint: {latest_file} (epoch {latest_epoch})")
|
| 695 |
|
| 696 |
+
# Download checkpoint
|
| 697 |
local_path = hf_hub_download(
|
| 698 |
repo_id=self.cfg.hf_repo_id,
|
| 699 |
filename=latest_file,
|
|
|
|
| 701 |
cache_dir=self.cfg.ckpt_dir
|
| 702 |
)
|
| 703 |
|
| 704 |
+
# Load checkpoint
|
| 705 |
+
print(f"📦 Loading checkpoint...")
|
| 706 |
checkpoint = torch.load(local_path, map_location='cpu')
|
| 707 |
|
| 708 |
+
# Load student state dict
|
| 709 |
+
if 'student' in checkpoint:
|
| 710 |
+
missing, unexpected = self.student.load_state_dict(checkpoint['student'], strict=False)
|
| 711 |
+
if missing:
|
| 712 |
+
print(f" ⚠️ Missing keys: {len(missing)}")
|
| 713 |
+
if unexpected:
|
| 714 |
+
print(f" ⚠️ Unexpected keys: {len(unexpected)}")
|
| 715 |
+
print(f" ✓ Loaded student model")
|
| 716 |
+
else:
|
| 717 |
+
print(f" ⚠️ Warning: 'student' key not found in checkpoint")
|
| 718 |
+
return
|
| 719 |
|
| 720 |
+
# Load optimizer state
|
| 721 |
if 'opt' in checkpoint:
|
| 722 |
+
try:
|
| 723 |
+
self.opt.load_state_dict(checkpoint['opt'])
|
| 724 |
+
print(" ✓ Loaded optimizer state")
|
| 725 |
+
except Exception as e:
|
| 726 |
+
print(f" ⚠️ Failed to load optimizer state: {e}")
|
| 727 |
+
|
| 728 |
+
# Load scheduler state
|
| 729 |
if 'sched' in checkpoint:
|
| 730 |
+
try:
|
| 731 |
+
self.sched.load_state_dict(checkpoint['sched'])
|
| 732 |
+
print(" ✓ Loaded scheduler state")
|
| 733 |
+
except Exception as e:
|
| 734 |
+
print(f" ⚠️ Failed to load scheduler state: {e}")
|
| 735 |
|
| 736 |
+
# Set starting epoch and global step
|
| 737 |
+
if 'gstep' in checkpoint:
|
| 738 |
+
self.start_gstep = checkpoint['gstep']
|
| 739 |
+
self.start_epoch = latest_epoch
|
| 740 |
+
print(f" ✓ Resuming from epoch {self.start_epoch + 1}, global step {self.start_gstep}")
|
| 741 |
+
else:
|
| 742 |
+
# Fallback: estimate from epoch number
|
| 743 |
+
self.start_epoch = latest_epoch
|
| 744 |
+
self.start_gstep = latest_epoch * len(self.loader)
|
| 745 |
+
print(f" ✓ Resuming from epoch {self.start_epoch + 1} (estimated step {self.start_gstep})")
|
| 746 |
|
| 747 |
+
# Cleanup
|
| 748 |
del checkpoint
|
| 749 |
torch.cuda.empty_cache()
|
| 750 |
|
| 751 |
+
print(f"✅ Successfully resumed from checkpoint!\n")
|
| 752 |
+
|
| 753 |
except Exception as e:
|
| 754 |
+
print(f"⚠️ Failed to load checkpoint: {e}")
|
| 755 |
+
print(" Starting training from scratch...\n")
|
| 756 |
|
| 757 |
def _v_star(self, x_t, t, eps_hat):
|
| 758 |
alpha, sigma = self.teacher.alpha_sigma(t)
|
|
|
|
| 770 |
cfg = self.cfg
|
| 771 |
gstep = self.start_gstep
|
| 772 |
|
| 773 |
+
# Test prompts for monitoring progress
|
| 774 |
+
test_prompts = [
|
| 775 |
+
"a castle at sunset",
|
| 776 |
+
"a mountain landscape with trees",
|
| 777 |
+
"a city street at night"
|
| 778 |
+
]
|
| 779 |
+
|
| 780 |
for ep in range(self.start_epoch, cfg.epochs):
|
| 781 |
+
# Sample before epoch to monitor progress
|
| 782 |
+
if ep > 0 or self.start_epoch > 0: # Skip first ever epoch
|
| 783 |
+
print(f"\n🎨 Sampling test images before epoch {ep+1}...")
|
| 784 |
+
try:
|
| 785 |
+
test_imgs = self.sample(test_prompts, steps=30, guidance=7.5)
|
| 786 |
+
|
| 787 |
+
# Save individual images
|
| 788 |
+
sample_dir = Path(cfg.out_dir) / "samples"
|
| 789 |
+
sample_dir.mkdir(exist_ok=True, parents=True)
|
| 790 |
+
|
| 791 |
+
for i, (img, prompt) in enumerate(zip(test_imgs, test_prompts)):
|
| 792 |
+
# Convert to PIL
|
| 793 |
+
img_np = ((img.cpu().permute(1,2,0).numpy() + 1) / 2 * 255).astype('uint8')
|
| 794 |
+
from PIL import Image
|
| 795 |
+
pil_img = Image.fromarray(img_np)
|
| 796 |
+
|
| 797 |
+
# Save with epoch number
|
| 798 |
+
safe_prompt = prompt.replace(" ", "_")[:30]
|
| 799 |
+
img_path = sample_dir / f"e{ep}_p{i}_{safe_prompt}.png"
|
| 800 |
+
pil_img.save(img_path)
|
| 801 |
+
|
| 802 |
+
# Log to tensorboard
|
| 803 |
+
self.writer.add_image(f"samples/{safe_prompt}",
|
| 804 |
+
(img + 1) / 2, # Normalize to [0,1]
|
| 805 |
+
global_step=ep)
|
| 806 |
+
|
| 807 |
+
print(f"✓ Saved {len(test_imgs)} test images to {sample_dir}")
|
| 808 |
+
|
| 809 |
+
except Exception as e:
|
| 810 |
+
print(f"⚠️ Sampling failed: {e}")
|
| 811 |
+
|
| 812 |
self.student.train()
|
| 813 |
pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}",
|
| 814 |
dynamic_ncols=True, leave=True, position=0)
|
|
|
|
| 894 |
self._save(ep+1, gstep)
|
| 895 |
|
| 896 |
self._save("final", gstep)
|
| 897 |
+
|
| 898 |
+
# Final comprehensive sampling
|
| 899 |
+
print("\n🎨 Generating final test samples...")
|
| 900 |
+
final_prompts = [
|
| 901 |
+
"a castle at sunset",
|
| 902 |
+
"a mountain landscape with trees",
|
| 903 |
+
"a city street at night",
|
| 904 |
+
"a portrait of a person",
|
| 905 |
+
"abstract geometric shapes"
|
| 906 |
+
]
|
| 907 |
+
try:
|
| 908 |
+
final_imgs = self.sample(final_prompts, steps=30, guidance=7.5)
|
| 909 |
+
|
| 910 |
+
sample_dir = Path(cfg.out_dir) / "samples"
|
| 911 |
+
sample_dir.mkdir(exist_ok=True, parents=True)
|
| 912 |
+
|
| 913 |
+
for i, (img, prompt) in enumerate(zip(final_imgs, final_prompts)):
|
| 914 |
+
from PIL import Image
|
| 915 |
+
img_np = ((img.cpu().permute(1,2,0).numpy() + 1) / 2 * 255).astype('uint8')
|
| 916 |
+
pil_img = Image.fromarray(img_np)
|
| 917 |
+
safe_prompt = prompt.replace(" ", "_")[:30]
|
| 918 |
+
pil_img.save(sample_dir / f"final_{safe_prompt}.png")
|
| 919 |
+
|
| 920 |
+
print(f"✓ Saved {len(final_imgs)} final images to {sample_dir}")
|
| 921 |
+
except Exception as e:
|
| 922 |
+
print(f"⚠️ Final sampling failed: {e}")
|
| 923 |
+
|
| 924 |
self.writer.close()
|
| 925 |
|
| 926 |
def _save(self, tag, gstep):
|
|
|
|
| 963 |
def sample(self, prompts: List[str], steps: Optional[int]=None, guidance: Optional[float]=None) -> torch.Tensor:
|
| 964 |
steps = steps or self.cfg.sample_steps
|
| 965 |
guidance = guidance if guidance is not None else self.cfg.guidance_scale
|
| 966 |
+
|
| 967 |
+
# Get model dtype from student
|
| 968 |
+
model_dtype = next(self.student.unet.parameters()).dtype
|
| 969 |
+
|
| 970 |
cond_e = self.teacher.encode(prompts)
|
| 971 |
uncond_e = self.teacher.encode([""]*len(prompts))
|
| 972 |
sched = self.teacher.sched
|
| 973 |
sched.set_timesteps(steps, device=self.device)
|
| 974 |
+
|
| 975 |
+
# Create latents with correct dtype
|
| 976 |
+
x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device, dtype=model_dtype)
|
| 977 |
|
| 978 |
for t_scalar in sched.timesteps:
|
| 979 |
t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long)
|