johnbridges commited on
Commit
be6d3d6
·
1 Parent(s): bf6d44e
Files changed (1) hide show
  1. hf_backend.py +63 -40
hf_backend.py CHANGED
@@ -10,46 +10,29 @@ from config import settings
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
- # ---------- logging helpers ----------
14
  def _snippet(txt: str, n: int = 800) -> str:
15
  if not isinstance(txt, str):
16
  return f"<non-str:{type(txt)}>"
17
  return txt if len(txt) <= n else txt[:n] + f"... <+{len(txt)-n} chars>"
18
 
19
- def _json_snippet(obj: Any, n: int = 800) -> str:
20
- try:
21
- s = json.dumps(obj, ensure_ascii=False, indent=2)
22
- except Exception:
23
- s = str(obj)
24
- return _snippet(s, n)
25
-
26
-
27
- # ---------- HF Spaces imports ----------
28
  try:
29
  import spaces
30
  from spaces.zero import client as zero_client
31
  except ImportError:
32
  spaces, zero_client = None, None
33
 
34
- # ---------- Model setup ----------
35
  MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
36
  logger.info(f"[init] MODEL_ID={MODEL_ID}")
37
 
38
  tokenizer, load_error = None, None
39
  try:
40
- tokenizer = AutoTokenizer.from_pretrained(
41
- MODEL_ID,
42
- trust_remote_code=True,
43
- use_fast=False,
44
- )
45
  has_template = hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None)
46
  logger.info(f"[init] tokenizer loaded. chat_template={'yes' if has_template else 'no'}")
47
  except Exception as e:
48
  load_error = f"Failed to load tokenizer: {e}"
49
  logger.exception(load_error)
50
 
51
-
52
- # ---------- helpers ----------
53
  def _pick_cpu_dtype() -> torch.dtype:
54
  try:
55
  if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported") and torch.cpu.is_bf16_supported():
@@ -60,11 +43,8 @@ def _pick_cpu_dtype() -> torch.dtype:
60
  logger.info("[dtype] fallback -> torch.float32")
61
  return torch.float32
62
 
63
-
64
- # ---------- global cache ----------
65
  _MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}
66
 
67
-
68
  def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
69
  key = (device, dtype)
70
  if key in _MODEL_CACHE:
@@ -120,8 +100,40 @@ def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, t
120
  _MODEL_CACHE[(device, eff_dtype)] = model
121
  return model, eff_dtype
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- # ---------- Chat Backend ----------
125
  class HFChatBackend(ChatBackend):
126
  async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
127
  if load_error:
@@ -130,16 +142,15 @@ class HFChatBackend(ChatBackend):
130
  messages = request.get("messages", [])
131
  tools = request.get("tools")
132
  temperature = float(request.get("temperature", settings.LlmTemp or 0.7))
133
- max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512))
134
 
135
  rid = f"chatcmpl-hf-{int(time.time())}"
136
  now = int(time.time())
137
 
138
- logger.info(f"[req] rid={rid} temp={temperature} max_tokens={max_tokens} "
139
  f"msgs={len(messages)} tools={'yes' if tools else 'no'} "
140
  f"spaces={'yes' if spaces else 'no'} cuda={'yes' if torch.cuda.is_available() else 'no'}")
141
 
142
- # X-IP-Token for ZeroGPU
143
  x_ip_token = request.get("x_ip_token")
144
  if x_ip_token and zero_client:
145
  zero_client.HEADERS["X-IP-Token"] = x_ip_token
@@ -150,11 +161,11 @@ class HFChatBackend(ChatBackend):
150
  try:
151
  prompt = tokenizer.apply_chat_template(
152
  messages,
153
- tools=tools,
154
  tokenize=False,
155
  add_generation_prompt=True,
156
  )
157
- logger.info(f"[prompt] built via chat_template. len={len(prompt)}\n{_snippet(prompt, 1200)}")
158
  except Exception as e:
159
  logger.warning(f"[prompt] chat_template failed -> fallback. err={e}")
160
  prompt = messages[-1]["content"] if messages else "(empty)"
@@ -166,11 +177,25 @@ class HFChatBackend(ChatBackend):
166
  def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str:
167
  model, eff_dtype = _get_model(device, req_dtype)
168
 
169
- inputs = tokenizer(prompt, return_tensors="pt")
170
- input_ids = inputs["input_ids"]
171
- logger.info(f"[gen] device={device} dtype={eff_dtype} input_tokens={input_ids.shape[-1]}")
 
 
 
 
 
 
 
 
172
 
173
- inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
 
 
 
 
 
 
174
 
175
  with torch.inference_mode():
176
  if device != "cpu":
@@ -179,25 +204,26 @@ class HFChatBackend(ChatBackend):
179
  autocast_ctx = torch.cpu.amp.autocast(dtype=torch.bfloat16) if eff_dtype == torch.bfloat16 else nullcontext()
180
 
181
  gen_kwargs = dict(
182
- max_new_tokens=max_tokens,
183
- temperature=temperature,
184
- do_sample=True,
185
  use_cache=True,
 
 
186
  )
187
  logger.info(f"[gen] kwargs={gen_kwargs}")
188
 
189
  with autocast_ctx:
190
  outputs = model.generate(**inputs, **gen_kwargs)
191
 
192
- # Only decode newly generated tokens
193
- input_len = input_ids.shape[-1]
194
  generated_ids = outputs[0][input_len:]
195
  logger.info(f"[gen] new_tokens={generated_ids.shape[-1]}")
196
  text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
197
  logger.info(f"[gen] text len={len(text)}\n{_snippet(text, 1200)}")
198
  return text
199
 
200
- # Dispatch with or without ZeroGPU
201
  if spaces:
202
  @spaces.GPU(duration=120)
203
  def run_once(prompt: str) -> str:
@@ -206,13 +232,11 @@ class HFChatBackend(ChatBackend):
206
  return _run_once(prompt, device="cuda", req_dtype=torch.float16)
207
  logger.info("[path] ZeroGPU but no CUDA -> CPU fallback")
208
  return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
209
-
210
  text = run_once(prompt)
211
  else:
212
  logger.info("[path] CPU-only runtime")
213
  text = _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
214
 
215
- # Emit single OpenAI-style chunk
216
  chunk = {
217
  "id": rid,
218
  "object": "chat.completion.chunk",
@@ -226,7 +250,6 @@ class HFChatBackend(ChatBackend):
226
  yield chunk
227
 
228
 
229
- # ---------- Stub Images Backend ----------
230
  class StubImagesBackend(ImagesBackend):
231
  async def generate_b64(self, request: Dict[str, Any]) -> str:
232
  logger.warning("Image generation not supported in HF backend.")
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
13
  def _snippet(txt: str, n: int = 800) -> str:
14
  if not isinstance(txt, str):
15
  return f"<non-str:{type(txt)}>"
16
  return txt if len(txt) <= n else txt[:n] + f"... <+{len(txt)-n} chars>"
17
 
 
 
 
 
 
 
 
 
 
18
  try:
19
  import spaces
20
  from spaces.zero import client as zero_client
21
  except ImportError:
22
  spaces, zero_client = None, None
23
 
 
24
  MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
25
  logger.info(f"[init] MODEL_ID={MODEL_ID}")
26
 
27
  tokenizer, load_error = None, None
28
  try:
29
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False)
 
 
 
 
30
  has_template = hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None)
31
  logger.info(f"[init] tokenizer loaded. chat_template={'yes' if has_template else 'no'}")
32
  except Exception as e:
33
  load_error = f"Failed to load tokenizer: {e}"
34
  logger.exception(load_error)
35
 
 
 
36
  def _pick_cpu_dtype() -> torch.dtype:
37
  try:
38
  if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported") and torch.cpu.is_bf16_supported():
 
43
  logger.info("[dtype] fallback -> torch.float32")
44
  return torch.float32
45
 
 
 
46
  _MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}
47
 
 
48
  def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
49
  key = (device, dtype)
50
  if key in _MODEL_CACHE:
 
100
  _MODEL_CACHE[(device, eff_dtype)] = model
101
  return model, eff_dtype
102
 
103
+ def _max_context(model, tokenizer) -> int:
104
+ # Prefer model config; fallback to tokenizer hint
105
+ mc = getattr(getattr(model, "config", None), "max_position_embeddings", None)
106
+ if isinstance(mc, int) and mc > 0:
107
+ return mc
108
+ tk = getattr(tokenizer, "model_max_length", None)
109
+ if isinstance(tk, int) and tk > 0 and tk < 10**12:
110
+ return tk
111
+ return 32768 # safe default for Qwen3
112
+
113
+ def _build_inputs_with_truncation(prompt: str, device: str, max_new_tokens: int, model, tokenizer):
114
+ toks = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
115
+ input_ids = toks["input_ids"]
116
+ attn = toks.get("attention_mask", None)
117
+
118
+ ctx = _max_context(model, tokenizer)
119
+ limit = max(8, ctx - max_new_tokens)
120
+ in_len = input_ids.shape[-1]
121
+ if in_len > limit:
122
+ # left-truncate to fit context
123
+ cut = in_len - limit
124
+ input_ids = input_ids[:, -limit:]
125
+ if attn is not None:
126
+ attn = attn[:, -limit:]
127
+ logger.warning(f"[truncate] prompt_tokens={in_len} > limit={limit}. truncated_left_by={cut} to fit ctx={ctx}, new_input={input_ids.shape[-1]}, max_new={max_new_tokens}")
128
+
129
+ inputs = {"input_ids": input_ids}
130
+ if attn is not None:
131
+ inputs["attention_mask"] = attn
132
+
133
+ # move to device
134
+ inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
135
+ return inputs, in_len, ctx, limit
136
 
 
137
  class HFChatBackend(ChatBackend):
