AbstractPhil commited on
Commit
eb18eb8
·
verified ·
1 Parent(s): 95fa965

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1524 -125
app.py CHANGED
@@ -1,154 +1,1553 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
- import random
 
 
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
 
 
 
 
 
 
 
 
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
 
22
 
 
 
 
 
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
 
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
 
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
 
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  minimum=1,
131
- maximum=50,
 
132
  step=1,
133
- value=2, # Replace with defaults that work for your model
 
 
 
 
 
 
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
 
 
1
+ """
2
+ Lyra/Lune Flow-Matching Inference Space
3
+ Author: AbstractPhil
4
+ License: MIT
5
+
6
+ SD1.5 and SDXL-based flow matching with geometric crystalline architectures.
7
+ Supports Illustrious XL, standard SDXL, and SD1.5 variants.
8
+ """
9
+
10
+ import os
11
+ import torch
12
  import gradio as gr
13
  import numpy as np
14
+ from PIL import Image
15
+ from typing import Optional, Dict, Tuple
16
+ import spaces
17
+ from safetensors.torch import load_file as load_safetensors
18
 
19
+ from diffusers import (
20
+ UNet2DConditionModel,
21
+ AutoencoderKL,
22
+ EulerDiscreteScheduler,
23
+ EulerAncestralDiscreteScheduler
24
+ )
25
+ from diffusers.models import UNet2DConditionModel as DiffusersUNet
26
+ from transformers import (
27
+ CLIPTextModel,
28
+ CLIPTokenizer,
29
+ CLIPTextModelWithProjection,
30
+ T5EncoderModel,
31
+ T5Tokenizer
32
+ )
33
+ from huggingface_hub import hf_hub_download
34
 
35
+ # Import Lyra VAE from geofractal
36
+ try:
37
+ from geofractal.models.vae.vae_lyra import MultiModalVAE, MultiModalVAEConfig
38
+ LYRA_AVAILABLE = True
39
+ except ImportError:
40
+ try:
41
+ from geofractal.train.model.vae.vae_lyra import MultiModalVAE, MultiModalVAEConfig
42
+ LYRA_AVAILABLE = True
43
+ except ImportError:
44
+ print("⚠️ Lyra VAE not available - install geofractal")
45
+ LYRA_AVAILABLE = False
46
 
 
 
 
 
47
 
48
+ # ============================================================================
49
+ # CONSTANTS
50
+ # ============================================================================
51
 
52
+ # Model architectures
53
+ ARCH_SD15 = "sd15"
54
+ ARCH_SDXL = "sdxl"
55
 
56
+ # ComfyUI key prefixes for SDXL single-file checkpoints
57
+ COMFYUI_UNET_PREFIX = "model.diffusion_model."
58
+ COMFYUI_CLIP_L_PREFIX = "conditioner.embedders.0.transformer."
59
+ COMFYUI_CLIP_G_PREFIX = "conditioner.embedders.1.model."
60
+ COMFYUI_VAE_PREFIX = "first_stage_model."
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # ============================================================================
64
+ # MODEL LOADING UTILITIES
65
+ # ============================================================================
66
 
