johnbridges commited on
Commit
7471f75
·
1 Parent(s): 7f43efb
Files changed (1) hide show
  1. hf_backend.py +57 -65
hf_backend.py CHANGED
@@ -1,6 +1,6 @@
1
- # hf_backend.py
2
- import time, logging, os, contextlib
3
- from typing import Any, Dict, AsyncIterable, List
4
 
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -9,23 +9,21 @@ from config import settings
9
 
10
  try:
11
  import spaces
 
12
  except ImportError:
13
- spaces = None
14
 
15
  logger = logging.getLogger(__name__)
16
 
17
- # --- Load model/tokenizer on CPU at import time (ZeroGPU safe) ---
18
  MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
19
  logger.info(f"Loading {MODEL_ID} on CPU at startup (ZeroGPU safe)...")
20
 
21
- tokenizer = None
22
- model = None
23
- load_error = None
24
  try:
25
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False)
26
  model = AutoModelForCausalLM.from_pretrained(
27
  MODEL_ID,
28
- torch_dtype=torch.float32, # CPU-safe default
29
  trust_remote_code=True,
30
  )
31
  model.eval()
@@ -34,11 +32,7 @@ except Exception as e:
34
  logger.exception(load_error)
35
 
36
 
37
- # --- Device helpers ---
38
  def pick_device() -> str:
39
- forced = os.getenv("FORCE_DEVICE", "").lower().strip()
40
- if forced in {"cpu", "cuda", "mps"}:
41
- return forced
42
  if torch.cuda.is_available():
43
  return "cuda"
44
  if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
@@ -54,7 +48,6 @@ def pick_dtype(device: str) -> torch.dtype:
54
  return torch.float32
55
 
56
 
57
- # --- Backend class ---
58
  class HFChatBackend(ChatBackend):
59
  async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
60
  if load_error:
@@ -68,55 +61,54 @@ class HFChatBackend(ChatBackend):
68
  rid = f"chatcmpl-hf-{int(time.time())}"
69
  now = int(time.time())
70
 
71
- if spaces:
72
- @spaces.GPU(duration=120) # allow longer run
73
- def run_once(prompt: str) -> str:
74
- device = pick_device()
75
- dtype = pick_dtype(device)
76
-
77
- # Move model to GPU if needed
78
- model.to(device=device, dtype=dtype).eval()
79
-
80
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
81
- with torch.inference_mode(), torch.autocast(device_type=device, dtype=dtype):
82
- outputs = model.generate(
83
- **inputs,
84
- max_new_tokens=max_tokens,
85
- temperature=temperature,
86
- do_sample=True,
87
- )
88
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
89
  else:
90
- def run_once(prompt: str) -> str:
91
- inputs = tokenizer(prompt, return_tensors="pt")
92
- with torch.inference_mode():
93
- outputs = model.generate(
94
- **inputs,
95
- max_new_tokens=max_tokens,
96
- temperature=temperature,
97
- do_sample=True,
98
- )
99
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
100
-
101
- try:
102
- text = run_once(prompt)
103
- yield {
104
- "id": rid,
105
- "object": "chat.completion.chunk",
106
- "created": now,
107
- "model": MODEL_ID,
108
- "choices": [
109
- {"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
110
- ],
111
- }
112
- except Exception:
113
- logger.exception("HF inference failed")
114
- raise
115
-
116
-
117
- class StubImagesBackend(ImagesBackend):
118
- async def generate_b64(self, request: Dict[str, Any]) -> str:
119
- logger.warning("Image generation not supported in HF backend.")
120
- return (
121
- "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
122
- )
 
1
+ # hf_backend.py (patched)
2
+ import time, logging, os
3
+ from typing import Any, Dict, AsyncIterable
4
 
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
9
 
10
  try:
11
  import spaces
12
+ from spaces.zero.client import SpaceZeroClient
13
  except ImportError:
14
+ spaces, SpaceZeroClient = None, None
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
18
  MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
19
  logger.info(f"Loading {MODEL_ID} on CPU at startup (ZeroGPU safe)...")
20
 
21
+ tokenizer, model, load_error = None, None, None
 
 
22
  try:
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False)
24
  model = AutoModelForCausalLM.from_pretrained(
25
  MODEL_ID,
26
+ torch_dtype=torch.float32,
27
  trust_remote_code=True,
28
  )
29
  model.eval()
 
32
  logger.exception(load_error)
33
 
34
 
 
35
  def pick_device() -> str:
 
 
 
36
  if torch.cuda.is_available():
37
  return "cuda"
38
  if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
 
48
  return torch.float32
49
 
50
 
 
51
  class HFChatBackend(ChatBackend):
52
  async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
53
  if load_error:
 
61
  rid = f"chatcmpl-hf-{int(time.time())}"
62
  now = int(time.time())
63
 
64
+ # --- ✅ Extract X-IP-Token from RabbitMQ message
65
+ x_ip_token = request.get("x_ip_token")
66
+ headers = {}
67
+ if x_ip_token:
68
+ headers["X-IP-Token"] = x_ip_token
69
+ logger.info("Using X-IP-Token from request for ZeroGPU attribution")
70
+
71
+ def _gpu_inference_fn(prompt: str) -> str:
72
+ device = pick_device()
73
+ dtype = pick_dtype(device)
74
+ model.to(device=device, dtype=dtype).eval()
75
+
76
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
77
+ with torch.inference_mode(), torch.autocast(device_type=device, dtype=dtype):
78
+ outputs = model.generate(
79
+ **inputs,
80
+ max_new_tokens=max_tokens,
81
+ temperature=temperature,
82
+ do_sample=True,
83
+ )
84
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
85
+
86
+ if spaces and SpaceZeroClient:
87
+ # Use a custom SpaceZeroClient with headers
88
+ client = SpaceZeroClient(headers=headers or None)
89
+ try:
90
+ text = await client.run(_gpu_inference_fn, args=[prompt], duration=120)
91
+ except Exception:
92
+ logger.exception("HF inference (ZeroGPU) failed")
93
+ raise
94
  else:
95
+ # CPU fallback
96
+ inputs = tokenizer(prompt, return_tensors="pt")
97
+ with torch.inference_mode():
98
+ outputs = model.generate(
99
+ **inputs,
100
+ max_new_tokens=max_tokens,
101
+ temperature=temperature,
102
+ do_sample=True,
103
+ )
104
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
105
+
106
+ yield {
107
+ "id": rid,
108
+ "object": "chat.completion.chunk",
109
+ "created": now,
110
+ "model": MODEL_ID,
111
+ "choices": [
112
+ {"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
113
+ ],
114
+ }