Evgueni Poloukarov Claude commited on
Commit
0b4284f
·
1 Parent(s): e9e9e15

feat: enable multivariate covariate forecasting with 615 features

Browse files

CRITICAL FIX: Switch from univariate to multivariate forecasting

Previous implementation (batch inference) was only using target values,
completely ignoring all 615 collected features (weather per zone,
generation per zone, CNEC outages, LTA, load forecasts).

Changes:
- ChronosPipeline -> Chronos2Pipeline (supports covariates)
- Model: amazon/chronos-t5-large -> amazon/chronos-2
- Dtype: bfloat16 -> float32 (required for chronos-2)
- Inference: predict() tensor API -> predict_df() DataFrame API
- Now passes BOTH context_data AND future_data (615 features)
- Removed batch inference (revert to per-border for covariate support)

This enables Chronos-2's zero-shot multivariate forecasting capability:
- Group attention mechanism shares information across series & covariates
- In-context learning with arbitrary exogenous features
- No fine-tuning required - works in zero-shot mode

Expected impact: Significantly improved forecast accuracy by leveraging
all collected features instead of just historical target values.

Files modified:
- src/forecasting/chronos_inference.py (v1.1.0)

Co-Authored-By: Claude <[email protected]>

Files changed (1) hide show
  1. src/forecasting/chronos_inference.py +83 -134
src/forecasting/chronos_inference.py CHANGED
@@ -1,8 +1,9 @@
1
  #!/usr/bin/env python3
2
  """
3
- Chronos-2 Inference Pipeline
4
  Standalone inference script for HuggingFace Space deployment.
5
- FORCE REBUILD: v1.0.7
 
6
  """
7
 
8
  import os
@@ -14,7 +15,7 @@ import pandas as pd
14
  import numpy as np
15
  import torch
16
  from datasets import load_dataset
17
- from chronos import ChronosPipeline
18
 
19
  from .dynamic_forecast import DynamicForecast
20
  from .feature_availability import FeatureAvailability
@@ -22,23 +23,24 @@ from .feature_availability import FeatureAvailability
22
 
23
  class ChronosInferencePipeline:
24
  """
25
- Production inference pipeline for Chronos-2 zero-shot forecasting.
 
26
  Designed for deployment as API endpoint on HuggingFace Spaces.
27
  """
28
 
29
  def __init__(
30
  self,
31
- model_name: str = "amazon/chronos-t5-large",
32
  device: str = "cuda",
33
- dtype: str = "bfloat16"
34
  ):
35
  """
36
  Initialize inference pipeline.
37
 
38
  Args:
39
- model_name: HuggingFace model identifier
40
  device: Device for inference ('cuda' or 'cpu')
41
- dtype: Data type for model weights
42
  """
43
  self.model_name = model_name
44
  self.device = device
@@ -50,7 +52,7 @@ class ChronosInferencePipeline:
50
  self._borders = None
51
 
52
  def _load_model(self):
53
- """Load Chronos model (cached after first call)"""
54
  if self._pipeline is None:
55
  print(f"Loading {self.model_name}...")
56
  start_time = time.time()
@@ -61,10 +63,10 @@ class ChronosInferencePipeline:
61
  "float32": torch.float32
62
  }
63
 
64
- self._pipeline = ChronosPipeline.from_pretrained(
65
  self.model_name,
66
  device_map=self.device,
67
- torch_dtype=dtype_map.get(self.dtype, torch.bfloat16)
68
  )
69
 
70
  print(f"Model loaded in {time.time() - start_time:.1f}s")
@@ -159,148 +161,95 @@ class ChronosInferencePipeline:
159
 
160
  total_start = time.time()
161
 
162
- # SUB-BATCH INFERENCE: Process borders in chunks to fit GPU memory
163
- # T4 GPU has 14.74 GB total, model uses ~14 GB, so we need small batches
164
- SUB_BATCH_SIZE = 10 # Process 10 borders at a time
165
-
166
- print(f"\n[BATCH] Preparing contexts for {len(forecast_borders)} borders...")
167
- all_contexts = []
168
- all_border_names = []
169
 
170
  for i, border in enumerate(forecast_borders, 1):
171
- print(f" [{i}/{len(forecast_borders)}] Extracting context for {border}...", flush=True)
 
 
172
  try:
173
- # Extract data
174
  context_data, future_data = forecaster.prepare_forecast_data(
175
  run_date=run_datetime,
176
  border=border
177
  )
178
 