67
+ def extract_comfyui_components(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
68
+ """Extract UNet, CLIP-L, CLIP-G, and VAE from ComfyUI single-file checkpoint."""
69
+
70
+ components = {
71
+ "unet": {},
72
+ "clip_l": {},
73
+ "clip_g": {},
74
+ "vae": {}
75
+ }
76
+
77
+ for key, value in state_dict.items():
78
+ if key.startswith(COMFYUI_UNET_PREFIX):
79
+ new_key = key[len(COMFYUI_UNET_PREFIX):]
80
+ components["unet"][new_key] = value
81
+ elif key.startswith(COMFYUI_CLIP_L_PREFIX):
82
+ new_key = key[len(COMFYUI_CLIP_L_PREFIX):]
83
+ components["clip_l"][new_key] = value
84
+ elif key.startswith(COMFYUI_CLIP_G_PREFIX):
85
+ new_key = key[len(COMFYUI_CLIP_G_PREFIX):]
86
+ components["clip_g"][new_key] = value
87
+ elif key.startswith(COMFYUI_VAE_PREFIX):
88
+ new_key = key[len(COMFYUI_VAE_PREFIX):]
89
+ components["vae"][new_key] = value
90
+
91
+ print(f" Extracted components:")
92
+ print(f" UNet: {len(components['unet'])} keys")
93
+ print(f" CLIP-L: {len(components['clip_l'])} keys")
94
+ print(f" CLIP-G: {len(components['clip_g'])} keys")
95
+ print(f" VAE: {len(components['vae'])} keys")
96
+
97
+ return components
98
+
99
+
100
+ def get_clip_hidden_state(
101
+ model_output,
102
+ clip_skip: int = 1,
103
+ output_hidden_states: bool = True
104
+ ) -> torch.Tensor:
105
+ """Extract hidden state with clip_skip support."""
106
+ if clip_skip == 1 or not output_hidden_states:
107
+ return model_output.last_hidden_state
108
+
109
+ if hasattr(model_output, 'hidden_states') and model_output.hidden_states is not None:
110
+ # hidden_states is tuple: (embedding, layer1, ..., layerN)
111
+ # clip_skip=2 means penultimate layer = hidden_states[-2]
112
+ return model_output.hidden_states[-clip_skip]
113
+
114
+ return model_output.last_hidden_state
115
 
 
116
 
117
+ # ============================================================================
118
+ # SDXL PIPELINE
119
+ # ============================================================================
120
+
121
+ class SDXLFlowMatchingPipeline:
122
+ """Pipeline for SDXL-based flow-matching inference with dual CLIP encoders."""
123
+
124
+ def __init__(
125
+ self,
126
+ vae: AutoencoderKL,
127
+ text_encoder: CLIPTextModel, # CLIP-L
128
+ text_encoder_2: CLIPTextModelWithProjection, # CLIP-G
129
+ tokenizer: CLIPTokenizer,
130
+ tokenizer_2: CLIPTokenizer,
131
+ unet: UNet2DConditionModel,
132
+ scheduler,
133
+ device: str = "cuda",
134
+ t5_encoder: Optional[T5EncoderModel] = None,
135
+ t5_tokenizer: Optional[T5Tokenizer] = None,
136
+ lyra_model: Optional[any] = None,
137
+ clip_skip: int = 1
138
+ ):
139
+ self.vae = vae
140
+ self.text_encoder = text_encoder
141
+ self.text_encoder_2 = text_encoder_2
142
+ self.tokenizer = tokenizer
143
+ self.tokenizer_2 = tokenizer_2
144
+ self.unet = unet
145
+ self.scheduler = scheduler
146
+ self.device = device
147
+
148
+ # Lyra components
149
+ self.t5_encoder = t5_encoder
150
+ self.t5_tokenizer = t5_tokenizer
151
+ self.lyra_model = lyra_model
152
+
153
+ # Settings
154
+ self.clip_skip = clip_skip
155
+ self.vae_scale_factor = 0.13025 # SDXL VAE scaling
156
+ self.arch = ARCH_SDXL
157
+
158
+ def encode_prompt(
159
+ self,
160
+ prompt: str,
161
+ negative_prompt: str = "",
162
+ clip_skip: int = 1
163
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
164
+ """Encode prompts using dual CLIP encoders for SDXL."""
165
+
166
+ # CLIP-L encoding
167
+ text_inputs = self.tokenizer(
168
+ prompt,
169
+ padding="max_length",
170
+ max_length=self.tokenizer.model_max_length,
171
+ truncation=True,
172
+ return_tensors="pt",
173
+ )
174
+ text_input_ids = text_inputs.input_ids.to(self.device)
175
+
176
+ with torch.no_grad():
177
+ output_hidden_states = clip_skip > 1
178
+ clip_l_output = self.text_encoder(
179
+ text_input_ids,
180
+ output_hidden_states=output_hidden_states
181
+ )
182
+ prompt_embeds_l = get_clip_hidden_state(clip_l_output, clip_skip, output_hidden_states)
183
+
184
+ # CLIP-G encoding
185
+ text_inputs_2 = self.tokenizer_2(
186
+ prompt,
187
+ padding="max_length",
188
+ max_length=self.tokenizer_2.model_max_length,
189
+ truncation=True,
190
+ return_tensors="pt",
191
+ )
192
+ text_input_ids_2 = text_inputs_2.input_ids.to(self.device)
193
+
194
+ with torch.no_grad():
195
+ clip_g_output = self.text_encoder_2(
196
+ text_input_ids_2,
197
+ output_hidden_states=output_hidden_states
198
+ )
199
+ prompt_embeds_g = get_clip_hidden_state(clip_g_output, clip_skip, output_hidden_states)
200
+
201
+ # Get pooled output from CLIP-G
202
+ pooled_prompt_embeds = clip_g_output.text_embeds
203
+
204
+ # Concatenate CLIP-L and CLIP-G embeddings
205
+ prompt_embeds = torch.cat([prompt_embeds_l, prompt_embeds_g], dim=-1)
206
+
207
+ # Negative prompt
208
+ if negative_prompt:
209
+ uncond_inputs = self.tokenizer(
210
+ negative_prompt,
211
+ padding="max_length",
212
+ max_length=self.tokenizer.model_max_length,
213
+ truncation=True,
214
+ return_tensors="pt",
215
+ )
216
+ uncond_input_ids = uncond_inputs.input_ids.to(self.device)
217
+
218
+ uncond_inputs_2 = self.tokenizer_2(
219
+ negative_prompt,
220
+ padding="max_length",
221
+ max_length=self.tokenizer_2.model_max_length,
222
+ truncation=True,
223
+ return_tensors="pt",
224
+ )
225
+ uncond_input_ids_2 = uncond_inputs_2.input_ids.to(self.device)
226
+
227
+ with torch.no_grad():
228
+ uncond_output_l = self.text_encoder(
229
+ uncond_input_ids,
230
+ output_hidden_states=output_hidden_states
231
+ )
232
+ negative_embeds_l = get_clip_hidden_state(uncond_output_l, clip_skip, output_hidden_states)
233
+
234
+ uncond_output_g = self.text_encoder_2(
235
+ uncond_input_ids_2,
236
+ output_hidden_states=output_hidden_states
237
+ )
238
+ negative_embeds_g = get_clip_hidden_state(uncond_output_g, clip_skip, output_hidden_states)
239
+ negative_pooled = uncond_output_g.text_embeds
240
+
241
+ negative_prompt_embeds = torch.cat([negative_embeds_l, negative_embeds_g], dim=-1)
242
+ else:
243
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
244
+ negative_pooled = torch.zeros_like(pooled_prompt_embeds)
245
+
246
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled
247
 
248
+ def encode_prompt_lyra(
249
+ self,
250
+ prompt: str,
251
+ negative_prompt: str = "",
252
+ clip_skip: int = 1
253
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
254
+ """Encode prompts using Lyra VAE fusion (CLIP + T5)."""
255
+ if self.lyra_model is None or self.t5_encoder is None:
256
+ raise ValueError("Lyra VAE components not initialized")
257
+
258
+ # Get standard CLIP embeddings first
259
+ prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
260
+ prompt, negative_prompt, clip_skip
261
+ )
262
+
263
+ # Get T5 embeddings
264
+ t5_inputs = self.t5_tokenizer(
265
+ prompt,
266
+ max_length=77,
267
+ padding='max_length',
268
+ truncation=True,
269
+ return_tensors='pt'
270
+ ).to(self.device)
271
+
272
+ with torch.no_grad():
273
+ t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state
274
+
275
+ # For SDXL, we need to handle the concatenated CLIP-L + CLIP-G embeddings
276
+ # Split them, fuse CLIP-L through Lyra, then recombine
277
+ clip_l_dim = 768
278
+ clip_g_dim = 1280
279
+
280
+ clip_l_embeds = prompt_embeds[..., :clip_l_dim]
281
+ clip_g_embeds = prompt_embeds[..., clip_l_dim:]
282
+
283
+ # Fuse CLIP-L through Lyra
284
+ modality_inputs = {
285
+ 'clip': clip_l_embeds,
286
+ 't5': t5_embeds
287
+ }
288
+
289
+ with torch.no_grad():
290
+ reconstructions, mu, logvar = self.lyra_model(
291
+ modality_inputs,
292
+ target_modalities=['clip']
293
  )
294
+ fused_clip_l = reconstructions['clip']
295
+
296
+ # Recombine with CLIP-G
297
+ prompt_embeds_fused = torch.cat([fused_clip_l, clip_g_embeds], dim=-1)
298
+
299
+ # Process negative prompt similarly if present
300
+ if negative_prompt:
301
+ t5_inputs_neg = self.t5_tokenizer(
302
+ negative_prompt,
303
+ max_length=77,
304
+ padding='max_length',
305
+ truncation=True,
306
+ return_tensors='pt'
307
+ ).to(self.device)
308
+
309
+ with torch.no_grad():
310
+ t5_embeds_neg = self.t5_encoder(**t5_inputs_neg).last_hidden_state
311
+
312
+ neg_clip_l = negative_prompt_embeds[..., :clip_l_dim]
313
+ neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
314
+
315
+ modality_inputs_neg = {
316
+ 'clip': neg_clip_l,
317
+ 't5': t5_embeds_neg
318
+ }
319
+
320
+ with torch.no_grad():
321
+ reconstructions_neg, _, _ = self.lyra_model(
322
+ modality_inputs_neg,
323
+ target_modalities=['clip']
324
+ )
325
+ fused_neg_clip_l = reconstructions_neg['clip']
326
+
327
+ negative_prompt_embeds_fused = torch.cat([fused_neg_clip_l, neg_clip_g], dim=-1)
328
+ else:
329
+ negative_prompt_embeds_fused = torch.zeros_like(prompt_embeds_fused)
330
+
331
+ return prompt_embeds_fused, negative_prompt_embeds_fused, pooled, negative_pooled
332
+
333
+ def _get_add_time_ids(
334
+ self,
335
+ original_size: Tuple[int, int],
336
+ crops_coords_top_left: Tuple[int, int],
337
+ target_size: Tuple[int, int],
338
+ dtype: torch.dtype
339
+ ) -> torch.Tensor:
340
+ """Create time embedding IDs for SDXL."""
341
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
342
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device)
343
+ return add_time_ids
344
 
345
+ @torch.no_grad()
346
+ def __call__(
347
+ self,
348
+ prompt: str,
349
+ negative_prompt: str = "",
350
+ height: int = 1024,
351
+ width: int = 1024,
352
+ num_inference_steps: int = 20,
353
+ guidance_scale: float = 7.5,
354
+ shift: float = 0.0,
355
+ use_flow_matching: bool = False,
356
+ prediction_type: str = "epsilon",
357
+ seed: Optional[int] = None,
358
+ use_lyra: bool = False,
359
+ clip_skip: int = 1,
360
+ progress_callback=None
361
+ ):
362
+ """Generate image using SDXL architecture."""
363
+
364
+ # Set seed
365
+ if seed is not None:
366
+ generator = torch.Generator(device=self.device).manual_seed(seed)
367
+ else:
368
+ generator = None
369
+
370
+ # Encode prompts
371
+ if use_lyra and self.lyra_model is not None:
372
+ prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra(
373
+ prompt, negative_prompt, clip_skip
374
  )
375
+ else:
376
+ prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
377
+ prompt, negative_prompt, clip_skip
378
+ )
379
+
380
+ # Prepare latents
381
+ latent_channels = 4
382
+ latent_height = height // 8
383
+ latent_width = width // 8
384
+
385
+ latents = torch.randn(
386
+ (1, latent_channels, latent_height, latent_width),
387
+ generator=generator,
388
+ device=self.device,
389
+ dtype=torch.float16
390
+ )
391
+
392
+ # Set timesteps
393
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
394
+ timesteps = self.scheduler.timesteps
395
+
396
+ # Scale initial latents
397
+ if not use_flow_matching:
398
+ latents = latents * self.scheduler.init_noise_sigma
399
+
400
+ # Prepare added time embeddings for SDXL
401
+ original_size = (height, width)
402
+ target_size = (height, width)
403
+ crops_coords_top_left = (0, 0)
404
+
405
+ add_time_ids = self._get_add_time_ids(
406
+ original_size, crops_coords_top_left, target_size, dtype=torch.float16
407
+ )
408
+ negative_add_time_ids = add_time_ids # Same for negative
409
+
410
+ # Denoising loop
411
+ for i, t in enumerate(timesteps):
412
+ if progress_callback:
413
+ progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
414
+
415
+ # Expand for CFG
416
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
417
+
418
+ # Flow matching scaling
419
+ if use_flow_matching and shift > 0:
420
+ sigma = t.float() / 1000.0
421
+ sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
422
+ scaling = torch.sqrt(1 + sigma_shifted ** 2)
423
+ latent_model_input = latent_model_input / scaling
424
+ else:
425
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
426
+
427
+ # Prepare timestep
428
+ timestep = t.expand(latent_model_input.shape[0])
429
+
430
+ # Prepare added conditions
431
+ if guidance_scale > 1.0:
432
+ text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
433
+ add_text_embeds = torch.cat([negative_pooled, pooled])
434
+ add_time_ids_input = torch.cat([negative_add_time_ids, add_time_ids])
435
+ else:
436
+ text_embeds = prompt_embeds
437
+ add_text_embeds = pooled
438
+ add_time_ids_input = add_time_ids
439
+
440
+ # Prepare added cond kwargs for SDXL UNet
441
+ added_cond_kwargs = {
442
+ "text_embeds": add_text_embeds,
443
+ "time_ids": add_time_ids_input
444
+ }
445
+
446
+ # Predict noise
447
+ noise_pred = self.unet(
448
+ latent_model_input,
449
+ timestep,
450
+ encoder_hidden_states=text_embeds,
451
+ added_cond_kwargs=added_cond_kwargs,
452
+ return_dict=False
453
+ )[0]
454
+
455
+ # CFG
456
+ if guidance_scale > 1.0:
457
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
458
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
459
+
460
+ # Step
461
+ if use_flow_matching:
462
+ sigma = t.float() / 1000.0
463
+ sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
464
+
465
+ if prediction_type == "v_prediction":
466
+ v_pred = noise_pred
467
+ alpha_t = torch.sqrt(1 - sigma_shifted ** 2)
468
+ sigma_t = sigma_shifted
469
+ noise_pred = alpha_t * v_pred + sigma_t * latents
470
+
471
+ dt = -1.0 / num_inference_steps
472
+ latents = latents + dt * noise_pred
473
+ else:
474
+ latents = self.scheduler.step(
475
+ noise_pred, t, latents, return_dict=False
476
+ )[0]
477
+
478
+ # Decode
479
+ latents = latents / self.vae_scale_factor
480
+
481
+ with torch.no_grad():
482
+ image = self.vae.decode(latents.to(self.vae.dtype)).sample
483
+
484
+ # Convert to PIL
485
+ image = (image / 2 + 0.5).clamp(0, 1)
486
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
487
+ image = (image * 255).round().astype("uint8")
488
+ image = Image.fromarray(image[0])
489
+
490
+ return image
491
 
 
492
 
