File size: 15,735 Bytes
12f45c0
 
0b4284f
12f45c0
0b4284f
4be9db4
12f45c0
 
 
 
 
 
c85b8a5
 
 
 
 
 
12f45c0
 
2c1d599
12f45c0
 
0b4284f
12f45c0
 
 
 
 
 
 
0b4284f
 
12f45c0
 
 
 
 
0b4284f
12f45c0
c8d76da
12f45c0
 
 
 
 
0b4284f
12f45c0
c8d76da
12f45c0
 
 
 
 
 
 
 
 
 
 
0b4284f
12f45c0
 
 
 
 
 
 
 
 
 
0b4284f
12f45c0
 
0b4284f
12f45c0
 
572e6a8
 
 
12f45c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c366480
12f45c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b4284f
 
 
 
12f45c0
dc9b9db
b8daa7e
 
 
 
 
 
 
0b4284f
 
 
12f45c0
0b4284f
12f45c0
 
 
 
 
0b4284f
 
2d135b5
0b4284f
6fa9b28
572e6a8
7a9aff9
572e6a8
 
 
 
 
 
 
7a9aff9
 
 
572e6a8
2d135b5
6fa9b28
 
0b4284f
6fa9b28
 
 
 
 
 
 
 
0b4284f
 
6fa9b28
 
 
 
0b4284f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc9b9db
12f45c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67808ce
12f45c0
 
 
 
 
 
 
 
 
7d5b63d
 
 
12f45c0
 
 
 
 
 
7d5b63d
 
 
 
 
 
 
 
 
 
 
 
 
12f45c0
 
 
7d5b63d
 
12f45c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d5b63d
 
 
 
cd3a36f
 
 
 
7d5b63d
 
cd3a36f
f197da0
 
7d5b63d
cd3a36f
 
 
 
7d5b63d
3f32d3a
12f45c0
 
 
 
7d5b63d
 
 
 
 
 
 
12f45c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
#!/usr/bin/env python3
"""
Chronos-2 Inference Pipeline with Covariate Support
Standalone inference script for HuggingFace Space deployment.
Uses predict_df() API to enable multivariate forecasting with weather, generation, CNEC outages.
FORCE REBUILD: v1.3.0 - Context reduced to 128h for memory
"""

import os
import time
from typing import List, Dict, Optional
from datetime import datetime, timedelta

# CRITICAL: Set PyTorch memory allocator config BEFORE importing torch
# This prevents memory fragmentation issues that cause OOM even with sufficient free memory
# See: https://pytorch.org/docs/stable/notes/cuda.html#environment-variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import polars as pl
import pandas as pd
import numpy as np
import torch
from datasets import load_dataset
from chronos import Chronos2Pipeline

from .dynamic_forecast import DynamicForecast
from .feature_availability import FeatureAvailability


