AbstractPhil commited on
Commit
a7aafe6
·
verified ·
1 Parent(s): ee10b5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +383 -1102
app.py CHANGED
@@ -8,7 +8,7 @@ Supports Illustrious XL, standard SDXL, and SD1.5 variants.
8
 
9
  Lyra VAE Versions:
10
  - v1: SD1.5 (768 dim CLIP + T5-base) - geofractal.model.vae.vae_lyra
11
- - v2: SDXL/Illustrious (768 CLIP-L + 2048 T5-XL) - geofractal.model.vae.vae_lyra_v2
12
  """
13
 
14
  import os
@@ -25,9 +25,10 @@ from diffusers import (
25
  UNet2DConditionModel,
26
  AutoencoderKL,
27
  EulerDiscreteScheduler,
28
- EulerAncestralDiscreteScheduler
 
 
29
  )
30
- from diffusers.models import UNet2DConditionModel as DiffusersUNet
31
  from transformers import (
32
  CLIPTextModel,
33
  CLIPTokenizer,
@@ -37,74 +38,90 @@ from transformers import (
37
  )
38
  from huggingface_hub import hf_hub_download
39
 
40
- # Import Lyra VAE v1 (SD1.5) from geofractal
41
- try:
42
- from geofractal.model.vae.vae_lyra import MultiModalVAE as LyraV1, MultiModalVAEConfig as LyraV1Config
43
- LYRA_V1_AVAILABLE = True
44
- except ImportError:
45
- print("⚠️ Lyra VAE v1 not available")
46
- LYRA_V1_AVAILABLE = False
47
 
48
- # Import Lyra VAE v2 (SDXL/Illustrious) from geofractal
49
- try:
50
- from geofractal.model.vae.vae_lyra_v2 import MultiModalVAE as LyraV2, MultiModalVAEConfig as LyraV2Config
51
- LYRA_V2_AVAILABLE = True
52
- except ImportError:
53
- print("⚠️ Lyra VAE v2 not available")
54
- LYRA_V2_AVAILABLE = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  # ============================================================================
58
  # CONSTANTS
59
  # ============================================================================
60
 
61
- # Model architectures
62
  ARCH_SD15 = "sd15"
63
  ARCH_SDXL = "sdxl"
64
 
65
- # ComfyUI key prefixes for SDXL single-file checkpoints
66
- COMFYUI_UNET_PREFIX = "model.diffusion_model."
67
- COMFYUI_CLIP_L_PREFIX = "conditioner.embedders.0.transformer."
68
- COMFYUI_CLIP_G_PREFIX = "conditioner.embedders.1.model."
69
- COMFYUI_VAE_PREFIX = "first_stage_model."
 
 
70
 
71
 
72
  # ============================================================================
73
- # MODEL LOADING UTILITIES
74
  # ============================================================================
75
 
76
- def extract_comfyui_components(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
77
- """Extract UNet, CLIP-L, CLIP-G, and VAE from ComfyUI single-file checkpoint."""
78
-
79
- components = {
80
- "unet": {},
81
- "clip_l": {},
82
- "clip_g": {},
83
- "vae": {}
84
- }
85
-
86
- for key, value in state_dict.items():
87
- if key.startswith(COMFYUI_UNET_PREFIX):
88
- new_key = key[len(COMFYUI_UNET_PREFIX):]
89
- components["unet"][new_key] = value
90
- elif key.startswith(COMFYUI_CLIP_L_PREFIX):
91
- new_key = key[len(COMFYUI_CLIP_L_PREFIX):]
92
- components["clip_l"][new_key] = value
93
- elif key.startswith(COMFYUI_CLIP_G_PREFIX):
94
- new_key = key[len(COMFYUI_CLIP_G_PREFIX):]
95
- components["clip_g"][new_key] = value
96
- elif key.startswith(COMFYUI_VAE_PREFIX):
97
- new_key = key[len(COMFYUI_VAE_PREFIX):]
98
- components["vae"][new_key] = value
99
-
100
- print(f" Extracted components:")
101
- print(f" UNet: {len(components['unet'])} keys")
102
- print(f" CLIP-L: {len(components['clip_l'])} keys")
103
- print(f" CLIP-G: {len(components['clip_g'])} keys")
104
- print(f" VAE: {len(components['vae'])} keys")
105
 
106
- return components
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
 
 
 
108
 
109
  def get_clip_hidden_state(
110
  model_output,
@@ -116,13 +133,180 @@ def get_clip_hidden_state(
116
  return model_output.last_hidden_state
117
 
118
  if hasattr(model_output, 'hidden_states') and model_output.hidden_states is not None:
119
- # hidden_states is tuple: (embedding, layer1, ..., layerN)
120
- # clip_skip=2 means penultimate layer = hidden_states[-2]
121
  return model_output.hidden_states[-clip_skip]
122
 
123
  return model_output.last_hidden_state
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  # ============================================================================
127
  # SDXL PIPELINE
128
  # ============================================================================
@@ -133,16 +317,15 @@ class SDXLFlowMatchingPipeline:
133
  def __init__(
134
  self,
135
  vae: AutoencoderKL,
136
- text_encoder: CLIPTextModel, # CLIP-L
137
- text_encoder_2: CLIPTextModelWithProjection, # CLIP-G
138
  tokenizer: CLIPTokenizer,
139
  tokenizer_2: CLIPTokenizer,
140
  unet: UNet2DConditionModel,
141
  scheduler,
142
  device: str = "cuda",
143
- t5_encoder: Optional[T5EncoderModel] = None,
144
- t5_tokenizer: Optional[T5Tokenizer] = None,
145
- lyra_model: Optional[any] = None,
146
  clip_skip: int = 1
147
  ):
148
  self.vae = vae
@@ -154,16 +337,31 @@ class SDXLFlowMatchingPipeline:
154
  self.scheduler = scheduler
155
  self.device = device
156
 
157
- # Lyra components
158
- self.t5_encoder = t5_encoder
159
- self.t5_tokenizer = t5_tokenizer
160
- self.lyra_model = lyra_model
161
 
162
  # Settings
163
  self.clip_skip = clip_skip
164
- self.vae_scale_factor = 0.13025 # SDXL VAE scaling
165
  self.arch = ARCH_SDXL
166
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def encode_prompt(
168
  self,
169
  prompt: str,
@@ -206,11 +404,8 @@ class SDXLFlowMatchingPipeline:
206
  output_hidden_states=output_hidden_states
207
  )
208
  prompt_embeds_g = get_clip_hidden_state(clip_g_output, clip_skip, output_hidden_states)
209
-
210
- # Get pooled output from CLIP-G
211
  pooled_prompt_embeds = clip_g_output.text_embeds
212
 
213
- # Concatenate CLIP-L and CLIP-G embeddings
214
  prompt_embeds = torch.cat([prompt_embeds_l, prompt_embeds_g], dim=-1)
215
 
216
  # Negative prompt
@@ -262,14 +457,8 @@ class SDXLFlowMatchingPipeline:
262
  t5_summary: str = "",
263
  lyra_strength: float = 0.3
264
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
265
- """Encode prompts using Lyra VAE v2 fusion (CLIP + T5).
266
 