493
+ # ============================================================================
494
+ # SD1.5 PIPELINE (Original)
495
+ # ============================================================================
 
 
 
 
 
496
 
497
+ class SD15FlowMatchingPipeline:
498
+ """Pipeline for SD1.5-based flow-matching inference."""
499
+
500
+ def __init__(
501
+ self,
502
+ vae: AutoencoderKL,
503
+ text_encoder: CLIPTextModel,
504
+ tokenizer: CLIPTokenizer,
505
+ unet: UNet2DConditionModel,
506
+ scheduler,
507
+ device: str = "cuda",
508
+ t5_encoder: Optional[T5EncoderModel] = None,
509
+ t5_tokenizer: Optional[T5Tokenizer] = None,
510
+ lyra_model: Optional[any] = None
511
+ ):
512
+ self.vae = vae
513
+ self.text_encoder = text_encoder
514
+ self.tokenizer = tokenizer
515
+ self.unet = unet
516
+ self.scheduler = scheduler
517
+ self.device = device
518
+
519
+ self.t5_encoder = t5_encoder
520
+ self.t5_tokenizer = t5_tokenizer
521
+ self.lyra_model = lyra_model
522
+
523
+ self.vae_scale_factor = 0.18215
524
+ self.arch = ARCH_SD15
525
+ self.is_lune_model = False
526
+
527
+ def encode_prompt(self, prompt: str, negative_prompt: str = ""):
528
+ """Encode text prompts to embeddings."""
529
+ text_inputs = self.tokenizer(
530
+ prompt,
531
+ padding="max_length",
532
+ max_length=self.tokenizer.model_max_length,
533
+ truncation=True,
534
+ return_tensors="pt",
535
+ )
536
+ text_input_ids = text_inputs.input_ids.to(self.device)
537
+
538
+ with torch.no_grad():
539
+ prompt_embeds = self.text_encoder(text_input_ids)[0]
540
+
541
+ if negative_prompt:
542
+ uncond_inputs = self.tokenizer(
543
+ negative_prompt,
544
+ padding="max_length",
545
+ max_length=self.tokenizer.model_max_length,
546
+ truncation=True,
547
+ return_tensors="pt",
548
+ )
549
+ uncond_input_ids = uncond_inputs.input_ids.to(self.device)
550
+
551
+ with torch.no_grad():
552
+ negative_prompt_embeds = self.text_encoder(uncond_input_ids)[0]
553
+ else:
554
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
555
+
556
+ return prompt_embeds, negative_prompt_embeds
557
+
558
+ def encode_prompt_lyra(self, prompt: str, negative_prompt: str = ""):
559
+ """Encode using Lyra VAE (CLIP + T5 fusion)."""
560
+ if self.lyra_model is None or self.t5_encoder is None:
561
+ raise ValueError("Lyra VAE components not initialized")
562
+
563
+ # CLIP
564
+ text_inputs = self.tokenizer(
565
+ prompt,
566
+ padding="max_length",
567
+ max_length=self.tokenizer.model_max_length,
568
+ truncation=True,
569
+ return_tensors="pt",
570
+ )
571
+ text_input_ids = text_inputs.input_ids.to(self.device)
572
+
573
+ with torch.no_grad():
574
+ clip_embeds = self.text_encoder(text_input_ids)[0]
575
+
576
+ # T5
577
+ t5_inputs = self.t5_tokenizer(
578
+ prompt,
579
+ max_length=77,
580
+ padding='max_length',
581
+ truncation=True,
582
+ return_tensors='pt'
583
+ ).to(self.device)
584
+
585
+ with torch.no_grad():
586
+ t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state
587
+
588
+ # Fuse
589
+ modality_inputs = {'clip': clip_embeds, 't5': t5_embeds}
590
+
591
+ with torch.no_grad():
592
+ reconstructions, mu, logvar = self.lyra_model(
593
+ modality_inputs,
594
+ target_modalities=['clip']
595
+ )
596
+ prompt_embeds = reconstructions['clip']
597
+
598
+ # Negative
599
+ if negative_prompt:
600
+ uncond_inputs = self.tokenizer(
601
+ negative_prompt,
602
+ padding="max_length",
603
+ max_length=self.tokenizer.model_max_length,
604
+ truncation=True,
605
+ return_tensors="pt",
606
+ )
607
+ uncond_input_ids = uncond_inputs.input_ids.to(self.device)
608
+
609
+ with torch.no_grad():
610
+ clip_embeds_uncond = self.text_encoder(uncond_input_ids)[0]
611
+
612
+ t5_inputs_uncond = self.t5_tokenizer(
613
+ negative_prompt,
614
+ max_length=77,
615
+ padding='max_length',
616
+ truncation=True,
617
+ return_tensors='pt'
618
+ ).to(self.device)
619
+
620
+ with torch.no_grad():
621
+ t5_embeds_uncond = self.t5_encoder(**t5_inputs_uncond).last_hidden_state
622
+
623
+ modality_inputs_uncond = {'clip': clip_embeds_uncond, 't5': t5_embeds_uncond}
624
+
625
+ with torch.no_grad():
626
+ reconstructions_uncond, _, _ = self.lyra_model(
627
+ modality_inputs_uncond,
628
+ target_modalities=['clip']
629
  )
