johnbridges commited on
Commit
849364d
·
1 Parent(s): 1d79762
Files changed (1) hide show
  1. hf_backend.py +27 -15
hf_backend.py CHANGED
@@ -45,6 +45,32 @@ def _pick_cpu_dtype() -> torch.dtype:
45
  return torch.float32
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # ---------------- Chat Backend ----------------
49
  class HFChatBackend(ChatBackend):
50
  async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
@@ -66,21 +92,7 @@ class HFChatBackend(ChatBackend):
66
  logger.debug("Injected X-IP-Token into ZeroGPU headers")
67
 
68
  def _run_once(prompt: str, device: str, dtype: torch.dtype) -> str:
69
- # Load config and strip any quantization settings (fix FP8 issue)
70
- cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
71
- if hasattr(cfg, "quantization_config"):
72
- logger.warning("Removing quantization_config from model config")
73
- cfg.quantization_config = None
74
-
75
- model = AutoModelForCausalLM.from_pretrained(
76
- MODEL_ID,
77
- config=cfg,
78
- torch_dtype=dtype,
79
- trust_remote_code=True,
80
- device_map="auto" if device != "cpu" else {"": "cpu"},
81
- )
82
- model.eval()
83
-
84
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
85
 
86
  with torch.inference_mode():
 
45
  return torch.float32
46
 
47
 
48
+ # ---------------- global cache ----------------
49
+ _MODEL_CACHE: dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}
50
+
51
+
52
+ def _get_model(device: str, dtype: torch.dtype):
53
+ key = (device, dtype)
54
+ if key in _MODEL_CACHE:
55
+ return _MODEL_CACHE[key]
56
+
57
+ cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
58
+ if hasattr(cfg, "quantization_config"):
59
+ logger.warning("Removing quantization_config from model config")
60
+ delattr(cfg, "quantization_config") # delete instead of setting None
61
+
62
+ model = AutoModelForCausalLM.from_pretrained(
63
+ MODEL_ID,
64
+ config=cfg,
65
+ torch_dtype=dtype,
66
+ trust_remote_code=True,
67
+ device_map="auto" if device != "cpu" else {"": "cpu"},
68
+ )
69
+ model.eval()
70
+ _MODEL_CACHE[key] = model
71
+ return model
72
+
73
+
74
  # ---------------- Chat Backend ----------------
75
  class HFChatBackend(ChatBackend):
76
  async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
 
92
  logger.debug("Injected X-IP-Token into ZeroGPU headers")
93
 
94
  def _run_once(prompt: str, device: str, dtype: torch.dtype) -> str:
95
+ model = _get_model(device, dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
97
 
98
  with torch.inference_mode():