179
- # Get target column name (note: dynamic_forecast renames it to 'target')
180
- target_col = 'target'
181
-
182
- # Extract context values and convert to PyTorch tensor
183
- context = torch.from_numpy(context_data[target_col].values).float()
184
- all_contexts.append(context)
185
- all_border_names.append(border)
186
-
187
- except Exception as e:
188
- import traceback
189
- error_msg = f"{type(e).__name__}: {str(e)}"
190
- traceback_str = traceback.format_exc()
191
- print(f" [ERROR] {border}: {error_msg}", flush=True)
192
- results['borders'][border] = {'error': error_msg, 'traceback': traceback_str}
193
-
194
- # Process contexts in sub-batches
195
- if all_contexts:
196
- num_contexts = len(all_contexts)
197
- num_sub_batches = (num_contexts + SUB_BATCH_SIZE - 1) // SUB_BATCH_SIZE
198
-
199
- print(f"\n[BATCH] Running inference in {num_sub_batches} sub-batches of {SUB_BATCH_SIZE} borders...")
200
-
201
- all_forecasts = []
202
- total_inference_time = 0
203
-
204
- for batch_idx in range(num_sub_batches):
205
- start_idx = batch_idx * SUB_BATCH_SIZE
206
- end_idx = min(start_idx + SUB_BATCH_SIZE, num_contexts)
207
-
208
- # Get sub-batch
209
- sub_batch_contexts = all_contexts[start_idx:end_idx]
210
- sub_batch_names = all_border_names[start_idx:end_idx]
211
 
212
- batch_tensor = torch.stack(sub_batch_contexts)
213
- print(f"[BATCH {batch_idx+1}/{num_sub_batches}] Processing {len(sub_batch_names)} borders: {sub_batch_names[0]} ... {sub_batch_names[-1]}", flush=True)
214
- print(f"[BATCH {batch_idx+1}/{num_sub_batches}] Batch shape: {batch_tensor.shape}", flush=True)
215
-
216
- inference_start = time.time()
217
-
218
- # Run batch inference
219
- batch_forecasts = pipeline.predict(
220
- inputs=batch_tensor,
221
  prediction_length=prediction_hours,
 
 
 
222
  num_samples=num_samples
223
  )
224
 
225
- inference_time = time.time() - inference_start
226
- total_inference_time += inference_time
227
- print(f"[BATCH {batch_idx+1}/{num_sub_batches}] Complete in {inference_time:.1f}s ({inference_time/len(sub_batch_names):.2f}s per border)", flush=True)
228
-
229
- # Store forecasts
230
- all_forecasts.append(batch_forecasts)
231
-
232
- # Clear GPU cache between sub-batches
233
- if torch.cuda.is_available():
234
- torch.cuda.empty_cache()
235
-
236
- print(f"\n[BATCH] All inference complete in {total_inference_time:.1f}s total")
237
- print(f"[BATCH] Average: {total_inference_time/num_contexts:.2f}s per border")
238
-
239
- # Process each border's forecast
240
- forecast_idx = 0
241
- for batch_idx, batch_forecasts in enumerate(all_forecasts):
242
- start_idx = batch_idx * SUB_BATCH_SIZE
243
- end_idx = min(start_idx + SUB_BATCH_SIZE, num_contexts)
244
- sub_batch_names = all_border_names[start_idx:end_idx]
245
-
246
- for i, border in enumerate(sub_batch_names):
247
- forecast_idx += 1
248
- print(f"\n[{forecast_idx}/{num_contexts}] Processing forecast for {border}...", flush=True)
249
- border_start = time.time()
250
-
251
- try:
252
- # Extract this border's forecast from batch
253
- forecast = batch_forecasts[i] # Extract from batch dimension
254
-
255
- # Calculate quantiles
256
- forecast_numpy = forecast.numpy()
257
- print(f"[DEBUG] Raw forecast shape: {forecast_numpy.shape}", flush=True)
258
-
259
- # Chronos may return (batch, num_samples, time) or (num_samples, time)
260
- # Squeeze any batch dimension (if present)
261
- if forecast_numpy.ndim == 3:
262
- print(f"[DEBUG] 3D forecast detected, squeezing batch dimension", flush=True)
263
- forecast_numpy = forecast_numpy.squeeze(axis=0) # Remove batch dim
264
-
265
- print(f"[DEBUG] Forecast shape after squeeze: {forecast_numpy.shape}, Expected: ({num_samples}, {prediction_hours}) or ({prediction_hours}, {num_samples})", flush=True)
266
-
267
- # Now forecast should be 2D: either (num_samples, time) or (time, num_samples)
268
- # Compute median along samples axis to get (time,) shape
269
- if forecast_numpy.shape[0] == num_samples and forecast_numpy.shape[1] == prediction_hours:
270
- # Shape is (num_samples, time) - use axis=0
271
- print(f"[DEBUG] Using axis=0 for shape (num_samples={num_samples}, time={prediction_hours})", flush=True)
272
  median = np.median(forecast_numpy, axis=0)