630
+ negative_prompt_embeds = reconstructions_uncond['clip']
631
+ else:
632
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
633
+
634
+ return prompt_embeds, negative_prompt_embeds
635
+
636
+ @torch.no_grad()
637
+ def __call__(
638
+ self,
639
+ prompt: str,
640
+ negative_prompt: str = "",
641
+ height: int = 512,
642
+ width: int = 512,
643
+ num_inference_steps: int = 20,
644
+ guidance_scale: float = 7.5,
645
+ shift: float = 2.5,
646
+ use_flow_matching: bool = True,
647
+ prediction_type: str = "epsilon",
648
+ seed: Optional[int] = None,
649
+ use_lyra: bool = False,
650
+ clip_skip: int = 1, # Unused for SD1.5 but kept for API consistency
651
+ progress_callback=None
652
+ ):
653
+ """Generate image."""
654
+
655
+ if seed is not None:
656
+ generator = torch.Generator(device=self.device).manual_seed(seed)
657
+ else:
658
+ generator = None
659
+
660
+ if use_lyra and self.lyra_model is not None:
661
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt_lyra(prompt, negative_prompt)
662
+ else:
663
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, negative_prompt)
664
+
665
+ latent_channels = 4
666
+ latent_height = height // 8
667
+ latent_width = width // 8
668
+
669
+ latents = torch.randn(
670
+ (1, latent_channels, latent_height, latent_width),
671
+ generator=generator,
672
+ device=self.device,
673
+ dtype=torch.float32
674
+ )
675
+
676
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
677
+ timesteps = self.scheduler.timesteps
678
+
679
+ if not use_flow_matching:
680
+ latents = latents * self.scheduler.init_noise_sigma
681
+
682
+ for i, t in enumerate(timesteps):
683
+ if progress_callback:
684
+ progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
685
+
686
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
687
+
688
+ if use_flow_matching and shift > 0:
689
+ sigma = t.float() / 1000.0
690
+ sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
691
+ scaling = torch.sqrt(1 + sigma_shifted ** 2)
692
+ latent_model_input = latent_model_input / scaling
693
+ else:
694
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
695
+
696
+ timestep = t.expand(latent_model_input.shape[0])
697
+ text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if guidance_scale > 1.0 else prompt_embeds
698
+
699
+ noise_pred = self.unet(
700
+ latent_model_input,
701
+ timestep,
702
+ encoder_hidden_states=text_embeds,
703
+ return_dict=False
704
+ )[0]
705
+
706
+ if guidance_scale > 1.0:
707
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
708
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
709
+
710
+ if use_flow_matching:
711
+ sigma = t.float() / 1000.0
712
+ sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
713
+
714
+ if prediction_type == "v_prediction":
715
+ v_pred = noise_pred
716
+ alpha_t = torch.sqrt(1 - sigma_shifted ** 2)
717
+ sigma_t = sigma_shifted
718
+ noise_pred = alpha_t * v_pred + sigma_t * latents
719
+
720
+ dt = -1.0 / num_inference_steps
721
+ latents = latents + dt * noise_pred
722
+ else:
723
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
724
+
725
+ latents = latents / self.vae_scale_factor
726
+
727
+ if self.is_lune_model:
728
+ latents = latents * 5.52
729
+
730
+ with torch.no_grad():
731
+ image = self.vae.decode(latents).sample
732
+
733
+ image = (image / 2 + 0.5).clamp(0, 1)
734
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
735
+ image = (image * 255).round().astype("uint8")
736
+ image = Image.fromarray(image[0])
737
+
738
+ return image
739
 
 
 
 
 
 
 
 
 
740
 