267
- Uses cross-modal translation: encode T5 → decode to CLIP space,
268
- then blend with original CLIP embeddings.
269
-
270
- Args:
271
- lyra_strength: Blend factor (0.0 = pure CLIP, 1.0 = pure Lyra reconstruction)
272
- """
273
  if self.lyra_model is None or self.t5_encoder is None:
274
  raise ValueError("Lyra VAE components not initialized")
275
 
@@ -278,7 +467,7 @@ class SDXLFlowMatchingPipeline:
278
  prompt, negative_prompt, clip_skip
279
  )
280
 
281
- # Format T5 input with pilcrow separator (¶)
282
  SUMMARY_SEPARATOR = "¶"
283
  if t5_summary.strip():
284
  t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {t5_summary}"
@@ -298,18 +487,10 @@ class SDXLFlowMatchingPipeline:
298
  t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state
299
 
300
  clip_l_dim = 768
301
- clip_g_dim = 1280
302
-
303
  clip_l_embeds = prompt_embeds[..., :clip_l_dim]
304
  clip_g_embeds = prompt_embeds[..., clip_l_dim:]
305
 
306
- # Debug: print input stats
307
- print(f"[Lyra Debug] CLIP-L input: shape={clip_l_embeds.shape}, mean={clip_l_embeds.mean():.4f}, std={clip_l_embeds.std():.4f}")
308
- print(f"[Lyra Debug] CLIP-G input: shape={clip_g_embeds.shape}, mean={clip_g_embeds.mean():.4f}, std={clip_g_embeds.std():.4f}")
309
- print(f"[Lyra Debug] T5 input: shape={t5_embeds.shape}, mean={t5_embeds.mean():.4f}, std={t5_embeds.std():.4f}")
310
-
311
  with torch.no_grad():
312
- # Full forward pass with all modalities (model requires all)
313
  modality_inputs = {
314
  'clip_l': clip_l_embeds.float(),
315
  'clip_g': clip_g_embeds.float(),
@@ -320,90 +501,30 @@ class SDXLFlowMatchingPipeline:
320
  modality_inputs,
321
  target_modalities=['clip_l', 'clip_g']
322
  )
323
- print(f"[Lyra Debug] Latent mu: shape={mu.shape}, mean={mu.mean():.4f}, std={mu.std():.4f}")
324
 
325
  lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
326
  lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
327
 
328
- print(f"[Lyra Debug] Lyra CLIP-L output: mean={lyra_clip_l.mean():.4f}, std={lyra_clip_l.std():.4f}")
329
- print(f"[Lyra Debug] Lyra CLIP-G output: mean={lyra_clip_g.mean():.4f}, std={lyra_clip_g.std():.4f}")
330
-
331
- # Check if reconstruction stats are wildly different from input
332
  clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8)
333
  clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8)
334
- print(f"[Lyra Debug] Std ratio CLIP-L: {clip_l_std_ratio:.4f}, CLIP-G: {clip_g_std_ratio:.4f}")
335
 
336
- # Normalize reconstructions to match input statistics if needed
337
  if clip_l_std_ratio > 2.0 or clip_l_std_ratio < 0.5:
338
- print("[Lyra Debug] Normalizing CLIP-L reconstruction to match input stats")
339
  lyra_clip_l = (lyra_clip_l - lyra_clip_l.mean()) / (lyra_clip_l.std() + 1e-8)
340
  lyra_clip_l = lyra_clip_l * clip_l_embeds.std() + clip_l_embeds.mean()
341
 
342
  if clip_g_std_ratio > 2.0 or clip_g_std_ratio < 0.5:
343
- print("[Lyra Debug] Normalizing CLIP-G reconstruction to match input stats")
344
  lyra_clip_g = (lyra_clip_g - lyra_clip_g.mean()) / (lyra_clip_g.std() + 1e-8)
345
  lyra_clip_g = lyra_clip_g * clip_g_embeds.std() + clip_g_embeds.mean()
346
-
347
 
348
- # Blend original CLIP with Lyra reconstruction
349
  fused_clip_l = (1 - lyra_strength) * clip_l_embeds + lyra_strength * lyra_clip_l
350
  fused_clip_g = (1 - lyra_strength) * clip_g_embeds + lyra_strength * lyra_clip_g
351
 
352
- print(f"[Lyra Debug] Final fused CLIP-L: mean={fused_clip_l.mean():.4f}, std={fused_clip_l.std():.4f}")
353
- print(f"[Lyra Debug] lyra_strength={lyra_strength}")
354
-
355
  prompt_embeds_fused = torch.cat([fused_clip_l, fused_clip_g], dim=-1)
356
 
357
- # Process negative prompt (simpler - just use original CLIP for negative)
358
- if negative_prompt:
359
- # For negative, blend less aggressively
360
- neg_strength = lyra_strength
361
-
362
- t5_neg_prompt = f"{negative_prompt} {SUMMARY_SEPARATOR} {negative_prompt}"
363
- t5_inputs_neg = self.t5_tokenizer(
364
- t5_neg_prompt,
365
- max_length=512,
366
- padding='max_length',
367
- truncation=True,
368
- return_tensors='pt'
369
- ).to(self.device)
370
-
371
- with torch.no_grad():
372
- t5_embeds_neg = self.t5_encoder(**t5_inputs_neg).last_hidden_state
373
-
374
- neg_clip_l = negative_prompt_embeds[..., :clip_l_dim]
375
- neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
376
-
377
- # Full forward pass (model requires all modalities)
378
- modality_inputs_neg = {
379
- 'clip_l': neg_clip_l.float(),
380
- 'clip_g': neg_clip_g.float(),
381
- 't5_xl_l': t5_embeds_neg.float(),
382
- 't5_xl_g': t5_embeds_neg.float()
383
- }
384
- recon_neg, _, _, _ = self.lyra_model(modality_inputs_neg, target_modalities=['clip_l', 'clip_g'])
385
-
386
- lyra_neg_l = recon_neg['clip_l'].to(negative_prompt_embeds.dtype)
387
- lyra_neg_g = recon_neg['clip_g'].to(negative_prompt_embeds.dtype)
388
-
389
- # Normalize if needed
390
- neg_l_ratio = lyra_neg_l.std() / (neg_clip_l.std() + 1e-8)
391
- neg_g_ratio = lyra_neg_g.std() / (neg_clip_g.std() + 1e-8)
392
- if neg_l_ratio > 2.0 or neg_l_ratio < 0.5:
393
- lyra_neg_l = (lyra_neg_l - lyra_neg_l.mean()) / (lyra_neg_l.std() + 1e-8)
394
- lyra_neg_l = lyra_neg_l * neg_clip_l.std() + neg_clip_l.mean()
395
- if neg_g_ratio > 2.0 or neg_g_ratio < 0.5:
396
- lyra_neg_g = (lyra_neg_g - lyra_neg_g.mean()) / (lyra_neg_g.std() + 1e-8)
397
- lyra_neg_g = lyra_neg_g * neg_clip_g.std() + neg_clip_g.mean()
398
-
399
- fused_neg_l = (1 - neg_strength) * neg_clip_l + neg_strength * lyra_neg_l
400
- fused_neg_g = (1 - neg_strength) * neg_clip_g + neg_strength * lyra_neg_g
401
-
402
- negative_prompt_embeds_fused = torch.cat([fused_neg_l, fused_neg_g], dim=-1)
403
- else:
404
- negative_prompt_embeds_fused = torch.zeros_like(prompt_embeds_fused)
405
-
406
- return prompt_embeds_fused, negative_prompt_embeds_fused, pooled, negative_pooled
407
 
408
  def _get_add_time_ids(
409
  self,
@@ -424,28 +545,24 @@ class SDXLFlowMatchingPipeline:
424
  negative_prompt: str = "",
425
  height: int = 1024,
426
  width: int = 1024,
427
- num_inference_steps: int = 20,
428
- guidance_scale: float = 7.5,
429
- shift: float = 0.0,
430
- use_flow_matching: bool = False,
431
- prediction_type: str = "epsilon",
432
  seed: Optional[int] = None,
433
  use_lyra: bool = False,
434
- clip_skip: int = 1,
435
  t5_summary: str = "",
436
  lyra_strength: float = 1.0,
437
  progress_callback=None
438
  ):
439
  """Generate image using SDXL architecture."""
440
 
441
- # Set seed
442
  if seed is not None:
443
  generator = torch.Generator(device=self.device).manual_seed(seed)
444
  else:
445
  generator = None
446
 
447
  # Encode prompts
448
- if use_lyra and self.lyra_model is not None:
449
  prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra(
450
  prompt, negative_prompt, clip_skip, t5_summary, lyra_strength
451
  )
@@ -470,11 +587,9 @@ class SDXLFlowMatchingPipeline:
470
  self.scheduler.set_timesteps(num_inference_steps, device=self.device)
471
  timesteps = self.scheduler.timesteps
472
 
473
- # Scale initial latents
474
- if not use_flow_matching:
475
- latents = latents * self.scheduler.init_noise_sigma
476
 
477
- # Prepare added time embeddings for SDXL
478
  original_size = (height, width)
479
  target_size = (height, width)
480
  crops_coords_top_left = (0, 0)
@@ -482,29 +597,18 @@ class SDXLFlowMatchingPipeline:
482
  add_time_ids = self._get_add_time_ids(
483
  original_size, crops_coords_top_left, target_size, dtype=torch.float16
484
  )
485
- negative_add_time_ids = add_time_ids # Same for negative
486
 
487
  # Denoising loop
488
  for i, t in enumerate(timesteps):
489
  if progress_callback:
490
  progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
491
 
492
- # Expand for CFG
493
  latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
 
494
 
495
- # Flow matching scaling
496
- if use_flow_matching and shift > 0:
497
- sigma = t.float() / 1000.0
498
- sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
499
- scaling = torch.sqrt(1 + sigma_shifted ** 2)
500
- latent_model_input = latent_model_input / scaling
501
- else:
502
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
503
-
504
- # Prepare timestep
505
  timestep = t.expand(latent_model_input.shape[0])
506
 
507
- # Prepare added conditions
508
  if guidance_scale > 1.0:
509
  text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
510
  add_text_embeds = torch.cat([negative_pooled, pooled])
@@ -514,13 +618,11 @@ class SDXLFlowMatchingPipeline:
514
  add_text_embeds = pooled
515
  add_time_ids_input = add_time_ids
516
 
517
- # Prepare added cond kwargs for SDXL UNet
518
  added_cond_kwargs = {
519
  "text_embeds": add_text_embeds,
520
  "time_ids": add_time_ids_input
521
  }
522
 
523
- # Predict noise
524
  noise_pred = self.unet(
525
  latent_model_input,
526
  timestep,
@@ -529,28 +631,11 @@ class SDXLFlowMatchingPipeline:
529
  return_dict=False
530
  )[0]
531
 
532
- # CFG
533
  if guidance_scale > 1.0:
534
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
535
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
536
 
537
- # Step
538
- if use_flow_matching:
539
- sigma = t.float() / 1000.0
540
- sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
541
-
542
- if prediction_type == "v_prediction":
543
- v_pred = noise_pred
544
- alpha_t = torch.sqrt(1 - sigma_shifted ** 2)
545
- sigma_t = sigma_shifted
546
- noise_pred = alpha_t * v_pred + sigma_t * latents
547
-
548
- dt = -1.0 / num_inference_steps
549
- latents = latents + dt * noise_pred
550
- else:
551
- latents = self.scheduler.step(
552
- noise_pred, t, latents, return_dict=False
553
- )[0]
554
 
555
  # Decode
556
  latents = latents / self.vae_scale_factor
@@ -558,255 +643,6 @@ class SDXLFlowMatchingPipeline:
558
  with torch.no_grad():
559
  image = self.vae.decode(latents.to(self.vae.dtype)).sample
560
 
561
- # Convert to PIL
562
- image = (image / 2 + 0.5).clamp(0, 1)
563
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
564
- image = (image * 255).round().astype("uint8")
565
- image = Image.fromarray(image[0])
566
-
567
- return image
568
-
569
-
570
- # ============================================================================
571
- # SD1.5 PIPELINE (Original)
572
- # ============================================================================
573
-
574
- class SD15FlowMatchingPipeline:
575
- """Pipeline for SD1.5-based flow-matching inference."""
576
-
577
- def __init__(
578
- self,
579
- vae: AutoencoderKL,
580
- text_encoder: CLIPTextModel,
581
- tokenizer: CLIPTokenizer,
582
- unet: UNet2DConditionModel,
583
- scheduler,
584
- device: str = "cuda",
585
- t5_encoder: Optional[T5EncoderModel] = None,
586
- t5_tokenizer: Optional[T5Tokenizer] = None,
587
- lyra_model: Optional[any] = None
588
- ):
589
- self.vae = vae
590
- self.text_encoder = text_encoder
591
- self.tokenizer = tokenizer
592
- self.unet = unet
593
- self.scheduler = scheduler
594
- self.device = device
595
-
596
- self.t5_encoder = t5_encoder
597
- self.t5_tokenizer = t5_tokenizer
598
- self.lyra_model = lyra_model
599
-
600
- self.vae_scale_factor = 0.18215
601
- self.arch = ARCH_SD15
602
- self.is_lune_model = False
603
-
604
- def encode_prompt(self, prompt: str, negative_prompt: str = ""):
605
- """Encode text prompts to embeddings."""
606
- text_inputs = self.tokenizer(
607
- prompt,
608
- padding="max_length",
609
- max_length=self.tokenizer.model_max_length,
610
- truncation=True,
611
- return_tensors="pt",
612
- )
613
- text_input_ids = text_inputs.input_ids.to(self.device)
614
-
615
- with torch.no_grad():
616
- prompt_embeds = self.text_encoder(text_input_ids)[0]
617
-
618
- if negative_prompt:
619
- uncond_inputs = self.tokenizer(
620
- negative_prompt,
621
- padding="max_length",
622
- max_length=self.tokenizer.model_max_length,
623
- truncation=True,
624
- return_tensors="pt",
625
- )
626
- uncond_input_ids = uncond_inputs.input_ids.to(self.device)
627
-
628
- with torch.no_grad():
629
- negative_prompt_embeds = self.text_encoder(uncond_input_ids)[0]
630
- else:
631
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
632
-
633
- return prompt_embeds, negative_prompt_embeds
634
-
635
- def encode_prompt_lyra(self, prompt: str, negative_prompt: str = ""):
636
- """Encode using Lyra VAE (CLIP + T5 fusion)."""
637
- if self.lyra_model is None or self.t5_encoder is None:
638
- raise ValueError("Lyra VAE components not initialized")
639
-
640
- # CLIP
641
- text_inputs = self.tokenizer(
642
- prompt,
643
- padding="max_length",
644
- max_length=self.tokenizer.model_max_length,
645
- truncation=True,
646
- return_tensors="pt",
647
- )
648
- text_input_ids = text_inputs.input_ids.to(self.device)
649
-
650
- with torch.no_grad():
651
- clip_embeds = self.text_encoder(text_input_ids)[0]
652
-
653
- # T5
654
- t5_inputs = self.t5_tokenizer(
655
- prompt,
656
- max_length=77,
657
- padding='max_length',
658
- truncation=True,
659
- return_tensors='pt'
660
- ).to(self.device)
661
-
662
- with torch.no_grad():
663
- t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state
664
-
665
- # Fuse
666
- modality_inputs = {'clip': clip_embeds, 't5': t5_embeds}
667
-
668
- with torch.no_grad():
669
- reconstructions, mu, logvar = self.lyra_model(
670
- modality_inputs,
671
- target_modalities=['clip']
672
- )
673
- prompt_embeds = reconstructions['clip']
674
-
675
- # Negative
676
- if negative_prompt:
677
- uncond_inputs = self.tokenizer(
678
- negative_prompt,
679
- padding="max_length",
680
- max_length=self.tokenizer.model_max_length,
681
- truncation=True,
682
- return_tensors="pt",
683
- )
684
- uncond_input_ids = uncond_inputs.input_ids.to(self.device)
685
-
686
- with torch.no_grad():
687
- clip_embeds_uncond = self.text_encoder(uncond_input_ids)[0]
688
-
689
- t5_inputs_uncond = self.t5_tokenizer(
690
- negative_prompt,
691
- max_length=77,
692
- padding='max_length',
693
- truncation=True,
694
- return_tensors='pt'
695
- ).to(self.device)
696
-
697
- with torch.no_grad():
698
- t5_embeds_uncond = self.t5_encoder(**t5_inputs_uncond).last_hidden_state
699
-
700
- modality_inputs_uncond = {'clip': clip_embeds_uncond, 't5': t5_embeds_uncond}
701
-
702
- with torch.no_grad():
703
- reconstructions_uncond, _, _ = self.lyra_model(
704
- modality_inputs_uncond,
705
- target_modalities=['clip']
706
- )
707
- negative_prompt_embeds = reconstructions_uncond['clip']
708
- else:
709
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
710
-
711
- return prompt_embeds, negative_prompt_embeds
712
-
713
- @torch.no_grad()
714
- def __call__(
715
- self,
716
- prompt: str,
717
- negative_prompt: str = "",
718
- height: int = 512,
719
- width: int = 512,
720
- num_inference_steps: int = 20,
721
- guidance_scale: float = 7.5,
722
- shift: float = 2.5,
723
- use_flow_matching: bool = True,
724
- prediction_type: str = "epsilon",
725
- seed: Optional[int] = None,
726
- use_lyra: bool = False,
727
- clip_skip: int = 1, # Unused for SD1.5 but kept for API consistency
728
- progress_callback=None
729
- ):
730
- """Generate image."""
731
-
732
- if seed is not None:
733
- generator = torch.Generator(device=self.device).manual_seed(seed)
734
- else:
735
- generator = None
736
-
737
- if use_lyra and self.lyra_model is not None:
738
- prompt_embeds, negative_prompt_embeds = self.encode_prompt_lyra(prompt, negative_prompt)
739
- else:
740
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, negative_prompt)
741
-
742
- latent_channels = 4
743
- latent_height = height // 8
744
- latent_width = width // 8
745
-
746
- latents = torch.randn(
747
- (1, latent_channels, latent_height, latent_width),
748
- generator=generator,
749
- device=self.device,
750
- dtype=torch.float32
751
- )
752
-
753
- self.scheduler.set_timesteps(num_inference_steps, device=self.device)
754
- timesteps = self.scheduler.timesteps
755
-
756
- if not use_flow_matching:
757
- latents = latents * self.scheduler.init_noise_sigma
758
-
759
- for i, t in enumerate(timesteps):
760
- if progress_callback:
761
- progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
762
-
763
- latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
764
-
765
- if use_flow_matching and shift > 0:
766
- sigma = t.float() / 1000.0
767
- sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
768
- scaling = torch.sqrt(1 + sigma_shifted ** 2)
769
- latent_model_input = latent_model_input / scaling
770
- else:
771
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
772
-
773
- timestep = t.expand(latent_model_input.shape[0])
774
- text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if guidance_scale > 1.0 else prompt_embeds
775
-
776
- noise_pred = self.unet(
777
- latent_model_input,
778
- timestep,
779
- encoder_hidden_states=text_embeds,
780
- return_dict=False
781
- )[0]
782
-
783
- if guidance_scale > 1.0:
784
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
785
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
786
-
787
- if use_flow_matching:
788
- sigma = t.float() / 1000.0
789
- sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
790
-
791
- if prediction_type == "v_prediction":
792
- v_pred = noise_pred
793
- alpha_t = torch.sqrt(1 - sigma_shifted ** 2)
794
- sigma_t = sigma_shifted
795
- noise_pred = alpha_t * v_pred + sigma_t * latents
796
-
797
- dt = -1.0 / num_inference_steps
798
- latents = latents + dt * noise_pred
799
- else:
800
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
801
-
802
- latents = latents / self.vae_scale_factor
803
-
804
- if self.is_lune_model:
805
- latents = latents * 5.52
806
-
807
- with torch.no_grad():
808
- image = self.vae.decode(latents).sample
809
-
810
  image = (image / 2 + 0.5).clamp(0, 1)
811
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
812
  image = (image * 255).round().astype("uint8")
@@ -819,59 +655,26 @@ class SD15FlowMatchingPipeline:
819
  # MODEL LOADERS
820
  # ============================================================================
821
 
822
- def load_lune_checkpoint(repo_id: str, filename: str, device: str = "cuda"):
823
- """Load Lune checkpoint from .pt file."""
824
- print(f"📥 Downloading: {repo_id}/{filename}")
825
-
826
- checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
827
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
828
-
829
- print(f"🏗️ Initializing SD1.5 UNet...")
830
- unet = UNet2DConditionModel.from_pretrained(
831
- "runwayml/stable-diffusion-v1-5",
832
- subfolder="unet",
833
- torch_dtype=torch.float32
834
- )
835
-
836
- student_state_dict = checkpoint["student"]
837
- cleaned_dict = {}
838
- for key, value in student_state_dict.items():
839
- if key.startswith("unet."):
840
- cleaned_dict[key[5:]] = value
841
- else:
842
- cleaned_dict[key] = value
843
-
844
- unet.load_state_dict(cleaned_dict, strict=False)
845
-
846
- step = checkpoint.get("gstep", "unknown")
847
- print(f"✅ Loaded Lune from step {step}")
848
-
849
- return unet.to(device)
850
-
851
-
852
  def load_illustrious_xl(
853
- repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
854
  filename: str = "illustriousXL_v01.safetensors",
855
  device: str = "cuda"
856
  ) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]:
857
- """Load Illustrious XL from single safetensors file using diffusers' single-file loader."""
858
  from diffusers import StableDiffusionXLPipeline
