AbstractPhil commited on
Commit
b20f699
·
verified ·
1 Parent(s): c2e9f6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -27
app.py CHANGED
@@ -327,20 +327,29 @@ class LazyLyraModel:
327
  # Config already prefetched
328
  config_dict = self._config
329
 
330
- # Find checkpoint
331
- from huggingface_hub import list_repo_files
332
-
333
- repo_files = list_repo_files(self.repo_id, repo_type="model")
334
- checkpoint_files = [f for f in repo_files if f.endswith('.safetensors') or f.endswith('.pt')]
335
-
336
- # Prefer weights/ folder
337
- weights_files = [f for f in checkpoint_files if f.startswith('weights/')]
338
- if weights_files:
339
- checkpoint_file = sorted(weights_files)[-1] # Latest
340
- elif checkpoint_files:
341
- checkpoint_file = checkpoint_files[0]
342
  else:
343
- raise FileNotFoundError(f"No checkpoint found in {self.repo_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  checkpoint_path = hf_hub_download(
346
  repo_id=self.repo_id,
@@ -1295,12 +1304,16 @@ def load_lune_checkpoint(repo_id: str, filename: str, device: str = "cuda"):
1295
 
1296
  def load_illustrious_xl(
1297
  repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
1298
- filename: str = "illustriousXL_v01.safetensors",
1299
  device: str = "cuda"
1300
  ) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]:
1301
  """Load Illustrious XL from single safetensors file."""
1302
  from diffusers import StableDiffusionXLPipeline
1303
 
 
 
 
 
1304
  print(f"📥 Loading Illustrious XL: {repo_id}/{filename}")
1305
 
1306
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
@@ -1377,13 +1390,23 @@ def load_sdxl_base(device: str = "cuda"):
1377
  # PIPELINE INITIALIZATION
1378
  # ============================================================================
1379
 
1380
- def initialize_pipeline(model_choice: str, device: str = "cuda"):
1381
  """Initialize the complete pipeline based on model choice.
1382
 
1383
  Uses lazy loading for T5 and Lyra - they won't be downloaded until first use.
 
 
 
 
 
 
1384
  """
1385
 
1386
  print(f"🚀 Initializing {model_choice} pipeline...")
 
 
 
 
1387
 
1388
  is_sdxl = "Illustrious" in model_choice or "SDXL" in model_choice
1389
  is_lune = "Lune" in model_choice
@@ -1391,7 +1414,10 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
1391
  if is_sdxl:
1392
  # SDXL-based models
1393
  if "Illustrious" in model_choice:
1394
- unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl(device=device)
 
 
 
1395
  else:
1396
  unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_sdxl_base(device=device)
1397
 
@@ -1404,7 +1430,8 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
1404
  )
1405
  lyra_loader = LazyLyraModel(
1406
  repo_id=LYRA_ILLUSTRIOUS_REPO,
1407
- device=device
 
1408
  )
1409
 
1410
  # Default scheduler: Euler Ancestral
@@ -1448,13 +1475,18 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
1448
  )
1449
  lyra_loader = LazyLyraModel(
1450
  repo_id=LYRA_SD15_REPO,
1451
- device=device
 
1452
  )
1453
 
1454
  # Load UNet
1455
  if is_lune:
1456
  repo_id = "AbstractPhil/sd15-flow-lune"
1457
- filename = "sd15_flow_lune_e34_s34000.pt"
 
 
 
 
1458
  unet = load_lune_checkpoint(repo_id, filename, device)
1459
  else:
