AbstractPhil's picture
Update app.py
a7aafe6 verified
raw
history blame
36.6 kB
"""
Lyra/Lune Flow-Matching Inference Space
Author: AbstractPhil
License: MIT
SD1.5 and SDXL-based flow matching with geometric crystalline architectures.
Supports Illustrious XL, standard SDXL, and SD1.5 variants.
Lyra VAE Versions:
- v1: SD1.5 (768 dim CLIP + T5-base) - geofractal.model.vae.vae_lyra
- v2: SDXL/Illustrious (768 CLIP-L + 1280 CLIP-G + 2048 T5-XL) - geofractal.model.vae.vae_lyra_v2
"""
import os
import json
import torch
import gradio as gr
import numpy as np
from PIL import Image
from typing import Optional, Dict, Tuple
import spaces
from safetensors.torch import load_file as load_safetensors
from diffusers import (
UNet2DConditionModel,
AutoencoderKL,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DPMSolverSDEScheduler,
)
from transformers import (
CLIPTextModel,
CLIPTokenizer,
CLIPTextModelWithProjection,
T5EncoderModel,
T5Tokenizer
)
from huggingface_hub import hf_hub_download
# Lazy imports for Lyra
LYRA_V1_AVAILABLE = False
LYRA_V2_AVAILABLE = False
LyraV1 = None
LyraV1Config = None
LyraV2 = None
LyraV2Config = None
def _load_lyra_imports():
"""Lazy load Lyra VAE modules."""
global LYRA_V1_AVAILABLE, LYRA_V2_AVAILABLE
global LyraV1, LyraV1Config, LyraV2, LyraV2Config
try:
from geofractal.model.vae.vae_lyra import MultiModalVAE as _LyraV1, MultiModalVAEConfig as _LyraV1Config
LyraV1 = _LyraV1
LyraV1Config = _LyraV1Config
LYRA_V1_AVAILABLE = True
except ImportError:
print("⚠️ Lyra VAE v1 not available")
try:
from geofractal.model.vae.vae_lyra_v2 import MultiModalVAE as _LyraV2, MultiModalVAEConfig as _LyraV2Config
LyraV2 = _LyraV2
LyraV2Config = _LyraV2Config
LYRA_V2_AVAILABLE = True
except ImportError:
print("⚠️ Lyra VAE v2 not available")
# ============================================================================
# CONSTANTS
# ============================================================================
ARCH_SD15 = "sd15"
ARCH_SDXL = "sdxl"
# Scheduler options
SCHEDULER_EULER_A = "Euler Ancestral"
SCHEDULER_EULER = "Euler"
SCHEDULER_DPM_2M_SDE = "DPM++ 2M SDE"
SCHEDULER_DPM_2M = "DPM++ 2M"
SDXL_SCHEDULERS = [SCHEDULER_EULER_A, SCHEDULER_EULER, SCHEDULER_DPM_2M_SDE, SCHEDULER_DPM_2M]
# ============================================================================
# SCHEDULER FACTORY
# ============================================================================
def get_scheduler(scheduler_name: str, config_path: str = "stabilityai/stable-diffusion-xl-base-1.0"):
"""Create scheduler by name."""
if scheduler_name == SCHEDULER_EULER_A:
return EulerAncestralDiscreteScheduler.from_pretrained(
config_path, subfolder="scheduler"
)
elif scheduler_name == SCHEDULER_EULER:
return EulerDiscreteScheduler.from_pretrained(
config_path, subfolder="scheduler"
)
elif scheduler_name == SCHEDULER_DPM_2M_SDE:
return DPMSolverSDEScheduler.from_pretrained(
config_path, subfolder="scheduler",
algorithm_type="sde-dpmsolver++",
solver_order=2,
)
elif scheduler_name == SCHEDULER_DPM_2M:
return DPMSolverMultistepScheduler.from_pretrained(
config_path, subfolder="scheduler",
algorithm_type="dpmsolver++",
solver_order=2,
)
else:
# Default to Euler Ancestral
return EulerAncestralDiscreteScheduler.from_pretrained(
config_path, subfolder="scheduler"
)
# ============================================================================
# MODEL LOADING UTILITIES
# ============================================================================
def get_clip_hidden_state(
model_output,
clip_skip: int = 1,
output_hidden_states: bool = True
) -> torch.Tensor:
"""Extract hidden state with clip_skip support."""
if clip_skip == 1 or not output_hidden_states:
return model_output.last_hidden_state
if hasattr(model_output, 'hidden_states') and model_output.hidden_states is not None:
return model_output.hidden_states[-clip_skip]
return model_output.last_hidden_state
# ============================================================================
# LAZY LOADERS
# ============================================================================
class LazyT5Encoder:
"""Lazy loader for T5 encoder - only loads when first accessed."""
def __init__(self, model_name: str = "google/flan-t5-xl", device: str = "cuda"):
self.model_name = model_name
self.device = device
self._encoder = None
self._tokenizer = None
@property
def encoder(self):
if self._encoder is None:
print(f"📥 Loading T5 encoder: {self.model_name}...")
self._encoder = T5EncoderModel.from_pretrained(
self.model_name,
torch_dtype=torch.float16
).to(self.device)
self._encoder.eval()
print("✓ T5 encoder loaded")
return self._encoder
@property
def tokenizer(self):
if self._tokenizer is None:
print(f"📥 Loading T5 tokenizer: {self.model_name}...")
self._tokenizer = T5Tokenizer.from_pretrained(self.model_name)
print("✓ T5 tokenizer loaded")
return self._tokenizer
def is_loaded(self):
return self._encoder is not None
class LazyLyraModel:
"""Lazy loader for Lyra VAE - only loads when first accessed."""
def __init__(self, repo_id: str, device: str = "cuda", version: int = 2):
self.repo_id = repo_id
self.device = device
self.version = version
self._model = None
@property
def model(self):
if self._model is None:
_load_lyra_imports()
if self.version == 2:
self._model = self._load_v2()
else:
self._model = self._load_v1()
return self._model
def _load_v2(self):
if not LYRA_V2_AVAILABLE:
print("⚠️ Lyra VAE v2 not available")
return None
print(f"🎵 Loading Lyra VAE v2 from {self.repo_id}...")
try:
from huggingface_hub import list_repo_files
config_path = hf_hub_download(
repo_id=self.repo_id,
filename="config.json",
repo_type="model"
)
with open(config_path, 'r') as f:
config_dict = json.load(f)
print(f" ✓ Config: {config_dict.get('fusion_strategy', 'unknown')} fusion")
# Auto-detect checkpoint
repo_files = list_repo_files(self.repo_id, repo_type="model")
checkpoint_files = [f for f in repo_files if f.endswith('.pt')]
checkpoint_files = [f for f in checkpoint_files if 'checkpoint' in f.lower()]
if not checkpoint_files:
raise FileNotFoundError(f"No checkpoint found in {self.repo_id}")
import re
def extract_step(name):
match = re.search(r'(\d+)\.pt', name)
return int(match.group(1)) if match else 0
checkpoint_files.sort(key=extract_step, reverse=True)
checkpoint_filename = checkpoint_files[0]
print(f" ✓ Using: {checkpoint_filename}")
checkpoint_path = hf_hub_download(
repo_id=self.repo_id,
filename=checkpoint_filename,
repo_type="model"
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
vae_config = LyraV2Config(
modality_dims=config_dict.get('modality_dims', {
"clip_l": 768, "clip_g": 1280,
"t5_xl_l": 2048, "t5_xl_g": 2048
}),
modality_seq_lens=config_dict.get('modality_seq_lens', {
"clip_l": 77, "clip_g": 77,
"t5_xl_l": 512, "t5_xl_g": 512
}),
binding_config=config_dict.get('binding_config', {
"clip_l": {"t5_xl_l": 0.3},
"clip_g": {"t5_xl_g": 0.3},
"t5_xl_l": {},
"t5_xl_g": {}
}),
latent_dim=config_dict.get('latent_dim', 2048),
seq_len=config_dict.get('seq_len', 77),
encoder_layers=config_dict.get('encoder_layers', 3),
decoder_layers=config_dict.get('decoder_layers', 3),
hidden_dim=config_dict.get('hidden_dim', 2048),
dropout=config_dict.get('dropout', 0.1),
fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'),
fusion_heads=config_dict.get('fusion_heads', 8),
fusion_dropout=config_dict.get('fusion_dropout', 0.1),
cantor_depth=config_dict.get('cantor_depth', 8),
cantor_local_window=config_dict.get('cantor_local_window', 3),
alpha_init=config_dict.get('alpha_init', 1.0),
beta_init=config_dict.get('beta_init', 0.3),
)
lyra_model = LyraV2(vae_config)
state_dict = checkpoint.get('model_state_dict', checkpoint)
missing, unexpected = lyra_model.load_state_dict(state_dict, strict=False)
if missing:
print(f" ⚠️ Missing keys: {len(missing)}")
if unexpected:
print(f" ⚠️ Unexpected keys: {len(unexpected)}")
lyra_model.to(self.device)
lyra_model.eval()
total_params = sum(p.numel() for p in lyra_model.parameters())
print(f"✅ Lyra VAE v2 loaded ({total_params/1e6:.1f}M params)")
return lyra_model
except Exception as e:
print(f"❌ Failed to load Lyra VAE v2: {e}")
import traceback
traceback.print_exc()
return None
def _load_v1(self):
if not LYRA_V1_AVAILABLE:
print("⚠️ Lyra VAE v1 not available")
return None
# Similar implementation for v1...
return None
def is_loaded(self):
return self._model is not None
# ============================================================================
# SDXL PIPELINE
# ============================================================================
class SDXLFlowMatchingPipeline:
"""Pipeline for SDXL-based flow-matching inference with dual CLIP encoders."""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler,
device: str = "cuda",
t5_loader: Optional[LazyT5Encoder] = None,
lyra_loader: Optional[LazyLyraModel] = None,
clip_skip: int = 1
):
self.vae = vae
self.text_encoder = text_encoder
self.text_encoder_2 = text_encoder_2
self.tokenizer = tokenizer
self.tokenizer_2 = tokenizer_2
self.unet = unet
self.scheduler = scheduler
self.device = device
# Lazy loaders
self.t5_loader = t5_loader
self.lyra_loader = lyra_loader
# Settings
self.clip_skip = clip_skip
self.vae_scale_factor = 0.13025
self.arch = ARCH_SDXL
def set_scheduler(self, scheduler_name: str):
"""Switch scheduler."""
self.scheduler = get_scheduler(scheduler_name)
@property
def t5_encoder(self):
return self.t5_loader.encoder if self.t5_loader else None
@property
def t5_tokenizer(self):
return self.t5_loader.tokenizer if self.t5_loader else None
@property
def lyra_model(self):
return self.lyra_loader.model if self.lyra_loader else None
def encode_prompt(
self,
prompt: str,
negative_prompt: str = "",
clip_skip: int = 1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode prompts using dual CLIP encoders for SDXL."""
# CLIP-L encoding
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.device)
with torch.no_grad():
output_hidden_states = clip_skip > 1
clip_l_output = self.text_encoder(
text_input_ids,
output_hidden_states=output_hidden_states
)
prompt_embeds_l = get_clip_hidden_state(clip_l_output, clip_skip, output_hidden_states)
# CLIP-G encoding
text_inputs_2 = self.tokenizer_2(
prompt,
padding="max_length",
max_length=self.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids_2 = text_inputs_2.input_ids.to(self.device)
with torch.no_grad():
clip_g_output = self.text_encoder_2(
text_input_ids_2,
output_hidden_states=output_hidden_states
)
prompt_embeds_g = get_clip_hidden_state(clip_g_output, clip_skip, output_hidden_states)
pooled_prompt_embeds = clip_g_output.text_embeds
prompt_embeds = torch.cat([prompt_embeds_l, prompt_embeds_g], dim=-1)
# Negative prompt
if negative_prompt:
uncond_inputs = self.tokenizer(
negative_prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids = uncond_inputs.input_ids.to(self.device)
uncond_inputs_2 = self.tokenizer_2(
negative_prompt,
padding="max_length",
max_length=self.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids_2 = uncond_inputs_2.input_ids.to(self.device)
with torch.no_grad():
uncond_output_l = self.text_encoder(
uncond_input_ids,
output_hidden_states=output_hidden_states
)
negative_embeds_l = get_clip_hidden_state(uncond_output_l, clip_skip, output_hidden_states)
uncond_output_g = self.text_encoder_2(
uncond_input_ids_2,
output_hidden_states=output_hidden_states
)
negative_embeds_g = get_clip_hidden_state(uncond_output_g, clip_skip, output_hidden_states)
negative_pooled = uncond_output_g.text_embeds
negative_prompt_embeds = torch.cat([negative_embeds_l, negative_embeds_g], dim=-1)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled = torch.zeros_like(pooled_prompt_embeds)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled
def encode_prompt_lyra(
self,
prompt: str,
negative_prompt: str = "",
clip_skip: int = 1,
t5_summary: str = "",
lyra_strength: float = 0.3
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode prompts using Lyra VAE v2 fusion (CLIP + T5)."""
if self.lyra_model is None or self.t5_encoder is None:
raise ValueError("Lyra VAE components not initialized")
# Get standard CLIP embeddings first
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
prompt, negative_prompt, clip_skip
)
# Format T5 input
SUMMARY_SEPARATOR = "¶"
if t5_summary.strip():
t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {t5_summary}"
else:
t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {prompt}"
# Get T5 embeddings
t5_inputs = self.t5_tokenizer(
t5_prompt,
max_length=512,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(self.device)
with torch.no_grad():
t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state
clip_l_dim = 768
clip_l_embeds = prompt_embeds[..., :clip_l_dim]
clip_g_embeds = prompt_embeds[..., clip_l_dim:]
with torch.no_grad():
modality_inputs = {
'clip_l': clip_l_embeds.float(),
'clip_g': clip_g_embeds.float(),
't5_xl_l': t5_embeds.float(),
't5_xl_g': t5_embeds.float()
}
reconstructions, mu, logvar, _ = self.lyra_model(
modality_inputs,
target_modalities=['clip_l', 'clip_g']
)
lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
# Normalize if stats are off
clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8)
clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8)
if clip_l_std_ratio > 2.0 or clip_l_std_ratio < 0.5:
lyra_clip_l = (lyra_clip_l - lyra_clip_l.mean()) / (lyra_clip_l.std() + 1e-8)
lyra_clip_l = lyra_clip_l * clip_l_embeds.std() + clip_l_embeds.mean()
if clip_g_std_ratio > 2.0 or clip_g_std_ratio < 0.5:
lyra_clip_g = (lyra_clip_g - lyra_clip_g.mean()) / (lyra_clip_g.std() + 1e-8)
lyra_clip_g = lyra_clip_g * clip_g_embeds.std() + clip_g_embeds.mean()
# Blend
fused_clip_l = (1 - lyra_strength) * clip_l_embeds + lyra_strength * lyra_clip_l
fused_clip_g = (1 - lyra_strength) * clip_g_embeds + lyra_strength * lyra_clip_g
prompt_embeds_fused = torch.cat([fused_clip_l, fused_clip_g], dim=-1)
# Negative prompt - just use original CLIP
return prompt_embeds_fused, negative_prompt_embeds, pooled, negative_pooled
def _get_add_time_ids(
self,
original_size: Tuple[int, int],
crops_coords_top_left: Tuple[int, int],
target_size: Tuple[int, int],
dtype: torch.dtype
) -> torch.Tensor:
"""Create time embedding IDs for SDXL."""
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device)
return add_time_ids
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: str = "",
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 25,
guidance_scale: float = 7.0,
seed: Optional[int] = None,
use_lyra: bool = False,
clip_skip: int = 2,
t5_summary: str = "",
lyra_strength: float = 1.0,
progress_callback=None
):
"""Generate image using SDXL architecture."""
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
else:
generator = None
# Encode prompts
if use_lyra and self.lyra_loader is not None:
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra(
prompt, negative_prompt, clip_skip, t5_summary, lyra_strength
)
else:
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
prompt, negative_prompt, clip_skip
)
# Prepare latents
latent_channels = 4
latent_height = height // 8
latent_width = width // 8
latents = torch.randn(
(1, latent_channels, latent_height, latent_width),
generator=generator,
device=self.device,
dtype=torch.float16
)
# Set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps = self.scheduler.timesteps
latents = latents * self.scheduler.init_noise_sigma
# Time embeddings for SDXL
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
add_time_ids = self._get_add_time_ids(
original_size, crops_coords_top_left, target_size, dtype=torch.float16
)
negative_add_time_ids = add_time_ids
# Denoising loop
for i, t in enumerate(timesteps):
if progress_callback:
progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
timestep = t.expand(latent_model_input.shape[0])
if guidance_scale > 1.0:
text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
add_text_embeds = torch.cat([negative_pooled, pooled])
add_time_ids_input = torch.cat([negative_add_time_ids, add_time_ids])
else:
text_embeds = prompt_embeds
add_text_embeds = pooled
add_time_ids_input = add_time_ids
added_cond_kwargs = {
"text_embeds": add_text_embeds,
"time_ids": add_time_ids_input
}
noise_pred = self.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeds,
added_cond_kwargs=added_cond_kwargs,
return_dict=False
)[0]
if guidance_scale > 1.0:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
# Decode
latents = latents / self.vae_scale_factor
with torch.no_grad():
image = self.vae.decode(latents.to(self.vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
image = Image.fromarray(image[0])
return image
# ============================================================================
# MODEL LOADERS
# ============================================================================
def load_illustrious_xl(
repo_id: str = "AbstractPhil/illustrious-xl-v1",
filename: str = "illustriousXL_v01.safetensors",
device: str = "cuda"
) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]:
"""Load Illustrious XL from single safetensors file."""
from diffusers import StableDiffusionXLPipeline
print(f"📥 Loading Illustrious XL: {repo_id}/{filename}")
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
print(f"✓ Downloaded: {checkpoint_path}")
print("📦 Loading pipeline...")
pipe = StableDiffusionXLPipeline.from_single_file(
checkpoint_path,
torch_dtype=torch.float16,
use_safetensors=True,
)
unet = pipe.unet.to(device)
vae = pipe.vae.to(device)
text_encoder = pipe.text_encoder.to(device)
text_encoder_2 = pipe.text_encoder_2.to(device)
tokenizer = pipe.tokenizer
tokenizer_2 = pipe.tokenizer_2
del pipe
torch.cuda.empty_cache()
print("✅ Illustrious XL loaded!")
return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
# ============================================================================
# PIPELINE INITIALIZATION
# ============================================================================
def initialize_sdxl_pipeline(
model_choice: str,
scheduler_name: str = SCHEDULER_EULER_A,
device: str = "cuda"
):
"""Initialize SDXL pipeline with lazy T5/Lyra loading."""
print(f"🚀 Initializing {model_choice} pipeline...")
# Load base model
if "Illustrious" in model_choice:
unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl(device=device)
else:
# SDXL Base
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
)
unet = pipe.unet.to(device)
vae = pipe.vae.to(device)
text_encoder = pipe.text_encoder.to(device)
text_encoder_2 = pipe.text_encoder_2.to(device)
tokenizer = pipe.tokenizer
tokenizer_2 = pipe.tokenizer_2
del pipe
torch.cuda.empty_cache()
# Create lazy loaders (don't download yet)
t5_loader = LazyT5Encoder(model_name="google/flan-t5-xl", device=device)
lyra_loader = LazyLyraModel(
repo_id="AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
device=device,
version=2
)
# Get scheduler
scheduler = get_scheduler(scheduler_name)
pipeline = SDXLFlowMatchingPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
device=device,
t5_loader=t5_loader,
lyra_loader=lyra_loader,
clip_skip=2
)
print("✅ Pipeline initialized (T5/Lyra will load on first use)")
return pipeline
# ============================================================================
# GLOBAL STATE
# ============================================================================
CURRENT_PIPELINE = None
CURRENT_MODEL = None
CURRENT_SCHEDULER = None
def get_pipeline(model_choice: str, scheduler_name: str = SCHEDULER_EULER_A):
"""Get or create pipeline for selected model."""
global CURRENT_PIPELINE, CURRENT_MODEL, CURRENT_SCHEDULER
if CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice:
CURRENT_PIPELINE = initialize_sdxl_pipeline(model_choice, scheduler_name, device="cuda")
CURRENT_MODEL = model_choice
CURRENT_SCHEDULER = scheduler_name
elif CURRENT_SCHEDULER != scheduler_name:
CURRENT_PIPELINE.set_scheduler(scheduler_name)
CURRENT_SCHEDULER = scheduler_name
return CURRENT_PIPELINE
# ============================================================================
# INFERENCE
# ============================================================================
@spaces.GPU(duration=120)
def generate_image(
prompt: str,
t5_summary: str,
negative_prompt: str,
model_choice: str,
scheduler_name: str,
clip_skip: int,
num_steps: int,
cfg_scale: float,
width: int,
height: int,
use_lyra: bool,
lyra_strength: float,
seed: int,
randomize_seed: bool,
progress=gr.Progress()
):
"""Generate image with ZeroGPU support."""
if randomize_seed:
seed = np.random.randint(0, 2**32 - 1)
def progress_callback(step, total, desc):
progress((step + 1) / total, desc=desc)
try:
pipeline = get_pipeline(model_choice, scheduler_name)
if not use_lyra or pipeline.lyra_loader is None:
progress(0.05, desc="Generating...")
image = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_steps,
guidance_scale=cfg_scale,
seed=seed,
use_lyra=False,
clip_skip=clip_skip,
progress_callback=progress_callback
)
progress(1.0, desc="Complete!")
return image, None, seed
else:
progress(0.05, desc="Generating standard...")
image_standard = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_steps,
guidance_scale=cfg_scale,
seed=seed,
use_lyra=False,
clip_skip=clip_skip,
progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d)
)
progress(0.5, desc="Loading Lyra + T5 (first run only)...")
image_lyra = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_steps,
guidance_scale=cfg_scale,
seed=seed,
use_lyra=True,
clip_skip=clip_skip,
t5_summary=t5_summary,
lyra_strength=lyra_strength,
progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d)
)
progress(1.0, desc="Complete!")
return image_standard, image_lyra, seed
except Exception as e:
print(f"❌ Generation failed: {e}")
import traceback
traceback.print_exc()
raise e
# ============================================================================
# GRADIO UI
# ============================================================================
def create_demo():
"""Create Gradio interface."""
with gr.Blocks() as demo:
gr.Markdown("""
# 🌙 Lyra/Illustrious XL Image Generation
**Geometric crystalline diffusion** by [AbstractPhil](https://huggingface.co/AbstractPhil)
| Model | Architecture | Lyra Version | Best For |
|-------|-------------|--------------|----------|
| **Illustrious XL** | SDXL | v2 (T5-XL) | Anime/illustration, high detail |
| **SDXL Base** | SDXL | v2 (T5-XL) | Photorealistic, general purpose |
**Lyra VAE** fuses CLIP + T5-XL embeddings using adaptive Cantor attention.
T5 and Lyra only load when you enable the Lyra checkbox!
""")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.TextArea(
label="Prompt",
value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background",
lines=3
)
t5_summary = gr.TextArea(
label="T5 Summary (for Lyra)",
value="A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms",
lines=2,
info="Natural language description for T5. Leave empty to use prompt."
)
negative_prompt = gr.TextArea(
label="Negative Prompt",
value="lowres, bad anatomy, bad hands, text, error, worst quality, low quality",
lines=2
)
with gr.Row():
model_choice = gr.Dropdown(
label="Model",
choices=["Illustrious XL", "SDXL Base"],
value="Illustrious XL"
)
scheduler_name = gr.Dropdown(
label="Scheduler",
choices=SDXL_SCHEDULERS,
value=SCHEDULER_EULER_A
)
clip_skip = gr.Slider(
label="CLIP Skip",
minimum=1, maximum=4, value=2, step=1,
info="2 recommended for Illustrious"
)
use_lyra = gr.Checkbox(
label="Enable Lyra VAE (loads T5-XL on first use)",
value=False,
info="Compare standard vs geometric fusion"
)
lyra_strength = gr.Slider(
label="Lyra Blend Strength",
minimum=0.0, maximum=2.0, value=1.0, step=0.05,
info="0.0 = pure CLIP, 1.0 = pure Lyra"
)
with gr.Accordion("Generation Settings", open=True):
num_steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=25, step=1)
cfg_scale = gr.Slider(label="CFG Scale", minimum=1.0, maximum=15.0, value=7.0, step=0.5)
with gr.Row():
width = gr.Slider(label="Width", minimum=512, maximum=1536, value=1024, step=64)
height = gr.Slider(label="Height", minimum=512, maximum=1536, value=1024, step=64)
seed = gr.Slider(label="Seed", minimum=0, maximum=2**32 - 1, value=42, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
with gr.Column(scale=1):
with gr.Row():
output_image_standard = gr.Image(label="Standard", type="pil")
output_image_lyra = gr.Image(label="Lyra Fusion 🎵", type="pil", visible=True)
output_seed = gr.Number(label="Seed", precision=0)
# Event handlers
def on_lyra_toggle(enabled):
if enabled:
return {
output_image_standard: gr.update(visible=True, label="Standard"),
output_image_lyra: gr.update(visible=True, label="Lyra Fusion 🎵")
}
else:
return {
output_image_standard: gr.update(visible=True, label="Generated Image"),
output_image_lyra: gr.update(visible=False)
}
use_lyra.change(
fn=on_lyra_toggle,
inputs=[use_lyra],
outputs=[output_image_standard, output_image_lyra]
)
generate_btn.click(
fn=generate_image,
inputs=[
prompt, t5_summary, negative_prompt, model_choice, scheduler_name,
clip_skip, num_steps, cfg_scale, width, height,
use_lyra, lyra_strength, seed, randomize_seed
],
outputs=[output_image_standard, output_image_lyra, output_seed]
)
return demo
# ============================================================================
# LAUNCH
# ============================================================================
if __name__ == "__main__":
demo = create_demo()
demo.queue(max_size=20)
demo.launch()