859
 
860
  print(f"📥 Loading Illustrious XL: {repo_id}/{filename}")
861
 
862
- # Download the checkpoint
863
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
864
  print(f"✓ Downloaded: {checkpoint_path}")
865
 
866
- # Use diffusers' built-in single-file loader which handles key remapping
867
- print("📦 Loading with StableDiffusionXLPipeline.from_single_file()...")
868
  pipe = StableDiffusionXLPipeline.from_single_file(
869
  checkpoint_path,
870
  torch_dtype=torch.float16,
871
  use_safetensors=True,
872
  )
873
 
874
- # Extract components
875
  unet = pipe.unet.to(device)
876
  vae = pipe.vae.to(device)
877
  text_encoder = pipe.text_encoder.to(device)
@@ -879,404 +682,72 @@ def load_illustrious_xl(
879
  tokenizer = pipe.tokenizer
880
  tokenizer_2 = pipe.tokenizer_2
881
 
882
- # Clean up the pipeline to free memory
883
  del pipe
884
  torch.cuda.empty_cache()
885
 
886
  print("✅ Illustrious XL loaded!")
887
- print(f" UNet params: {sum(p.numel() for p in unet.parameters()):,}")
888
- print(f" VAE params: {sum(p.numel() for p in vae.parameters()):,}")
889
 
890
  return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
891
 
892
 
893
- def load_sdxl_base(device: str = "cuda"):
894
- """Load standard SDXL base model."""
895
- print("📥 Loading SDXL Base 1.0...")
896
-
897
- unet = UNet2DConditionModel.from_pretrained(
898
- "stabilityai/stable-diffusion-xl-base-1.0",
899
- subfolder="unet",
900
- torch_dtype=torch.float16
901
- ).to(device)
902
-
903
- vae = AutoencoderKL.from_pretrained(
904
- "stabilityai/stable-diffusion-xl-base-1.0",
905
- subfolder="vae",
906
- torch_dtype=torch.float16
907
- ).to(device)
908
-
909
- text_encoder = CLIPTextModel.from_pretrained(
910
- "stabilityai/stable-diffusion-xl-base-1.0",
911
- subfolder="text_encoder",
912
- torch_dtype=torch.float16
913
- ).to(device)
914
-
915
- text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
916
- "stabilityai/stable-diffusion-xl-base-1.0",
917
- subfolder="text_encoder_2",
918
- torch_dtype=torch.float16
919
- ).to(device)
920
-
921
- tokenizer = CLIPTokenizer.from_pretrained(
922
- "stabilityai/stable-diffusion-xl-base-1.0",
923
- subfolder="tokenizer"
924
- )
925
-
926
- tokenizer_2 = CLIPTokenizer.from_pretrained(
927
- "stabilityai/stable-diffusion-xl-base-1.0",
928
- subfolder="tokenizer_2"
929
- )
930
-
931
- print("✅ SDXL Base loaded!")
932
-
933
- return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
934
-
935
-
936
- def load_lyra_vae(repo_id: str = "AbstractPhil/vae-lyra", device: str = "cuda"):
937
- """Load Lyra VAE v1 (SD1.5 version) from HuggingFace."""
938
- if not LYRA_V1_AVAILABLE:
939
- print("⚠️ Lyra VAE v1 not available")
940
- return None
941
-
942
- print(f"🎵 Loading Lyra VAE v1 from {repo_id}...")
943
-
944
- try:
945
- # Try to download config.json first
946
- try:
947
- print(" 📥 Downloading config.json...")
948
- config_path = hf_hub_download(
949
- repo_id=repo_id,
950
- filename="config.json",
951
- repo_type="model"
952
- )
953
- with open(config_path, 'r') as f:
954
- config_dict = json.load(f)
955
- print(f" ✓ Config loaded: {config_dict.get('fusion_strategy', 'unknown')} fusion")
956
- except Exception:
957
- # Fallback to defaults if no config.json
958
- print(" ⚠️ No config.json found, using defaults")
959
- config_dict = {
960
- 'modality_dims': {"clip": 768, "t5": 768},
961
- 'latent_dim': 768,
962
- 'seq_len': 77,
963
- 'encoder_layers': 3,
964
- 'decoder_layers': 3,
965
- 'hidden_dim': 1024,
966
- 'dropout': 0.1,
967
- 'fusion_strategy': 'cantor',
968
- 'fusion_heads': 8,
969
- 'fusion_dropout': 0.1
970
- }
971
-
972
- # Download model weights
973
- print(" 📥 Downloading model weights...")
974
- try:
975
- checkpoint_path = hf_hub_download(
976
- repo_id=repo_id,
977
- filename="model.pt",
978
- repo_type="model"
979
- )
980
- except Exception:
981
- # Fallback to best_model.pt
982
- checkpoint_path = hf_hub_download(
983
- repo_id=repo_id,
984
- filename="best_model.pt",
985
- repo_type="model"
986
- )
987
-
988
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
989
-
990
- vae_config = LyraV1Config(
991
- modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 768}),
992
- latent_dim=config_dict.get('latent_dim', 768),
993
- seq_len=config_dict.get('seq_len', 77),
994
- encoder_layers=config_dict.get('encoder_layers', 3),
995
- decoder_layers=config_dict.get('decoder_layers', 3),
996
- hidden_dim=config_dict.get('hidden_dim', 1024),
997
- dropout=config_dict.get('dropout', 0.1),
998
- fusion_strategy=config_dict.get('fusion_strategy', 'cantor'),
999
- fusion_heads=config_dict.get('fusion_heads', 8),
1000
- fusion_dropout=config_dict.get('fusion_dropout', 0.1)
1001
- )
1002
-
1003
- lyra_model = LyraV1(vae_config)
1004
-
1005
- if 'model_state_dict' in checkpoint:
1006
- lyra_model.load_state_dict(checkpoint['model_state_dict'])
1007
- else:
1008
- lyra_model.load_state_dict(checkpoint)
1009
-
1010
- lyra_model.to(device)
1011
- lyra_model.eval()
1012
-
1013
- print(f"✅ Lyra VAE v1 loaded")
1014
- print(f" Fusion: {config_dict.get('fusion_strategy')}")
1015
- print(f" Latent dim: {config_dict.get('latent_dim')}")
1016
- if 'global_step' in checkpoint:
1017
- print(f" Step: {checkpoint['global_step']:,}")
1018
-
1019
- return lyra_model
1020
-
1021
- except Exception as e:
1022
- print(f"❌ Failed to load Lyra VAE v1: {e}")
1023
- import traceback
1024
- traceback.print_exc()
1025
- return None
1026
-
1027
-
1028
- def load_lyra_vae_xl(
1029
- repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
1030
- checkpoint_filename: str = None, # Auto-detect if None
1031
- device: str = "cuda"
1032
- ):
1033
- """Load Lyra VAE v2 (SDXL/Illustrious version) from HuggingFace."""
1034
- if not LYRA_V2_AVAILABLE:
1035
- print("⚠️ Lyra VAE v2 not available")
1036
- return None
1037
-
1038
- print(f"🎵 Loading Lyra VAE v2 from {repo_id}...")
1039
-
1040
- try:
1041
- from huggingface_hub import list_repo_files
1042
-
1043
- # Download config.json
1044
- print(" 📥 Downloading config.json...")
1045
- config_path = hf_hub_download(
1046
- repo_id=repo_id,
1047
- filename="config.json",
1048
- repo_type="model"
1049
- )
1050
-
1051
- with open(config_path, 'r') as f:
1052
- config_dict = json.load(f)
1053
-
1054
- print(f" ✓ Config: {config_dict.get('fusion_strategy', 'unknown')} fusion, latent_dim={config_dict.get('latent_dim')}")
1055
-
1056
- # Auto-detect checkpoint if not specified
1057
- if checkpoint_filename is None:
1058
- repo_files = list_repo_files(repo_id, repo_type="model")
1059
- checkpoint_files = [f for f in repo_files if f.endswith('.pt') or f.endswith('.safetensors')]
1060
- checkpoint_files = [f for f in checkpoint_files if 'checkpoint' in f.lower() or 'model' in f.lower()]
1061
-
1062
- if not checkpoint_files:
1063
- raise FileNotFoundError(f"No checkpoint found in {repo_id}")
1064
-
1065
- # Prefer newest checkpoint (highest step number)
1066
- def extract_step(name):
1067
- import re
1068
- match = re.search(r'(\d+)\.pt', name)
1069
- return int(match.group(1)) if match else 0
1070
-
1071
- checkpoint_files.sort(key=extract_step, reverse=True)
1072
- checkpoint_filename = checkpoint_files[0]
1073
- print(f" ✓ Auto-selected checkpoint: {checkpoint_filename}")
1074
-
1075
- # Download checkpoint
1076
- print(f" 📥 Downloading {checkpoint_filename}...")
1077
- checkpoint_path = hf_hub_download(
1078
- repo_id=repo_id,
1079
- filename=checkpoint_filename,
1080
- repo_type="model"
1081
- )
1082
-
1083
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
1084
-
1085
- # Build config with all v2 fields
1086
- vae_config = LyraV2Config(
1087
- modality_dims=config_dict.get('modality_dims', {
1088
- "clip_l": 768, "clip_g": 1280,
1089
- "t5_xl_l": 2048, "t5_xl_g": 2048
1090
- }),
1091
- modality_seq_lens=config_dict.get('modality_seq_lens', {
1092
- "clip_l": 77, "clip_g": 77,
1093
- "t5_xl_l": 512, "t5_xl_g": 512
1094
- }),
1095
- binding_config=config_dict.get('binding_config', {
1096
- "clip_l": {"t5_xl_l": 0.3},
1097
- "clip_g": {"t5_xl_g": 0.3},
1098
- "t5_xl_l": {},
1099
- "t5_xl_g": {}
1100
- }),
1101
- latent_dim=config_dict.get('latent_dim', 2048),
1102
- seq_len=config_dict.get('seq_len', 77),
1103
- encoder_layers=config_dict.get('encoder_layers', 3),
1104
- decoder_layers=config_dict.get('decoder_layers', 3),
1105
- hidden_dim=config_dict.get('hidden_dim', 2048),
1106
- dropout=config_dict.get('dropout', 0.1),
1107
- fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'),
1108
- fusion_heads=config_dict.get('fusion_heads', 8),
1109
- fusion_dropout=config_dict.get('fusion_dropout', 0.1),
1110
- cantor_depth=config_dict.get('cantor_depth', 8),
1111
- cantor_local_window=config_dict.get('cantor_local_window', 3),
1112
- alpha_init=config_dict.get('alpha_init', 1.0),
1113
- beta_init=config_dict.get('beta_init', 0.3),
1114
- alpha_lr_scale=config_dict.get('alpha_lr_scale', 0.1),
1115
- beta_lr_scale=config_dict.get('beta_lr_scale', 1.0),
1116
- beta_kl=config_dict.get('beta_kl', 0.1),
1117
- beta_reconstruction=config_dict.get('beta_reconstruction', 1.0),
1118
- beta_cross_modal=config_dict.get('beta_cross_modal', 0.0),
1119
- beta_alpha_regularization=config_dict.get('beta_alpha_regularization', 0.01),
1120
- kl_clamp_max=config_dict.get('kl_clamp_max', 1.0),
1121
- logvar_clamp_min=config_dict.get('logvar_clamp_min', -10.0),
1122
- logvar_clamp_max=config_dict.get('logvar_clamp_max', 10.0),
1123
- )
1124
-
1125
- # Initialize model
1126
- lyra_model = LyraV2(vae_config)
1127
-
1128
- # Load weights
1129
- state_dict = checkpoint.get('model_state_dict', checkpoint)
1130
- missing, unexpected = lyra_model.load_state_dict(state_dict, strict=False)
1131
-
1132
- if missing:
1133
- print(f" ⚠️ Missing keys: {len(missing)} (using initialized weights)")
1134
- if unexpected:
1135
- print(f" ⚠️ Unexpected keys: {len(unexpected)} (ignored)")
1136
-
1137
- lyra_model.to(device)
1138
- lyra_model.eval()
1139
-
1140
- # Print summary
1141
- total_params = sum(p.numel() for p in lyra_model.parameters())
1142
- print(f"✅ Lyra VAE v2 loaded ({total_params/1e6:.1f}M params)")
1143
- print(f" Fusion: {vae_config.fusion_strategy}")
1144
- print(f" Latent: {vae_config.latent_dim}, Hidden: {vae_config.hidden_dim}")
1145
-
1146
- if 'global_step' in checkpoint:
1147
- print(f" Trained steps: {checkpoint['global_step']:,}")
1148
- if 'best_loss' in checkpoint:
1149
- print(f" Best loss: {checkpoint['best_loss']:.4f}")
1150
-
1151
- # Print binding info
1152
- fusion_params = lyra_model.get_fusion_params()
1153
- if fusion_params.get('alphas'):
1154
- alpha_vals = {k: torch.sigmoid(v).item() for k, v in fusion_params['alphas'].items()}
1155
- print(f" Alphas: {alpha_vals}")
1156
-
1157
- return lyra_model
1158
-
1159
- except Exception as e:
1160
- print(f"❌ Failed to load Lyra VAE v2: {e}")
1161
- import traceback
1162
- traceback.print_exc()
1163
- return None
1164
-
1165
-
1166
  # ============================================================================
