johnbridges commited on
Commit
b416f51
·
1 Parent(s): 552430d
Files changed (1) hide show
  1. hf_backend.py +6 -13
hf_backend.py CHANGED
@@ -1,5 +1,5 @@
1
  # hf_backend.py
2
- import time, logging, json
3
  from contextlib import nullcontext
4
  from typing import Any, Dict, AsyncIterable, Tuple
5
 
@@ -101,7 +101,6 @@ def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, t
101
  return model, eff_dtype
102
 
103
  def _max_context(model, tokenizer) -> int:
104
- # Prefer model config; fallback to tokenizer hint
105
  mc = getattr(getattr(model, "config", None), "max_position_embeddings", None)
106
  if isinstance(mc, int) and mc > 0:
107
  return mc
@@ -119,7 +118,6 @@ def _build_inputs_with_truncation(prompt: str, device: str, max_new_tokens: int,
119
  limit = max(8, ctx - max_new_tokens)
120
  in_len = input_ids.shape[-1]
121
  if in_len > limit:
122
- # left-truncate to fit context
123
  cut = in_len - limit
124
  input_ids = input_ids[:, -limit:]
125
  if attn is not None:
@@ -130,7 +128,6 @@ def _build_inputs_with_truncation(prompt: str, device: str, max_new_tokens: int,
130
  if attn is not None:
131
  inputs["attention_mask"] = attn
132
 
133
- # move to device
134
  inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
135
  return inputs, in_len, ctx, limit
136
 
@@ -156,7 +153,7 @@ class HFChatBackend(ChatBackend):
156
  zero_client.HEADERS["X-IP-Token"] = x_ip_token
157
  logger.info("[req] injected X-IP-Token into ZeroGPU headers")
158
 
159
- # Build prompt
160
  if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
161
  try:
162
  prompt = tokenizer.apply_chat_template(
@@ -176,20 +173,16 @@ class HFChatBackend(ChatBackend):
176
 
177
  def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str:
178
  model, eff_dtype = _get_model(device, req_dtype)
179
-
180
  max_new_tokens = req_max_tokens
181
 
182
- # Build inputs with context-aware truncation
183
  inputs, orig_in_len, ctx, limit = _build_inputs_with_truncation(prompt, device, max_new_tokens, model, tokenizer)
184
 
185
  logger.info(f"[gen] device={device} dtype={eff_dtype} input_tokens={inputs['input_ids'].shape[-1]} "
186
  f"(orig={orig_in_len}) max_ctx={ctx} limit_for_input={limit} max_new_tokens={max_new_tokens}")
187
 
188
- # Sampling settings
189
  do_sample = temperature > 1e-6
190
  temp = max(1e-5, temperature) if do_sample else 0.0
191
 
192
- # ids
193
  eos_id = tokenizer.eos_token_id
194
  pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_id
195
 
@@ -212,7 +205,6 @@ class HFChatBackend(ChatBackend):
212
  with autocast_ctx:
213
  outputs = model.generate(**inputs, **gen_kwargs)
214
 
215
- # Slice generated continuation only
216
  input_len = inputs["input_ids"].shape[-1]
217
  generated_ids = outputs[0][input_len:]
218
  logger.info(f"[gen] new_tokens={generated_ids.shape[-1]}")
@@ -220,18 +212,19 @@ class HFChatBackend(ChatBackend):
220
  logger.info(f"[gen] text len={len(text)}\n{_snippet(text, 1200)}")
221
  return text
222
 
 
223
  if spaces:
224
  @spaces.GPU(duration=120)
225
- def run_once(prompt: str) -> str:
226
  if torch.cuda.is_available():
227
  logger.info("[path] ZeroGPU + CUDA")
228
  return _run_once(prompt, device="cuda", req_dtype=torch.float16)
229
  logger.info("[path] ZeroGPU but no CUDA -> CPU fallback")
230
  return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
231
- text = run_once(prompt)
232
  else:
233
  logger.info("[path] CPU-only runtime")
234
- text = _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
235
 
236
  chunk = {
237
  "id": rid,
 
1
  # hf_backend.py
2
+ import time, logging, json, asyncio
3
  from contextlib import nullcontext
4
  from typing import Any, Dict, AsyncIterable, Tuple
5
 
 
101
  return model, eff_dtype
102
 
103
  def _max_context(model, tokenizer) -> int:
 
104
  mc = getattr(getattr(model, "config", None), "max_position_embeddings", None)
105
  if isinstance(mc, int) and mc > 0:
106
  return mc
 
118
  limit = max(8, ctx - max_new_tokens)
119
  in_len = input_ids.shape[-1]
120
  if in_len > limit:
 
121
  cut = in_len - limit
122
  input_ids = input_ids[:, -limit:]
123
  if attn is not None:
 
128
  if attn is not None:
129
  inputs["attention_mask"] = attn
130
 
 
131
  inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
132
  return inputs, in_len, ctx, limit
133
 
 
153
  zero_client.HEADERS["X-IP-Token"] = x_ip_token
154
  logger.info("[req] injected X-IP-Token into ZeroGPU headers")
155
 
156
+ # Build prompt (pass tools to template)
157
  if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
158
  try:
159
  prompt = tokenizer.apply_chat_template(
 
173
 
174
  def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str:
175
  model, eff_dtype = _get_model(device, req_dtype)
 
176
  max_new_tokens = req_max_tokens
177
 
 
178
  inputs, orig_in_len, ctx, limit = _build_inputs_with_truncation(prompt, device, max_new_tokens, model, tokenizer)
179
 
180
  logger.info(f"[gen] device={device} dtype={eff_dtype} input_tokens={inputs['input_ids'].shape[-1]} "
181
  f"(orig={orig_in_len}) max_ctx={ctx} limit_for_input={limit} max_new_tokens={max_new_tokens}")
182
 
 
183
  do_sample = temperature > 1e-6
184
  temp = max(1e-5, temperature) if do_sample else 0.0
185
 
 
186
  eos_id = tokenizer.eos_token_id
187
  pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_id
188
 
 
205
  with autocast_ctx:
206
  outputs = model.generate(**inputs, **gen_kwargs)
207
 
 
208
  input_len = inputs["input_ids"].shape[-1]
209
  generated_ids = outputs[0][input_len:]
210
  logger.info(f"[gen] new_tokens={generated_ids.shape[-1]}")
 
212
  logger.info(f"[gen] text len={len(text)}\n{_snippet(text, 1200)}")
213
  return text
214
 
215
+ # Offload heavy work to a worker thread so asyncio heartbeats continue
216
  if spaces:
217
  @spaces.GPU(duration=120)
218
+ def run_once_sync(prompt: str) -> str:
219
  if torch.cuda.is_available():
220
  logger.info("[path] ZeroGPU + CUDA")
221
  return _run_once(prompt, device="cuda", req_dtype=torch.float16)
222
  logger.info("[path] ZeroGPU but no CUDA -> CPU fallback")
223
  return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
224
+ text = await asyncio.to_thread(run_once_sync, prompt)
225
  else:
226
  logger.info("[path] CPU-only runtime")
227
+ text = await asyncio.to_thread(_run_once, prompt, "cpu", _pick_cpu_dtype())
228
 
229
  chunk = {
230
  "id": rid,