Spaces:
Runtime error
Runtime error
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "datasets", | |
| # "hf-transfer", | |
| # "hf-xet", | |
| # "huggingface-hub", | |
| # "polars", | |
| # "sentence-transformers", | |
| # "torch", | |
| # ] | |
| # /// | |
| """ | |
| Compute embeddings for dataset/model summaries and push to a Hub dataset. | |
| The output dataset has two configs: | |
| - dataset_cards: id, summary, embedding, likes, downloads, last_modified | |
| - model_cards: id, summary, embedding, likes, downloads, last_modified, param_count | |
| The Space then loads these pre-computed embeddings into ChromaDB on startup | |
| (no GPU needed, just HNSW index building). | |
| Usage: | |
| hf jobs uv run --flavor l4x1 -s HF_TOKEN build_embeddings.py | |
| # With mounted model for faster startup: | |
| hf jobs uv run --flavor l4x1 -s HF_TOKEN \ | |
| -v hf://Qwen/Qwen3-Embedding-0.6B:/model:ro \ | |
| build_embeddings.py --model-path /model | |
| """ | |
| import argparse | |
| import logging | |
| import os | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| import polars as pl | |
| import torch | |
| from datasets import Dataset | |
| from huggingface_hub import login | |
| from sentence_transformers import SentenceTransformer | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B" | |
| BATCH_SIZE = 512 | |
| DATASET_SOURCE = "davanstrien/datasets_with_metadata_and_summaries" | |
| MODEL_SOURCE = "davanstrien/models_with_metadata_and_summaries" | |
| OUTPUT_REPO = "davanstrien/search-v2-embeddings" | |
| def get_device(): | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def embed_datasets(model, dataset_source): | |
| """Embed dataset summaries and return a Dataset with embeddings.""" | |
| logger.info(f"=== Embedding dataset cards from {dataset_source} ===") | |
| df = pl.scan_parquet( | |
| f"hf://datasets/{dataset_source}/data/train-*.parquet" | |
| ) | |
| df = df.filter( | |
| pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_() | |
| ) | |
| df = df.filter( | |
| pl.col("datasetId") | |
| .str.contains_any(["gemma-2-2B-it-thinking-function_calling-V0"]) | |
| .not_() | |
| ) | |
| df = df.select(["datasetId", "summary", "likes", "downloads", "last_modified"]) | |
| df = df.collect() | |
| logger.info(f"Loaded {len(df):,} dataset records") | |
| summaries = df["summary"].to_list() | |
| logger.info(f"Computing embeddings for {len(summaries):,} summaries...") | |
| embeddings = model.encode( | |
| summaries, | |
| batch_size=BATCH_SIZE, | |
| show_progress_bar=True, | |
| prompt_name="document", | |
| ) | |
| logger.info("Embeddings computed") | |
| ds = Dataset.from_dict({ | |
| "id": df["datasetId"].to_list(), | |
| "summary": summaries, | |
| "embedding": [emb.tolist() for emb in embeddings], | |
| "likes": df["likes"].to_list(), | |
| "downloads": df["downloads"].to_list(), | |
| "last_modified": df["last_modified"].cast(pl.Utf8).to_list(), | |
| }) | |
| return ds | |
| def embed_models(model, model_source): | |
| """Embed model summaries and return a Dataset with embeddings.""" | |
| logger.info(f"=== Embedding model cards from {model_source} ===") | |
| df = pl.scan_parquet( | |
| f"hf://datasets/{model_source}/data/train-*.parquet" | |
| ) | |
| schema = df.collect_schema() | |
| select_cols = ["modelId", "summary", "likes", "downloads", "last_modified"] | |
| has_param_count = "param_count" in schema | |
| if has_param_count: | |
| select_cols.append("param_count") | |
| df = df.select(select_cols).collect() | |
| logger.info(f"Loaded {len(df):,} model records") | |
| summaries = df["summary"].to_list() | |
| logger.info(f"Computing embeddings for {len(summaries):,} summaries...") | |
| embeddings = model.encode( | |
| summaries, | |
| batch_size=BATCH_SIZE, | |
| show_progress_bar=True, | |
| prompt_name="document", | |
| ) | |
| logger.info("Embeddings computed") | |
| data = { | |
| "id": df["modelId"].to_list(), | |
| "summary": summaries, | |
| "embedding": [emb.tolist() for emb in embeddings], | |
| "likes": df["likes"].to_list(), | |
| "downloads": df["downloads"].to_list(), | |
| "last_modified": df["last_modified"].cast(pl.Utf8).to_list(), | |
| } | |
| if has_param_count: | |
| data["param_count"] = df["param_count"].to_list() | |
| return Dataset.from_dict(data) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Compute embeddings and push to Hub dataset" | |
| ) | |
| parser.add_argument( | |
| "--model-path", | |
| help="Local/mounted path to embedding model (skips download)", | |
| ) | |
| parser.add_argument( | |
| "--output-repo", | |
| default=OUTPUT_REPO, | |
| help=f"Hub dataset repo to push embeddings to (default: {OUTPUT_REPO})", | |
| ) | |
| args = parser.parse_args() | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN) | |
| device = get_device() | |
| logger.info(f"Device: {device}") | |
| model_name = args.model_path or EMBEDDING_MODEL | |
| logger.info(f"Loading embedding model: {model_name}") | |
| model = SentenceTransformer(model_name, device=device) | |
| # Embed datasets | |
| ds_datasets = embed_datasets(model, DATASET_SOURCE) | |
| logger.info(f"Pushing dataset_cards config ({len(ds_datasets):,} rows) to {args.output_repo}...") | |
| ds_datasets.push_to_hub(args.output_repo, config_name="dataset_cards", token=HF_TOKEN) | |
| # Embed models | |
| ds_models = embed_models(model, MODEL_SOURCE) | |
| logger.info(f"Pushing model_cards config ({len(ds_models):,} rows) to {args.output_repo}...") | |
| ds_models.push_to_hub(args.output_repo, config_name="model_cards", token=HF_TOKEN) | |
| logger.info("=== Done ===") | |
| if __name__ == "__main__": | |
| main() | |