741
+ # ============================================================================
742
+ # MODEL LOADERS
743
+ # ============================================================================
744
+
745
+ def load_lune_checkpoint(repo_id: str, filename: str, device: str = "cuda"):
746
+ """Load Lune checkpoint from .pt file."""
747
+ print(f"📥 Downloading: {repo_id}/{filename}")
748
+
749
+ checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
750
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
751
+
752
+ print(f"🏗️ Initializing SD1.5 UNet...")
753
+ unet = UNet2DConditionModel.from_pretrained(
754
+ "runwayml/stable-diffusion-v1-5",
755
+ subfolder="unet",
756
+ torch_dtype=torch.float32
757
+ )
758
+
759
+ student_state_dict = checkpoint["student"]
760
+ cleaned_dict = {}
761
+ for key, value in student_state_dict.items():
762
+ if key.startswith("unet."):
763
+ cleaned_dict[key[5:]] = value
764
+ else:
765
+ cleaned_dict[key] = value
766
+
767
+ unet.load_state_dict(cleaned_dict, strict=False)
768
+
769
+ step = checkpoint.get("gstep", "unknown")
770
+ print(f"✅ Loaded Lune from step {step}")
771
+
772
+ return unet.to(device)
773
+
774
+
775
+ def load_illustrious_xl(
776
+ repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
777
+ filename: str = "illustriousXL_v01.safetensors",
778
+ device: str = "cuda"
779
+ ) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]:
780
+ """Load Illustrious XL from single safetensors file."""
781
+
782
+ print(f"📥 Downloading Illustrious XL: {repo_id}/{filename}")
783
+
784
+ checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
785
+ print(f"✓ Downloaded: {checkpoint_path}")
786
+
787
+ print("📦 Loading safetensors...")
788
+ state_dict = load_safetensors(checkpoint_path)
789
+
790
+ # Extract components
791
+ components = extract_comfyui_components(state_dict)
792
+
793
+ # Load UNet from SDXL base config, then load weights
794
+ print("🏗️ Initializing SDXL UNet...")
795
+ unet = UNet2DConditionModel.from_pretrained(
796
+ "stabilityai/stable-diffusion-xl-base-1.0",
797
+ subfolder="unet",
798
+ torch_dtype=torch.float16
799
+ )
800
+
801
+ if components["unet"]:
802
+ missing, unexpected = unet.load_state_dict(components["unet"], strict=False)
803
+ print(f" UNet: {len(missing)} missing, {len(unexpected)} unexpected keys")
804
+
805
+ # Load VAE
806
+ print("🏗️ Initializing SDXL VAE...")
807
+ vae = AutoencoderKL.from_pretrained(
808
+ "stabilityai/stable-diffusion-xl-base-1.0",
809
+ subfolder="vae",
810
+ torch_dtype=torch.float16
811
+ )
812
+
813
+ if components["vae"]:
814
+ missing, unexpected = vae.load_state_dict(components["vae"], strict=False)
815
+ print(f" VAE: {len(missing)} missing, {len(unexpected)} unexpected keys")
816
+
817
+ # Load CLIP-L
818
+ print("🏗️ Loading CLIP-L...")
819
+ text_encoder = CLIPTextModel.from_pretrained(
820
+ "openai/clip-vit-large-patch14",
821
+ torch_dtype=torch.float16
822
+ )
823
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
824
+
825
+ # Load CLIP-G
826
+ print("🏗️ Loading CLIP-G...")
827
+ text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
828
+ "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
829
+ torch_dtype=torch.float16
830
+ )
831
+ tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
832
+
833
+ # Move to device
834
+ unet = unet.to(device)
835
+ vae = vae.to(device)
836
+ text_encoder = text_encoder.to(device)
837
+ text_encoder_2 = text_encoder_2.to(device)
838
+
839
+ print("✅ Illustrious XL loaded!")
840
+
841
+ return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
842
+
843
+
844
+ def load_sdxl_base(device: str = "cuda"):
845
+ """Load standard SDXL base model."""
846
+ print("📥 Loading SDXL Base 1.0...")
847
+
848
+ unet = UNet2DConditionModel.from_pretrained(
849
+ "stabilityai/stable-diffusion-xl-base-1.0",
850
+ subfolder="unet",
851
+ torch_dtype=torch.float16
852
+ ).to(device)
853
+
854
+ vae = AutoencoderKL.from_pretrained(
855
+ "stabilityai/stable-diffusion-xl-base-1.0",
856
+ subfolder="vae",
857
+ torch_dtype=torch.float16
858
+ ).to(device)
859
+
860
+ text_encoder = CLIPTextModel.from_pretrained(
861
+ "stabilityai/stable-diffusion-xl-base-1.0",
862
+ subfolder="text_encoder",
863
+ torch_dtype=torch.float16
864
+ ).to(device)
865
+
866
+ text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
867
+ "stabilityai/stable-diffusion-xl-base-1.0",
868
+ subfolder="text_encoder_2",
869
+ torch_dtype=torch.float16
870
+ ).to(device)
871
+
872
+ tokenizer = CLIPTokenizer.from_pretrained(
873
+ "stabilityai/stable-diffusion-xl-base-1.0",
874
+ subfolder="tokenizer"
875
+ )
876
+
877
+ tokenizer_2 = CLIPTokenizer.from_pretrained(
878
+ "stabilityai/stable-diffusion-xl-base-1.0",
879
+ subfolder="tokenizer_2"
880
+ )
881
+
882
+ print("✅ SDXL Base loaded!")
883
+
884
+ return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
885
+
886
+
887
+ def load_lyra_vae(repo_id: str = "AbstractPhil/vae-lyra", device: str = "cuda"):
888
+ """Load Lyra VAE (SD1.5 version) from HuggingFace."""
889
+ if not LYRA_AVAILABLE:
890
+ print("⚠️ Lyra VAE not available")
891
+ return None
892
+
893
+ print(f"🎵 Loading Lyra VAE from {repo_id}...")
894
+
895
+ try:
896
+ checkpoint_path = hf_hub_download(
897
+ repo_id=repo_id,
898
+ filename="best_model.pt",
899
+ repo_type="model"
900
+ )
901
+
902
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
903
+
904
+ if 'config' in checkpoint:
905
+ config_dict = checkpoint['config']
906
+ else:
907
+ config_dict = {
908
+ 'modality_dims': {"clip": 768, "t5": 768},
909
+ 'latent_dim': 768,
910
+ 'seq_len': 77,
911
+ 'encoder_layers': 3,
912
+ 'decoder_layers': 3,
913
+ 'hidden_dim': 1024,
914
+ 'dropout': 0.1,
915
+ 'fusion_strategy': 'cantor',
916
+ 'fusion_heads': 8,
917
+ 'fusion_dropout': 0.1
918
+ }
919
+
920
+ vae_config = MultiModalVAEConfig(
921
+ modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 768}),
922
+ latent_dim=config_dict.get('latent_dim', 768),
923
+ seq_len=config_dict.get('seq_len', 77),
924
+ encoder_layers=config_dict.get('encoder_layers', 3),
925
+ decoder_layers=config_dict.get('decoder_layers', 3),
926
+ hidden_dim=config_dict.get('hidden_dim', 1024),
927
+ dropout=config_dict.get('dropout', 0.1),
928
+ fusion_strategy=config_dict.get('fusion_strategy', 'cantor'),
929
+ fusion_heads=config_dict.get('fusion_heads', 8),
930
+ fusion_dropout=config_dict.get('fusion_dropout', 0.1)
931
+ )
932
+
933
+ lyra_model = MultiModalVAE(vae_config)
934
+
935
+ if 'model_state_dict' in checkpoint:
936
+ lyra_model.load_state_dict(checkpoint['model_state_dict'])
937
+ else:
938
+ lyra_model.load_state_dict(checkpoint)
939
+
940
+ lyra_model.to(device)
941
+ lyra_model.eval()
942
+
943
+ print(f"✅ Lyra VAE (SD1.5) loaded")
944
+ return lyra_model
945
+
946
+ except Exception as e:
947
+ print(f"❌ Failed to load Lyra VAE: {e}")
948
+ return None
949
+
950
+
951
+ def load_lyra_vae_xl(
952
+ repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
953
+ device: str = "cuda"
954
+ ):
955
+ """Load Lyra VAE XL version for SDXL/Illustrious."""
956
+ if not LYRA_AVAILABLE:
957
+ print("⚠️ Lyra VAE not available")
958
+ return None
959
+
960
+ print(f"🎵 Loading Lyra VAE XL from {repo_id}...")
961
+
962
+ try:
963
+ checkpoint_path = hf_hub_download(
964
+ repo_id=repo_id,
965
+ filename="best_model.pt",
966
+ repo_type="model"
967
+ )
968
+
969
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
970
+
971
+ if 'config' in checkpoint:
972
+ config_dict = checkpoint['config']
973
+ else:
974
+ # XL defaults - note larger dimensions
975
+ config_dict = {
976
+ 'modality_dims': {"clip": 768, "t5": 2048}, # T5-XL
977
+ 'latent_dim': 2048,
978
+ 'seq_len': 77,
979
+ 'encoder_layers': 4,
980
+ 'decoder_layers': 4,
981
+ 'hidden_dim': 2048,
982
+ 'dropout': 0.1,
983
+ 'fusion_strategy': 'adaptive_cantor',
984
+ 'fusion_heads': 16,
985
+ 'fusion_dropout': 0.1
986
+ }
987
+
988
+ vae_config = MultiModalVAEConfig(
989
+ modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 2048}),
990
+ latent_dim=config_dict.get('latent_dim', 2048),
991
+ seq_len=config_dict.get('seq_len', 77),
992
+ encoder_layers=config_dict.get('encoder_layers', 4),
993
+ decoder_layers=config_dict.get('decoder_layers', 4),
994
+ hidden_dim=config_dict.get('hidden_dim', 2048),
995
+ dropout=config_dict.get('dropout', 0.1),
996
+ fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'),
997
+ fusion_heads=config_dict.get('fusion_heads', 16),
998
+ fusion_dropout=config_dict.get('fusion_dropout', 0.1)
999
+ )
1000
+
1001
+ lyra_model = MultiModalVAE(vae_config)
1002
+
1003
+ if 'model_state_dict' in checkpoint:
1004
+ lyra_model.load_state_dict(checkpoint['model_state_dict'])
1005
+ else:
1006
+ lyra_model.load_state_dict(checkpoint)
1007
+
1008
+ lyra_model.to(device)
1009
+ lyra_model.eval()
1010
+
1011
+ print(f"✅ Lyra VAE XL loaded")
1012
+ if 'global_step' in checkpoint:
1013
+ print(f" Step: {checkpoint['global_step']:,}")
1014
+
1015
+ return lyra_model
1016
+
1017
+ except Exception as e:
1018
+ print(f"❌ Failed to load Lyra VAE XL: {e}")
1019
+ return None
1020
+
1021
+
1022
+ # ============================================================================
1023
+ # PIPELINE INITIALIZATION
1024
+ # ============================================================================
1025
+
1026
+ def initialize_pipeline(model_choice: str, device: str = "cuda"):
1027
+ """Initialize the complete pipeline based on model choice."""
1028
+
1029
+ print(f"🚀 Initializing {model_choice} pipeline...")
1030
+
1031
+ # Determine architecture
1032
+ is_sdxl = "Illustrious" in model_choice or "SDXL" in model_choice
1033
+ is_lune = "Lune" in model_choice
1034
+
1035
+ if is_sdxl:
1036
+ # SDXL-based models
1037
+ if "Illustrious" in model_choice:
1038
+ unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl(device=device)
1039
+ else:
1040
+ unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_sdxl_base(device=device)
1041
+
1042
+ # T5-XL for Lyra
1043
+ print("Loading T5-XL encoder...")
1044
+ t5_tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xl")
1045
+ t5_encoder = T5EncoderModel.from_pretrained(
1046
+ "google/t5-v1_1-xl",
1047
+ torch_dtype=torch.float16
1048
+ ).to(device)
1049
+ t5_encoder.eval()
1050
+ print("✓ T5-XL loaded")
1051
+
1052
+ # Lyra XL
1053
+ lyra_model = load_lyra_vae_xl(device=device)
1054
+
1055
+ # Scheduler (epsilon for SDXL)
1056
+ scheduler = EulerDiscreteScheduler.from_pretrained(
1057
+ "stabilityai/stable-diffusion-xl-base-1.0",
1058
+ subfolder="scheduler"
1059
+ )
1060
+
1061
+ pipeline = SDXLFlowMatchingPipeline(
1062
+ vae=vae,
1063
+ text_encoder=text_encoder,
1064
+ text_encoder_2=text_encoder_2,
1065
+ tokenizer=tokenizer,
1066
+ tokenizer_2=tokenizer_2,
1067
+ unet=unet,
1068
+ scheduler=scheduler,
1069
+ device=device,
1070
+ t5_encoder=t5_encoder,
1071
+ t5_tokenizer=t5_tokenizer,
1072
+ lyra_model=lyra_model,
1073
+ clip_skip=1
1074
+ )
1075
+
1076
+ else:
1077
+ # SD1.5-based models
1078
+ vae = AutoencoderKL.from_pretrained(
1079
+ "runwayml/stable-diffusion-v1-5",
1080
+ subfolder="vae",
1081
+ torch_dtype=torch.float32
1082
+ ).to(device)
1083
+
1084
+ text_encoder = CLIPTextModel.from_pretrained(
1085
+ "openai/clip-vit-large-patch14",
1086
+ torch_dtype=torch.float32
1087
+ ).to(device)
1088
+
1089
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
1090
+
1091
+ # T5-base for SD1.5 Lyra
1092
+ print("Loading T5-base encoder...")
1093
+ t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
1094
+ t5_encoder = T5EncoderModel.from_pretrained(
1095
+ "t5-base",
1096
+ torch_dtype=torch.float32
1097
+ ).to(device)
1098
+ t5_encoder.eval()
1099
+ print("✓ T5-base loaded")
1100
+
1101
+ # Lyra (SD1.5 version)
1102
+ lyra_model = load_lyra_vae(device=device)
1103
+
1104
+ # Load UNet
1105
+ if is_lune:
1106
+ repo_id = "AbstractPhil/sd15-flow-lune"
1107
+ filename = "sd15_flow_lune_e34_s34000.pt"
1108
+ unet = load_lune_checkpoint(repo_id, filename, device)
1109
+ else:
1110
+ unet = UNet2DConditionModel.from_pretrained(
1111
+ "runwayml/stable-diffusion-v1-5",
1112
+ subfolder="unet",
1113
+ torch_dtype=torch.float32
1114
+ ).to(device)
1115
+
1116
+ scheduler = EulerDiscreteScheduler.from_pretrained(
1117
+ "runwayml/stable-diffusion-v1-5",
1118
+ subfolder="scheduler"
1119
+ )
1120
+
1121
+ pipeline = SD15FlowMatchingPipeline(
1122
+ vae=vae,
1123
+ text_encoder=text_encoder,
1124
+ tokenizer=tokenizer,
1125
+ unet=unet,
1126
+ scheduler=scheduler,
1127
+ device=device,
1128
+ t5_encoder=t5_encoder,
1129
+ t5_tokenizer=t5_tokenizer,
1130
+ lyra_model=lyra_model
1131
+ )
1132
+
1133
+ pipeline.is_lune_model = is_lune
1134
+
1135
+ print("✅ Pipeline initialized!")
1136
+ return pipeline
1137
+
1138
+
1139
+ # ============================================================================
1140
+ # GLOBAL STATE
1141
+ # ============================================================================
1142
+
1143
+ CURRENT_PIPELINE = None
1144
+ CURRENT_MODEL = None
1145
+
1146
+
1147
+ def get_pipeline(model_choice: str):
1148
+ """Get or create pipeline for selected model."""
1149
+ global CURRENT_PIPELINE, CURRENT_MODEL
1150
+
1151
+ if CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice:
1152
+ CURRENT_PIPELINE = initialize_pipeline(model_choice, device="cuda")
1153
+ CURRENT_MODEL = model_choice
1154
+
1155
+ return CURRENT_PIPELINE
1156
+
1157
+
1158
+ # ============================================================================
1159
+ # INFERENCE
1160
+ # ============================================================================
1161
+
1162
+ def estimate_duration(num_steps: int, width: int, height: int, use_lyra: bool = False, is_sdxl: bool = False) -> int:
1163
+ """Estimate GPU duration."""
1164
+ base_time_per_step = 0.5 if is_sdxl else 0.3
1165
+ resolution_factor = (width * height) / (512 * 512)
1166
+ estimated = num_steps * base_time_per_step * resolution_factor
1167
+
1168
+ if use_lyra:
1169
+ estimated *= 2
1170
+ estimated += 3
1171
+
1172
+ return int(estimated + 20)
1173
+
1174
+
1175
+ @spaces.GPU(duration=lambda *args: estimate_duration(
1176
+ args[4], args[6], args[7], args[10],
1177
+ "SDXL" in args[2] or "Illustrious" in args[2]
1178
+ ))
1179
+ def generate_image(
1180
+ prompt: str,
1181
+ negative_prompt: str,
1182
+ model_choice: str,
1183
+ clip_skip: int,
1184
+ num_steps: int,
1185
+ cfg_scale: float,
1186
+ width: int,
1187
+ height: int,
1188
+ shift: float,
1189
+ use_flow_matching: bool,
1190
+ use_lyra: bool,
1191
+ seed: int,
1192
+ randomize_seed: bool,
1193
+ progress=gr.Progress()
1194
+ ):
1195
+ """Generate image with ZeroGPU support."""
1196
+
1197
+ if randomize_seed:
1198
+ seed = np.random.randint(0, 2**32 - 1)
1199
+
1200
+ def progress_callback(step, total, desc):
1201
+ progress((step + 1) / total, desc=desc)
1202
+
1203
+ try:
1204
+ pipeline = get_pipeline(model_choice)
1205
+
1206
+ # Determine prediction type based on model
1207
+ is_sdxl = "SDXL" in model_choice or "Illustrious" in model_choice
1208
+ prediction_type = "epsilon" # SDXL always uses epsilon
1209
+
1210
+ if not is_sdxl and "Lune" in model_choice:
1211
+ prediction_type = "v_prediction"
1212
+
1213
+ if not use_lyra or pipeline.lyra_model is None:
1214
+ progress(0.05, desc="Generating...")
1215
+
1216
+ image = pipeline(
1217
+ prompt=prompt,
1218
+ negative_prompt=negative_prompt,
1219
+ height=height,
1220
+ width=width,
1221
+ num_inference_steps=num_steps,
1222
+ guidance_scale=cfg_scale,
1223
+ shift=shift,
1224
+ use_flow_matching=use_flow_matching,
1225
+ prediction_type=prediction_type,
1226
+ seed=seed,
1227
+ use_lyra=False,
1228
+ clip_skip=clip_skip,
1229
+ progress_callback=progress_callback
1230
+ )
1231
+
1232
+ progress(1.0, desc="Complete!")
1233
+ return image, None, seed
1234
+
1235
+ else:
1236
+ progress(0.05, desc="Generating standard...")
1237
+
1238
+ image_standard = pipeline(
1239
+ prompt=prompt,
1240
+ negative_prompt=negative_prompt,
1241
+ height=height,
1242
+ width=width,
1243
+ num_inference_steps=num_steps,
1244
+ guidance_scale=cfg_scale,
1245
+ shift=shift,
1246
+ use_flow_matching=use_flow_matching,
1247
+ prediction_type=prediction_type,
1248
+ seed=seed,
1249
+ use_lyra=False,
1250
+ clip_skip=clip_skip,
1251
+ progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d)
1252
+ )
1253
+
1254
+ progress(0.5, desc="Generating Lyra fusion...")
1255
+
1256
+ image_lyra = pipeline(
1257
+ prompt=prompt,
1258
+ negative_prompt=negative_prompt,
1259
+ height=height,
1260
+ width=width,
1261
+ num_inference_steps=num_steps,
1262
+ guidance_scale=cfg_scale,
1263
+ shift=shift,
1264
+ use_flow_matching=use_flow_matching,
1265
+ prediction_type=prediction_type,
1266
+ seed=seed,
1267
+ use_lyra=True,
1268
+ clip_skip=clip_skip,
1269
+ progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d)
1270
+ )
1271
+
1272
+ progress(1.0, desc="Complete!")
1273
+ return image_standard, image_lyra, seed
1274
+
1275
+ except Exception as e:
1276
+ print(f"❌ Generation failed: {e}")
1277
+ raise e
1278
+
1279
+
1280
+ # ============================================================================
1281
+ # GRADIO UI
1282
+ # ============================================================================
1283
+
1284
+ def create_demo():
1285
+ """Create Gradio interface."""
1286
+
1287
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
1288
+ gr.Markdown("""
1289
+ # 🌙 Lyra/Lune Flow-Matching Image Generation
1290
+
1291
+ **Geometric crystalline diffusion** by [AbstractPhil](https://huggingface.co/AbstractPhil)
1292
+
1293
+ Generate images using SD1.5 and SDXL-based models with geometric deep learning:
1294
+
1295
+ | Model | Architecture | Best For |
1296
+ |-------|-------------|----------|
1297
+ | **Illustrious XL** | SDXL | Anime/illustration, high detail |
1298
+ | **SDXL Base** | SDXL | Photorealistic, general purpose |
1299
+ | **Flow-Lune** | SD1.5 | Fast flow matching (15-25 steps) |
1300
+ | **SD1.5 Base** | SD1.5 | Baseline comparison |
1301
+
1302
+ Enable **Lyra VAE** for CLIP+T5 fusion comparison!
1303
+ """)
1304
+
1305
+ with gr.Row():
1306
+ with gr.Column(scale=1):
1307
+ prompt = gr.TextArea(
1308
+ label="Prompt",
1309
+ value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background",
1310
+ lines=3
1311
+ )
1312
+
1313
+ negative_prompt = gr.TextArea(
1314
+ label="Negative Prompt",
1315
+ value="lowres, bad anatomy, bad hands, text, error, cropped, worst quality, low quality",
1316
+ lines=2
1317
+ )
1318
+
1319
+ model_choice = gr.Dropdown(
1320
+ label="Model",
1321
+ choices=[
1322
+ "Illustrious XL",
1323
+ "SDXL Base",
1324
+ "Flow-Lune (SD1.5)",
1325
+ "SD1.5 Base"
1326
+ ],
1327
+ value="Illustrious XL"
1328
+ )
1329
+
1330
+ clip_skip = gr.Slider(
1331
+ label="CLIP Skip",
1332
  minimum=1,
1333
+ maximum=4,
1334
+ value=2,
1335
  step=1,
1336
+ info="2 recommended for Illustrious, 1 for others"
1337
+ )
1338
+
1339
+ use_lyra = gr.Checkbox(
1340
+ label="Enable Lyra VAE (CLIP+T5 Fusion)",
1341
+ value=False,
1342
+ info="Compare standard vs geometric fusion"
1343
  )
1344
+
1345
+ with gr.Accordion("Generation Settings", open=True):
1346
+ num_steps = gr.Slider(
1347
+ label="Steps",
1348
+ minimum=1,
1349
+ maximum=50,
1350
+ value=25,
1351
+ step=1
1352
+ )
1353
+
1354
+ cfg_scale = gr.Slider(
1355
+ label="CFG Scale",
1356
+ minimum=1.0,
1357
+ maximum=20.0,
1358
+ value=7.0,
1359
+ step=0.5
1360
+ )
1361
+
1362
+ with gr.Row():
1363
+ width = gr.Slider(
1364
+ label="Width",
1365
+ minimum=512,
1366
+ maximum=1536,
1367
+ value=1024,
1368
+ step=64
1369
+ )
1370
+ height = gr.Slider(
1371
+ label="Height",
1372
+ minimum=512,
1373
+ maximum=1536,
1374
+ value=1024,
1375
+ step=64
1376
+ )
1377
+
1378
+ seed = gr.Slider(
1379
+ label="Seed",
1380
+ minimum=0,
1381
+ maximum=2**32 - 1,
1382
+ value=42,
1383
+ step=1
1384
+ )
1385
+
1386
+ randomize_seed = gr.Checkbox(
1387
+ label="Randomize Seed",
1388
+ value=True
1389
+ )
1390
+
1391
+ with gr.Accordion("Advanced (Flow Matching)", open=False):
1392
+ use_flow_matching = gr.Checkbox(
1393
+ label="Enable Flow Matching",
1394
+ value=False,
1395
+ info="Use flow matching ODE (for Lune only)"
1396
+ )
1397
+
1398
+ shift = gr.Slider(
1399
+ label="Shift",
1400
+ minimum=0.0,
1401
+ maximum=5.0,
1402
+ value=0.0,
1403
+ step=0.1,
1404
+ info="Flow matching shift (0=disabled)"
1405
+ )
1406
+
1407
+ generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
1408
+
1409
+ with gr.Column(scale=1):
1410
+ with gr.Row():
1411
+ output_image_standard = gr.Image(
1412
+ label="Generated Image",
1413
+ type="pil"
1414
+ )
1415
+ output_image_lyra = gr.Image(
1416
+ label="Lyra Fusion 🎵",
1417
+ type="pil",
1418
+ visible=False
1419
+ )
1420
+
1421
+ output_seed = gr.Number(label="Seed", precision=0)
1422
+
1423
+ gr.Markdown("""
1424
+ ### Tips
1425
+ - **Illustrious XL**: Use CLIP skip 2, booru-style tags
1426
+ - **SDXL Base**: Natural language prompts work well
1427
+ - **Flow-Lune**: Enable flow matching, shift ~2.5, fewer steps
1428
+ - **Lyra**: Generates both standard and fused for comparison
1429
+
1430
+ ### Model Info
1431
+ - SDXL models use **epsilon** prediction
1432
+ - Lune uses **v_prediction** with flow matching
1433
+ - Lyra fuses CLIP + T5 for richer semantics
1434
+ """)
1435
+
1436
+ # Examples
1437
+ gr.Examples(
1438
+ examples=[
1439
+ [
1440
+ "masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background",
1441
+ "lowres, bad anatomy, worst quality, low quality",
1442
+ "Illustrious XL",
1443
+ 2, 25, 7.0, 1024, 1024, 0.0, False, False, 42, False
1444
+ ],
1445
+ [
1446
+ "A majestic mountain landscape at golden hour, crystal clear lake, photorealistic, 8k",
1447
+ "blurry, low quality",
1448
+ "SDXL Base",
1449
+ 1, 30, 7.5, 1024, 1024, 0.0, False, False, 123, False
1450
+ ],
1451
+ [
1452
+ "cyberpunk city at night, neon lights, rain, highly detailed",
1453
+ "low quality, blurry",
1454
+ "Flow-Lune (SD1.5)",
1455
+ 1, 20, 7.5, 512, 512, 2.5, True, False, 456, False
1456
+ ],
1457
+ ],
1458
+ inputs=[
1459
+ prompt, negative_prompt, model_choice, clip_skip,
1460
+ num_steps, cfg_scale, width, height, shift,
1461
+ use_flow_matching, use_lyra, seed, randomize_seed
1462
+ ],
1463
+ outputs=[output_image_standard, output_image_lyra, output_seed],
1464
+ fn=generate_image,
1465
+ cache_examples=False
1466
+ )
1467
+
1468
+ # Event handlers
1469
+ def on_model_change(model_name):
1470
+ """Update defaults based on model."""
1471
+ if "Illustrious" in model_name:
1472
+ return {
1473
+ clip_skip: gr.update(value=2),
1474
+ width: gr.update(value=1024),
1475
+ height: gr.update(value=1024),
1476
+ num_steps: gr.update(value=25),
1477
+ use_flow_matching: gr.update(value=False),
1478
+ shift: gr.update(value=0.0)
1479
+ }
1480
+ elif "SDXL" in model_name:
1481
+ return {
1482
+ clip_skip: gr.update(value=1),
1483
+ width: gr.update(value=1024),
1484
+ height: gr.update(value=1024),
1485
+ num_steps: gr.update(value=30),
1486
+ use_flow_matching: gr.update(value=False),
1487
+ shift: gr.update(value=0.0)
1488
+ }
1489
+ elif "Lune" in model_name:
1490
+ return {
1491
+ clip_skip: gr.update(value=1),
1492
+ width: gr.update(value=512),
1493
+ height: gr.update(value=512),
1494
+ num_steps: gr.update(value=20),
1495
+ use_flow_matching: gr.update(value=True),
1496
+ shift: gr.update(value=2.5)
1497
+ }
1498
+ else: # SD1.5 Base
1499
+ return {
1500
+ clip_skip: gr.update(value=1),
1501
+ width: gr.update(value=512),
1502
+ height: gr.update(value=512),
1503
+ num_steps: gr.update(value=30),
1504
+ use_flow_matching: gr.update(value=False),
1505
+ shift: gr.update(value=0.0)
1506
+ }
1507
+
1508
+ def on_lyra_toggle(enabled):
1509
+ """Show/hide Lyra comparison."""
1510
+ if enabled:
1511
+ return {
1512
+ output_image_standard: gr.update(visible=True, label="Standard"),
1513
+ output_image_lyra: gr.update(visible=True, label="Lyra Fusion 🎵")
1514
+ }
1515
+ else:
1516
+ return {
1517
+ output_image_standard: gr.update(visible=True, label="Generated Image"),
1518
+ output_image_lyra: gr.update(visible=False)
1519
+ }
1520
+
1521
+ model_choice.change(
1522
+ fn=on_model_change,
1523
+ inputs=[model_choice],
1524
+ outputs=[clip_skip, width, height, num_steps, use_flow_matching, shift]
1525
+ )
1526
+
1527
+ use_lyra.change(
1528
+ fn=on_lyra_toggle,
1529
+ inputs=[use_lyra],
1530
+ outputs=[output_image_standard, output_image_lyra]
1531
+ )
1532
+
1533
+ generate_btn.click(
1534
+ fn=generate_image,
1535
+ inputs=[
1536
+ prompt, negative_prompt, model_choice, clip_skip,
1537
+ num_steps, cfg_scale, width, height, shift,
1538
+ use_flow_matching, use_lyra, seed, randomize_seed
1539
+ ],
1540
+ outputs=[output_image_standard, output_image_lyra, output_seed]
1541
+ )
1542
+
1543
+ return demo
1544
 
1545
+
1546
+ # ============================================================================
1547
+ # LAUNCH
1548
+ # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
1549
 
1550
  if __name__ == "__main__":
1551
+ demo = create_demo()
1552
+ demo.queue(max_size=20)
1553
+ demo.launch(show_api=False)