GradLLM / timesfm_backend.py
Mungert's picture
Update timesfm_backend.py
c7f8c69 verified
import time
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple, Sequence
import numpy as np
import torch
from backends_base import ChatBackend, ImagesBackend
from config import settings
logger = logging.getLogger(__name__)
# ---------- helpers ----------
def _parse_series(series: Any) -> np.ndarray:
"""
Accepts: list[float|int], list[dict{'y'|'value'}], or dict with 'values'/'y'.
Returns: 1D float32 numpy array.
"""
if series is None:
raise ValueError("series is required")
if isinstance(series, dict):
series = series.get("values") or series.get("y")
vals: List[float] = []
if isinstance(series, (list, tuple)):
if series and isinstance(series[0], dict):
for item in series:
if "y" in item:
vals.append(float(item["y"]))
elif "value" in item:
vals.append(float(item["value"]))
else:
vals = [float(x) for x in series]
else:
raise ValueError("series must be a list/tuple or dict with 'values'/'y'")
if not vals:
raise ValueError("series is empty")
return np.asarray(vals, dtype=np.float32)
def _extract_json_from_text(s: str) -> Optional[Dict[str, Any]]:
s = s.strip()
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
try:
obj = json.loads(s)
return obj if isinstance(obj, dict) else None
except Exception:
pass
if "```" in s:
parts = s.split("```")
for i in range(1, len(parts), 2):
block = parts[i]
if block.lstrip().lower().startswith("json"):
block = block.split("\n", 1)[-1]
try:
obj = json.loads(block.strip())
return obj if isinstance(obj, dict) else None
except Exception:
continue
return None
def _merge_openai_message_json(payload: Dict[str, Any]) -> Dict[str, Any]:
msgs = payload.get("messages")
if not isinstance(msgs, list):
return payload
for m in reversed(msgs):
if not isinstance(m, dict) or m.get("role") != "user":
continue
content = m.get("content")
texts: List[str] = []
if isinstance(content, list):
texts = [
p.get("text")
for p in content
if isinstance(p, dict) and p.get("type") == "text" and isinstance(p.get("text"), str)
]
elif isinstance(content, str):
texts = [content]
for t in reversed(texts):
obj = _extract_json_from_text(t)
if isinstance(obj, dict):
return {**payload, **obj}
break
return payload
# ---------- backend ----------
class TimesFMBackend(ChatBackend):
"""
TimesFM 2.5 backend.
Input JSON can be in top-level keys, in CloudEvents .data, or embedded in last user message.
Keys:
series: list[float|int|{y|value}] OR list of such lists for batch
horizon: int (>0)
Optional:
quantiles: bool (default True) -> include quantile forecasts
max_context, max_horizon: ints to override defaults
"""
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
# HF id for bookkeeping only
self.model_id = model_id or "google/timesfm-2.5-200m-pytorch"
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self._model = None # lazy
def _ensure_model(self) -> None:
if self._model is not None:
return
try:
import os
import timesfm # 2.5 API
hf_token = getattr(settings, "HF_TOKEN", None) or os.environ.get("HF_TOKEN")
cache_dir = getattr(settings, "TIMESFM_CACHE_DIR", None)
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
self.model_id,
token=hf_token,
cache_dir=cache_dir,
local_files_only=False,
)
try:
# .model holds the underlying nn.Module; fall back to instance if absent.
target = getattr(model, "model", model)
target.to(self.device) # type: ignore[arg-type]
except Exception:
pass
cfg = timesfm.ForecastConfig(
max_context=1024,
max_horizon=256,
normalize_inputs=True,
use_continuous_quantile_head=True,
force_flip_invariance=True,
infer_is_positive=True,
fix_quantile_crossing=True,
)
model.compile(cfg)
self._model = model
logger.info("TimesFM 2.5 model loaded on %s", self.device)
except Exception as e:
logger.exception("TimesFM 2.5 init failed")
raise RuntimeError(f"timesfm 2.5 init failed: {e}") from e
def _prepare_inputs(self, payload: Dict[str, Any]) -> Tuple[List[np.ndarray], int, bool, Dict[str, int]]:
# unwrap CloudEvents and nested keys
if isinstance(payload.get("data"), dict):
payload = {**payload, **payload["data"]}
if isinstance(payload.get("timeseries"), dict):
payload = {**payload, **payload["timeseries"]}
# merge JSON in last user message
payload = _merge_openai_message_json(payload)
horizon = int(payload.get("horizon", 0))
if horizon <= 0:
raise ValueError("horizon must be a positive integer")
quantiles = bool(payload.get("quantiles", True))
mc = int(payload.get("max_context", 1024))
mh = int(payload.get("max_horizon", 256))
series = payload.get("series")
inputs: List[np.ndarray]
if isinstance(series, list) and series and isinstance(series[0], (list, tuple, dict)):
# batch input
inputs = [_parse_series(s) for s in series]
else:
# single series -> batch of 1
inputs = [_parse_series(series)]
return inputs, horizon, quantiles, {"max_context": mc, "max_horizon": mh}
async def forecast(self, payload: Dict[str, Any]) -> Dict[str, Any]:
inputs, horizon, want_quantiles, cfg_overrides = self._prepare_inputs(payload)
self._ensure_model()
# if user wants larger limits, recompile once
try:
import timesfm
if cfg_overrides["max_context"] != 1024 or cfg_overrides["max_horizon"] != 256:
cfg = timesfm.ForecastConfig(
max_context=cfg_overrides["max_context"],
max_horizon=cfg_overrides["max_horizon"],
normalize_inputs=True,
use_continuous_quantile_head=want_quantiles,
force_flip_invariance=True,
infer_is_positive=True,
fix_quantile_crossing=True,
)
self._model.compile(cfg)
except Exception:
pass
try:
point, quant = self._model.forecast(horizon=horizon, inputs=inputs)
point_list = [row.astype(float).tolist() for row in point] # shape (B, H)
quant_list = None
if want_quantiles and quant is not None:
# shape (B, H, 10): mean, q10..q90
quant_list = [[row[h].astype(float).tolist() for h in range(row.shape[0])] for row in quant]
except Exception as e:
logger.exception("TimesFM 2.5 forecast failed")
raise RuntimeError(f"forecast failed: {e}") from e
# If single-series input, unwrap batch dim for convenience
single = len(inputs) == 1
return {
"model": self.model_id,
"horizon": horizon,
"forecast": point_list[0] if single else point_list,
"quantiles": (quant_list[0] if single else quant_list) if want_quantiles else None,
"backend": "timesfm-2.5",
}
async def stream(self, request: Dict[str, Any]):
rid = f"chatcmpl-timesfm-{int(time.time())}"
now = int(time.time())
try:
result = await self.forecast(dict(request) if isinstance(request, dict) else {})
content = json.dumps(result, separators=(",", ":"), ensure_ascii=False)
except Exception as e:
content = json.dumps({"error": str(e)}, separators=(",", ":"), ensure_ascii=False)
yield {
"id": rid,
"object": "chat.completion.chunk",
"created": now,
"model": self.model_id,
"choices": [
{"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"}
],
}
class StubImagesBackend(ImagesBackend):
async def generate_b64(self, request: Dict[str, Any]) -> str:
logger.warning("Image generation not supported in TimesFM backend.")
return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="