AbstractPhil commited on
Commit
9348685
Β·
verified Β·
1 Parent(s): 757995e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -36
app.py CHANGED
@@ -12,6 +12,7 @@ Lyra VAE Versions:
12
  """
13
 
14
  import os
 
15
  import torch
16
  import gradio as gr
17
  import numpy as np
@@ -901,17 +902,20 @@ def load_lyra_vae(repo_id: str = "AbstractPhil/vae-lyra", device: str = "cuda"):
901
  print(f"🎡 Loading Lyra VAE v1 from {repo_id}...")
902
 
903
  try:
904
- checkpoint_path = hf_hub_download(
905
- repo_id=repo_id,
906
- filename="best_model.pt",
907
- repo_type="model"
908
- )
909
-
910
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
911
-
912
- if 'config' in checkpoint:
913
- config_dict = checkpoint['config']
914
- else:
 
 
 
915
  config_dict = {
916
  'modality_dims': {"clip": 768, "t5": 768},
917
  'latent_dim': 768,
@@ -925,6 +929,24 @@ def load_lyra_vae(repo_id: str = "AbstractPhil/vae-lyra", device: str = "cuda"):
925
  'fusion_dropout': 0.1
926
  }
927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
928
  vae_config = LyraV1Config(
929
  modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 768}),
930
  latent_dim=config_dict.get('latent_dim', 768),
@@ -948,11 +970,18 @@ def load_lyra_vae(repo_id: str = "AbstractPhil/vae-lyra", device: str = "cuda"):
948
  lyra_model.to(device)
949
  lyra_model.eval()
950
 
951
- print(f"βœ… Lyra VAE v1 (SD1.5) loaded")
 
 
 
 
 
952
  return lyra_model
953
 
954
  except Exception as e:
955
  print(f"❌ Failed to load Lyra VAE v1: {e}")
 
 
956
  return None
957
 
958
 
@@ -968,46 +997,52 @@ def load_lyra_vae_xl(
968
  print(f"🎡 Loading Lyra VAE v2 from {repo_id}...")
969
 
970
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971
  checkpoint_path = hf_hub_download(
972
  repo_id=repo_id,
973
- filename="best_model.pt",
974
  repo_type="model"
975
  )
976
 
977
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
978
 
979
- if 'config' in checkpoint:
980
- config_dict = checkpoint['config']
981
- else:
982
- # XL v2 defaults - larger dimensions for SDXL
983
- config_dict = {
984
- 'modality_dims': {"clip": 768, "t5": 2048}, # T5-XL
985
- 'latent_dim': 2048,
986
- 'seq_len': 77,
987
- 'encoder_layers': 4,
988
- 'decoder_layers': 4,
989
- 'hidden_dim': 2048,
990
- 'dropout': 0.1,
991
- 'fusion_strategy': 'adaptive_cantor',
992
- 'fusion_heads': 16,
993
- 'fusion_dropout': 0.1
994
- }
995
-
996
  vae_config = LyraV2Config(
997
- modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 2048}),
 
 
998
  latent_dim=config_dict.get('latent_dim', 2048),
999
  seq_len=config_dict.get('seq_len', 77),
1000
- encoder_layers=config_dict.get('encoder_layers', 4),
1001
- decoder_layers=config_dict.get('decoder_layers', 4),
1002
  hidden_dim=config_dict.get('hidden_dim', 2048),
1003
  dropout=config_dict.get('dropout', 0.1),
1004
  fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'),
1005
- fusion_heads=config_dict.get('fusion_heads', 16),
1006
- fusion_dropout=config_dict.get('fusion_dropout', 0.1)
 
 
 
 
1007
  )
1008
 
1009
  lyra_model = LyraV2(vae_config)
1010
 
 
1011
  if 'model_state_dict' in checkpoint:
1012
  lyra_model.load_state_dict(checkpoint['model_state_dict'])
1013
  else:
@@ -1016,14 +1051,21 @@ def load_lyra_vae_xl(
1016
  lyra_model.to(device)
1017
  lyra_model.eval()
1018
 
1019
- print(f"βœ… Lyra VAE v2 (SDXL) loaded")
 
 
 
1020
  if 'global_step' in checkpoint:
1021
  print(f" Step: {checkpoint['global_step']:,}")
 
 
1022
 
1023
  return lyra_model
1024
 
1025
  except Exception as e:
1026
  print(f"❌ Failed to load Lyra VAE v2: {e}")
 
 
1027
  return None
1028
 
1029
 
 
12
  """
13
 
14
  import os
15
+ import json
16
  import torch
17
  import gradio as gr
18
  import numpy as np
 
902
  print(f"🎡 Loading Lyra VAE v1 from {repo_id}...")
903
 
904
  try:
905
+ # Try to download config.json first
906
+ try:
907
+ print(" πŸ“₯ Downloading config.json...")
908
+ config_path = hf_hub_download(
909
+ repo_id=repo_id,
910
+ filename="config.json",
911
+ repo_type="model"
912
+ )
913
+ with open(config_path, 'r') as f:
914
+ config_dict = json.load(f)
915
+ print(f" βœ“ Config loaded: {config_dict.get('fusion_strategy', 'unknown')} fusion")
916
+ except Exception:
917
+ # Fallback to defaults if no config.json
918
+ print(" ⚠️ No config.json found, using defaults")
919
  config_dict = {
920
  'modality_dims': {"clip": 768, "t5": 768},
921
  'latent_dim': 768,
 
929
  'fusion_dropout': 0.1
930
  }