class ChronosInferencePipeline:
    """
    Production inference pipeline for Chronos-2 zero-shot forecasting WITH COVARIATES.
    Uses predict_df() API to leverage all 615 collected features (weather, generation, outages, etc.)
    Designed for deployment as API endpoint on HuggingFace Spaces.
    """

    def __init__(
        self,
        model_name: str = "amazon/chronos-2",
        device: str = "cuda",
        dtype: str = "bfloat16"
    ):
        """
        Initialize inference pipeline.

        Args:
            model_name: HuggingFace model identifier (chronos-2 supports covariates)
            device: Device for inference ('cuda' or 'cpu')
            dtype: Data type for model weights (bfloat16 for memory efficiency)
        """
        self.model_name = model_name
        self.device = device
        self.dtype = dtype

        # Model loaded on first inference (lazy loading)
        self._pipeline = None
        self._dataset = None
        self._borders = None

    def _load_model(self):
        """Load Chronos-2 model (cached after first call)"""
        if self._pipeline is None:
            print(f"Loading {self.model_name}...")
            start_time = time.time()

            dtype_map = {
                "bfloat16": torch.bfloat16,
                "float16": torch.float16,
                "float32": torch.float32
            }

            self._pipeline = Chronos2Pipeline.from_pretrained(
                self.model_name,
                device_map=self.device,
                torch_dtype=dtype_map.get(self.dtype, torch.float32)
            )

            # Set model to evaluation mode (disables dropout, etc.)
            self._pipeline.model.eval()

            print(f"Model loaded in {time.time() - start_time:.1f}s")
            print(f"  Device: {next(self._pipeline.model.parameters()).device}")

        return self._pipeline

    def _load_dataset(self):
        """Load dataset from HuggingFace (cached after first call)"""
        if self._dataset is None:
            print("Loading dataset from HuggingFace...")
            start_time = time.time()

            hf_token = os.getenv("HF_TOKEN")
            dataset = load_dataset(
                "evgueni-p/fbmc-features-24month",
                split="train",
                token=hf_token
            )

            # Convert to Polars
            self._dataset = pl.from_arrow(dataset.data.table)

            # Extract available borders
            target_cols = [col for col in self._dataset.columns if col.startswith('target_border_')]
            self._borders = [col.replace('target_border_', '') for col in target_cols]

            print(f"Dataset loaded in {time.time() - start_time:.1f}s")
            print(f"  Shape: {self._dataset.shape}")
            print(f"  Borders: {len(self._borders)}")

        return self._dataset, self._borders

    def run_forecast(
        self,
        run_date: str,
        borders: Optional[List[str]] = None,
        forecast_days: int = 7,
        context_hours: int = 504,
        num_samples: int = 20
    ) -> Dict:
        """
        Run zero-shot forecast for specified borders.

        Args:
            run_date: Forecast run date (YYYY-MM-DD format)
            borders: List of borders to forecast (None = all borders)
            forecast_days: Forecast horizon in days (7 or 14)
            context_hours: Historical context window
            num_samples: Number of probabilistic samples

        Returns:
            Dictionary with forecast results and metadata
        """
        # Load model and dataset (cached)
        pipeline = self._load_model()
        df, all_borders = self._load_dataset()

        # Parse run date
        run_datetime = datetime.strptime(run_date, "%Y-%m-%d")
        run_datetime = run_datetime.replace(hour=23, minute=0)

        # Determine borders to forecast
        forecast_borders = borders if borders else all_borders
        prediction_hours = forecast_days * 24

        print(f"\nForecast configuration:")
        print(f"  Run date: {run_datetime}")
        print(f"  Borders: {len(forecast_borders)}")
        print(f"  Forecast horizon: {forecast_days} days ({prediction_hours} hours)")
        print(f"  Context window: {context_hours} hours")

        # Initialize dynamic forecast system
        forecaster = DynamicForecast(
            dataset=df,
            context_hours=context_hours,
            forecast_hours=prediction_hours
        )

        # Run forecasts for each border
        results = {
            'run_date': run_date,
            'forecast_days': forecast_days,
            'borders': {},
            'metadata': {
                'model': self.model_name,
                'device': self.device,
                'num_samples': num_samples,
                'context_hours': context_hours
            }
        }

        total_start = time.time()

        # PER-BORDER INFERENCE WITH COVARIATES
        # Using predict_df() API to leverage all 615 features (weather, generation, CNEC outages, etc.)
        print(f"\n[COVARIATE FORECAST] Running inference for {len(forecast_borders)} borders with 615 features...")
        print(f"  Features: weather per zone, generation per zone, CNEC outages, LTA, load forecasts")

        for i, border in enumerate(forecast_borders, 1):
            # Clear GPU cache BEFORE each border to prevent memory accumulation
            # This releases tensors from previous border (no-op on first iteration)
            # Does NOT affect model weights (710M params stay loaded)
            # Does NOT affect forecast accuracy (each border is independent)
            if i > 1:  # Skip on first border (clean GPU state)
                torch.cuda.empty_cache()

            border_start = time.time()
            print(f"\n  [{i}/{len(forecast_borders)}] {border}...", flush=True)

            try:
                # Extract data WITH covariates
                context_data, future_data = forecaster.prepare_forecast_data(
                    run_date=run_datetime,
                    border=border
                )

                print(f"    Context shape: {context_data.shape}, Future shape: {future_data.shape}", flush=True)
                print(f"    Using {len(future_data.columns)-2} future covariates for multivariate forecast", flush=True)

                # Run covariate-informed inference using DataFrame API
                # Note: predict_df() returns quantiles directly (0.1, 0.5, 0.9 by default)
                # Use torch.inference_mode() to disable gradient tracking (saves ~2-5 GB VRAM)
                # Memory optimizations: batch_size=32 (from 256), 3 quantiles (from 9)
                with torch.inference_mode():
                    forecasts_df = pipeline.predict_df(
                        context_data,  # Historical data with ALL features
                        future_df=future_data,  # Future covariates (615 features)
                        prediction_length=prediction_hours,
                        id_column='border',
                        timestamp_column='timestamp',
                        target='target',
                        batch_size=32,  # Reduce from default 256 to save GPU memory
                        quantile_levels=[0.1, 0.5, 0.9]  # Only compute needed quantiles (not all 9)
                    )

                # Extract quantiles from predict_df() output
                # predict_df() returns quantiles directly as string columns: "0.1", "0.5", "0.9"
                if isinstance(forecasts_df, pd.DataFrame):
                    # Chronos-2 predict_df() returns columns: 'predictions', '0.1', '0.5', '0.9'
                    if '0.5' in forecasts_df.columns and '0.1' in forecasts_df.columns and '0.9' in forecasts_df.columns:
                        median = forecasts_df['0.5'].values
                        q10 = forecasts_df['0.1'].values
                        q90 = forecasts_df['0.9'].values
                    elif 'predictions' in forecasts_df.columns:
                        # Fallback: use predictions as median (no uncertainty bounds)
                        median = forecasts_df['predictions'].values
                        q10 = median.copy()
                        q90 = median.copy()
                    else:
                        raise ValueError(f"Unexpected predict_df output format. Columns: {forecasts_df.columns.tolist()}")
                else:
                    raise TypeError(f"Expected DataFrame from predict_df(), got {type(forecasts_df)}")

                inference_time = time.time() - border_start

                # Store results
                results['borders'][border] = {
                    'median': median.tolist(),
                    'q10': q10.tolist(),
                    'q90': q90.tolist(),
                    'inference_time_s': inference_time,
                    'used_covariates': True,
                    'num_features': len(future_data.columns) - 2  # Exclude border and timestamp
                }

                print(f"    [OK] Complete in {inference_time:.1f}s (WITH {len(future_data.columns)-2} covariates)", flush=True)

            except Exception as e:
                import traceback
                error_msg = f"{type(e).__name__}: {str(e)}"
                traceback_str = traceback.format_exc()
                print(f"    [ERROR] {error_msg}", flush=True)
                print(f"Traceback:\n{traceback_str}", flush=True)
                results['borders'][border] = {'error': error_msg, 'traceback': traceback_str}

        # Add summary metadata
        results['metadata']['total_time_s'] = time.time() - total_start
        results['metadata']['successful_borders'] = sum(
            1 for b in results['borders'].values() if 'error' not in b
        )

        print(f"\n{'='*60}")
        print(f"FORECAST COMPLETE")
        print(f"{'='*60}")
        print(f"Total time: {results['metadata']['total_time_s']:.1f}s")
        print(f"Successful: {results['metadata']['successful_borders']}/{len(forecast_borders)} borders")

        return results

    def export_to_parquet(self, results: Dict, output_path: str):
        """
        Export forecast results to parquet format.

        Args:
            results: Forecast results from run_forecast()
            output_path: Path to save parquet file
        """
        # Create forecast timestamps
        run_datetime = datetime.strptime(results['run_date'], "%Y-%m-%d")
        forecast_start = run_datetime + timedelta(days=1)  # Next day at midnight, not +1 hour
        forecast_hours = results['forecast_days'] * 24

        timestamps = [
            forecast_start + timedelta(hours=h)
            for h in range(forecast_hours)
        ]

        # Build DataFrame
        data = {'timestamp': timestamps}
        
        successful_borders = []
        failed_borders = []

        for border, forecast_data in results['borders'].items():
            if 'error' not in forecast_data:
                data[f'{border}_median'] = forecast_data['median']
                data[f'{border}_q10'] = forecast_data['q10']
                data[f'{border}_q90'] = forecast_data['q90']
                successful_borders.append(border)
            else:
                failed_borders.append((border, forecast_data['error']))

        # Log results
        print(f"[EXPORT] Forecast export summary:", flush=True)
        print(f"  Successful: {len(successful_borders)} borders", flush=True)
        print(f"  Failed: {len(failed_borders)} borders", flush=True)
        if failed_borders:
            print(f"[EXPORT] Errors:", flush=True)
            for border, error in failed_borders:
                print(f"  {border}: {error}", flush=True)
        
        df = pl.DataFrame(data)
        df.write_parquet(output_path)

        print(f"[EXPORT] Exported to: {output_path}", flush=True)
        print(f"[EXPORT] Shape: {df.shape}, Columns: {len(df.columns)}", flush=True)

        return output_path