1167
  # PIPELINE INITIALIZATION
1168
  # ============================================================================
1169
 
1170
- def initialize_pipeline(model_choice: str, device: str = "cuda"):
1171
- """Initialize the complete pipeline based on model choice."""
 
 
 
 
1172
 
1173
  print(f"🚀 Initializing {model_choice} pipeline...")
1174
 
1175
- # Determine architecture
1176
- is_sdxl = "Illustrious" in model_choice or "SDXL" in model_choice
1177
- is_lune = "Lune" in model_choice
1178
-
1179
- if is_sdxl:
1180
- # SDXL-based models
1181
- if "Illustrious" in model_choice:
1182
- unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl(device=device)
1183
- else:
1184
- unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_sdxl_base(device=device)
1185
-
1186
- # T5-XL for Lyra
1187
- print("Loading T5-XL encoder...")
1188
- t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
1189
- t5_encoder = T5EncoderModel.from_pretrained(
1190
- "google/flan-t5-xl",
1191
- torch_dtype=torch.float16
1192
- ).to(device)
1193
- t5_encoder.eval()
1194
- print("✓ T5-XL loaded")
1195
-
1196
- # Lyra XL
1197
- lyra_model = load_lyra_vae_xl(device=device)
1198
-
1199
- # Scheduler (epsilon for SDXL)
1200
- scheduler = EulerDiscreteScheduler.from_pretrained(
1201
- "stabilityai/stable-diffusion-xl-base-1.0",
1202
- subfolder="scheduler"
1203
- )
1204
-
1205
- pipeline = SDXLFlowMatchingPipeline(
1206
- vae=vae,
1207
- text_encoder=text_encoder,
1208
- text_encoder_2=text_encoder_2,
1209
- tokenizer=tokenizer,
1210
- tokenizer_2=tokenizer_2,
1211
- unet=unet,
1212
- scheduler=scheduler,
1213
- device=device,
1214
- t5_encoder=t5_encoder,
1215
- t5_tokenizer=t5_tokenizer,
1216
- lyra_model=lyra_model,
1217
- clip_skip=1
1218
- )
1219
-
1220
  else:
1221
- # SD1.5-based models
1222
- vae = AutoencoderKL.from_pretrained(
1223
- "runwayml/stable-diffusion-v1-5",
1224
- subfolder="vae",
1225
- torch_dtype=torch.float32
1226
- ).to(device)
1227
-
1228
- text_encoder = CLIPTextModel.from_pretrained(
1229
- "openai/clip-vit-large-patch14",
1230
- torch_dtype=torch.float32
1231
- ).to(device)
1232
-
1233
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
1234
-
1235
- # T5-base for SD1.5 Lyra
1236
- print("Loading T5-base encoder...")
1237
- t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
1238
- t5_encoder = T5EncoderModel.from_pretrained(
1239
- "t5-base",
1240
- torch_dtype=torch.float32
1241
- ).to(device)
1242
- t5_encoder.eval()
1243
- print("✓ T5-base loaded")
1244
-
1245
- # Lyra (SD1.5 version)
1246
- lyra_model = load_lyra_vae(device=device)
1247
-
1248
- # Load UNet
1249
- if is_lune:
1250
- repo_id = "AbstractPhil/sd15-flow-lune"
1251
- filename = "sd15_flow_lune_e34_s34000.pt"
1252
- unet = load_lune_checkpoint(repo_id, filename, device)
1253
- else:
1254
- unet = UNet2DConditionModel.from_pretrained(
1255
- "runwayml/stable-diffusion-v1-5",
1256
- subfolder="unet",
1257
- torch_dtype=torch.float32
1258
- ).to(device)
1259
-
1260
- scheduler = EulerDiscreteScheduler.from_pretrained(
1261
- "runwayml/stable-diffusion-v1-5",
1262
- subfolder="scheduler"
1263
- )
1264
-
1265
- pipeline = SD15FlowMatchingPipeline(
1266
- vae=vae,
1267
- text_encoder=text_encoder,
1268
- tokenizer=tokenizer,
1269
- unet=unet,
1270
- scheduler=scheduler,
1271
- device=device,
1272
- t5_encoder=t5_encoder,
1273
- t5_tokenizer=t5_tokenizer,
1274
- lyra_model=lyra_model
1275
  )
