Commit
·
b416f51
1
Parent(s):
552430d
- hf_backend.py +6 -13
hf_backend.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# hf_backend.py
|
| 2 |
-
import time, logging, json
|
| 3 |
from contextlib import nullcontext
|
| 4 |
from typing import Any, Dict, AsyncIterable, Tuple
|
| 5 |
|
|
@@ -101,7 +101,6 @@ def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, t
|
|
| 101 |
return model, eff_dtype
|
| 102 |
|
| 103 |
def _max_context(model, tokenizer) -> int:
|
| 104 |
-
# Prefer model config; fallback to tokenizer hint
|
| 105 |
mc = getattr(getattr(model, "config", None), "max_position_embeddings", None)
|
| 106 |
if isinstance(mc, int) and mc > 0:
|
| 107 |
return mc
|
|
@@ -119,7 +118,6 @@ def _build_inputs_with_truncation(prompt: str, device: str, max_new_tokens: int,
|
|
| 119 |
limit = max(8, ctx - max_new_tokens)
|
| 120 |
in_len = input_ids.shape[-1]
|
| 121 |
if in_len > limit:
|
| 122 |
-
# left-truncate to fit context
|
| 123 |
cut = in_len - limit
|
| 124 |
input_ids = input_ids[:, -limit:]
|
| 125 |
if attn is not None:
|
|
@@ -130,7 +128,6 @@ def _build_inputs_with_truncation(prompt: str, device: str, max_new_tokens: int,
|
|
| 130 |
if attn is not None:
|
| 131 |
inputs["attention_mask"] = attn
|
| 132 |
|
| 133 |
-
# move to device
|
| 134 |
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
|
| 135 |
return inputs, in_len, ctx, limit
|
| 136 |
|
|
@@ -156,7 +153,7 @@ class HFChatBackend(ChatBackend):
|
|
| 156 |
zero_client.HEADERS["X-IP-Token"] = x_ip_token
|
| 157 |
logger.info("[req] injected X-IP-Token into ZeroGPU headers")
|
| 158 |
|
| 159 |
-
# Build prompt
|
| 160 |
if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
|
| 161 |
try:
|
| 162 |
prompt = tokenizer.apply_chat_template(
|
|
@@ -176,20 +173,16 @@ class HFChatBackend(ChatBackend):
|
|
| 176 |
|
| 177 |
def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str:
|
| 178 |
model, eff_dtype = _get_model(device, req_dtype)
|
| 179 |
-
|
| 180 |
max_new_tokens = req_max_tokens
|
| 181 |
|
| 182 |
-
# Build inputs with context-aware truncation
|
| 183 |
inputs, orig_in_len, ctx, limit = _build_inputs_with_truncation(prompt, device, max_new_tokens, model, tokenizer)
|
| 184 |
|
| 185 |
logger.info(f"[gen] device={device} dtype={eff_dtype} input_tokens={inputs['input_ids'].shape[-1]} "
|
| 186 |
f"(orig={orig_in_len}) max_ctx={ctx} limit_for_input={limit} max_new_tokens={max_new_tokens}")
|
| 187 |
|
| 188 |
-
# Sampling settings
|
| 189 |
do_sample = temperature > 1e-6
|
| 190 |
temp = max(1e-5, temperature) if do_sample else 0.0
|
| 191 |
|
| 192 |
-
# ids
|
| 193 |
eos_id = tokenizer.eos_token_id
|
| 194 |
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_id
|
| 195 |
|
|
@@ -212,7 +205,6 @@ class HFChatBackend(ChatBackend):
|
|
| 212 |
with autocast_ctx:
|
| 213 |
outputs = model.generate(**inputs, **gen_kwargs)
|
| 214 |
|
| 215 |
-
# Slice generated continuation only
|
| 216 |
input_len = inputs["input_ids"].shape[-1]
|
| 217 |
generated_ids = outputs[0][input_len:]
|
| 218 |
logger.info(f"[gen] new_tokens={generated_ids.shape[-1]}")
|
|
@@ -220,18 +212,19 @@ class HFChatBackend(ChatBackend):
|
|
| 220 |
logger.info(f"[gen] text len={len(text)}\n{_snippet(text, 1200)}")
|
| 221 |
return text
|
| 222 |
|
|
|
|
| 223 |
if spaces:
|
| 224 |
@spaces.GPU(duration=120)
|
| 225 |
-
def
|
| 226 |
if torch.cuda.is_available():
|
| 227 |
logger.info("[path] ZeroGPU + CUDA")
|
| 228 |
return _run_once(prompt, device="cuda", req_dtype=torch.float16)
|
| 229 |
logger.info("[path] ZeroGPU but no CUDA -> CPU fallback")
|
| 230 |
return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
|
| 231 |
-
text =
|
| 232 |
else:
|
| 233 |
logger.info("[path] CPU-only runtime")
|
| 234 |
-
text = _run_once
|
| 235 |
|
| 236 |
chunk = {
|
| 237 |
"id": rid,
|
|
|
|
| 1 |
# hf_backend.py
|
| 2 |
+
import time, logging, json, asyncio
|
| 3 |
from contextlib import nullcontext
|
| 4 |
from typing import Any, Dict, AsyncIterable, Tuple
|
| 5 |
|
|
|
|
| 101 |
return model, eff_dtype
|
| 102 |
|
| 103 |
def _max_context(model, tokenizer) -> int:
|
|
|
|
| 104 |
mc = getattr(getattr(model, "config", None), "max_position_embeddings", None)
|
| 105 |
if isinstance(mc, int) and mc > 0:
|
| 106 |
return mc
|
|
|
|
| 118 |
limit = max(8, ctx - max_new_tokens)
|
| 119 |
in_len = input_ids.shape[-1]
|
| 120 |
if in_len > limit:
|
|
|
|
| 121 |
cut = in_len - limit
|
| 122 |
input_ids = input_ids[:, -limit:]
|
| 123 |
if attn is not None:
|
|
|
|
| 128 |
if attn is not None:
|
| 129 |
inputs["attention_mask"] = attn
|
| 130 |
|
|
|
|
| 131 |
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
|
| 132 |
return inputs, in_len, ctx, limit
|
| 133 |
|
|
|
|
| 153 |
zero_client.HEADERS["X-IP-Token"] = x_ip_token
|
| 154 |
logger.info("[req] injected X-IP-Token into ZeroGPU headers")
|
| 155 |
|
| 156 |
+
# Build prompt (pass tools to template)
|
| 157 |
if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
|
| 158 |
try:
|
| 159 |
prompt = tokenizer.apply_chat_template(
|
|
|
|
| 173 |
|
| 174 |
def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str:
|
| 175 |
model, eff_dtype = _get_model(device, req_dtype)
|
|
|
|
| 176 |
max_new_tokens = req_max_tokens
|
| 177 |
|
|
|
|
| 178 |
inputs, orig_in_len, ctx, limit = _build_inputs_with_truncation(prompt, device, max_new_tokens, model, tokenizer)
|
| 179 |
|
| 180 |
logger.info(f"[gen] device={device} dtype={eff_dtype} input_tokens={inputs['input_ids'].shape[-1]} "
|
| 181 |
f"(orig={orig_in_len}) max_ctx={ctx} limit_for_input={limit} max_new_tokens={max_new_tokens}")
|
| 182 |
|
|
|
|
| 183 |
do_sample = temperature > 1e-6
|
| 184 |
temp = max(1e-5, temperature) if do_sample else 0.0
|
| 185 |
|
|
|
|
| 186 |
eos_id = tokenizer.eos_token_id
|
| 187 |
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_id
|
| 188 |
|
|
|
|
| 205 |
with autocast_ctx:
|
| 206 |
outputs = model.generate(**inputs, **gen_kwargs)
|
| 207 |
|
|
|
|
| 208 |
input_len = inputs["input_ids"].shape[-1]
|
| 209 |
generated_ids = outputs[0][input_len:]
|
| 210 |
logger.info(f"[gen] new_tokens={generated_ids.shape[-1]}")
|
|
|
|
| 212 |
logger.info(f"[gen] text len={len(text)}\n{_snippet(text, 1200)}")
|
| 213 |
return text
|
| 214 |
|
| 215 |
+
# Offload heavy work to a worker thread so asyncio heartbeats continue
|
| 216 |
if spaces:
|
| 217 |
@spaces.GPU(duration=120)
|
| 218 |
+
def run_once_sync(prompt: str) -> str:
|
| 219 |
if torch.cuda.is_available():
|
| 220 |
logger.info("[path] ZeroGPU + CUDA")
|
| 221 |
return _run_once(prompt, device="cuda", req_dtype=torch.float16)
|
| 222 |
logger.info("[path] ZeroGPU but no CUDA -> CPU fallback")
|
| 223 |
return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
|
| 224 |
+
text = await asyncio.to_thread(run_once_sync, prompt)
|
| 225 |
else:
|
| 226 |
logger.info("[path] CPU-only runtime")
|
| 227 |
+
text = await asyncio.to_thread(_run_once, prompt, "cpu", _pick_cpu_dtype())
|
| 228 |
|
| 229 |
chunk = {
|
| 230 |
"id": rid,
|