Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Chronos-2 Inference Pipeline with Past-Only Covariate Masking | |
| Standalone inference script for HuggingFace Space deployment. | |
| Uses predict_df() API with ALL 2,514 features leveraging Chronos-2's mask-based attention. | |
| FORCE REBUILD: v1.6.0 - Extended context window (2,160 hours = 90 days) optimized for 96GB VRAM | |
| """ | |
| import os | |
| import time | |
| from typing import List, Dict, Optional | |
| from datetime import datetime, timedelta | |
| # CRITICAL: Set PyTorch memory allocator config BEFORE importing torch | |
| # This prevents memory fragmentation issues that cause OOM even with sufficient free memory | |
| # See: https://pytorch.org/docs/stable/notes/cuda.html#environment-variables | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
| import polars as pl | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from datasets import load_dataset | |
| from chronos import Chronos2Pipeline | |
| from .dynamic_forecast import DynamicForecast | |
| from .feature_availability import FeatureAvailability | |
| class ChronosInferencePipeline: | |
| """ | |
| Production inference pipeline for Chronos-2 zero-shot forecasting WITH PAST-ONLY MASKING. | |
| Uses predict_df() API with ALL 3,043 features (known-future + past-only covariates). | |
| Past-only covariates (CNEC, volatility, historical flows) are masked in future → model | |
| learns cross-feature correlations from historical context via attention mechanism. | |
| Designed for deployment as API endpoint on HuggingFace Spaces. | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "amazon/chronos-2", | |
| device: str = "cuda", | |
| dtype: str = "bfloat16" | |
| ): | |
| """ | |
| Initialize inference pipeline. | |
| Args: | |
| model_name: HuggingFace model identifier (chronos-2 supports covariates) | |
| device: Device for inference ('cuda' or 'cpu') | |
| dtype: Data type for model weights (bfloat16 for memory efficiency) | |
| """ | |
| self.model_name = model_name | |
| self.device = device | |
| self.dtype = dtype | |
| # Model loaded on first inference (lazy loading) | |
| self._pipeline = None | |
| self._dataset = None | |
| self._borders = None | |
| def _load_model(self): | |
| """Load Chronos-2 model (cached after first call)""" | |
| if self._pipeline is None: | |
| print(f"Loading {self.model_name}...") | |
| start_time = time.time() | |
| dtype_map = { | |
| "bfloat16": torch.bfloat16, | |
| "float16": torch.float16, | |
| "float32": torch.float32 | |
| } | |
| self._pipeline = Chronos2Pipeline.from_pretrained( | |
| self.model_name, | |
| device_map="auto", # Auto-distribute across all available GPUs | |
| torch_dtype=dtype_map.get(self.dtype, torch.float32) | |
| ) | |
| # Set model to evaluation mode (disables dropout, etc.) | |
| self._pipeline.model.eval() | |
| print(f"Model loaded in {time.time() - start_time:.1f}s") | |
| print(f" Device: {next(self._pipeline.model.parameters()).device}") | |
| # GPU detection and memory profiling diagnostics | |
| if torch.cuda.is_available(): | |
| gpu_count = torch.cuda.device_count() | |
| total_vram = sum(torch.cuda.get_device_properties(i).total_memory for i in range(gpu_count)) | |
| print(f" [GPU] Detected {gpu_count} GPU(s)") | |
| print(f" [GPU] Total VRAM: {total_vram/1e9:.1f} GB") | |
| print(f" [MEMORY] After model load:") | |
| print(f" GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB") | |
| print(f" GPU memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB") | |
| return self._pipeline | |
| def _load_dataset(self): | |
| """Load dataset from HuggingFace (cached after first call)""" | |
| if self._dataset is None: | |
| print("Loading dataset from HuggingFace...") | |
| start_time = time.time() | |
| hf_token = os.getenv("HF_TOKEN") | |
| dataset = load_dataset( | |
| "evgueni-p/fbmc-features-24month", | |
| split="train", | |
| token=hf_token | |
| ) | |
| # Convert to Polars | |
| self._dataset = pl.from_arrow(dataset.data.table) | |
| # Extract available borders | |
| target_cols = [col for col in self._dataset.columns if col.startswith('target_border_')] | |
| self._borders = [col.replace('target_border_', '') for col in target_cols] | |
| print(f"Dataset loaded in {time.time() - start_time:.1f}s") | |
| print(f" Shape: {self._dataset.shape}") | |
| print(f" Borders: {len(self._borders)}") | |
| # Memory profiling diagnostics | |
| if torch.cuda.is_available(): | |
| print(f" [MEMORY] After dataset load:") | |
| print(f" GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB") | |
| print(f" GPU memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB") | |
| return self._dataset, self._borders | |
| def run_forecast( | |
| self, | |
| run_date: str, | |
| borders: Optional[List[str]] = None, | |
| forecast_days: int = 7, | |
| context_hours: int = 1125, # 1,125 hours = 46.9 days (1.5 months, fits A100-80GB) | |
| num_samples: int = 20 | |
| ) -> Dict: | |
| """ | |
| Run zero-shot forecast for specified borders. | |
| Args: | |
| run_date: Forecast run date (YYYY-MM-DD format) | |
| borders: List of borders to forecast (None = all borders) | |
| forecast_days: Forecast horizon in days (7 or 14) | |
| context_hours: Historical context window | |
| num_samples: Number of probabilistic samples | |
| Returns: | |
| Dictionary with forecast results and metadata | |
| """ | |
| # Load model and dataset (cached) | |
| pipeline = self._load_model() | |
| df, all_borders = self._load_dataset() | |
| # Parse run date | |
| run_datetime = datetime.strptime(run_date, "%Y-%m-%d") | |
| run_datetime = run_datetime.replace(hour=23, minute=0) | |
| # Determine borders to forecast | |
| forecast_borders = borders if borders else all_borders | |
| prediction_hours = forecast_days * 24 | |
| print(f"\nForecast configuration:") | |
| print(f" Run date: {run_datetime}") | |
| print(f" Borders: {len(forecast_borders)}") | |
| print(f" Forecast horizon: {forecast_days} days ({prediction_hours} hours)") | |
| print(f" Context window: {context_hours} hours") | |
| # Initialize dynamic forecast system | |
| forecaster = DynamicForecast( | |
| dataset=df, | |
| context_hours=context_hours, | |
| forecast_hours=prediction_hours | |
| ) | |
| # Run forecasts for each border | |
| results = { | |
| 'run_date': run_date, | |
| 'forecast_days': forecast_days, | |
| 'borders': {}, | |
| 'metadata': { | |
| 'model': self.model_name, | |
| 'device': self.device, | |
| 'num_samples': num_samples, | |
| 'context_hours': context_hours | |
| } | |
| } | |
| total_start = time.time() | |
| # PER-BORDER INFERENCE WITH PAST-ONLY COVARIATE MASKING | |
| # Using predict_df() API with ALL 2,514 features (known-future + past-only masked) | |
| print(f"\n[PAST-ONLY MASKING] Running inference for {len(forecast_borders)} borders with 2,514 features...") | |
| print(f" Known-future: weather, generation, load forecasts (615 features)") | |
| print(f" Past-only masked: CNEC outages, volatility, historical flows (1,899 features)") | |
| for i, border in enumerate(forecast_borders, 1): | |
| # Clear GPU cache BEFORE each border to prevent memory accumulation | |
| # This releases tensors from previous border (no-op on first iteration) | |
| # Does NOT affect model weights (120M params stay loaded) | |
| # Does NOT affect forecast accuracy (each border is independent) | |
| if i > 1: # Skip on first border (clean GPU state) | |
| torch.cuda.empty_cache() | |
| import gc | |
| gc.collect() # Force Python garbage collector to free tensors | |
| border_start = time.time() | |
| print(f"\n [{i}/{len(forecast_borders)}] {border}...", flush=True) | |
| try: | |
| # Extract data WITH covariates | |
| context_data, future_data = forecaster.prepare_forecast_data( | |
| run_date=run_datetime, | |
| border=border | |
| ) | |
| print(f" Context shape: {context_data.shape}, Future shape: {future_data.shape}", flush=True) | |
| print(f" Using {len(future_data.columns)-2} features (known-future + past-only masked)", flush=True) | |
| # Run covariate-informed inference using DataFrame API | |
| # Note: predict_df() returns quantiles directly | |
| # Request 9 quantiles to capture learned uncertainty and tail events | |
| # Use torch.inference_mode() to disable gradient tracking (saves ~2-5 GB VRAM) | |
| with torch.inference_mode(): | |
| forecasts_df = pipeline.predict_df( | |
| context_data, # Historical data with ALL features | |
| future_df=future_data, # All 3,043 features (past-only masked) | |
| prediction_length=prediction_hours, | |
| id_column='border', | |
| timestamp_column='timestamp', | |
| target='target', | |
| batch_size=32, # Reduced from 64 (41.57GB -> 20.79GB attention tensor to fit single GPU) | |
| quantile_levels=[0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95, 0.99] # 9 quantiles for volatility | |
| ) | |
| # Extract all 9 quantiles from predict_df() output | |
| # predict_df() returns quantiles directly as string columns | |
| if isinstance(forecasts_df, pd.DataFrame): | |
| # Expected columns: '0.01', '0.05', '0.1', '0.25', '0.5', '0.75', '0.9', '0.95', '0.99' | |
| quantile_cols = ['0.01', '0.05', '0.1', '0.25', '0.5', '0.75', '0.9', '0.95', '0.99'] | |
| # Extract all quantiles | |
| quantiles = {} | |
| for q in quantile_cols: | |
| if q in forecasts_df.columns: | |
| quantiles[q] = forecasts_df[q].values | |
| else: | |
| # Fallback if quantile missing | |
| if '0.5' in forecasts_df.columns: | |
| quantiles[q] = forecasts_df['0.5'].values # Use median as fallback | |
| elif 'predictions' in forecasts_df.columns: | |
| quantiles[q] = forecasts_df['predictions'].values | |
| else: | |
| raise ValueError(f"Missing quantile {q} and no fallback available. Columns: {forecasts_df.columns.tolist()}") | |
| # Backward compatibility: still extract median, q10, q90 | |
| median = quantiles['0.5'] | |
| q10 = quantiles['0.1'] | |
| q90 = quantiles['0.9'] | |
| else: | |
| raise TypeError(f"Expected DataFrame from predict_df(), got {type(forecasts_df)}") | |
| # Round all quantiles to nearest integer (capacity values are always whole MW) | |
| median = np.round(median).astype(int) | |
| q10 = np.round(q10).astype(int) | |
| q90 = np.round(q90).astype(int) | |
| # Round all other quantiles | |
| for q_key in quantiles: | |
| quantiles[q_key] = np.round(quantiles[q_key]).astype(int) | |
| inference_time = time.time() - border_start | |
| # Store results (backward compatible + all quantiles) | |
| results['borders'][border] = { | |
| 'median': median.tolist(), | |
| 'q10': q10.tolist(), | |
| 'q90': q90.tolist(), | |
| # Add all 9 quantiles for adaptive selection | |
| 'q01': quantiles['0.01'].tolist(), | |
| 'q05': quantiles['0.05'].tolist(), | |
| 'q25': quantiles['0.25'].tolist(), | |
| 'q75': quantiles['0.75'].tolist(), | |
| 'q95': quantiles['0.95'].tolist(), | |
| 'q99': quantiles['0.99'].tolist(), | |
| 'inference_time_s': inference_time, | |
| 'used_covariates': True, | |
| 'num_features': len(future_data.columns) - 2 # Exclude border and timestamp | |
| } | |
| print(f" [OK] Complete in {inference_time:.1f}s ({len(future_data.columns)-2} features with past-only masking)", flush=True) | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"{type(e).__name__}: {str(e)}" | |
| traceback_str = traceback.format_exc() | |
| print(f" [ERROR] {error_msg}", flush=True) | |
| print(f"Traceback:\n{traceback_str}", flush=True) | |
| results['borders'][border] = {'error': error_msg, 'traceback': traceback_str} | |
| # Add summary metadata | |
| results['metadata']['total_time_s'] = time.time() - total_start | |
| results['metadata']['successful_borders'] = sum( | |
| 1 for b in results['borders'].values() if 'error' not in b | |
| ) | |
| print(f"\n{'='*60}") | |
| print(f"FORECAST COMPLETE") | |
| print(f"{'='*60}") | |
| print(f"Total time: {results['metadata']['total_time_s']:.1f}s") | |
| print(f"Successful: {results['metadata']['successful_borders']}/{len(forecast_borders)} borders") | |
| return results | |
| def export_to_parquet(self, results: Dict, output_path: str): | |
| """ | |
| Export forecast results to parquet format. | |
| Args: | |
| results: Forecast results from run_forecast() | |
| output_path: Path to save parquet file | |
| """ | |
| # Create forecast timestamps | |
| run_datetime = datetime.strptime(results['run_date'], "%Y-%m-%d") | |
| forecast_start = run_datetime + timedelta(days=1) # Next day at midnight, not +1 hour | |
| forecast_hours = results['forecast_days'] * 24 | |
| timestamps = [ | |
| forecast_start + timedelta(hours=h) | |
| for h in range(forecast_hours) | |
| ] | |
| # Build DataFrame | |
| data = {'timestamp': timestamps} | |
| successful_borders = [] | |
| failed_borders = [] | |
| for border, forecast_data in results['borders'].items(): | |
| if 'error' not in forecast_data: | |
| data[f'{border}_median'] = forecast_data['median'] | |
| data[f'{border}_q10'] = forecast_data['q10'] | |
| data[f'{border}_q90'] = forecast_data['q90'] | |
| # Add adaptive forecast if available (learned uncertainty-based selection) | |
| if 'adaptive' in forecast_data: | |
| data[f'{border}_adaptive'] = forecast_data['adaptive'] | |
| successful_borders.append(border) | |
| else: | |
| failed_borders.append((border, forecast_data['error'])) | |
| # Log results | |
| print(f"[EXPORT] Forecast export summary:", flush=True) | |
| print(f" Successful: {len(successful_borders)} borders", flush=True) | |
| print(f" Failed: {len(failed_borders)} borders", flush=True) | |
| if failed_borders: | |
| print(f"[EXPORT] Errors:", flush=True) | |
| for border, error in failed_borders: | |
| print(f" {border}: {error}", flush=True) | |
| df = pl.DataFrame(data) | |
| df.write_parquet(output_path) | |
| print(f"[EXPORT] Exported to: {output_path}", flush=True) | |
| print(f"[EXPORT] Shape: {df.shape}, Columns: {len(df.columns)}", flush=True) | |
| return output_path | |
| # Convenience function for API usage | |
| def run_inference( | |
| run_date: str, | |
| forecast_type: str = "smoke_test", | |
| borders: Optional[List[str]] = None, | |
| output_dir: str = "/tmp" | |
| ) -> str: | |
| """ | |
| Run forecast and return path to results file. | |
| Args: | |
| run_date: Forecast run date (YYYY-MM-DD) | |
| forecast_type: 'smoke_test' (7 days, 1 border) or 'full_14day' (14 days, all borders) | |
| borders: Specific borders to forecast (None = use forecast_type defaults) | |
| output_dir: Directory to save results | |
| Returns: | |
| Path to forecast results parquet file | |
| """ | |
| # Initialize pipeline | |
| pipeline = ChronosInferencePipeline() | |
| # Configure based on forecast type | |
| if forecast_type == "smoke_test": | |
| forecast_days = 7 | |
| if borders is None: | |
| # Load just to get first border | |
| _, all_borders = pipeline._load_dataset() | |
| borders = [all_borders[0]] | |
| else: # full_14day | |
| forecast_days = 14 | |
| # borders = None means all borders | |
| # Run forecast | |
| results = pipeline.run_forecast( | |
| run_date=run_date, | |
| borders=borders, | |
| forecast_days=forecast_days | |
| ) | |
| # Write debug file | |
| debug_filename = f"debug_{run_date}_{forecast_type}.txt" | |
| debug_path = os.path.join(output_dir, debug_filename) | |
| with open(debug_path, 'w') as f: | |
| f.write(f"Results summary:\n") | |
| f.write(f" Run date: {results['run_date']}\n") | |
| f.write(f" Forecast days: {results['forecast_days']}\n") | |
| f.write(f" Borders in results: {list(results['borders'].keys())}\n\n") | |
| for border, data in results['borders'].items(): | |
| if 'error' in data: | |
| f.write(f" {border}: ERROR - {data['error']}\n") | |
| if 'traceback' in data: | |
| f.write(f"\nFull Traceback:\n{data['traceback']}\n") | |
| else: | |
| f.write(f" {border}: OK\n") | |
| f.write(f" median count: {len(data.get('median', []))}\n") | |
| f.write(f" q10 count: {len(data.get('q10', []))}\n") | |
| f.write(f" q90 count: {len(data.get('q90', []))}\n") | |
| print(f"Debug file written to: {debug_path}", flush=True) | |
| # Export to parquet | |
| output_filename = f"forecast_{run_date}_{forecast_type}.parquet" | |
| output_path = os.path.join(output_dir, output_filename) | |
| pipeline.export_to_parquet(results, output_path) | |
| # Check if forecast has data, if not return debug file | |
| successful_count = sum(1 for data in results['borders'].values() if 'error' not in data) | |
| if successful_count == 0: | |
| print(f"[WARNING] No successful forecasts! Returning debug file instead.", flush=True) | |
| return debug_path | |
| return output_path | |