1276
-
1277
- pipeline.is_lune_model = is_lune
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1278
 
1279
- print("✅ Pipeline initialized!")
1280
  return pipeline
1281
 
1282
 
@@ -1286,15 +757,20 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
1286
 
1287
  CURRENT_PIPELINE = None
1288
  CURRENT_MODEL = None
 
1289
 
1290
 
1291
- def get_pipeline(model_choice: str):
1292
  """Get or create pipeline for selected model."""
1293
- global CURRENT_PIPELINE, CURRENT_MODEL
1294
 
1295
  if CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice:
1296
- CURRENT_PIPELINE = initialize_pipeline(model_choice, device="cuda")
1297
  CURRENT_MODEL = model_choice
 
 
 
 
1298
 
1299
  return CURRENT_PIPELINE
1300
 
@@ -1303,35 +779,18 @@ def get_pipeline(model_choice: str):
1303
  # INFERENCE
1304
  # ============================================================================
1305
 
1306
- def estimate_duration(num_steps: int, width: int, height: int, use_lyra: bool = False, is_sdxl: bool = False) -> int:
1307
- """Estimate GPU duration."""
1308
- base_time_per_step = 0.5 if is_sdxl else 0.3
1309
- resolution_factor = (width * height) / (512 * 512)
1310
- estimated = num_steps * base_time_per_step * resolution_factor
1311
-
1312
- if use_lyra:
1313
- estimated *= 2
1314
- estimated += 3
1315
-
1316
- return int(estimated + 20)
1317
-
1318
-
1319
- @spaces.GPU(duration=lambda *args: estimate_duration(
1320
- args[5], args[7], args[8], args[11],
1321
- "SDXL" in args[3] or "Illustrious" in args[3]
1322
- ))
1323
  def generate_image(
1324
  prompt: str,
1325
  t5_summary: str,
1326
  negative_prompt: str,
1327
  model_choice: str,
 
1328
  clip_skip: int,
1329
  num_steps: int,
1330
  cfg_scale: float,
1331
  width: int,
1332
  height: int,
1333
- shift: float,
1334
- use_flow_matching: bool,
1335
  use_lyra: bool,
1336
  lyra_strength: float,
1337
  seed: int,
@@ -1347,16 +806,9 @@ def generate_image(
1347
  progress((step + 1) / total, desc=desc)
1348
 
1349
  try:
1350
- pipeline = get_pipeline(model_choice)
1351
 
1352
- # Determine prediction type based on model
1353
- is_sdxl = "SDXL" in model_choice or "Illustrious" in model_choice
1354
- prediction_type = "epsilon" # SDXL always uses epsilon
1355
-
1356
- if not is_sdxl and "Lune" in model_choice:
1357
- prediction_type = "v_prediction"
1358
-
1359
- if not use_lyra or pipeline.lyra_model is None:
1360
  progress(0.05, desc="Generating...")
1361
 
1362
  image = pipeline(
@@ -1366,9 +818,6 @@ def generate_image(
1366
  width=width,
1367
  num_inference_steps=num_steps,
1368
  guidance_scale=cfg_scale,
1369
- shift=shift,
1370
- use_flow_matching=use_flow_matching,
1371
- prediction_type=prediction_type,
1372
  seed=seed,
1373
  use_lyra=False,
1374
  clip_skip=clip_skip,
@@ -1388,16 +837,13 @@ def generate_image(
1388
  width=width,
1389
  num_inference_steps=num_steps,
1390
  guidance_scale=cfg_scale,
1391
- shift=shift,
1392
- use_flow_matching=use_flow_matching,
1393
- prediction_type=prediction_type,
1394
  seed=seed,
1395
  use_lyra=False,
1396
  clip_skip=clip_skip,
1397
  progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d)
1398
  )
1399
 
1400
- progress(0.5, desc="Generating Lyra fusion...")
1401
 
1402
  image_lyra = pipeline(
1403
  prompt=prompt,
@@ -1406,9 +852,6 @@ def generate_image(
1406
  width=width,
1407
  num_inference_steps=num_steps,
1408
  guidance_scale=cfg_scale,
1409
- shift=shift,
1410
- use_flow_matching=use_flow_matching,
1411
- prediction_type=prediction_type,
1412
  seed=seed,
1413
  use_lyra=True,
1414
  clip_skip=clip_skip,
@@ -1422,6 +865,8 @@ def generate_image(
1422
 
1423
  except Exception as e:
1424
  print(f"❌ Generation failed: {e}")
 
 
1425
  raise e
1426
 
1427
 
@@ -1434,251 +879,93 @@ def create_demo():
1434
 
1435
  with gr.Blocks() as demo:
1436
  gr.Markdown("""
1437
- # 🌙 Lyra/Lune Flow-Matching Image Generation
1438
 
1439
  **Geometric crystalline diffusion** by [AbstractPhil](https://huggingface.co/AbstractPhil)
1440
 
1441
- Generate images using SD1.5 and SDXL-based models with geometric deep learning:
1442
-
1443
  | Model | Architecture | Lyra Version | Best For |
1444
  |-------|-------------|--------------|----------|
1445
  | **Illustrious XL** | SDXL | v2 (T5-XL) | Anime/illustration, high detail |
1446
  | **SDXL Base** | SDXL | v2 (T5-XL) | Photorealistic, general purpose |
1447
- | **Flow-Lune** | SD1.5 | v1 (T5-base) | Fast flow matching (15-25 steps) |
1448
- | **SD1.5 Base** | SD1.5 | v1 (T5-base) | Baseline comparison |
1449
-
1450
- **Lyra VAE** fuses CLIP + T5 embeddings using:
1451
- - **Prompt (Tags)**: Booru-style tags for CLIP encoding
1452
- - **T5 Summary**: Natural language description for T5 (format: `tags ¶ summary`)
1453
 
1454
- Enable **Lyra VAE** for side-by-side comparison!
 
1455
  """)
1456
 
1457
  with gr.Row():
1458
  with gr.Column(scale=1):
1459
  prompt = gr.TextArea(
1460
- label="Prompt (Tags for CLIP)",
1461
  value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background",
1462
  lines=3
1463
  )
1464
 
1465
  t5_summary = gr.TextArea(
1466
- label="T5 Summary (Natural Language for Lyra)",
1467
- value="A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms against a bright sky",
1468
  lines=2,
1469
- info="Used after separator for T5. Leave empty to use tags only."
1470
  )
1471
 
1472
  negative_prompt = gr.TextArea(
1473
  label="Negative Prompt",
1474
- value="lowres, bad anatomy, bad hands, text, error, cropped, worst quality, low quality",
1475
  lines=2
1476
  )
1477
 
1478
- model_choice = gr.Dropdown(
1479
- label="Model",
1480
- choices=[
1481
- "Illustrious XL",
1482
- "SDXL Base",
1483
- "Flow-Lune (SD1.5)",
1484
- "SD1.5 Base"
1485
- ],
1486
- value="Illustrious XL"
1487
- )
 
 
1488
 
1489
  clip_skip = gr.Slider(
1490
  label="CLIP Skip",
1491
- minimum=1,
1492
- maximum=4,
1493
- value=2,
1494
- step=1,
1495
- info="2 recommended for Illustrious, 1 for others"
1496
  )
1497
 
1498
  use_lyra = gr.Checkbox(
1499
- label="Enable Lyra VAE (CLIP+T5 Fusion)",
1500
- value=True,
1501
  info="Compare standard vs geometric fusion"
1502
  )
1503
 
1504
  lyra_strength = gr.Slider(
1505
  label="Lyra Blend Strength",
1506
- minimum=0.0,
1507
- maximum=3.0,
1508
- value=1.0,
1509
- step=0.05,
1510
- info="0.0 = pure CLIP, 1.0 = pure Lyra reconstruction, 3.0 = way too much but try it anyway"
1511
  )
1512
 
1513
  with gr.Accordion("Generation Settings", open=True):
1514
- num_steps = gr.Slider(
1515
- label="Steps",
1516
- minimum=1,
1517
- maximum=50,
1518
- value=25,
1519
- step=1
1520
- )
1521
-
1522
- cfg_scale = gr.Slider(
1523
- label="CFG Scale",
1524
- minimum=1.0,
1525
- maximum=20.0,
1526
- value=7.0,
1527
- step=0.5
1528
- )
1529
 
1530
  with gr.Row():
1531
- width = gr.Slider(
1532
- label="Width",
1533
- minimum=512,
1534
- maximum=1536,
1535
- value=1024,
1536
- step=64
1537
- )
1538
- height = gr.Slider(
1539
- label="Height",
1540
- minimum=512,
1541
- maximum=1536,
1542
- value=1024,
1543
- step=64
1544
- )
1545
 
1546
- seed = gr.Slider(
1547
- label="Seed",
1548
- minimum=0,
1549
- maximum=2**32 - 1,
1550
- value=42,
1551
- step=1
1552
- )
1553
-
1554
- randomize_seed = gr.Checkbox(
1555
- label="Randomize Seed",
1556
- value=True
1557
- )
1558
-
1559
- with gr.Accordion("Advanced (Flow Matching)", open=False):
1560
- use_flow_matching = gr.Checkbox(
1561
- label="Enable Flow Matching",
1562
- value=False,
1563
- info="Use flow matching ODE (for Lune only)"
1564
- )
1565
-
1566
- shift = gr.Slider(
1567
- label="Shift",
1568
- minimum=0.0,
1569
- maximum=5.0,
1570
- value=0.0,
1571
- step=0.1,
1572
- info="Flow matching shift (0=disabled)"
1573
- )
1574
 
1575
  generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
1576
 
1577
  with gr.Column(scale=1):
1578
  with gr.Row():
1579
- output_image_standard = gr.Image(
1580
- label="Standard",
1581
- type="pil"
1582
- )
1583
- output_image_lyra = gr.Image(
1584
- label="Lyra Fusion 🎵",
1585
- type="pil",
1586
- visible=True
1587
- )
1588
 
1589
  output_seed = gr.Number(label="Seed", precision=0)
1590
-
1591
- gr.Markdown("""
1592
- ### Tips
1593
- - **Illustrious XL**: Use CLIP skip 2, booru-style tags
1594
- - **SDXL Base**: Natural language prompts work well
1595
- - **Flow-Lune**: Enable flow matching, shift ~2.5, fewer steps
1596
- - **Lyra v2**: SDXL models use T5-XL for richer semantics
1597
- - **Lyra v1**: SD1.5 models use T5-base
1598
-
1599
- ### Model Info
1600
- - SDXL models use **epsilon** prediction
1601
- - Lune uses **v_prediction** with flow matching
1602
- - Lyra fuses CLIP + T5 via geometric Cantor attention
1603
- """)
1604
-
1605
- # Examples
1606
- gr.Examples(
1607
- examples=[
1608
- [
1609
- "masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background",
1610
- "A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms against a bright sky",
1611
- "lowres, bad anatomy, worst quality, low quality",
1612
- "Illustrious XL",
1613
- 2, 25, 7.0, 1024, 1024, 0.0, False, True, 0.8, 42, False
1614
- ],
1615
- [
1616
- "A majestic mountain landscape at golden hour, crystal clear lake, photorealistic, 8k",
1617
- "A breathtaking mountain vista bathed in warm golden light at sunset, with a perfectly still crystal clear lake reflecting the peaks",
1618
- "blurry, low quality",
1619
- "SDXL Base",
1620
- 1, 30, 7.5, 1024, 1024, 0.0, False, True, 0.8, 123, False
1621
- ],
1622
- [
1623
- "cyberpunk city at night, neon lights, rain, highly detailed",
1624
- "A futuristic cyberpunk metropolis at night with vibrant neon lights reflecting off rain-slicked streets",
1625
- "low quality, blurry",
1626
- "Flow-Lune (SD1.5)",
1627
- 1, 20, 7.5, 512, 512, 2.5, True, True, 0.8, 456, False
1628
- ],
1629
- ],
1630
- inputs=[
1631
- prompt, t5_summary, negative_prompt, model_choice, clip_skip,
1632
- num_steps, cfg_scale, width, height, shift,
1633
- use_flow_matching, use_lyra, lyra_strength, seed, randomize_seed
1634
- ],
1635
- outputs=[output_image_standard, output_image_lyra, output_seed],
1636
- fn=generate_image,
1637
- cache_examples=False
1638
- )
1639
 
1640
  # Event handlers
1641
- def on_model_change(model_name):
1642
- """Update defaults based on model."""
1643
- if "Illustrious" in model_name:
1644
- return {
1645
- clip_skip: gr.update(value=2),
1646
- width: gr.update(value=1024),
1647
- height: gr.update(value=1024),
1648
- num_steps: gr.update(value=25),
1649
- use_flow_matching: gr.update(value=False),
1650
- shift: gr.update(value=0.0)
1651
- }
1652
- elif "SDXL" in model_name:
1653
- return {
1654
- clip_skip: gr.update(value=1),
1655
- width: gr.update(value=1024),
1656
- height: gr.update(value=1024),
1657
- num_steps: gr.update(value=30),
1658
- use_flow_matching: gr.update(value=False),
1659
- shift: gr.update(value=0.0)
1660
- }
1661
- elif "Lune" in model_name:
1662
- return {
1663
- clip_skip: gr.update(value=1),
1664
- width: gr.update(value=512),
1665
- height: gr.update(value=512),
1666
- num_steps: gr.update(value=20),
1667
- use_flow_matching: gr.update(value=True),
1668
- shift: gr.update(value=2.5)
1669
- }
1670
- else: # SD1.5 Base
1671
- return {
1672
- clip_skip: gr.update(value=1),
1673
- width: gr.update(value=512),
1674
- height: gr.update(value=512),
1675
- num_steps: gr.update(value=30),
1676
- use_flow_matching: gr.update(value=False),
1677
- shift: gr.update(value=0.0)
1678
- }
1679
-
1680
  def on_lyra_toggle(enabled):
1681
- """Show/hide Lyra comparison."""
1682
  if enabled:
1683
  return {
1684
  output_image_standard: gr.update(visible=True, label="Standard"),
@@ -1690,12 +977,6 @@ def create_demo():
1690
  output_image_lyra: gr.update(visible=False)
1691
  }
1692
 
1693
- model_choice.change(
1694
- fn=on_model_change,
1695
- inputs=[model_choice],
1696
- outputs=[clip_skip, width, height, num_steps, use_flow_matching, shift]
1697
- )
1698
-
1699
  use_lyra.change(
1700
  fn=on_lyra_toggle,
1701
  inputs=[use_lyra],
@@ -1705,9 +986,9 @@ def create_demo():
1705
  generate_btn.click(
1706
  fn=generate_image,
1707
  inputs=[
1708
- prompt, t5_summary, negative_prompt, model_choice, clip_skip,
1709
- num_steps, cfg_scale, width, height, shift,
1710
- use_flow_matching, use_lyra, lyra_strength, seed, randomize_seed
1711
  ],
1712
  outputs=[output_image_standard, output_image_lyra, output_seed]
1713
  )
 
8
 
9
  Lyra VAE Versions:
10
  - v1: SD1.5 (768 dim CLIP + T5-base) - geofractal.model.vae.vae_lyra
11
+ - v2: SDXL/Illustrious (768 CLIP-L + 1280 CLIP-G + 2048 T5-XL) - geofractal.model.vae.vae_lyra_v2
12
  """
13
 
14
  import os
 
25
  UNet2DConditionModel,
26
  AutoencoderKL,
27
  EulerDiscreteScheduler,
28
+ EulerAncestralDiscreteScheduler,
29
+ DPMSolverMultistepScheduler,
30
+ DPMSolverSDEScheduler,
31
  )
 
32
  from transformers import (
33
  CLIPTextModel,
34
  CLIPTokenizer,
 
38
  )
39
  from huggingface_hub import hf_hub_download
40
 
41
+ # Lazy imports for Lyra
42
+ LYRA_V1_AVAILABLE = False
43
+ LYRA_V2_AVAILABLE = False
44
+ LyraV1 = None
45
+ LyraV1Config = None
46
+ LyraV2 = None
47
+ LyraV2Config = None
48
 
49
+
50
+ def _load_lyra_imports():
51
+ """Lazy load Lyra VAE modules."""
52
+ global LYRA_V1_AVAILABLE, LYRA_V2_AVAILABLE
53
+ global LyraV1, LyraV1Config, LyraV2, LyraV2Config
54
+
55
+ try:
56
+ from geofractal.model.vae.vae_lyra import MultiModalVAE as _LyraV1, MultiModalVAEConfig as _LyraV1Config
57
+ LyraV1 = _LyraV1
58
+ LyraV1Config = _LyraV1Config
59
+ LYRA_V1_AVAILABLE = True
60
+ except ImportError:
61
+ print("⚠️ Lyra VAE v1 not available")
62
+
63
+ try:
64
+ from geofractal.model.vae.vae_lyra_v2 import MultiModalVAE as _LyraV2, MultiModalVAEConfig as _LyraV2Config
65
+ LyraV2 = _LyraV2
66
+ LyraV2Config = _LyraV2Config
67
+ LYRA_V2_AVAILABLE = True
68
+ except ImportError:
69
+ print("⚠️ Lyra VAE v2 not available")
70
 
71
 
72
  # ============================================================================
73
  # CONSTANTS
74
  # ============================================================================
75
 
 
76
  ARCH_SD15 = "sd15"
77
  ARCH_SDXL = "sdxl"
78
 
79
+ # Scheduler options
80
+ SCHEDULER_EULER_A = "Euler Ancestral"
81
+ SCHEDULER_EULER = "Euler"
82
+ SCHEDULER_DPM_2M_SDE = "DPM++ 2M SDE"
83
+ SCHEDULER_DPM_2M = "DPM++ 2M"
84
+
85
+ SDXL_SCHEDULERS = [SCHEDULER_EULER_A, SCHEDULER_EULER, SCHEDULER_DPM_2M_SDE, SCHEDULER_DPM_2M]
86
 
87
 
88
  # ============================================================================
89
+ # SCHEDULER FACTORY
90
  # ============================================================================
91
 
92
+ def get_scheduler(scheduler_name: str, config_path: str = "stabilityai/stable-diffusion-xl-base-1.0"):
93
+ """Create scheduler by name."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ if scheduler_name == SCHEDULER_EULER_A:
96
+ return EulerAncestralDiscreteScheduler.from_pretrained(
97
+ config_path, subfolder="scheduler"
98
+ )
99
+ elif scheduler_name == SCHEDULER_EULER:
100
+ return EulerDiscreteScheduler.from_pretrained(
101
+ config_path, subfolder="scheduler"
102
+ )
103
+ elif scheduler_name == SCHEDULER_DPM_2M_SDE:
104
+ return DPMSolverSDEScheduler.from_pretrained(
105
+ config_path, subfolder="scheduler",
106
+ algorithm_type="sde-dpmsolver++",
107
+ solver_order=2,
108
+ )
109
+ elif scheduler_name == SCHEDULER_DPM_2M:
110
+ return DPMSolverMultistepScheduler.from_pretrained(
111
+ config_path, subfolder="scheduler",
112
+ algorithm_type="dpmsolver++",
113
+ solver_order=2,
114
+ )
115
+ else:
116
+ # Default to Euler Ancestral
117
+ return EulerAncestralDiscreteScheduler.from_pretrained(
118
+ config_path, subfolder="scheduler"
119
+ )
120
+
121
 
122
+ # ============================================================================
123
+ # MODEL LOADING UTILITIES
124
+ # ============================================================================
125
 
126
  def get_clip_hidden_state(
127
  model_output,
 
133
  return model_output.last_hidden_state
134
 
135
  if hasattr(model_output, 'hidden_states') and model_output.hidden_states is not None:
 
 
136
  return model_output.hidden_states[-clip_skip]
137
 
138
  return model_output.last_hidden_state
139
 
140
 
141
+ # ============================================================================
142
+ # LAZY LOADERS
143
+ # ============================================================================
144
+
145
+ class LazyT5Encoder:
146
+ """Lazy loader for T5 encoder - only loads when first accessed."""
147
+
148
+ def __init__(self, model_name: str = "google/flan-t5-xl", device: str = "cuda"):
149
+ self.model_name = model_name
150
+ self.device = device
151
+ self._encoder = None
152
+ self._tokenizer = None
153
+
154
+ @property
155
+ def encoder(self):
156
+ if self._encoder is None:
157
+ print(f"📥 Loading T5 encoder: {self.model_name}...")
158
+ self._encoder = T5EncoderModel.from_pretrained(
159
+ self.model_name,
160
+ torch_dtype=torch.float16
161
+ ).to(self.device)
162
+ self._encoder.eval()
163
+ print("✓ T5 encoder loaded")
164
+ return self._encoder
165
+
166
+ @property
167
+ def tokenizer(self):
168
+ if self._tokenizer is None:
169
+ print(f"📥 Loading T5 tokenizer: {self.model_name}...")
170
+ self._tokenizer = T5Tokenizer.from_pretrained(self.model_name)
171
+ print("✓ T5 tokenizer loaded")
172
+ return self._tokenizer
173
+
174
+ def is_loaded(self):
175
+ return self._encoder is not None
176
+
177
+
178
+ class LazyLyraModel:
179
+ """Lazy loader for Lyra VAE - only loads when first accessed."""
180
+
181
+ def __init__(self, repo_id: str, device: str = "cuda", version: int = 2):
182
+ self.repo_id = repo_id
183
+ self.device = device
184
+ self.version = version
185
+ self._model = None
186
+
187
+ @property
188
+ def model(self):
189
+ if self._model is None:
190
+ _load_lyra_imports()
191
+
192
+ if self.version == 2:
193
+ self._model = self._load_v2()
194
+ else:
195
+ self._model = self._load_v1()
196
+ return self._model
197
+
198
+ def _load_v2(self):
199
+ if not LYRA_V2_AVAILABLE:
200
+ print("⚠️ Lyra VAE v2 not available")
201
+ return None
202
+
203
+ print(f"🎵 Loading Lyra VAE v2 from {self.repo_id}...")
204
+
205
+ try:
206
+ from huggingface_hub import list_repo_files
207
+
208
+ config_path = hf_hub_download(
209
+ repo_id=self.repo_id,
210
+ filename="config.json",
211
+ repo_type="model"
212
+ )
213
+
214
+ with open(config_path, 'r') as f:
215
+ config_dict = json.load(f)
216
+
217
+ print(f" ✓ Config: {config_dict.get('fusion_strategy', 'unknown')} fusion")
218
+
219
+ # Auto-detect checkpoint
220
+ repo_files = list_repo_files(self.repo_id, repo_type="model")
221
+ checkpoint_files = [f for f in repo_files if f.endswith('.pt')]
222
+ checkpoint_files = [f for f in checkpoint_files if 'checkpoint' in f.lower()]
223
+
224
+ if not checkpoint_files:
225
+ raise FileNotFoundError(f"No checkpoint found in {self.repo_id}")
226
+
227
+ import re
228
+ def extract_step(name):
229
+ match = re.search(r'(\d+)\.pt', name)
230
+ return int(match.group(1)) if match else 0
231
+
232
+ checkpoint_files.sort(key=extract_step, reverse=True)
233
+ checkpoint_filename = checkpoint_files[0]
234
+ print(f" ✓ Using: {checkpoint_filename}")
235
+
236
+ checkpoint_path = hf_hub_download(
237
+ repo_id=self.repo_id,
238
+ filename=checkpoint_filename,
239
+ repo_type="model"
240
+ )
241
+
242
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
243
+
244
+ vae_config = LyraV2Config(
245
+ modality_dims=config_dict.get('modality_dims', {
246
+ "clip_l": 768, "clip_g": 1280,
247
+ "t5_xl_l": 2048, "t5_xl_g": 2048
248
+ }),
249
+ modality_seq_lens=config_dict.get('modality_seq_lens', {
250
+ "clip_l": 77, "clip_g": 77,
251
+ "t5_xl_l": 512, "t5_xl_g": 512
252
+ }),
253
+ binding_config=config_dict.get('binding_config', {
254
+ "clip_l": {"t5_xl_l": 0.3},
255
+ "clip_g": {"t5_xl_g": 0.3},
256
+ "t5_xl_l": {},
257
+ "t5_xl_g": {}
258
+ }),
259
+ latent_dim=config_dict.get('latent_dim', 2048),
260
+ seq_len=config_dict.get('seq_len', 77),
261
+ encoder_layers=config_dict.get('encoder_layers', 3),
262
+ decoder_layers=config_dict.get('decoder_layers', 3),
263
+ hidden_dim=config_dict.get('hidden_dim', 2048),
264
+ dropout=config_dict.get('dropout', 0.1),
265
+ fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'),
266
+ fusion_heads=config_dict.get('fusion_heads', 8),
267
+ fusion_dropout=config_dict.get('fusion_dropout', 0.1),
268
+ cantor_depth=config_dict.get('cantor_depth', 8),
269
+ cantor_local_window=config_dict.get('cantor_local_window', 3),
270
+ alpha_init=config_dict.get('alpha_init', 1.0),
271
+ beta_init=config_dict.get('beta_init', 0.3),
272
+ )
273
+
274
+ lyra_model = LyraV2(vae_config)
275
+
276
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
277
+ missing, unexpected = lyra_model.load_state_dict(state_dict, strict=False)
278
+
279
+ if missing:
280
+ print(f" ⚠️ Missing keys: {len(missing)}")
281
+ if unexpected:
282
+ print(f" ⚠️ Unexpected keys: {len(unexpected)}")
283
+
284
+ lyra_model.to(self.device)
285
+ lyra_model.eval()
286
+
287
+ total_params = sum(p.numel() for p in lyra_model.parameters())
288
+ print(f"✅ Lyra VAE v2 loaded ({total_params/1e6:.1f}M params)")
289
+
290
+ return lyra_model
291
+
292
+ except Exception as e:
293
+ print(f"❌ Failed to load Lyra VAE v2: {e}")
294
+ import traceback
295
+ traceback.print_exc()
296
+ return None
297
+
298
+ def _load_v1(self):
299
+ if not LYRA_V1_AVAILABLE:
300
+ print("⚠️ Lyra VAE v1 not available")
301
+ return None
302
+
303
+ # Similar implementation for v1...
304
+ return None
305
+
306
+ def is_loaded(self):
307
+ return self._model is not None
308
+
309
+
310
  # ============================================================================
311
  # SDXL PIPELINE
312
  # ============================================================================
 
317
  def __init__(
318
  self,
319
  vae: AutoencoderKL,
320
+ text_encoder: CLIPTextModel,
321
+ text_encoder_2: CLIPTextModelWithProjection,
322
  tokenizer: CLIPTokenizer,
323
  tokenizer_2: CLIPTokenizer,
324
  unet: UNet2DConditionModel,
325
  scheduler,
326
  device: str = "cuda",
327
+ t5_loader: Optional[LazyT5Encoder] = None,
328
+ lyra_loader: Optional[LazyLyraModel] = None,
 
329
  clip_skip: int = 1
330
  ):
331
  self.vae = vae
 
337
  self.scheduler = scheduler
338
  self.device = device
339
 
340
+ # Lazy loaders
341
+ self.t5_loader = t5_loader
342
+ self.lyra_loader = lyra_loader
 
343
 
344
  # Settings
345
  self.clip_skip = clip_skip
346
+ self.vae_scale_factor = 0.13025
347
  self.arch = ARCH_SDXL
348
+
349
+ def set_scheduler(self, scheduler_name: str):
350
+ """Switch scheduler."""
351
+ self.scheduler = get_scheduler(scheduler_name)
352
+
353
+ @property
354
+ def t5_encoder(self):
355
+ return self.t5_loader.encoder if self.t5_loader else None
356
+
357
+ @property
358
+ def t5_tokenizer(self):
359
+ return self.t5_loader.tokenizer if self.t5_loader else None
360
+
361
+ @property
362
+ def lyra_model(self):
363
+ return self.lyra_loader.model if self.lyra_loader else None
364
+
365
  def encode_prompt(
366
  self,
367
  prompt: str,
 
404
  output_hidden_states=output_hidden_states
405
  )
406
  prompt_embeds_g = get_clip_hidden_state(clip_g_output, clip_skip, output_hidden_states)
 
 
407
  pooled_prompt_embeds = clip_g_output.text_embeds
408
 
 
409
  prompt_embeds = torch.cat([prompt_embeds_l, prompt_embeds_g], dim=-1)
410
 
411
  # Negative prompt
 
457
  t5_summary: str = "",
458
  lyra_strength: float = 0.3
459
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
460
+ """Encode prompts using Lyra VAE v2 fusion (CLIP + T5)."""
461
 
 
 
 
 
 
 
462
  if self.lyra_model is None or self.t5_encoder is None:
463
  raise ValueError("Lyra VAE components not initialized")
464
 
 
467
  prompt, negative_prompt, clip_skip
468
  )
469
 
470
+ # Format T5 input
471
  SUMMARY_SEPARATOR = "¶"
472
  if t5_summary.strip():
473
  t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {t5_summary}"
 
487
  t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state
488
 
489
  clip_l_dim = 768
 
 
490
  clip_l_embeds = prompt_embeds[..., :clip_l_dim]
491
  clip_g_embeds = prompt_embeds[..., clip_l_dim:]
492
 
 
 
 
 
 
493
  with torch.no_grad():
 
494
  modality_inputs = {
495
  'clip_l': clip_l_embeds.float(),
496
  'clip_g': clip_g_embeds.float(),
 
501
  modality_inputs,
502
  target_modalities=['clip_l', 'clip_g']
503
  )
 
504
 
505
  lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
506
  lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
507
 
508
+ # Normalize if stats are off
 
 
 
509
  clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8)
510
  clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8)
 
511
 
 
512
  if clip_l_std_ratio > 2.0 or clip_l_std_ratio < 0.5:
 
513
  lyra_clip_l = (lyra_clip_l - lyra_clip_l.mean()) / (lyra_clip_l.std() + 1e-8)
514
  lyra_clip_l = lyra_clip_l * clip_l_embeds.std() + clip_l_embeds.mean()
515
 
516
  if clip_g_std_ratio > 2.0 or clip_g_std_ratio < 0.5:
 
517
  lyra_clip_g = (lyra_clip_g - lyra_clip_g.mean()) / (lyra_clip_g.std() + 1e-8)
518
  lyra_clip_g = lyra_clip_g * clip_g_embeds.std() + clip_g_embeds.mean()
 
519
 
520
+ # Blend
521
  fused_clip_l = (1 - lyra_strength) * clip_l_embeds + lyra_strength * lyra_clip_l
522
  fused_clip_g = (1 - lyra_strength) * clip_g_embeds + lyra_strength * lyra_clip_g
523
 
 
 
 
524
  prompt_embeds_fused = torch.cat([fused_clip_l, fused_clip_g], dim=-1)
525
 
526
+ # Negative prompt - just use original CLIP
527
+ return prompt_embeds_fused, negative_prompt_embeds, pooled, negative_pooled
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
  def _get_add_time_ids(
530
  self,
 
545
  negative_prompt: str = "",
546
  height: int = 1024,
547
  width: int = 1024,
548
+ num_inference_steps: int = 25,
549
+ guidance_scale: float = 7.0,
 
 
 
550
  seed: Optional[int] = None,
551
  use_lyra: bool = False,
552
+ clip_skip: int = 2,
553
  t5_summary: str = "",
554
  lyra_strength: float = 1.0,
555
  progress_callback=None
556
  ):
557
  """Generate image using SDXL architecture."""
558
 
 
559
  if seed is not None:
560
  generator = torch.Generator(device=self.device).manual_seed(seed)
561
  else:
562
  generator = None
563
 
564
  # Encode prompts
565
+ if use_lyra and self.lyra_loader is not None:
566
  prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra(
567
  prompt, negative_prompt, clip_skip, t5_summary, lyra_strength
568
  )
 
587
  self.scheduler.set_timesteps(num_inference_steps, device=self.device)
588
  timesteps = self.scheduler.timesteps
589
 
590
+ latents = latents * self.scheduler.init_noise_sigma
 
 
591
 
592
+ # Time embeddings for SDXL
593
  original_size = (height, width)
594
  target_size = (height, width)
595
  crops_coords_top_left = (0, 0)
 
597
  add_time_ids = self._get_add_time_ids(
598
  original_size, crops_coords_top_left, target_size, dtype=torch.float16
599
  )
600
+ negative_add_time_ids = add_time_ids
601
 
602
  # Denoising loop
603
  for i, t in enumerate(timesteps):
604
  if progress_callback:
605
  progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
606
 
 
607
  latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
608
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
609
 
 
 
 
 
 
 
 
 
 
 
610
  timestep = t.expand(latent_model_input.shape[0])
611
 
 
612
  if guidance_scale > 1.0:
613
  text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
614
  add_text_embeds = torch.cat([negative_pooled, pooled])
 
618
  add_text_embeds = pooled
619
  add_time_ids_input = add_time_ids
620
 
 
621
  added_cond_kwargs = {
622
  "text_embeds": add_text_embeds,
623
  "time_ids": add_time_ids_input
624
  }
625
 
 
626
  noise_pred = self.unet(
627
  latent_model_input,
628
  timestep,
 
631
  return_dict=False
632
  )[0]
633
 
 
634
  if guidance_scale > 1.0:
635
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
636
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
637
 
638
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
 
640
  # Decode
641
  latents = latents / self.vae_scale_factor
 
643
  with torch.no_grad():
644
  image = self.vae.decode(latents.to(self.vae.dtype)).sample
645
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
  image = (image / 2 + 0.5).clamp(0, 1)
647
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
648
  image = (image * 255).round().astype("uint8")
 
655
  # MODEL LOADERS
656
  # ============================================================================
657
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
  def load_illustrious_xl(
659
+ repo_id: str = "AbstractPhil/illustrious-xl-v1",
660
  filename: str = "illustriousXL_v01.safetensors",
661
  device: str = "cuda"
662
  ) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]:
663
+ """Load Illustrious XL from single safetensors file."""
664
  from diffusers import StableDiffusionXLPipeline
665
 
666
  print(f"📥 Loading Illustrious XL: {repo_id}/{filename}")
667
 
 
668
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
669
  print(f"✓ Downloaded: {checkpoint_path}")
670
 
671
+ print("📦 Loading pipeline...")
 
672
  pipe = StableDiffusionXLPipeline.from_single_file(
673
  checkpoint_path,
674
  torch_dtype=torch.float16,
675
  use_safetensors=True,
676
  )
677
 
 
678
  unet = pipe.unet.to(device)
679
  vae = pipe.vae.to(device)
680
  text_encoder = pipe.text_encoder.to(device)
 
682
  tokenizer = pipe.tokenizer
683
  tokenizer_2 = pipe.tokenizer_2
684
 
 
685
  del pipe
686
  torch.cuda.empty_cache()
687
 
688
  print("✅ Illustrious XL loaded!")
 
 
689
 
690
  return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
691
 
692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
  # ============================================================================
694
  # PIPELINE INITIALIZATION
695
  # ============================================================================
696
 
697
+ def initialize_sdxl_pipeline(
698
+ model_choice: str,
699
+ scheduler_name: str = SCHEDULER_EULER_A,
700
+ device: str = "cuda"
701
+ ):
702
+ """Initialize SDXL pipeline with lazy T5/Lyra loading."""
703
 
704
  print(f"🚀 Initializing {model_choice} pipeline...")
705
 
706
+ # Load base model
707
+ if "Illustrious" in model_choice:
708
+ unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl(device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709
  else:
710
+ # SDXL Base
711
+ from diffusers import StableDiffusionXLPipeline
712
+ pipe = StableDiffusionXLPipeline.from_pretrained(
713
+ "stabilityai/stable-diffusion-xl-base-1.0",
714
+ torch_dtype=torch.float16,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
  )
716
+ unet = pipe.unet.to(device)
717
+ vae = pipe.vae.to(device)
718
+ text_encoder = pipe.text_encoder.to(device)
719
+ text_encoder_2 = pipe.text_encoder_2.to(device)
720
+ tokenizer = pipe.tokenizer
721
+ tokenizer_2 = pipe.tokenizer_2
722
+ del pipe
723
+ torch.cuda.empty_cache()
724
+
725
+ # Create lazy loaders (don't download yet)
726
+ t5_loader = LazyT5Encoder(model_name="google/flan-t5-xl", device=device)
727
+ lyra_loader = LazyLyraModel(
728
+ repo_id="AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
729
+ device=device,
730
+ version=2
731
+ )
732
+
733
+ # Get scheduler
734
+ scheduler = get_scheduler(scheduler_name)
735
+
736
+ pipeline = SDXLFlowMatchingPipeline(
737
+ vae=vae,
738
+ text_encoder=text_encoder,
739
+ text_encoder_2=text_encoder_2,
740
+ tokenizer=tokenizer,
741
+ tokenizer_2=tokenizer_2,
742
+ unet=unet,
743
+ scheduler=scheduler,
744
+ device=device,
745
+ t5_loader=t5_loader,
746
+ lyra_loader=lyra_loader,
747
+ clip_skip=2
748
+ )
749
 
750
+ print("✅ Pipeline initialized (T5/Lyra will load on first use)")
751
  return pipeline
752
 
753
 
 
757
 
758
  CURRENT_PIPELINE = None
759
  CURRENT_MODEL = None
760
+ CURRENT_SCHEDULER = None
761
 
762
 
763
+ def get_pipeline(model_choice: str, scheduler_name: str = SCHEDULER_EULER_A):
764
  """Get or create pipeline for selected model."""
765
+ global CURRENT_PIPELINE, CURRENT_MODEL, CURRENT_SCHEDULER
766
 
767
  if CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice:
768
+ CURRENT_PIPELINE = initialize_sdxl_pipeline(model_choice, scheduler_name, device="cuda")
769
  CURRENT_MODEL = model_choice
770
+ CURRENT_SCHEDULER = scheduler_name
771
+ elif CURRENT_SCHEDULER != scheduler_name:
772
+ CURRENT_PIPELINE.set_scheduler(scheduler_name)
773
+ CURRENT_SCHEDULER = scheduler_name
774
 
775
  return CURRENT_PIPELINE
776
 
 
779
  # INFERENCE
780
  # ============================================================================
781
 
782
+ @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
783
  def generate_image(
784
  prompt: str,
785
  t5_summary: str,
786
  negative_prompt: str,
787
  model_choice: str,
788
+ scheduler_name: str,
789
  clip_skip: int,
790
  num_steps: int,
791
  cfg_scale: float,
792
  width: int,
793
  height: int,
 
 
794
  use_lyra: bool,
795
  lyra_strength: float,
796
  seed: int,
 
806
  progress((step + 1) / total, desc=desc)
807
 
808
  try:
809
+ pipeline = get_pipeline(model_choice, scheduler_name)
810
 
811
+ if not use_lyra or pipeline.lyra_loader is None:
 
 
 
 
 
 
 
812
  progress(0.05, desc="Generating...")
813
 
814
  image = pipeline(
 
818
  width=width,
819
  num_inference_steps=num_steps,
820
  guidance_scale=cfg_scale,
 
 
 
821
  seed=seed,
822
  use_lyra=False,
823
  clip_skip=clip_skip,
 
837
  width=width,
838
  num_inference_steps=num_steps,
839
  guidance_scale=cfg_scale,
 
 
 
840
  seed=seed,
841
  use_lyra=False,
842
  clip_skip=clip_skip,
843
  progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d)
844
  )
845
 
846
+ progress(0.5, desc="Loading Lyra + T5 (first run only)...")
847
 
848
  image_lyra = pipeline(
849
  prompt=prompt,
 
852
  width=width,
853
  num_inference_steps=num_steps,
854
  guidance_scale=cfg_scale,
 
 
 
855
  seed=seed,
856
  use_lyra=True,
857
  clip_skip=clip_skip,
 
865
 
866
  except Exception as e:
867
  print(f"❌ Generation failed: {e}")
868
+ import traceback
869
+ traceback.print_exc()
870
  raise e
871
 
872
 
 
879
 
880
  with gr.Blocks() as demo:
881
  gr.Markdown("""
882
+ # 🌙 Lyra/Illustrious XL Image Generation
883
 
884
  **Geometric crystalline diffusion** by [AbstractPhil](https://huggingface.co/AbstractPhil)
885
 
 
 
886
  | Model | Architecture | Lyra Version | Best For |
887
  |-------|-------------|--------------|----------|
888
  | **Illustrious XL** | SDXL | v2 (T5-XL) | Anime/illustration, high detail |
889
  | **SDXL Base** | SDXL | v2 (T5-XL) | Photorealistic, general purpose |
 
 
 
 
 
 
890
 
891
+ **Lyra VAE** fuses CLIP + T5-XL embeddings using adaptive Cantor attention.
892
+ T5 and Lyra only load when you enable the Lyra checkbox!
893
  """)
894
 
895
  with gr.Row():
896
  with gr.Column(scale=1):
897
  prompt = gr.TextArea(
898
+ label="Prompt",
899
  value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background",
900
  lines=3
901
  )
902
 
903
  t5_summary = gr.TextArea(
904
+ label="T5 Summary (for Lyra)",
905
+ value="A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms",
906
  lines=2,
907
+ info="Natural language description for T5. Leave empty to use prompt."
908
  )
909
 
910
  negative_prompt = gr.TextArea(
911
  label="Negative Prompt",
912
+ value="lowres, bad anatomy, bad hands, text, error, worst quality, low quality",
913
  lines=2
914
  )
915
 
916
+ with gr.Row():
917
+ model_choice = gr.Dropdown(
918
+ label="Model",
919
+ choices=["Illustrious XL", "SDXL Base"],
920
+ value="Illustrious XL"
921
+ )
922
+
923
+ scheduler_name = gr.Dropdown(
924
+ label="Scheduler",
925
+ choices=SDXL_SCHEDULERS,
926
+ value=SCHEDULER_EULER_A
927
+ )
928
 
929
  clip_skip = gr.Slider(
930
  label="CLIP Skip",
931
+ minimum=1, maximum=4, value=2, step=1,
932
+ info="2 recommended for Illustrious"
 
 
 
933
  )
934
 
935
  use_lyra = gr.Checkbox(
936
+ label="Enable Lyra VAE (loads T5-XL on first use)",
937
+ value=False,
938
  info="Compare standard vs geometric fusion"
939
  )
940
 
941
  lyra_strength = gr.Slider(
942
  label="Lyra Blend Strength",
943
+ minimum=0.0, maximum=2.0, value=1.0, step=0.05,
944
+ info="0.0 = pure CLIP, 1.0 = pure Lyra"
 
 
 
945
  )
946
 
947
  with gr.Accordion("Generation Settings", open=True):
948
+ num_steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=25, step=1)
949
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1.0, maximum=15.0, value=7.0, step=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
950
 
951
  with gr.Row():
952
+ width = gr.Slider(label="Width", minimum=512, maximum=1536, value=1024, step=64)
953
+ height = gr.Slider(label="Height", minimum=512, maximum=1536, value=1024, step=64)
 
 
 
 
 
 
 
 
 
 
 
 
954
 
955
+ seed = gr.Slider(label="Seed", minimum=0, maximum=2**32 - 1, value=42, step=1)
956
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
957
 
958
  generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
959
 
960
  with gr.Column(scale=1):
961
  with gr.Row():
962
+ output_image_standard = gr.Image(label="Standard", type="pil")
963
+ output_image_lyra = gr.Image(label="Lyra Fusion 🎵", type="pil", visible=True)
 
 
 
 
 
 
 
964
 
965
  output_seed = gr.Number(label="Seed", precision=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
966
 
967
  # Event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
968
  def on_lyra_toggle(enabled):
 
969
  if enabled:
970
  return {
971
  output_image_standard: gr.update(visible=True, label="Standard"),
 
977
  output_image_lyra: gr.update(visible=False)
978
  }
979
 
 
 
 
 
 
 
980
  use_lyra.change(
981
  fn=on_lyra_toggle,
982
  inputs=[use_lyra],
 
986
  generate_btn.click(
987
  fn=generate_image,
988
  inputs=[
989
+ prompt, t5_summary, negative_prompt, model_choice, scheduler_name,
990
+ clip_skip, num_steps, cfg_scale, width, height,
991
+ use_lyra, lyra_strength, seed, randomize_seed
992
  ],
993
  outputs=[output_image_standard, output_image_lyra, output_seed]
994
  )