931
 
932
+ # Download model weights
933
+ print(" πŸ“₯ Downloading model weights...")
934
+ try:
935
+ checkpoint_path = hf_hub_download(
936
+ repo_id=repo_id,
937
+ filename="model.pt",
938
+ repo_type="model"
939
+ )
940
+ except Exception:
941
+ # Fallback to best_model.pt
942
+ checkpoint_path = hf_hub_download(
943
+ repo_id=repo_id,
944
+ filename="best_model.pt",
945
+ repo_type="model"
946
+ )
947
+
948
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
949
+
950
  vae_config = LyraV1Config(
951
  modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 768}),
952
  latent_dim=config_dict.get('latent_dim', 768),
 
970
  lyra_model.to(device)
971
  lyra_model.eval()
972
 
973
+ print(f"βœ… Lyra VAE v1 loaded")
974
+ print(f" Fusion: {config_dict.get('fusion_strategy')}")
975
+ print(f" Latent dim: {config_dict.get('latent_dim')}")
976
+ if 'global_step' in checkpoint:
977
+ print(f" Step: {checkpoint['global_step']:,}")
978
+
979
  return lyra_model
980
 
981
  except Exception as e:
982
  print(f"❌ Failed to load Lyra VAE v1: {e}")
983
+ import traceback
984
+ traceback.print_exc()
985
  return None
986
 
987
 
 
997
  print(f"🎡 Loading Lyra VAE v2 from {repo_id}...")
998
 
999
  try:
1000
+ # Download config.json first to get model architecture
1001
+ print(" πŸ“₯ Downloading config.json...")
1002
+ config_path = hf_hub_download(
1003
+ repo_id=repo_id,
1004
+ filename="config.json",
1005
+ repo_type="model"
1006
+ )
1007
+
1008
+ with open(config_path, 'r') as f:
1009
+ config_dict = json.load(f)
1010
+
1011
+ print(f" βœ“ Config loaded: {config_dict.get('fusion_strategy', 'unknown')} fusion")
1012
+
1013
+ # Download model weights
1014
+ print(" πŸ“₯ Downloading model.pt...")
1015
  checkpoint_path = hf_hub_download(
1016
  repo_id=repo_id,
1017
+ filename="model.pt",
1018
  repo_type="model"
1019
  )
1020
 
1021
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
1022
 
1023
+ # Build config from repo's config.json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1024
  vae_config = LyraV2Config(
1025
+ modality_dims=config_dict.get('modality_dims', {"clip_l": 768, "clip_g": 1280, "t5_xl_l": 2048, "t5_xl_g": 2048}),
1026
+ modality_seq_lens=config_dict.get('modality_seq_lens', {"clip_l": 77, "clip_g": 77, "t5_xl_l": 512, "t5_xl_g": 512}),
1027
+ binding_config=config_dict.get('binding_config'),
1028
  latent_dim=config_dict.get('latent_dim', 2048),
1029
  seq_len=config_dict.get('seq_len', 77),
1030
+ encoder_layers=config_dict.get('encoder_layers', 3),
1031
+ decoder_layers=config_dict.get('decoder_layers', 3),
1032
  hidden_dim=config_dict.get('hidden_dim', 2048),
1033
  dropout=config_dict.get('dropout', 0.1),
1034
  fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'),
1035
+ fusion_heads=config_dict.get('fusion_heads', 8),
1036
+ fusion_dropout=config_dict.get('fusion_dropout', 0.1),
1037
+ cantor_depth=config_dict.get('cantor_depth', 8),
1038
+ cantor_local_window=config_dict.get('cantor_local_window', 3),
1039
+ alpha_init=config_dict.get('alpha_init', 1.0),
1040
+ beta_init=config_dict.get('beta_init', 0.3),
1041
  )
1042
 
1043
  lyra_model = LyraV2(vae_config)
1044
 
1045
+ # Load weights from checkpoint
1046
  if 'model_state_dict' in checkpoint:
1047
  lyra_model.load_state_dict(checkpoint['model_state_dict'])
1048
  else:
 
1051
  lyra_model.to(device)
1052
  lyra_model.eval()
1053
 
1054
+ print(f"βœ… Lyra VAE v2 loaded")
1055
+ print(f" Fusion: {config_dict.get('fusion_strategy')}")
1056
+ print(f" Latent dim: {config_dict.get('latent_dim')}")
1057
+ print(f" Hidden dim: {config_dict.get('hidden_dim')}")
1058
  if 'global_step' in checkpoint:
1059
  print(f" Step: {checkpoint['global_step']:,}")
1060
+ if 'best_loss' in checkpoint:
1061
+ print(f" Loss: {checkpoint['best_loss']:.4f}")
1062
 
1063
  return lyra_model
1064
 
1065
  except Exception as e:
1066
  print(f"❌ Failed to load Lyra VAE v2: {e}")
1067
+ import traceback
1068
+ traceback.print_exc()
1069
  return None
1070
 
1071