johnbridges commited on
Commit
d76b941
·
1 Parent(s): d279e64
Files changed (1) hide show
  1. hf_backend.py +35 -22
hf_backend.py CHANGED
@@ -1,6 +1,7 @@
1
  # hf_backend.py
2
  import time, logging
3
- from typing import Any, Dict, AsyncIterable
 
4
 
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
@@ -33,6 +34,7 @@ except Exception as e:
33
 
34
  # ---------------- helpers ----------------
35
  def _pick_cpu_dtype() -> torch.dtype:
 
36
  if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
37
  try:
38
  if torch.cpu.is_bf16_supported():
@@ -45,19 +47,22 @@ def _pick_cpu_dtype() -> torch.dtype:
45
 
46
 
47
  # ---------------- global cache ----------------
48
- _MODEL_CACHE: dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}
49
 
50
 
51
- def _get_model(device: str, dtype: torch.dtype):
52
- key = (device, dtype)
53
- if key in _MODEL_CACHE:
54
- return _MODEL_CACHE[key]
 
 
55
 
56
  cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
57
  if hasattr(cfg, "quantization_config"):
58
  logger.warning("Removing quantization_config from model config")
59
- delattr(cfg, "quantization_config")
60
 
 
61
  try:
62
  model = AutoModelForCausalLM.from_pretrained(
63
  MODEL_ID,
@@ -69,20 +74,20 @@ def _get_model(device: str, dtype: torch.dtype):
69
  except Exception as e:
70
  if device == "cpu" and dtype == torch.bfloat16:
71
  logger.warning(f"BF16 load failed on CPU: {e}. Retrying with FP32.")
 
72
  model = AutoModelForCausalLM.from_pretrained(
73
  MODEL_ID,
74
  config=cfg,
75
- torch_dtype=torch.float32,
76
  trust_remote_code=True,
77
  device_map={"": "cpu"},
78
  )
79
- dtype = torch.float32
80
  else:
81
  raise
82
 
83
  model.eval()
84
- _MODEL_CACHE[(device, dtype)] = model
85
- return model
86
 
87
 
88
  # ---------------- Chat Backend ----------------
@@ -105,7 +110,7 @@ class HFChatBackend(ChatBackend):
105
  logger.debug("Injected X-IP-Token into ZeroGPU headers")
106
 
107
  # Build prompt using chat template if available
108
- if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
109
  try:
110
  prompt = tokenizer.apply_chat_template(
111
  messages,
@@ -119,15 +124,20 @@ class HFChatBackend(ChatBackend):
119
  else:
120
  prompt = messages[-1]["content"] if messages else "(empty)"
121
 
122
- def _run_once(prompt: str, device: str, dtype: torch.dtype) -> str:
123
- model = _get_model(device, dtype)
124
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
125
 
126
  with torch.inference_mode():
127
  if device != "cpu":
128
- autocast_ctx = torch.autocast(device_type=device, dtype=dtype)
129
  else:
130
- autocast_ctx = torch.cpu.amp.autocast(dtype=dtype)
 
 
 
131
 
132
  with autocast_ctx:
133
  outputs = model.generate(
@@ -135,21 +145,24 @@ class HFChatBackend(ChatBackend):
135
  max_new_tokens=max_tokens,
136
  temperature=temperature,
137
  do_sample=True,
 
138
  )
139
 
140
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
141
 
142
  if spaces:
143
- # --- GPU path with ZeroGPU ---
144
  @spaces.GPU(duration=120)
145
  def run_once(prompt: str) -> str:
146
- return _run_once(prompt, device="cuda", dtype=torch.float16)
 
 
 
147
 
148
  text = run_once(prompt)
149
  else:
150
- # --- CPU-only fallback with auto dtype detection ---
151
- dtype = _pick_cpu_dtype()
152
- text = _run_once(prompt, device="cpu", dtype=dtype)
153
 
154
  yield {
155
  "id": rid,
 
1
  # hf_backend.py
2
  import time, logging
3
+ from contextlib import nullcontext
4
+ from typing import Any, Dict, AsyncIterable, Tuple
5
 
6
  import torch
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
 
34
 
35
  # ---------------- helpers ----------------
36
  def _pick_cpu_dtype() -> torch.dtype:
37
+ # Prefer BF16 if CPU supports it
38
  if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
39
  try:
40
  if torch.cpu.is_bf16_supported():
 
47
 
48
 
49
  # ---------------- global cache ----------------
50
+ _MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}
51
 
52
 
53
+ def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
54
+ # Return model and the effective dtype actually loaded with
55
+ # (handles CPU BF16 -> FP32 fallback)
56
+ effective_key = (device, dtype)
57
+ if effective_key in _MODEL_CACHE:
58
+ return _MODEL_CACHE[effective_key], dtype
59
 
60
  cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
61
  if hasattr(cfg, "quantization_config"):
62
  logger.warning("Removing quantization_config from model config")
63
+ delattr(cfg, "quantization_config") # delete instead of setting None
64
 
65
+ eff_dtype = dtype
66
  try:
67
  model = AutoModelForCausalLM.from_pretrained(
68
  MODEL_ID,
 
74
  except Exception as e:
75
  if device == "cpu" and dtype == torch.bfloat16:
76
  logger.warning(f"BF16 load failed on CPU: {e}. Retrying with FP32.")
77
+ eff_dtype = torch.float32
78
  model = AutoModelForCausalLM.from_pretrained(
79
  MODEL_ID,
80
  config=cfg,
81
+ torch_dtype=eff_dtype,
82
  trust_remote_code=True,
83
  device_map={"": "cpu"},
84
  )
 
85
  else:
86
  raise
87
 
88
  model.eval()
89
+ _MODEL_CACHE[(device, eff_dtype)] = model
90
+ return model, eff_dtype
91
 
92
 
93
  # ---------------- Chat Backend ----------------
 
110
  logger.debug("Injected X-IP-Token into ZeroGPU headers")
111
 
112
  # Build prompt using chat template if available
113
+ if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
114
  try:
115
  prompt = tokenizer.apply_chat_template(
116
  messages,
 
124
  else:
125
  prompt = messages[-1]["content"] if messages else "(empty)"
126
 
127
+ def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str:
128
+ model, eff_dtype = _get_model(device, req_dtype)
129
+
130
+ inputs = tokenizer(prompt, return_tensors="pt")
131
+ inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
132
 
133
  with torch.inference_mode():
134
  if device != "cpu":
135
+ autocast_ctx = torch.autocast(device_type=device, dtype=eff_dtype)
136
  else:
137
+ if eff_dtype == torch.bfloat16:
138
+ autocast_ctx = torch.cpu.amp.autocast(dtype=torch.bfloat16)
139
+ else:
140
+ autocast_ctx = nullcontext()
141
 
142
  with autocast_ctx:
143
  outputs = model.generate(
 
145
  max_new_tokens=max_tokens,
146
  temperature=temperature,
147
  do_sample=True,
148
+ use_cache=True,
149
  )
150
 
151
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
152
 
153
  if spaces:
154
+ # Always dispatch via ZeroGPU decorator if available.
155
  @spaces.GPU(duration=120)
156
  def run_once(prompt: str) -> str:
157
+ if torch.cuda.is_available():
158
+ return _run_once(prompt, device="cuda", req_dtype=torch.float16)
159
+ # Fallback to CPU inside the GPU context if CUDA is unavailable
160
+ return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
161
 
162
  text = run_once(prompt)
163
  else:
164
+ # CPU-only runtime
165
+ text = _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
 
166
 
167
  yield {
168
  "id": rid,