Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -309,37 +309,18 @@ class SDXLFlowMatchingPipeline:
|
|
| 309 |
print(f"[Lyra Debug] T5 input: shape={t5_embeds.shape}, mean={t5_embeds.mean():.4f}, std={t5_embeds.std():.4f}")
|
| 310 |
|
| 311 |
with torch.no_grad():
|
| 312 |
-
#
|
| 313 |
-
|
| 314 |
-
|
|
|
|
| 315 |
't5_xl_l': t5_embeds.float(),
|
| 316 |
't5_xl_g': t5_embeds.float()
|
| 317 |
}
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
mu, logvar = self.lyra_model.encode(t5_only_inputs)
|
| 324 |
-
z = mu # Use mean for deterministic output
|
| 325 |
-
print(f"[Lyra Debug] Latent z: shape={z.shape}, mean={z.mean():.4f}, std={z.std():.4f}")
|
| 326 |
-
|
| 327 |
-
# Decode to CLIP space
|
| 328 |
-
reconstructions = self.lyra_model.decode(z, target_modalities=['clip_l', 'clip_g'])
|
| 329 |
-
else:
|
| 330 |
-
print("[Lyra Debug] Using forward pass with all modalities")
|
| 331 |
-
# Fall back to full forward pass with all modalities
|
| 332 |
-
modality_inputs = {
|
| 333 |
-
'clip_l': clip_l_embeds.float(),
|
| 334 |
-
'clip_g': clip_g_embeds.float(),
|
| 335 |
-
't5_xl_l': t5_embeds.float(),
|
| 336 |
-
't5_xl_g': t5_embeds.float()
|
| 337 |
-
}
|
| 338 |
-
reconstructions, mu, logvar, _ = self.lyra_model(
|
| 339 |
-
modality_inputs,
|
| 340 |
-
target_modalities=['clip_l', 'clip_g']
|
| 341 |
-
)
|
| 342 |
-
print(f"[Lyra Debug] Latent mu: shape={mu.shape}, mean={mu.mean():.4f}, std={mu.std():.4f}")
|
| 343 |
|
| 344 |
lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
|
| 345 |
lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
|
|
@@ -348,7 +329,6 @@ class SDXLFlowMatchingPipeline:
|
|
| 348 |
print(f"[Lyra Debug] Lyra CLIP-G output: mean={lyra_clip_g.mean():.4f}, std={lyra_clip_g.std():.4f}")
|
| 349 |
|
| 350 |
# Check if reconstruction stats are wildly different from input
|
| 351 |
-
# If so, we may need to normalize
|
| 352 |
clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8)
|
| 353 |
clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8)
|
| 354 |
print(f"[Lyra Debug] Std ratio CLIP-L: {clip_l_std_ratio:.4f}, CLIP-G: {clip_g_std_ratio:.4f}")
|
|
@@ -393,27 +373,25 @@ class SDXLFlowMatchingPipeline:
|
|
| 393 |
neg_clip_l = negative_prompt_embeds[..., :clip_l_dim]
|
| 394 |
neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
|
| 395 |
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
't5_xl_l': t5_embeds_neg.float(),
|
| 405 |
-
't5_xl_g': t5_embeds_neg.float()
|
| 406 |
-
}
|
| 407 |
-
recon_neg, _, _, _ = self.lyra_model(modality_inputs_neg, target_modalities=['clip_l', 'clip_g'])
|
| 408 |
|
| 409 |
lyra_neg_l = recon_neg['clip_l'].to(negative_prompt_embeds.dtype)
|
| 410 |
lyra_neg_g = recon_neg['clip_g'].to(negative_prompt_embeds.dtype)
|
| 411 |
|
| 412 |
# Normalize if needed
|
| 413 |
-
|
|
|
|
|
|
|
| 414 |
lyra_neg_l = (lyra_neg_l - lyra_neg_l.mean()) / (lyra_neg_l.std() + 1e-8)
|
| 415 |
lyra_neg_l = lyra_neg_l * neg_clip_l.std() + neg_clip_l.mean()
|
| 416 |
-
if
|
| 417 |
lyra_neg_g = (lyra_neg_g - lyra_neg_g.mean()) / (lyra_neg_g.std() + 1e-8)
|
| 418 |
lyra_neg_g = lyra_neg_g * neg_clip_g.std() + neg_clip_g.mean()
|
| 419 |
|
|
|
|
| 309 |
print(f"[Lyra Debug] T5 input: shape={t5_embeds.shape}, mean={t5_embeds.mean():.4f}, std={t5_embeds.std():.4f}")
|
| 310 |
|
| 311 |
with torch.no_grad():
|
| 312 |
+
# Full forward pass with all modalities (model requires all)
|
| 313 |
+
modality_inputs = {
|
| 314 |
+
'clip_l': clip_l_embeds.float(),
|
| 315 |
+
'clip_g': clip_g_embeds.float(),
|
| 316 |
't5_xl_l': t5_embeds.float(),
|
| 317 |
't5_xl_g': t5_embeds.float()
|
| 318 |
}
|
| 319 |
+
reconstructions, mu, logvar, _ = self.lyra_model(
|
| 320 |
+
modality_inputs,
|
| 321 |
+
target_modalities=['clip_l', 'clip_g']
|
| 322 |
+
)
|
| 323 |
+
print(f"[Lyra Debug] Latent mu: shape={mu.shape}, mean={mu.mean():.4f}, std={mu.std():.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
|
| 326 |
lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
|
|
|
|
| 329 |
print(f"[Lyra Debug] Lyra CLIP-G output: mean={lyra_clip_g.mean():.4f}, std={lyra_clip_g.std():.4f}")
|
| 330 |
|
| 331 |
# Check if reconstruction stats are wildly different from input
|
|
|
|
| 332 |
clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8)
|
| 333 |
clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8)
|
| 334 |
print(f"[Lyra Debug] Std ratio CLIP-L: {clip_l_std_ratio:.4f}, CLIP-G: {clip_g_std_ratio:.4f}")
|
|
|
|
| 373 |
neg_clip_l = negative_prompt_embeds[..., :clip_l_dim]
|
| 374 |
neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
|
| 375 |
|
| 376 |
+
# Full forward pass (model requires all modalities)
|
| 377 |
+
modality_inputs_neg = {
|
| 378 |
+
'clip_l': neg_clip_l.float(),
|
| 379 |
+
'clip_g': neg_clip_g.float(),
|
| 380 |
+
't5_xl_l': t5_embeds_neg.float(),
|
| 381 |
+
't5_xl_g': t5_embeds_neg.float()
|
| 382 |
+
}
|
| 383 |
+
recon_neg, _, _, _ = self.lyra_model(modality_inputs_neg, target_modalities=['clip_l', 'clip_g'])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
lyra_neg_l = recon_neg['clip_l'].to(negative_prompt_embeds.dtype)
|
| 386 |
lyra_neg_g = recon_neg['clip_g'].to(negative_prompt_embeds.dtype)
|
| 387 |
|
| 388 |
# Normalize if needed
|
| 389 |
+
neg_l_ratio = lyra_neg_l.std() / (neg_clip_l.std() + 1e-8)
|
| 390 |
+
neg_g_ratio = lyra_neg_g.std() / (neg_clip_g.std() + 1e-8)
|
| 391 |
+
if neg_l_ratio > 2.0 or neg_l_ratio < 0.5:
|
| 392 |
lyra_neg_l = (lyra_neg_l - lyra_neg_l.mean()) / (lyra_neg_l.std() + 1e-8)
|
| 393 |
lyra_neg_l = lyra_neg_l * neg_clip_l.std() + neg_clip_l.mean()
|
| 394 |
+
if neg_g_ratio > 2.0 or neg_g_ratio < 0.5:
|
| 395 |
lyra_neg_g = (lyra_neg_g - lyra_neg_g.mean()) / (lyra_neg_g.std() + 1e-8)
|
| 396 |
lyra_neg_g = lyra_neg_g * neg_clip_g.std() + neg_clip_g.mean()
|
| 397 |
|