# timesfm_backend.py import time import json import logging from typing import Any, Dict, List, Optional import numpy as np import torch from backends_base import ChatBackend, ImagesBackend from config import settings logger = logging.getLogger(__name__) # ---------------- TimesFM import (fallback-safe) ---------------- try: from timesfm import TimesFm # Google TimesFM 2.5+ _TIMESFM_AVAILABLE = True except Exception as e: logger.warning("timesfm not available (%s) — using naive fallback.", e) TimesFm = None # type: ignore _TIMESFM_AVAILABLE = False # ---------------- helpers ---------------- def _parse_series(series: Any) -> np.ndarray: """ Accepts: list[float|int], list[dict{'y'|'value'}], or dict with 'values'/'y'. Returns: 1D float32 numpy array. """ if series is None: raise ValueError("series is required") if isinstance(series, dict): # allow {"values":[...]} or {"y":[...]} series = series.get("values") or series.get("y") vals: List[float] = [] if isinstance(series, (list, tuple)): if series and isinstance(series[0], dict): for item in series: if "y" in item: vals.append(float(item["y"])) elif "value" in item: vals.append(float(item["value"])) else: vals = [float(x) for x in series] else: raise ValueError("series must be a list/tuple or dict with 'values'/'y'") if not vals: raise ValueError("series is empty") return np.asarray(vals, dtype=np.float32) def _fallback_forecast(y: np.ndarray, horizon: int) -> np.ndarray: """ Naive fallback: mean of last 4 (or all if <4), repeated H times. """ if horizon <= 0: return np.zeros((0,), dtype=np.float32) k = 4 if y.shape[0] >= 4 else y.shape[0] base = float(np.mean(y[-k:])) return np.full((horizon,), base, dtype=np.float32) def _extract_json_from_text(s: str) -> Optional[Dict[str, Any]]: """ Try to parse JSON from a plain string or a fenced ```json block. Returns dict or None. """ s = s.strip() # whole-string JSON object/array if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")): try: obj = json.loads(s) return obj if isinstance(obj, dict) else None except Exception: pass # fenced code blocks if "```" in s: parts = s.split("```") for i in range(1, len(parts), 2): block = parts[i] if block.lstrip().lower().startswith("json"): block = block.split("\n", 1)[-1] try: obj = json.loads(block.strip()) return obj if isinstance(obj, dict) else None except Exception: continue return None def _merge_openai_message_json(payload: Dict[str, Any]) -> Dict[str, Any]: """ OpenAI chat format compatibility: payload["messages"] may hold user JSON in the last user message. content can be a plain string or a list of parts [{"type":"text","text":...}]. If a JSON object is found, merge its keys into payload. """ msgs = payload.get("messages") if not isinstance(msgs, list): return payload for m in reversed(msgs): if not isinstance(m, dict) or m.get("role") != "user": continue content = m.get("content") texts: List[str] = [] if isinstance(content, list): texts = [ p.get("text") for p in content if isinstance(p, dict) and p.get("type") == "text" and isinstance(p.get("text"), str) ] elif isinstance(content, str): texts = [content] for t in reversed(texts): obj = _extract_json_from_text(t) if isinstance(obj, dict): return {**payload, **obj} break # only inspect last user return payload # ---------------- backend ---------------- class TimesFMBackend(ChatBackend): """ Accepts OpenAI chat-completions requests. Pulls timeseries config from: - top-level keys, OR - payload['data'] (CloudEvents wrapper), OR - last user message JSON (OpenAI format). Keys: series: list[float|int|{y|value}] horizon: int (>0) freq: optional str """ def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): self.model_id = model_id or "google/timesfm-2.5-200m-pytorch" self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self._model: Optional[TimesFm] = None # type: ignore def _ensure_model(self) -> None: if self._model is not None or not _TIMESFM_AVAILABLE: return try: # Set lengths compatible with the 2.5 checkpoints. self._model = TimesFm( context_len=512, horizon_len=128, input_patch_len=32, ) self._model.load_from_checkpoint(self.model_id) try: self._model.to(self.device) # type: ignore[attr-defined] except Exception: pass logger.info("TimesFM loaded from %s on %s", self.model_id, self.device) except Exception as e: logger.exception("TimesFM init failed; fallback only. %s", e) self._model = None async def forecast(self, payload: Dict[str, Any]) -> Dict[str, Any]: # unwrap CloudEvents .data and nested .timeseries if isinstance(payload.get("data"), dict): payload = {**payload, **payload["data"]} if isinstance(payload.get("timeseries"), dict): payload = {**payload, **payload["timeseries"]} # merge JSON embedded in last user message (OpenAI format) payload = _merge_openai_message_json(payload) y = _parse_series(payload.get("series")) horizon = int(payload.get("horizon", 0)) freq = payload.get("freq") if horizon <= 0: raise ValueError("horizon must be a positive integer") self._ensure_model() note = None if _TIMESFM_AVAILABLE and self._model is not None: try: x = torch.tensor(y, dtype=torch.float32, device=self.device).unsqueeze(0) # [1, T] preds = self._model.forecast_on_batch(x, horizon) # -> [1, H] fc = preds[0].detach().cpu().numpy().astype(float).tolist() except Exception as e: logger.exception("TimesFM forecast failed; fallback used. %s", e) fc = _fallback_forecast(y, horizon).tolist() note = "fallback_used_due_to_predict_error" else: fc = _fallback_forecast(y, horizon).tolist() note = "fallback_used_timesfm_missing" return { "model": self.model_id, "horizon": horizon, "freq": freq, "forecast": fc, "note": note, } async def stream(self, request: Dict[str, Any]): """ OA-compatible streaming shim: Emits exactly one chat.completion.chunk with compact JSON content. """ rid = f"chatcmpl-timesfm-{int(time.time())}" now = int(time.time()) payload = dict(request) if isinstance(request, dict) else {} try: result = await self.forecast(payload) except Exception as e: content = json.dumps({"error": str(e)}, separators=(",", ":"), ensure_ascii=False) yield { "id": rid, "object": "chat.completion.chunk", "created": now, "model": self.model_id, "choices": [ {"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"} ], } return content = json.dumps( { "model": result["model"], "horizon": result["horizon"], "freq": result["freq"], "forecast": result["forecast"], "note": result.get("note"), "backend": "timesfm", }, separators=(",", ":"), ensure_ascii=False, ) yield { "id": rid, "object": "chat.completion.chunk", "created": now, "model": self.model_id, "choices": [ {"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"} ], } # ---------------- images stub ---------------- class StubImagesBackend(ImagesBackend): async def generate_b64(self, request: Dict[str, Any]) -> str: logger.warning("Image generation not supported in TimesFM backend.") return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="