fbmc-chronos2 / smoke_test.py
Evgueni Poloukarov
feat: complete Day 3 zero-shot inference pipeline
74bde7a
raw
history blame
6.14 kB
#!/usr/bin/env python3
"""
Smoke Test for Chronos 2 Zero-Shot Inference
Tests: 1 border × 7 days (168 hours)
"""
import time
import pandas as pd
import numpy as np
import polars as pl
from datetime import datetime, timedelta
from chronos import Chronos2Pipeline
import torch
print("="*60)
print("CHRONOS 2 ZERO-SHOT INFERENCE - SMOKE TEST")
print("="*60)
# Step 1: Load dataset
print("\n[1/6] Loading dataset from HuggingFace...")
start_time = time.time()
from datasets import load_dataset
import os
# Use HF token for private dataset access
hf_token = "<HF_TOKEN>"
dataset = load_dataset(
"evgueni-p/fbmc-features-24month",
split="train",
token=hf_token
)
df = pl.from_pandas(dataset.to_pandas())
# Ensure timestamp is datetime (check if conversion needed)
if df['timestamp'].dtype == pl.String:
df = df.with_columns(pl.col('timestamp').str.to_datetime())
elif df['timestamp'].dtype != pl.Datetime:
df = df.with_columns(pl.col('timestamp').cast(pl.Datetime))
print(f"[OK] Loaded {len(df)} rows, {len(df.columns)} columns")
print(f" Date range: {df['timestamp'].min()} to {df['timestamp'].max()}")
print(f" Load time: {time.time() - start_time:.1f}s")
# Step 2: Identify target borders
print("\n[2/6] Identifying target borders...")
target_cols = [col for col in df.columns if col.startswith('target_border_')]
borders = [col.replace('target_border_', '') for col in target_cols]
print(f"[OK] Found {len(borders)} borders")
# Select first border for test
test_border = borders[0]
print(f"[*] Test border: {test_border}")
# Step 3: Prepare test data
print("\n[3/6] Preparing test data...")
# Use last available date as forecast date
forecast_date = df['timestamp'].max()
context_hours = 512
prediction_hours = 168 # 7 days
# Get context data
context_start = forecast_date - timedelta(hours=context_hours)
context_df = df.filter(
(pl.col('timestamp') >= context_start) &
(pl.col('timestamp') < forecast_date)
)
print(f"[OK] Context: {len(context_df)} hours ({context_start} to {forecast_date})")
# Prepare context DataFrame for Chronos
target_col = f'target_border_{test_border}'
context_data = context_df.select([
'timestamp',
pl.lit(test_border).alias('border'),
pl.col(target_col).alias('target')
]).to_pandas()
# Simple future covariates (just timestamp and border for smoke test)
future_timestamps = pd.date_range(
start=forecast_date,
periods=prediction_hours,
freq='H'
)
future_data = pd.DataFrame({
'timestamp': future_timestamps,
'border': [test_border] * prediction_hours,
'target': [np.nan] * prediction_hours # NaN for future values to predict
})
print(f"[OK] Future: {len(future_data)} hours")
print(f" Context shape: {context_data.shape}")
print(f" Future shape: {future_data.shape}")
# Step 4: Load model
print("\n[4/6] Loading Chronos 2 model on GPU...")
model_start = time.time()
pipeline = Chronos2Pipeline.from_pretrained(
'amazon/chronos-2',
device_map='cuda',
dtype=torch.float32
)
model_time = time.time() - model_start
print(f"[OK] Model loaded in {model_time:.1f}s")
print(f" Device: {next(pipeline.model.parameters()).device}")
# Step 5: Run inference
print(f"\n[5/6] Running zero-shot inference...")
print(f" Border: {test_border}")
print(f" Prediction: {prediction_hours} hours (7 days)")
print(f" Samples: 100 (for probabilistic forecast)")
inference_start = time.time()
try:
# Combine context and future data
combined_df = pd.concat([context_data, future_data], ignore_index=True)
forecasts = pipeline.predict_df(
df=combined_df,
prediction_length=prediction_hours,
id_column='border',
timestamp_column='timestamp',
target='target'
)
inference_time = time.time() - inference_start
print(f"[OK] Inference complete in {inference_time:.1f}s")
print(f" Forecast shape: {forecasts.shape}")
# Step 6: Validate results
print("\n[6/6] Validating results...")
# Check for NaN values
nan_count = forecasts.isna().sum().sum()
print(f" NaN values: {nan_count}")
if 'mean' in forecasts.columns:
mean_forecast = forecasts['mean']
print(f" Forecast statistics:")
print(f" Mean: {mean_forecast.mean():.2f} MW")
print(f" Min: {mean_forecast.min():.2f} MW")
print(f" Max: {mean_forecast.max():.2f} MW")
print(f" Std: {mean_forecast.std():.2f} MW")
# Sanity checks
if mean_forecast.min() < 0:
print(" [!] WARNING: Negative forecasts detected")
if mean_forecast.max() > 20000:
print(" [!] WARNING: Unreasonably high forecasts")
if nan_count == 0 and mean_forecast.min() >= 0 and mean_forecast.max() < 20000:
print(" [OK] Validation passed!")
# Performance summary
print("\n" + "="*60)
print("SMOKE TEST SUMMARY")
print("="*60)
print(f"Border tested: {test_border}")
print(f"Forecast length: {prediction_hours} hours (7 days)")
print(f"Inference time: {inference_time:.1f}s")
print(f"Speed: {prediction_hours / inference_time:.1f} hours/second")
# Estimate full run time
total_borders = len(borders)
full_forecast_hours = 336 # 14 days
estimated_time = (inference_time / prediction_hours) * full_forecast_hours * total_borders
print(f"\nEstimated time for full run:")
print(f" {total_borders} borders × {full_forecast_hours} hours")
print(f" = {estimated_time / 60:.1f} minutes ({estimated_time / 3600:.1f} hours)")
# Target check
if inference_time < 300: # 5 minutes
print(f"\n[OK] Performance target met! (<5 min for 7-day forecast)")
else:
print(f"\n[!] Performance slower than target (expected <5 min)")
print("="*60)
print("[OK] SMOKE TEST PASSED!")
print("="*60)
except Exception as e:
print(f"\n[ERROR] Inference failed: {e}")
import traceback
traceback.print_exc()
exit(1)
# Total time
total_time = time.time() - start_time
print(f"\nTotal test time: {total_time:.1f}s ({total_time / 60:.1f} min)")