fbmc-chronos2 / src /forecasting /dynamic_forecast.py
Evgueni Poloukarov
fix: revert context window to 512h and upgrade GPU to 1xL4 (24GB)
42acd7e
raw
history blame
10.2 kB
#!/usr/bin/env python3
"""
Dynamic Forecast Module
Time-aware data extraction for forecasting with run-date awareness.
Purpose: Prevent data leakage by extracting data AS IT WAS KNOWN at run time.
Key Concepts:
- run_date: When the forecast is made (e.g., "2025-09-30 23:00")
- forecast_horizon: Always 14 days (D+1 to D+14, fixed at 336 hours)
- context_window: Historical data before run_date (typically 512 hours)
- future_covariates: Features available for forecasting (603 full + 12 partial)
"""
from typing import Dict, Tuple, Optional
import pandas as pd
import polars as pl
import numpy as np
from datetime import datetime, timedelta
from src.forecasting.feature_availability import FeatureAvailability
class DynamicForecast:
"""
Handles time-aware data extraction for forecasting.
Ensures no data leakage by only using data available at run_date.
"""
def __init__(
self,
dataset: pl.DataFrame,
context_hours: int = 512,
forecast_hours: int = 336 # Fixed at 14 days
):
"""
Initialize dynamic forecast handler.
Args:
dataset: Polars DataFrame with all features
context_hours: Hours of historical context (default 512)
forecast_hours: Forecast horizon in hours (default 336 = 14 days)
"""
self.dataset = dataset
self.context_hours = context_hours
self.forecast_hours = forecast_hours
# Categorize features on initialization
self.categories = FeatureAvailability.categorize_features(dataset.columns)
# Validate categorization
is_valid, warnings = FeatureAvailability.validate_categorization(
self.categories, verbose=False
)
if not is_valid:
print("[!] WARNING: Feature categorization issues detected")
for w in warnings:
print(f" - {w}")
def prepare_forecast_data(
self,
run_date: datetime,
border: str
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Prepare context and future data for a single border forecast.
Args:
run_date: When the forecast is made (all data before this is historical)
border: Border to forecast (e.g., "AT_CZ")
Returns:
Tuple of (context_data, future_data):
- context_data: Historical features + target (pandas DataFrame)
- future_data: Future covariates only (pandas DataFrame)
"""
# Step 1: Extract historical context
context_data = self._extract_context(run_date, border)
# Step 2: Extract future covariates
future_data = self._extract_future_covariates(run_date, border)
# Step 3: Apply availability masking
future_data = self._apply_masking(future_data, run_date)
return context_data, future_data
def _extract_context(
self,
run_date: datetime,
border: str
) -> pd.DataFrame:
"""
Extract historical context data.
Context includes:
- All features (full+partial+historical) up to run_date
- Target values up to run_date
Args:
run_date: Cutoff timestamp
border: Border identifier
Returns:
Pandas DataFrame with columns: timestamp, border, target, all_features
"""
# Calculate context window
context_start = run_date - timedelta(hours=self.context_hours)
# Filter data
context_df = self.dataset.filter(
(pl.col('timestamp') >= context_start) &
(pl.col('timestamp') < run_date)
)
# Select target column for this border
target_col = f'target_border_{border}'
# All features (we'll use all for context, Chronos-2 handles it)
all_features = (
self.categories['full_horizon_d14'] +
self.categories['partial_d1'] +
self.categories['historical']
)
# Build context DataFrame
context_cols = ['timestamp', target_col] + all_features
context_data = context_df.select(context_cols).to_pandas()
# Add border identifier and rename target
context_data['border'] = border
context_data = context_data.rename(columns={target_col: 'target'})
# Reorder: timestamp, border, target, features
context_data = context_data[['timestamp', 'border', 'target'] + all_features]
return context_data
def _extract_future_covariates(
self,
run_date: datetime,
border: str
) -> pd.DataFrame:
"""
Extract future covariate data for D+1 to D+14.
Future covariates include:
- Full-horizon D+14: 603 features (always available)
- Partial D+1: 12 features (load forecasts, will be masked D+2-D+14)
Args:
run_date: Forecast run timestamp
border: Border identifier
Returns:
Pandas DataFrame with columns: timestamp, border, future_features
"""
# Calculate future window
# IMPORTANT: Chronos-2 predict_df() expects future_df to start at the LAST context timestamp,
# not the first forecast timestamp. See dataset.py:549 assertion.
forecast_start = run_date # Start at last context timestamp
forecast_end = forecast_start + timedelta(hours=self.forecast_hours - 1)
# Filter data
future_df = self.dataset.filter(
(pl.col('timestamp') >= forecast_start) &
(pl.col('timestamp') <= forecast_end)
)
# Select only future covariate features (603 full + 12 partial)
future_features = (
self.categories['full_horizon_d14'] +
self.categories['partial_d1']
)
# Build future DataFrame
future_cols = ['timestamp'] + future_features
future_data = future_df.select(future_cols).to_pandas()
# Add border identifier
future_data['border'] = border
# Reorder: timestamp, border, features
future_data = future_data[['timestamp', 'border'] + future_features]
return future_data
def _apply_masking(
self,
future_data: pd.DataFrame,
run_date: datetime
) -> pd.DataFrame:
"""
Apply availability masking for partial features.
Masking:
- Load forecasts (12 features): Available D+1 only, masked D+2-D+14
- LTA (40 features): Forward-fill from last known value
Args:
future_data: DataFrame with future covariates
run_date: Forecast run timestamp
Returns:
DataFrame with masking applied
"""
# Calculate D+1 cutoff (24 hours after run_date)
d1_cutoff = run_date + timedelta(hours=24)
# Mask load forecasts for D+2 onwards
for col in self.categories['partial_d1']:
# Set to NaN (or 0) for hours beyond D+1
mask = future_data['timestamp'] > d1_cutoff
future_data.loc[mask, col] = np.nan # Chronos-2 handles NaN
# Forward-fill LTA values
# Note: LTA values in dataset should already be forward-filled during
# feature engineering, but we ensure consistency here
lta_cols = [c for c in self.categories['full_horizon_d14']
if c.startswith('lta_')]
# LTA is constant across forecast horizon (use first value)
if len(lta_cols) > 0 and len(future_data) > 0:
first_values = future_data[lta_cols].iloc[0]
for col in lta_cols:
future_data[col] = first_values[col]
return future_data
def validate_no_leakage(
self,
context_data: pd.DataFrame,
future_data: pd.DataFrame,
run_date: datetime
) -> Tuple[bool, list]:
"""
Validate that no data leakage exists.
Checks:
1. All context timestamps < run_date
2. All future timestamps >= run_date + 1 hour
3. No overlap between context and future
4. Future data only contains future covariates
Args:
context_data: Historical context
future_data: Future covariates
run_date: Forecast run timestamp
Returns:
Tuple of (is_valid, errors)
"""
errors = []
# Check 1: Context timestamps
if context_data['timestamp'].max() >= run_date:
errors.append(
f"Context data leaks into future: max timestamp "
f"{context_data['timestamp'].max()} >= run_date {run_date}"
)
# Check 2: Future timestamps
forecast_start = run_date + timedelta(hours=1)
if future_data['timestamp'].min() < forecast_start:
errors.append(
f"Future data includes historical: min timestamp "
f"{future_data['timestamp'].min()} < forecast_start {forecast_start}"
)
# Check 3: No overlap
if (context_data['timestamp'].max() >= future_data['timestamp'].min()):
errors.append("Overlap detected between context and future data")
# Check 4: Future columns
future_features = set(
self.categories['full_horizon_d14'] +
self.categories['partial_d1']
)
future_cols = set(future_data.columns) - {'timestamp', 'border'}
if not future_cols.issubset(future_features):
extra_cols = future_cols - future_features
errors.append(
f"Future data contains non-future features: {extra_cols}"
)
is_valid = len(errors) == 0
return is_valid, errors
def get_feature_summary(self) -> Dict[str, int]:
"""
Get summary of feature categorization.
Returns:
Dictionary with feature counts by category
"""
return {
'full_horizon_d14': len(self.categories['full_horizon_d14']),
'partial_d1': len(self.categories['partial_d1']),
'historical': len(self.categories['historical']),
'total': sum(len(v) for v in self.categories.values())
}