Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| """ | |
| Compute Embeddings for Major-TOM Sentinel-2 Images | |
| This script generates embeddings for Sentinel-2 imagery using various models: | |
| - DINOv2: Vision Transformer trained with self-supervised learning | |
| - SigLIP: Vision-Language model with sigmoid loss | |
| - FarSLIP: Remote sensing fine-tuned CLIP | |
| - SatCLIP: Satellite imagery CLIP with location awareness | |
| Usage: | |
| python compute_embeddings.py --model dinov2 --device cuda:1 | |
| python compute_embeddings.py --model siglip --device cuda:5 | |
| python compute_embeddings.py --model satclip --device cuda:3 | |
| python compute_embeddings.py --model farslip --device cuda:4 | |
| Author: Generated by Copilot | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| import logging | |
| from pathlib import Path | |
| from datetime import datetime | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from PIL import Image | |
| from tqdm.auto import tqdm | |
| # Add project root to path | |
| PROJECT_ROOT = Path(__file__).parent.absolute() | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from models.load_config import load_and_process_config | |
| # ============================================================================= | |
| # Configuration | |
| # ============================================================================= | |
| METADATA_PATH = Path("/data1/zyj/Core-S2L2A-249k/Core_S2L2A_249k_crop_384x384_metadata.parquet") | |
| IMAGE_PARQUET_DIR = Path("/data1/zyj/Core-S2L2A-249k/images") | |
| OUTPUT_BASE_DIR = Path("/data1/zyj/EarthEmbeddings/Core-S2L2A-249k") | |
| # Columns to remove from output | |
| COLUMNS_TO_REMOVE = ['cloud_cover', 'nodata', 'geometry_wkt', 'bands', 'image_shape', 'image_dtype'] | |
| # Columns to rename | |
| COLUMNS_RENAME = {'crs': 'utm_crs'} | |
| # Pixel bbox for center 384x384 crop from 1068x1068 original | |
| # (1068 - 384) / 2 = 342 | |
| PIXEL_BBOX = [342, 342, 726, 726] # [x_min, y_min, x_max, y_max] | |
| # Model output paths | |
| MODEL_OUTPUT_PATHS = { | |
| 'dinov2': OUTPUT_BASE_DIR / 'dinov2' / 'DINOv2_crop_384x384.parquet', | |
| 'siglip': OUTPUT_BASE_DIR / 'siglip' / 'SigLIP_crop_384x384.parquet', | |
| 'farslip': OUTPUT_BASE_DIR / 'farslip' / 'FarSLIP_crop_384x384.parquet', | |
| 'satclip': OUTPUT_BASE_DIR / 'satclip' / 'SatCLIP_crop_384x384.parquet', | |
| } | |
| # Batch sizes for different models | |
| BATCH_SIZES = { | |
| 'dinov2': 64, | |
| 'siglip': 64, | |
| 'farslip': 64, | |
| 'satclip': 128, | |
| } | |
| # ============================================================================= | |
| # Setup Logging | |
| # ============================================================================= | |
| def setup_logging(model_name: str): | |
| """Configure logging to both file and console.""" | |
| log_dir = PROJECT_ROOT / "logs" | |
| log_dir.mkdir(parents=True, exist_ok=True) | |
| log_file = log_dir / f"compute_embeddings_{model_name}.log" | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| handlers=[ | |
| logging.FileHandler(log_file), | |
| logging.StreamHandler(sys.stdout) | |
| ] | |
| ) | |
| return logging.getLogger(__name__) | |
| # ============================================================================= | |
| # Image Preprocessing Functions | |
| # ============================================================================= | |
| def decode_image_bytes(row) -> np.ndarray: | |
| """ | |
| Decode image bytes from parquet row to numpy array. | |
| Args: | |
| row: pandas Series with 'image_bytes', 'image_shape', 'image_dtype' | |
| Returns: | |
| np.ndarray of shape (H, W, 12) with uint16 values | |
| """ | |
| shape = tuple(map(int, row['image_shape'])) | |
| dtype = np.dtype(row['image_dtype']) | |
| img_flat = np.frombuffer(row['image_bytes'], dtype=dtype) | |
| return img_flat.reshape(shape) | |
| def extract_rgb_image(img_array: np.ndarray, clip_max: float = 4000.0) -> Image.Image: | |
| """ | |
| Extract RGB channels from 12-band Sentinel-2 array. | |
| Sentinel-2 Bands: [B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12] | |
| RGB Mapping: R=B04(idx 3), G=B03(idx 2), B=B02(idx 1) | |
| Args: | |
| img_array: numpy array of shape (H, W, 12) | |
| clip_max: Value to clip reflectance data for visualization | |
| Returns: | |
| PIL.Image: RGB image | |
| """ | |
| # Select RGB Channels: R=B04(3), G=B03(2), B=B02(1) | |
| rgb_bands = img_array[:, :, [3, 2, 1]].astype(np.float32) | |
| # Normalize and Clip | |
| rgb_normalized = np.clip(rgb_bands / clip_max, 0, 1) | |
| # Convert to 8-bit | |
| rgb_uint8 = (rgb_normalized * 255).astype(np.uint8) | |
| return Image.fromarray(rgb_uint8) | |
| # ============================================================================= | |
| # Model Loading Functions | |
| # ============================================================================= | |
| def load_model(model_name: str, device: str, config: dict): | |
| """ | |
| Load the specified model. | |
| Args: | |
| model_name: One of 'dinov2', 'siglip', 'farslip', 'satclip' | |
| device: Device string like 'cuda:0' or 'cpu' | |
| config: Configuration dictionary from local.yaml | |
| Returns: | |
| Model instance | |
| """ | |
| logger = logging.getLogger(__name__) | |
| if model_name == 'dinov2': | |
| from models.dinov2_model import DINOv2Model | |
| model_config = config.get('dinov2', {}) | |
| model = DINOv2Model( | |
| ckpt_path=model_config.get('ckpt_path', '/data1/zyj/checkpoints/dinov2-large'), | |
| model_name='facebook/dinov2-large', | |
| embedding_path=None, # We're generating, not loading | |
| device=device | |
| ) | |
| logger.info(f"DINOv2 model loaded on {device}") | |
| return model | |
| elif model_name == 'siglip': | |
| from models.siglip_model import SigLIPModel | |
| model_config = config.get('siglip', {}) | |
| model = SigLIPModel( | |
| ckpt_path=model_config.get('ckpt_path', './checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin'), | |
| model_name='ViT-SO400M-14-SigLIP-384', | |
| tokenizer_path=model_config.get('tokenizer_path', './checkpoints/ViT-SO400M-14-SigLIP-384'), | |
| embedding_path=None, | |
| device=device | |
| ) | |
| # Disable embedding loading since we set path to None | |
| model.df_embed = None | |
| model.image_embeddings = None | |
| logger.info(f"SigLIP model loaded on {device}") | |
| return model | |
| elif model_name == 'farslip': | |
| from models.farslip_model import FarSLIPModel | |
| model_config = config.get('farslip', {}) | |
| model = FarSLIPModel( | |
| ckpt_path=model_config.get('ckpt_path', './checkpoints/FarSLIP/FarSLIP2_ViT-B-16.pt'), | |
| model_name='ViT-B-16', | |
| embedding_path=None, | |
| device=device | |
| ) | |
| logger.info(f"FarSLIP model loaded on {device}") | |
| return model | |
| elif model_name == 'satclip': | |
| from models.satclip_ms_model import SatCLIPMSModel | |
| model_config = config.get('satclip', {}) | |
| model = SatCLIPMSModel( | |
| ckpt_path=model_config.get('ckpt_path', './checkpoints/SatCLIP/satclip-vit16-l40.ckpt'), | |
| embedding_path=None, | |
| device=device | |
| ) | |
| logger.info(f"SatCLIP-MS model loaded on {device}") | |
| return model | |
| else: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| # ============================================================================= | |
| # Embedding Computation Functions | |
| # ============================================================================= | |
| def compute_embedding_single(model, model_name: str, img_array: np.ndarray) -> np.ndarray: | |
| """ | |
| Compute embedding for a single image. | |
| Args: | |
| model: Model instance | |
| model_name: Model identifier | |
| img_array: numpy array of shape (H, W, 12) | |
| Returns: | |
| np.ndarray: 1D embedding vector | |
| """ | |
| if model_name in ['dinov2', 'siglip', 'farslip']: | |
| # These models use RGB input | |
| rgb_img = extract_rgb_image(img_array) | |
| feature = model.encode_image(rgb_img) | |
| if feature is not None: | |
| return feature.cpu().numpy().flatten() | |
| return None | |
| elif model_name == 'satclip': | |
| # SatCLIP can use multi-spectral input directly | |
| feature = model.encode_image(img_array, is_multispectral=True) | |
| if feature is not None: | |
| return feature.cpu().numpy().flatten() | |
| return None | |
| return None | |
| def compute_embedding_batch(model, model_name: str, img_arrays: list) -> list: | |
| """ | |
| Compute embeddings for a batch of images. | |
| Falls back to single-image processing if batch method unavailable. | |
| Args: | |
| model: Model instance | |
| model_name: Model identifier | |
| img_arrays: List of numpy arrays of shape (H, W, 12) | |
| Returns: | |
| List of 1D embedding vectors (numpy arrays), None for failed items | |
| """ | |
| n_images = len(img_arrays) | |
| if model_name in ['dinov2', 'siglip', 'farslip']: | |
| # These models use RGB input | |
| rgb_imgs = [extract_rgb_image(arr) for arr in img_arrays] | |
| # Try batch encoding first | |
| if hasattr(model, 'encode_images'): | |
| try: | |
| features = model.encode_images(rgb_imgs) | |
| if features is not None: | |
| return [features[i].cpu().numpy().flatten() for i in range(len(features))] | |
| except Exception: | |
| pass # Fall back to single processing | |
| # Fall back to single image encoding | |
| results = [] | |
| for img in rgb_imgs: | |
| try: | |
| feature = model.encode_image(img) | |
| if feature is not None: | |
| results.append(feature.cpu().numpy().flatten()) | |
| else: | |
| results.append(None) | |
| except Exception: | |
| results.append(None) | |
| return results | |
| elif model_name == 'satclip': | |
| # SatCLIP uses multi-spectral input | |
| if hasattr(model, 'encode_images'): | |
| try: | |
| features = model.encode_images(img_arrays, is_multispectral=True) | |
| if features is not None: | |
| return [features[i].cpu().numpy().flatten() for i in range(len(features))] | |
| except Exception: | |
| pass # Fall back to single processing | |
| # Fall back to single image encoding | |
| results = [] | |
| for arr in img_arrays: | |
| try: | |
| feature = model.encode_image(arr, is_multispectral=True) | |
| if feature is not None: | |
| results.append(feature.cpu().numpy().flatten()) | |
| else: | |
| results.append(None) | |
| except Exception: | |
| results.append(None) | |
| return results | |
| return [None] * n_images | |
| # def process_parquet_file( | |
| # file_path: Path, | |
| # model, | |
| # model_name: str, | |
| # batch_size: int = 64 | |
| # ) -> pd.DataFrame: | |
| # """ | |
| # Process a single parquet file and generate embeddings. | |
| # Args: | |
| # file_path: Path to input parquet file | |
| # model: Model instance | |
| # model_name: Model identifier | |
| # batch_size: Batch size for processing | |
| # Returns: | |
| # DataFrame with embeddings | |
| # """ | |
| # logger = logging.getLogger(__name__) | |
| # # Load data | |
| # df = pd.read_parquet(file_path) | |
| # embeddings_list = [] | |
| # valid_indices = [] | |
| # # Process in batches (for future batch optimization) | |
| # for idx, row in df.iterrows(): | |
| # try: | |
| # # Decode image | |
| # img_array = decode_image_bytes(row) | |
| # # Compute embedding | |
| # embedding = compute_embedding_single(model, model_name, img_array) | |
| # if embedding is not None: | |
| # embeddings_list.append(embedding) | |
| # valid_indices.append(idx) | |
| # except Exception as e: | |
| # logger.warning(f"Error processing row {idx}: {e}") | |
| # continue | |
| # if not embeddings_list: | |
| # logger.warning(f"No valid embeddings for {file_path.name}") | |
| # return None | |
| # # Build result DataFrame | |
| # result_df = df.loc[valid_indices].copy() | |
| # # Remove unwanted columns | |
| # cols_to_drop = [c for c in COLUMNS_TO_REMOVE if c in result_df.columns] | |
| # if cols_to_drop: | |
| # result_df = result_df.drop(columns=cols_to_drop) | |
| # # Remove image_bytes (large binary data) | |
| # if 'image_bytes' in result_df.columns: | |
| # result_df = result_df.drop(columns=['image_bytes']) | |
| # # Remove geometry column (binary) | |
| # if 'geometry' in result_df.columns: | |
| # result_df = result_df.drop(columns=['geometry']) | |
| # # Rename columns | |
| # result_df = result_df.rename(columns=COLUMNS_RENAME) | |
| # # Add pixel_bbox | |
| # result_df['pixel_bbox'] = [PIXEL_BBOX] * len(result_df) | |
| # # Add embedding | |
| # result_df['embedding'] = embeddings_list | |
| # return result_df | |
| def process_parquet_file( | |
| file_path: Path, | |
| model, | |
| model_name: str, | |
| batch_size: int = 64 | |
| ) -> pd.DataFrame: | |
| """ | |
| Process a single parquet file and generate embeddings using batch processing. | |
| Args: | |
| file_path: Path to input parquet file | |
| model: Model instance | |
| model_name: Model identifier | |
| batch_size: Batch size for processing | |
| Returns: | |
| DataFrame with embeddings | |
| """ | |
| logger = logging.getLogger(__name__) | |
| # Load data | |
| df = pd.read_parquet(file_path) | |
| n_rows = len(df) | |
| embeddings_list = [None] * n_rows | |
| valid_mask = [False] * n_rows | |
| # Process in batches | |
| for batch_start in range(0, n_rows, batch_size): | |
| batch_end = min(batch_start + batch_size, n_rows) | |
| batch_indices = list(range(batch_start, batch_end)) | |
| # Decode images for this batch | |
| batch_arrays = [] | |
| batch_valid_indices = [] | |
| for idx in batch_indices: | |
| try: | |
| row = df.iloc[idx] | |
| img_array = decode_image_bytes(row) | |
| batch_arrays.append(img_array) | |
| batch_valid_indices.append(idx) | |
| except Exception as e: | |
| logger.warning(f"Error decoding row {idx}: {e}") | |
| continue | |
| if not batch_arrays: | |
| continue | |
| # Compute embeddings for this batch | |
| try: | |
| batch_embeddings = compute_embedding_batch(model, model_name, batch_arrays) | |
| # Store results | |
| for i, idx in enumerate(batch_valid_indices): | |
| if batch_embeddings[i] is not None: | |
| embeddings_list[idx] = batch_embeddings[i] | |
| valid_mask[idx] = True | |
| except Exception as e: | |
| logger.warning(f"Error computing batch embeddings: {e}") | |
| # Fall back to single image processing for this batch | |
| for i, idx in enumerate(batch_valid_indices): | |
| try: | |
| embedding = compute_embedding_single(model, model_name, batch_arrays[i]) | |
| if embedding is not None: | |
| embeddings_list[idx] = embedding | |
| valid_mask[idx] = True | |
| except Exception as inner_e: | |
| logger.warning(f"Error processing row {idx}: {inner_e}") | |
| continue | |
| # Filter to valid rows only | |
| valid_indices = [i for i, v in enumerate(valid_mask) if v] | |
| if not valid_indices: | |
| logger.warning(f"No valid embeddings for {file_path.name}") | |
| return None | |
| # Build result DataFrame | |
| result_df = df.iloc[valid_indices].copy() | |
| valid_embeddings = [embeddings_list[i] for i in valid_indices] | |
| # Remove unwanted columns | |
| cols_to_drop = [c for c in COLUMNS_TO_REMOVE if c in result_df.columns] | |
| if cols_to_drop: | |
| result_df = result_df.drop(columns=cols_to_drop) | |
| # Remove image_bytes (large binary data) | |
| if 'image_bytes' in result_df.columns: | |
| result_df = result_df.drop(columns=['image_bytes']) | |
| # Remove geometry column (binary) | |
| if 'geometry' in result_df.columns: | |
| result_df = result_df.drop(columns=['geometry']) | |
| # Rename columns | |
| result_df = result_df.rename(columns=COLUMNS_RENAME) | |
| # Add pixel_bbox | |
| result_df['pixel_bbox'] = [PIXEL_BBOX] * len(result_df) | |
| # Add embedding | |
| result_df['embedding'] = valid_embeddings | |
| return result_df | |
| # ============================================================================= | |
| # Main Processing Pipeline | |
| # ============================================================================= | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Compute embeddings for Major-TOM images') | |
| parser.add_argument('--model', type=str, required=True, | |
| choices=['dinov2', 'siglip', 'farslip', 'satclip'], | |
| help='Model to use for embedding computation') | |
| parser.add_argument('--device', type=str, default='cuda:0', | |
| help='Device to run on (e.g., cuda:0, cuda:1, cpu)') | |
| parser.add_argument('--batch-size', type=int, default=None, | |
| help='Batch size for processing (default: model-specific)') | |
| parser.add_argument('--max-files', type=int, default=None, | |
| help='Maximum number of files to process (for testing)') | |
| args = parser.parse_args() | |
| # Setup logging | |
| logger = setup_logging(args.model) | |
| logger.info("=" * 80) | |
| logger.info(f"Computing {args.model.upper()} embeddings") | |
| logger.info(f"Timestamp: {datetime.now().isoformat()}") | |
| logger.info(f"Device: {args.device}") | |
| logger.info("=" * 80) | |
| # Load configuration | |
| config = load_and_process_config() | |
| if config is None: | |
| logger.warning("No config file found, using default paths") | |
| config = {} | |
| # Determine batch size | |
| batch_size = args.batch_size or BATCH_SIZES.get(args.model, 64) | |
| logger.info(f"Batch size: {batch_size}") | |
| # Get output path | |
| output_path = MODEL_OUTPUT_PATHS[args.model] | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| logger.info(f"Output path: {output_path}") | |
| # Load model | |
| logger.info(f"Loading {args.model} model...") | |
| model = load_model(args.model, args.device, config) | |
| # Get input files | |
| parquet_files = sorted(IMAGE_PARQUET_DIR.glob("batch_*.parquet")) | |
| if args.max_files: | |
| parquet_files = parquet_files[:args.max_files] | |
| logger.info(f"Found {len(parquet_files)} input files") | |
| # Process files | |
| all_results = [] | |
| total_rows = 0 | |
| for file_path in tqdm(parquet_files, desc=f"Processing {args.model}"): | |
| try: | |
| result_df = process_parquet_file(file_path, model, args.model, batch_size) | |
| if result_df is not None: | |
| all_results.append(result_df) | |
| total_rows += len(result_df) | |
| logger.info(f"[{file_path.name}] Processed {len(result_df)} rows") | |
| except Exception as e: | |
| logger.error(f"Error processing {file_path.name}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| continue | |
| # Merge and save | |
| if all_results: | |
| logger.info("Merging all results...") | |
| final_df = pd.concat(all_results, ignore_index=True) | |
| # Validate columns | |
| logger.info(f"Final columns: {list(final_df.columns)}") | |
| # Check for removed columns | |
| removed = [c for c in COLUMNS_TO_REMOVE if c in final_df.columns] | |
| if removed: | |
| logger.warning(f"Columns still present that should be removed: {removed}") | |
| else: | |
| logger.info("✓ All unwanted columns removed") | |
| # Check for renamed columns | |
| if 'utm_crs' in final_df.columns and 'crs' not in final_df.columns: | |
| logger.info("✓ Column 'crs' renamed to 'utm_crs'") | |
| # Check for pixel_bbox | |
| if 'pixel_bbox' in final_df.columns: | |
| logger.info("✓ Column 'pixel_bbox' added") | |
| # Save | |
| logger.info(f"Saving to {output_path}...") | |
| final_df.to_parquet(output_path, index=False) | |
| logger.info(f"=" * 80) | |
| logger.info(f"Processing complete!") | |
| logger.info(f"Total rows: {len(final_df):,}") | |
| logger.info(f"Embedding dimension: {len(final_df['embedding'].iloc[0])}") | |
| logger.info(f"Output file: {output_path}") | |
| logger.info(f"=" * 80) | |
| else: | |
| logger.error("No data processed!") | |
| return 1 | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |