""" Lyra/Lune Flow-Matching Inference Space Author: AbstractPhil License: MIT SD1.5 and SDXL-based flow matching with geometric crystalline architectures. Supports Illustrious XL, standard SDXL, and SD1.5 variants. Lyra VAE Versions: - v1: SD1.5 (768 dim CLIP + T5-base) - geofractal.model.vae.vae_lyra - v2: SDXL/Illustrious (768 CLIP-L + 1280 CLIP-G + 2048 T5-XL) - geofractal.model.vae.vae_lyra_v2 Features: - Lazy loading: T5 and Lyra only download when first used - Multiple schedulers: Euler Ancestral, Euler, DPM++ 2M SDE, DPM++ 2M - Integrated loader module for automatic version detection """ import os import json import torch import gradio as gr import numpy as np from PIL import Image from typing import Optional, Dict, Tuple, Union import spaces from safetensors.torch import load_file as load_safetensors from diffusers import ( UNet2DConditionModel, AutoencoderKL, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSDEScheduler, ) from transformers import ( CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection, T5EncoderModel, T5Tokenizer ) from huggingface_hub import hf_hub_download # Import Lyra VAE v1 (SD1.5) from geofractal try: from geofractal.model.vae.vae_lyra import MultiModalVAE as LyraV1, MultiModalVAEConfig as LyraV1Config LYRA_V1_AVAILABLE = True except ImportError: print("⚠️ Lyra VAE v1 not available") LYRA_V1_AVAILABLE = False # Import Lyra VAE v2 (SDXL/Illustrious) from geofractal try: from geofractal.model.vae.vae_lyra_v2 import MultiModalVAE as LyraV2, MultiModalVAEConfig as LyraV2Config LYRA_V2_AVAILABLE = True except ImportError: print("⚠️ Lyra VAE v2 not available") LYRA_V2_AVAILABLE = False # Import Lyra loader module try: from geofractal.model.vae.loader import load_vae_lyra, load_lyra_illustrious LYRA_LOADER_AVAILABLE = True except ImportError: print("⚠️ Lyra loader module not available, using fallback") LYRA_LOADER_AVAILABLE = False # ============================================================================ # CONSTANTS # ============================================================================ ARCH_SD15 = "sd15" ARCH_SDXL = "sdxl" # Scheduler names SCHEDULER_EULER_A = "Euler Ancestral" SCHEDULER_EULER = "Euler" SCHEDULER_DPM_2M_SDE = "DPM++ 2M SDE" SCHEDULER_DPM_2M = "DPM++ 2M" SCHEDULER_CHOICES = [ SCHEDULER_EULER_A, SCHEDULER_EULER, SCHEDULER_DPM_2M_SDE, SCHEDULER_DPM_2M, ] # ComfyUI key prefixes for SDXL single-file checkpoints COMFYUI_UNET_PREFIX = "model.diffusion_model." COMFYUI_CLIP_L_PREFIX = "conditioner.embedders.0.transformer." COMFYUI_CLIP_G_PREFIX = "conditioner.embedders.1.model." COMFYUI_VAE_PREFIX = "first_stage_model." # Lyra repos LYRA_ILLUSTRIOUS_REPO = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious" LYRA_SD15_REPO = "AbstractPhil/vae-lyra" # T5 model - use flan-t5-xl (what Lyra was trained on) T5_XL_MODEL = "google/flan-t5-xl" T5_BASE_MODEL = "google/flan-t5-base" # ============================================================================ # LAZY LOADERS # ============================================================================ class LazyT5Encoder: """Lazy loader for T5 encoder - only downloads/loads when first accessed.""" def __init__(self, model_name: str = T5_XL_MODEL, device: str = "cuda", dtype=torch.float16): self.model_name = model_name self.device = device self.dtype = dtype self._encoder = None self._tokenizer = None self._loaded = False @property def encoder(self) -> T5EncoderModel: if self._encoder is None: print(f"📥 Lazy loading T5 encoder: {self.model_name}...") self._encoder = T5EncoderModel.from_pretrained( self.model_name, torch_dtype=self.dtype ).to(self.device) self._encoder.eval() print(f"✓ T5 encoder loaded ({sum(p.numel() for p in self._encoder.parameters())/1e6:.1f}M params)") self._loaded = True return self._encoder @property def tokenizer(self) -> T5Tokenizer: if self._tokenizer is None: print(f"📥 Loading T5 tokenizer: {self.model_name}...") self._tokenizer = T5Tokenizer.from_pretrained(self.model_name) print("✓ T5 tokenizer loaded") return self._tokenizer @property def is_loaded(self) -> bool: return self._loaded def unload(self): """Free VRAM by unloading the encoder.""" if self._encoder is not None: del self._encoder self._encoder = None self._loaded = False torch.cuda.empty_cache() print("🗑️ T5 encoder unloaded") class LazyLyraModel: """Lazy loader for Lyra VAE - only downloads/loads when first accessed. Exposes config with modality_seq_lens for proper tokenization lengths. """ def __init__( self, repo_id: str = LYRA_ILLUSTRIOUS_REPO, device: str = "cuda", checkpoint: Optional[str] = None ): self.repo_id = repo_id self.device = device self.checkpoint = checkpoint self._model = None self._info = None self._config = None self._loaded = False # Pre-fetch config without loading model (lightweight) self._prefetch_config() def _prefetch_config(self): """Fetch config.json to get sequence lengths without loading the full model.""" try: config_path = hf_hub_download( repo_id=self.repo_id, filename="config.json", repo_type="model" ) with open(config_path, 'r') as f: self._config = json.load(f) # Detect version from config is_v2 = 'modality_seq_lens' in self._config or 'binding_config' in self._config version = "v2" if is_v2 else "v1" print(f"📋 Lyra config prefetched: {self.repo_id} ({version})") if is_v2: print(f" Sequence lengths: {self._config.get('modality_seq_lens', {})}") else: print(f" Sequence length: {self._config.get('seq_len', 77)}") except Exception as e: print(f"⚠️ Could not prefetch Lyra config: {e}") # Detect version from repo name and use appropriate defaults is_illustrious = 'illustrious' in self.repo_id.lower() or 'xl' in self.repo_id.lower() if is_illustrious: # v2 defaults for SDXL/Illustrious self._config = { "modality_dims": { "clip_l": 768, "clip_g": 1280, "t5_xl_l": 2048, "t5_xl_g": 2048 }, "modality_seq_lens": { "clip_l": 77, "clip_g": 77, "t5_xl_l": 512, "t5_xl_g": 512 }, "fusion_strategy": "adaptive_cantor", "latent_dim": 2048 } else: # v1 defaults for SD1.5 self._config = { "modality_dims": { "clip": 768, "t5": 768 }, "seq_len": 77, "fusion_strategy": "cantor", "latent_dim": 768 } @property def config(self) -> Dict: """Get model config (available before full model load).""" return self._config or {} @property def modality_seq_lens(self) -> Dict[str, int]: """Get sequence lengths for each modality. Handles both v1 (seq_len) and v2 (modality_seq_lens) config formats. """ # v2 format: modality_seq_lens dict if 'modality_seq_lens' in self.config: return self.config['modality_seq_lens'] # v1 format: derive from single seq_len seq_len = self.config.get('seq_len', 77) modality_dims = self.config.get('modality_dims', {}) # Return seq_len for all modalities in v1 return {name: seq_len for name in modality_dims.keys()} @property def t5_max_length(self) -> int: """Get T5 max sequence length from config. Handles both v1 (seq_len) and v2 (modality_seq_lens) config formats. """ # v2 format: modality_seq_lens dict if 'modality_seq_lens' in self.config: seq_lens = self.config['modality_seq_lens'] return seq_lens.get('t5_xl_l', seq_lens.get('t5_xl_g', 512)) # v1 format: single seq_len return self.config.get('seq_len', 77) @property def clip_max_length(self) -> int: """Get CLIP max sequence length from config. Handles both v1 (seq_len) and v2 (modality_seq_lens) config formats. """ # v2 format: modality_seq_lens dict if 'modality_seq_lens' in self.config: seq_lens = self.config['modality_seq_lens'] return seq_lens.get('clip_l', 77) # v1 format: single seq_len (same for all modalities) return self.config.get('seq_len', 77) @property def model(self): if self._model is None: print(f"📥 Lazy loading Lyra VAE: {self.repo_id}...") if LYRA_LOADER_AVAILABLE: # Use the loader module self._model, self._info = load_vae_lyra( self.repo_id, checkpoint=self.checkpoint, device=self.device, return_info=True ) # Update config from loaded info if self._info and 'config' in self._info: self._config = self._info['config'] else: # Fallback to manual loading self._model = self._load_fallback() self._info = {"repo_id": self.repo_id, "version": "v2", "config": self._config} self._model.eval() self._loaded = True print(f"✓ Lyra VAE loaded") return self._model @property def info(self) -> Optional[Dict]: if self._info is None: return {"repo_id": self.repo_id, "config": self._config} return self._info @property def is_loaded(self) -> bool: return self._loaded def _load_fallback(self): """Fallback loading if loader module not available.""" if not LYRA_V2_AVAILABLE: raise ImportError("Lyra VAE v2 not available") # Config already prefetched config_dict = self._config # Use provided checkpoint or find one if self.checkpoint and self.checkpoint.strip(): checkpoint_file = self.checkpoint.strip() # Add weights/ prefix if not present and file doesn't exist at root if not checkpoint_file.startswith('weights/'): checkpoint_file = f"weights/{checkpoint_file}" print(f"[Lyra] Using specified checkpoint: {checkpoint_file}") else: # Find checkpoint automatically from huggingface_hub import list_repo_files repo_files = list_repo_files(self.repo_id, repo_type="model") checkpoint_files = [f for f in repo_files if f.endswith('.safetensors') or f.endswith('.pt')] # Prefer weights/ folder weights_files = [f for f in checkpoint_files if f.startswith('weights/')] if weights_files: checkpoint_file = sorted(weights_files)[-1] # Latest elif checkpoint_files: checkpoint_file = checkpoint_files[0] else: raise FileNotFoundError(f"No checkpoint found in {self.repo_id}") print(f"[Lyra] Auto-selected checkpoint: {checkpoint_file}") checkpoint_path = hf_hub_download( repo_id=self.repo_id, filename=checkpoint_file, repo_type="model" ) # Load weights if checkpoint_file.endswith('.safetensors'): state_dict = load_safetensors(checkpoint_path, device="cpu") else: checkpoint = torch.load(checkpoint_path, map_location="cpu") state_dict = checkpoint.get('model_state_dict', checkpoint) # Build config with all fields from prefetched config vae_config = LyraV2Config( modality_dims=config_dict.get('modality_dims'), modality_seq_lens=config_dict.get('modality_seq_lens'), binding_config=config_dict.get('binding_config'), latent_dim=config_dict.get('latent_dim', 2048), hidden_dim=config_dict.get('hidden_dim', 2048), fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'), encoder_layers=config_dict.get('encoder_layers', 3), decoder_layers=config_dict.get('decoder_layers', 3), fusion_heads=config_dict.get('fusion_heads', 8), cantor_depth=config_dict.get('cantor_depth', 8), cantor_local_window=config_dict.get('cantor_local_window', 3), alpha_init=config_dict.get('alpha_init', 1.0), beta_init=config_dict.get('beta_init', 0.3), ) model = LyraV2(vae_config) model.load_state_dict(state_dict, strict=False) model.to(self.device) return model def unload(self): """Free VRAM by unloading the model.""" if self._model is not None: del self._model self._model = None self._info = None self._loaded = False torch.cuda.empty_cache() print("🗑️ Lyra VAE unloaded") # ============================================================================ # SCHEDULER FACTORY # ============================================================================ def get_scheduler( scheduler_name: str, config_source: str = "stabilityai/stable-diffusion-xl-base-1.0", is_sdxl: bool = True ): """Create scheduler by name. Args: scheduler_name: One of SCHEDULER_CHOICES config_source: HF repo to load scheduler config from is_sdxl: Whether this is for SDXL (affects some defaults) Returns: Configured scheduler instance """ subfolder = "scheduler" if scheduler_name == SCHEDULER_EULER_A: return EulerAncestralDiscreteScheduler.from_pretrained( config_source, subfolder=subfolder ) elif scheduler_name == SCHEDULER_EULER: return EulerDiscreteScheduler.from_pretrained( config_source, subfolder=subfolder ) elif scheduler_name == SCHEDULER_DPM_2M_SDE: # DPM++ 2M SDE - good for detailed images return DPMSolverSDEScheduler.from_pretrained( config_source, subfolder=subfolder, algorithm_type="sde-dpmsolver++", solver_order=2, use_karras_sigmas=True, ) elif scheduler_name == SCHEDULER_DPM_2M: # DPM++ 2M - fast and quality return DPMSolverMultistepScheduler.from_pretrained( config_source, subfolder=subfolder, algorithm_type="dpmsolver++", solver_order=2, use_karras_sigmas=True, ) else: print(f"⚠️ Unknown scheduler '{scheduler_name}', defaulting to Euler Ancestral") return EulerAncestralDiscreteScheduler.from_pretrained( config_source, subfolder=subfolder ) # ============================================================================ # UTILITIES # ============================================================================ def get_clip_hidden_state( model_output, clip_skip: int = 1, output_hidden_states: bool = True ) -> torch.Tensor: """Extract hidden state with clip_skip support.""" if clip_skip == 1 or not output_hidden_states: return model_output.last_hidden_state if hasattr(model_output, 'hidden_states') and model_output.hidden_states is not None: return model_output.hidden_states[-clip_skip] return model_output.last_hidden_state # ============================================================================ # SDXL PIPELINE # ============================================================================ class SDXLFlowMatchingPipeline: """Pipeline for SDXL-based flow-matching inference with dual CLIP encoders. Uses lazy loading for T5 and Lyra - they're only downloaded when actually used. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler, device: str = "cuda", t5_loader: Optional[LazyT5Encoder] = None, lyra_loader: Optional[LazyLyraModel] = None, clip_skip: int = 1 ): self.vae = vae self.text_encoder = text_encoder self.text_encoder_2 = text_encoder_2 self.tokenizer = tokenizer self.tokenizer_2 = tokenizer_2 self.unet = unet self.scheduler = scheduler self.device = device # Lazy loaders for Lyra components self.t5_loader = t5_loader self.lyra_loader = lyra_loader # Settings self.clip_skip = clip_skip self.vae_scale_factor = 0.13025 self.arch = ARCH_SDXL # Track current scheduler name for UI self._scheduler_name = SCHEDULER_EULER_A def set_scheduler(self, scheduler_name: str): """Switch scheduler without reloading model.""" if scheduler_name != self._scheduler_name: self.scheduler = get_scheduler( scheduler_name, config_source="stabilityai/stable-diffusion-xl-base-1.0", is_sdxl=True ) self._scheduler_name = scheduler_name print(f"✓ Scheduler changed to: {scheduler_name}") @property def t5_encoder(self) -> Optional[T5EncoderModel]: """Access T5 encoder (triggers lazy load if needed).""" return self.t5_loader.encoder if self.t5_loader else None @property def t5_tokenizer(self) -> Optional[T5Tokenizer]: """Access T5 tokenizer (triggers lazy load if needed).""" return self.t5_loader.tokenizer if self.t5_loader else None @property def lyra_model(self): """Access Lyra model (triggers lazy load if needed).""" return self.lyra_loader.model if self.lyra_loader else None @property def lyra_available(self) -> bool: """Check if Lyra components are configured (not necessarily loaded).""" return self.t5_loader is not None and self.lyra_loader is not None def encode_prompt( self, prompt: str, negative_prompt: str = "", clip_skip: int = 1 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Encode prompts using dual CLIP encoders for SDXL.""" # CLIP-L encoding text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.device) with torch.no_grad(): output_hidden_states = clip_skip > 1 clip_l_output = self.text_encoder( text_input_ids, output_hidden_states=output_hidden_states ) prompt_embeds_l = get_clip_hidden_state(clip_l_output, clip_skip, output_hidden_states) # CLIP-G encoding text_inputs_2 = self.tokenizer_2( prompt, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids_2 = text_inputs_2.input_ids.to(self.device) with torch.no_grad(): clip_g_output = self.text_encoder_2( text_input_ids_2, output_hidden_states=output_hidden_states ) prompt_embeds_g = get_clip_hidden_state(clip_g_output, clip_skip, output_hidden_states) pooled_prompt_embeds = clip_g_output.text_embeds # Concatenate CLIP-L and CLIP-G embeddings prompt_embeds = torch.cat([prompt_embeds_l, prompt_embeds_g], dim=-1) # Negative prompt if negative_prompt: uncond_inputs = self.tokenizer( negative_prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_inputs.input_ids.to(self.device) uncond_inputs_2 = self.tokenizer_2( negative_prompt, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt", ) uncond_input_ids_2 = uncond_inputs_2.input_ids.to(self.device) with torch.no_grad(): uncond_output_l = self.text_encoder( uncond_input_ids, output_hidden_states=output_hidden_states ) negative_embeds_l = get_clip_hidden_state(uncond_output_l, clip_skip, output_hidden_states) uncond_output_g = self.text_encoder_2( uncond_input_ids_2, output_hidden_states=output_hidden_states ) negative_embeds_g = get_clip_hidden_state(uncond_output_g, clip_skip, output_hidden_states) negative_pooled = uncond_output_g.text_embeds negative_prompt_embeds = torch.cat([negative_embeds_l, negative_embeds_g], dim=-1) else: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled = torch.zeros_like(pooled_prompt_embeds) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled def encode_prompt_lyra( self, prompt: str, negative_prompt: str = "", clip_skip: int = 1, t5_summary: str = "", lyra_strength: float = 0.3, use_separator: bool = True, clip_include_summary: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Encode prompts using Lyra VAE v2 fusion (CLIP + T5). CLIP encoders receive tags only (prompt field). T5 encoder receives tags + separator + summary. Args: prompt: Tags/keywords for CLIP encoding negative_prompt: Negative tags clip_skip: CLIP skip layers t5_summary: Natural language summary for T5 lyra_strength: Blend factor (0=pure CLIP, 1=pure Lyra) use_separator: If True, use ¶ separator between tags and summary clip_include_summary: If True, append summary to CLIP input (default False) This triggers lazy loading of T5 and Lyra if not already loaded. Uses sequence lengths from Lyra config for proper tokenization. """ if not self.lyra_available: raise ValueError("Lyra VAE components not configured") # Get sequence lengths from Lyra config (available before full load) t5_max_length = self.lyra_loader.t5_max_length # 512 for Illustrious clip_max_length = self.lyra_loader.clip_max_length # 77 for Illustrious print(f"[Lyra] Using sequence lengths: CLIP={clip_max_length}, T5={t5_max_length}") # Access properties triggers lazy load t5_encoder = self.t5_encoder t5_tokenizer = self.t5_tokenizer lyra_model = self.lyra_model # === CLIP ENCODING === # CLIP sees tags only (unless clip_include_summary is True) if clip_include_summary and t5_summary.strip(): clip_prompt = f"{prompt} {t5_summary}" else: clip_prompt = prompt # Get CLIP embeddings with tags only prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt( clip_prompt, negative_prompt, clip_skip ) # === T5 ENCODING === # T5 sees tags + separator + summary (or tags + summary if no separator) SUMMARY_SEPARATOR = "¶" if t5_summary.strip(): if use_separator: t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {t5_summary}" else: t5_prompt = f"{prompt} {t5_summary}" else: # No summary provided - T5 just sees the tags t5_prompt = prompt print(f"[Lyra] CLIP input: {clip_prompt[:80]}...") print(f"[Lyra] T5 input: {t5_prompt[:80]}...") # Get T5 embeddings with config-specified max_length t5_inputs = t5_tokenizer( t5_prompt, max_length=t5_max_length, padding='max_length', truncation=True, return_tensors='pt' ).to(self.device) with torch.no_grad(): t5_embeds = t5_encoder(**t5_inputs).last_hidden_state # === LYRA FUSION === clip_l_dim = 768 clip_g_dim = 1280 clip_l_embeds = prompt_embeds[..., :clip_l_dim] clip_g_embeds = prompt_embeds[..., clip_l_dim:] with torch.no_grad(): modality_inputs = { 'clip_l': clip_l_embeds.float(), 'clip_g': clip_g_embeds.float(), 't5_xl_l': t5_embeds.float(), 't5_xl_g': t5_embeds.float() } reconstructions, mu, logvar, _ = lyra_model( modality_inputs, target_modalities=['clip_l', 'clip_g'] ) lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype) lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype) # Normalize reconstructions to match input statistics clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8) clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8) if clip_l_std_ratio > 2.0 or clip_l_std_ratio < 0.5: lyra_clip_l = (lyra_clip_l - lyra_clip_l.mean()) / (lyra_clip_l.std() + 1e-8) lyra_clip_l = lyra_clip_l * clip_l_embeds.std() + clip_l_embeds.mean() if clip_g_std_ratio > 2.0 or clip_g_std_ratio < 0.5: lyra_clip_g = (lyra_clip_g - lyra_clip_g.mean()) / (lyra_clip_g.std() + 1e-8) lyra_clip_g = lyra_clip_g * clip_g_embeds.std() + clip_g_embeds.mean() # Blend original CLIP with Lyra reconstruction fused_clip_l = (1 - lyra_strength) * clip_l_embeds + lyra_strength * lyra_clip_l fused_clip_g = (1 - lyra_strength) * clip_g_embeds + lyra_strength * lyra_clip_g prompt_embeds_fused = torch.cat([fused_clip_l, fused_clip_g], dim=-1) # === NEGATIVE PROMPT === # Negative uses same logic: CLIP sees negative tags only if negative_prompt: neg_strength = lyra_strength * 0.5 # Less aggressive for negative # T5 negative: tags only (no summary for negative) t5_neg_prompt = negative_prompt t5_inputs_neg = t5_tokenizer( t5_neg_prompt, max_length=t5_max_length, padding='max_length', truncation=True, return_tensors='pt' ).to(self.device) with torch.no_grad(): t5_embeds_neg = t5_encoder(**t5_inputs_neg).last_hidden_state neg_clip_l = negative_prompt_embeds[..., :clip_l_dim] neg_clip_g = negative_prompt_embeds[..., clip_l_dim:] modality_inputs_neg = { 'clip_l': neg_clip_l.float(), 'clip_g': neg_clip_g.float(), 't5_xl_l': t5_embeds_neg.float(), 't5_xl_g': t5_embeds_neg.float() } recon_neg, _, _, _ = lyra_model(modality_inputs_neg, target_modalities=['clip_l', 'clip_g']) lyra_neg_l = recon_neg['clip_l'].to(negative_prompt_embeds.dtype) lyra_neg_g = recon_neg['clip_g'].to(negative_prompt_embeds.dtype) # Normalize neg_l_ratio = lyra_neg_l.std() / (neg_clip_l.std() + 1e-8) neg_g_ratio = lyra_neg_g.std() / (neg_clip_g.std() + 1e-8) if neg_l_ratio > 2.0 or neg_l_ratio < 0.5: lyra_neg_l = (lyra_neg_l - lyra_neg_l.mean()) / (lyra_neg_l.std() + 1e-8) lyra_neg_l = lyra_neg_l * neg_clip_l.std() + neg_clip_l.mean() if neg_g_ratio > 2.0 or neg_g_ratio < 0.5: lyra_neg_g = (lyra_neg_g - lyra_neg_g.mean()) / (lyra_neg_g.std() + 1e-8) lyra_neg_g = lyra_neg_g * neg_clip_g.std() + neg_clip_g.mean() fused_neg_l = (1 - neg_strength) * neg_clip_l + neg_strength * lyra_neg_l fused_neg_g = (1 - neg_strength) * neg_clip_g + neg_strength * lyra_neg_g negative_prompt_embeds_fused = torch.cat([fused_neg_l, fused_neg_g], dim=-1) else: negative_prompt_embeds_fused = torch.zeros_like(prompt_embeds_fused) return prompt_embeds_fused, negative_prompt_embeds_fused, pooled, negative_pooled def _get_add_time_ids( self, original_size: Tuple[int, int], crops_coords_top_left: Tuple[int, int], target_size: Tuple[int, int], dtype: torch.dtype ) -> torch.Tensor: """Create time embedding IDs for SDXL.""" add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device) return add_time_ids @torch.no_grad() def __call__( self, prompt: str, negative_prompt: str = "", height: int = 1024, width: int = 1024, num_inference_steps: int = 20, guidance_scale: float = 7.5, shift: float = 0.0, use_flow_matching: bool = False, prediction_type: str = "epsilon", seed: Optional[int] = None, use_lyra: bool = False, clip_skip: int = 1, t5_summary: str = "", lyra_strength: float = 1.0, use_separator: bool = True, clip_include_summary: bool = False, progress_callback=None ): """Generate image using SDXL architecture. Args: prompt: Tags/keywords for image generation negative_prompt: Negative tags t5_summary: Natural language summary (T5 only, unless clip_include_summary=True) use_separator: Use ¶ separator between tags and summary in T5 input clip_include_summary: If True, append summary to CLIP input (default False) seed: Random seed for reproducibility (must be int) """ # Create generator with seed for deterministic generation if seed is not None: seed = int(seed) generator = torch.Generator(device=self.device).manual_seed(seed) print(f"[SDXL Pipeline] Using seed: {seed}") else: generator = None print("[SDXL Pipeline] No seed provided, using random") # Encode prompts (Lyra triggers lazy load only if use_lyra=True) if use_lyra and self.lyra_available: prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra( prompt, negative_prompt, clip_skip, t5_summary, lyra_strength, use_separator=use_separator, clip_include_summary=clip_include_summary ) else: prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt( prompt, negative_prompt, clip_skip ) # Prepare latents latent_channels = 4 latent_height = height // 8 latent_width = width // 8 latents = torch.randn( (1, latent_channels, latent_height, latent_width), generator=generator, device=self.device, dtype=torch.float16 ) # Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps = self.scheduler.timesteps if not use_flow_matching: latents = latents * self.scheduler.init_noise_sigma # Prepare added time embeddings for SDXL original_size = (height, width) target_size = (height, width) crops_coords_top_left = (0, 0) add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=torch.float16 ) negative_add_time_ids = add_time_ids # Denoising loop for i, t in enumerate(timesteps): if progress_callback: progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}") latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents if use_flow_matching and shift > 0: sigma = t.float() / 1000.0 sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) scaling = torch.sqrt(1 + sigma_shifted ** 2) latent_model_input = latent_model_input / scaling else: latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) timestep = t.expand(latent_model_input.shape[0]) if guidance_scale > 1.0: text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) add_text_embeds = torch.cat([negative_pooled, pooled]) add_time_ids_input = torch.cat([negative_add_time_ids, add_time_ids]) else: text_embeds = prompt_embeds add_text_embeds = pooled add_time_ids_input = add_time_ids added_cond_kwargs = { "text_embeds": add_text_embeds, "time_ids": add_time_ids_input } noise_pred = self.unet( latent_model_input, timestep, encoder_hidden_states=text_embeds, added_cond_kwargs=added_cond_kwargs, return_dict=False )[0] if guidance_scale > 1.0: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if use_flow_matching: sigma = t.float() / 1000.0 sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) if prediction_type == "v_prediction": v_pred = noise_pred alpha_t = torch.sqrt(1 - sigma_shifted ** 2) sigma_t = sigma_shifted noise_pred = alpha_t * v_pred + sigma_t * latents dt = -1.0 / num_inference_steps latents = latents + dt * noise_pred else: # Pass generator for deterministic ancestral/SDE sampling latents = self.scheduler.step( noise_pred, t, latents, generator=generator, return_dict=False )[0] # Decode latents = latents / self.vae_scale_factor with torch.no_grad(): image = self.vae.decode(latents.to(self.vae.dtype)).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = (image * 255).round().astype("uint8") image = Image.fromarray(image[0]) return image # ============================================================================ # SD1.5 PIPELINE # ============================================================================ class SD15FlowMatchingPipeline: """Pipeline for SD1.5-based flow-matching inference.""" def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler, device: str = "cuda", t5_loader: Optional[LazyT5Encoder] = None, lyra_loader: Optional[LazyLyraModel] = None, ): self.vae = vae self.text_encoder = text_encoder self.tokenizer = tokenizer self.unet = unet self.scheduler = scheduler self.device = device self.t5_loader = t5_loader self.lyra_loader = lyra_loader self.vae_scale_factor = 0.18215 self.arch = ARCH_SD15 self.is_lune_model = False @property def t5_encoder(self): return self.t5_loader.encoder if self.t5_loader else None @property def t5_tokenizer(self): return self.t5_loader.tokenizer if self.t5_loader else None @property def lyra_model(self): return self.lyra_loader.model if self.lyra_loader else None @property def lyra_available(self) -> bool: return self.t5_loader is not None and self.lyra_loader is not None def encode_prompt(self, prompt: str, negative_prompt: str = ""): """Encode text prompts to embeddings.""" text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.device) with torch.no_grad(): prompt_embeds = self.text_encoder(text_input_ids)[0] if negative_prompt: uncond_inputs = self.tokenizer( negative_prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_inputs.input_ids.to(self.device) with torch.no_grad(): negative_prompt_embeds = self.text_encoder(uncond_input_ids)[0] else: negative_prompt_embeds = torch.zeros_like(prompt_embeds) return prompt_embeds, negative_prompt_embeds def encode_prompt_lyra(self, prompt: str, negative_prompt: str = ""): """Encode using Lyra VAE v1 (CLIP + T5 fusion). Uses sequence lengths from Lyra config for proper tokenization. """ if not self.lyra_available: raise ValueError("Lyra VAE components not configured") # Get sequence length from config (v1 uses same length for clip and t5) # Default to 77 for SD1.5/v1 t5_max_length = self.lyra_loader.config.get('seq_len', 77) print(f"[Lyra v1] Using sequence length: {t5_max_length}") t5_encoder = self.t5_encoder t5_tokenizer = self.t5_tokenizer lyra_model = self.lyra_model # CLIP text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.device) with torch.no_grad(): clip_embeds = self.text_encoder(text_input_ids)[0] # T5 with config-specified max_length t5_inputs = t5_tokenizer( prompt, max_length=t5_max_length, padding='max_length', truncation=True, return_tensors='pt' ).to(self.device) with torch.no_grad(): t5_embeds = t5_encoder(**t5_inputs).last_hidden_state # Fuse modality_inputs = {'clip': clip_embeds, 't5': t5_embeds} with torch.no_grad(): reconstructions, mu, logvar = lyra_model( modality_inputs, target_modalities=['clip'] ) prompt_embeds = reconstructions['clip'] # Negative if negative_prompt: uncond_inputs = self.tokenizer( negative_prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_inputs.input_ids.to(self.device) with torch.no_grad(): clip_embeds_uncond = self.text_encoder(uncond_input_ids)[0] t5_inputs_uncond = t5_tokenizer( negative_prompt, max_length=t5_max_length, padding='max_length', truncation=True, return_tensors='pt' ).to(self.device) with torch.no_grad(): t5_embeds_uncond = t5_encoder(**t5_inputs_uncond).last_hidden_state modality_inputs_uncond = {'clip': clip_embeds_uncond, 't5': t5_embeds_uncond} with torch.no_grad(): reconstructions_uncond, _, _ = lyra_model( modality_inputs_uncond, target_modalities=['clip'] ) negative_prompt_embeds = reconstructions_uncond['clip'] else: negative_prompt_embeds = torch.zeros_like(prompt_embeds) return prompt_embeds, negative_prompt_embeds @torch.no_grad() def __call__( self, prompt: str, negative_prompt: str = "", height: int = 512, width: int = 512, num_inference_steps: int = 20, guidance_scale: float = 7.5, shift: float = 2.5, use_flow_matching: bool = True, prediction_type: str = "epsilon", seed: Optional[int] = None, use_lyra: bool = False, clip_skip: int = 1, t5_summary: str = "", lyra_strength: float = 1.0, progress_callback=None ): """Generate image.""" # Create generator with seed for deterministic generation if seed is not None: seed = int(seed) generator = torch.Generator(device=self.device).manual_seed(seed) print(f"[SD1.5 Pipeline] Using seed: {seed}") else: generator = None print("[SD1.5 Pipeline] No seed provided, using random") if use_lyra and self.lyra_available: prompt_embeds, negative_prompt_embeds = self.encode_prompt_lyra(prompt, negative_prompt) else: prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, negative_prompt) latent_channels = 4 latent_height = height // 8 latent_width = width // 8 latents = torch.randn( (1, latent_channels, latent_height, latent_width), generator=generator, device=self.device, dtype=torch.float32 ) self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps = self.scheduler.timesteps if not use_flow_matching: latents = latents * self.scheduler.init_noise_sigma for i, t in enumerate(timesteps): if progress_callback: progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}") latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents if use_flow_matching and shift > 0: sigma = t.float() / 1000.0 sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) scaling = torch.sqrt(1 + sigma_shifted ** 2) latent_model_input = latent_model_input / scaling else: latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) timestep = t.expand(latent_model_input.shape[0]) text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if guidance_scale > 1.0 else prompt_embeds noise_pred = self.unet( latent_model_input, timestep, encoder_hidden_states=text_embeds, return_dict=False )[0] if guidance_scale > 1.0: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if use_flow_matching: sigma = t.float() / 1000.0 sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) if prediction_type == "v_prediction": v_pred = noise_pred alpha_t = torch.sqrt(1 - sigma_shifted ** 2) sigma_t = sigma_shifted noise_pred = alpha_t * v_pred + sigma_t * latents dt = -1.0 / num_inference_steps latents = latents + dt * noise_pred else: # Pass generator for deterministic ancestral/SDE sampling latents = self.scheduler.step( noise_pred, t, latents, generator=generator, return_dict=False )[0] latents = latents / self.vae_scale_factor if self.is_lune_model: latents = latents * 5.52 with torch.no_grad(): image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = (image * 255).round().astype("uint8") image = Image.fromarray(image[0]) return image # ============================================================================ # MODEL LOADERS # ============================================================================ def load_lune_checkpoint(repo_id: str, filename: str, device: str = "cuda"): """Load Lune checkpoint from .pt file.""" print(f"📥 Downloading: {repo_id}/{filename}") checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model") checkpoint = torch.load(checkpoint_path, map_location="cpu") print(f"🏗️ Initializing SD1.5 UNet...") unet = UNet2DConditionModel.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="unet", torch_dtype=torch.float32 ) student_state_dict = checkpoint["student"] cleaned_dict = {} for key, value in student_state_dict.items(): if key.startswith("unet."): cleaned_dict[key[5:]] = value else: cleaned_dict[key] = value unet.load_state_dict(cleaned_dict, strict=False) step = checkpoint.get("gstep", "unknown") print(f"✅ Loaded Lune from step {step}") return unet.to(device) def load_illustrious_xl( repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious", filename: str = "", device: str = "cuda" ) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]: """Load Illustrious XL from single safetensors file.""" from diffusers import StableDiffusionXLPipeline # Default checkpoint if none specified if not filename or not filename.strip(): filename = "illustriousXL_v01.safetensors" print(f"📥 Loading Illustrious XL: {repo_id}/{filename}") checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model") print(f"✓ Downloaded: {checkpoint_path}") print("📦 Loading with StableDiffusionXLPipeline.from_single_file()...") pipe = StableDiffusionXLPipeline.from_single_file( checkpoint_path, torch_dtype=torch.float16, use_safetensors=True, ) unet = pipe.unet.to(device) vae = pipe.vae.to(device) text_encoder = pipe.text_encoder.to(device) text_encoder_2 = pipe.text_encoder_2.to(device) tokenizer = pipe.tokenizer tokenizer_2 = pipe.tokenizer_2 del pipe torch.cuda.empty_cache() print("✅ Illustrious XL loaded!") print(f" UNet params: {sum(p.numel() for p in unet.parameters()):,}") print(f" VAE params: {sum(p.numel() for p in vae.parameters()):,}") return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 def load_sdxl_base(device: str = "cuda"): """Load standard SDXL base model.""" print("📥 Loading SDXL Base 1.0...") unet = UNet2DConditionModel.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16 ).to(device) # Use fp16-fix VAE to avoid NaN issues with SDXL's original VAE in fp16 print(" Using madebyollin/sdxl-vae-fp16-fix for stable fp16 decoding...") vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 ).to(device) text_encoder = CLIPTextModel.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", torch_dtype=torch.float16 ).to(device) text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2", torch_dtype=torch.float16 ).to(device) tokenizer = CLIPTokenizer.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="tokenizer" ) tokenizer_2 = CLIPTokenizer.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="tokenizer_2" ) print("✅ SDXL Base loaded!") return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 # ============================================================================ # PIPELINE INITIALIZATION # ============================================================================ def initialize_pipeline(model_choice: str, device: str = "cuda", checkpoint: str = "", lyra_checkpoint: str = ""): """Initialize the complete pipeline based on model choice. Uses lazy loading for T5 and Lyra - they won't be downloaded until first use. Args: model_choice: Model selection from dropdown device: Target device checkpoint: Optional custom checkpoint filename (e.g., "my_model.safetensors") lyra_checkpoint: Optional custom Lyra VAE checkpoint filename """ print(f"🚀 Initializing {model_choice} pipeline...") if checkpoint: print(f" Custom model checkpoint: {checkpoint}") if lyra_checkpoint: print(f" Custom Lyra checkpoint: {lyra_checkpoint}") is_sdxl = "Illustrious" in model_choice or "SDXL" in model_choice is_lune = "Lune" in model_choice if is_sdxl: # SDXL-based models if "Illustrious" in model_choice: unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl( device=device, filename=checkpoint ) else: unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_sdxl_base(device=device) # Create LAZY loaders for T5 and Lyra (no download yet!) print("📋 Configuring lazy loaders for T5-XL and Lyra VAE (will download on first use)") t5_loader = LazyT5Encoder( model_name=T5_XL_MODEL, # google/flan-t5-xl device=device, dtype=torch.float16 ) lyra_loader = LazyLyraModel( repo_id=LYRA_ILLUSTRIOUS_REPO, device=device, checkpoint=lyra_checkpoint if lyra_checkpoint and lyra_checkpoint.strip() else None ) # Default scheduler: Euler Ancestral scheduler = get_scheduler(SCHEDULER_EULER_A, is_sdxl=True) pipeline = SDXLFlowMatchingPipeline( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, device=device, t5_loader=t5_loader, lyra_loader=lyra_loader, clip_skip=1 ) else: # SD1.5-based models vae = AutoencoderKL.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="vae", torch_dtype=torch.float32 ).to(device) text_encoder = CLIPTextModel.from_pretrained( "openai/clip-vit-large-patch14", torch_dtype=torch.float32 ).to(device) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") # Lazy loaders for SD1.5 Lyra (T5-base) print("📋 Configuring lazy loaders for T5-base and Lyra VAE v1 (will download on first use)") t5_loader = LazyT5Encoder( model_name=T5_BASE_MODEL, # google/flan-t5-base device=device, dtype=torch.float32 ) lyra_loader = LazyLyraModel( repo_id=LYRA_SD15_REPO, device=device, checkpoint=lyra_checkpoint if lyra_checkpoint and lyra_checkpoint.strip() else None ) # Load UNet if is_lune: repo_id = "AbstractPhil/sd15-flow-lune" # Use custom checkpoint or default if checkpoint and checkpoint.strip(): filename = checkpoint else: filename = "sd15_flow_lune_e34_s34000.pt" unet = load_lune_checkpoint(repo_id, filename, device) else: unet = UNet2DConditionModel.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="unet", torch_dtype=torch.float32 ).to(device) scheduler = EulerDiscreteScheduler.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="scheduler" ) pipeline = SD15FlowMatchingPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, device=device, t5_loader=t5_loader, lyra_loader=lyra_loader, ) pipeline.is_lune_model = is_lune print("✅ Pipeline initialized! (T5 and Lyra will load on first use)") return pipeline # ============================================================================ # GLOBAL STATE # ============================================================================ CURRENT_PIPELINE = None CURRENT_MODEL = None CURRENT_CHECKPOINT = None CURRENT_LYRA_CHECKPOINT = None def get_pipeline(model_choice: str, checkpoint: str = "", lyra_checkpoint: str = ""): """Get or create pipeline for selected model.""" global CURRENT_PIPELINE, CURRENT_MODEL, CURRENT_CHECKPOINT, CURRENT_LYRA_CHECKPOINT # Normalize empty values checkpoint = checkpoint.strip() if checkpoint else "" lyra_checkpoint = lyra_checkpoint.strip() if lyra_checkpoint else "" # Reinitialize if model or any checkpoint changed if (CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice or CURRENT_CHECKPOINT != checkpoint or CURRENT_LYRA_CHECKPOINT != lyra_checkpoint): CURRENT_PIPELINE = initialize_pipeline( model_choice, device="cuda", checkpoint=checkpoint, lyra_checkpoint=lyra_checkpoint ) CURRENT_MODEL = model_choice CURRENT_CHECKPOINT = checkpoint CURRENT_LYRA_CHECKPOINT = lyra_checkpoint return CURRENT_PIPELINE # ============================================================================ # INFERENCE # ============================================================================ def estimate_duration(num_steps: int, width: int, height: int, use_lyra: bool = False, is_sdxl: bool = False) -> int: """Estimate GPU duration.""" base_time_per_step = 0.5 if is_sdxl else 0.3 resolution_factor = (width * height) / (512 * 512) estimated = num_steps * base_time_per_step * resolution_factor if use_lyra: estimated *= 2 estimated += 10 # Extra time for lazy loading on first use return int(estimated + 20) @spaces.GPU(duration=lambda *args: estimate_duration( args[8], args[10], args[11], args[14], "SDXL" in args[3] or "Illustrious" in args[3] )) def generate_image( prompt: str, t5_summary: str, negative_prompt: str, model_choice: str, checkpoint: str, lyra_checkpoint: str, scheduler_choice: str, clip_skip: int, num_steps: int, cfg_scale: float, width: int, height: int, shift: float, use_flow_matching: bool, use_lyra: bool, lyra_strength: float, use_separator: bool, clip_include_summary: bool, seed: int, randomize_seed: bool, progress=gr.Progress() ): """Generate image with ZeroGPU support. Args: prompt: Tags/keywords (CLIP input) t5_summary: Natural language summary (T5 input, unless clip_include_summary) checkpoint: Custom model checkpoint filename (empty for default) lyra_checkpoint: Custom Lyra VAE checkpoint filename (empty for default) use_separator: Use ¶ separator between tags and summary clip_include_summary: If True, CLIP also sees the summary """ # Ensure seed is an integer (Gradio sliders return floats) seed = int(seed) if randomize_seed: seed = np.random.randint(0, 2**32 - 1) print(f"🎲 Using seed: {seed} (randomize={randomize_seed})") def progress_callback(step, total, desc): progress((step + 1) / total, desc=desc) try: pipeline = get_pipeline(model_choice, checkpoint, lyra_checkpoint) # Update scheduler if needed (SDXL only) is_sdxl = "SDXL" in model_choice or "Illustrious" in model_choice if is_sdxl and hasattr(pipeline, 'set_scheduler'): pipeline.set_scheduler(scheduler_choice) prediction_type = "epsilon" if not is_sdxl and "Lune" in model_choice: prediction_type = "v_prediction" if not use_lyra or not pipeline.lyra_available: progress(0.05, desc="Generating...") image = pipeline( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_inference_steps=num_steps, guidance_scale=cfg_scale, shift=shift, use_flow_matching=use_flow_matching, prediction_type=prediction_type, seed=seed, use_lyra=False, clip_skip=clip_skip, progress_callback=progress_callback ) progress(1.0, desc="Complete!") return image, None, seed else: # Side-by-side comparison: SAME seed for both! progress(0.05, desc="Generating standard...") image_standard = pipeline( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_inference_steps=num_steps, guidance_scale=cfg_scale, shift=shift, use_flow_matching=use_flow_matching, prediction_type=prediction_type, seed=seed, # Same seed use_lyra=False, clip_skip=clip_skip, progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d) ) progress(0.5, desc="Generating Lyra fusion (loading T5 + Lyra if needed)...") image_lyra = pipeline( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_inference_steps=num_steps, guidance_scale=cfg_scale, shift=shift, use_flow_matching=use_flow_matching, prediction_type=prediction_type, seed=seed, # Same seed for deterministic comparison use_lyra=True, clip_skip=clip_skip, t5_summary=t5_summary, lyra_strength=lyra_strength, use_separator=use_separator, clip_include_summary=clip_include_summary, progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d) ) progress(1.0, desc="Complete!") return image_standard, image_lyra, seed except Exception as e: print(f"❌ Generation failed: {e}") import traceback traceback.print_exc() raise e # ============================================================================ # GRADIO UI # ============================================================================ def create_demo(): """Create Gradio interface.""" with gr.Blocks() as demo: gr.Markdown(""" # 🌙 Lyra/Lune Flow-Matching Image Generation **Geometric crystalline diffusion** by [AbstractPhil](https://huggingface.co/AbstractPhil) Generate images using SD1.5 and SDXL-based models with geometric deep learning: | Model | Architecture | Lyra Version | Best For | |-------|-------------|--------------|----------| | **Illustrious XL** | SDXL | v2 (T5-XL) | Anime/illustration, high detail | | **SDXL Base** | SDXL | v2 (T5-XL) | Photorealistic, general purpose | | **Flow-Lune** | SD1.5 | v1 (T5-base) | Fast flow matching (15-25 steps) | | **SD1.5 Base** | SD1.5 | v1 (T5-base) | Baseline comparison | **Lazy Loading**: T5 and Lyra VAE are only downloaded when you enable Lyra fusion! """) with gr.Row(): with gr.Column(scale=1): prompt = gr.TextArea( label="Prompt (Tags for CLIP)", value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background", lines=3, info="CLIP encoders see these tags. T5 also sees these + the summary below." ) t5_summary = gr.TextArea( label="T5 Summary (Natural Language - T5 Only)", value="A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms against a bright sky", lines=2, info="T5 sees: tags ¶ summary. CLIP sees: tags only (unless 'Include Summary in CLIP' is enabled)." ) negative_prompt = gr.TextArea( label="Negative Prompt", value="lowres, bad anatomy, bad hands, text, error, cropped, worst quality, low quality", lines=2 ) model_choice = gr.Dropdown( label="Model", choices=[ "Illustrious XL", "SDXL Base", "Flow-Lune (SD1.5)", "SD1.5 Base" ], value="Illustrious XL" ) with gr.Accordion("Advanced Options", open=False): checkpoint = gr.Textbox( label="Model Checkpoint (optional)", value="", placeholder="e.g., illustriousXL_v01.safetensors", info="Leave empty for default. Illustrious: .safetensors, Lune: .pt" ) lyra_checkpoint = gr.Textbox( label="Lyra VAE Checkpoint (optional)", value="weights/lyra_illustrious_step_12000.safetensors", placeholder="e.g., lyra_e100_s50000.safetensors", info="Leave empty for latest. Loaded from weights/ folder in Lyra repo." ) scheduler_choice = gr.Dropdown( label="Scheduler (SDXL only)", choices=SCHEDULER_CHOICES, value=SCHEDULER_EULER_A, info="Euler Ancestral recommended for Illustrious" ) clip_skip = gr.Slider( label="CLIP Skip", minimum=1, maximum=4, value=2, step=1, info="2 recommended for Illustrious, 1 for others" ) use_lyra = gr.Checkbox( label="Enable Lyra VAE (CLIP+T5 Fusion)", value=True, # DEFAULT: ON info="Enables lazy loading of T5 and Lyra on first use" ) lyra_strength = gr.Slider( label="Lyra Blend Strength", minimum=0.0, maximum=3.0, value=1.0, step=0.05, info="0.0 = pure CLIP, 1.0 = pure Lyra reconstruction" ) with gr.Accordion("Lyra Advanced Settings", open=False): use_separator = gr.Checkbox( label="Use ¶ Separator", value=True, info="Insert ¶ between tags and summary in T5 input" ) clip_include_summary = gr.Checkbox( label="Include Summary in CLIP", value=False, info="By default CLIP sees tags only. Enable to append summary to CLIP input." ) with gr.Accordion("Generation Settings", open=True): num_steps = gr.Slider( label="Steps", minimum=1, maximum=50, value=25, step=1 ) cfg_scale = gr.Slider( label="CFG Scale", minimum=1.0, maximum=20.0, value=7.0, step=0.5 ) with gr.Row(): width = gr.Slider( label="Width", minimum=512, maximum=1536, value=1024, step=64 ) height = gr.Slider( label="Height", minimum=512, maximum=1536, value=1024, step=64 ) seed = gr.Slider( label="Seed", minimum=0, maximum=2**32 - 1, value=42, # DEFAULT: 42 step=1 ) randomize_seed = gr.Checkbox( label="Randomize Seed", value=False # DEFAULT: OFF for reproducibility ) with gr.Accordion("Advanced (Flow Matching)", open=False): use_flow_matching = gr.Checkbox( label="Enable Flow Matching", value=False, info="Use flow matching ODE (for Lune only)" ) shift = gr.Slider( label="Shift", minimum=0.0, maximum=5.0, value=0.0, step=0.1, info="Flow matching shift (0=disabled)" ) generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg") with gr.Column(scale=1): with gr.Row(): output_image_standard = gr.Image( label="Standard", type="pil" ) output_image_lyra = gr.Image( label="Lyra Fusion 🎵", type="pil", visible=True # Visible by default since Lyra is on ) output_seed = gr.Number(label="Seed Used", precision=0) gr.Markdown(""" ### Tips - **Lazy Loading**: T5-XL (~3GB) and Lyra VAE only download when you enable Lyra - **Illustrious XL**: Use CLIP skip 2, Euler Ancestral scheduler - **Schedulers**: DPM++ 2M SDE for detail, Euler A for speed - **Lyra v2**: Uses `google/flan-t5-xl` for richer semantics - **Same Seed**: Both Standard and Lyra use the same seed for fair comparison """) # Event handlers def on_model_change(model_name): """Update defaults based on model.""" if "Illustrious" in model_name: return { clip_skip: gr.update(value=2), width: gr.update(value=1024), height: gr.update(value=1024), num_steps: gr.update(value=25), use_flow_matching: gr.update(value=False), shift: gr.update(value=0.0), scheduler_choice: gr.update(visible=True, value=SCHEDULER_EULER_A) } elif "SDXL" in model_name: return { clip_skip: gr.update(value=1), width: gr.update(value=1024), height: gr.update(value=1024), num_steps: gr.update(value=30), use_flow_matching: gr.update(value=False), shift: gr.update(value=0.0), scheduler_choice: gr.update(visible=True, value=SCHEDULER_EULER_A) } elif "Lune" in model_name: return { clip_skip: gr.update(value=1), width: gr.update(value=512), height: gr.update(value=512), num_steps: gr.update(value=20), use_flow_matching: gr.update(value=True), shift: gr.update(value=2.5), scheduler_choice: gr.update(visible=False) } else: # SD1.5 Base return { clip_skip: gr.update(value=1), width: gr.update(value=512), height: gr.update(value=512), num_steps: gr.update(value=30), use_flow_matching: gr.update(value=False), shift: gr.update(value=0.0), scheduler_choice: gr.update(visible=False) } def on_lyra_toggle(enabled): """Show/hide Lyra comparison.""" if enabled: return { output_image_standard: gr.update(visible=True, label="Standard"), output_image_lyra: gr.update(visible=True, label="Lyra Fusion 🎵") } else: return { output_image_standard: gr.update(visible=True, label="Generated Image"), output_image_lyra: gr.update(visible=False) } model_choice.change( fn=on_model_change, inputs=[model_choice], outputs=[clip_skip, width, height, num_steps, use_flow_matching, shift, scheduler_choice] ) use_lyra.change( fn=on_lyra_toggle, inputs=[use_lyra], outputs=[output_image_standard, output_image_lyra] ) generate_btn.click( fn=generate_image, inputs=[ prompt, t5_summary, negative_prompt, model_choice, checkpoint, lyra_checkpoint, scheduler_choice, clip_skip, num_steps, cfg_scale, width, height, shift, use_flow_matching, use_lyra, lyra_strength, use_separator, clip_include_summary, seed, randomize_seed ], outputs=[output_image_standard, output_image_lyra, output_seed] ) return demo # ============================================================================ # LAUNCH # ============================================================================ if __name__ == "__main__": demo = create_demo() demo.queue(max_size=20) demo.launch()