""" Smoke test for zero-shot inference pipeline Tests: 1. Data loading and preparation 2. Chronos 2 model loading 3. Inference on single border (7 days) 4. Output validation 5. Performance metrics """ import sys from pathlib import Path # Add src to path sys.path.insert(0, str(Path(__file__).parent.parent / 'src')) from inference.data_fetcher import DataFetcher from inference.chronos_pipeline import ChronosForecaster from datetime import datetime, timedelta import torch import pandas as pd def main(): print("="*60) print("FBMC Chronos 2 Zero-Shot Inference - Smoke Test") print("="*60) # Step 1: Check environment print("\n[1] Checking environment...") print(f"PyTorch version: {torch.__version__}") print(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") else: print("Running on CPU (inference will be slower)") # Step 2: Initialize DataFetcher print("\n[2] Initializing DataFetcher...") fetcher = DataFetcher( use_local=True, # Use local files for testing context_length=512 # Use 512 hours context ) # Step 3: Load data print("\n[3] Loading unified features...") fetcher.load_data() # Get available date range min_date, max_date = fetcher.get_available_dates() print(f"Available data: {min_date} to {max_date}") # Select forecast date (use last month as test) forecast_date = max_date - timedelta(days=30) print(f"Test forecast date: {forecast_date}") # Step 4: Prepare inference data (single border, 7 days) print("\n[4] Preparing inference data (1 border, 7 days)...") test_border = fetcher.target_borders[0] # Use first border print(f"Test border: {test_border}") context_df, future_df = fetcher.prepare_inference_data( forecast_date=forecast_date, prediction_length=168, # 7 days borders=[test_border] ) print(f"Context shape: {context_df.shape}") print(f"Future shape: {future_df.shape}") # Validate data print("\n[5] Validating prepared data...") assert 'timestamp' in context_df.columns, "Missing timestamp column" assert 'border' in context_df.columns, "Missing border column" assert 'target' in context_df.columns, "Missing target column" assert len(context_df) > 0, "Empty context data" assert len(future_df) > 0, "Empty future data" print("[+] Data validation passed!") # Check for NaN values context_nulls = context_df.isnull().sum().sum() future_nulls = future_df.isnull().sum().sum() print(f"Context NaN count: {context_nulls}") print(f"Future NaN count: {future_nulls}") if context_nulls > 0 or future_nulls > 0: print("[!] Warning: Data contains NaN values (will be handled by model)") # Step 6: Initialize Chronos 2 forecaster print("\n[6] Initializing Chronos 2 forecaster...") forecaster = ChronosForecaster( model_name="amazon/chronos-2-large", device="auto" # Will use GPU if available ) # Step 7: Load model print("\n[7] Loading Chronos 2 Large model...") print("(This may take a few minutes on first load)") forecaster.load_model() print("[+] Model loaded successfully!") # Step 8: Run inference print("\n[8] Running zero-shot inference...") print(f"Forecasting {test_border} for 7 days (168 hours)") forecasts = forecaster.predict_single_border( border=test_border, context_df=context_df, future_df=future_df, prediction_length=168, num_samples=100 # 100 samples for probabilistic forecast ) print(f"[+] Inference complete! Forecast shape: {forecasts.shape}") # Step 9: Validate forecasts print("\n[9] Validating forecasts...") assert len(forecasts) > 0, "Empty forecasts" assert 'timestamp' in forecasts.columns or forecasts.index.name == 'timestamp', "Missing timestamp" # Check for reasonable values if 'mean' in forecasts.columns: mean_forecast = forecasts['mean'] print(f"Forecast statistics:") print(f" Mean: {mean_forecast.mean():.2f} MW") print(f" Min: {mean_forecast.min():.2f} MW") print(f" Max: {mean_forecast.max():.2f} MW") print(f" Std: {mean_forecast.std():.2f} MW") # Sanity check: values should be reasonable for power capacity assert mean_forecast.min() >= 0, "Negative forecasts detected" assert mean_forecast.max() < 20000, "Unreasonably high forecasts" print("[+] Forecast validation passed!") # Step 10: Benchmark performance print("\n[10] Benchmarking inference performance...") metrics = forecaster.benchmark_inference( context_df=context_df, future_df=future_df, prediction_length=168 ) print(f"Performance metrics:") for key, value in metrics.items(): print(f" {key}: {value}") # Check if we meet the 5-minute target (for 14 days) # Scale to 14-day estimate estimated_14d_time = metrics['inference_time_sec'] * (336 / 168) print(f"\nEstimated time for 14-day forecast: {estimated_14d_time:.1f}s ({estimated_14d_time/60:.1f} min)") if estimated_14d_time < 300: # 5 minutes print("[+] Performance target met! (<5 min for 14 days)") else: print("[!] Warning: May not meet 5-minute target for 14 days") # Step 11: Save test forecasts print("\n[11] Saving test forecasts...") output_path = "data/evaluation/smoke_test_forecast.parquet" forecaster.save_forecasts(forecasts, output_path) print(f"[+] Saved to: {output_path}") # Summary print("\n" + "="*60) print("SMOKE TEST SUMMARY") print("="*60) print("[+] All tests passed!") print(f"[+] Border: {test_border}") print(f"[+] Forecast length: 168 hours (7 days)") print(f"[+] Inference time: {metrics['inference_time_sec']:.1f}s") print(f"[+] Output shape: {forecasts.shape}") print("\n[+] Ready for full inference run!") print("="*60) if __name__ == "__main__": main()