|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
import os, json, math, random, re, shutil |
|
|
from dataclasses import dataclass, asdict |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
from diffusers import StableDiffusionPipeline, DDPMScheduler |
|
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
|
|
|
|
|
|
|
|
from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective |
|
|
from geovocab2.data.prompt.symbolic_tree import SynthesisSystem |
|
|
|
|
|
|
|
|
from huggingface_hub import snapshot_download, HfApi, create_repo, hf_hub_download |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BaseConfig: |
|
|
run_name: str = "sd15_flowmatch_david_weighted" |
|
|
out_dir: str = "./runs/sd15_flowmatch_david_weighted" |
|
|
ckpt_dir: str = "./checkpoints_sd15_flow_david_weighted" |
|
|
save_every: int = 1 |
|
|
|
|
|
|
|
|
num_samples: int = 200_000 |
|
|
batch_size: int = 32 |
|
|
num_workers: int = 2 |
|
|
seed: int = 42 |
|
|
|
|
|
|
|
|
model_id: str = "runwayml/stable-diffusion-v1-5" |
|
|
active_blocks: Tuple[str, ...] = ("down_0","down_1","down_2","down_3","mid","up_0","up_1","up_2","up_3") |
|
|
pooling: str = "mean" |
|
|
|
|
|
|
|
|
epochs: int = 20 |
|
|
lr: float = 1e-4 |
|
|
weight_decay: float = 1e-3 |
|
|
grad_clip: float = 1.0 |
|
|
amp: bool = True |
|
|
|
|
|
global_flow_weight: float = 1.0 |
|
|
block_penalty_weight: float = 0.2 |
|
|
use_local_flow_heads: bool = False |
|
|
local_flow_weight: float = 1.0 |
|
|
|
|
|
|
|
|
use_kd: bool = True |
|
|
kd_weight: float = 0.25 |
|
|
|
|
|
|
|
|
david_repo_id: str = "AbstractPhil/geo-david-collective-sd15-base-e40" |
|
|
david_cache_dir: str = "./_hf_david_cache" |
|
|
david_state_key: Optional[str] = None |
|
|
|
|
|
|
|
|
alpha_timestep: float = 0.5 |
|
|
beta_pattern: float = 0.25 |
|
|
delta_incoherence: float = 0.25 |
|
|
lambda_min: float = 0.5 |
|
|
lambda_max: float = 3.0 |
|
|
|
|
|
block_weights: Dict[str, float] = None |
|
|
|
|
|
|
|
|
use_timestep_weighting: bool = True |
|
|
use_david_weights: bool = True |
|
|
timestep_shift: float = 3.0 |
|
|
base_jitter: int = 5 |
|
|
adaptive_chaos: bool = True |
|
|
profile_samples: int = 2500 |
|
|
reliability_threshold: float = 0.15 |
|
|
|
|
|
|
|
|
num_train_timesteps: int = 1000 |
|
|
|
|
|
|
|
|
sample_steps: int = 30 |
|
|
guidance_scale: float = 7.5 |
|
|
|
|
|
|
|
|
hf_repo_id: Optional[str] = "AbstractPhil/sd15-flow-matching" |
|
|
upload_every_epoch: bool = True |
|
|
continue_training: bool = True |
|
|
|
|
|
def __post_init__(self): |
|
|
Path(self.out_dir).mkdir(parents=True, exist_ok=True) |
|
|
Path(self.ckpt_dir).mkdir(parents=True, exist_ok=True) |
|
|
Path(self.david_cache_dir).mkdir(parents=True, exist_ok=True) |
|
|
if self.block_weights is None: |
|
|
self.block_weights = {'down_0':0.7,'down_1':0.9,'down_2':1.0,'down_3':1.1,'mid':1.2,'up_0':1.1,'up_1':1.0,'up_2':0.9,'up_3':0.7} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DavidWeightedTimestepSampler: |
|
|
""" |
|
|
Samples timesteps weighted by David's inherent difficulty + SD3 shift + adaptive chaos. |
|
|
FIXED: Properly handles nested GeoDavidCollective output structure. |
|
|
FIXED: Filters out unreliable bins (accuracy < threshold). |
|
|
""" |
|
|
def __init__(self, num_timesteps=1000, num_bins=100, shift=3.0, base_jitter=5, adaptive_chaos=True, reliability_threshold=0.15): |
|
|
self.num_timesteps = num_timesteps |
|
|
self.num_bins = num_bins |
|
|
self.shift = shift |
|
|
self.base_jitter = base_jitter |
|
|
self.adaptive_chaos = adaptive_chaos |
|
|
self.reliability_threshold = reliability_threshold |
|
|
|
|
|
self.difficulty_weights = None |
|
|
self.pattern_difficulty = None |
|
|
|
|
|
def _apply_shift(self, t: float) -> float: |
|
|
"""Apply SD3-style timestep shift (operates on normalized t ∈ [0,1]).""" |
|
|
if self.shift <= 0: |
|
|
return t |
|
|
return self.shift * t / (1.0 + (self.shift - 1.0) * t) |
|
|
|
|
|
def compute_difficulty_from_david(self, david, teacher, device, num_samples=500): |
|
|
"""Profile David's confusion patterns to create difficulty map.""" |
|
|
print("🔍 Profiling David's timestep & pattern difficulty...") |
|
|
|
|
|
david.eval() |
|
|
teacher.eval() |
|
|
|
|
|
|
|
|
correct_per_bin = torch.zeros(self.num_bins) |
|
|
total_per_bin = torch.zeros(self.num_bins) |
|
|
entropy_per_bin = torch.zeros(self.num_bins) |
|
|
entropy_count_per_bin = torch.zeros(self.num_bins) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in tqdm(range(num_samples // 32), desc="Profiling David", leave=False): |
|
|
|
|
|
x = torch.randn(32, 4, 64, 64, device=device, dtype=torch.float16) |
|
|
t = torch.randint(0, self.num_timesteps, (32,), device=device) |
|
|
t_bins = (t // 10) |
|
|
|
|
|
|
|
|
ehs = torch.randn(32, 77, 768, device=device, dtype=torch.float16) |
|
|
|
|
|
|
|
|
teacher.hooks.clear() |
|
|
_ = teacher.unet(x, t, encoder_hidden_states=ehs) |
|
|
feats = {k: v.float() for k, v in teacher.hooks.bank.items()} |
|
|
|
|
|
|
|
|
pooled = {name: f.mean(dim=(2, 3)) for name, f in feats.items()} |
|
|
|
|
|
|
|
|
outputs = david(pooled, t.float()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
timestep_logits_list = [] |
|
|
for block_name, block_out in outputs.items(): |
|
|
if 'timestep_logits' in block_out: |
|
|
timestep_logits_list.append(block_out['timestep_logits']) |
|
|
|
|
|
if timestep_logits_list: |
|
|
|
|
|
ts_logits = torch.stack(timestep_logits_list).mean(0) |
|
|
preds = ts_logits.argmax(dim=-1) |
|
|
|
|
|
for pred, true_bin in zip(preds, t_bins): |
|
|
bin_idx = true_bin.item() |
|
|
correct_per_bin[bin_idx] += (pred == true_bin).float().item() |
|
|
total_per_bin[bin_idx] += 1 |
|
|
|
|
|
|
|
|
pattern_logits_list = [] |
|
|
for block_name, block_out in outputs.items(): |
|
|
if 'pattern_logits' in block_out: |
|
|
pattern_logits_list.append(block_out['pattern_logits']) |
|
|
|
|
|
if pattern_logits_list: |
|
|
|
|
|
pt_logits = torch.stack(pattern_logits_list).mean(0) |
|
|
|
|
|
P = pt_logits.softmax(-1) |
|
|
ent = -(P * P.clamp_min(1e-9).log()).sum(-1) |
|
|
norm_ent = ent / math.log(P.shape[-1]) |
|
|
|
|
|
for i, true_bin in enumerate(t_bins): |
|
|
bin_idx = true_bin.item() |
|
|
entropy_per_bin[bin_idx] += norm_ent[i].item() |
|
|
entropy_count_per_bin[bin_idx] += 1 |
|
|
|
|
|
|
|
|
accuracy_per_bin = correct_per_bin / (total_per_bin.clamp(min=1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reliable_mask = accuracy_per_bin >= self.reliability_threshold |
|
|
num_reliable = reliable_mask.sum().item() |
|
|
num_disabled = self.num_bins - num_reliable |
|
|
|
|
|
print(f"\n🎯 Reliability Analysis:") |
|
|
print(f" Threshold: {self.reliability_threshold:.0%}") |
|
|
print(f" Reliable bins: {num_reliable}/{self.num_bins}") |
|
|
print(f" Disabled bins: {num_disabled}/{self.num_bins}") |
|
|
|
|
|
if num_disabled > 0: |
|
|
disabled_bins = torch.where(~reliable_mask)[0].tolist() |
|
|
disabled_accs = [accuracy_per_bin[i].item() for i in disabled_bins] |
|
|
print(f" Disabled: {disabled_bins[:10]}{'...' if len(disabled_bins) > 10 else ''}") |
|
|
print(f" (accuracies: {[f'{a:.1%}' for a in disabled_accs[:10]]})") |
|
|
|
|
|
|
|
|
if num_reliable == 0: |
|
|
print("\n⚠️ WARNING: No reliable bins found! Falling back to uniform sampling.") |
|
|
self.difficulty_weights = torch.ones(self.num_bins) / self.num_bins |
|
|
self.pattern_difficulty = torch.ones(self.num_bins) * 0.5 |
|
|
return self.difficulty_weights |
|
|
|
|
|
|
|
|
timestep_difficulty = torch.zeros(self.num_bins) |
|
|
timestep_difficulty[reliable_mask] = (1.0 - accuracy_per_bin[reliable_mask]) + 0.1 |
|
|
|
|
|
|
|
|
timestep_difficulty[~reliable_mask] = 0.0 |
|
|
|
|
|
|
|
|
self.difficulty_weights = timestep_difficulty / timestep_difficulty.sum() |
|
|
|
|
|
|
|
|
self.pattern_difficulty = entropy_per_bin / (entropy_count_per_bin.clamp(min=1)) |
|
|
self.pattern_difficulty = self.pattern_difficulty.clamp(min=0.1, max=1.0) |
|
|
|
|
|
|
|
|
self.pattern_difficulty[~reliable_mask] = 0.5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n✓ David difficulty map computed (filtered):") |
|
|
print(f" Avg timestep accuracy (all bins): {accuracy_per_bin.mean():.2%}") |
|
|
print(f" Avg timestep accuracy (reliable): {accuracy_per_bin[reliable_mask].mean():.2%}") |
|
|
|
|
|
|
|
|
reliable_indices = torch.where(reliable_mask)[0] |
|
|
if len(reliable_indices) > 0: |
|
|
hardest_idx = reliable_indices[accuracy_per_bin[reliable_mask].argmin()].item() |
|
|
easiest_idx = reliable_indices[accuracy_per_bin[reliable_mask].argmax()].item() |
|
|
|
|
|
print(f" Hardest reliable bin: {hardest_idx} ({accuracy_per_bin[hardest_idx]:.2%} acc)") |
|
|
print(f" Easiest reliable bin: {easiest_idx} ({accuracy_per_bin[easiest_idx]:.2%} acc)") |
|
|
|
|
|
print(f" Avg pattern entropy (reliable): {self.pattern_difficulty[reliable_mask].mean():.3f}") |
|
|
|
|
|
|
|
|
top_weights, top_bins = self.difficulty_weights.topk(10) |
|
|
print(f"\n📊 Top 10 sampled bins (by difficulty weight):") |
|
|
for i, (bin_idx, weight) in enumerate(zip(top_bins.tolist(), top_weights.tolist())): |
|
|
acc = accuracy_per_bin[bin_idx].item() |
|
|
print(f" {i+1}. Bin {bin_idx:2d}: weight={weight:.3f} (acc={acc:.1%})") |
|
|
|
|
|
return self.difficulty_weights |
|
|
|
|
|
def sample(self, batch_size: int) -> List[int]: |
|
|
"""Sample timesteps with David weighting + shift + adaptive chaos.""" |
|
|
if self.difficulty_weights is None: |
|
|
|
|
|
return [random.randint(0, self.num_timesteps - 1) for _ in range(batch_size)] |
|
|
|
|
|
timesteps = [] |
|
|
for _ in range(batch_size): |
|
|
|
|
|
bin_idx = torch.multinomial(self.difficulty_weights, 1).item() |
|
|
|
|
|
|
|
|
bin_center_raw = bin_idx * (self.num_timesteps // self.num_bins) + (self.num_timesteps // self.num_bins) // 2 |
|
|
t_normalized = bin_center_raw / self.num_timesteps |
|
|
|
|
|
|
|
|
t_shifted = self._apply_shift(t_normalized) |
|
|
|
|
|
|
|
|
if self.adaptive_chaos: |
|
|
chaos_scale = self.pattern_difficulty[bin_idx].item() |
|
|
jitter = int(self.base_jitter * (0.5 + chaos_scale)) |
|
|
else: |
|
|
jitter = self.base_jitter |
|
|
|
|
|
|
|
|
t_raw = int(t_shifted * self.num_timesteps) |
|
|
t_raw += random.randint(-jitter, jitter) |
|
|
t_raw = max(0, min(self.num_timesteps - 1, t_raw)) |
|
|
|
|
|
timesteps.append(t_raw) |
|
|
|
|
|
return timesteps |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SymbolicPromptDataset(Dataset): |
|
|
def __init__(self, n:int, seed:int=42, timestep_sampler=None): |
|
|
self.n = n |
|
|
self.timestep_sampler = timestep_sampler |
|
|
random.seed(seed) |
|
|
self.sys = SynthesisSystem(seed=seed) |
|
|
|
|
|
def __len__(self): return self.n |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
r = self.sys.synthesize(complexity=random.choice([1,2,3,4,5])) |
|
|
prompt = r['text'] |
|
|
|
|
|
if self.timestep_sampler: |
|
|
t = self.timestep_sampler.sample(1)[0] |
|
|
else: |
|
|
t = random.randint(0, 999) |
|
|
|
|
|
return {"prompt": prompt, "t": t} |
|
|
|
|
|
def collate(batch: List[dict]): |
|
|
prompts = [b["prompt"] for b in batch] |
|
|
t = torch.tensor([b["t"] for b in batch], dtype=torch.long) |
|
|
t_bins = t // 10 |
|
|
return {"prompts": prompts, "t": t, "t_bins": t_bins} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HookBank: |
|
|
def __init__(self, unet: UNet2DConditionModel, active: Tuple[str, ...]): |
|
|
self.active = set(active) |
|
|
self.bank: Dict[str, torch.Tensor] = {} |
|
|
self.hooks: List[torch.utils.hooks.RemovableHandle] = [] |
|
|
self._register(unet) |
|
|
|
|
|
def _register(self, unet: UNet2DConditionModel): |
|
|
def mk(name): |
|
|
def h(m, i, o): |
|
|
out = o[0] if isinstance(o,(tuple,list)) else o |
|
|
self.bank[name] = out |
|
|
return h |
|
|
for i, blk in enumerate(unet.down_blocks): |
|
|
nm = f"down_{i}" |
|
|
if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm))) |
|
|
if "mid" in self.active: |
|
|
self.hooks.append(unet.mid_block.register_forward_hook(mk("mid"))) |
|
|
for i, blk in enumerate(unet.up_blocks): |
|
|
nm = f"up_{i}" |
|
|
if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm))) |
|
|
|
|
|
def clear(self): self.bank.clear() |
|
|
def close(self): |
|
|
for h in self.hooks: h.remove() |
|
|
self.hooks.clear() |
|
|
|
|
|
def spatial_pool(x: torch.Tensor, name: str, policy: str) -> torch.Tensor: |
|
|
if policy == "mean": return x.mean(dim=(2,3)) |
|
|
if policy == "max": return x.amax(dim=(2,3)) |
|
|
if policy == "adaptive": |
|
|
return x.mean(dim=(2,3)) if (name.startswith("down") or name=="mid") else x.amax(dim=(2,3)) |
|
|
raise ValueError(f"Unknown pooling: {policy}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SD15Teacher(nn.Module): |
|
|
def __init__(self, cfg: BaseConfig, device: str): |
|
|
super().__init__() |
|
|
self.pipe = StableDiffusionPipeline.from_pretrained(cfg.model_id, torch_dtype=torch.float16, safety_checker=None).to(device) |
|
|
self.unet: UNet2DConditionModel = self.pipe.unet |
|
|
self.text_encoder = self.pipe.text_encoder |
|
|
self.tokenizer = self.pipe.tokenizer |
|
|
self.hooks = HookBank(self.unet, cfg.active_blocks) |
|
|
self.sched = DDPMScheduler(num_train_timesteps=cfg.num_train_timesteps) |
|
|
self.device = device |
|
|
for p in self.parameters(): p.requires_grad_(False) |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode(self, prompts: List[str]) -> torch.Tensor: |
|
|
tok = self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, |
|
|
truncation=True, return_tensors="pt") |
|
|
return self.text_encoder(tok.input_ids.to(self.device))[0] |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward_eps_and_feats(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor): |
|
|
self.hooks.clear() |
|
|
eps_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample |
|
|
feats = {k: v.detach().float() for k, v in self.hooks.bank.items()} |
|
|
return eps_hat.float(), feats |
|
|
|
|
|
def alpha_sigma(self, t: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
ac = self.sched.alphas_cumprod.to(self.device)[t] |
|
|
alpha = ac.sqrt().view(-1,1,1,1).float() |
|
|
sigma = (1.0 - ac).sqrt().view(-1,1,1,1).float() |
|
|
return alpha, sigma |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StudentUNet(nn.Module): |
|
|
def __init__(self, teacher_unet: UNet2DConditionModel, active_blocks: Tuple[str,...], use_local_heads: bool): |
|
|
super().__init__() |
|
|
self.unet = UNet2DConditionModel.from_config(teacher_unet.config) |
|
|
self.unet.load_state_dict(teacher_unet.state_dict(), strict=True) |
|
|
self.hooks = HookBank(self.unet, active_blocks) |
|
|
self.use_local_heads = use_local_heads |
|
|
self.local_heads = nn.ModuleDict() |
|
|
|
|
|
def _ensure_heads(self, feats: Dict[str, torch.Tensor]): |
|
|
if not self.use_local_heads: return |
|
|
if len(self.local_heads) == len(feats): return |
|
|
|
|
|
target_dtype = next(self.unet.parameters()).dtype |
|
|
|
|
|
for name, f in feats.items(): |
|
|
c = f.shape[1] |
|
|
if name not in self.local_heads: |
|
|
head = nn.Conv2d(c, 4, kernel_size=1) |
|
|
head = head.to(dtype=target_dtype, device=f.device) |
|
|
self.local_heads[name] = head |
|
|
|
|
|
def forward(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor): |
|
|
self.hooks.clear() |
|
|
v_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample |
|
|
feats = {k: v for k, v in self.hooks.bank.items()} |
|
|
self._ensure_heads(feats) |
|
|
return v_hat, feats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DavidLoader: |
|
|
def __init__(self, cfg: BaseConfig, device: str): |
|
|
self.cfg = cfg |
|
|
self.device = device |
|
|
self.repo_dir = snapshot_download(repo_id=cfg.david_repo_id, local_dir=cfg.david_cache_dir, local_dir_use_symlinks=False) |
|
|
self.config_path = os.path.join(self.repo_dir, "config.json") |
|
|
self.weights_path = os.path.join(self.repo_dir, "model.safetensors") |
|
|
with open(self.config_path, "r") as f: |
|
|
self.hf_config = json.load(f) |
|
|
|
|
|
self.gdc = GeoDavidCollective( |
|
|
block_configs=self.hf_config["block_configs"], |
|
|
num_timestep_bins=int(self.hf_config["num_timestep_bins"]), |
|
|
num_patterns_per_bin=int(self.hf_config["num_patterns_per_bin"]), |
|
|
block_weights=self.hf_config.get("block_weights", {k:1.0 for k in self.hf_config["block_configs"].keys()}), |
|
|
loss_config=self.hf_config.get("loss_config", {}) |
|
|
).to(device).eval() |
|
|
|
|
|
state = load_file(self.weights_path) |
|
|
self.gdc.load_state_dict(state, strict=False) |
|
|
for p in self.gdc.parameters(): p.requires_grad_(False) |
|
|
|
|
|
print(f"✓ David loaded from HF: {self.repo_dir}") |
|
|
print(f" blocks={len(self.hf_config['block_configs'])} bins={self.hf_config['num_timestep_bins']} patterns={self.hf_config['num_patterns_per_bin']}") |
|
|
|
|
|
if "block_weights" in self.hf_config: |
|
|
cfg.block_weights = self.hf_config["block_weights"] |
|
|
|
|
|
class DavidAssessor(nn.Module): |
|
|
""" |
|
|
CORRECTED: Properly handles GeoDavidCollective's nested multi-block output structure. |
|
|
|
|
|
GeoDavidCollective returns: Dict[block_name, Dict[str, Tensor]] |
|
|
Not a flat Dict[str, Tensor]! |
|
|
""" |
|
|
def __init__(self, gdc: GeoDavidCollective, pooling: str): |
|
|
super().__init__() |
|
|
self.gdc = gdc |
|
|
self.pooling = pooling |
|
|
|
|
|
def _pool(self, feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
return {k: spatial_pool(v, k, self.pooling) for k, v in feats.items()} |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, feats_student: Dict[str, torch.Tensor], t: torch.LongTensor |
|
|
) -> Tuple[Dict[str,float], Dict[str,float], Dict[str,float]]: |
|
|
""" |
|
|
Assess student features using David's geometric knowledge. |
|
|
|
|
|
Returns: |
|
|
e_t: Dict[block_name, timestep_error] - classification error per block |
|
|
e_p: Dict[block_name, pattern_entropy] - normalized entropy per block |
|
|
coh: Dict[block_name, coherence] - geometric coherence per block |
|
|
""" |
|
|
|
|
|
Zs = self._pool(feats_student) |
|
|
|
|
|
|
|
|
|
|
|
outs = self.gdc(Zs, t.float()) |
|
|
|
|
|
|
|
|
e_t, e_p, coh = {}, {}, {} |
|
|
|
|
|
|
|
|
t_bins = (t // 10).to(next(self.gdc.parameters()).device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for block_name, block_out in outs.items(): |
|
|
if 'timestep_logits' in block_out: |
|
|
ts_logits = block_out['timestep_logits'] |
|
|
ce = F.cross_entropy(ts_logits, t_bins, reduction="mean") |
|
|
e_t[block_name] = float(ce.item()) |
|
|
|
|
|
|
|
|
if not e_t: |
|
|
for name in Zs.keys(): |
|
|
e_t[name] = 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for block_name, block_out in outs.items(): |
|
|
if 'pattern_logits' in block_out: |
|
|
pt_logits = block_out['pattern_logits'] |
|
|
|
|
|
|
|
|
P = pt_logits.softmax(-1) |
|
|
ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean() |
|
|
norm_ent = ent / math.log(P.shape[-1]) |
|
|
|
|
|
e_p[block_name] = float(norm_ent.item()) |
|
|
|
|
|
|
|
|
if not e_p: |
|
|
for name in Zs.keys(): |
|
|
e_p[name] = 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
alphas = self.gdc.get_cantor_alphas() |
|
|
|
|
|
|
|
|
for name, alpha in alphas.items(): |
|
|
|
|
|
|
|
|
|
|
|
coherence = 1.0 - 2.0 * abs(alpha - 0.5) |
|
|
coh[name] = max(0.0, min(1.0, coherence)) |
|
|
except Exception: |
|
|
|
|
|
for name in Zs.keys(): |
|
|
coh[name] = 1.0 |
|
|
|
|
|
|
|
|
for name in Zs.keys(): |
|
|
if name not in e_t: |
|
|
|
|
|
e_t[name] = sum(e_t.values()) / max(len(e_t), 1) if e_t else 0.0 |
|
|
if name not in e_p: |
|
|
e_p[name] = sum(e_p.values()) / max(len(e_p), 1) if e_p else 0.0 |
|
|
if name not in coh: |
|
|
coh[name] = sum(coh.values()) / max(len(coh), 1) if coh else 1.0 |
|
|
|
|
|
return e_t, e_p, coh |
|
|
|
|
|
class BlockPenaltyFusion: |
|
|
def __init__(self, cfg: BaseConfig): self.cfg = cfg |
|
|
def lambdas(self, e_t:Dict[str,float], e_p:Dict[str,float], coh:Dict[str,float]) -> Dict[str,float]: |
|
|
lam = {} |
|
|
for name, base in self.cfg.block_weights.items(): |
|
|
val = base * (1.0 |
|
|
+ self.cfg.alpha_timestep * float(e_t.get(name,0.0)) |
|
|
+ self.cfg.beta_pattern * float(e_p.get(name,0.0)) |
|
|
+ self.cfg.delta_incoherence * (1.0 - float(coh.get(name,1.0)))) |
|
|
lam[name] = float(max(self.cfg.lambda_min, min(self.cfg.lambda_max, val))) |
|
|
return lam |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FlowMatchDavidTrainer: |
|
|
def __init__(self, cfg: BaseConfig, device: str = "cuda"): |
|
|
self.cfg = cfg |
|
|
self.device = device |
|
|
self.start_epoch = 0 |
|
|
self.start_gstep = 0 |
|
|
|
|
|
|
|
|
self.david_loader = DavidLoader(cfg, device) |
|
|
self.david = self.david_loader.gdc |
|
|
self.assessor = DavidAssessor(self.david, cfg.pooling) |
|
|
self.fusion = BlockPenaltyFusion(cfg) |
|
|
|
|
|
|
|
|
self.teacher = SD15Teacher(cfg, device).eval() |
|
|
|
|
|
|
|
|
self.timestep_sampler = None |
|
|
if cfg.use_timestep_weighting: |
|
|
print("\n" + "="*70) |
|
|
print("🎯 ADAPTIVE TIMESTEP SAMPLING ENABLED") |
|
|
print(f" David weighting: {cfg.use_david_weights}") |
|
|
print(f" SD3 shift: {cfg.timestep_shift}") |
|
|
print(f" Base jitter: ±{cfg.base_jitter}") |
|
|
print(f" Adaptive chaos: {cfg.adaptive_chaos}") |
|
|
print(f" Reliability threshold: {cfg.reliability_threshold:.0%}") |
|
|
|
|
|
self.timestep_sampler = DavidWeightedTimestepSampler( |
|
|
num_timesteps=cfg.num_train_timesteps, |
|
|
num_bins=100, |
|
|
shift=cfg.timestep_shift if cfg.use_david_weights else 0.0, |
|
|
base_jitter=cfg.base_jitter, |
|
|
adaptive_chaos=cfg.adaptive_chaos, |
|
|
reliability_threshold=cfg.reliability_threshold |
|
|
) |
|
|
|
|
|
if cfg.use_david_weights: |
|
|
self.timestep_sampler.compute_difficulty_from_david( |
|
|
david=self.david, |
|
|
teacher=self.teacher, |
|
|
device=device, |
|
|
num_samples=cfg.profile_samples |
|
|
) |
|
|
print("="*70 + "\n") |
|
|
|
|
|
|
|
|
self.dataset = SymbolicPromptDataset(cfg.num_samples, cfg.seed, self.timestep_sampler) |
|
|
self.loader = DataLoader(self.dataset, batch_size=cfg.batch_size, shuffle=True, |
|
|
num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate) |
|
|
|
|
|
|
|
|
self.student = StudentUNet(self.teacher.unet, cfg.active_blocks, cfg.use_local_flow_heads).to(device) |
|
|
|
|
|
self.opt = torch.optim.AdamW(self.student.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) |
|
|
self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader)) |
|
|
self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp) |
|
|
|
|
|
|
|
|
if cfg.continue_training: |
|
|
self._load_latest_from_hf() |
|
|
|
|
|
self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name)) |
|
|
|
|
|
def _load_latest_from_hf(self): |
|
|
"""Load the most recent checkpoint from HuggingFace repo.""" |
|
|
if not self.cfg.hf_repo_id: |
|
|
print("ℹ️ No HuggingFace repo specified, starting from scratch\n") |
|
|
return |
|
|
|
|
|
try: |
|
|
api = HfApi() |
|
|
print(f"\n🔍 Searching for latest checkpoint in {self.cfg.hf_repo_id}...") |
|
|
|
|
|
|
|
|
files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model") |
|
|
|
|
|
|
|
|
epochs = [] |
|
|
for f in files: |
|
|
if f.endswith('.pt') and 'final' not in f.lower(): |
|
|
match = re.search(r'_e(\d+)\.pt$', f) |
|
|
if match: |
|
|
epoch_num = int(match.group(1)) |
|
|
epochs.append((epoch_num, f)) |
|
|
|
|
|
if not epochs: |
|
|
print("ℹ️ No previous checkpoints found, starting from scratch\n") |
|
|
return |
|
|
|
|
|
|
|
|
latest_epoch, latest_file = max(epochs, key=lambda x: x[0]) |
|
|
print(f"📥 Found latest checkpoint: {latest_file} (epoch {latest_epoch})") |
|
|
|
|
|
|
|
|
local_path = hf_hub_download( |
|
|
repo_id=self.cfg.hf_repo_id, |
|
|
filename=latest_file, |
|
|
repo_type="model", |
|
|
cache_dir=self.cfg.ckpt_dir |
|
|
) |
|
|
|
|
|
|
|
|
print(f"📦 Loading checkpoint...") |
|
|
checkpoint = torch.load(local_path, map_location='cpu') |
|
|
|
|
|
|
|
|
if 'student' in checkpoint: |
|
|
missing, unexpected = self.student.load_state_dict(checkpoint['student'], strict=False) |
|
|
if missing: |
|
|
print(f" ⚠️ Missing keys: {len(missing)}") |
|
|
if unexpected: |
|
|
print(f" ⚠️ Unexpected keys: {len(unexpected)}") |
|
|
print(f" ✓ Loaded student model") |
|
|
else: |
|
|
print(f" ⚠️ Warning: 'student' key not found in checkpoint") |
|
|
return |
|
|
|
|
|
|
|
|
if 'opt' in checkpoint: |
|
|
try: |
|
|
self.opt.load_state_dict(checkpoint['opt']) |
|
|
print(" ✓ Loaded optimizer state") |
|
|
except Exception as e: |
|
|
print(f" ⚠️ Failed to load optimizer state: {e}") |
|
|
|
|
|
|
|
|
if 'sched' in checkpoint: |
|
|
try: |
|
|
self.sched.load_state_dict(checkpoint['sched']) |
|
|
print(" ✓ Loaded scheduler state") |
|
|
except Exception as e: |
|
|
print(f" ⚠️ Failed to load scheduler state: {e}") |
|
|
|
|
|
|
|
|
if 'gstep' in checkpoint: |
|
|
self.start_gstep = checkpoint['gstep'] |
|
|
self.start_epoch = latest_epoch |
|
|
print(f" ✓ Resuming from epoch {self.start_epoch + 1}, global step {self.start_gstep}") |
|
|
else: |
|
|
|
|
|
self.start_epoch = latest_epoch |
|
|
self.start_gstep = latest_epoch * len(self.loader) |
|
|
print(f" ✓ Resuming from epoch {self.start_epoch + 1} (estimated step {self.start_gstep})") |
|
|
|
|
|
|
|
|
del checkpoint |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
print(f"✅ Successfully resumed from checkpoint!\n") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Failed to load checkpoint: {e}") |
|
|
print(" Starting training from scratch...\n") |
|
|
|
|
|
def _v_star(self, x_t, t, eps_hat): |
|
|
alpha, sigma = self.teacher.alpha_sigma(t) |
|
|
x0_hat = (x_t - sigma * eps_hat) / (alpha + 1e-8) |
|
|
return alpha * eps_hat - sigma * x0_hat |
|
|
|
|
|
def _down_like(self, tgt: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: |
|
|
return F.interpolate(tgt, size=ref.shape[-2:], mode="bilinear", align_corners=False) |
|
|
|
|
|
def _kd_cos(self, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
|
|
s = F.normalize(s, dim=-1); t = F.normalize(t, dim=-1) |
|
|
return 1.0 - (s*t).sum(-1).mean() |
|
|
|
|
|
def train(self): |
|
|
cfg = self.cfg |
|
|
gstep = self.start_gstep |
|
|
|
|
|
|
|
|
test_prompts = [ |
|
|
"a castle at sunset", |
|
|
"a mountain landscape with trees", |
|
|
"a city street at night" |
|
|
] |
|
|
|
|
|
for ep in range(self.start_epoch, cfg.epochs): |
|
|
|
|
|
if ep > 0 or self.start_epoch > 0: |
|
|
print(f"\n🎨 Sampling test images before epoch {ep+1}...") |
|
|
try: |
|
|
test_imgs = self.sample(test_prompts, steps=30, guidance=7.5) |
|
|
|
|
|
|
|
|
sample_dir = Path(cfg.out_dir) / "samples" |
|
|
sample_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
for i, (img, prompt) in enumerate(zip(test_imgs, test_prompts)): |
|
|
|
|
|
img_np = ((img.cpu().permute(1,2,0).numpy() + 1) / 2 * 255).astype('uint8') |
|
|
from PIL import Image |
|
|
pil_img = Image.fromarray(img_np) |
|
|
|
|
|
|
|
|
safe_prompt = prompt.replace(" ", "_")[:30] |
|
|
img_path = sample_dir / f"e{ep}_p{i}_{safe_prompt}.png" |
|
|
pil_img.save(img_path) |
|
|
|
|
|
|
|
|
self.writer.add_image(f"samples/{safe_prompt}", |
|
|
(img + 1) / 2, |
|
|
global_step=ep) |
|
|
|
|
|
print(f"✓ Saved {len(test_imgs)} test images to {sample_dir}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Sampling failed: {e}") |
|
|
|
|
|
self.student.train() |
|
|
pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}", |
|
|
dynamic_ncols=True, leave=True, position=0) |
|
|
acc = {"L":0.0, "Lf":0.0, "Lb":0.0} |
|
|
|
|
|
for it, batch in enumerate(pbar): |
|
|
prompts = batch["prompts"] |
|
|
t = batch["t"].to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
ehs = self.teacher.encode(prompts) |
|
|
|
|
|
x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device, dtype=torch.float16) |
|
|
|
|
|
with torch.no_grad(): |
|
|
eps_hat, t_feats_spatial = self.teacher.forward_eps_and_feats(x_t.half(), t, ehs) |
|
|
v_star = self._v_star(x_t, t, eps_hat) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=cfg.amp): |
|
|
v_hat, s_feats_spatial = self.student(x_t, t, ehs) |
|
|
L_flow = F.mse_loss(v_hat, v_star) |
|
|
|
|
|
e_t, e_p, coh = self.assessor(s_feats_spatial, t) |
|
|
lam = self.fusion.lambdas(e_t, e_p, coh) |
|
|
|
|
|
L_blocks = torch.zeros((), device=self.device) |
|
|
for name, s_feat in s_feats_spatial.items(): |
|
|
L_kd = torch.zeros((), device=self.device) |
|
|
if cfg.use_kd: |
|
|
s_pool = spatial_pool(s_feat, name, cfg.pooling) |
|
|
t_pool = spatial_pool(t_feats_spatial[name], name, cfg.pooling) |
|
|
L_kd = self._kd_cos(s_pool, t_pool) |
|
|
|
|
|
L_lf = torch.zeros((), device=self.device) |
|
|
if cfg.use_local_flow_heads and name in self.student.local_heads: |
|
|
v_loc = self.student.local_heads[name](s_feat) |
|
|
v_ds = self._down_like(v_star, v_loc) |
|
|
L_lf = F.mse_loss(v_loc, v_ds) |
|
|
L_blocks = L_blocks + lam.get(name,1.0) * (cfg.kd_weight * L_kd + cfg.local_flow_weight * L_lf) |
|
|
|
|
|
L_total = cfg.global_flow_weight*L_flow + cfg.block_penalty_weight*L_blocks |
|
|
|
|
|
self.opt.zero_grad(set_to_none=True) |
|
|
if cfg.amp: |
|
|
self.scaler.scale(L_total).backward() |
|
|
nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip) |
|
|
self.scaler.step(self.opt); self.scaler.update() |
|
|
else: |
|
|
L_total.backward() |
|
|
nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip) |
|
|
self.opt.step() |
|
|
self.sched.step(); gstep += 1 |
|
|
|
|
|
acc["L"] += float(L_total.item()) |
|
|
acc["Lf"] += float(L_flow.item()) |
|
|
acc["Lb"] += float(L_blocks.item()) |
|
|
|
|
|
if it % 50 == 0: |
|
|
self.writer.add_scalar("train/total", float(L_total.item()), gstep) |
|
|
self.writer.add_scalar("train/flow", float(L_flow.item()), gstep) |
|
|
self.writer.add_scalar("train/blocks",float(L_blocks.item()), gstep) |
|
|
for k in list(lam.keys())[:4]: |
|
|
self.writer.add_scalar(f"lambda/{k}", lam[k], gstep) |
|
|
|
|
|
if it % 10 == 0 or it == len(self.loader) - 1: |
|
|
pbar.set_postfix({ |
|
|
"L": f"{float(L_total.item()):.4f}", |
|
|
"Lf": f"{float(L_flow.item()):.4f}", |
|
|
"Lb": f"{float(L_blocks.item()):.4f}" |
|
|
}, refresh=False) |
|
|
|
|
|
del x_t, eps_hat, v_star, v_hat, s_feats_spatial, t_feats_spatial |
|
|
|
|
|
pbar.close() |
|
|
|
|
|
n = len(self.loader) |
|
|
print(f"\n[Epoch {ep+1}] L={acc['L']/n:.4f} | L_flow={acc['Lf']/n:.4f} | L_blocks={acc['Lb']/n:.4f}") |
|
|
self.writer.add_scalar("epoch/total", acc['L']/n, ep+1) |
|
|
self.writer.add_scalar("epoch/flow", acc['Lf']/n, ep+1) |
|
|
self.writer.add_scalar("epoch/blocks",acc['Lb']/n, ep+1) |
|
|
|
|
|
if (ep+1) % cfg.save_every == 0: |
|
|
self._save(ep+1, gstep) |
|
|
|
|
|
self._save("final", gstep) |
|
|
|
|
|
|
|
|
print("\n🎨 Generating final test samples...") |
|
|
final_prompts = [ |
|
|
"a castle at sunset", |
|
|
"a mountain landscape with trees", |
|
|
"a city street at night", |
|
|
"a portrait of a person", |
|
|
"abstract geometric shapes" |
|
|
] |
|
|
try: |
|
|
final_imgs = self.sample(final_prompts, steps=30, guidance=7.5) |
|
|
|
|
|
sample_dir = Path(cfg.out_dir) / "samples" |
|
|
sample_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
for i, (img, prompt) in enumerate(zip(final_imgs, final_prompts)): |
|
|
from PIL import Image |
|
|
img_np = ((img.cpu().permute(1,2,0).numpy() + 1) / 2 * 255).astype('uint8') |
|
|
pil_img = Image.fromarray(img_np) |
|
|
safe_prompt = prompt.replace(" ", "_")[:30] |
|
|
pil_img.save(sample_dir / f"final_{safe_prompt}.png") |
|
|
|
|
|
print(f"✓ Saved {len(final_imgs)} final images to {sample_dir}") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Final sampling failed: {e}") |
|
|
|
|
|
self.writer.close() |
|
|
|
|
|
def _save(self, tag, gstep): |
|
|
"""Save checkpoint and upload to HuggingFace.""" |
|
|
pt_path = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.pt" |
|
|
torch.save({ |
|
|
"cfg": asdict(self.cfg), |
|
|
"student": self.student.state_dict(), |
|
|
"opt": self.opt.state_dict(), |
|
|
"sched": self.sched.state_dict(), |
|
|
"gstep": gstep |
|
|
}, pt_path) |
|
|
|
|
|
size_mb = pt_path.stat().st_size / 1e6 |
|
|
print(f"✓ Saved checkpoint: {pt_path.name} ({size_mb:.1f} MB)") |
|
|
|
|
|
if self.cfg.upload_every_epoch and self.cfg.hf_repo_id: |
|
|
self._upload_to_hf(pt_path, tag) |
|
|
|
|
|
def _upload_to_hf(self, path: Path, tag): |
|
|
"""Upload checkpoint to HuggingFace.""" |
|
|
try: |
|
|
api = HfApi() |
|
|
create_repo(self.cfg.hf_repo_id, exist_ok=True, private=False, repo_type="model") |
|
|
|
|
|
print(f"📤 Uploading {path.name} to {self.cfg.hf_repo_id}...") |
|
|
api.upload_file( |
|
|
path_or_fileobj=str(path), |
|
|
path_in_repo=path.name, |
|
|
repo_id=self.cfg.hf_repo_id, |
|
|
repo_type="model", |
|
|
commit_message=f"Epoch {tag}" |
|
|
) |
|
|
print(f"✅ Uploaded: https://huggingface.co/{self.cfg.hf_repo_id}/{path.name}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Upload failed: {e}") |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample(self, prompts: List[str], steps: Optional[int]=None, guidance: Optional[float]=None) -> torch.Tensor: |
|
|
steps = steps or self.cfg.sample_steps |
|
|
guidance = guidance if guidance is not None else self.cfg.guidance_scale |
|
|
|
|
|
|
|
|
was_training = self.student.training |
|
|
self.student.eval() |
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=self.cfg.amp): |
|
|
cond_e = self.teacher.encode(prompts) |
|
|
uncond_e = self.teacher.encode([""]*len(prompts)) |
|
|
|
|
|
sched = self.teacher.sched |
|
|
sched.set_timesteps(steps, device=self.device) |
|
|
|
|
|
|
|
|
x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device) |
|
|
|
|
|
for t_scalar in sched.timesteps: |
|
|
t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long) |
|
|
v_u, _ = self.student(x_t, t, uncond_e) |
|
|
v_c, _ = self.student(x_t, t, cond_e) |
|
|
v_hat = v_u + guidance*(v_c - v_u) |
|
|
|
|
|
alpha, sigma = self.teacher.alpha_sigma(t) |
|
|
denom = (alpha**2 + sigma**2) |
|
|
x0_hat = (alpha * x_t - sigma * v_hat) / (denom + 1e-8) |
|
|
eps_hat = (x_t - alpha * x0_hat) / (sigma + 1e-8) |
|
|
|
|
|
step = sched.step(model_output=eps_hat, timestep=t_scalar, sample=x_t) |
|
|
x_t = step.prev_sample |
|
|
|
|
|
|
|
|
imgs = self.teacher.pipe.vae.decode(x_t / 0.18215).sample |
|
|
|
|
|
|
|
|
if was_training: |
|
|
self.student.train() |
|
|
|
|
|
return imgs.clamp(-1,1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
cfg = BaseConfig() |
|
|
print(json.dumps(asdict(cfg), indent=2)) |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if device != "cuda": |
|
|
print("⚠️ A100 strongly recommended.") |
|
|
trainer = FlowMatchDavidTrainer(cfg, device=device) |
|
|
trainer.train() |
|
|
_ = trainer.sample(["a castle at sunset"], steps=10, guidance=7.0) |
|
|
print("✓ Training complete.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |