Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Lyra/Lune Flow-Matching Inference Space | |
| Author: AbstractPhil | |
| License: MIT | |
| SD1.5 and SDXL-based flow matching with geometric crystalline architectures. | |
| Supports Illustrious XL, standard SDXL, and SD1.5 variants. | |
| Lyra VAE Versions: | |
| - v1: SD1.5 (768 dim CLIP + T5-base) - geofractal.model.vae.vae_lyra | |
| - v2: SDXL/Illustrious (768 CLIP-L + 1280 CLIP-G + 2048 T5-XL) - geofractal.model.vae.vae_lyra_v2 | |
| """ | |
| import os | |
| import json | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Optional, Dict, Tuple | |
| import spaces | |
| from safetensors.torch import load_file as load_safetensors | |
| from diffusers import ( | |
| UNet2DConditionModel, | |
| AutoencoderKL, | |
| EulerDiscreteScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| DPMSolverMultistepScheduler, | |
| DPMSolverSDEScheduler, | |
| ) | |
| from transformers import ( | |
| CLIPTextModel, | |
| CLIPTokenizer, | |
| CLIPTextModelWithProjection, | |
| T5EncoderModel, | |
| T5Tokenizer | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| # Lazy imports for Lyra | |
| LYRA_V1_AVAILABLE = False | |
| LYRA_V2_AVAILABLE = False | |
| LyraV1 = None | |
| LyraV1Config = None | |
| LyraV2 = None | |
| LyraV2Config = None | |
| def _load_lyra_imports(): | |
| """Lazy load Lyra VAE modules.""" | |
| global LYRA_V1_AVAILABLE, LYRA_V2_AVAILABLE | |
| global LyraV1, LyraV1Config, LyraV2, LyraV2Config | |
| try: | |
| from geofractal.model.vae.vae_lyra import MultiModalVAE as _LyraV1, MultiModalVAEConfig as _LyraV1Config | |
| LyraV1 = _LyraV1 | |
| LyraV1Config = _LyraV1Config | |
| LYRA_V1_AVAILABLE = True | |
| except ImportError: | |
| print("⚠️ Lyra VAE v1 not available") | |
| try: | |
| from geofractal.model.vae.vae_lyra_v2 import MultiModalVAE as _LyraV2, MultiModalVAEConfig as _LyraV2Config | |
| LyraV2 = _LyraV2 | |
| LyraV2Config = _LyraV2Config | |
| LYRA_V2_AVAILABLE = True | |
| except ImportError: | |
| print("⚠️ Lyra VAE v2 not available") | |
| # ============================================================================ | |
| # CONSTANTS | |
| # ============================================================================ | |
| ARCH_SD15 = "sd15" | |
| ARCH_SDXL = "sdxl" | |
| # Scheduler options | |
| SCHEDULER_EULER_A = "Euler Ancestral" | |
| SCHEDULER_EULER = "Euler" | |
| SCHEDULER_DPM_2M_SDE = "DPM++ 2M SDE" | |
| SCHEDULER_DPM_2M = "DPM++ 2M" | |
| SDXL_SCHEDULERS = [SCHEDULER_EULER_A, SCHEDULER_EULER, SCHEDULER_DPM_2M_SDE, SCHEDULER_DPM_2M] | |
| # ============================================================================ | |
| # SCHEDULER FACTORY | |
| # ============================================================================ | |
| def get_scheduler(scheduler_name: str, config_path: str = "stabilityai/stable-diffusion-xl-base-1.0"): | |
| """Create scheduler by name.""" | |
| if scheduler_name == SCHEDULER_EULER_A: | |
| return EulerAncestralDiscreteScheduler.from_pretrained( | |
| config_path, subfolder="scheduler" | |
| ) | |
| elif scheduler_name == SCHEDULER_EULER: | |
| return EulerDiscreteScheduler.from_pretrained( | |
| config_path, subfolder="scheduler" | |
| ) | |
| elif scheduler_name == SCHEDULER_DPM_2M_SDE: | |
| return DPMSolverSDEScheduler.from_pretrained( | |
| config_path, subfolder="scheduler", | |
| algorithm_type="sde-dpmsolver++", | |
| solver_order=2, | |
| ) | |
| elif scheduler_name == SCHEDULER_DPM_2M: | |
| return DPMSolverMultistepScheduler.from_pretrained( | |
| config_path, subfolder="scheduler", | |
| algorithm_type="dpmsolver++", | |
| solver_order=2, | |
| ) | |
| else: | |
| # Default to Euler Ancestral | |
| return EulerAncestralDiscreteScheduler.from_pretrained( | |
| config_path, subfolder="scheduler" | |
| ) | |
| # ============================================================================ | |
| # MODEL LOADING UTILITIES | |
| # ============================================================================ | |
| def get_clip_hidden_state( | |
| model_output, | |
| clip_skip: int = 1, | |
| output_hidden_states: bool = True | |
| ) -> torch.Tensor: | |
| """Extract hidden state with clip_skip support.""" | |
| if clip_skip == 1 or not output_hidden_states: | |
| return model_output.last_hidden_state | |
| if hasattr(model_output, 'hidden_states') and model_output.hidden_states is not None: | |
| return model_output.hidden_states[-clip_skip] | |
| return model_output.last_hidden_state | |
| # ============================================================================ | |
| # LAZY LOADERS | |
| # ============================================================================ | |
| class LazyT5Encoder: | |
| """Lazy loader for T5 encoder - only loads when first accessed.""" | |
| def __init__(self, model_name: str = "google/flan-t5-xl", device: str = "cuda"): | |
| self.model_name = model_name | |
| self.device = device | |
| self._encoder = None | |
| self._tokenizer = None | |
| def encoder(self): | |
| if self._encoder is None: | |
| print(f"📥 Loading T5 encoder: {self.model_name}...") | |
| self._encoder = T5EncoderModel.from_pretrained( | |
| self.model_name, | |
| torch_dtype=torch.float16 | |
| ).to(self.device) | |
| self._encoder.eval() | |
| print("✓ T5 encoder loaded") | |
| return self._encoder | |
| def tokenizer(self): | |
| if self._tokenizer is None: | |
| print(f"📥 Loading T5 tokenizer: {self.model_name}...") | |
| self._tokenizer = T5Tokenizer.from_pretrained(self.model_name) | |
| print("✓ T5 tokenizer loaded") | |
| return self._tokenizer | |
| def is_loaded(self): | |
| return self._encoder is not None | |
| class LazyLyraModel: | |
| """Lazy loader for Lyra VAE - only loads when first accessed.""" | |
| def __init__(self, repo_id: str, device: str = "cuda", version: int = 2): | |
| self.repo_id = repo_id | |
| self.device = device | |
| self.version = version | |
| self._model = None | |
| def model(self): | |
| if self._model is None: | |
| _load_lyra_imports() | |
| if self.version == 2: | |
| self._model = self._load_v2() | |
| else: | |
| self._model = self._load_v1() | |
| return self._model | |
| def _load_v2(self): | |
| if not LYRA_V2_AVAILABLE: | |
| print("⚠️ Lyra VAE v2 not available") | |
| return None | |
| print(f"🎵 Loading Lyra VAE v2 from {self.repo_id}...") | |
| try: | |
| from huggingface_hub import list_repo_files | |
| config_path = hf_hub_download( | |
| repo_id=self.repo_id, | |
| filename="config.json", | |
| repo_type="model" | |
| ) | |
| with open(config_path, 'r') as f: | |
| config_dict = json.load(f) | |
| print(f" ✓ Config: {config_dict.get('fusion_strategy', 'unknown')} fusion") | |
| # Auto-detect checkpoint | |
| repo_files = list_repo_files(self.repo_id, repo_type="model") | |
| checkpoint_files = [f for f in repo_files if f.endswith('.pt')] | |
| checkpoint_files = [f for f in checkpoint_files if 'checkpoint' in f.lower()] | |
| if not checkpoint_files: | |
| raise FileNotFoundError(f"No checkpoint found in {self.repo_id}") | |
| import re | |
| def extract_step(name): | |
| match = re.search(r'(\d+)\.pt', name) | |
| return int(match.group(1)) if match else 0 | |
| checkpoint_files.sort(key=extract_step, reverse=True) | |
| checkpoint_filename = checkpoint_files[0] | |
| print(f" ✓ Using: {checkpoint_filename}") | |
| checkpoint_path = hf_hub_download( | |
| repo_id=self.repo_id, | |
| filename=checkpoint_filename, | |
| repo_type="model" | |
| ) | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| vae_config = LyraV2Config( | |
| modality_dims=config_dict.get('modality_dims', { | |
| "clip_l": 768, "clip_g": 1280, | |
| "t5_xl_l": 2048, "t5_xl_g": 2048 | |
| }), | |
| modality_seq_lens=config_dict.get('modality_seq_lens', { | |
| "clip_l": 77, "clip_g": 77, | |
| "t5_xl_l": 512, "t5_xl_g": 512 | |
| }), | |
| binding_config=config_dict.get('binding_config', { | |
| "clip_l": {"t5_xl_l": 0.3}, | |
| "clip_g": {"t5_xl_g": 0.3}, | |
| "t5_xl_l": {}, | |
| "t5_xl_g": {} | |
| }), | |
| latent_dim=config_dict.get('latent_dim', 2048), | |
| seq_len=config_dict.get('seq_len', 77), | |
| encoder_layers=config_dict.get('encoder_layers', 3), | |
| decoder_layers=config_dict.get('decoder_layers', 3), | |
| hidden_dim=config_dict.get('hidden_dim', 2048), | |
| dropout=config_dict.get('dropout', 0.1), | |
| fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'), | |
| fusion_heads=config_dict.get('fusion_heads', 8), | |
| fusion_dropout=config_dict.get('fusion_dropout', 0.1), | |
| cantor_depth=config_dict.get('cantor_depth', 8), | |
| cantor_local_window=config_dict.get('cantor_local_window', 3), | |
| alpha_init=config_dict.get('alpha_init', 1.0), | |
| beta_init=config_dict.get('beta_init', 0.3), | |
| ) | |
| lyra_model = LyraV2(vae_config) | |
| state_dict = checkpoint.get('model_state_dict', checkpoint) | |
| missing, unexpected = lyra_model.load_state_dict(state_dict, strict=False) | |
| if missing: | |
| print(f" ⚠️ Missing keys: {len(missing)}") | |
| if unexpected: | |
| print(f" ⚠️ Unexpected keys: {len(unexpected)}") | |
| lyra_model.to(self.device) | |
| lyra_model.eval() | |
| total_params = sum(p.numel() for p in lyra_model.parameters()) | |
| print(f"✅ Lyra VAE v2 loaded ({total_params/1e6:.1f}M params)") | |
| return lyra_model | |
| except Exception as e: | |
| print(f"❌ Failed to load Lyra VAE v2: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def _load_v1(self): | |
| if not LYRA_V1_AVAILABLE: | |
| print("⚠️ Lyra VAE v1 not available") | |
| return None | |
| # Similar implementation for v1... | |
| return None | |
| def is_loaded(self): | |
| return self._model is not None | |
| # ============================================================================ | |
| # SDXL PIPELINE | |
| # ============================================================================ | |
| class SDXLFlowMatchingPipeline: | |
| """Pipeline for SDXL-based flow-matching inference with dual CLIP encoders.""" | |
| def __init__( | |
| self, | |
| vae: AutoencoderKL, | |
| text_encoder: CLIPTextModel, | |
| text_encoder_2: CLIPTextModelWithProjection, | |
| tokenizer: CLIPTokenizer, | |
| tokenizer_2: CLIPTokenizer, | |
| unet: UNet2DConditionModel, | |
| scheduler, | |
| device: str = "cuda", | |
| t5_loader: Optional[LazyT5Encoder] = None, | |
| lyra_loader: Optional[LazyLyraModel] = None, | |
| clip_skip: int = 1 | |
| ): | |
| self.vae = vae | |
| self.text_encoder = text_encoder | |
| self.text_encoder_2 = text_encoder_2 | |
| self.tokenizer = tokenizer | |
| self.tokenizer_2 = tokenizer_2 | |
| self.unet = unet | |
| self.scheduler = scheduler | |
| self.device = device | |
| # Lazy loaders | |
| self.t5_loader = t5_loader | |
| self.lyra_loader = lyra_loader | |
| # Settings | |
| self.clip_skip = clip_skip | |
| self.vae_scale_factor = 0.13025 | |
| self.arch = ARCH_SDXL | |
| def set_scheduler(self, scheduler_name: str): | |
| """Switch scheduler.""" | |
| self.scheduler = get_scheduler(scheduler_name) | |
| def t5_encoder(self): | |
| return self.t5_loader.encoder if self.t5_loader else None | |
| def t5_tokenizer(self): | |
| return self.t5_loader.tokenizer if self.t5_loader else None | |
| def lyra_model(self): | |
| return self.lyra_loader.model if self.lyra_loader else None | |
| def encode_prompt( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| clip_skip: int = 1 | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Encode prompts using dual CLIP encoders for SDXL.""" | |
| # CLIP-L encoding | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| output_hidden_states = clip_skip > 1 | |
| clip_l_output = self.text_encoder( | |
| text_input_ids, | |
| output_hidden_states=output_hidden_states | |
| ) | |
| prompt_embeds_l = get_clip_hidden_state(clip_l_output, clip_skip, output_hidden_states) | |
| # CLIP-G encoding | |
| text_inputs_2 = self.tokenizer_2( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer_2.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids_2 = text_inputs_2.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| clip_g_output = self.text_encoder_2( | |
| text_input_ids_2, | |
| output_hidden_states=output_hidden_states | |
| ) | |
| prompt_embeds_g = get_clip_hidden_state(clip_g_output, clip_skip, output_hidden_states) | |
| pooled_prompt_embeds = clip_g_output.text_embeds | |
| prompt_embeds = torch.cat([prompt_embeds_l, prompt_embeds_g], dim=-1) | |
| # Negative prompt | |
| if negative_prompt: | |
| uncond_inputs = self.tokenizer( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| uncond_input_ids = uncond_inputs.input_ids.to(self.device) | |
| uncond_inputs_2 = self.tokenizer_2( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer_2.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| uncond_input_ids_2 = uncond_inputs_2.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| uncond_output_l = self.text_encoder( | |
| uncond_input_ids, | |
| output_hidden_states=output_hidden_states | |
| ) | |
| negative_embeds_l = get_clip_hidden_state(uncond_output_l, clip_skip, output_hidden_states) | |
| uncond_output_g = self.text_encoder_2( | |
| uncond_input_ids_2, | |
| output_hidden_states=output_hidden_states | |
| ) | |
| negative_embeds_g = get_clip_hidden_state(uncond_output_g, clip_skip, output_hidden_states) | |
| negative_pooled = uncond_output_g.text_embeds | |
| negative_prompt_embeds = torch.cat([negative_embeds_l, negative_embeds_g], dim=-1) | |
| else: | |
| negative_prompt_embeds = torch.zeros_like(prompt_embeds) | |
| negative_pooled = torch.zeros_like(pooled_prompt_embeds) | |
| return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled | |
| def encode_prompt_lyra( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| clip_skip: int = 1, | |
| t5_summary: str = "", | |
| lyra_strength: float = 0.3 | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Encode prompts using Lyra VAE v2 fusion (CLIP + T5).""" | |
| if self.lyra_model is None or self.t5_encoder is None: | |
| raise ValueError("Lyra VAE components not initialized") | |
| # Get standard CLIP embeddings first | |
| prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt( | |
| prompt, negative_prompt, clip_skip | |
| ) | |
| # Format T5 input | |
| SUMMARY_SEPARATOR = "¶" | |
| if t5_summary.strip(): | |
| t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {t5_summary}" | |
| else: | |
| t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {prompt}" | |
| # Get T5 embeddings | |
| t5_inputs = self.t5_tokenizer( | |
| t5_prompt, | |
| max_length=512, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state | |
| clip_l_dim = 768 | |
| clip_l_embeds = prompt_embeds[..., :clip_l_dim] | |
| clip_g_embeds = prompt_embeds[..., clip_l_dim:] | |
| with torch.no_grad(): | |
| modality_inputs = { | |
| 'clip_l': clip_l_embeds.float(), | |
| 'clip_g': clip_g_embeds.float(), | |
| 't5_xl_l': t5_embeds.float(), | |
| 't5_xl_g': t5_embeds.float() | |
| } | |
| reconstructions, mu, logvar, _ = self.lyra_model( | |
| modality_inputs, | |
| target_modalities=['clip_l', 'clip_g'] | |
| ) | |
| lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype) | |
| lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype) | |
| # Normalize if stats are off | |
| clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8) | |
| clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8) | |
| if clip_l_std_ratio > 2.0 or clip_l_std_ratio < 0.5: | |
| lyra_clip_l = (lyra_clip_l - lyra_clip_l.mean()) / (lyra_clip_l.std() + 1e-8) | |
| lyra_clip_l = lyra_clip_l * clip_l_embeds.std() + clip_l_embeds.mean() | |
| if clip_g_std_ratio > 2.0 or clip_g_std_ratio < 0.5: | |
| lyra_clip_g = (lyra_clip_g - lyra_clip_g.mean()) / (lyra_clip_g.std() + 1e-8) | |
| lyra_clip_g = lyra_clip_g * clip_g_embeds.std() + clip_g_embeds.mean() | |
| # Blend | |
| fused_clip_l = (1 - lyra_strength) * clip_l_embeds + lyra_strength * lyra_clip_l | |
| fused_clip_g = (1 - lyra_strength) * clip_g_embeds + lyra_strength * lyra_clip_g | |
| prompt_embeds_fused = torch.cat([fused_clip_l, fused_clip_g], dim=-1) | |
| # Negative prompt - just use original CLIP | |
| return prompt_embeds_fused, negative_prompt_embeds, pooled, negative_pooled | |
| def _get_add_time_ids( | |
| self, | |
| original_size: Tuple[int, int], | |
| crops_coords_top_left: Tuple[int, int], | |
| target_size: Tuple[int, int], | |
| dtype: torch.dtype | |
| ) -> torch.Tensor: | |
| """Create time embedding IDs for SDXL.""" | |
| add_time_ids = list(original_size + crops_coords_top_left + target_size) | |
| add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device) | |
| return add_time_ids | |
| def __call__( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| height: int = 1024, | |
| width: int = 1024, | |
| num_inference_steps: int = 25, | |
| guidance_scale: float = 7.0, | |
| seed: Optional[int] = None, | |
| use_lyra: bool = False, | |
| clip_skip: int = 2, | |
| t5_summary: str = "", | |
| lyra_strength: float = 1.0, | |
| progress_callback=None | |
| ): | |
| """Generate image using SDXL architecture.""" | |
| if seed is not None: | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| else: | |
| generator = None | |
| # Encode prompts | |
| if use_lyra and self.lyra_loader is not None: | |
| prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra( | |
| prompt, negative_prompt, clip_skip, t5_summary, lyra_strength | |
| ) | |
| else: | |
| prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt( | |
| prompt, negative_prompt, clip_skip | |
| ) | |
| # Prepare latents | |
| latent_channels = 4 | |
| latent_height = height // 8 | |
| latent_width = width // 8 | |
| latents = torch.randn( | |
| (1, latent_channels, latent_height, latent_width), | |
| generator=generator, | |
| device=self.device, | |
| dtype=torch.float16 | |
| ) | |
| # Set timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=self.device) | |
| timesteps = self.scheduler.timesteps | |
| latents = latents * self.scheduler.init_noise_sigma | |
| # Time embeddings for SDXL | |
| original_size = (height, width) | |
| target_size = (height, width) | |
| crops_coords_top_left = (0, 0) | |
| add_time_ids = self._get_add_time_ids( | |
| original_size, crops_coords_top_left, target_size, dtype=torch.float16 | |
| ) | |
| negative_add_time_ids = add_time_ids | |
| # Denoising loop | |
| for i, t in enumerate(timesteps): | |
| if progress_callback: | |
| progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}") | |
| latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| timestep = t.expand(latent_model_input.shape[0]) | |
| if guidance_scale > 1.0: | |
| text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
| add_text_embeds = torch.cat([negative_pooled, pooled]) | |
| add_time_ids_input = torch.cat([negative_add_time_ids, add_time_ids]) | |
| else: | |
| text_embeds = prompt_embeds | |
| add_text_embeds = pooled | |
| add_time_ids_input = add_time_ids | |
| added_cond_kwargs = { | |
| "text_embeds": add_text_embeds, | |
| "time_ids": add_time_ids_input | |
| } | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| timestep, | |
| encoder_hidden_states=text_embeds, | |
| added_cond_kwargs=added_cond_kwargs, | |
| return_dict=False | |
| )[0] | |
| if guidance_scale > 1.0: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| # Decode | |
| latents = latents / self.vae_scale_factor | |
| with torch.no_grad(): | |
| image = self.vae.decode(latents.to(self.vae.dtype)).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
| image = (image * 255).round().astype("uint8") | |
| image = Image.fromarray(image[0]) | |
| return image | |
| # ============================================================================ | |
| # MODEL LOADERS | |
| # ============================================================================ | |
| def load_illustrious_xl( | |
| repo_id: str = "AbstractPhil/illustrious-xl-v1", | |
| filename: str = "illustriousXL_v01.safetensors", | |
| device: str = "cuda" | |
| ) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]: | |
| """Load Illustrious XL from single safetensors file.""" | |
| from diffusers import StableDiffusionXLPipeline | |
| print(f"📥 Loading Illustrious XL: {repo_id}/{filename}") | |
| checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model") | |
| print(f"✓ Downloaded: {checkpoint_path}") | |
| print("📦 Loading pipeline...") | |
| pipe = StableDiffusionXLPipeline.from_single_file( | |
| checkpoint_path, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True, | |
| ) | |
| unet = pipe.unet.to(device) | |
| vae = pipe.vae.to(device) | |
| text_encoder = pipe.text_encoder.to(device) | |
| text_encoder_2 = pipe.text_encoder_2.to(device) | |
| tokenizer = pipe.tokenizer | |
| tokenizer_2 = pipe.tokenizer_2 | |
| del pipe | |
| torch.cuda.empty_cache() | |
| print("✅ Illustrious XL loaded!") | |
| return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 | |
| # ============================================================================ | |
| # PIPELINE INITIALIZATION | |
| # ============================================================================ | |
| def initialize_sdxl_pipeline( | |
| model_choice: str, | |
| scheduler_name: str = SCHEDULER_EULER_A, | |
| device: str = "cuda" | |
| ): | |
| """Initialize SDXL pipeline with lazy T5/Lyra loading.""" | |
| print(f"🚀 Initializing {model_choice} pipeline...") | |
| # Load base model | |
| if "Illustrious" in model_choice: | |
| unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl(device=device) | |
| else: | |
| # SDXL Base | |
| from diffusers import StableDiffusionXLPipeline | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| torch_dtype=torch.float16, | |
| ) | |
| unet = pipe.unet.to(device) | |
| vae = pipe.vae.to(device) | |
| text_encoder = pipe.text_encoder.to(device) | |
| text_encoder_2 = pipe.text_encoder_2.to(device) | |
| tokenizer = pipe.tokenizer | |
| tokenizer_2 = pipe.tokenizer_2 | |
| del pipe | |
| torch.cuda.empty_cache() | |
| # Create lazy loaders (don't download yet) | |
| t5_loader = LazyT5Encoder(model_name="google/flan-t5-xl", device=device) | |
| lyra_loader = LazyLyraModel( | |
| repo_id="AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious", | |
| device=device, | |
| version=2 | |
| ) | |
| # Get scheduler | |
| scheduler = get_scheduler(scheduler_name) | |
| pipeline = SDXLFlowMatchingPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| text_encoder_2=text_encoder_2, | |
| tokenizer=tokenizer, | |
| tokenizer_2=tokenizer_2, | |
| unet=unet, | |
| scheduler=scheduler, | |
| device=device, | |
| t5_loader=t5_loader, | |
| lyra_loader=lyra_loader, | |
| clip_skip=2 | |
| ) | |
| print("✅ Pipeline initialized (T5/Lyra will load on first use)") | |
| return pipeline | |
| # ============================================================================ | |
| # GLOBAL STATE | |
| # ============================================================================ | |
| CURRENT_PIPELINE = None | |
| CURRENT_MODEL = None | |
| CURRENT_SCHEDULER = None | |
| def get_pipeline(model_choice: str, scheduler_name: str = SCHEDULER_EULER_A): | |
| """Get or create pipeline for selected model.""" | |
| global CURRENT_PIPELINE, CURRENT_MODEL, CURRENT_SCHEDULER | |
| if CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice: | |
| CURRENT_PIPELINE = initialize_sdxl_pipeline(model_choice, scheduler_name, device="cuda") | |
| CURRENT_MODEL = model_choice | |
| CURRENT_SCHEDULER = scheduler_name | |
| elif CURRENT_SCHEDULER != scheduler_name: | |
| CURRENT_PIPELINE.set_scheduler(scheduler_name) | |
| CURRENT_SCHEDULER = scheduler_name | |
| return CURRENT_PIPELINE | |
| # ============================================================================ | |
| # INFERENCE | |
| # ============================================================================ | |
| def generate_image( | |
| prompt: str, | |
| t5_summary: str, | |
| negative_prompt: str, | |
| model_choice: str, | |
| scheduler_name: str, | |
| clip_skip: int, | |
| num_steps: int, | |
| cfg_scale: float, | |
| width: int, | |
| height: int, | |
| use_lyra: bool, | |
| lyra_strength: float, | |
| seed: int, | |
| randomize_seed: bool, | |
| progress=gr.Progress() | |
| ): | |
| """Generate image with ZeroGPU support.""" | |
| if randomize_seed: | |
| seed = np.random.randint(0, 2**32 - 1) | |
| def progress_callback(step, total, desc): | |
| progress((step + 1) / total, desc=desc) | |
| try: | |
| pipeline = get_pipeline(model_choice, scheduler_name) | |
| if not use_lyra or pipeline.lyra_loader is None: | |
| progress(0.05, desc="Generating...") | |
| image = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_steps, | |
| guidance_scale=cfg_scale, | |
| seed=seed, | |
| use_lyra=False, | |
| clip_skip=clip_skip, | |
| progress_callback=progress_callback | |
| ) | |
| progress(1.0, desc="Complete!") | |
| return image, None, seed | |
| else: | |
| progress(0.05, desc="Generating standard...") | |
| image_standard = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_steps, | |
| guidance_scale=cfg_scale, | |
| seed=seed, | |
| use_lyra=False, | |
| clip_skip=clip_skip, | |
| progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d) | |
| ) | |
| progress(0.5, desc="Loading Lyra + T5 (first run only)...") | |
| image_lyra = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_steps, | |
| guidance_scale=cfg_scale, | |
| seed=seed, | |
| use_lyra=True, | |
| clip_skip=clip_skip, | |
| t5_summary=t5_summary, | |
| lyra_strength=lyra_strength, | |
| progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d) | |
| ) | |
| progress(1.0, desc="Complete!") | |
| return image_standard, image_lyra, seed | |
| except Exception as e: | |
| print(f"❌ Generation failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise e | |
| # ============================================================================ | |
| # GRADIO UI | |
| # ============================================================================ | |
| def create_demo(): | |
| """Create Gradio interface.""" | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # 🌙 Lyra/Illustrious XL Image Generation | |
| **Geometric crystalline diffusion** by [AbstractPhil](https://huggingface.co/AbstractPhil) | |
| | Model | Architecture | Lyra Version | Best For | | |
| |-------|-------------|--------------|----------| | |
| | **Illustrious XL** | SDXL | v2 (T5-XL) | Anime/illustration, high detail | | |
| | **SDXL Base** | SDXL | v2 (T5-XL) | Photorealistic, general purpose | | |
| **Lyra VAE** fuses CLIP + T5-XL embeddings using adaptive Cantor attention. | |
| T5 and Lyra only load when you enable the Lyra checkbox! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.TextArea( | |
| label="Prompt", | |
| value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background", | |
| lines=3 | |
| ) | |
| t5_summary = gr.TextArea( | |
| label="T5 Summary (for Lyra)", | |
| value="A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms", | |
| lines=2, | |
| info="Natural language description for T5. Leave empty to use prompt." | |
| ) | |
| negative_prompt = gr.TextArea( | |
| label="Negative Prompt", | |
| value="lowres, bad anatomy, bad hands, text, error, worst quality, low quality", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| label="Model", | |
| choices=["Illustrious XL", "SDXL Base"], | |
| value="Illustrious XL" | |
| ) | |
| scheduler_name = gr.Dropdown( | |
| label="Scheduler", | |
| choices=SDXL_SCHEDULERS, | |
| value=SCHEDULER_EULER_A | |
| ) | |
| clip_skip = gr.Slider( | |
| label="CLIP Skip", | |
| minimum=1, maximum=4, value=2, step=1, | |
| info="2 recommended for Illustrious" | |
| ) | |
| use_lyra = gr.Checkbox( | |
| label="Enable Lyra VAE (loads T5-XL on first use)", | |
| value=False, | |
| info="Compare standard vs geometric fusion" | |
| ) | |
| lyra_strength = gr.Slider( | |
| label="Lyra Blend Strength", | |
| minimum=0.0, maximum=2.0, value=1.0, step=0.05, | |
| info="0.0 = pure CLIP, 1.0 = pure Lyra" | |
| ) | |
| with gr.Accordion("Generation Settings", open=True): | |
| num_steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=25, step=1) | |
| cfg_scale = gr.Slider(label="CFG Scale", minimum=1.0, maximum=15.0, value=7.0, step=0.5) | |
| with gr.Row(): | |
| width = gr.Slider(label="Width", minimum=512, maximum=1536, value=1024, step=64) | |
| height = gr.Slider(label="Height", minimum=512, maximum=1536, value=1024, step=64) | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=2**32 - 1, value=42, step=1) | |
| randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
| generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| output_image_standard = gr.Image(label="Standard", type="pil") | |
| output_image_lyra = gr.Image(label="Lyra Fusion 🎵", type="pil", visible=True) | |
| output_seed = gr.Number(label="Seed", precision=0) | |
| # Event handlers | |
| def on_lyra_toggle(enabled): | |
| if enabled: | |
| return { | |
| output_image_standard: gr.update(visible=True, label="Standard"), | |
| output_image_lyra: gr.update(visible=True, label="Lyra Fusion 🎵") | |
| } | |
| else: | |
| return { | |
| output_image_standard: gr.update(visible=True, label="Generated Image"), | |
| output_image_lyra: gr.update(visible=False) | |
| } | |
| use_lyra.change( | |
| fn=on_lyra_toggle, | |
| inputs=[use_lyra], | |
| outputs=[output_image_standard, output_image_lyra] | |
| ) | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=[ | |
| prompt, t5_summary, negative_prompt, model_choice, scheduler_name, | |
| clip_skip, num_steps, cfg_scale, width, height, | |
| use_lyra, lyra_strength, seed, randomize_seed | |
| ], | |
| outputs=[output_image_standard, output_image_lyra, output_seed] | |
| ) | |
| return demo | |
| # ============================================================================ | |
| # LAUNCH | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.queue(max_size=20) | |
| demo.launch() |