AbstractPhil commited on
Commit
f6dab9d
·
verified ·
1 Parent(s): 93038cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -43
app.py CHANGED
@@ -309,37 +309,18 @@ class SDXLFlowMatchingPipeline:
309
  print(f"[Lyra Debug] T5 input: shape={t5_embeds.shape}, mean={t5_embeds.mean():.4f}, std={t5_embeds.std():.4f}")
310
 
311
  with torch.no_grad():
312
- # Try approach 1: Cross-modal - encode T5 only, decode to CLIP
313
- # This uses T5's semantic understanding to generate CLIP-compatible embeddings
314
- t5_only_inputs = {
 
315
  't5_xl_l': t5_embeds.float(),
316
  't5_xl_g': t5_embeds.float()
317
  }
318
-
319
- # Check if model has separate encode/decode methods
320
- if hasattr(self.lyra_model, 'encode') and hasattr(self.lyra_model, 'decode'):
321
- print("[Lyra Debug] Using separate encode/decode path")
322
- # Encode T5 to latent space
323
- mu, logvar = self.lyra_model.encode(t5_only_inputs)
324
- z = mu # Use mean for deterministic output
325
- print(f"[Lyra Debug] Latent z: shape={z.shape}, mean={z.mean():.4f}, std={z.std():.4f}")
326
-
327
- # Decode to CLIP space
328
- reconstructions = self.lyra_model.decode(z, target_modalities=['clip_l', 'clip_g'])
329
- else:
330
- print("[Lyra Debug] Using forward pass with all modalities")
331
- # Fall back to full forward pass with all modalities
332
- modality_inputs = {
333
- 'clip_l': clip_l_embeds.float(),
334
- 'clip_g': clip_g_embeds.float(),
335
- 't5_xl_l': t5_embeds.float(),
336
- 't5_xl_g': t5_embeds.float()
337
- }
338
- reconstructions, mu, logvar, _ = self.lyra_model(
339
- modality_inputs,
340
- target_modalities=['clip_l', 'clip_g']
341
- )
342
- print(f"[Lyra Debug] Latent mu: shape={mu.shape}, mean={mu.mean():.4f}, std={mu.std():.4f}")
343
 
344
  lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
345
  lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
@@ -348,7 +329,6 @@ class SDXLFlowMatchingPipeline:
348
  print(f"[Lyra Debug] Lyra CLIP-G output: mean={lyra_clip_g.mean():.4f}, std={lyra_clip_g.std():.4f}")
349
 
350
  # Check if reconstruction stats are wildly different from input
351
- # If so, we may need to normalize
352
  clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8)
353
  clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8)
354
  print(f"[Lyra Debug] Std ratio CLIP-L: {clip_l_std_ratio:.4f}, CLIP-G: {clip_g_std_ratio:.4f}")
@@ -393,27 +373,25 @@ class SDXLFlowMatchingPipeline:
393
  neg_clip_l = negative_prompt_embeds[..., :clip_l_dim]
394
  neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
395
 
396
- if hasattr(self.lyra_model, 'encode') and hasattr(self.lyra_model, 'decode'):
397
- t5_neg_inputs = {'t5_xl_l': t5_embeds_neg.float(), 't5_xl_g': t5_embeds_neg.float()}
398
- mu_neg, _ = self.lyra_model.encode(t5_neg_inputs)
399
- recon_neg = self.lyra_model.decode(mu_neg, target_modalities=['clip_l', 'clip_g'])
400
- else:
401
- modality_inputs_neg = {
402
- 'clip_l': neg_clip_l.float(),
403
- 'clip_g': neg_clip_g.float(),
404
- 't5_xl_l': t5_embeds_neg.float(),
405
- 't5_xl_g': t5_embeds_neg.float()
406
- }
407
- recon_neg, _, _, _ = self.lyra_model(modality_inputs_neg, target_modalities=['clip_l', 'clip_g'])
408
 
409
  lyra_neg_l = recon_neg['clip_l'].to(negative_prompt_embeds.dtype)