# Convenience function for API usage
def run_inference(
    run_date: str,
    forecast_type: str = "smoke_test",
    borders: Optional[List[str]] = None,
    output_dir: str = "/tmp"
) -> str:
    """
    Run forecast and return path to results file.

    Args:
        run_date: Forecast run date (YYYY-MM-DD)
        forecast_type: 'smoke_test' (7 days, 1 border) or 'full_14day' (14 days, all borders)
        borders: Specific borders to forecast (None = use forecast_type defaults)
        output_dir: Directory to save results

    Returns:
        Path to forecast results parquet file
    """
    # Initialize pipeline
    pipeline = ChronosInferencePipeline()

    # Configure based on forecast type
    if forecast_type == "smoke_test":
        forecast_days = 7
        if borders is None:
            # Load just to get first border
            _, all_borders = pipeline._load_dataset()
            borders = [all_borders[0]]
    else:  # full_14day
        forecast_days = 14
        # borders = None means all borders

    # Run forecast
    results = pipeline.run_forecast(
        run_date=run_date,
        borders=borders,
        forecast_days=forecast_days
    )

    # Write debug file
    debug_filename = f"debug_{run_date}_{forecast_type}.txt"
    debug_path = os.path.join(output_dir, debug_filename)
    with open(debug_path, 'w') as f:
        f.write(f"Results summary:\n")
        f.write(f"  Run date: {results['run_date']}\n")
        f.write(f"  Forecast days: {results['forecast_days']}\n")
        f.write(f"  Borders in results: {list(results['borders'].keys())}\n\n")
        for border, data in results['borders'].items():
            if 'error' in data:
                f.write(f"  {border}: ERROR - {data['error']}\n")
                if 'traceback' in data:
                    f.write(f"\nFull Traceback:\n{data['traceback']}\n")
            else:
                f.write(f"  {border}: OK\n")
                f.write(f"    median count: {len(data.get('median', []))}\n")
                f.write(f"    q10 count: {len(data.get('q10', []))}\n")
                f.write(f"    q90 count: {len(data.get('q90', []))}\n")
    print(f"Debug file written to: {debug_path}", flush=True)
    
    # Export to parquet
    output_filename = f"forecast_{run_date}_{forecast_type}.parquet"
    output_path = os.path.join(output_dir, output_filename)
    pipeline.export_to_parquet(results, output_path)
    
    # Check if forecast has data, if not return debug file
    successful_count = sum(1 for data in results['borders'].values() if 'error' not in data)
    if successful_count == 0:
        print(f"[WARNING] No successful forecasts! Returning debug file instead.", flush=True)
        return debug_path
    
    return output_path