1460
  unet = UNet2DConditionModel.from_pretrained(
@@ -1491,15 +1523,31 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
1491
 
1492
  CURRENT_PIPELINE = None
1493
  CURRENT_MODEL = None
 
 
1494
 
1495
 
1496
- def get_pipeline(model_choice: str):
1497
  """Get or create pipeline for selected model."""
1498
- global CURRENT_PIPELINE, CURRENT_MODEL
1499
-
1500
- if CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice:
1501
- CURRENT_PIPELINE = initialize_pipeline(model_choice, device="cuda")
 
 
 
 
 
 
 
 
 
 
 
 
1502
  CURRENT_MODEL = model_choice
 
 
1503
 
1504
  return CURRENT_PIPELINE
1505
 
@@ -1522,7 +1570,7 @@ def estimate_duration(num_steps: int, width: int, height: int, use_lyra: bool =
1522
 
1523
 
1524
  @spaces.GPU(duration=lambda *args: estimate_duration(
1525
- args[6], args[8], args[9], args[12],
1526
  "SDXL" in args[3] or "Illustrious" in args[3]
1527
  ))
1528
  def generate_image(
@@ -1530,6 +1578,8 @@ def generate_image(
1530
  t5_summary: str,
1531
  negative_prompt: str,
1532
  model_choice: str,
 
 
1533
  scheduler_choice: str,
1534
  clip_skip: int,
1535
  num_steps: int,
@@ -1551,6 +1601,8 @@ def generate_image(
1551
  Args:
1552
  prompt: Tags/keywords (CLIP input)
1553
  t5_summary: Natural language summary (T5 input, unless clip_include_summary)
 
 
1554
  use_separator: Use ¶ separator between tags and summary
1555
  clip_include_summary: If True, CLIP also sees the summary
1556
  """
@@ -1567,7 +1619,7 @@ def generate_image(
1567
  progress((step + 1) / total, desc=desc)
1568
 
1569
  try:
1570
- pipeline = get_pipeline(model_choice)
1571
 
1572
  # Update scheduler if needed (SDXL only)
1573
  is_sdxl = "SDXL" in model_choice or "Illustrious" in model_choice
@@ -1710,6 +1762,21 @@ def create_demo():
1710
  value="Illustrious XL"
1711
  )
1712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1713
  scheduler_choice = gr.Dropdown(
1714
  label="Scheduler (SDXL only)",
1715
  choices=SCHEDULER_CHOICES,
@@ -1913,7 +1980,8 @@ def create_demo():
1913
  generate_btn.click(
1914
  fn=generate_image,
1915
  inputs=[
1916
- prompt, t5_summary, negative_prompt, model_choice, scheduler_choice, clip_skip,
 
1917
  num_steps, cfg_scale, width, height, shift,
1918
  use_flow_matching, use_lyra, lyra_strength, use_separator, clip_include_summary,
1919
  seed, randomize_seed
 
327
  # Config already prefetched
328
  config_dict = self._config
329
 
330
+ # Use provided checkpoint or find one
331
+ if self.checkpoint and self.checkpoint.strip():
332
+ checkpoint_file = self.checkpoint.strip()
333
+ # Add weights/ prefix if not present and file doesn't exist at root
334
+ if not checkpoint_file.startswith('weights/'):
335
+ checkpoint_file = f"weights/{checkpoint_file}"
336
+ print(f"[Lyra] Using specified checkpoint: {checkpoint_file}")
 
 
 
 
 
337
  else:
338
+ # Find checkpoint automatically
339
+ from huggingface_hub import list_repo_files
340
+
341
+ repo_files = list_repo_files(self.repo_id, repo_type="model")
342
+ checkpoint_files = [f for f in repo_files if f.endswith('.safetensors') or f.endswith('.pt')]
343
+
344
+ # Prefer weights/ folder
345
+ weights_files = [f for f in checkpoint_files if f.startswith('weights/')]
346
+ if weights_files:
347
+ checkpoint_file = sorted(weights_files)[-1] # Latest
348
+ elif checkpoint_files:
349
+ checkpoint_file = checkpoint_files[0]
350
+ else:
351
+ raise FileNotFoundError(f"No checkpoint found in {self.repo_id}")
352
+ print(f"[Lyra] Auto-selected checkpoint: {checkpoint_file}")
353
 
354
  checkpoint_path = hf_hub_download(
355
  repo_id=self.repo_id,
 
1304
 
1305
  def load_illustrious_xl(
1306
  repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
1307
+ filename: str = "",
1308
  device: str = "cuda"
1309
  ) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]:
1310
  """Load Illustrious XL from single safetensors file."""
1311
  from diffusers import StableDiffusionXLPipeline
1312
 
1313
+ # Default checkpoint if none specified
1314
+ if not filename or not filename.strip():
1315
+ filename = "illustriousXL_v01.safetensors"
1316
+
1317
  print(f"📥 Loading Illustrious XL: {repo_id}/{filename}")
1318
 
1319
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
 
1390
  # PIPELINE INITIALIZATION
1391
  # ============================================================================
1392
 
1393
+ def initialize_pipeline(model_choice: str, device: str = "cuda", checkpoint: str = "", lyra_checkpoint: str = ""):
1394
  """Initialize the complete pipeline based on model choice.
1395
 
1396
  Uses lazy loading for T5 and Lyra - they won't be downloaded until first use.
1397
+
1398
+ Args:
1399
+ model_choice: Model selection from dropdown
1400
+ device: Target device
1401
+ checkpoint: Optional custom checkpoint filename (e.g., "my_model.safetensors")
1402
+ lyra_checkpoint: Optional custom Lyra VAE checkpoint filename
1403
  """
1404
 
1405
  print(f"🚀 Initializing {model_choice} pipeline...")
1406
+ if checkpoint:
1407
+ print(f" Custom model checkpoint: {checkpoint}")
1408
+ if lyra_checkpoint:
1409
+ print(f" Custom Lyra checkpoint: {lyra_checkpoint}")
1410
 
1411
  is_sdxl = "Illustrious" in model_choice or "SDXL" in model_choice
1412
  is_lune = "Lune" in model_choice
 
1414
  if is_sdxl:
1415
  # SDXL-based models
1416
  if "Illustrious" in model_choice:
1417
+ unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl(
1418
+ device=device,
1419
+ filename=checkpoint
1420
+ )
1421
  else:
1422
  unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_sdxl_base(device=device)
1423
 
 
1430
  )
1431
  lyra_loader = LazyLyraModel(
1432
  repo_id=LYRA_ILLUSTRIOUS_REPO,
1433
+ device=device,
1434
+ checkpoint=lyra_checkpoint if lyra_checkpoint and lyra_checkpoint.strip() else None
1435
  )
1436
 
1437
  # Default scheduler: Euler Ancestral
 
1475
  )
1476
  lyra_loader = LazyLyraModel(
1477
  repo_id=LYRA_SD15_REPO,
1478
+ device=device,
1479
+ checkpoint=lyra_checkpoint if lyra_checkpoint and lyra_checkpoint.strip() else None
1480
  )
1481
 
1482
  # Load UNet
1483
  if is_lune:
1484
  repo_id = "AbstractPhil/sd15-flow-lune"
1485
+ # Use custom checkpoint or default
1486
+ if checkpoint and checkpoint.strip():
1487
+ filename = checkpoint
1488
+ else:
1489
+ filename = "sd15_flow_lune_e34_s34000.pt"
1490
  unet = load_lune_checkpoint(repo_id, filename, device)
1491
  else:
1492
  unet = UNet2DConditionModel.from_pretrained(
 
1523
 
1524
  CURRENT_PIPELINE = None
1525
  CURRENT_MODEL = None
1526
+ CURRENT_CHECKPOINT = None
1527
+ CURRENT_LYRA_CHECKPOINT = None
1528
 
1529
 
1530
+ def get_pipeline(model_choice: str, checkpoint: str = "", lyra_checkpoint: str = ""):
1531
  """Get or create pipeline for selected model."""
1532
+ global CURRENT_PIPELINE, CURRENT_MODEL, CURRENT_CHECKPOINT, CURRENT_LYRA_CHECKPOINT
1533
+
1534
+ # Normalize empty values
1535
+ checkpoint = checkpoint.strip() if checkpoint else ""
1536
+ lyra_checkpoint = lyra_checkpoint.strip() if lyra_checkpoint else ""
1537
+
1538
+ # Reinitialize if model or any checkpoint changed
1539
+ if (CURRENT_PIPELINE is None or
1540
+ CURRENT_MODEL != model_choice or
1541
+ CURRENT_CHECKPOINT != checkpoint or
1542
+ CURRENT_LYRA_CHECKPOINT != lyra_checkpoint):
1543
+ CURRENT_PIPELINE = initialize_pipeline(
1544
+ model_choice, device="cuda",
1545
+ checkpoint=checkpoint,
1546
+ lyra_checkpoint=lyra_checkpoint
1547
+ )
1548
  CURRENT_MODEL = model_choice
1549
+ CURRENT_CHECKPOINT = checkpoint
1550
+ CURRENT_LYRA_CHECKPOINT = lyra_checkpoint
1551
 
1552
  return CURRENT_PIPELINE
1553
 
 
1570
 
1571
 
1572
  @spaces.GPU(duration=lambda *args: estimate_duration(
1573
+ args[8], args[10], args[11], args[14],
1574
  "SDXL" in args[3] or "Illustrious" in args[3]
1575
  ))
1576
  def generate_image(
 
1578
  t5_summary: str,
1579
  negative_prompt: str,
1580
  model_choice: str,
1581
+ checkpoint: str,
1582
+ lyra_checkpoint: str,
1583
  scheduler_choice: str,
1584
  clip_skip: int,
1585
  num_steps: int,
 
1601
  Args:
1602
  prompt: Tags/keywords (CLIP input)
1603
  t5_summary: Natural language summary (T5 input, unless clip_include_summary)
1604
+ checkpoint: Custom model checkpoint filename (empty for default)
1605
+ lyra_checkpoint: Custom Lyra VAE checkpoint filename (empty for default)
1606
  use_separator: Use ¶ separator between tags and summary
1607
  clip_include_summary: If True, CLIP also sees the summary
1608
  """
 
1619
  progress((step + 1) / total, desc=desc)
1620
 
1621
  try:
1622
+ pipeline = get_pipeline(model_choice, checkpoint, lyra_checkpoint)
1623
 
1624
  # Update scheduler if needed (SDXL only)
1625
  is_sdxl = "SDXL" in model_choice or "Illustrious" in model_choice
 
1762
  value="Illustrious XL"
1763
  )
1764
 
1765
+ with gr.Accordion("Advanced Options", open=False):
1766
+ checkpoint = gr.Textbox(
1767
+ label="Model Checkpoint (optional)",
1768
+ value="",
1769
+ placeholder="e.g., illustriousXL_v01.safetensors",
1770
+ info="Leave empty for default. Illustrious: .safetensors, Lune: .pt"
1771
+ )
1772
+
1773
+ lyra_checkpoint = gr.Textbox(
1774
+ label="Lyra VAE Checkpoint (optional)",
1775
+ value="weights/lyra_illustrious_step_41000.safetensors",
1776
+ placeholder="e.g., lyra_e100_s50000.safetensors",
1777
+ info="Leave empty for latest. Loaded from weights/ folder in Lyra repo."
1778
+ )
1779
+
1780
  scheduler_choice = gr.Dropdown(
1781
  label="Scheduler (SDXL only)",
1782
  choices=SCHEDULER_CHOICES,
 
1980
  generate_btn.click(
1981
  fn=generate_image,
1982
  inputs=[
1983
+ prompt, t5_summary, negative_prompt, model_choice, checkpoint, lyra_checkpoint,
1984
+ scheduler_choice, clip_skip,
1985
  num_steps, cfg_scale, width, height, shift,
1986
  use_flow_matching, use_lyra, lyra_strength, use_separator, clip_include_summary,
1987
  seed, randomize_seed