Commit
·
d76b941
1
Parent(s):
d279e64
- hf_backend.py +35 -22
hf_backend.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
# hf_backend.py
|
| 2 |
import time, logging
|
| 3 |
-
from
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
|
@@ -33,6 +34,7 @@ except Exception as e:
|
|
| 33 |
|
| 34 |
# ---------------- helpers ----------------
|
| 35 |
def _pick_cpu_dtype() -> torch.dtype:
|
|
|
|
| 36 |
if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
|
| 37 |
try:
|
| 38 |
if torch.cpu.is_bf16_supported():
|
|
@@ -45,19 +47,22 @@ def _pick_cpu_dtype() -> torch.dtype:
|
|
| 45 |
|
| 46 |
|
| 47 |
# ---------------- global cache ----------------
|
| 48 |
-
_MODEL_CACHE:
|
| 49 |
|
| 50 |
|
| 51 |
-
def _get_model(device: str, dtype: torch.dtype):
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
|
| 56 |
cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
|
| 57 |
if hasattr(cfg, "quantization_config"):
|
| 58 |
logger.warning("Removing quantization_config from model config")
|
| 59 |
-
delattr(cfg, "quantization_config")
|
| 60 |
|
|
|
|
| 61 |
try:
|
| 62 |
model = AutoModelForCausalLM.from_pretrained(
|
| 63 |
MODEL_ID,
|
|
@@ -69,20 +74,20 @@ def _get_model(device: str, dtype: torch.dtype):
|
|
| 69 |
except Exception as e:
|
| 70 |
if device == "cpu" and dtype == torch.bfloat16:
|
| 71 |
logger.warning(f"BF16 load failed on CPU: {e}. Retrying with FP32.")
|
|
|
|
| 72 |
model = AutoModelForCausalLM.from_pretrained(
|
| 73 |
MODEL_ID,
|
| 74 |
config=cfg,
|
| 75 |
-
torch_dtype=
|
| 76 |
trust_remote_code=True,
|
| 77 |
device_map={"": "cpu"},
|
| 78 |
)
|
| 79 |
-
dtype = torch.float32
|
| 80 |
else:
|
| 81 |
raise
|
| 82 |
|
| 83 |
model.eval()
|
| 84 |
-
_MODEL_CACHE[(device,
|
| 85 |
-
return model
|
| 86 |
|
| 87 |
|
| 88 |
# ---------------- Chat Backend ----------------
|
|
@@ -105,7 +110,7 @@ class HFChatBackend(ChatBackend):
|
|
| 105 |
logger.debug("Injected X-IP-Token into ZeroGPU headers")
|
| 106 |
|
| 107 |
# Build prompt using chat template if available
|
| 108 |
-
if hasattr(tokenizer, "apply_chat_template") and tokenizer
|
| 109 |
try:
|
| 110 |
prompt = tokenizer.apply_chat_template(
|
| 111 |
messages,
|
|
@@ -119,15 +124,20 @@ class HFChatBackend(ChatBackend):
|
|
| 119 |
else:
|
| 120 |
prompt = messages[-1]["content"] if messages else "(empty)"
|
| 121 |
|
| 122 |
-
def _run_once(prompt: str, device: str,
|
| 123 |
-
model = _get_model(device,
|
| 124 |
-
|
|
|
|
|
|
|
| 125 |
|
| 126 |
with torch.inference_mode():
|
| 127 |
if device != "cpu":
|
| 128 |
-
autocast_ctx = torch.autocast(device_type=device, dtype=
|
| 129 |
else:
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
with autocast_ctx:
|
| 133 |
outputs = model.generate(
|
|
@@ -135,21 +145,24 @@ class HFChatBackend(ChatBackend):
|
|
| 135 |
max_new_tokens=max_tokens,
|
| 136 |
temperature=temperature,
|
| 137 |
do_sample=True,
|
|
|
|
| 138 |
)
|
| 139 |
|
| 140 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 141 |
|
| 142 |
if spaces:
|
| 143 |
-
#
|
| 144 |
@spaces.GPU(duration=120)
|
| 145 |
def run_once(prompt: str) -> str:
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
text = run_once(prompt)
|
| 149 |
else:
|
| 150 |
-
#
|
| 151 |
-
|
| 152 |
-
text = _run_once(prompt, device="cpu", dtype=dtype)
|
| 153 |
|
| 154 |
yield {
|
| 155 |
"id": rid,
|
|
|
|
| 1 |
# hf_backend.py
|
| 2 |
import time, logging
|
| 3 |
+
from contextlib import nullcontext
|
| 4 |
+
from typing import Any, Dict, AsyncIterable, Tuple
|
| 5 |
|
| 6 |
import torch
|
| 7 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
|
|
|
| 34 |
|
| 35 |
# ---------------- helpers ----------------
|
| 36 |
def _pick_cpu_dtype() -> torch.dtype:
|
| 37 |
+
# Prefer BF16 if CPU supports it
|
| 38 |
if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
|
| 39 |
try:
|
| 40 |
if torch.cpu.is_bf16_supported():
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
# ---------------- global cache ----------------
|
| 50 |
+
_MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}
|
| 51 |
|
| 52 |
|
| 53 |
+
def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
|
| 54 |
+
# Return model and the effective dtype actually loaded with
|
| 55 |
+
# (handles CPU BF16 -> FP32 fallback)
|
| 56 |
+
effective_key = (device, dtype)
|
| 57 |
+
if effective_key in _MODEL_CACHE:
|
| 58 |
+
return _MODEL_CACHE[effective_key], dtype
|
| 59 |
|
| 60 |
cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
|
| 61 |
if hasattr(cfg, "quantization_config"):
|
| 62 |
logger.warning("Removing quantization_config from model config")
|
| 63 |
+
delattr(cfg, "quantization_config") # delete instead of setting None
|
| 64 |
|
| 65 |
+
eff_dtype = dtype
|
| 66 |
try:
|
| 67 |
model = AutoModelForCausalLM.from_pretrained(
|
| 68 |
MODEL_ID,
|
|
|
|
| 74 |
except Exception as e:
|
| 75 |
if device == "cpu" and dtype == torch.bfloat16:
|
| 76 |
logger.warning(f"BF16 load failed on CPU: {e}. Retrying with FP32.")
|
| 77 |
+
eff_dtype = torch.float32
|
| 78 |
model = AutoModelForCausalLM.from_pretrained(
|
| 79 |
MODEL_ID,
|
| 80 |
config=cfg,
|
| 81 |
+
torch_dtype=eff_dtype,
|
| 82 |
trust_remote_code=True,
|
| 83 |
device_map={"": "cpu"},
|
| 84 |
)
|
|
|
|
| 85 |
else:
|
| 86 |
raise
|
| 87 |
|
| 88 |
model.eval()
|
| 89 |
+
_MODEL_CACHE[(device, eff_dtype)] = model
|
| 90 |
+
return model, eff_dtype
|
| 91 |
|
| 92 |
|
| 93 |
# ---------------- Chat Backend ----------------
|
|
|
|
| 110 |
logger.debug("Injected X-IP-Token into ZeroGPU headers")
|
| 111 |
|
| 112 |
# Build prompt using chat template if available
|
| 113 |
+
if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
|
| 114 |
try:
|
| 115 |
prompt = tokenizer.apply_chat_template(
|
| 116 |
messages,
|
|
|
|
| 124 |
else:
|
| 125 |
prompt = messages[-1]["content"] if messages else "(empty)"
|
| 126 |
|
| 127 |
+
def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str:
|
| 128 |
+
model, eff_dtype = _get_model(device, req_dtype)
|
| 129 |
+
|
| 130 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 131 |
+
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
|
| 132 |
|
| 133 |
with torch.inference_mode():
|
| 134 |
if device != "cpu":
|
| 135 |
+
autocast_ctx = torch.autocast(device_type=device, dtype=eff_dtype)
|
| 136 |
else:
|
| 137 |
+
if eff_dtype == torch.bfloat16:
|
| 138 |
+
autocast_ctx = torch.cpu.amp.autocast(dtype=torch.bfloat16)
|
| 139 |
+
else:
|
| 140 |
+
autocast_ctx = nullcontext()
|
| 141 |
|
| 142 |
with autocast_ctx:
|
| 143 |
outputs = model.generate(
|
|
|
|
| 145 |
max_new_tokens=max_tokens,
|
| 146 |
temperature=temperature,
|
| 147 |
do_sample=True,
|
| 148 |
+
use_cache=True,
|
| 149 |
)
|
| 150 |
|
| 151 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 152 |
|
| 153 |
if spaces:
|
| 154 |
+
# Always dispatch via ZeroGPU decorator if available.
|
| 155 |
@spaces.GPU(duration=120)
|
| 156 |
def run_once(prompt: str) -> str:
|
| 157 |
+
if torch.cuda.is_available():
|
| 158 |
+
return _run_once(prompt, device="cuda", req_dtype=torch.float16)
|
| 159 |
+
# Fallback to CPU inside the GPU context if CUDA is unavailable
|
| 160 |
+
return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
|
| 161 |
|
| 162 |
text = run_once(prompt)
|
| 163 |
else:
|
| 164 |
+
# CPU-only runtime
|
| 165 |
+
text = _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
|
|
|
|
| 166 |
|
| 167 |
yield {
|
| 168 |
"id": rid,
|