Spaces:
Sleeping
feat: enable multivariate covariate forecasting with 615 features
Browse filesCRITICAL FIX: Switch from univariate to multivariate forecasting
Previous implementation (batch inference) was only using target values,
completely ignoring all 615 collected features (weather per zone,
generation per zone, CNEC outages, LTA, load forecasts).
Changes:
- ChronosPipeline -> Chronos2Pipeline (supports covariates)
- Model: amazon/chronos-t5-large -> amazon/chronos-2
- Dtype: bfloat16 -> float32 (required for chronos-2)
- Inference: predict() tensor API -> predict_df() DataFrame API
- Now passes BOTH context_data AND future_data (615 features)
- Removed batch inference (revert to per-border for covariate support)
This enables Chronos-2's zero-shot multivariate forecasting capability:
- Group attention mechanism shares information across series & covariates
- In-context learning with arbitrary exogenous features
- No fine-tuning required - works in zero-shot mode
Expected impact: Significantly improved forecast accuracy by leveraging
all collected features instead of just historical target values.
Files modified:
- src/forecasting/chronos_inference.py (v1.1.0)
Co-Authored-By: Claude <[email protected]>
- src/forecasting/chronos_inference.py +83 -134
|
@@ -1,8 +1,9 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
Chronos-2 Inference Pipeline
|
| 4 |
Standalone inference script for HuggingFace Space deployment.
|
| 5 |
-
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import os
|
|
@@ -14,7 +15,7 @@ import pandas as pd
|
|
| 14 |
import numpy as np
|
| 15 |
import torch
|
| 16 |
from datasets import load_dataset
|
| 17 |
-
from chronos import
|
| 18 |
|
| 19 |
from .dynamic_forecast import DynamicForecast
|
| 20 |
from .feature_availability import FeatureAvailability
|
|
@@ -22,23 +23,24 @@ from .feature_availability import FeatureAvailability
|
|
| 22 |
|
| 23 |
class ChronosInferencePipeline:
|
| 24 |
"""
|
| 25 |
-
Production inference pipeline for Chronos-2 zero-shot forecasting.
|
|
|
|
| 26 |
Designed for deployment as API endpoint on HuggingFace Spaces.
|
| 27 |
"""
|
| 28 |
|
| 29 |
def __init__(
|
| 30 |
self,
|
| 31 |
-
model_name: str = "amazon/chronos-
|
| 32 |
device: str = "cuda",
|
| 33 |
-
dtype: str = "
|
| 34 |
):
|
| 35 |
"""
|
| 36 |
Initialize inference pipeline.
|
| 37 |
|
| 38 |
Args:
|
| 39 |
-
model_name: HuggingFace model identifier
|
| 40 |
device: Device for inference ('cuda' or 'cpu')
|
| 41 |
-
dtype: Data type for model weights
|
| 42 |
"""
|
| 43 |
self.model_name = model_name
|
| 44 |
self.device = device
|
|
@@ -50,7 +52,7 @@ class ChronosInferencePipeline:
|
|
| 50 |
self._borders = None
|
| 51 |
|
| 52 |
def _load_model(self):
|
| 53 |
-
"""Load Chronos model (cached after first call)"""
|
| 54 |
if self._pipeline is None:
|
| 55 |
print(f"Loading {self.model_name}...")
|
| 56 |
start_time = time.time()
|
|
@@ -61,10 +63,10 @@ class ChronosInferencePipeline:
|
|
| 61 |
"float32": torch.float32
|
| 62 |
}
|
| 63 |
|
| 64 |
-
self._pipeline =
|
| 65 |
self.model_name,
|
| 66 |
device_map=self.device,
|
| 67 |
-
torch_dtype=dtype_map.get(self.dtype, torch.
|
| 68 |
)
|
| 69 |
|
| 70 |
print(f"Model loaded in {time.time() - start_time:.1f}s")
|
|
@@ -159,148 +161,95 @@ class ChronosInferencePipeline:
|
|
| 159 |
|
| 160 |
total_start = time.time()
|
| 161 |
|
| 162 |
-
#
|
| 163 |
-
#
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
print(f"\n[BATCH] Preparing contexts for {len(forecast_borders)} borders...")
|
| 167 |
-
all_contexts = []
|
| 168 |
-
all_border_names = []
|
| 169 |
|
| 170 |
for i, border in enumerate(forecast_borders, 1):
|
| 171 |
-
|
|
|
|
|
|
|
| 172 |
try:
|
| 173 |
-
# Extract data
|
| 174 |
context_data, future_data = forecaster.prepare_forecast_data(
|
| 175 |
run_date=run_datetime,
|
| 176 |
border=border
|
| 177 |
)
|
| 178 |
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
# Extract context values and convert to PyTorch tensor
|
| 183 |
-
context = torch.from_numpy(context_data[target_col].values).float()
|
| 184 |
-
all_contexts.append(context)
|
| 185 |
-
all_border_names.append(border)
|
| 186 |
-
|
| 187 |
-
except Exception as e:
|
| 188 |
-
import traceback
|
| 189 |
-
error_msg = f"{type(e).__name__}: {str(e)}"
|
| 190 |
-
traceback_str = traceback.format_exc()
|
| 191 |
-
print(f" [ERROR] {border}: {error_msg}", flush=True)
|
| 192 |
-
results['borders'][border] = {'error': error_msg, 'traceback': traceback_str}
|
| 193 |
-
|
| 194 |
-
# Process contexts in sub-batches
|
| 195 |
-
if all_contexts:
|
| 196 |
-
num_contexts = len(all_contexts)
|
| 197 |
-
num_sub_batches = (num_contexts + SUB_BATCH_SIZE - 1) // SUB_BATCH_SIZE
|
| 198 |
-
|
| 199 |
-
print(f"\n[BATCH] Running inference in {num_sub_batches} sub-batches of {SUB_BATCH_SIZE} borders...")
|
| 200 |
-
|
| 201 |
-
all_forecasts = []
|
| 202 |
-
total_inference_time = 0
|
| 203 |
-
|
| 204 |
-
for batch_idx in range(num_sub_batches):
|
| 205 |
-
start_idx = batch_idx * SUB_BATCH_SIZE
|
| 206 |
-
end_idx = min(start_idx + SUB_BATCH_SIZE, num_contexts)
|
| 207 |
-
|
| 208 |
-
# Get sub-batch
|
| 209 |
-
sub_batch_contexts = all_contexts[start_idx:end_idx]
|
| 210 |
-
sub_batch_names = all_border_names[start_idx:end_idx]
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
inference_start = time.time()
|
| 217 |
-
|
| 218 |
-
# Run batch inference
|
| 219 |
-
batch_forecasts = pipeline.predict(
|
| 220 |
-
inputs=batch_tensor,
|
| 221 |
prediction_length=prediction_hours,
|
|
|
|
|
|
|
|
|
|
| 222 |
num_samples=num_samples
|
| 223 |
)
|
| 224 |
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
try:
|
| 252 |
-
# Extract this border's forecast from batch
|
| 253 |
-
forecast = batch_forecasts[i] # Extract from batch dimension
|
| 254 |
-
|
| 255 |
-
# Calculate quantiles
|
| 256 |
-
forecast_numpy = forecast.numpy()
|
| 257 |
-
print(f"[DEBUG] Raw forecast shape: {forecast_numpy.shape}", flush=True)
|
| 258 |
-
|
| 259 |
-
# Chronos may return (batch, num_samples, time) or (num_samples, time)
|
| 260 |
-
# Squeeze any batch dimension (if present)
|
| 261 |
-
if forecast_numpy.ndim == 3:
|
| 262 |
-
print(f"[DEBUG] 3D forecast detected, squeezing batch dimension", flush=True)
|
| 263 |
-
forecast_numpy = forecast_numpy.squeeze(axis=0) # Remove batch dim
|
| 264 |
-
|
| 265 |
-
print(f"[DEBUG] Forecast shape after squeeze: {forecast_numpy.shape}, Expected: ({num_samples}, {prediction_hours}) or ({prediction_hours}, {num_samples})", flush=True)
|
| 266 |
-
|
| 267 |
-
# Now forecast should be 2D: either (num_samples, time) or (time, num_samples)
|
| 268 |
-
# Compute median along samples axis to get (time,) shape
|
| 269 |
-
if forecast_numpy.shape[0] == num_samples and forecast_numpy.shape[1] == prediction_hours:
|
| 270 |
-
# Shape is (num_samples, time) - use axis=0
|
| 271 |
-
print(f"[DEBUG] Using axis=0 for shape (num_samples={num_samples}, time={prediction_hours})", flush=True)
|
| 272 |
median = np.median(forecast_numpy, axis=0)
|
| 273 |
q10 = np.quantile(forecast_numpy, 0.1, axis=0)
|
| 274 |
q90 = np.quantile(forecast_numpy, 0.9, axis=0)
|
| 275 |
-
|
| 276 |
-
# Shape is (time, num_samples) - use axis=1
|
| 277 |
-
print(f"[DEBUG] Using axis=1 for shape (time={prediction_hours}, num_samples={num_samples})", flush=True)
|
| 278 |
median = np.median(forecast_numpy, axis=1)
|
| 279 |
q10 = np.quantile(forecast_numpy, 0.1, axis=1)
|
| 280 |
q90 = np.quantile(forecast_numpy, 0.9, axis=1)
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
# Add summary metadata
|
| 306 |
results['metadata']['total_time_s'] = time.time() - total_start
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Chronos-2 Inference Pipeline with Covariate Support
|
| 4 |
Standalone inference script for HuggingFace Space deployment.
|
| 5 |
+
Uses predict_df() API to enable multivariate forecasting with weather, generation, CNEC outages.
|
| 6 |
+
FORCE REBUILD: v1.1.0
|
| 7 |
"""
|
| 8 |
|
| 9 |
import os
|
|
|
|
| 15 |
import numpy as np
|
| 16 |
import torch
|
| 17 |
from datasets import load_dataset
|
| 18 |
+
from chronos import Chronos2Pipeline
|
| 19 |
|
| 20 |
from .dynamic_forecast import DynamicForecast
|
| 21 |
from .feature_availability import FeatureAvailability
|
|
|
|
| 23 |
|
| 24 |
class ChronosInferencePipeline:
|
| 25 |
"""
|
| 26 |
+
Production inference pipeline for Chronos-2 zero-shot forecasting WITH COVARIATES.
|
| 27 |
+
Uses predict_df() API to leverage all 615 collected features (weather, generation, outages, etc.)
|
| 28 |
Designed for deployment as API endpoint on HuggingFace Spaces.
|
| 29 |
"""
|
| 30 |
|
| 31 |
def __init__(
|
| 32 |
self,
|
| 33 |
+
model_name: str = "amazon/chronos-2",
|
| 34 |
device: str = "cuda",
|
| 35 |
+
dtype: str = "float32"
|
| 36 |
):
|
| 37 |
"""
|
| 38 |
Initialize inference pipeline.
|
| 39 |
|
| 40 |
Args:
|
| 41 |
+
model_name: HuggingFace model identifier (chronos-2 supports covariates)
|
| 42 |
device: Device for inference ('cuda' or 'cpu')
|
| 43 |
+
dtype: Data type for model weights (float32 for chronos-2)
|
| 44 |
"""
|
| 45 |
self.model_name = model_name
|
| 46 |
self.device = device
|
|
|
|
| 52 |
self._borders = None
|
| 53 |
|
| 54 |
def _load_model(self):
|
| 55 |
+
"""Load Chronos-2 model (cached after first call)"""
|
| 56 |
if self._pipeline is None:
|
| 57 |
print(f"Loading {self.model_name}...")
|
| 58 |
start_time = time.time()
|
|
|
|
| 63 |
"float32": torch.float32
|
| 64 |
}
|
| 65 |
|
| 66 |
+
self._pipeline = Chronos2Pipeline.from_pretrained(
|
| 67 |
self.model_name,
|
| 68 |
device_map=self.device,
|
| 69 |
+
torch_dtype=dtype_map.get(self.dtype, torch.float32)
|
| 70 |
)
|
| 71 |
|
| 72 |
print(f"Model loaded in {time.time() - start_time:.1f}s")
|
|
|
|
| 161 |
|
| 162 |
total_start = time.time()
|
| 163 |
|
| 164 |
+
# PER-BORDER INFERENCE WITH COVARIATES
|
| 165 |
+
# Using predict_df() API to leverage all 615 features (weather, generation, CNEC outages, etc.)
|
| 166 |
+
print(f"\n[COVARIATE FORECAST] Running inference for {len(forecast_borders)} borders with 615 features...")
|
| 167 |
+
print(f" Features: weather per zone, generation per zone, CNEC outages, LTA, load forecasts")
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
for i, border in enumerate(forecast_borders, 1):
|
| 170 |
+
border_start = time.time()
|
| 171 |
+
print(f"\n [{i}/{len(forecast_borders)}] {border}...", flush=True)
|
| 172 |
+
|
| 173 |
try:
|
| 174 |
+
# Extract data WITH covariates
|
| 175 |
context_data, future_data = forecaster.prepare_forecast_data(
|
| 176 |
run_date=run_datetime,
|
| 177 |
border=border
|
| 178 |
)
|
| 179 |
|
| 180 |
+
print(f" Context shape: {context_data.shape}, Future shape: {future_data.shape}", flush=True)
|
| 181 |
+
print(f" Using {len(future_data.columns)-2} future covariates for multivariate forecast", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
+
# Run covariate-informed inference using DataFrame API
|
| 184 |
+
forecasts_df = pipeline.predict_df(
|
| 185 |
+
context_data, # Historical data with ALL features
|
| 186 |
+
future_df=future_data, # Future covariates (615 features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
prediction_length=prediction_hours,
|
| 188 |
+
id_column='border',
|
| 189 |
+
timestamp_column='timestamp',
|
| 190 |
+
target='target',
|
| 191 |
num_samples=num_samples
|
| 192 |
)
|
| 193 |
|
| 194 |
+
# Extract quantiles from probabilistic forecast
|
| 195 |
+
# predict_df returns samples - we need to compute quantiles
|
| 196 |
+
# The output format depends on Chronos2Pipeline implementation
|
| 197 |
+
# Typically returns DataFrame with columns per quantile or sample
|
| 198 |
+
|
| 199 |
+
# Convert to numpy for quantile calculation
|
| 200 |
+
if isinstance(forecasts_df, pd.DataFrame):
|
| 201 |
+
# Extract sample columns (format: sample_0, sample_1, ...)
|
| 202 |
+
sample_cols = [col for col in forecasts_df.columns if col.startswith('sample_')]
|
| 203 |
+
if sample_cols:
|
| 204 |
+
# Shape: (time, num_samples)
|
| 205 |
+
forecast_samples = forecasts_df[sample_cols].values
|
| 206 |
+
median = np.median(forecast_samples, axis=1)
|
| 207 |
+
q10 = np.quantile(forecast_samples, 0.1, axis=1)
|
| 208 |
+
q90 = np.quantile(forecast_samples, 0.9, axis=1)
|
| 209 |
+
else:
|
| 210 |
+
# Fallback: single prediction column
|
| 211 |
+
median = forecasts_df['prediction'].values if 'prediction' in forecasts_df.columns else forecasts_df.iloc[:, 0].values
|
| 212 |
+
q10 = median.copy() # No uncertainty if single prediction
|
| 213 |
+
q90 = median.copy()
|
| 214 |
+
else:
|
| 215 |
+
# Handle tensor output (fallback)
|
| 216 |
+
forecast_numpy = forecasts_df.numpy() if hasattr(forecasts_df, 'numpy') else np.array(forecasts_df)
|
| 217 |
+
if forecast_numpy.ndim == 2:
|
| 218 |
+
# (num_samples, time) or (time, num_samples)
|
| 219 |
+
if forecast_numpy.shape[0] == num_samples:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
median = np.median(forecast_numpy, axis=0)
|
| 221 |
q10 = np.quantile(forecast_numpy, 0.1, axis=0)
|
| 222 |
q90 = np.quantile(forecast_numpy, 0.9, axis=0)
|
| 223 |
+
else:
|
|
|
|
|
|
|
| 224 |
median = np.median(forecast_numpy, axis=1)
|
| 225 |
q10 = np.quantile(forecast_numpy, 0.1, axis=1)
|
| 226 |
q90 = np.quantile(forecast_numpy, 0.9, axis=1)
|
| 227 |
+
else:
|
| 228 |
+
median = forecast_numpy.flatten()
|
| 229 |
+
q10 = median.copy()
|
| 230 |
+
q90 = median.copy()
|
| 231 |
+
|
| 232 |
+
inference_time = time.time() - border_start
|
| 233 |
+
|
| 234 |
+
# Store results
|
| 235 |
+
results['borders'][border] = {
|
| 236 |
+
'median': median.tolist(),
|
| 237 |
+
'q10': q10.tolist(),
|
| 238 |
+
'q90': q90.tolist(),
|
| 239 |
+
'inference_time_s': inference_time,
|
| 240 |
+
'used_covariates': True,
|
| 241 |
+
'num_features': len(future_data.columns) - 2 # Exclude border and timestamp
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
print(f" [OK] Complete in {inference_time:.1f}s (WITH {len(future_data.columns)-2} covariates)", flush=True)
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
import traceback
|
| 248 |
+
error_msg = f"{type(e).__name__}: {str(e)}"
|
| 249 |
+
traceback_str = traceback.format_exc()
|
| 250 |
+
print(f" [ERROR] {error_msg}", flush=True)
|
| 251 |
+
print(f"Traceback:\n{traceback_str}", flush=True)
|
| 252 |
+
results['borders'][border] = {'error': error_msg, 'traceback': traceback_str}
|
| 253 |
|
| 254 |
# Add summary metadata
|
| 255 |
results['metadata']['total_time_s'] = time.time() - total_start
|