Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Lyra/Lune Flow-Matching Inference Space | |
| Author: AbstractPhil | |
| License: MIT | |
| SD1.5 and SDXL-based flow matching with geometric crystalline architectures. | |
| Supports Illustrious XL, standard SDXL, and SD1.5 variants. | |
| Lyra VAE Versions: | |
| - v1: SD1.5 (768 dim CLIP + T5-base) - geofractal.model.vae.vae_lyra | |
| - v2: SDXL/Illustrious (768 CLIP-L + 1280 CLIP-G + 2048 T5-XL) - geofractal.model.vae.vae_lyra_v2 | |
| Features: | |
| - Lazy loading: T5 and Lyra only download when first used | |
| - Multiple schedulers: Euler Ancestral, Euler, DPM++ 2M SDE, DPM++ 2M | |
| - Integrated loader module for automatic version detection | |
| """ | |
| import os | |
| import json | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Optional, Dict, Tuple, Union | |
| import spaces | |
| from safetensors.torch import load_file as load_safetensors | |
| from diffusers import ( | |
| UNet2DConditionModel, | |
| AutoencoderKL, | |
| EulerDiscreteScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| DPMSolverMultistepScheduler, | |
| DPMSolverSDEScheduler, | |
| ) | |
| from transformers import ( | |
| CLIPTextModel, | |
| CLIPTokenizer, | |
| CLIPTextModelWithProjection, | |
| T5EncoderModel, | |
| T5Tokenizer | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| # Import Lyra VAE v1 (SD1.5) from geofractal | |
| try: | |
| from geofractal.model.vae.vae_lyra import MultiModalVAE as LyraV1, MultiModalVAEConfig as LyraV1Config | |
| LYRA_V1_AVAILABLE = True | |
| except ImportError: | |
| print("⚠️ Lyra VAE v1 not available") | |
| LYRA_V1_AVAILABLE = False | |
| # Import Lyra VAE v2 (SDXL/Illustrious) from geofractal | |
| try: | |
| from geofractal.model.vae.vae_lyra_v2 import MultiModalVAE as LyraV2, MultiModalVAEConfig as LyraV2Config | |
| LYRA_V2_AVAILABLE = True | |
| except ImportError: | |
| print("⚠️ Lyra VAE v2 not available") | |
| LYRA_V2_AVAILABLE = False | |
| # Import Lyra loader module | |
| try: | |
| from geofractal.model.vae.loader import load_vae_lyra, load_lyra_illustrious | |
| LYRA_LOADER_AVAILABLE = True | |
| except ImportError: | |
| print("⚠️ Lyra loader module not available, using fallback") | |
| LYRA_LOADER_AVAILABLE = False | |
| # ============================================================================ | |
| # CONSTANTS | |
| # ============================================================================ | |
| ARCH_SD15 = "sd15" | |
| ARCH_SDXL = "sdxl" | |
| # Scheduler names | |
| SCHEDULER_EULER_A = "Euler Ancestral" | |
| SCHEDULER_EULER = "Euler" | |
| SCHEDULER_DPM_2M_SDE = "DPM++ 2M SDE" | |
| SCHEDULER_DPM_2M = "DPM++ 2M" | |
| SCHEDULER_CHOICES = [ | |
| SCHEDULER_EULER_A, | |
| SCHEDULER_EULER, | |
| SCHEDULER_DPM_2M_SDE, | |
| SCHEDULER_DPM_2M, | |
| ] | |
| # ComfyUI key prefixes for SDXL single-file checkpoints | |
| COMFYUI_UNET_PREFIX = "model.diffusion_model." | |
| COMFYUI_CLIP_L_PREFIX = "conditioner.embedders.0.transformer." | |
| COMFYUI_CLIP_G_PREFIX = "conditioner.embedders.1.model." | |
| COMFYUI_VAE_PREFIX = "first_stage_model." | |
| # Lyra repos | |
| LYRA_ILLUSTRIOUS_REPO = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious" | |
| LYRA_SD15_REPO = "AbstractPhil/vae-lyra" | |
| # T5 model - use flan-t5-xl (what Lyra was trained on) | |
| T5_XL_MODEL = "google/flan-t5-xl" | |
| T5_BASE_MODEL = "google/flan-t5-base" | |
| # ============================================================================ | |
| # LAZY LOADERS | |
| # ============================================================================ | |
| class LazyT5Encoder: | |
| """Lazy loader for T5 encoder - only downloads/loads when first accessed.""" | |
| def __init__(self, model_name: str = T5_XL_MODEL, device: str = "cuda", dtype=torch.float16): | |
| self.model_name = model_name | |
| self.device = device | |
| self.dtype = dtype | |
| self._encoder = None | |
| self._tokenizer = None | |
| self._loaded = False | |
| def encoder(self) -> T5EncoderModel: | |
| if self._encoder is None: | |
| print(f"📥 Lazy loading T5 encoder: {self.model_name}...") | |
| self._encoder = T5EncoderModel.from_pretrained( | |
| self.model_name, | |
| torch_dtype=self.dtype | |
| ).to(self.device) | |
| self._encoder.eval() | |
| print(f"✓ T5 encoder loaded ({sum(p.numel() for p in self._encoder.parameters())/1e6:.1f}M params)") | |
| self._loaded = True | |
| return self._encoder | |
| def tokenizer(self) -> T5Tokenizer: | |
| if self._tokenizer is None: | |
| print(f"📥 Loading T5 tokenizer: {self.model_name}...") | |
| self._tokenizer = T5Tokenizer.from_pretrained(self.model_name) | |
| print("✓ T5 tokenizer loaded") | |
| return self._tokenizer | |
| def is_loaded(self) -> bool: | |
| return self._loaded | |
| def unload(self): | |
| """Free VRAM by unloading the encoder.""" | |
| if self._encoder is not None: | |
| del self._encoder | |
| self._encoder = None | |
| self._loaded = False | |
| torch.cuda.empty_cache() | |
| print("🗑️ T5 encoder unloaded") | |
| class LazyLyraModel: | |
| """Lazy loader for Lyra VAE - only downloads/loads when first accessed. | |
| Exposes config with modality_seq_lens for proper tokenization lengths. | |
| """ | |
| def __init__( | |
| self, | |
| repo_id: str = LYRA_ILLUSTRIOUS_REPO, | |
| device: str = "cuda", | |
| checkpoint: Optional[str] = None | |
| ): | |
| self.repo_id = repo_id | |
| self.device = device | |
| self.checkpoint = checkpoint | |
| self._model = None | |
| self._info = None | |
| self._config = None | |
| self._loaded = False | |
| # Pre-fetch config without loading model (lightweight) | |
| self._prefetch_config() | |
| def _prefetch_config(self): | |
| """Fetch config.json to get sequence lengths without loading the full model.""" | |
| try: | |
| config_path = hf_hub_download( | |
| repo_id=self.repo_id, | |
| filename="config.json", | |
| repo_type="model" | |
| ) | |
| with open(config_path, 'r') as f: | |
| self._config = json.load(f) | |
| # Detect version from config | |
| is_v2 = 'modality_seq_lens' in self._config or 'binding_config' in self._config | |
| version = "v2" if is_v2 else "v1" | |
| print(f"📋 Lyra config prefetched: {self.repo_id} ({version})") | |
| if is_v2: | |
| print(f" Sequence lengths: {self._config.get('modality_seq_lens', {})}") | |
| else: | |
| print(f" Sequence length: {self._config.get('seq_len', 77)}") | |
| except Exception as e: | |
| print(f"⚠️ Could not prefetch Lyra config: {e}") | |
| # Detect version from repo name and use appropriate defaults | |
| is_illustrious = 'illustrious' in self.repo_id.lower() or 'xl' in self.repo_id.lower() | |
| if is_illustrious: | |
| # v2 defaults for SDXL/Illustrious | |
| self._config = { | |
| "modality_dims": { | |
| "clip_l": 768, | |
| "clip_g": 1280, | |
| "t5_xl_l": 2048, | |
| "t5_xl_g": 2048 | |
| }, | |
| "modality_seq_lens": { | |
| "clip_l": 77, | |
| "clip_g": 77, | |
| "t5_xl_l": 512, | |
| "t5_xl_g": 512 | |
| }, | |
| "fusion_strategy": "adaptive_cantor", | |
| "latent_dim": 2048 | |
| } | |
| else: | |
| # v1 defaults for SD1.5 | |
| self._config = { | |
| "modality_dims": { | |
| "clip": 768, | |
| "t5": 768 | |
| }, | |
| "seq_len": 77, | |
| "fusion_strategy": "cantor", | |
| "latent_dim": 768 | |
| } | |
| def config(self) -> Dict: | |
| """Get model config (available before full model load).""" | |
| return self._config or {} | |
| def modality_seq_lens(self) -> Dict[str, int]: | |
| """Get sequence lengths for each modality. | |
| Handles both v1 (seq_len) and v2 (modality_seq_lens) config formats. | |
| """ | |
| # v2 format: modality_seq_lens dict | |
| if 'modality_seq_lens' in self.config: | |
| return self.config['modality_seq_lens'] | |
| # v1 format: derive from single seq_len | |
| seq_len = self.config.get('seq_len', 77) | |
| modality_dims = self.config.get('modality_dims', {}) | |
| # Return seq_len for all modalities in v1 | |
| return {name: seq_len for name in modality_dims.keys()} | |
| def t5_max_length(self) -> int: | |
| """Get T5 max sequence length from config. | |
| Handles both v1 (seq_len) and v2 (modality_seq_lens) config formats. | |
| """ | |
| # v2 format: modality_seq_lens dict | |
| if 'modality_seq_lens' in self.config: | |
| seq_lens = self.config['modality_seq_lens'] | |
| return seq_lens.get('t5_xl_l', seq_lens.get('t5_xl_g', 512)) | |
| # v1 format: single seq_len | |
| return self.config.get('seq_len', 77) | |
| def clip_max_length(self) -> int: | |
| """Get CLIP max sequence length from config. | |
| Handles both v1 (seq_len) and v2 (modality_seq_lens) config formats. | |
| """ | |
| # v2 format: modality_seq_lens dict | |
| if 'modality_seq_lens' in self.config: | |
| seq_lens = self.config['modality_seq_lens'] | |
| return seq_lens.get('clip_l', 77) | |
| # v1 format: single seq_len (same for all modalities) | |
| return self.config.get('seq_len', 77) | |
| def model(self): | |
| if self._model is None: | |
| print(f"📥 Lazy loading Lyra VAE: {self.repo_id}...") | |
| if LYRA_LOADER_AVAILABLE: | |
| # Use the loader module | |
| self._model, self._info = load_vae_lyra( | |
| self.repo_id, | |
| checkpoint=self.checkpoint, | |
| device=self.device, | |
| return_info=True | |
| ) | |
| # Update config from loaded info | |
| if self._info and 'config' in self._info: | |
| self._config = self._info['config'] | |
| else: | |
| # Fallback to manual loading | |
| self._model = self._load_fallback() | |
| self._info = {"repo_id": self.repo_id, "version": "v2", "config": self._config} | |
| self._model.eval() | |
| self._loaded = True | |
| print(f"✓ Lyra VAE loaded") | |
| return self._model | |
| def info(self) -> Optional[Dict]: | |
| if self._info is None: | |
| return {"repo_id": self.repo_id, "config": self._config} | |
| return self._info | |
| def is_loaded(self) -> bool: | |
| return self._loaded | |
| def _load_fallback(self): | |
| """Fallback loading if loader module not available.""" | |
| if not LYRA_V2_AVAILABLE: | |
| raise ImportError("Lyra VAE v2 not available") | |
| # Config already prefetched | |
| config_dict = self._config | |
| # Use provided checkpoint or find one | |
| if self.checkpoint and self.checkpoint.strip(): | |
| checkpoint_file = self.checkpoint.strip() | |
| # Add weights/ prefix if not present and file doesn't exist at root | |
| if not checkpoint_file.startswith('weights/'): | |
| checkpoint_file = f"weights/{checkpoint_file}" | |
| print(f"[Lyra] Using specified checkpoint: {checkpoint_file}") | |
| else: | |
| # Find checkpoint automatically | |
| from huggingface_hub import list_repo_files | |
| repo_files = list_repo_files(self.repo_id, repo_type="model") | |
| checkpoint_files = [f for f in repo_files if f.endswith('.safetensors') or f.endswith('.pt')] | |
| # Prefer weights/ folder | |
| weights_files = [f for f in checkpoint_files if f.startswith('weights/')] | |
| if weights_files: | |
| checkpoint_file = sorted(weights_files)[-1] # Latest | |
| elif checkpoint_files: | |
| checkpoint_file = checkpoint_files[0] | |
| else: | |
| raise FileNotFoundError(f"No checkpoint found in {self.repo_id}") | |
| print(f"[Lyra] Auto-selected checkpoint: {checkpoint_file}") | |
| checkpoint_path = hf_hub_download( | |
| repo_id=self.repo_id, | |
| filename=checkpoint_file, | |
| repo_type="model" | |
| ) | |
| # Load weights | |
| if checkpoint_file.endswith('.safetensors'): | |
| state_dict = load_safetensors(checkpoint_path, device="cpu") | |
| else: | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| state_dict = checkpoint.get('model_state_dict', checkpoint) | |
| # Build config with all fields from prefetched config | |
| vae_config = LyraV2Config( | |
| modality_dims=config_dict.get('modality_dims'), | |
| modality_seq_lens=config_dict.get('modality_seq_lens'), | |
| binding_config=config_dict.get('binding_config'), | |
| latent_dim=config_dict.get('latent_dim', 2048), | |
| hidden_dim=config_dict.get('hidden_dim', 2048), | |
| fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'), | |
| encoder_layers=config_dict.get('encoder_layers', 3), | |
| decoder_layers=config_dict.get('decoder_layers', 3), | |
| fusion_heads=config_dict.get('fusion_heads', 8), | |
| cantor_depth=config_dict.get('cantor_depth', 8), | |
| cantor_local_window=config_dict.get('cantor_local_window', 3), | |
| alpha_init=config_dict.get('alpha_init', 1.0), | |
| beta_init=config_dict.get('beta_init', 0.3), | |
| ) | |
| model = LyraV2(vae_config) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.to(self.device) | |
| return model | |
| def unload(self): | |
| """Free VRAM by unloading the model.""" | |
| if self._model is not None: | |
| del self._model | |
| self._model = None | |
| self._info = None | |
| self._loaded = False | |
| torch.cuda.empty_cache() | |
| print("🗑️ Lyra VAE unloaded") | |
| # ============================================================================ | |
| # SCHEDULER FACTORY | |
| # ============================================================================ | |
| def get_scheduler( | |
| scheduler_name: str, | |
| config_source: str = "stabilityai/stable-diffusion-xl-base-1.0", | |
| is_sdxl: bool = True | |
| ): | |
| """Create scheduler by name. | |
| Args: | |
| scheduler_name: One of SCHEDULER_CHOICES | |
| config_source: HF repo to load scheduler config from | |
| is_sdxl: Whether this is for SDXL (affects some defaults) | |
| Returns: | |
| Configured scheduler instance | |
| """ | |
| subfolder = "scheduler" | |
| if scheduler_name == SCHEDULER_EULER_A: | |
| return EulerAncestralDiscreteScheduler.from_pretrained( | |
| config_source, | |
| subfolder=subfolder | |
| ) | |
| elif scheduler_name == SCHEDULER_EULER: | |
| return EulerDiscreteScheduler.from_pretrained( | |
| config_source, | |
| subfolder=subfolder | |
| ) | |
| elif scheduler_name == SCHEDULER_DPM_2M_SDE: | |
| # DPM++ 2M SDE - good for detailed images | |
| return DPMSolverSDEScheduler.from_pretrained( | |
| config_source, | |
| subfolder=subfolder, | |
| algorithm_type="sde-dpmsolver++", | |
| solver_order=2, | |
| use_karras_sigmas=True, | |
| ) | |
| elif scheduler_name == SCHEDULER_DPM_2M: | |
| # DPM++ 2M - fast and quality | |
| return DPMSolverMultistepScheduler.from_pretrained( | |
| config_source, | |
| subfolder=subfolder, | |
| algorithm_type="dpmsolver++", | |
| solver_order=2, | |
| use_karras_sigmas=True, | |
| ) | |
| else: | |
| print(f"⚠️ Unknown scheduler '{scheduler_name}', defaulting to Euler Ancestral") | |
| return EulerAncestralDiscreteScheduler.from_pretrained( | |
| config_source, | |
| subfolder=subfolder | |
| ) | |
| # ============================================================================ | |
| # UTILITIES | |
| # ============================================================================ | |
| def get_clip_hidden_state( | |
| model_output, | |
| clip_skip: int = 1, | |
| output_hidden_states: bool = True | |
| ) -> torch.Tensor: | |
| """Extract hidden state with clip_skip support.""" | |
| if clip_skip == 1 or not output_hidden_states: | |
| return model_output.last_hidden_state | |
| if hasattr(model_output, 'hidden_states') and model_output.hidden_states is not None: | |
| return model_output.hidden_states[-clip_skip] | |
| return model_output.last_hidden_state | |
| # ============================================================================ | |
| # SDXL PIPELINE | |
| # ============================================================================ | |
| class SDXLFlowMatchingPipeline: | |
| """Pipeline for SDXL-based flow-matching inference with dual CLIP encoders. | |
| Uses lazy loading for T5 and Lyra - they're only downloaded when actually used. | |
| """ | |
| def __init__( | |
| self, | |
| vae: AutoencoderKL, | |
| text_encoder: CLIPTextModel, | |
| text_encoder_2: CLIPTextModelWithProjection, | |
| tokenizer: CLIPTokenizer, | |
| tokenizer_2: CLIPTokenizer, | |
| unet: UNet2DConditionModel, | |
| scheduler, | |
| device: str = "cuda", | |
| t5_loader: Optional[LazyT5Encoder] = None, | |
| lyra_loader: Optional[LazyLyraModel] = None, | |
| clip_skip: int = 1 | |
| ): | |
| self.vae = vae | |
| self.text_encoder = text_encoder | |
| self.text_encoder_2 = text_encoder_2 | |
| self.tokenizer = tokenizer | |
| self.tokenizer_2 = tokenizer_2 | |
| self.unet = unet | |
| self.scheduler = scheduler | |
| self.device = device | |
| # Lazy loaders for Lyra components | |
| self.t5_loader = t5_loader | |
| self.lyra_loader = lyra_loader | |
| # Settings | |
| self.clip_skip = clip_skip | |
| self.vae_scale_factor = 0.13025 | |
| self.arch = ARCH_SDXL | |
| # Track current scheduler name for UI | |
| self._scheduler_name = SCHEDULER_EULER_A | |
| def set_scheduler(self, scheduler_name: str): | |
| """Switch scheduler without reloading model.""" | |
| if scheduler_name != self._scheduler_name: | |
| self.scheduler = get_scheduler( | |
| scheduler_name, | |
| config_source="stabilityai/stable-diffusion-xl-base-1.0", | |
| is_sdxl=True | |
| ) | |
| self._scheduler_name = scheduler_name | |
| print(f"✓ Scheduler changed to: {scheduler_name}") | |
| def t5_encoder(self) -> Optional[T5EncoderModel]: | |
| """Access T5 encoder (triggers lazy load if needed).""" | |
| return self.t5_loader.encoder if self.t5_loader else None | |
| def t5_tokenizer(self) -> Optional[T5Tokenizer]: | |
| """Access T5 tokenizer (triggers lazy load if needed).""" | |
| return self.t5_loader.tokenizer if self.t5_loader else None | |
| def lyra_model(self): | |
| """Access Lyra model (triggers lazy load if needed).""" | |
| return self.lyra_loader.model if self.lyra_loader else None | |
| def lyra_available(self) -> bool: | |
| """Check if Lyra components are configured (not necessarily loaded).""" | |
| return self.t5_loader is not None and self.lyra_loader is not None | |
| def encode_prompt( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| clip_skip: int = 1 | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Encode prompts using dual CLIP encoders for SDXL.""" | |
| # CLIP-L encoding | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| output_hidden_states = clip_skip > 1 | |
| clip_l_output = self.text_encoder( | |
| text_input_ids, | |
| output_hidden_states=output_hidden_states | |
| ) | |
| prompt_embeds_l = get_clip_hidden_state(clip_l_output, clip_skip, output_hidden_states) | |
| # CLIP-G encoding | |
| text_inputs_2 = self.tokenizer_2( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer_2.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids_2 = text_inputs_2.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| clip_g_output = self.text_encoder_2( | |
| text_input_ids_2, | |
| output_hidden_states=output_hidden_states | |
| ) | |
| prompt_embeds_g = get_clip_hidden_state(clip_g_output, clip_skip, output_hidden_states) | |
| pooled_prompt_embeds = clip_g_output.text_embeds | |
| # Concatenate CLIP-L and CLIP-G embeddings | |
| prompt_embeds = torch.cat([prompt_embeds_l, prompt_embeds_g], dim=-1) | |
| # Negative prompt | |
| if negative_prompt: | |
| uncond_inputs = self.tokenizer( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| uncond_input_ids = uncond_inputs.input_ids.to(self.device) | |
| uncond_inputs_2 = self.tokenizer_2( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer_2.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| uncond_input_ids_2 = uncond_inputs_2.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| uncond_output_l = self.text_encoder( | |
| uncond_input_ids, | |
| output_hidden_states=output_hidden_states | |
| ) | |
| negative_embeds_l = get_clip_hidden_state(uncond_output_l, clip_skip, output_hidden_states) | |
| uncond_output_g = self.text_encoder_2( | |
| uncond_input_ids_2, | |
| output_hidden_states=output_hidden_states | |
| ) | |
| negative_embeds_g = get_clip_hidden_state(uncond_output_g, clip_skip, output_hidden_states) | |
| negative_pooled = uncond_output_g.text_embeds | |
| negative_prompt_embeds = torch.cat([negative_embeds_l, negative_embeds_g], dim=-1) | |
| else: | |
| negative_prompt_embeds = torch.zeros_like(prompt_embeds) | |
| negative_pooled = torch.zeros_like(pooled_prompt_embeds) | |
| return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled | |
| def encode_prompt_lyra( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| clip_skip: int = 1, | |
| t5_summary: str = "", | |
| lyra_strength: float = 0.3, | |
| use_separator: bool = True, | |
| clip_include_summary: bool = False | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Encode prompts using Lyra VAE v2 fusion (CLIP + T5). | |
| CLIP encoders receive tags only (prompt field). | |
| T5 encoder receives tags + separator + summary. | |
| Args: | |
| prompt: Tags/keywords for CLIP encoding | |
| negative_prompt: Negative tags | |
| clip_skip: CLIP skip layers | |
| t5_summary: Natural language summary for T5 | |
| lyra_strength: Blend factor (0=pure CLIP, 1=pure Lyra) | |
| use_separator: If True, use ¶ separator between tags and summary | |
| clip_include_summary: If True, append summary to CLIP input (default False) | |
| This triggers lazy loading of T5 and Lyra if not already loaded. | |
| Uses sequence lengths from Lyra config for proper tokenization. | |
| """ | |
| if not self.lyra_available: | |
| raise ValueError("Lyra VAE components not configured") | |
| # Get sequence lengths from Lyra config (available before full load) | |
| t5_max_length = self.lyra_loader.t5_max_length # 512 for Illustrious | |
| clip_max_length = self.lyra_loader.clip_max_length # 77 for Illustrious | |
| print(f"[Lyra] Using sequence lengths: CLIP={clip_max_length}, T5={t5_max_length}") | |
| # Access properties triggers lazy load | |
| t5_encoder = self.t5_encoder | |
| t5_tokenizer = self.t5_tokenizer | |
| lyra_model = self.lyra_model | |
| # === CLIP ENCODING === | |
| # CLIP sees tags only (unless clip_include_summary is True) | |
| if clip_include_summary and t5_summary.strip(): | |
| clip_prompt = f"{prompt} {t5_summary}" | |
| else: | |
| clip_prompt = prompt | |
| # Get CLIP embeddings with tags only | |
| prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt( | |
| clip_prompt, negative_prompt, clip_skip | |
| ) | |
| # === T5 ENCODING === | |
| # T5 sees tags + separator + summary (or tags + summary if no separator) | |
| SUMMARY_SEPARATOR = "¶" | |
| if t5_summary.strip(): | |
| if use_separator: | |
| t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {t5_summary}" | |
| else: | |
| t5_prompt = f"{prompt} {t5_summary}" | |
| else: | |
| # No summary provided - T5 just sees the tags | |
| t5_prompt = prompt | |
| print(f"[Lyra] CLIP input: {clip_prompt[:80]}...") | |
| print(f"[Lyra] T5 input: {t5_prompt[:80]}...") | |
| # Get T5 embeddings with config-specified max_length | |
| t5_inputs = t5_tokenizer( | |
| t5_prompt, | |
| max_length=t5_max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| t5_embeds = t5_encoder(**t5_inputs).last_hidden_state | |
| # === LYRA FUSION === | |
| clip_l_dim = 768 | |
| clip_g_dim = 1280 | |
| clip_l_embeds = prompt_embeds[..., :clip_l_dim] | |
| clip_g_embeds = prompt_embeds[..., clip_l_dim:] | |
| with torch.no_grad(): | |
| modality_inputs = { | |
| 'clip_l': clip_l_embeds.float(), | |
| 'clip_g': clip_g_embeds.float(), | |
| 't5_xl_l': t5_embeds.float(), | |
| 't5_xl_g': t5_embeds.float() | |
| } | |
| reconstructions, mu, logvar, _ = lyra_model( | |
| modality_inputs, | |
| target_modalities=['clip_l', 'clip_g'] | |
| ) | |
| lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype) | |
| lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype) | |
| # Normalize reconstructions to match input statistics | |
| clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8) | |
| clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8) | |
| if clip_l_std_ratio > 2.0 or clip_l_std_ratio < 0.5: | |
| lyra_clip_l = (lyra_clip_l - lyra_clip_l.mean()) / (lyra_clip_l.std() + 1e-8) | |
| lyra_clip_l = lyra_clip_l * clip_l_embeds.std() + clip_l_embeds.mean() | |
| if clip_g_std_ratio > 2.0 or clip_g_std_ratio < 0.5: | |
| lyra_clip_g = (lyra_clip_g - lyra_clip_g.mean()) / (lyra_clip_g.std() + 1e-8) | |
| lyra_clip_g = lyra_clip_g * clip_g_embeds.std() + clip_g_embeds.mean() | |
| # Blend original CLIP with Lyra reconstruction | |
| fused_clip_l = (1 - lyra_strength) * clip_l_embeds + lyra_strength * lyra_clip_l | |
| fused_clip_g = (1 - lyra_strength) * clip_g_embeds + lyra_strength * lyra_clip_g | |
| prompt_embeds_fused = torch.cat([fused_clip_l, fused_clip_g], dim=-1) | |
| # === NEGATIVE PROMPT === | |
| # Negative uses same logic: CLIP sees negative tags only | |
| if negative_prompt: | |
| neg_strength = lyra_strength * 0.5 # Less aggressive for negative | |
| # T5 negative: tags only (no summary for negative) | |
| t5_neg_prompt = negative_prompt | |
| t5_inputs_neg = t5_tokenizer( | |
| t5_neg_prompt, | |
| max_length=t5_max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| t5_embeds_neg = t5_encoder(**t5_inputs_neg).last_hidden_state | |
| neg_clip_l = negative_prompt_embeds[..., :clip_l_dim] | |
| neg_clip_g = negative_prompt_embeds[..., clip_l_dim:] | |
| modality_inputs_neg = { | |
| 'clip_l': neg_clip_l.float(), | |
| 'clip_g': neg_clip_g.float(), | |
| 't5_xl_l': t5_embeds_neg.float(), | |
| 't5_xl_g': t5_embeds_neg.float() | |
| } | |
| recon_neg, _, _, _ = lyra_model(modality_inputs_neg, target_modalities=['clip_l', 'clip_g']) | |
| lyra_neg_l = recon_neg['clip_l'].to(negative_prompt_embeds.dtype) | |
| lyra_neg_g = recon_neg['clip_g'].to(negative_prompt_embeds.dtype) | |
| # Normalize | |
| neg_l_ratio = lyra_neg_l.std() / (neg_clip_l.std() + 1e-8) | |
| neg_g_ratio = lyra_neg_g.std() / (neg_clip_g.std() + 1e-8) | |
| if neg_l_ratio > 2.0 or neg_l_ratio < 0.5: | |
| lyra_neg_l = (lyra_neg_l - lyra_neg_l.mean()) / (lyra_neg_l.std() + 1e-8) | |
| lyra_neg_l = lyra_neg_l * neg_clip_l.std() + neg_clip_l.mean() | |
| if neg_g_ratio > 2.0 or neg_g_ratio < 0.5: | |
| lyra_neg_g = (lyra_neg_g - lyra_neg_g.mean()) / (lyra_neg_g.std() + 1e-8) | |
| lyra_neg_g = lyra_neg_g * neg_clip_g.std() + neg_clip_g.mean() | |
| fused_neg_l = (1 - neg_strength) * neg_clip_l + neg_strength * lyra_neg_l | |
| fused_neg_g = (1 - neg_strength) * neg_clip_g + neg_strength * lyra_neg_g | |
| negative_prompt_embeds_fused = torch.cat([fused_neg_l, fused_neg_g], dim=-1) | |
| else: | |
| negative_prompt_embeds_fused = torch.zeros_like(prompt_embeds_fused) | |
| return prompt_embeds_fused, negative_prompt_embeds_fused, pooled, negative_pooled | |
| def _get_add_time_ids( | |
| self, | |
| original_size: Tuple[int, int], | |
| crops_coords_top_left: Tuple[int, int], | |
| target_size: Tuple[int, int], | |
| dtype: torch.dtype | |
| ) -> torch.Tensor: | |
| """Create time embedding IDs for SDXL.""" | |
| add_time_ids = list(original_size + crops_coords_top_left + target_size) | |
| add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device) | |
| return add_time_ids | |
| def __call__( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| height: int = 1024, | |
| width: int = 1024, | |
| num_inference_steps: int = 20, | |
| guidance_scale: float = 7.5, | |
| shift: float = 0.0, | |
| use_flow_matching: bool = False, | |
| prediction_type: str = "epsilon", | |
| seed: Optional[int] = None, | |
| use_lyra: bool = False, | |
| clip_skip: int = 1, | |
| t5_summary: str = "", | |
| lyra_strength: float = 1.0, | |
| use_separator: bool = True, | |
| clip_include_summary: bool = False, | |
| progress_callback=None | |
| ): | |
| """Generate image using SDXL architecture. | |
| Args: | |
| prompt: Tags/keywords for image generation | |
| negative_prompt: Negative tags | |
| t5_summary: Natural language summary (T5 only, unless clip_include_summary=True) | |
| use_separator: Use ¶ separator between tags and summary in T5 input | |
| clip_include_summary: If True, append summary to CLIP input (default False) | |
| seed: Random seed for reproducibility (must be int) | |
| """ | |
| # Create generator with seed for deterministic generation | |
| if seed is not None: | |
| seed = int(seed) | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| print(f"[SDXL Pipeline] Using seed: {seed}") | |
| else: | |
| generator = None | |
| print("[SDXL Pipeline] No seed provided, using random") | |
| # Encode prompts (Lyra triggers lazy load only if use_lyra=True) | |
| if use_lyra and self.lyra_available: | |
| prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra( | |
| prompt, negative_prompt, clip_skip, t5_summary, lyra_strength, | |
| use_separator=use_separator, | |
| clip_include_summary=clip_include_summary | |
| ) | |
| else: | |
| prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt( | |
| prompt, negative_prompt, clip_skip | |
| ) | |
| # Prepare latents | |
| latent_channels = 4 | |
| latent_height = height // 8 | |
| latent_width = width // 8 | |
| latents = torch.randn( | |
| (1, latent_channels, latent_height, latent_width), | |
| generator=generator, | |
| device=self.device, | |
| dtype=torch.float16 | |
| ) | |
| # Set timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=self.device) | |
| timesteps = self.scheduler.timesteps | |
| if not use_flow_matching: | |
| latents = latents * self.scheduler.init_noise_sigma | |
| # Prepare added time embeddings for SDXL | |
| original_size = (height, width) | |
| target_size = (height, width) | |
| crops_coords_top_left = (0, 0) | |
| add_time_ids = self._get_add_time_ids( | |
| original_size, crops_coords_top_left, target_size, dtype=torch.float16 | |
| ) | |
| negative_add_time_ids = add_time_ids | |
| # Denoising loop | |
| for i, t in enumerate(timesteps): | |
| if progress_callback: | |
| progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}") | |
| latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents | |
| if use_flow_matching and shift > 0: | |
| sigma = t.float() / 1000.0 | |
| sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) | |
| scaling = torch.sqrt(1 + sigma_shifted ** 2) | |
| latent_model_input = latent_model_input / scaling | |
| else: | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| timestep = t.expand(latent_model_input.shape[0]) | |
| if guidance_scale > 1.0: | |
| text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
| add_text_embeds = torch.cat([negative_pooled, pooled]) | |
| add_time_ids_input = torch.cat([negative_add_time_ids, add_time_ids]) | |
| else: | |
| text_embeds = prompt_embeds | |
| add_text_embeds = pooled | |
| add_time_ids_input = add_time_ids | |
| added_cond_kwargs = { | |
| "text_embeds": add_text_embeds, | |
| "time_ids": add_time_ids_input | |
| } | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| timestep, | |
| encoder_hidden_states=text_embeds, | |
| added_cond_kwargs=added_cond_kwargs, | |
| return_dict=False | |
| )[0] | |
| if guidance_scale > 1.0: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| if use_flow_matching: | |
| sigma = t.float() / 1000.0 | |
| sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) | |
| if prediction_type == "v_prediction": | |
| v_pred = noise_pred | |
| alpha_t = torch.sqrt(1 - sigma_shifted ** 2) | |
| sigma_t = sigma_shifted | |
| noise_pred = alpha_t * v_pred + sigma_t * latents | |
| dt = -1.0 / num_inference_steps | |
| latents = latents + dt * noise_pred | |
| else: | |
| # Pass generator for deterministic ancestral/SDE sampling | |
| latents = self.scheduler.step( | |
| noise_pred, t, latents, generator=generator, return_dict=False | |
| )[0] | |
| # Decode | |
| latents = latents / self.vae_scale_factor | |
| with torch.no_grad(): | |
| image = self.vae.decode(latents.to(self.vae.dtype)).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
| image = (image * 255).round().astype("uint8") | |
| image = Image.fromarray(image[0]) | |
| return image | |
| # ============================================================================ | |
| # SD1.5 PIPELINE | |
| # ============================================================================ | |
| class SD15FlowMatchingPipeline: | |
| """Pipeline for SD1.5-based flow-matching inference.""" | |
| def __init__( | |
| self, | |
| vae: AutoencoderKL, | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| unet: UNet2DConditionModel, | |
| scheduler, | |
| device: str = "cuda", | |
| t5_loader: Optional[LazyT5Encoder] = None, | |
| lyra_loader: Optional[LazyLyraModel] = None, | |
| ): | |
| self.vae = vae | |
| self.text_encoder = text_encoder | |
| self.tokenizer = tokenizer | |
| self.unet = unet | |
| self.scheduler = scheduler | |
| self.device = device | |
| self.t5_loader = t5_loader | |
| self.lyra_loader = lyra_loader | |
| self.vae_scale_factor = 0.18215 | |
| self.arch = ARCH_SD15 | |
| self.is_lune_model = False | |
| def t5_encoder(self): | |
| return self.t5_loader.encoder if self.t5_loader else None | |
| def t5_tokenizer(self): | |
| return self.t5_loader.tokenizer if self.t5_loader else None | |
| def lyra_model(self): | |
| return self.lyra_loader.model if self.lyra_loader else None | |
| def lyra_available(self) -> bool: | |
| return self.t5_loader is not None and self.lyra_loader is not None | |
| def encode_prompt(self, prompt: str, negative_prompt: str = ""): | |
| """Encode text prompts to embeddings.""" | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| prompt_embeds = self.text_encoder(text_input_ids)[0] | |
| if negative_prompt: | |
| uncond_inputs = self.tokenizer( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| uncond_input_ids = uncond_inputs.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| negative_prompt_embeds = self.text_encoder(uncond_input_ids)[0] | |
| else: | |
| negative_prompt_embeds = torch.zeros_like(prompt_embeds) | |
| return prompt_embeds, negative_prompt_embeds | |
| def encode_prompt_lyra(self, prompt: str, negative_prompt: str = ""): | |
| """Encode using Lyra VAE v1 (CLIP + T5 fusion). | |
| Uses sequence lengths from Lyra config for proper tokenization. | |
| """ | |
| if not self.lyra_available: | |
| raise ValueError("Lyra VAE components not configured") | |
| # Get sequence length from config (v1 uses same length for clip and t5) | |
| # Default to 77 for SD1.5/v1 | |
| t5_max_length = self.lyra_loader.config.get('seq_len', 77) | |
| print(f"[Lyra v1] Using sequence length: {t5_max_length}") | |
| t5_encoder = self.t5_encoder | |
| t5_tokenizer = self.t5_tokenizer | |
| lyra_model = self.lyra_model | |
| # CLIP | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| clip_embeds = self.text_encoder(text_input_ids)[0] | |
| # T5 with config-specified max_length | |
| t5_inputs = t5_tokenizer( | |
| prompt, | |
| max_length=t5_max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| t5_embeds = t5_encoder(**t5_inputs).last_hidden_state | |
| # Fuse | |
| modality_inputs = {'clip': clip_embeds, 't5': t5_embeds} | |
| with torch.no_grad(): | |
| reconstructions, mu, logvar = lyra_model( | |
| modality_inputs, | |
| target_modalities=['clip'] | |
| ) | |
| prompt_embeds = reconstructions['clip'] | |
| # Negative | |
| if negative_prompt: | |
| uncond_inputs = self.tokenizer( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| uncond_input_ids = uncond_inputs.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| clip_embeds_uncond = self.text_encoder(uncond_input_ids)[0] | |
| t5_inputs_uncond = t5_tokenizer( | |
| negative_prompt, | |
| max_length=t5_max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| t5_embeds_uncond = t5_encoder(**t5_inputs_uncond).last_hidden_state | |
| modality_inputs_uncond = {'clip': clip_embeds_uncond, 't5': t5_embeds_uncond} | |
| with torch.no_grad(): | |
| reconstructions_uncond, _, _ = lyra_model( | |
| modality_inputs_uncond, | |
| target_modalities=['clip'] | |
| ) | |
| negative_prompt_embeds = reconstructions_uncond['clip'] | |
| else: | |
| negative_prompt_embeds = torch.zeros_like(prompt_embeds) | |
| return prompt_embeds, negative_prompt_embeds | |
| def __call__( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| height: int = 512, | |
| width: int = 512, | |
| num_inference_steps: int = 20, | |
| guidance_scale: float = 7.5, | |
| shift: float = 2.5, | |
| use_flow_matching: bool = True, | |
| prediction_type: str = "epsilon", | |
| seed: Optional[int] = None, | |
| use_lyra: bool = False, | |
| clip_skip: int = 1, | |
| t5_summary: str = "", | |
| lyra_strength: float = 1.0, | |
| progress_callback=None | |
| ): | |
| """Generate image.""" | |
| # Create generator with seed for deterministic generation | |
| if seed is not None: | |
| seed = int(seed) | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| print(f"[SD1.5 Pipeline] Using seed: {seed}") | |
| else: | |
| generator = None | |
| print("[SD1.5 Pipeline] No seed provided, using random") | |
| if use_lyra and self.lyra_available: | |
| prompt_embeds, negative_prompt_embeds = self.encode_prompt_lyra(prompt, negative_prompt) | |
| else: | |
| prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, negative_prompt) | |
| latent_channels = 4 | |
| latent_height = height // 8 | |
| latent_width = width // 8 | |
| latents = torch.randn( | |
| (1, latent_channels, latent_height, latent_width), | |
| generator=generator, | |
| device=self.device, | |
| dtype=torch.float32 | |
| ) | |
| self.scheduler.set_timesteps(num_inference_steps, device=self.device) | |
| timesteps = self.scheduler.timesteps | |
| if not use_flow_matching: | |
| latents = latents * self.scheduler.init_noise_sigma | |
| for i, t in enumerate(timesteps): | |
| if progress_callback: | |
| progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}") | |
| latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents | |
| if use_flow_matching and shift > 0: | |
| sigma = t.float() / 1000.0 | |
| sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) | |
| scaling = torch.sqrt(1 + sigma_shifted ** 2) | |
| latent_model_input = latent_model_input / scaling | |
| else: | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| timestep = t.expand(latent_model_input.shape[0]) | |
| text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if guidance_scale > 1.0 else prompt_embeds | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| timestep, | |
| encoder_hidden_states=text_embeds, | |
| return_dict=False | |
| )[0] | |
| if guidance_scale > 1.0: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| if use_flow_matching: | |
| sigma = t.float() / 1000.0 | |
| sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) | |
| if prediction_type == "v_prediction": | |
| v_pred = noise_pred | |
| alpha_t = torch.sqrt(1 - sigma_shifted ** 2) | |
| sigma_t = sigma_shifted | |
| noise_pred = alpha_t * v_pred + sigma_t * latents | |
| dt = -1.0 / num_inference_steps | |
| latents = latents + dt * noise_pred | |
| else: | |
| # Pass generator for deterministic ancestral/SDE sampling | |
| latents = self.scheduler.step( | |
| noise_pred, t, latents, generator=generator, return_dict=False | |
| )[0] | |
| latents = latents / self.vae_scale_factor | |
| if self.is_lune_model: | |
| latents = latents * 5.52 | |
| with torch.no_grad(): | |
| image = self.vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
| image = (image * 255).round().astype("uint8") | |
| image = Image.fromarray(image[0]) | |
| return image | |
| # ============================================================================ | |
| # MODEL LOADERS | |
| # ============================================================================ | |
| def load_lune_checkpoint(repo_id: str, filename: str, device: str = "cuda"): | |
| """Load Lune checkpoint from .pt file.""" | |
| print(f"📥 Downloading: {repo_id}/{filename}") | |
| checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model") | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| print(f"🏗️ Initializing SD1.5 UNet...") | |
| unet = UNet2DConditionModel.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| subfolder="unet", | |
| torch_dtype=torch.float32 | |
| ) | |
| student_state_dict = checkpoint["student"] | |
| cleaned_dict = {} | |
| for key, value in student_state_dict.items(): | |
| if key.startswith("unet."): | |
| cleaned_dict[key[5:]] = value | |
| else: | |
| cleaned_dict[key] = value | |
| unet.load_state_dict(cleaned_dict, strict=False) | |
| step = checkpoint.get("gstep", "unknown") | |
| print(f"✅ Loaded Lune from step {step}") | |
| return unet.to(device) | |
| def load_illustrious_xl( | |
| repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious", | |
| filename: str = "", | |
| device: str = "cuda" | |
| ) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]: | |
| """Load Illustrious XL from single safetensors file.""" | |
| from diffusers import StableDiffusionXLPipeline | |
| # Default checkpoint if none specified | |
| if not filename or not filename.strip(): | |
| filename = "illustriousXL_v01.safetensors" | |
| print(f"📥 Loading Illustrious XL: {repo_id}/{filename}") | |
| checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model") | |
| print(f"✓ Downloaded: {checkpoint_path}") | |
| print("📦 Loading with StableDiffusionXLPipeline.from_single_file()...") | |
| pipe = StableDiffusionXLPipeline.from_single_file( | |
| checkpoint_path, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True, | |
| ) | |
| unet = pipe.unet.to(device) | |
| vae = pipe.vae.to(device) | |
| text_encoder = pipe.text_encoder.to(device) | |
| text_encoder_2 = pipe.text_encoder_2.to(device) | |
| tokenizer = pipe.tokenizer | |
| tokenizer_2 = pipe.tokenizer_2 | |
| del pipe | |
| torch.cuda.empty_cache() | |
| print("✅ Illustrious XL loaded!") | |
| print(f" UNet params: {sum(p.numel() for p in unet.parameters()):,}") | |
| print(f" VAE params: {sum(p.numel() for p in vae.parameters()):,}") | |
| return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 | |
| def load_sdxl_base(device: str = "cuda"): | |
| """Load standard SDXL base model.""" | |
| print("📥 Loading SDXL Base 1.0...") | |
| unet = UNet2DConditionModel.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="unet", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| # Use fp16-fix VAE to avoid NaN issues with SDXL's original VAE in fp16 | |
| print(" Using madebyollin/sdxl-vae-fp16-fix for stable fp16 decoding...") | |
| vae = AutoencoderKL.from_pretrained( | |
| "madebyollin/sdxl-vae-fp16-fix", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="text_encoder", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="text_encoder_2", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="tokenizer" | |
| ) | |
| tokenizer_2 = CLIPTokenizer.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="tokenizer_2" | |
| ) | |
| print("✅ SDXL Base loaded!") | |
| return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 | |
| # ============================================================================ | |
| # PIPELINE INITIALIZATION | |
| # ============================================================================ | |
| def initialize_pipeline(model_choice: str, device: str = "cuda", checkpoint: str = "", lyra_checkpoint: str = ""): | |
| """Initialize the complete pipeline based on model choice. | |
| Uses lazy loading for T5 and Lyra - they won't be downloaded until first use. | |
| Args: | |
| model_choice: Model selection from dropdown | |
| device: Target device | |
| checkpoint: Optional custom checkpoint filename (e.g., "my_model.safetensors") | |
| lyra_checkpoint: Optional custom Lyra VAE checkpoint filename | |
| """ | |
| print(f"🚀 Initializing {model_choice} pipeline...") | |
| if checkpoint: | |
| print(f" Custom model checkpoint: {checkpoint}") | |
| if lyra_checkpoint: | |
| print(f" Custom Lyra checkpoint: {lyra_checkpoint}") | |
| is_sdxl = "Illustrious" in model_choice or "SDXL" in model_choice | |
| is_lune = "Lune" in model_choice | |
| if is_sdxl: | |
| # SDXL-based models | |
| if "Illustrious" in model_choice: | |
| unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl( | |
| device=device, | |
| filename=checkpoint | |
| ) | |
| else: | |
| unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_sdxl_base(device=device) | |
| # Create LAZY loaders for T5 and Lyra (no download yet!) | |
| print("📋 Configuring lazy loaders for T5-XL and Lyra VAE (will download on first use)") | |
| t5_loader = LazyT5Encoder( | |
| model_name=T5_XL_MODEL, # google/flan-t5-xl | |
| device=device, | |
| dtype=torch.float16 | |
| ) | |
| lyra_loader = LazyLyraModel( | |
| repo_id=LYRA_ILLUSTRIOUS_REPO, | |
| device=device, | |
| checkpoint=lyra_checkpoint if lyra_checkpoint and lyra_checkpoint.strip() else None | |
| ) | |
| # Default scheduler: Euler Ancestral | |
| scheduler = get_scheduler(SCHEDULER_EULER_A, is_sdxl=True) | |
| pipeline = SDXLFlowMatchingPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| text_encoder_2=text_encoder_2, | |
| tokenizer=tokenizer, | |
| tokenizer_2=tokenizer_2, | |
| unet=unet, | |
| scheduler=scheduler, | |
| device=device, | |
| t5_loader=t5_loader, | |
| lyra_loader=lyra_loader, | |
| clip_skip=1 | |
| ) | |
| else: | |
| # SD1.5-based models | |
| vae = AutoencoderKL.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| subfolder="vae", | |
| torch_dtype=torch.float32 | |
| ).to(device) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| "openai/clip-vit-large-patch14", | |
| torch_dtype=torch.float32 | |
| ).to(device) | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
| # Lazy loaders for SD1.5 Lyra (T5-base) | |
| print("📋 Configuring lazy loaders for T5-base and Lyra VAE v1 (will download on first use)") | |
| t5_loader = LazyT5Encoder( | |
| model_name=T5_BASE_MODEL, # google/flan-t5-base | |
| device=device, | |
| dtype=torch.float32 | |
| ) | |
| lyra_loader = LazyLyraModel( | |
| repo_id=LYRA_SD15_REPO, | |
| device=device, | |
| checkpoint=lyra_checkpoint if lyra_checkpoint and lyra_checkpoint.strip() else None | |
| ) | |
| # Load UNet | |
| if is_lune: | |
| repo_id = "AbstractPhil/sd15-flow-lune" | |
| # Use custom checkpoint or default | |
| if checkpoint and checkpoint.strip(): | |
| filename = checkpoint | |
| else: | |
| filename = "sd15_flow_lune_e34_s34000.pt" | |
| unet = load_lune_checkpoint(repo_id, filename, device) | |
| else: | |
| unet = UNet2DConditionModel.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| subfolder="unet", | |
| torch_dtype=torch.float32 | |
| ).to(device) | |
| scheduler = EulerDiscreteScheduler.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| subfolder="scheduler" | |
| ) | |
| pipeline = SD15FlowMatchingPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| scheduler=scheduler, | |
| device=device, | |
| t5_loader=t5_loader, | |
| lyra_loader=lyra_loader, | |
| ) | |
| pipeline.is_lune_model = is_lune | |
| print("✅ Pipeline initialized! (T5 and Lyra will load on first use)") | |
| return pipeline | |
| # ============================================================================ | |
| # GLOBAL STATE | |
| # ============================================================================ | |
| CURRENT_PIPELINE = None | |
| CURRENT_MODEL = None | |
| CURRENT_CHECKPOINT = None | |
| CURRENT_LYRA_CHECKPOINT = None | |
| def get_pipeline(model_choice: str, checkpoint: str = "", lyra_checkpoint: str = ""): | |
| """Get or create pipeline for selected model.""" | |
| global CURRENT_PIPELINE, CURRENT_MODEL, CURRENT_CHECKPOINT, CURRENT_LYRA_CHECKPOINT | |
| # Normalize empty values | |
| checkpoint = checkpoint.strip() if checkpoint else "" | |
| lyra_checkpoint = lyra_checkpoint.strip() if lyra_checkpoint else "" | |
| # Reinitialize if model or any checkpoint changed | |
| if (CURRENT_PIPELINE is None or | |
| CURRENT_MODEL != model_choice or | |
| CURRENT_CHECKPOINT != checkpoint or | |
| CURRENT_LYRA_CHECKPOINT != lyra_checkpoint): | |
| CURRENT_PIPELINE = initialize_pipeline( | |
| model_choice, device="cuda", | |
| checkpoint=checkpoint, | |
| lyra_checkpoint=lyra_checkpoint | |
| ) | |
| CURRENT_MODEL = model_choice | |
| CURRENT_CHECKPOINT = checkpoint | |
| CURRENT_LYRA_CHECKPOINT = lyra_checkpoint | |
| return CURRENT_PIPELINE | |
| # ============================================================================ | |
| # INFERENCE | |
| # ============================================================================ | |
| def estimate_duration(num_steps: int, width: int, height: int, use_lyra: bool = False, is_sdxl: bool = False) -> int: | |
| """Estimate GPU duration.""" | |
| base_time_per_step = 0.5 if is_sdxl else 0.3 | |
| resolution_factor = (width * height) / (512 * 512) | |
| estimated = num_steps * base_time_per_step * resolution_factor | |
| if use_lyra: | |
| estimated *= 2 | |
| estimated += 10 # Extra time for lazy loading on first use | |
| return int(estimated + 20) | |
| def generate_image( | |
| prompt: str, | |
| t5_summary: str, | |
| negative_prompt: str, | |
| model_choice: str, | |
| checkpoint: str, | |
| lyra_checkpoint: str, | |
| scheduler_choice: str, | |
| clip_skip: int, | |
| num_steps: int, | |
| cfg_scale: float, | |
| width: int, | |
| height: int, | |
| shift: float, | |
| use_flow_matching: bool, | |
| use_lyra: bool, | |
| lyra_strength: float, | |
| use_separator: bool, | |
| clip_include_summary: bool, | |
| seed: int, | |
| randomize_seed: bool, | |
| progress=gr.Progress() | |
| ): | |
| """Generate image with ZeroGPU support. | |
| Args: | |
| prompt: Tags/keywords (CLIP input) | |
| t5_summary: Natural language summary (T5 input, unless clip_include_summary) | |
| checkpoint: Custom model checkpoint filename (empty for default) | |
| lyra_checkpoint: Custom Lyra VAE checkpoint filename (empty for default) | |
| use_separator: Use ¶ separator between tags and summary | |
| clip_include_summary: If True, CLIP also sees the summary | |
| """ | |
| # Ensure seed is an integer (Gradio sliders return floats) | |
| seed = int(seed) | |
| if randomize_seed: | |
| seed = np.random.randint(0, 2**32 - 1) | |
| print(f"🎲 Using seed: {seed} (randomize={randomize_seed})") | |
| def progress_callback(step, total, desc): | |
| progress((step + 1) / total, desc=desc) | |
| try: | |
| pipeline = get_pipeline(model_choice, checkpoint, lyra_checkpoint) | |
| # Update scheduler if needed (SDXL only) | |
| is_sdxl = "SDXL" in model_choice or "Illustrious" in model_choice | |
| if is_sdxl and hasattr(pipeline, 'set_scheduler'): | |
| pipeline.set_scheduler(scheduler_choice) | |
| prediction_type = "epsilon" | |
| if not is_sdxl and "Lune" in model_choice: | |
| prediction_type = "v_prediction" | |
| if not use_lyra or not pipeline.lyra_available: | |
| progress(0.05, desc="Generating...") | |
| image = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_steps, | |
| guidance_scale=cfg_scale, | |
| shift=shift, | |
| use_flow_matching=use_flow_matching, | |
| prediction_type=prediction_type, | |
| seed=seed, | |
| use_lyra=False, | |
| clip_skip=clip_skip, | |
| progress_callback=progress_callback | |
| ) | |
| progress(1.0, desc="Complete!") | |
| return image, None, seed | |
| else: | |
| # Side-by-side comparison: SAME seed for both! | |
| progress(0.05, desc="Generating standard...") | |
| image_standard = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_steps, | |
| guidance_scale=cfg_scale, | |
| shift=shift, | |
| use_flow_matching=use_flow_matching, | |
| prediction_type=prediction_type, | |
| seed=seed, # Same seed | |
| use_lyra=False, | |
| clip_skip=clip_skip, | |
| progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d) | |
| ) | |
| progress(0.5, desc="Generating Lyra fusion (loading T5 + Lyra if needed)...") | |
| image_lyra = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_steps, | |
| guidance_scale=cfg_scale, | |
| shift=shift, | |
| use_flow_matching=use_flow_matching, | |
| prediction_type=prediction_type, | |
| seed=seed, # Same seed for deterministic comparison | |
| use_lyra=True, | |
| clip_skip=clip_skip, | |
| t5_summary=t5_summary, | |
| lyra_strength=lyra_strength, | |
| use_separator=use_separator, | |
| clip_include_summary=clip_include_summary, | |
| progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d) | |
| ) | |
| progress(1.0, desc="Complete!") | |
| return image_standard, image_lyra, seed | |
| except Exception as e: | |
| print(f"❌ Generation failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise e | |
| # ============================================================================ | |
| # GRADIO UI | |
| # ============================================================================ | |
| def create_demo(): | |
| """Create Gradio interface.""" | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # 🌙 Lyra/Lune Flow-Matching Image Generation | |
| **Geometric crystalline diffusion** by [AbstractPhil](https://huggingface.co/AbstractPhil) | |
| Generate images using SD1.5 and SDXL-based models with geometric deep learning: | |
| | Model | Architecture | Lyra Version | Best For | | |
| |-------|-------------|--------------|----------| | |
| | **Illustrious XL** | SDXL | v2 (T5-XL) | Anime/illustration, high detail | | |
| | **SDXL Base** | SDXL | v2 (T5-XL) | Photorealistic, general purpose | | |
| | **Flow-Lune** | SD1.5 | v1 (T5-base) | Fast flow matching (15-25 steps) | | |
| | **SD1.5 Base** | SD1.5 | v1 (T5-base) | Baseline comparison | | |
| **Lazy Loading**: T5 and Lyra VAE are only downloaded when you enable Lyra fusion! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.TextArea( | |
| label="Prompt (Tags for CLIP)", | |
| value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background", | |
| lines=3, | |
| info="CLIP encoders see these tags. T5 also sees these + the summary below." | |
| ) | |
| t5_summary = gr.TextArea( | |
| label="T5 Summary (Natural Language - T5 Only)", | |
| value="A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms against a bright sky", | |
| lines=2, | |
| info="T5 sees: tags ¶ summary. CLIP sees: tags only (unless 'Include Summary in CLIP' is enabled)." | |
| ) | |
| negative_prompt = gr.TextArea( | |
| label="Negative Prompt", | |
| value="lowres, bad anatomy, bad hands, text, error, cropped, worst quality, low quality", | |
| lines=2 | |
| ) | |
| model_choice = gr.Dropdown( | |
| label="Model", | |
| choices=[ | |
| "Illustrious XL", | |
| "SDXL Base", | |
| "Flow-Lune (SD1.5)", | |
| "SD1.5 Base" | |
| ], | |
| value="Illustrious XL" | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| checkpoint = gr.Textbox( | |
| label="Model Checkpoint (optional)", | |
| value="", | |
| placeholder="e.g., illustriousXL_v01.safetensors", | |
| info="Leave empty for default. Illustrious: .safetensors, Lune: .pt" | |
| ) | |
| lyra_checkpoint = gr.Textbox( | |
| label="Lyra VAE Checkpoint (optional)", | |
| value="weights/lyra_illustrious_step_12000.safetensors", | |
| placeholder="e.g., lyra_e100_s50000.safetensors", | |
| info="Leave empty for latest. Loaded from weights/ folder in Lyra repo." | |
| ) | |
| scheduler_choice = gr.Dropdown( | |
| label="Scheduler (SDXL only)", | |
| choices=SCHEDULER_CHOICES, | |
| value=SCHEDULER_EULER_A, | |
| info="Euler Ancestral recommended for Illustrious" | |
| ) | |
| clip_skip = gr.Slider( | |
| label="CLIP Skip", | |
| minimum=1, | |
| maximum=4, | |
| value=2, | |
| step=1, | |
| info="2 recommended for Illustrious, 1 for others" | |
| ) | |
| use_lyra = gr.Checkbox( | |
| label="Enable Lyra VAE (CLIP+T5 Fusion)", | |
| value=True, # DEFAULT: ON | |
| info="Enables lazy loading of T5 and Lyra on first use" | |
| ) | |
| lyra_strength = gr.Slider( | |
| label="Lyra Blend Strength", | |
| minimum=0.0, | |
| maximum=3.0, | |
| value=1.0, | |
| step=0.05, | |
| info="0.0 = pure CLIP, 1.0 = pure Lyra reconstruction" | |
| ) | |
| with gr.Accordion("Lyra Advanced Settings", open=False): | |
| use_separator = gr.Checkbox( | |
| label="Use ¶ Separator", | |
| value=True, | |
| info="Insert ¶ between tags and summary in T5 input" | |
| ) | |
| clip_include_summary = gr.Checkbox( | |
| label="Include Summary in CLIP", | |
| value=False, | |
| info="By default CLIP sees tags only. Enable to append summary to CLIP input." | |
| ) | |
| with gr.Accordion("Generation Settings", open=True): | |
| num_steps = gr.Slider( | |
| label="Steps", | |
| minimum=1, | |
| maximum=50, | |
| value=25, | |
| step=1 | |
| ) | |
| cfg_scale = gr.Slider( | |
| label="CFG Scale", | |
| minimum=1.0, | |
| maximum=20.0, | |
| value=7.0, | |
| step=0.5 | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=512, | |
| maximum=1536, | |
| value=1024, | |
| step=64 | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=512, | |
| maximum=1536, | |
| value=1024, | |
| step=64 | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=2**32 - 1, | |
| value=42, # DEFAULT: 42 | |
| step=1 | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize Seed", | |
| value=False # DEFAULT: OFF for reproducibility | |
| ) | |
| with gr.Accordion("Advanced (Flow Matching)", open=False): | |
| use_flow_matching = gr.Checkbox( | |
| label="Enable Flow Matching", | |
| value=False, | |
| info="Use flow matching ODE (for Lune only)" | |
| ) | |
| shift = gr.Slider( | |
| label="Shift", | |
| minimum=0.0, | |
| maximum=5.0, | |
| value=0.0, | |
| step=0.1, | |
| info="Flow matching shift (0=disabled)" | |
| ) | |
| generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| output_image_standard = gr.Image( | |
| label="Standard", | |
| type="pil" | |
| ) | |
| output_image_lyra = gr.Image( | |
| label="Lyra Fusion 🎵", | |
| type="pil", | |
| visible=True # Visible by default since Lyra is on | |
| ) | |
| output_seed = gr.Number(label="Seed Used", precision=0) | |
| gr.Markdown(""" | |
| ### Tips | |
| - **Lazy Loading**: T5-XL (~3GB) and Lyra VAE only download when you enable Lyra | |
| - **Illustrious XL**: Use CLIP skip 2, Euler Ancestral scheduler | |
| - **Schedulers**: DPM++ 2M SDE for detail, Euler A for speed | |
| - **Lyra v2**: Uses `google/flan-t5-xl` for richer semantics | |
| - **Same Seed**: Both Standard and Lyra use the same seed for fair comparison | |
| """) | |
| # Event handlers | |
| def on_model_change(model_name): | |
| """Update defaults based on model.""" | |
| if "Illustrious" in model_name: | |
| return { | |
| clip_skip: gr.update(value=2), | |
| width: gr.update(value=1024), | |
| height: gr.update(value=1024), | |
| num_steps: gr.update(value=25), | |
| use_flow_matching: gr.update(value=False), | |
| shift: gr.update(value=0.0), | |
| scheduler_choice: gr.update(visible=True, value=SCHEDULER_EULER_A) | |
| } | |
| elif "SDXL" in model_name: | |
| return { | |
| clip_skip: gr.update(value=1), | |
| width: gr.update(value=1024), | |
| height: gr.update(value=1024), | |
| num_steps: gr.update(value=30), | |
| use_flow_matching: gr.update(value=False), | |
| shift: gr.update(value=0.0), | |
| scheduler_choice: gr.update(visible=True, value=SCHEDULER_EULER_A) | |
| } | |
| elif "Lune" in model_name: | |
| return { | |
| clip_skip: gr.update(value=1), | |
| width: gr.update(value=512), | |
| height: gr.update(value=512), | |
| num_steps: gr.update(value=20), | |
| use_flow_matching: gr.update(value=True), | |
| shift: gr.update(value=2.5), | |
| scheduler_choice: gr.update(visible=False) | |
| } | |
| else: # SD1.5 Base | |
| return { | |
| clip_skip: gr.update(value=1), | |
| width: gr.update(value=512), | |
| height: gr.update(value=512), | |
| num_steps: gr.update(value=30), | |
| use_flow_matching: gr.update(value=False), | |
| shift: gr.update(value=0.0), | |
| scheduler_choice: gr.update(visible=False) | |
| } | |
| def on_lyra_toggle(enabled): | |
| """Show/hide Lyra comparison.""" | |
| if enabled: | |
| return { | |
| output_image_standard: gr.update(visible=True, label="Standard"), | |
| output_image_lyra: gr.update(visible=True, label="Lyra Fusion 🎵") | |
| } | |
| else: | |
| return { | |
| output_image_standard: gr.update(visible=True, label="Generated Image"), | |
| output_image_lyra: gr.update(visible=False) | |
| } | |
| model_choice.change( | |
| fn=on_model_change, | |
| inputs=[model_choice], | |
| outputs=[clip_skip, width, height, num_steps, use_flow_matching, shift, scheduler_choice] | |
| ) | |
| use_lyra.change( | |
| fn=on_lyra_toggle, | |
| inputs=[use_lyra], | |
| outputs=[output_image_standard, output_image_lyra] | |
| ) | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=[ | |
| prompt, t5_summary, negative_prompt, model_choice, checkpoint, lyra_checkpoint, | |
| scheduler_choice, clip_skip, | |
| num_steps, cfg_scale, width, height, shift, | |
| use_flow_matching, use_lyra, lyra_strength, use_separator, clip_include_summary, | |
| seed, randomize_seed | |
| ], | |
| outputs=[output_image_standard, output_image_lyra, output_seed] | |
| ) | |
| return demo | |
| # ============================================================================ | |
| # LAUNCH | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.queue(max_size=20) | |
| demo.launch() |