Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Smoke Test for Chronos 2 Zero-Shot Inference | |
| Tests: 1 border × 7 days (168 hours) | |
| """ | |
| import time | |
| import pandas as pd | |
| import numpy as np | |
| import polars as pl | |
| from datetime import datetime, timedelta | |
| from chronos import Chronos2Pipeline | |
| import torch | |
| print("="*60) | |
| print("CHRONOS 2 ZERO-SHOT INFERENCE - SMOKE TEST") | |
| print("="*60) | |
| # Step 1: Load dataset | |
| print("\n[1/6] Loading dataset from HuggingFace...") | |
| start_time = time.time() | |
| from datasets import load_dataset | |
| import os | |
| # Use HF token for private dataset access | |
| hf_token = "<HF_TOKEN>" | |
| dataset = load_dataset( | |
| "evgueni-p/fbmc-features-24month", | |
| split="train", | |
| token=hf_token | |
| ) | |
| df = pl.from_pandas(dataset.to_pandas()) | |
| # Ensure timestamp is datetime (check if conversion needed) | |
| if df['timestamp'].dtype == pl.String: | |
| df = df.with_columns(pl.col('timestamp').str.to_datetime()) | |
| elif df['timestamp'].dtype != pl.Datetime: | |
| df = df.with_columns(pl.col('timestamp').cast(pl.Datetime)) | |
| print(f"[OK] Loaded {len(df)} rows, {len(df.columns)} columns") | |
| print(f" Date range: {df['timestamp'].min()} to {df['timestamp'].max()}") | |
| print(f" Load time: {time.time() - start_time:.1f}s") | |
| # Step 2: Identify target borders | |
| print("\n[2/6] Identifying target borders...") | |
| target_cols = [col for col in df.columns if col.startswith('target_border_')] | |
| borders = [col.replace('target_border_', '') for col in target_cols] | |
| print(f"[OK] Found {len(borders)} borders") | |
| # Select first border for test | |
| test_border = borders[0] | |
| print(f"[*] Test border: {test_border}") | |
| # Step 3: Prepare test data | |
| print("\n[3/6] Preparing test data...") | |
| # Use last available date as forecast date | |
| forecast_date = df['timestamp'].max() | |
| context_hours = 512 | |
| prediction_hours = 168 # 7 days | |
| # Get context data | |
| context_start = forecast_date - timedelta(hours=context_hours) | |
| context_df = df.filter( | |
| (pl.col('timestamp') >= context_start) & | |
| (pl.col('timestamp') < forecast_date) | |
| ) | |
| print(f"[OK] Context: {len(context_df)} hours ({context_start} to {forecast_date})") | |
| # Prepare context DataFrame for Chronos | |
| target_col = f'target_border_{test_border}' | |
| context_data = context_df.select([ | |
| 'timestamp', | |
| pl.lit(test_border).alias('border'), | |
| pl.col(target_col).alias('target') | |
| ]).to_pandas() | |
| # Simple future covariates (just timestamp and border for smoke test) | |
| future_timestamps = pd.date_range( | |
| start=forecast_date, | |
| periods=prediction_hours, | |
| freq='H' | |
| ) | |
| future_data = pd.DataFrame({ | |
| 'timestamp': future_timestamps, | |
| 'border': [test_border] * prediction_hours, | |
| 'target': [np.nan] * prediction_hours # NaN for future values to predict | |
| }) | |
| print(f"[OK] Future: {len(future_data)} hours") | |
| print(f" Context shape: {context_data.shape}") | |
| print(f" Future shape: {future_data.shape}") | |
| # Step 4: Load model | |
| print("\n[4/6] Loading Chronos 2 model on GPU...") | |
| model_start = time.time() | |
| pipeline = Chronos2Pipeline.from_pretrained( | |
| 'amazon/chronos-2', | |
| device_map='cuda', | |
| dtype=torch.float32 | |
| ) | |
| model_time = time.time() - model_start | |
| print(f"[OK] Model loaded in {model_time:.1f}s") | |
| print(f" Device: {next(pipeline.model.parameters()).device}") | |
| # Step 5: Run inference | |
| print(f"\n[5/6] Running zero-shot inference...") | |
| print(f" Border: {test_border}") | |
| print(f" Prediction: {prediction_hours} hours (7 days)") | |
| print(f" Samples: 100 (for probabilistic forecast)") | |
| inference_start = time.time() | |
| try: | |
| # Combine context and future data | |
| combined_df = pd.concat([context_data, future_data], ignore_index=True) | |
| forecasts = pipeline.predict_df( | |
| df=combined_df, | |
| prediction_length=prediction_hours, | |
| id_column='border', | |
| timestamp_column='timestamp', | |
| target='target' | |
| ) | |
| inference_time = time.time() - inference_start | |
| print(f"[OK] Inference complete in {inference_time:.1f}s") | |
| print(f" Forecast shape: {forecasts.shape}") | |
| # Step 6: Validate results | |
| print("\n[6/6] Validating results...") | |
| # Check for NaN values | |
| nan_count = forecasts.isna().sum().sum() | |
| print(f" NaN values: {nan_count}") | |
| if 'mean' in forecasts.columns: | |
| mean_forecast = forecasts['mean'] | |
| print(f" Forecast statistics:") | |
| print(f" Mean: {mean_forecast.mean():.2f} MW") | |
| print(f" Min: {mean_forecast.min():.2f} MW") | |
| print(f" Max: {mean_forecast.max():.2f} MW") | |
| print(f" Std: {mean_forecast.std():.2f} MW") | |
| # Sanity checks | |
| if mean_forecast.min() < 0: | |
| print(" [!] WARNING: Negative forecasts detected") | |
| if mean_forecast.max() > 20000: | |
| print(" [!] WARNING: Unreasonably high forecasts") | |
| if nan_count == 0 and mean_forecast.min() >= 0 and mean_forecast.max() < 20000: | |
| print(" [OK] Validation passed!") | |
| # Performance summary | |
| print("\n" + "="*60) | |
| print("SMOKE TEST SUMMARY") | |
| print("="*60) | |
| print(f"Border tested: {test_border}") | |
| print(f"Forecast length: {prediction_hours} hours (7 days)") | |
| print(f"Inference time: {inference_time:.1f}s") | |
| print(f"Speed: {prediction_hours / inference_time:.1f} hours/second") | |
| # Estimate full run time | |
| total_borders = len(borders) | |
| full_forecast_hours = 336 # 14 days | |
| estimated_time = (inference_time / prediction_hours) * full_forecast_hours * total_borders | |
| print(f"\nEstimated time for full run:") | |
| print(f" {total_borders} borders × {full_forecast_hours} hours") | |
| print(f" = {estimated_time / 60:.1f} minutes ({estimated_time / 3600:.1f} hours)") | |
| # Target check | |
| if inference_time < 300: # 5 minutes | |
| print(f"\n[OK] Performance target met! (<5 min for 7-day forecast)") | |
| else: | |
| print(f"\n[!] Performance slower than target (expected <5 min)") | |
| print("="*60) | |
| print("[OK] SMOKE TEST PASSED!") | |
| print("="*60) | |
| except Exception as e: | |
| print(f"\n[ERROR] Inference failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| exit(1) | |
| # Total time | |
| total_time = time.time() - start_time | |
| print(f"\nTotal test time: {total_time:.1f}s ({total_time / 60:.1f} min)") | |