138
  async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
139
  if load_error:
 
142
  messages = request.get("messages", [])
143
  tools = request.get("tools")
144
  temperature = float(request.get("temperature", settings.LlmTemp or 0.7))
145
+ req_max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512))
146
 
147
  rid = f"chatcmpl-hf-{int(time.time())}"
148
  now = int(time.time())
149
 
150
+ logger.info(f"[req] rid={rid} temp={temperature} req_max_tokens={req_max_tokens} "
151
  f"msgs={len(messages)} tools={'yes' if tools else 'no'} "
152
  f"spaces={'yes' if spaces else 'no'} cuda={'yes' if torch.cuda.is_available() else 'no'}")
153
 
 
154
  x_ip_token = request.get("x_ip_token")
155
  if x_ip_token and zero_client:
156
  zero_client.HEADERS["X-IP-Token"] = x_ip_token
 
161
  try:
162
  prompt = tokenizer.apply_chat_template(
163
  messages,
164
+ #tools=tools,
165
  tokenize=False,
166
  add_generation_prompt=True,
167
  )
168
+ logger.info(f"[prompt] built via chat_template. len={len(prompt)}\n{_snippet(prompt, 800)}")
169
  except Exception as e:
170
  logger.warning(f"[prompt] chat_template failed -> fallback. err={e}")
171
  prompt = messages[-1]["content"] if messages else "(empty)"
 
177
  def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str:
178
  model, eff_dtype = _get_model(device, req_dtype)
179
 
180
+ # Clamp max_new_tokens for CPU to prevent stalls
181
+ if device == "cpu":
182
+ max_new_tokens = min(req_max_tokens, 512)
183
+ else:
184
+ max_new_tokens = req_max_tokens
185
+
186
+ # Build inputs with context-aware truncation
187
+ inputs, orig_in_len, ctx, limit = _build_inputs_with_truncation(prompt, device, max_new_tokens, model, tokenizer)
188
+
189
+ logger.info(f"[gen] device={device} dtype={eff_dtype} input_tokens={inputs['input_ids'].shape[-1]} "
190
+ f"(orig={orig_in_len}) max_ctx={ctx} limit_for_input={limit} max_new_tokens={max_new_tokens}")
191
 
192
+ # Sampling settings
193
+ do_sample = temperature > 1e-6
194
+ temp = max(1e-5, temperature) if do_sample else 0.0
195
+
196
+ # ids
197
+ eos_id = tokenizer.eos_token_id
198
+ pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_id
199
 
200
  with torch.inference_mode():
201
  if device != "cpu":
 
204
  autocast_ctx = torch.cpu.amp.autocast(dtype=torch.bfloat16) if eff_dtype == torch.bfloat16 else nullcontext()
205
 
206
  gen_kwargs = dict(
207
+ max_new_tokens=max_new_tokens,
208
+ temperature=temp,
209
+ do_sample=do_sample,
210
  use_cache=True,
211
+ eos_token_id=eos_id,
212
+ pad_token_id=pad_id,
213
  )
214
  logger.info(f"[gen] kwargs={gen_kwargs}")
215
 
216
  with autocast_ctx:
217
  outputs = model.generate(**inputs, **gen_kwargs)
218
 
219
+ # Slice generated continuation only
220
+ input_len = inputs["input_ids"].shape[-1]
221
  generated_ids = outputs[0][input_len:]
222
  logger.info(f"[gen] new_tokens={generated_ids.shape[-1]}")
223
  text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
224
  logger.info(f"[gen] text len={len(text)}\n{_snippet(text, 1200)}")
225
  return text
226
 
 
227
  if spaces:
228
  @spaces.GPU(duration=120)
229
  def run_once(prompt: str) -> str:
 
232
  return _run_once(prompt, device="cuda", req_dtype=torch.float16)
233
  logger.info("[path] ZeroGPU but no CUDA -> CPU fallback")
234
  return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
 
235
  text = run_once(prompt)
236
  else:
237
  logger.info("[path] CPU-only runtime")
238
  text = _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
239
 
 
240
  chunk = {
241
  "id": rid,
242
  "object": "chat.completion.chunk",
 
250
  yield chunk
251
 
252
 
 
253
  class StubImagesBackend(ImagesBackend):
254
  async def generate_b64(self, request: Dict[str, Any]) -> str:
255
  logger.warning("Image generation not supported in HF backend.")