410
  lyra_neg_g = recon_neg['clip_g'].to(negative_prompt_embeds.dtype)
411
 
412
  # Normalize if needed
413
- if lyra_neg_l.std() / (neg_clip_l.std() + 1e-8) > 2.0:
 
 
414
  lyra_neg_l = (lyra_neg_l - lyra_neg_l.mean()) / (lyra_neg_l.std() + 1e-8)
415
  lyra_neg_l = lyra_neg_l * neg_clip_l.std() + neg_clip_l.mean()
416
- if lyra_neg_g.std() / (neg_clip_g.std() + 1e-8) > 2.0:
417
  lyra_neg_g = (lyra_neg_g - lyra_neg_g.mean()) / (lyra_neg_g.std() + 1e-8)
418
  lyra_neg_g = lyra_neg_g * neg_clip_g.std() + neg_clip_g.mean()
419
 
 
309
  print(f"[Lyra Debug] T5 input: shape={t5_embeds.shape}, mean={t5_embeds.mean():.4f}, std={t5_embeds.std():.4f}")
310
 
311
  with torch.no_grad():
312
+ # Full forward pass with all modalities (model requires all)
313
+ modality_inputs = {
314
+ 'clip_l': clip_l_embeds.float(),
315
+ 'clip_g': clip_g_embeds.float(),
316
  't5_xl_l': t5_embeds.float(),
317
  't5_xl_g': t5_embeds.float()
318
  }
319
+ reconstructions, mu, logvar, _ = self.lyra_model(
320
+ modality_inputs,
321
+ target_modalities=['clip_l', 'clip_g']
322
+ )
323
+ print(f"[Lyra Debug] Latent mu: shape={mu.shape}, mean={mu.mean():.4f}, std={mu.std():.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
  lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
326
  lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
 
329
  print(f"[Lyra Debug] Lyra CLIP-G output: mean={lyra_clip_g.mean():.4f}, std={lyra_clip_g.std():.4f}")
330
 
331
  # Check if reconstruction stats are wildly different from input
 
332
  clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8)
333
  clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8)
334
  print(f"[Lyra Debug] Std ratio CLIP-L: {clip_l_std_ratio:.4f}, CLIP-G: {clip_g_std_ratio:.4f}")
 
373
  neg_clip_l = negative_prompt_embeds[..., :clip_l_dim]
374
  neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
375
 
376
+ # Full forward pass (model requires all modalities)
377
+ modality_inputs_neg = {
378
+ 'clip_l': neg_clip_l.float(),
379
+ 'clip_g': neg_clip_g.float(),
380
+ 't5_xl_l': t5_embeds_neg.float(),
381
+ 't5_xl_g': t5_embeds_neg.float()
382
+ }
383
+ recon_neg, _, _, _ = self.lyra_model(modality_inputs_neg, target_modalities=['clip_l', 'clip_g'])
 
 
 
 
384
 
385
  lyra_neg_l = recon_neg['clip_l'].to(negative_prompt_embeds.dtype)
386
  lyra_neg_g = recon_neg['clip_g'].to(negative_prompt_embeds.dtype)
387
 
388
  # Normalize if needed
389
+ neg_l_ratio = lyra_neg_l.std() / (neg_clip_l.std() + 1e-8)
390
+ neg_g_ratio = lyra_neg_g.std() / (neg_clip_g.std() + 1e-8)
391
+ if neg_l_ratio > 2.0 or neg_l_ratio < 0.5:
392
  lyra_neg_l = (lyra_neg_l - lyra_neg_l.mean()) / (lyra_neg_l.std() + 1e-8)
393
  lyra_neg_l = lyra_neg_l * neg_clip_l.std() + neg_clip_l.mean()
394
+ if neg_g_ratio > 2.0 or neg_g_ratio < 0.5:
395
  lyra_neg_g = (lyra_neg_g - lyra_neg_g.mean()) / (lyra_neg_g.std() + 1e-8)
396
  lyra_neg_g = lyra_neg_g * neg_clip_g.std() + neg_clip_g.mean()
397