# space/tools/ts_forecast_tool.py import os import logging from typing import Optional, Dict import torch import pandas as pd import numpy as np from utils.tracing import Tracer from utils.config import AppConfig from transformers import AutoModel, AutoConfig logger = logging.getLogger(__name__) # Constants MIN_SERIES_LENGTH = 2 MAX_SERIES_LENGTH = 10000 MIN_HORIZON = 1 MAX_HORIZON = 365 DEFAULT_MODEL_ID = "ibm-granite/granite-timeseries-ttm-r1" class ForecastToolError(Exception): """Custom exception for forecast tool errors.""" pass class TimeseriesForecastTool: """ Lightweight wrapper around Granite Time Series models for zero-shot forecasting. This wrapper: - Loads the model with AutoModel.from_pretrained - Validates input series and horizon - Attempts multiple inference methods (predict, forward with prediction_length) - Returns a Pandas DataFrame with forecast column - Provides comprehensive error handling and logging Expected input: - series: pd.Series with DatetimeIndex (regular frequency recommended) - horizon: int, number of future steps to forecast """ def __init__( self, cfg: Optional[AppConfig], tracer: Optional[Tracer], model_id: str = DEFAULT_MODEL_ID, device: Optional[str] = None, ): self.cfg = cfg self.tracer = tracer self.model_id = model_id self.model = None self.config = None # Determine device self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"TimeseriesForecastTool initialized with device: {self.device}") # Lazy loading - model loaded on first use self._initialized = False def _ensure_loaded(self): """Lazy load the model and configuration.""" if self._initialized: return try: logger.info(f"Loading Granite time series model: {self.model_id}") # Load configuration try: self.config = AutoConfig.from_pretrained(self.model_id) logger.info(f"Model config loaded: {type(self.config).__name__}") except Exception as e: logger.warning(f"Could not load model config: {e}") self.config = None # Load model try: self.model = AutoModel.from_pretrained( self.model_id, trust_remote_code=True # Required for some custom models ) self.model.to(self.device) self.model.eval() logger.info(f"Model loaded successfully: {type(self.model).__name__}") except Exception as e: raise ForecastToolError( f"Failed to load model '{self.model_id}': {e}\n" "Ensure the model is available and transformers is up to date." ) from e self._initialized = True except ForecastToolError: raise except Exception as e: raise ForecastToolError(f"Model initialization failed: {e}") from e def _validate_series(self, series: pd.Series) -> tuple[bool, str]: """ Validate input time series. Returns (is_valid, error_message). """ if not isinstance(series, pd.Series): return False, "Input must be a pandas Series" if series.empty: return False, "Series is empty" if len(series) < MIN_SERIES_LENGTH: return False, f"Series too short (min {MIN_SERIES_LENGTH} points required)" if len(series) > MAX_SERIES_LENGTH: return False, f"Series too long (max {MAX_SERIES_LENGTH} points allowed)" # Check for nulls if series.isnull().any(): null_count = series.isnull().sum() return False, f"Series contains {null_count} null values. Please handle missing data first." # Check for infinite values if not np.isfinite(series).all(): return False, "Series contains infinite values" # Warn if not numeric if not pd.api.types.is_numeric_dtype(series): return False, f"Series must be numeric, got dtype: {series.dtype}" return True, "" def _validate_horizon(self, horizon: int) -> tuple[bool, str]: """ Validate forecast horizon. Returns (is_valid, error_message). """ try: h = int(horizon) except (TypeError, ValueError): return False, f"Horizon must be an integer, got: {horizon}" if h < MIN_HORIZON: return False, f"Horizon too small (min {MIN_HORIZON})" if h > MAX_HORIZON: return False, f"Horizon too large (max {MAX_HORIZON})" return True, "" def _prepare_input_tensor(self, series: pd.Series) -> torch.Tensor: """ Convert pandas Series to PyTorch tensor. Handles type conversion and device placement. """ try: # Convert to float32 numpy array values = series.astype("float32").to_numpy() # Create tensor and move to device tensor = torch.tensor(values, dtype=torch.float32, device=self.device) # Add batch dimension [1, seq_len] tensor = tensor.unsqueeze(0) logger.debug(f"Input tensor shape: {tensor.shape}, device: {tensor.device}") return tensor except Exception as e: raise ForecastToolError(f"Failed to prepare input tensor: {e}") from e def _try_predict_method(self, x: torch.Tensor, horizon: int) -> Optional[np.ndarray]: """ Try using the model's .predict() method. Returns None if method doesn't exist or fails. """ if not hasattr(self.model, "predict"): logger.debug("Model has no 'predict' method") return None try: logger.info("Attempting forecast with .predict() method") preds = self.model.predict(x, prediction_length=horizon) # Convert to tensor if needed if not isinstance(preds, torch.Tensor): preds = torch.tensor(preds, device=self.device) # Extract numpy array output = preds.squeeze().detach().cpu().numpy() # Validate output shape if output.shape[-1] != horizon: logger.warning( f"Prediction length mismatch: expected {horizon}, got {output.shape[-1]}" ) logger.info(f"Forecast successful via .predict(): {output.shape}") return output except Exception as e: logger.warning(f"predict() method failed: {e}") return None def _try_forward_method(self, x: torch.Tensor, horizon: int) -> Optional[np.ndarray]: """ Try using the model's forward() method with prediction_length parameter. Returns None if method fails. """ try: logger.info("Attempting forecast with forward(prediction_length=...)") outputs = self.model(x, prediction_length=horizon) # Try to extract predictions from various possible output formats prediction_tensor = None # Check common attribute names for attr in ("predictions", "prediction", "logits", "forecast", "output"): if hasattr(outputs, attr): candidate = getattr(outputs, attr) # Handle tuple/list outputs if isinstance(candidate, (tuple, list)): candidate = candidate[0] # Convert to tensor if needed if not isinstance(candidate, torch.Tensor): candidate = torch.tensor(candidate, device=self.device) prediction_tensor = candidate logger.debug(f"Found predictions in attribute: {attr}") break # If outputs is directly a tensor if prediction_tensor is None and isinstance(outputs, torch.Tensor): prediction_tensor = outputs logger.debug("Using raw tensor output") if prediction_tensor is None: logger.warning("Could not extract predictions from forward() output") return None # Convert to numpy output = prediction_tensor.squeeze().detach().cpu().numpy() # Handle multi-dimensional outputs if output.ndim > 1: # Take the last row or flatten based on shape if output.shape[0] == horizon: output = output.flatten() else: output = output[-1] if output.shape[0] < output.shape[1] else output.flatten() # Ensure correct length if len(output) != horizon: logger.warning( f"Output length {len(output)} doesn't match horizon {horizon}. Truncating/padding." ) if len(output) > horizon: output = output[:horizon] else: # Pad with last value output = np.pad(output, (0, horizon - len(output)), mode='edge') logger.info(f"Forecast successful via forward(): {output.shape}") return output except TypeError as e: logger.warning(f"forward() doesn't accept prediction_length: {e}") return None except Exception as e: logger.warning(f"forward() method failed: {e}") return None def zeroshot_forecast(self, series: pd.Series, horizon: int = 96) -> pd.DataFrame: """ Generate zero-shot forecast for input time series. Args: series: Input time series (pd.Series with numeric values) horizon: Number of periods to forecast (default: 96) Returns: DataFrame with 'forecast' column containing predictions Raises: ForecastToolError: If forecasting fails """ try: # Validate inputs is_valid, error_msg = self._validate_series(series) if not is_valid: raise ForecastToolError(f"Invalid series: {error_msg}") is_valid, error_msg = self._validate_horizon(horizon) if not is_valid: raise ForecastToolError(f"Invalid horizon: {error_msg}") # Ensure model is loaded self._ensure_loaded() # Log input statistics logger.info( f"Forecasting: series_length={len(series)}, " f"horizon={horizon}, " f"series_mean={series.mean():.2f}, " f"series_std={series.std():.2f}" ) # Prepare input tensor x = self._prepare_input_tensor(series) # Try prediction methods in order of preference output = None with torch.no_grad(): # Method 1: Try .predict() output = self._try_predict_method(x, horizon) # Method 2: Try forward with prediction_length if output is None: output = self._try_forward_method(x, horizon) # If all methods failed if output is None: raise ForecastToolError( "Could not generate forecast using available model methods.\n" "The model may not support zero-shot forecasting with this interface.\n" "Suggestions:\n" " • Check model documentation for correct usage\n" " • Ensure transformers library is up to date\n" " • Try a different model or use traditional forecasting (ARIMA, Prophet)\n" f" • Model type: {type(self.model).__name__}" ) # Create output DataFrame result_df = pd.DataFrame({"forecast": output}) # Log output statistics logger.info( f"Forecast complete: " f"mean={output.mean():.2f}, " f"std={output.std():.2f}, " f"min={output.min():.2f}, " f"max={output.max():.2f}" ) # Trace event if self.tracer: self.tracer.trace_event("forecast", { "series_length": len(series), "horizon": horizon, "forecast_mean": float(output.mean()), "forecast_std": float(output.std()) }) return result_df except ForecastToolError: raise except Exception as e: error_msg = f"Forecasting failed unexpectedly: {str(e)}" logger.error(error_msg) if self.tracer: self.tracer.trace_event("forecast_error", {"error": error_msg}) raise ForecastToolError(error_msg) from e def get_model_info(self) -> Dict[str, any]: """Get information about the loaded model.""" self._ensure_loaded() return { "model_id": self.model_id, "model_type": type(self.model).__name__, "device": str(self.device), "has_predict": hasattr(self.model, "predict"), "config": str(self.config) if self.config else None }