fbmc-chronos2 / src /forecasting /chronos_inference.py
Evgueni Poloukarov
fix: reduce context window to 1125 hours (1.5 months) for A100-80GB
d080539
#!/usr/bin/env python3
"""
Chronos-2 Inference Pipeline with Past-Only Covariate Masking
Standalone inference script for HuggingFace Space deployment.
Uses predict_df() API with ALL 2,514 features leveraging Chronos-2's mask-based attention.
FORCE REBUILD: v1.6.0 - Extended context window (2,160 hours = 90 days) optimized for 96GB VRAM
"""
import os
import time
from typing import List, Dict, Optional
from datetime import datetime, timedelta
# CRITICAL: Set PyTorch memory allocator config BEFORE importing torch
# This prevents memory fragmentation issues that cause OOM even with sufficient free memory
# See: https://pytorch.org/docs/stable/notes/cuda.html#environment-variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import polars as pl
import pandas as pd
import numpy as np
import torch
from datasets import load_dataset
from chronos import Chronos2Pipeline
from .dynamic_forecast import DynamicForecast
from .feature_availability import FeatureAvailability
class ChronosInferencePipeline:
"""
Production inference pipeline for Chronos-2 zero-shot forecasting WITH PAST-ONLY MASKING.
Uses predict_df() API with ALL 3,043 features (known-future + past-only covariates).
Past-only covariates (CNEC, volatility, historical flows) are masked in future → model
learns cross-feature correlations from historical context via attention mechanism.
Designed for deployment as API endpoint on HuggingFace Spaces.
"""
def __init__(
self,
model_name: str = "amazon/chronos-2",
device: str = "cuda",
dtype: str = "bfloat16"
):
"""
Initialize inference pipeline.
Args:
model_name: HuggingFace model identifier (chronos-2 supports covariates)
device: Device for inference ('cuda' or 'cpu')
dtype: Data type for model weights (bfloat16 for memory efficiency)
"""
self.model_name = model_name
self.device = device
self.dtype = dtype
# Model loaded on first inference (lazy loading)
self._pipeline = None
self._dataset = None
self._borders = None
def _load_model(self):
"""Load Chronos-2 model (cached after first call)"""
if self._pipeline is None:
print(f"Loading {self.model_name}...")
start_time = time.time()
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32
}
self._pipeline = Chronos2Pipeline.from_pretrained(
self.model_name,
device_map="auto", # Auto-distribute across all available GPUs
torch_dtype=dtype_map.get(self.dtype, torch.float32)
)
# Set model to evaluation mode (disables dropout, etc.)
self._pipeline.model.eval()
print(f"Model loaded in {time.time() - start_time:.1f}s")
print(f" Device: {next(self._pipeline.model.parameters()).device}")
# GPU detection and memory profiling diagnostics
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
total_vram = sum(torch.cuda.get_device_properties(i).total_memory for i in range(gpu_count))
print(f" [GPU] Detected {gpu_count} GPU(s)")
print(f" [GPU] Total VRAM: {total_vram/1e9:.1f} GB")
print(f" [MEMORY] After model load:")
print(f" GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
print(f" GPU memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB")
return self._pipeline
def _load_dataset(self):
"""Load dataset from HuggingFace (cached after first call)"""
if self._dataset is None:
print("Loading dataset from HuggingFace...")
start_time = time.time()
hf_token = os.getenv("HF_TOKEN")
dataset = load_dataset(
"evgueni-p/fbmc-features-24month",
split="train",
token=hf_token
)
# Convert to Polars
self._dataset = pl.from_arrow(dataset.data.table)
# Extract available borders
target_cols = [col for col in self._dataset.columns if col.startswith('target_border_')]
self._borders = [col.replace('target_border_', '') for col in target_cols]
print(f"Dataset loaded in {time.time() - start_time:.1f}s")
print(f" Shape: {self._dataset.shape}")
print(f" Borders: {len(self._borders)}")
# Memory profiling diagnostics
if torch.cuda.is_available():
print(f" [MEMORY] After dataset load:")
print(f" GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
print(f" GPU memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB")
return self._dataset, self._borders
def run_forecast(
self,
run_date: str,
borders: Optional[List[str]] = None,
forecast_days: int = 7,
context_hours: int = 1125, # 1,125 hours = 46.9 days (1.5 months, fits A100-80GB)
num_samples: int = 20
) -> Dict:
"""
Run zero-shot forecast for specified borders.
Args:
run_date: Forecast run date (YYYY-MM-DD format)
borders: List of borders to forecast (None = all borders)
forecast_days: Forecast horizon in days (7 or 14)
context_hours: Historical context window
num_samples: Number of probabilistic samples
Returns:
Dictionary with forecast results and metadata
"""
# Load model and dataset (cached)
pipeline = self._load_model()
df, all_borders = self._load_dataset()
# Parse run date
run_datetime = datetime.strptime(run_date, "%Y-%m-%d")
run_datetime = run_datetime.replace(hour=23, minute=0)
# Determine borders to forecast
forecast_borders = borders if borders else all_borders
prediction_hours = forecast_days * 24
print(f"\nForecast configuration:")
print(f" Run date: {run_datetime}")
print(f" Borders: {len(forecast_borders)}")
print(f" Forecast horizon: {forecast_days} days ({prediction_hours} hours)")
print(f" Context window: {context_hours} hours")
# Initialize dynamic forecast system
forecaster = DynamicForecast(
dataset=df,
context_hours=context_hours,
forecast_hours=prediction_hours
)
# Run forecasts for each border
results = {
'run_date': run_date,
'forecast_days': forecast_days,
'borders': {},
'metadata': {
'model': self.model_name,
'device': self.device,
'num_samples': num_samples,
'context_hours': context_hours
}
}
total_start = time.time()
# PER-BORDER INFERENCE WITH PAST-ONLY COVARIATE MASKING
# Using predict_df() API with ALL 2,514 features (known-future + past-only masked)
print(f"\n[PAST-ONLY MASKING] Running inference for {len(forecast_borders)} borders with 2,514 features...")
print(f" Known-future: weather, generation, load forecasts (615 features)")
print(f" Past-only masked: CNEC outages, volatility, historical flows (1,899 features)")
for i, border in enumerate(forecast_borders, 1):
# Clear GPU cache BEFORE each border to prevent memory accumulation
# This releases tensors from previous border (no-op on first iteration)
# Does NOT affect model weights (120M params stay loaded)
# Does NOT affect forecast accuracy (each border is independent)
if i > 1: # Skip on first border (clean GPU state)
torch.cuda.empty_cache()
import gc
gc.collect() # Force Python garbage collector to free tensors
border_start = time.time()
print(f"\n [{i}/{len(forecast_borders)}] {border}...", flush=True)
try:
# Extract data WITH covariates
context_data, future_data = forecaster.prepare_forecast_data(
run_date=run_datetime,
border=border
)
print(f" Context shape: {context_data.shape}, Future shape: {future_data.shape}", flush=True)
print(f" Using {len(future_data.columns)-2} features (known-future + past-only masked)", flush=True)
# Run covariate-informed inference using DataFrame API
# Note: predict_df() returns quantiles directly
# Request 9 quantiles to capture learned uncertainty and tail events
# Use torch.inference_mode() to disable gradient tracking (saves ~2-5 GB VRAM)
with torch.inference_mode():
forecasts_df = pipeline.predict_df(
context_data, # Historical data with ALL features
future_df=future_data, # All 3,043 features (past-only masked)
prediction_length=prediction_hours,
id_column='border',
timestamp_column='timestamp',
target='target',
batch_size=32, # Reduced from 64 (41.57GB -> 20.79GB attention tensor to fit single GPU)
quantile_levels=[0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95, 0.99] # 9 quantiles for volatility
)
# Extract all 9 quantiles from predict_df() output
# predict_df() returns quantiles directly as string columns
if isinstance(forecasts_df, pd.DataFrame):
# Expected columns: '0.01', '0.05', '0.1', '0.25', '0.5', '0.75', '0.9', '0.95', '0.99'
quantile_cols = ['0.01', '0.05', '0.1', '0.25', '0.5', '0.75', '0.9', '0.95', '0.99']
# Extract all quantiles
quantiles = {}
for q in quantile_cols:
if q in forecasts_df.columns:
quantiles[q] = forecasts_df[q].values
else:
# Fallback if quantile missing
if '0.5' in forecasts_df.columns:
quantiles[q] = forecasts_df['0.5'].values # Use median as fallback
elif 'predictions' in forecasts_df.columns:
quantiles[q] = forecasts_df['predictions'].values
else:
raise ValueError(f"Missing quantile {q} and no fallback available. Columns: {forecasts_df.columns.tolist()}")
# Backward compatibility: still extract median, q10, q90
median = quantiles['0.5']
q10 = quantiles['0.1']
q90 = quantiles['0.9']
else:
raise TypeError(f"Expected DataFrame from predict_df(), got {type(forecasts_df)}")
# Round all quantiles to nearest integer (capacity values are always whole MW)
median = np.round(median).astype(int)
q10 = np.round(q10).astype(int)
q90 = np.round(q90).astype(int)
# Round all other quantiles
for q_key in quantiles:
quantiles[q_key] = np.round(quantiles[q_key]).astype(int)
inference_time = time.time() - border_start
# Store results (backward compatible + all quantiles)
results['borders'][border] = {
'median': median.tolist(),
'q10': q10.tolist(),
'q90': q90.tolist(),
# Add all 9 quantiles for adaptive selection
'q01': quantiles['0.01'].tolist(),
'q05': quantiles['0.05'].tolist(),
'q25': quantiles['0.25'].tolist(),
'q75': quantiles['0.75'].tolist(),
'q95': quantiles['0.95'].tolist(),
'q99': quantiles['0.99'].tolist(),
'inference_time_s': inference_time,
'used_covariates': True,
'num_features': len(future_data.columns) - 2 # Exclude border and timestamp
}
print(f" [OK] Complete in {inference_time:.1f}s ({len(future_data.columns)-2} features with past-only masking)", flush=True)
except Exception as e:
import traceback
error_msg = f"{type(e).__name__}: {str(e)}"
traceback_str = traceback.format_exc()
print(f" [ERROR] {error_msg}", flush=True)
print(f"Traceback:\n{traceback_str}", flush=True)
results['borders'][border] = {'error': error_msg, 'traceback': traceback_str}
# Add summary metadata
results['metadata']['total_time_s'] = time.time() - total_start
results['metadata']['successful_borders'] = sum(
1 for b in results['borders'].values() if 'error' not in b
)
print(f"\n{'='*60}")
print(f"FORECAST COMPLETE")
print(f"{'='*60}")
print(f"Total time: {results['metadata']['total_time_s']:.1f}s")
print(f"Successful: {results['metadata']['successful_borders']}/{len(forecast_borders)} borders")
return results
def export_to_parquet(self, results: Dict, output_path: str):
"""
Export forecast results to parquet format.
Args:
results: Forecast results from run_forecast()
output_path: Path to save parquet file
"""
# Create forecast timestamps
run_datetime = datetime.strptime(results['run_date'], "%Y-%m-%d")
forecast_start = run_datetime + timedelta(days=1) # Next day at midnight, not +1 hour
forecast_hours = results['forecast_days'] * 24
timestamps = [
forecast_start + timedelta(hours=h)
for h in range(forecast_hours)
]
# Build DataFrame
data = {'timestamp': timestamps}
successful_borders = []
failed_borders = []
for border, forecast_data in results['borders'].items():
if 'error' not in forecast_data:
data[f'{border}_median'] = forecast_data['median']
data[f'{border}_q10'] = forecast_data['q10']
data[f'{border}_q90'] = forecast_data['q90']
# Add adaptive forecast if available (learned uncertainty-based selection)
if 'adaptive' in forecast_data:
data[f'{border}_adaptive'] = forecast_data['adaptive']
successful_borders.append(border)
else:
failed_borders.append((border, forecast_data['error']))
# Log results
print(f"[EXPORT] Forecast export summary:", flush=True)
print(f" Successful: {len(successful_borders)} borders", flush=True)
print(f" Failed: {len(failed_borders)} borders", flush=True)
if failed_borders:
print(f"[EXPORT] Errors:", flush=True)
for border, error in failed_borders:
print(f" {border}: {error}", flush=True)
df = pl.DataFrame(data)
df.write_parquet(output_path)
print(f"[EXPORT] Exported to: {output_path}", flush=True)
print(f"[EXPORT] Shape: {df.shape}, Columns: {len(df.columns)}", flush=True)
return output_path
# Convenience function for API usage
def run_inference(
run_date: str,
forecast_type: str = "smoke_test",
borders: Optional[List[str]] = None,
output_dir: str = "/tmp"
) -> str:
"""
Run forecast and return path to results file.
Args:
run_date: Forecast run date (YYYY-MM-DD)
forecast_type: 'smoke_test' (7 days, 1 border) or 'full_14day' (14 days, all borders)
borders: Specific borders to forecast (None = use forecast_type defaults)
output_dir: Directory to save results
Returns:
Path to forecast results parquet file
"""
# Initialize pipeline
pipeline = ChronosInferencePipeline()
# Configure based on forecast type
if forecast_type == "smoke_test":
forecast_days = 7
if borders is None:
# Load just to get first border
_, all_borders = pipeline._load_dataset()
borders = [all_borders[0]]
else: # full_14day
forecast_days = 14
# borders = None means all borders
# Run forecast
results = pipeline.run_forecast(
run_date=run_date,
borders=borders,
forecast_days=forecast_days
)
# Write debug file
debug_filename = f"debug_{run_date}_{forecast_type}.txt"
debug_path = os.path.join(output_dir, debug_filename)
with open(debug_path, 'w') as f:
f.write(f"Results summary:\n")
f.write(f" Run date: {results['run_date']}\n")
f.write(f" Forecast days: {results['forecast_days']}\n")
f.write(f" Borders in results: {list(results['borders'].keys())}\n\n")
for border, data in results['borders'].items():
if 'error' in data:
f.write(f" {border}: ERROR - {data['error']}\n")
if 'traceback' in data:
f.write(f"\nFull Traceback:\n{data['traceback']}\n")
else:
f.write(f" {border}: OK\n")
f.write(f" median count: {len(data.get('median', []))}\n")
f.write(f" q10 count: {len(data.get('q10', []))}\n")
f.write(f" q90 count: {len(data.get('q90', []))}\n")
print(f"Debug file written to: {debug_path}", flush=True)
# Export to parquet
output_filename = f"forecast_{run_date}_{forecast_type}.parquet"
output_path = os.path.join(output_dir, output_filename)
pipeline.export_to_parquet(results, output_path)
# Check if forecast has data, if not return debug file
successful_count = sum(1 for data in results['borders'].values() if 'error' not in data)
if successful_count == 0:
print(f"[WARNING] No successful forecasts! Returning debug file instead.", flush=True)
return debug_path
return output_path