#!/usr/bin/env python3 """ Chronos-2 Inference Pipeline with Past-Only Covariate Masking Standalone inference script for HuggingFace Space deployment. Uses predict_df() API with ALL 2,514 features leveraging Chronos-2's mask-based attention. FORCE REBUILD: v1.6.0 - Extended context window (2,160 hours = 90 days) optimized for 96GB VRAM """ import os import time from typing import List, Dict, Optional from datetime import datetime, timedelta # CRITICAL: Set PyTorch memory allocator config BEFORE importing torch # This prevents memory fragmentation issues that cause OOM even with sufficient free memory # See: https://pytorch.org/docs/stable/notes/cuda.html#environment-variables os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' import polars as pl import pandas as pd import numpy as np import torch from datasets import load_dataset from chronos import Chronos2Pipeline from .dynamic_forecast import DynamicForecast from .feature_availability import FeatureAvailability class ChronosInferencePipeline: """ Production inference pipeline for Chronos-2 zero-shot forecasting WITH PAST-ONLY MASKING. Uses predict_df() API with ALL 3,043 features (known-future + past-only covariates). Past-only covariates (CNEC, volatility, historical flows) are masked in future → model learns cross-feature correlations from historical context via attention mechanism. Designed for deployment as API endpoint on HuggingFace Spaces. """ def __init__( self, model_name: str = "amazon/chronos-2", device: str = "cuda", dtype: str = "bfloat16" ): """ Initialize inference pipeline. Args: model_name: HuggingFace model identifier (chronos-2 supports covariates) device: Device for inference ('cuda' or 'cpu') dtype: Data type for model weights (bfloat16 for memory efficiency) """ self.model_name = model_name self.device = device self.dtype = dtype # Model loaded on first inference (lazy loading) self._pipeline = None self._dataset = None self._borders = None def _load_model(self): """Load Chronos-2 model (cached after first call)""" if self._pipeline is None: print(f"Loading {self.model_name}...") start_time = time.time() dtype_map = { "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32 } self._pipeline = Chronos2Pipeline.from_pretrained( self.model_name, device_map="auto", # Auto-distribute across all available GPUs torch_dtype=dtype_map.get(self.dtype, torch.float32) ) # Set model to evaluation mode (disables dropout, etc.) self._pipeline.model.eval() print(f"Model loaded in {time.time() - start_time:.1f}s") print(f" Device: {next(self._pipeline.model.parameters()).device}") # GPU detection and memory profiling diagnostics if torch.cuda.is_available(): gpu_count = torch.cuda.device_count() total_vram = sum(torch.cuda.get_device_properties(i).total_memory for i in range(gpu_count)) print(f" [GPU] Detected {gpu_count} GPU(s)") print(f" [GPU] Total VRAM: {total_vram/1e9:.1f} GB") print(f" [MEMORY] After model load:") print(f" GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB") print(f" GPU memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB") return self._pipeline def _load_dataset(self): """Load dataset from HuggingFace (cached after first call)""" if self._dataset is None: print("Loading dataset from HuggingFace...") start_time = time.time() hf_token = os.getenv("HF_TOKEN") dataset = load_dataset( "evgueni-p/fbmc-features-24month", split="train", token=hf_token ) # Convert to Polars self._dataset = pl.from_arrow(dataset.data.table) # Extract available borders target_cols = [col for col in self._dataset.columns if col.startswith('target_border_')] self._borders = [col.replace('target_border_', '') for col in target_cols] print(f"Dataset loaded in {time.time() - start_time:.1f}s") print(f" Shape: {self._dataset.shape}") print(f" Borders: {len(self._borders)}") # Memory profiling diagnostics if torch.cuda.is_available(): print(f" [MEMORY] After dataset load:") print(f" GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB") print(f" GPU memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB") return self._dataset, self._borders def run_forecast( self, run_date: str, borders: Optional[List[str]] = None, forecast_days: int = 7, context_hours: int = 1125, # 1,125 hours = 46.9 days (1.5 months, fits A100-80GB) num_samples: int = 20 ) -> Dict: """ Run zero-shot forecast for specified borders. Args: run_date: Forecast run date (YYYY-MM-DD format) borders: List of borders to forecast (None = all borders) forecast_days: Forecast horizon in days (7 or 14) context_hours: Historical context window num_samples: Number of probabilistic samples Returns: Dictionary with forecast results and metadata """ # Load model and dataset (cached) pipeline = self._load_model() df, all_borders = self._load_dataset() # Parse run date run_datetime = datetime.strptime(run_date, "%Y-%m-%d") run_datetime = run_datetime.replace(hour=23, minute=0) # Determine borders to forecast forecast_borders = borders if borders else all_borders prediction_hours = forecast_days * 24 print(f"\nForecast configuration:") print(f" Run date: {run_datetime}") print(f" Borders: {len(forecast_borders)}") print(f" Forecast horizon: {forecast_days} days ({prediction_hours} hours)") print(f" Context window: {context_hours} hours") # Initialize dynamic forecast system forecaster = DynamicForecast( dataset=df, context_hours=context_hours, forecast_hours=prediction_hours ) # Run forecasts for each border results = { 'run_date': run_date, 'forecast_days': forecast_days, 'borders': {}, 'metadata': { 'model': self.model_name, 'device': self.device, 'num_samples': num_samples, 'context_hours': context_hours } } total_start = time.time() # PER-BORDER INFERENCE WITH PAST-ONLY COVARIATE MASKING # Using predict_df() API with ALL 2,514 features (known-future + past-only masked) print(f"\n[PAST-ONLY MASKING] Running inference for {len(forecast_borders)} borders with 2,514 features...") print(f" Known-future: weather, generation, load forecasts (615 features)") print(f" Past-only masked: CNEC outages, volatility, historical flows (1,899 features)") for i, border in enumerate(forecast_borders, 1): # Clear GPU cache BEFORE each border to prevent memory accumulation # This releases tensors from previous border (no-op on first iteration) # Does NOT affect model weights (120M params stay loaded) # Does NOT affect forecast accuracy (each border is independent) if i > 1: # Skip on first border (clean GPU state) torch.cuda.empty_cache() import gc gc.collect() # Force Python garbage collector to free tensors border_start = time.time() print(f"\n [{i}/{len(forecast_borders)}] {border}...", flush=True) try: # Extract data WITH covariates context_data, future_data = forecaster.prepare_forecast_data( run_date=run_datetime, border=border ) print(f" Context shape: {context_data.shape}, Future shape: {future_data.shape}", flush=True) print(f" Using {len(future_data.columns)-2} features (known-future + past-only masked)", flush=True) # Run covariate-informed inference using DataFrame API # Note: predict_df() returns quantiles directly # Request 9 quantiles to capture learned uncertainty and tail events # Use torch.inference_mode() to disable gradient tracking (saves ~2-5 GB VRAM) with torch.inference_mode(): forecasts_df = pipeline.predict_df( context_data, # Historical data with ALL features future_df=future_data, # All 3,043 features (past-only masked) prediction_length=prediction_hours, id_column='border', timestamp_column='timestamp', target='target', batch_size=32, # Reduced from 64 (41.57GB -> 20.79GB attention tensor to fit single GPU) quantile_levels=[0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95, 0.99] # 9 quantiles for volatility ) # Extract all 9 quantiles from predict_df() output # predict_df() returns quantiles directly as string columns if isinstance(forecasts_df, pd.DataFrame): # Expected columns: '0.01', '0.05', '0.1', '0.25', '0.5', '0.75', '0.9', '0.95', '0.99' quantile_cols = ['0.01', '0.05', '0.1', '0.25', '0.5', '0.75', '0.9', '0.95', '0.99'] # Extract all quantiles quantiles = {} for q in quantile_cols: if q in forecasts_df.columns: quantiles[q] = forecasts_df[q].values else: # Fallback if quantile missing if '0.5' in forecasts_df.columns: quantiles[q] = forecasts_df['0.5'].values # Use median as fallback elif 'predictions' in forecasts_df.columns: quantiles[q] = forecasts_df['predictions'].values else: raise ValueError(f"Missing quantile {q} and no fallback available. Columns: {forecasts_df.columns.tolist()}") # Backward compatibility: still extract median, q10, q90 median = quantiles['0.5'] q10 = quantiles['0.1'] q90 = quantiles['0.9'] else: raise TypeError(f"Expected DataFrame from predict_df(), got {type(forecasts_df)}") # Round all quantiles to nearest integer (capacity values are always whole MW) median = np.round(median).astype(int) q10 = np.round(q10).astype(int) q90 = np.round(q90).astype(int) # Round all other quantiles for q_key in quantiles: quantiles[q_key] = np.round(quantiles[q_key]).astype(int) inference_time = time.time() - border_start # Store results (backward compatible + all quantiles) results['borders'][border] = { 'median': median.tolist(), 'q10': q10.tolist(), 'q90': q90.tolist(), # Add all 9 quantiles for adaptive selection 'q01': quantiles['0.01'].tolist(), 'q05': quantiles['0.05'].tolist(), 'q25': quantiles['0.25'].tolist(), 'q75': quantiles['0.75'].tolist(), 'q95': quantiles['0.95'].tolist(), 'q99': quantiles['0.99'].tolist(), 'inference_time_s': inference_time, 'used_covariates': True, 'num_features': len(future_data.columns) - 2 # Exclude border and timestamp } print(f" [OK] Complete in {inference_time:.1f}s ({len(future_data.columns)-2} features with past-only masking)", flush=True) except Exception as e: import traceback error_msg = f"{type(e).__name__}: {str(e)}" traceback_str = traceback.format_exc() print(f" [ERROR] {error_msg}", flush=True) print(f"Traceback:\n{traceback_str}", flush=True) results['borders'][border] = {'error': error_msg, 'traceback': traceback_str} # Add summary metadata results['metadata']['total_time_s'] = time.time() - total_start results['metadata']['successful_borders'] = sum( 1 for b in results['borders'].values() if 'error' not in b ) print(f"\n{'='*60}") print(f"FORECAST COMPLETE") print(f"{'='*60}") print(f"Total time: {results['metadata']['total_time_s']:.1f}s") print(f"Successful: {results['metadata']['successful_borders']}/{len(forecast_borders)} borders") return results def export_to_parquet(self, results: Dict, output_path: str): """ Export forecast results to parquet format. Args: results: Forecast results from run_forecast() output_path: Path to save parquet file """ # Create forecast timestamps run_datetime = datetime.strptime(results['run_date'], "%Y-%m-%d") forecast_start = run_datetime + timedelta(days=1) # Next day at midnight, not +1 hour forecast_hours = results['forecast_days'] * 24 timestamps = [ forecast_start + timedelta(hours=h) for h in range(forecast_hours) ] # Build DataFrame data = {'timestamp': timestamps} successful_borders = [] failed_borders = [] for border, forecast_data in results['borders'].items(): if 'error' not in forecast_data: data[f'{border}_median'] = forecast_data['median'] data[f'{border}_q10'] = forecast_data['q10'] data[f'{border}_q90'] = forecast_data['q90'] # Add adaptive forecast if available (learned uncertainty-based selection) if 'adaptive' in forecast_data: data[f'{border}_adaptive'] = forecast_data['adaptive'] successful_borders.append(border) else: failed_borders.append((border, forecast_data['error'])) # Log results print(f"[EXPORT] Forecast export summary:", flush=True) print(f" Successful: {len(successful_borders)} borders", flush=True) print(f" Failed: {len(failed_borders)} borders", flush=True) if failed_borders: print(f"[EXPORT] Errors:", flush=True) for border, error in failed_borders: print(f" {border}: {error}", flush=True) df = pl.DataFrame(data) df.write_parquet(output_path) print(f"[EXPORT] Exported to: {output_path}", flush=True) print(f"[EXPORT] Shape: {df.shape}, Columns: {len(df.columns)}", flush=True) return output_path # Convenience function for API usage def run_inference( run_date: str, forecast_type: str = "smoke_test", borders: Optional[List[str]] = None, output_dir: str = "/tmp" ) -> str: """ Run forecast and return path to results file. Args: run_date: Forecast run date (YYYY-MM-DD) forecast_type: 'smoke_test' (7 days, 1 border) or 'full_14day' (14 days, all borders) borders: Specific borders to forecast (None = use forecast_type defaults) output_dir: Directory to save results Returns: Path to forecast results parquet file """ # Initialize pipeline pipeline = ChronosInferencePipeline() # Configure based on forecast type if forecast_type == "smoke_test": forecast_days = 7 if borders is None: # Load just to get first border _, all_borders = pipeline._load_dataset() borders = [all_borders[0]] else: # full_14day forecast_days = 14 # borders = None means all borders # Run forecast results = pipeline.run_forecast( run_date=run_date, borders=borders, forecast_days=forecast_days ) # Write debug file debug_filename = f"debug_{run_date}_{forecast_type}.txt" debug_path = os.path.join(output_dir, debug_filename) with open(debug_path, 'w') as f: f.write(f"Results summary:\n") f.write(f" Run date: {results['run_date']}\n") f.write(f" Forecast days: {results['forecast_days']}\n") f.write(f" Borders in results: {list(results['borders'].keys())}\n\n") for border, data in results['borders'].items(): if 'error' in data: f.write(f" {border}: ERROR - {data['error']}\n") if 'traceback' in data: f.write(f"\nFull Traceback:\n{data['traceback']}\n") else: f.write(f" {border}: OK\n") f.write(f" median count: {len(data.get('median', []))}\n") f.write(f" q10 count: {len(data.get('q10', []))}\n") f.write(f" q90 count: {len(data.get('q90', []))}\n") print(f"Debug file written to: {debug_path}", flush=True) # Export to parquet output_filename = f"forecast_{run_date}_{forecast_type}.parquet" output_path = os.path.join(output_dir, output_filename) pipeline.export_to_parquet(results, output_path) # Check if forecast has data, if not return debug file successful_count = sum(1 for data in results['borders'].values() if 'error' not in data) if successful_count == 0: print(f"[WARNING] No successful forecasts! Returning debug file instead.", flush=True) return debug_path return output_path