AbstractPhil's picture
Update app.py
223ed70 verified
"""
Lyra/Lune Flow-Matching Inference Space
Author: AbstractPhil
License: MIT
SD1.5 and SDXL-based flow matching with geometric crystalline architectures.
Supports Illustrious XL, standard SDXL, and SD1.5 variants.
Lyra VAE Versions:
- v1: SD1.5 (768 dim CLIP + T5-base) - geofractal.model.vae.vae_lyra
- v2: SDXL/Illustrious (768 CLIP-L + 1280 CLIP-G + 2048 T5-XL) - geofractal.model.vae.vae_lyra_v2
Features:
- Lazy loading: T5 and Lyra only download when first used
- Multiple schedulers: Euler Ancestral, Euler, DPM++ 2M SDE, DPM++ 2M
- Integrated loader module for automatic version detection
"""
import os
import json
import torch
import gradio as gr
import numpy as np
from PIL import Image
from typing import Optional, Dict, Tuple, Union
import spaces
from safetensors.torch import load_file as load_safetensors
from diffusers import (
UNet2DConditionModel,
AutoencoderKL,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DPMSolverSDEScheduler,
)
from transformers import (
CLIPTextModel,
CLIPTokenizer,
CLIPTextModelWithProjection,
T5EncoderModel,
T5Tokenizer
)
from huggingface_hub import hf_hub_download
# Import Lyra VAE v1 (SD1.5) from geofractal
try:
from geofractal.model.vae.vae_lyra import MultiModalVAE as LyraV1, MultiModalVAEConfig as LyraV1Config
LYRA_V1_AVAILABLE = True
except ImportError:
print("⚠️ Lyra VAE v1 not available")
LYRA_V1_AVAILABLE = False
# Import Lyra VAE v2 (SDXL/Illustrious) from geofractal
try:
from geofractal.model.vae.vae_lyra_v2 import MultiModalVAE as LyraV2, MultiModalVAEConfig as LyraV2Config
LYRA_V2_AVAILABLE = True
except ImportError:
print("⚠️ Lyra VAE v2 not available")
LYRA_V2_AVAILABLE = False
# Import Lyra loader module
try:
from geofractal.model.vae.loader import load_vae_lyra, load_lyra_illustrious
LYRA_LOADER_AVAILABLE = True
except ImportError:
print("⚠️ Lyra loader module not available, using fallback")
LYRA_LOADER_AVAILABLE = False
# ============================================================================
# CONSTANTS
# ============================================================================
ARCH_SD15 = "sd15"
ARCH_SDXL = "sdxl"
# Scheduler names
SCHEDULER_EULER_A = "Euler Ancestral"
SCHEDULER_EULER = "Euler"
SCHEDULER_DPM_2M_SDE = "DPM++ 2M SDE"
SCHEDULER_DPM_2M = "DPM++ 2M"
SCHEDULER_CHOICES = [
SCHEDULER_EULER_A,
SCHEDULER_EULER,
SCHEDULER_DPM_2M_SDE,
SCHEDULER_DPM_2M,
]
# ComfyUI key prefixes for SDXL single-file checkpoints
COMFYUI_UNET_PREFIX = "model.diffusion_model."
COMFYUI_CLIP_L_PREFIX = "conditioner.embedders.0.transformer."
COMFYUI_CLIP_G_PREFIX = "conditioner.embedders.1.model."
COMFYUI_VAE_PREFIX = "first_stage_model."
# Lyra repos
LYRA_ILLUSTRIOUS_REPO = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious"
LYRA_SD15_REPO = "AbstractPhil/vae-lyra"
# T5 model - use flan-t5-xl (what Lyra was trained on)
T5_XL_MODEL = "google/flan-t5-xl"
T5_BASE_MODEL = "google/flan-t5-base"
# ============================================================================
# LAZY LOADERS
# ============================================================================
class LazyT5Encoder:
"""Lazy loader for T5 encoder - only downloads/loads when first accessed."""
def __init__(self, model_name: str = T5_XL_MODEL, device: str = "cuda", dtype=torch.float16):
self.model_name = model_name
self.device = device
self.dtype = dtype
self._encoder = None
self._tokenizer = None
self._loaded = False
@property
def encoder(self) -> T5EncoderModel:
if self._encoder is None:
print(f"📥 Lazy loading T5 encoder: {self.model_name}...")
self._encoder = T5EncoderModel.from_pretrained(
self.model_name,
torch_dtype=self.dtype
).to(self.device)
self._encoder.eval()
print(f"✓ T5 encoder loaded ({sum(p.numel() for p in self._encoder.parameters())/1e6:.1f}M params)")
self._loaded = True
return self._encoder
@property
def tokenizer(self) -> T5Tokenizer:
if self._tokenizer is None:
print(f"📥 Loading T5 tokenizer: {self.model_name}...")
self._tokenizer = T5Tokenizer.from_pretrained(self.model_name)
print("✓ T5 tokenizer loaded")
return self._tokenizer
@property
def is_loaded(self) -> bool:
return self._loaded
def unload(self):
"""Free VRAM by unloading the encoder."""
if self._encoder is not None:
del self._encoder
self._encoder = None
self._loaded = False
torch.cuda.empty_cache()
print("🗑️ T5 encoder unloaded")
class LazyLyraModel:
"""Lazy loader for Lyra VAE - only downloads/loads when first accessed.
Exposes config with modality_seq_lens for proper tokenization lengths.
"""
def __init__(
self,
repo_id: str = LYRA_ILLUSTRIOUS_REPO,
device: str = "cuda",
checkpoint: Optional[str] = None
):
self.repo_id = repo_id
self.device = device
self.checkpoint = checkpoint
self._model = None
self._info = None
self._config = None
self._loaded = False
# Pre-fetch config without loading model (lightweight)
self._prefetch_config()
def _prefetch_config(self):
"""Fetch config.json to get sequence lengths without loading the full model."""
try:
config_path = hf_hub_download(
repo_id=self.repo_id,
filename="config.json",
repo_type="model"
)
with open(config_path, 'r') as f:
self._config = json.load(f)
# Detect version from config
is_v2 = 'modality_seq_lens' in self._config or 'binding_config' in self._config
version = "v2" if is_v2 else "v1"
print(f"📋 Lyra config prefetched: {self.repo_id} ({version})")
if is_v2:
print(f" Sequence lengths: {self._config.get('modality_seq_lens', {})}")
else:
print(f" Sequence length: {self._config.get('seq_len', 77)}")
except Exception as e:
print(f"⚠️ Could not prefetch Lyra config: {e}")
# Detect version from repo name and use appropriate defaults
is_illustrious = 'illustrious' in self.repo_id.lower() or 'xl' in self.repo_id.lower()
if is_illustrious:
# v2 defaults for SDXL/Illustrious
self._config = {
"modality_dims": {
"clip_l": 768,
"clip_g": 1280,
"t5_xl_l": 2048,
"t5_xl_g": 2048
},
"modality_seq_lens": {
"clip_l": 77,
"clip_g": 77,
"t5_xl_l": 512,
"t5_xl_g": 512
},
"fusion_strategy": "adaptive_cantor",
"latent_dim": 2048
}
else:
# v1 defaults for SD1.5
self._config = {
"modality_dims": {
"clip": 768,
"t5": 768
},
"seq_len": 77,
"fusion_strategy": "cantor",
"latent_dim": 768
}
@property
def config(self) -> Dict:
"""Get model config (available before full model load)."""
return self._config or {}
@property
def modality_seq_lens(self) -> Dict[str, int]:
"""Get sequence lengths for each modality.
Handles both v1 (seq_len) and v2 (modality_seq_lens) config formats.
"""
# v2 format: modality_seq_lens dict
if 'modality_seq_lens' in self.config:
return self.config['modality_seq_lens']
# v1 format: derive from single seq_len
seq_len = self.config.get('seq_len', 77)
modality_dims = self.config.get('modality_dims', {})
# Return seq_len for all modalities in v1
return {name: seq_len for name in modality_dims.keys()}
@property
def t5_max_length(self) -> int:
"""Get T5 max sequence length from config.
Handles both v1 (seq_len) and v2 (modality_seq_lens) config formats.
"""
# v2 format: modality_seq_lens dict
if 'modality_seq_lens' in self.config:
seq_lens = self.config['modality_seq_lens']
return seq_lens.get('t5_xl_l', seq_lens.get('t5_xl_g', 512))
# v1 format: single seq_len
return self.config.get('seq_len', 77)
@property
def clip_max_length(self) -> int:
"""Get CLIP max sequence length from config.
Handles both v1 (seq_len) and v2 (modality_seq_lens) config formats.
"""
# v2 format: modality_seq_lens dict
if 'modality_seq_lens' in self.config:
seq_lens = self.config['modality_seq_lens']
return seq_lens.get('clip_l', 77)
# v1 format: single seq_len (same for all modalities)
return self.config.get('seq_len', 77)
@property
def model(self):
if self._model is None:
print(f"📥 Lazy loading Lyra VAE: {self.repo_id}...")
if LYRA_LOADER_AVAILABLE:
# Use the loader module
self._model, self._info = load_vae_lyra(
self.repo_id,
checkpoint=self.checkpoint,
device=self.device,
return_info=True
)
# Update config from loaded info
if self._info and 'config' in self._info:
self._config = self._info['config']
else:
# Fallback to manual loading
self._model = self._load_fallback()
self._info = {"repo_id": self.repo_id, "version": "v2", "config": self._config}
self._model.eval()
self._loaded = True
print(f"✓ Lyra VAE loaded")
return self._model
@property
def info(self) -> Optional[Dict]:
if self._info is None:
return {"repo_id": self.repo_id, "config": self._config}
return self._info
@property
def is_loaded(self) -> bool:
return self._loaded
def _load_fallback(self):
"""Fallback loading if loader module not available."""
if not LYRA_V2_AVAILABLE:
raise ImportError("Lyra VAE v2 not available")
# Config already prefetched
config_dict = self._config
# Use provided checkpoint or find one
if self.checkpoint and self.checkpoint.strip():
checkpoint_file = self.checkpoint.strip()
# Add weights/ prefix if not present and file doesn't exist at root
if not checkpoint_file.startswith('weights/'):
checkpoint_file = f"weights/{checkpoint_file}"
print(f"[Lyra] Using specified checkpoint: {checkpoint_file}")
else:
# Find checkpoint automatically
from huggingface_hub import list_repo_files
repo_files = list_repo_files(self.repo_id, repo_type="model")
checkpoint_files = [f for f in repo_files if f.endswith('.safetensors') or f.endswith('.pt')]
# Prefer weights/ folder
weights_files = [f for f in checkpoint_files if f.startswith('weights/')]
if weights_files:
checkpoint_file = sorted(weights_files)[-1] # Latest
elif checkpoint_files:
checkpoint_file = checkpoint_files[0]
else:
raise FileNotFoundError(f"No checkpoint found in {self.repo_id}")
print(f"[Lyra] Auto-selected checkpoint: {checkpoint_file}")
checkpoint_path = hf_hub_download(
repo_id=self.repo_id,
filename=checkpoint_file,
repo_type="model"
)
# Load weights
if checkpoint_file.endswith('.safetensors'):
state_dict = load_safetensors(checkpoint_path, device="cpu")
else:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
state_dict = checkpoint.get('model_state_dict', checkpoint)
# Build config with all fields from prefetched config
vae_config = LyraV2Config(
modality_dims=config_dict.get('modality_dims'),
modality_seq_lens=config_dict.get('modality_seq_lens'),
binding_config=config_dict.get('binding_config'),
latent_dim=config_dict.get('latent_dim', 2048),
hidden_dim=config_dict.get('hidden_dim', 2048),
fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'),
encoder_layers=config_dict.get('encoder_layers', 3),
decoder_layers=config_dict.get('decoder_layers', 3),
fusion_heads=config_dict.get('fusion_heads', 8),
cantor_depth=config_dict.get('cantor_depth', 8),
cantor_local_window=config_dict.get('cantor_local_window', 3),
alpha_init=config_dict.get('alpha_init', 1.0),
beta_init=config_dict.get('beta_init', 0.3),
)
model = LyraV2(vae_config)
model.load_state_dict(state_dict, strict=False)
model.to(self.device)
return model
def unload(self):
"""Free VRAM by unloading the model."""
if self._model is not None:
del self._model
self._model = None
self._info = None
self._loaded = False
torch.cuda.empty_cache()
print("🗑️ Lyra VAE unloaded")
# ============================================================================
# SCHEDULER FACTORY
# ============================================================================
def get_scheduler(
scheduler_name: str,
config_source: str = "stabilityai/stable-diffusion-xl-base-1.0",
is_sdxl: bool = True
):
"""Create scheduler by name.
Args:
scheduler_name: One of SCHEDULER_CHOICES
config_source: HF repo to load scheduler config from
is_sdxl: Whether this is for SDXL (affects some defaults)
Returns:
Configured scheduler instance
"""
subfolder = "scheduler"
if scheduler_name == SCHEDULER_EULER_A:
return EulerAncestralDiscreteScheduler.from_pretrained(
config_source,
subfolder=subfolder
)
elif scheduler_name == SCHEDULER_EULER:
return EulerDiscreteScheduler.from_pretrained(
config_source,
subfolder=subfolder
)
elif scheduler_name == SCHEDULER_DPM_2M_SDE:
# DPM++ 2M SDE - good for detailed images
return DPMSolverSDEScheduler.from_pretrained(
config_source,
subfolder=subfolder,
algorithm_type="sde-dpmsolver++",
solver_order=2,
use_karras_sigmas=True,
)
elif scheduler_name == SCHEDULER_DPM_2M:
# DPM++ 2M - fast and quality
return DPMSolverMultistepScheduler.from_pretrained(
config_source,
subfolder=subfolder,
algorithm_type="dpmsolver++",
solver_order=2,
use_karras_sigmas=True,
)
else:
print(f"⚠️ Unknown scheduler '{scheduler_name}', defaulting to Euler Ancestral")
return EulerAncestralDiscreteScheduler.from_pretrained(
config_source,
subfolder=subfolder
)
# ============================================================================
# UTILITIES
# ============================================================================
def get_clip_hidden_state(
model_output,
clip_skip: int = 1,
output_hidden_states: bool = True
) -> torch.Tensor:
"""Extract hidden state with clip_skip support."""
if clip_skip == 1 or not output_hidden_states:
return model_output.last_hidden_state
if hasattr(model_output, 'hidden_states') and model_output.hidden_states is not None:
return model_output.hidden_states[-clip_skip]
return model_output.last_hidden_state
# ============================================================================
# SDXL PIPELINE
# ============================================================================
class SDXLFlowMatchingPipeline:
"""Pipeline for SDXL-based flow-matching inference with dual CLIP encoders.
Uses lazy loading for T5 and Lyra - they're only downloaded when actually used.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler,
device: str = "cuda",
t5_loader: Optional[LazyT5Encoder] = None,
lyra_loader: Optional[LazyLyraModel] = None,
clip_skip: int = 1
):
self.vae = vae
self.text_encoder = text_encoder
self.text_encoder_2 = text_encoder_2
self.tokenizer = tokenizer
self.tokenizer_2 = tokenizer_2
self.unet = unet
self.scheduler = scheduler
self.device = device
# Lazy loaders for Lyra components
self.t5_loader = t5_loader
self.lyra_loader = lyra_loader
# Settings
self.clip_skip = clip_skip
self.vae_scale_factor = 0.13025
self.arch = ARCH_SDXL
# Track current scheduler name for UI
self._scheduler_name = SCHEDULER_EULER_A
def set_scheduler(self, scheduler_name: str):
"""Switch scheduler without reloading model."""
if scheduler_name != self._scheduler_name:
self.scheduler = get_scheduler(
scheduler_name,
config_source="stabilityai/stable-diffusion-xl-base-1.0",
is_sdxl=True
)
self._scheduler_name = scheduler_name
print(f"✓ Scheduler changed to: {scheduler_name}")
@property
def t5_encoder(self) -> Optional[T5EncoderModel]:
"""Access T5 encoder (triggers lazy load if needed)."""
return self.t5_loader.encoder if self.t5_loader else None
@property
def t5_tokenizer(self) -> Optional[T5Tokenizer]:
"""Access T5 tokenizer (triggers lazy load if needed)."""
return self.t5_loader.tokenizer if self.t5_loader else None
@property
def lyra_model(self):
"""Access Lyra model (triggers lazy load if needed)."""
return self.lyra_loader.model if self.lyra_loader else None
@property
def lyra_available(self) -> bool:
"""Check if Lyra components are configured (not necessarily loaded)."""
return self.t5_loader is not None and self.lyra_loader is not None
def encode_prompt(
self,
prompt: str,
negative_prompt: str = "",
clip_skip: int = 1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode prompts using dual CLIP encoders for SDXL."""
# CLIP-L encoding
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.device)
with torch.no_grad():
output_hidden_states = clip_skip > 1
clip_l_output = self.text_encoder(
text_input_ids,
output_hidden_states=output_hidden_states
)
prompt_embeds_l = get_clip_hidden_state(clip_l_output, clip_skip, output_hidden_states)
# CLIP-G encoding
text_inputs_2 = self.tokenizer_2(
prompt,
padding="max_length",
max_length=self.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids_2 = text_inputs_2.input_ids.to(self.device)
with torch.no_grad():
clip_g_output = self.text_encoder_2(
text_input_ids_2,
output_hidden_states=output_hidden_states
)
prompt_embeds_g = get_clip_hidden_state(clip_g_output, clip_skip, output_hidden_states)
pooled_prompt_embeds = clip_g_output.text_embeds
# Concatenate CLIP-L and CLIP-G embeddings
prompt_embeds = torch.cat([prompt_embeds_l, prompt_embeds_g], dim=-1)
# Negative prompt
if negative_prompt:
uncond_inputs = self.tokenizer(
negative_prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids = uncond_inputs.input_ids.to(self.device)
uncond_inputs_2 = self.tokenizer_2(
negative_prompt,
padding="max_length",
max_length=self.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids_2 = uncond_inputs_2.input_ids.to(self.device)
with torch.no_grad():
uncond_output_l = self.text_encoder(
uncond_input_ids,
output_hidden_states=output_hidden_states
)
negative_embeds_l = get_clip_hidden_state(uncond_output_l, clip_skip, output_hidden_states)
uncond_output_g = self.text_encoder_2(
uncond_input_ids_2,
output_hidden_states=output_hidden_states
)
negative_embeds_g = get_clip_hidden_state(uncond_output_g, clip_skip, output_hidden_states)
negative_pooled = uncond_output_g.text_embeds
negative_prompt_embeds = torch.cat([negative_embeds_l, negative_embeds_g], dim=-1)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled = torch.zeros_like(pooled_prompt_embeds)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled
def encode_prompt_lyra(
self,
prompt: str,
negative_prompt: str = "",
clip_skip: int = 1,
t5_summary: str = "",
lyra_strength: float = 0.3,
use_separator: bool = True,
clip_include_summary: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode prompts using Lyra VAE v2 fusion (CLIP + T5).
CLIP encoders receive tags only (prompt field).
T5 encoder receives tags + separator + summary.
Args:
prompt: Tags/keywords for CLIP encoding
negative_prompt: Negative tags
clip_skip: CLIP skip layers
t5_summary: Natural language summary for T5
lyra_strength: Blend factor (0=pure CLIP, 1=pure Lyra)
use_separator: If True, use ¶ separator between tags and summary
clip_include_summary: If True, append summary to CLIP input (default False)
This triggers lazy loading of T5 and Lyra if not already loaded.
Uses sequence lengths from Lyra config for proper tokenization.
"""
if not self.lyra_available:
raise ValueError("Lyra VAE components not configured")
# Get sequence lengths from Lyra config (available before full load)
t5_max_length = self.lyra_loader.t5_max_length # 512 for Illustrious
clip_max_length = self.lyra_loader.clip_max_length # 77 for Illustrious
print(f"[Lyra] Using sequence lengths: CLIP={clip_max_length}, T5={t5_max_length}")
# Access properties triggers lazy load
t5_encoder = self.t5_encoder
t5_tokenizer = self.t5_tokenizer
lyra_model = self.lyra_model
# === CLIP ENCODING ===
# CLIP sees tags only (unless clip_include_summary is True)
if clip_include_summary and t5_summary.strip():
clip_prompt = f"{prompt} {t5_summary}"
else:
clip_prompt = prompt
# Get CLIP embeddings with tags only
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
clip_prompt, negative_prompt, clip_skip
)
# === T5 ENCODING ===
# T5 sees tags + separator + summary (or tags + summary if no separator)
SUMMARY_SEPARATOR = "¶"
if t5_summary.strip():
if use_separator:
t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {t5_summary}"
else:
t5_prompt = f"{prompt} {t5_summary}"
else:
# No summary provided - T5 just sees the tags
t5_prompt = prompt
print(f"[Lyra] CLIP input: {clip_prompt[:80]}...")
print(f"[Lyra] T5 input: {t5_prompt[:80]}...")
# Get T5 embeddings with config-specified max_length
t5_inputs = t5_tokenizer(
t5_prompt,
max_length=t5_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(self.device)
with torch.no_grad():
t5_embeds = t5_encoder(**t5_inputs).last_hidden_state
# === LYRA FUSION ===
clip_l_dim = 768
clip_g_dim = 1280
clip_l_embeds = prompt_embeds[..., :clip_l_dim]
clip_g_embeds = prompt_embeds[..., clip_l_dim:]
with torch.no_grad():
modality_inputs = {
'clip_l': clip_l_embeds.float(),
'clip_g': clip_g_embeds.float(),
't5_xl_l': t5_embeds.float(),
't5_xl_g': t5_embeds.float()
}
reconstructions, mu, logvar, _ = lyra_model(
modality_inputs,
target_modalities=['clip_l', 'clip_g']
)
lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
# Normalize reconstructions to match input statistics
clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8)
clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8)
if clip_l_std_ratio > 2.0 or clip_l_std_ratio < 0.5:
lyra_clip_l = (lyra_clip_l - lyra_clip_l.mean()) / (lyra_clip_l.std() + 1e-8)
lyra_clip_l = lyra_clip_l * clip_l_embeds.std() + clip_l_embeds.mean()
if clip_g_std_ratio > 2.0 or clip_g_std_ratio < 0.5:
lyra_clip_g = (lyra_clip_g - lyra_clip_g.mean()) / (lyra_clip_g.std() + 1e-8)
lyra_clip_g = lyra_clip_g * clip_g_embeds.std() + clip_g_embeds.mean()
# Blend original CLIP with Lyra reconstruction
fused_clip_l = (1 - lyra_strength) * clip_l_embeds + lyra_strength * lyra_clip_l
fused_clip_g = (1 - lyra_strength) * clip_g_embeds + lyra_strength * lyra_clip_g
prompt_embeds_fused = torch.cat([fused_clip_l, fused_clip_g], dim=-1)
# === NEGATIVE PROMPT ===
# Negative uses same logic: CLIP sees negative tags only
if negative_prompt:
neg_strength = lyra_strength * 0.5 # Less aggressive for negative
# T5 negative: tags only (no summary for negative)
t5_neg_prompt = negative_prompt
t5_inputs_neg = t5_tokenizer(
t5_neg_prompt,
max_length=t5_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(self.device)
with torch.no_grad():
t5_embeds_neg = t5_encoder(**t5_inputs_neg).last_hidden_state
neg_clip_l = negative_prompt_embeds[..., :clip_l_dim]
neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
modality_inputs_neg = {
'clip_l': neg_clip_l.float(),
'clip_g': neg_clip_g.float(),
't5_xl_l': t5_embeds_neg.float(),
't5_xl_g': t5_embeds_neg.float()
}
recon_neg, _, _, _ = lyra_model(modality_inputs_neg, target_modalities=['clip_l', 'clip_g'])
lyra_neg_l = recon_neg['clip_l'].to(negative_prompt_embeds.dtype)
lyra_neg_g = recon_neg['clip_g'].to(negative_prompt_embeds.dtype)
# Normalize
neg_l_ratio = lyra_neg_l.std() / (neg_clip_l.std() + 1e-8)
neg_g_ratio = lyra_neg_g.std() / (neg_clip_g.std() + 1e-8)
if neg_l_ratio > 2.0 or neg_l_ratio < 0.5:
lyra_neg_l = (lyra_neg_l - lyra_neg_l.mean()) / (lyra_neg_l.std() + 1e-8)
lyra_neg_l = lyra_neg_l * neg_clip_l.std() + neg_clip_l.mean()
if neg_g_ratio > 2.0 or neg_g_ratio < 0.5:
lyra_neg_g = (lyra_neg_g - lyra_neg_g.mean()) / (lyra_neg_g.std() + 1e-8)
lyra_neg_g = lyra_neg_g * neg_clip_g.std() + neg_clip_g.mean()
fused_neg_l = (1 - neg_strength) * neg_clip_l + neg_strength * lyra_neg_l
fused_neg_g = (1 - neg_strength) * neg_clip_g + neg_strength * lyra_neg_g
negative_prompt_embeds_fused = torch.cat([fused_neg_l, fused_neg_g], dim=-1)
else:
negative_prompt_embeds_fused = torch.zeros_like(prompt_embeds_fused)
return prompt_embeds_fused, negative_prompt_embeds_fused, pooled, negative_pooled
def _get_add_time_ids(
self,
original_size: Tuple[int, int],
crops_coords_top_left: Tuple[int, int],
target_size: Tuple[int, int],
dtype: torch.dtype
) -> torch.Tensor:
"""Create time embedding IDs for SDXL."""
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device)
return add_time_ids
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: str = "",
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 20,
guidance_scale: float = 7.5,
shift: float = 0.0,
use_flow_matching: bool = False,
prediction_type: str = "epsilon",
seed: Optional[int] = None,
use_lyra: bool = False,
clip_skip: int = 1,
t5_summary: str = "",
lyra_strength: float = 1.0,
use_separator: bool = True,
clip_include_summary: bool = False,
progress_callback=None
):
"""Generate image using SDXL architecture.
Args:
prompt: Tags/keywords for image generation
negative_prompt: Negative tags
t5_summary: Natural language summary (T5 only, unless clip_include_summary=True)
use_separator: Use ¶ separator between tags and summary in T5 input
clip_include_summary: If True, append summary to CLIP input (default False)
seed: Random seed for reproducibility (must be int)
"""
# Create generator with seed for deterministic generation
if seed is not None:
seed = int(seed)
generator = torch.Generator(device=self.device).manual_seed(seed)
print(f"[SDXL Pipeline] Using seed: {seed}")
else:
generator = None
print("[SDXL Pipeline] No seed provided, using random")
# Encode prompts (Lyra triggers lazy load only if use_lyra=True)
if use_lyra and self.lyra_available:
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra(
prompt, negative_prompt, clip_skip, t5_summary, lyra_strength,
use_separator=use_separator,
clip_include_summary=clip_include_summary
)
else:
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
prompt, negative_prompt, clip_skip
)
# Prepare latents
latent_channels = 4
latent_height = height // 8
latent_width = width // 8
latents = torch.randn(
(1, latent_channels, latent_height, latent_width),
generator=generator,
device=self.device,
dtype=torch.float16
)
# Set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps = self.scheduler.timesteps
if not use_flow_matching:
latents = latents * self.scheduler.init_noise_sigma
# Prepare added time embeddings for SDXL
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
add_time_ids = self._get_add_time_ids(
original_size, crops_coords_top_left, target_size, dtype=torch.float16
)
negative_add_time_ids = add_time_ids
# Denoising loop
for i, t in enumerate(timesteps):
if progress_callback:
progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
if use_flow_matching and shift > 0:
sigma = t.float() / 1000.0
sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
scaling = torch.sqrt(1 + sigma_shifted ** 2)
latent_model_input = latent_model_input / scaling
else:
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
timestep = t.expand(latent_model_input.shape[0])
if guidance_scale > 1.0:
text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
add_text_embeds = torch.cat([negative_pooled, pooled])
add_time_ids_input = torch.cat([negative_add_time_ids, add_time_ids])
else:
text_embeds = prompt_embeds
add_text_embeds = pooled
add_time_ids_input = add_time_ids
added_cond_kwargs = {
"text_embeds": add_text_embeds,
"time_ids": add_time_ids_input
}
noise_pred = self.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeds,
added_cond_kwargs=added_cond_kwargs,
return_dict=False
)[0]
if guidance_scale > 1.0:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if use_flow_matching:
sigma = t.float() / 1000.0
sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
if prediction_type == "v_prediction":
v_pred = noise_pred
alpha_t = torch.sqrt(1 - sigma_shifted ** 2)
sigma_t = sigma_shifted
noise_pred = alpha_t * v_pred + sigma_t * latents
dt = -1.0 / num_inference_steps
latents = latents + dt * noise_pred
else:
# Pass generator for deterministic ancestral/SDE sampling
latents = self.scheduler.step(
noise_pred, t, latents, generator=generator, return_dict=False
)[0]
# Decode
latents = latents / self.vae_scale_factor
with torch.no_grad():
image = self.vae.decode(latents.to(self.vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
image = Image.fromarray(image[0])
return image
# ============================================================================
# SD1.5 PIPELINE
# ============================================================================
class SD15FlowMatchingPipeline:
"""Pipeline for SD1.5-based flow-matching inference."""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler,
device: str = "cuda",
t5_loader: Optional[LazyT5Encoder] = None,
lyra_loader: Optional[LazyLyraModel] = None,
):
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.unet = unet
self.scheduler = scheduler
self.device = device
self.t5_loader = t5_loader
self.lyra_loader = lyra_loader
self.vae_scale_factor = 0.18215
self.arch = ARCH_SD15
self.is_lune_model = False
@property
def t5_encoder(self):
return self.t5_loader.encoder if self.t5_loader else None
@property
def t5_tokenizer(self):
return self.t5_loader.tokenizer if self.t5_loader else None
@property
def lyra_model(self):
return self.lyra_loader.model if self.lyra_loader else None
@property
def lyra_available(self) -> bool:
return self.t5_loader is not None and self.lyra_loader is not None
def encode_prompt(self, prompt: str, negative_prompt: str = ""):
"""Encode text prompts to embeddings."""
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.device)
with torch.no_grad():
prompt_embeds = self.text_encoder(text_input_ids)[0]
if negative_prompt:
uncond_inputs = self.tokenizer(
negative_prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids = uncond_inputs.input_ids.to(self.device)
with torch.no_grad():
negative_prompt_embeds = self.text_encoder(uncond_input_ids)[0]
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
return prompt_embeds, negative_prompt_embeds
def encode_prompt_lyra(self, prompt: str, negative_prompt: str = ""):
"""Encode using Lyra VAE v1 (CLIP + T5 fusion).
Uses sequence lengths from Lyra config for proper tokenization.
"""
if not self.lyra_available:
raise ValueError("Lyra VAE components not configured")
# Get sequence length from config (v1 uses same length for clip and t5)
# Default to 77 for SD1.5/v1
t5_max_length = self.lyra_loader.config.get('seq_len', 77)
print(f"[Lyra v1] Using sequence length: {t5_max_length}")
t5_encoder = self.t5_encoder
t5_tokenizer = self.t5_tokenizer
lyra_model = self.lyra_model
# CLIP
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.device)
with torch.no_grad():
clip_embeds = self.text_encoder(text_input_ids)[0]
# T5 with config-specified max_length
t5_inputs = t5_tokenizer(
prompt,
max_length=t5_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(self.device)
with torch.no_grad():
t5_embeds = t5_encoder(**t5_inputs).last_hidden_state
# Fuse
modality_inputs = {'clip': clip_embeds, 't5': t5_embeds}
with torch.no_grad():
reconstructions, mu, logvar = lyra_model(
modality_inputs,
target_modalities=['clip']
)
prompt_embeds = reconstructions['clip']
# Negative
if negative_prompt:
uncond_inputs = self.tokenizer(
negative_prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids = uncond_inputs.input_ids.to(self.device)
with torch.no_grad():
clip_embeds_uncond = self.text_encoder(uncond_input_ids)[0]
t5_inputs_uncond = t5_tokenizer(
negative_prompt,
max_length=t5_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(self.device)
with torch.no_grad():
t5_embeds_uncond = t5_encoder(**t5_inputs_uncond).last_hidden_state
modality_inputs_uncond = {'clip': clip_embeds_uncond, 't5': t5_embeds_uncond}
with torch.no_grad():
reconstructions_uncond, _, _ = lyra_model(
modality_inputs_uncond,
target_modalities=['clip']
)
negative_prompt_embeds = reconstructions_uncond['clip']
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
return prompt_embeds, negative_prompt_embeds
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: str = "",
height: int = 512,
width: int = 512,
num_inference_steps: int = 20,
guidance_scale: float = 7.5,
shift: float = 2.5,
use_flow_matching: bool = True,
prediction_type: str = "epsilon",
seed: Optional[int] = None,
use_lyra: bool = False,
clip_skip: int = 1,
t5_summary: str = "",
lyra_strength: float = 1.0,
progress_callback=None
):
"""Generate image."""
# Create generator with seed for deterministic generation
if seed is not None:
seed = int(seed)
generator = torch.Generator(device=self.device).manual_seed(seed)
print(f"[SD1.5 Pipeline] Using seed: {seed}")
else:
generator = None
print("[SD1.5 Pipeline] No seed provided, using random")
if use_lyra and self.lyra_available:
prompt_embeds, negative_prompt_embeds = self.encode_prompt_lyra(prompt, negative_prompt)
else:
prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, negative_prompt)
latent_channels = 4
latent_height = height // 8
latent_width = width // 8
latents = torch.randn(
(1, latent_channels, latent_height, latent_width),
generator=generator,
device=self.device,
dtype=torch.float32
)
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps = self.scheduler.timesteps
if not use_flow_matching:
latents = latents * self.scheduler.init_noise_sigma
for i, t in enumerate(timesteps):
if progress_callback:
progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
if use_flow_matching and shift > 0:
sigma = t.float() / 1000.0
sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
scaling = torch.sqrt(1 + sigma_shifted ** 2)
latent_model_input = latent_model_input / scaling
else:
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
timestep = t.expand(latent_model_input.shape[0])
text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if guidance_scale > 1.0 else prompt_embeds
noise_pred = self.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeds,
return_dict=False
)[0]
if guidance_scale > 1.0:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if use_flow_matching:
sigma = t.float() / 1000.0
sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
if prediction_type == "v_prediction":
v_pred = noise_pred
alpha_t = torch.sqrt(1 - sigma_shifted ** 2)
sigma_t = sigma_shifted
noise_pred = alpha_t * v_pred + sigma_t * latents
dt = -1.0 / num_inference_steps
latents = latents + dt * noise_pred
else:
# Pass generator for deterministic ancestral/SDE sampling
latents = self.scheduler.step(
noise_pred, t, latents, generator=generator, return_dict=False
)[0]
latents = latents / self.vae_scale_factor
if self.is_lune_model:
latents = latents * 5.52
with torch.no_grad():
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
image = Image.fromarray(image[0])
return image
# ============================================================================
# MODEL LOADERS
# ============================================================================
def load_lune_checkpoint(repo_id: str, filename: str, device: str = "cuda"):
"""Load Lune checkpoint from .pt file."""
print(f"📥 Downloading: {repo_id}/{filename}")
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print(f"🏗️ Initializing SD1.5 UNet...")
unet = UNet2DConditionModel.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="unet",
torch_dtype=torch.float32
)
student_state_dict = checkpoint["student"]
cleaned_dict = {}
for key, value in student_state_dict.items():
if key.startswith("unet."):
cleaned_dict[key[5:]] = value
else:
cleaned_dict[key] = value
unet.load_state_dict(cleaned_dict, strict=False)
step = checkpoint.get("gstep", "unknown")
print(f"✅ Loaded Lune from step {step}")
return unet.to(device)
def load_illustrious_xl(
repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
filename: str = "",
device: str = "cuda"
) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]:
"""Load Illustrious XL from single safetensors file."""
from diffusers import StableDiffusionXLPipeline
# Default checkpoint if none specified
if not filename or not filename.strip():
filename = "illustriousXL_v01.safetensors"
print(f"📥 Loading Illustrious XL: {repo_id}/{filename}")
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
print(f"✓ Downloaded: {checkpoint_path}")
print("📦 Loading with StableDiffusionXLPipeline.from_single_file()...")
pipe = StableDiffusionXLPipeline.from_single_file(
checkpoint_path,
torch_dtype=torch.float16,
use_safetensors=True,
)
unet = pipe.unet.to(device)
vae = pipe.vae.to(device)
text_encoder = pipe.text_encoder.to(device)
text_encoder_2 = pipe.text_encoder_2.to(device)
tokenizer = pipe.tokenizer
tokenizer_2 = pipe.tokenizer_2
del pipe
torch.cuda.empty_cache()
print("✅ Illustrious XL loaded!")
print(f" UNet params: {sum(p.numel() for p in unet.parameters()):,}")
print(f" VAE params: {sum(p.numel() for p in vae.parameters()):,}")
return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
def load_sdxl_base(device: str = "cuda"):
"""Load standard SDXL base model."""
print("📥 Loading SDXL Base 1.0...")
unet = UNet2DConditionModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="unet",
torch_dtype=torch.float16
).to(device)
# Use fp16-fix VAE to avoid NaN issues with SDXL's original VAE in fp16
print(" Using madebyollin/sdxl-vae-fp16-fix for stable fp16 decoding...")
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16
).to(device)
text_encoder = CLIPTextModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="text_encoder",
torch_dtype=torch.float16
).to(device)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="text_encoder_2",
torch_dtype=torch.float16
).to(device)
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="tokenizer"
)
tokenizer_2 = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="tokenizer_2"
)
print("✅ SDXL Base loaded!")
return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
# ============================================================================
# PIPELINE INITIALIZATION
# ============================================================================
def initialize_pipeline(model_choice: str, device: str = "cuda", checkpoint: str = "", lyra_checkpoint: str = ""):
"""Initialize the complete pipeline based on model choice.
Uses lazy loading for T5 and Lyra - they won't be downloaded until first use.
Args:
model_choice: Model selection from dropdown
device: Target device
checkpoint: Optional custom checkpoint filename (e.g., "my_model.safetensors")
lyra_checkpoint: Optional custom Lyra VAE checkpoint filename
"""
print(f"🚀 Initializing {model_choice} pipeline...")
if checkpoint:
print(f" Custom model checkpoint: {checkpoint}")
if lyra_checkpoint:
print(f" Custom Lyra checkpoint: {lyra_checkpoint}")
is_sdxl = "Illustrious" in model_choice or "SDXL" in model_choice
is_lune = "Lune" in model_choice
if is_sdxl:
# SDXL-based models
if "Illustrious" in model_choice:
unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl(
device=device,
filename=checkpoint
)
else:
unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_sdxl_base(device=device)
# Create LAZY loaders for T5 and Lyra (no download yet!)
print("📋 Configuring lazy loaders for T5-XL and Lyra VAE (will download on first use)")
t5_loader = LazyT5Encoder(
model_name=T5_XL_MODEL, # google/flan-t5-xl
device=device,
dtype=torch.float16
)
lyra_loader = LazyLyraModel(
repo_id=LYRA_ILLUSTRIOUS_REPO,
device=device,
checkpoint=lyra_checkpoint if lyra_checkpoint and lyra_checkpoint.strip() else None
)
# Default scheduler: Euler Ancestral
scheduler = get_scheduler(SCHEDULER_EULER_A, is_sdxl=True)
pipeline = SDXLFlowMatchingPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
device=device,
t5_loader=t5_loader,
lyra_loader=lyra_loader,
clip_skip=1
)
else:
# SD1.5-based models
vae = AutoencoderKL.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="vae",
torch_dtype=torch.float32
).to(device)
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=torch.float32
).to(device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# Lazy loaders for SD1.5 Lyra (T5-base)
print("📋 Configuring lazy loaders for T5-base and Lyra VAE v1 (will download on first use)")
t5_loader = LazyT5Encoder(
model_name=T5_BASE_MODEL, # google/flan-t5-base
device=device,
dtype=torch.float32
)
lyra_loader = LazyLyraModel(
repo_id=LYRA_SD15_REPO,
device=device,
checkpoint=lyra_checkpoint if lyra_checkpoint and lyra_checkpoint.strip() else None
)
# Load UNet
if is_lune:
repo_id = "AbstractPhil/sd15-flow-lune"
# Use custom checkpoint or default
if checkpoint and checkpoint.strip():
filename = checkpoint
else:
filename = "sd15_flow_lune_e34_s34000.pt"
unet = load_lune_checkpoint(repo_id, filename, device)
else:
unet = UNet2DConditionModel.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="unet",
torch_dtype=torch.float32
).to(device)
scheduler = EulerDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="scheduler"
)
pipeline = SD15FlowMatchingPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
device=device,
t5_loader=t5_loader,
lyra_loader=lyra_loader,
)
pipeline.is_lune_model = is_lune
print("✅ Pipeline initialized! (T5 and Lyra will load on first use)")
return pipeline
# ============================================================================
# GLOBAL STATE
# ============================================================================
CURRENT_PIPELINE = None
CURRENT_MODEL = None
CURRENT_CHECKPOINT = None
CURRENT_LYRA_CHECKPOINT = None
def get_pipeline(model_choice: str, checkpoint: str = "", lyra_checkpoint: str = ""):
"""Get or create pipeline for selected model."""
global CURRENT_PIPELINE, CURRENT_MODEL, CURRENT_CHECKPOINT, CURRENT_LYRA_CHECKPOINT
# Normalize empty values
checkpoint = checkpoint.strip() if checkpoint else ""
lyra_checkpoint = lyra_checkpoint.strip() if lyra_checkpoint else ""
# Reinitialize if model or any checkpoint changed
if (CURRENT_PIPELINE is None or
CURRENT_MODEL != model_choice or
CURRENT_CHECKPOINT != checkpoint or
CURRENT_LYRA_CHECKPOINT != lyra_checkpoint):
CURRENT_PIPELINE = initialize_pipeline(
model_choice, device="cuda",
checkpoint=checkpoint,
lyra_checkpoint=lyra_checkpoint
)
CURRENT_MODEL = model_choice
CURRENT_CHECKPOINT = checkpoint
CURRENT_LYRA_CHECKPOINT = lyra_checkpoint
return CURRENT_PIPELINE
# ============================================================================
# INFERENCE
# ============================================================================
def estimate_duration(num_steps: int, width: int, height: int, use_lyra: bool = False, is_sdxl: bool = False) -> int:
"""Estimate GPU duration."""
base_time_per_step = 0.5 if is_sdxl else 0.3
resolution_factor = (width * height) / (512 * 512)
estimated = num_steps * base_time_per_step * resolution_factor
if use_lyra:
estimated *= 2
estimated += 10 # Extra time for lazy loading on first use
return int(estimated + 20)
@spaces.GPU(duration=lambda *args: estimate_duration(
args[8], args[10], args[11], args[14],
"SDXL" in args[3] or "Illustrious" in args[3]
))
def generate_image(
prompt: str,
t5_summary: str,
negative_prompt: str,
model_choice: str,
checkpoint: str,
lyra_checkpoint: str,
scheduler_choice: str,
clip_skip: int,
num_steps: int,
cfg_scale: float,
width: int,
height: int,
shift: float,
use_flow_matching: bool,
use_lyra: bool,
lyra_strength: float,
use_separator: bool,
clip_include_summary: bool,
seed: int,
randomize_seed: bool,
progress=gr.Progress()
):
"""Generate image with ZeroGPU support.
Args:
prompt: Tags/keywords (CLIP input)
t5_summary: Natural language summary (T5 input, unless clip_include_summary)
checkpoint: Custom model checkpoint filename (empty for default)
lyra_checkpoint: Custom Lyra VAE checkpoint filename (empty for default)
use_separator: Use ¶ separator between tags and summary
clip_include_summary: If True, CLIP also sees the summary
"""
# Ensure seed is an integer (Gradio sliders return floats)
seed = int(seed)
if randomize_seed:
seed = np.random.randint(0, 2**32 - 1)
print(f"🎲 Using seed: {seed} (randomize={randomize_seed})")
def progress_callback(step, total, desc):
progress((step + 1) / total, desc=desc)
try:
pipeline = get_pipeline(model_choice, checkpoint, lyra_checkpoint)
# Update scheduler if needed (SDXL only)
is_sdxl = "SDXL" in model_choice or "Illustrious" in model_choice
if is_sdxl and hasattr(pipeline, 'set_scheduler'):
pipeline.set_scheduler(scheduler_choice)
prediction_type = "epsilon"
if not is_sdxl and "Lune" in model_choice:
prediction_type = "v_prediction"
if not use_lyra or not pipeline.lyra_available:
progress(0.05, desc="Generating...")
image = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_steps,
guidance_scale=cfg_scale,
shift=shift,
use_flow_matching=use_flow_matching,
prediction_type=prediction_type,
seed=seed,
use_lyra=False,
clip_skip=clip_skip,
progress_callback=progress_callback
)
progress(1.0, desc="Complete!")
return image, None, seed
else:
# Side-by-side comparison: SAME seed for both!
progress(0.05, desc="Generating standard...")
image_standard = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_steps,
guidance_scale=cfg_scale,
shift=shift,
use_flow_matching=use_flow_matching,
prediction_type=prediction_type,
seed=seed, # Same seed
use_lyra=False,
clip_skip=clip_skip,
progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d)
)
progress(0.5, desc="Generating Lyra fusion (loading T5 + Lyra if needed)...")
image_lyra = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_steps,
guidance_scale=cfg_scale,
shift=shift,
use_flow_matching=use_flow_matching,
prediction_type=prediction_type,
seed=seed, # Same seed for deterministic comparison
use_lyra=True,
clip_skip=clip_skip,
t5_summary=t5_summary,
lyra_strength=lyra_strength,
use_separator=use_separator,
clip_include_summary=clip_include_summary,
progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d)
)
progress(1.0, desc="Complete!")
return image_standard, image_lyra, seed
except Exception as e:
print(f"❌ Generation failed: {e}")
import traceback
traceback.print_exc()
raise e
# ============================================================================
# GRADIO UI
# ============================================================================
def create_demo():
"""Create Gradio interface."""
with gr.Blocks() as demo:
gr.Markdown("""
# 🌙 Lyra/Lune Flow-Matching Image Generation
**Geometric crystalline diffusion** by [AbstractPhil](https://huggingface.co/AbstractPhil)
Generate images using SD1.5 and SDXL-based models with geometric deep learning:
| Model | Architecture | Lyra Version | Best For |
|-------|-------------|--------------|----------|
| **Illustrious XL** | SDXL | v2 (T5-XL) | Anime/illustration, high detail |
| **SDXL Base** | SDXL | v2 (T5-XL) | Photorealistic, general purpose |
| **Flow-Lune** | SD1.5 | v1 (T5-base) | Fast flow matching (15-25 steps) |
| **SD1.5 Base** | SD1.5 | v1 (T5-base) | Baseline comparison |
**Lazy Loading**: T5 and Lyra VAE are only downloaded when you enable Lyra fusion!
""")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.TextArea(
label="Prompt (Tags for CLIP)",
value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background",
lines=3,
info="CLIP encoders see these tags. T5 also sees these + the summary below."
)
t5_summary = gr.TextArea(
label="T5 Summary (Natural Language - T5 Only)",
value="A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms against a bright sky",
lines=2,
info="T5 sees: tags ¶ summary. CLIP sees: tags only (unless 'Include Summary in CLIP' is enabled)."
)
negative_prompt = gr.TextArea(
label="Negative Prompt",
value="lowres, bad anatomy, bad hands, text, error, cropped, worst quality, low quality",
lines=2
)
model_choice = gr.Dropdown(
label="Model",
choices=[
"Illustrious XL",
"SDXL Base",
"Flow-Lune (SD1.5)",
"SD1.5 Base"
],
value="Illustrious XL"
)
with gr.Accordion("Advanced Options", open=False):
checkpoint = gr.Textbox(
label="Model Checkpoint (optional)",
value="",
placeholder="e.g., illustriousXL_v01.safetensors",
info="Leave empty for default. Illustrious: .safetensors, Lune: .pt"
)
lyra_checkpoint = gr.Textbox(
label="Lyra VAE Checkpoint (optional)",
value="weights/lyra_illustrious_step_12000.safetensors",
placeholder="e.g., lyra_e100_s50000.safetensors",
info="Leave empty for latest. Loaded from weights/ folder in Lyra repo."
)
scheduler_choice = gr.Dropdown(
label="Scheduler (SDXL only)",
choices=SCHEDULER_CHOICES,
value=SCHEDULER_EULER_A,
info="Euler Ancestral recommended for Illustrious"
)
clip_skip = gr.Slider(
label="CLIP Skip",
minimum=1,
maximum=4,
value=2,
step=1,
info="2 recommended for Illustrious, 1 for others"
)
use_lyra = gr.Checkbox(
label="Enable Lyra VAE (CLIP+T5 Fusion)",
value=True, # DEFAULT: ON
info="Enables lazy loading of T5 and Lyra on first use"
)
lyra_strength = gr.Slider(
label="Lyra Blend Strength",
minimum=0.0,
maximum=3.0,
value=1.0,
step=0.05,
info="0.0 = pure CLIP, 1.0 = pure Lyra reconstruction"
)
with gr.Accordion("Lyra Advanced Settings", open=False):
use_separator = gr.Checkbox(
label="Use ¶ Separator",
value=True,
info="Insert ¶ between tags and summary in T5 input"
)
clip_include_summary = gr.Checkbox(
label="Include Summary in CLIP",
value=False,
info="By default CLIP sees tags only. Enable to append summary to CLIP input."
)
with gr.Accordion("Generation Settings", open=True):
num_steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50,
value=25,
step=1
)
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1.0,
maximum=20.0,
value=7.0,
step=0.5
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=1536,
value=1024,
step=64
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=1536,
value=1024,
step=64
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2**32 - 1,
value=42, # DEFAULT: 42
step=1
)
randomize_seed = gr.Checkbox(
label="Randomize Seed",
value=False # DEFAULT: OFF for reproducibility
)
with gr.Accordion("Advanced (Flow Matching)", open=False):
use_flow_matching = gr.Checkbox(
label="Enable Flow Matching",
value=False,
info="Use flow matching ODE (for Lune only)"
)
shift = gr.Slider(
label="Shift",
minimum=0.0,
maximum=5.0,
value=0.0,
step=0.1,
info="Flow matching shift (0=disabled)"
)
generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
with gr.Column(scale=1):
with gr.Row():
output_image_standard = gr.Image(
label="Standard",
type="pil"
)
output_image_lyra = gr.Image(
label="Lyra Fusion 🎵",
type="pil",
visible=True # Visible by default since Lyra is on
)
output_seed = gr.Number(label="Seed Used", precision=0)
gr.Markdown("""
### Tips
- **Lazy Loading**: T5-XL (~3GB) and Lyra VAE only download when you enable Lyra
- **Illustrious XL**: Use CLIP skip 2, Euler Ancestral scheduler
- **Schedulers**: DPM++ 2M SDE for detail, Euler A for speed
- **Lyra v2**: Uses `google/flan-t5-xl` for richer semantics
- **Same Seed**: Both Standard and Lyra use the same seed for fair comparison
""")
# Event handlers
def on_model_change(model_name):
"""Update defaults based on model."""
if "Illustrious" in model_name:
return {
clip_skip: gr.update(value=2),
width: gr.update(value=1024),
height: gr.update(value=1024),
num_steps: gr.update(value=25),
use_flow_matching: gr.update(value=False),
shift: gr.update(value=0.0),
scheduler_choice: gr.update(visible=True, value=SCHEDULER_EULER_A)
}
elif "SDXL" in model_name:
return {
clip_skip: gr.update(value=1),
width: gr.update(value=1024),
height: gr.update(value=1024),
num_steps: gr.update(value=30),
use_flow_matching: gr.update(value=False),
shift: gr.update(value=0.0),
scheduler_choice: gr.update(visible=True, value=SCHEDULER_EULER_A)
}
elif "Lune" in model_name:
return {
clip_skip: gr.update(value=1),
width: gr.update(value=512),
height: gr.update(value=512),
num_steps: gr.update(value=20),
use_flow_matching: gr.update(value=True),
shift: gr.update(value=2.5),
scheduler_choice: gr.update(visible=False)
}
else: # SD1.5 Base
return {
clip_skip: gr.update(value=1),
width: gr.update(value=512),
height: gr.update(value=512),
num_steps: gr.update(value=30),
use_flow_matching: gr.update(value=False),
shift: gr.update(value=0.0),
scheduler_choice: gr.update(visible=False)
}
def on_lyra_toggle(enabled):
"""Show/hide Lyra comparison."""
if enabled:
return {
output_image_standard: gr.update(visible=True, label="Standard"),
output_image_lyra: gr.update(visible=True, label="Lyra Fusion 🎵")
}
else:
return {
output_image_standard: gr.update(visible=True, label="Generated Image"),
output_image_lyra: gr.update(visible=False)
}
model_choice.change(
fn=on_model_change,
inputs=[model_choice],
outputs=[clip_skip, width, height, num_steps, use_flow_matching, shift, scheduler_choice]
)
use_lyra.change(
fn=on_lyra_toggle,
inputs=[use_lyra],
outputs=[output_image_standard, output_image_lyra]
)
generate_btn.click(
fn=generate_image,
inputs=[
prompt, t5_summary, negative_prompt, model_choice, checkpoint, lyra_checkpoint,
scheduler_choice, clip_skip,
num_steps, cfg_scale, width, height, shift,
use_flow_matching, use_lyra, lyra_strength, use_separator, clip_include_summary,
seed, randomize_seed
],
outputs=[output_image_standard, output_image_lyra, output_seed]
)
return demo
# ============================================================================
# LAUNCH
# ============================================================================
if __name__ == "__main__":
demo = create_demo()
demo.queue(max_size=20)
demo.launch()