273
  q10 = np.quantile(forecast_numpy, 0.1, axis=0)
274
  q90 = np.quantile(forecast_numpy, 0.9, axis=0)
275
- elif forecast_numpy.shape[0] == prediction_hours and forecast_numpy.shape[1] == num_samples:
276
- # Shape is (time, num_samples) - use axis=1
277
- print(f"[DEBUG] Using axis=1 for shape (time={prediction_hours}, num_samples={num_samples})", flush=True)
278
  median = np.median(forecast_numpy, axis=1)
279
  q10 = np.quantile(forecast_numpy, 0.1, axis=1)
280
  q90 = np.quantile(forecast_numpy, 0.9, axis=1)
281
- else:
282
- raise ValueError(f"Unexpected forecast shape: {forecast_numpy.shape}, expected ({num_samples}, {prediction_hours}) or ({prediction_hours}, {num_samples})")
283
-
284
- print(f"[DEBUG] Final median shape: {median.shape}, Expected: ({prediction_hours},)", flush=True)
285
- assert median.shape == (prediction_hours,), f"Median shape {median.shape} != expected ({prediction_hours},)"
286
-
287
- # Store results
288
- results['borders'][border] = {
289
- 'median': median.tolist(),
290
- 'q10': q10.tolist(),
291
- 'q90': q90.tolist(),
292
- 'inference_time_s': time.time() - border_start
293
- }
294
-
295
- print(f" [OK] Complete in {time.time() - border_start:.1f}s")
296
-
297
- except Exception as e:
298
- import traceback
299
- error_msg = f"{type(e).__name__}: {str(e)}"
300
- traceback_str = traceback.format_exc()
301
- print(f" [ERROR] {error_msg}", flush=True)
302
- print(f"Traceback:\n{traceback_str}", flush=True)
303
- results['borders'][border] = {'error': error_msg, 'traceback': traceback_str}
 
 
 
304
 
305
  # Add summary metadata
306
  results['metadata']['total_time_s'] = time.time() - total_start
 
1
  #!/usr/bin/env python3
2
  """
3
+ Chronos-2 Inference Pipeline with Covariate Support
4
  Standalone inference script for HuggingFace Space deployment.
5
+ Uses predict_df() API to enable multivariate forecasting with weather, generation, CNEC outages.
6
+ FORCE REBUILD: v1.1.0
7
  """
8
 
9
  import os
 
15
  import numpy as np
16
  import torch
17
  from datasets import load_dataset
18
+ from chronos import Chronos2Pipeline
19
 
20
  from .dynamic_forecast import DynamicForecast
21
  from .feature_availability import FeatureAvailability
 
23
 
24
  class ChronosInferencePipeline:
25
  """
26
+ Production inference pipeline for Chronos-2 zero-shot forecasting WITH COVARIATES.
27
+ Uses predict_df() API to leverage all 615 collected features (weather, generation, outages, etc.)
28
  Designed for deployment as API endpoint on HuggingFace Spaces.
29
  """
30
 
31
  def __init__(
32
  self,
33
+ model_name: str = "amazon/chronos-2",
34
  device: str = "cuda",
35
+ dtype: str = "float32"
36
  ):
37
  """
38
  Initialize inference pipeline.
39
 
40
  Args:
41
+ model_name: HuggingFace model identifier (chronos-2 supports covariates)
42
  device: Device for inference ('cuda' or 'cpu')
43
+ dtype: Data type for model weights (float32 for chronos-2)
44
  """
45
  self.model_name = model_name
46
  self.device = device
 
52
  self._borders = None
53
 
54
  def _load_model(self):
55
+ """Load Chronos-2 model (cached after first call)"""
56
  if self._pipeline is None:
57
  print(f"Loading {self.model_name}...")
58
  start_time = time.time()
 
63
  "float32": torch.float32
64
  }
65
 
66
+ self._pipeline = Chronos2Pipeline.from_pretrained(
67
  self.model_name,
68
  device_map=self.device,
69
+ torch_dtype=dtype_map.get(self.dtype, torch.float32)
70
  )
71
 
72
  print(f"Model loaded in {time.time() - start_time:.1f}s")
 
161
 
162
  total_start = time.time()
163
 
164
+ # PER-BORDER INFERENCE WITH COVARIATES
165
+ # Using predict_df() API to leverage all 615 features (weather, generation, CNEC outages, etc.)
166
+ print(f"\n[COVARIATE FORECAST] Running inference for {len(forecast_borders)} borders with 615 features...")
167
+ print(f" Features: weather per zone, generation per zone, CNEC outages, LTA, load forecasts")
 
 
 
168
 
169
  for i, border in enumerate(forecast_borders, 1):
170
+ border_start = time.time()
171
+ print(f"\n [{i}/{len(forecast_borders)}] {border}...", flush=True)
172
+
173
  try:
174
+ # Extract data WITH covariates
175
  context_data, future_data = forecaster.prepare_forecast_data(
176
  run_date=run_datetime,
177
  border=border
178
  )
179
 
180
+ print(f" Context shape: {context_data.shape}, Future shape: {future_data.shape}", flush=True)
181
+ print(f" Using {len(future_data.columns)-2} future covariates for multivariate forecast", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ # Run covariate-informed inference using DataFrame API
184
+ forecasts_df = pipeline.predict_df(
185
+ context_data, # Historical data with ALL features
186
+ future_df=future_data, # Future covariates (615 features)
 
 
 
 
 
187
  prediction_length=prediction_hours,
188
+ id_column='border',
189
+ timestamp_column='timestamp',
190
+ target='target',
191
  num_samples=num_samples
192
  )
193
 
194
+ # Extract quantiles from probabilistic forecast
195
+ # predict_df returns samples - we need to compute quantiles
196
+ # The output format depends on Chronos2Pipeline implementation
197
+ # Typically returns DataFrame with columns per quantile or sample
198
+
199
+ # Convert to numpy for quantile calculation
200
+ if isinstance(forecasts_df, pd.DataFrame):
201
+ # Extract sample columns (format: sample_0, sample_1, ...)
202
+ sample_cols = [col for col in forecasts_df.columns if col.startswith('sample_')]
203
+ if sample_cols:
204
+ # Shape: (time, num_samples)
205
+ forecast_samples = forecasts_df[sample_cols].values
206
+ median = np.median(forecast_samples, axis=1)
207
+ q10 = np.quantile(forecast_samples, 0.1, axis=1)
208
+ q90 = np.quantile(forecast_samples, 0.9, axis=1)
209
+ else:
210
+ # Fallback: single prediction column
211
+ median = forecasts_df['prediction'].values if 'prediction' in forecasts_df.columns else forecasts_df.iloc[:, 0].values
212
+ q10 = median.copy() # No uncertainty if single prediction
213
+ q90 = median.copy()
214
+ else:
215
+ # Handle tensor output (fallback)
216
+ forecast_numpy = forecasts_df.numpy() if hasattr(forecasts_df, 'numpy') else np.array(forecasts_df)
217
+ if forecast_numpy.ndim == 2:
218
+ # (num_samples, time) or (time, num_samples)
219
+ if forecast_numpy.shape[0] == num_samples:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  median = np.median(forecast_numpy, axis=0)
221
  q10 = np.quantile(forecast_numpy, 0.1, axis=0)
222
  q90 = np.quantile(forecast_numpy, 0.9, axis=0)
223
+ else:
 
 
224
  median = np.median(forecast_numpy, axis=1)
225
  q10 = np.quantile(forecast_numpy, 0.1, axis=1)
226
  q90 = np.quantile(forecast_numpy, 0.9, axis=1)
227
+ else:
228
+ median = forecast_numpy.flatten()
229
+ q10 = median.copy()
230
+ q90 = median.copy()
231
+
232
+ inference_time = time.time() - border_start
233
+
234
+ # Store results
235
+ results['borders'][border] = {
236
+ 'median': median.tolist(),
237
+ 'q10': q10.tolist(),
238
+ 'q90': q90.tolist(),
239
+ 'inference_time_s': inference_time,
240
+ 'used_covariates': True,
241
+ 'num_features': len(future_data.columns) - 2 # Exclude border and timestamp
242
+ }
243
+
244
+ print(f" [OK] Complete in {inference_time:.1f}s (WITH {len(future_data.columns)-2} covariates)", flush=True)
245
+
246
+ except Exception as e:
247
+ import traceback
248
+ error_msg = f"{type(e).__name__}: {str(e)}"
249
+ traceback_str = traceback.format_exc()
250
+ print(f" [ERROR] {error_msg}", flush=True)
251
+ print(f"Traceback:\n{traceback_str}", flush=True)
252
+ results['borders'][border] = {'error': error_msg, 'traceback': traceback_str}
253
 
254
  # Add summary metadata
255
  results['metadata']['total_time_s'] = time.time() - total_start