huggingface-datasets-search-v2 / build_embeddings.py
davanstrien's picture
davanstrien HF Staff
Fix: use 'document' prompt for index embeddings, not 'query'
12683d4
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "datasets",
# "hf-transfer",
# "hf-xet",
# "huggingface-hub",
# "polars",
# "sentence-transformers",
# "torch",
# ]
# ///
"""
Compute embeddings for dataset/model summaries and push to a Hub dataset.
The output dataset has two configs:
- dataset_cards: id, summary, embedding, likes, downloads, last_modified
- model_cards: id, summary, embedding, likes, downloads, last_modified, param_count
The Space then loads these pre-computed embeddings into ChromaDB on startup
(no GPU needed, just HNSW index building).
Usage:
hf jobs uv run --flavor l4x1 -s HF_TOKEN build_embeddings.py
# With mounted model for faster startup:
hf jobs uv run --flavor l4x1 -s HF_TOKEN \
-v hf://Qwen/Qwen3-Embedding-0.6B:/model:ro \
build_embeddings.py --model-path /model
"""
import argparse
import logging
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import polars as pl
import torch
from datasets import Dataset
from huggingface_hub import login
from sentence_transformers import SentenceTransformer
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"
BATCH_SIZE = 512
DATASET_SOURCE = "davanstrien/datasets_with_metadata_and_summaries"
MODEL_SOURCE = "davanstrien/models_with_metadata_and_summaries"
OUTPUT_REPO = "davanstrien/search-v2-embeddings"
def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
def embed_datasets(model, dataset_source):
"""Embed dataset summaries and return a Dataset with embeddings."""
logger.info(f"=== Embedding dataset cards from {dataset_source} ===")
df = pl.scan_parquet(
f"hf://datasets/{dataset_source}/data/train-*.parquet"
)
df = df.filter(
pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
)
df = df.filter(
pl.col("datasetId")
.str.contains_any(["gemma-2-2B-it-thinking-function_calling-V0"])
.not_()
)
df = df.select(["datasetId", "summary", "likes", "downloads", "last_modified"])
df = df.collect()
logger.info(f"Loaded {len(df):,} dataset records")
summaries = df["summary"].to_list()
logger.info(f"Computing embeddings for {len(summaries):,} summaries...")
embeddings = model.encode(
summaries,
batch_size=BATCH_SIZE,
show_progress_bar=True,
prompt_name="document",
)
logger.info("Embeddings computed")
ds = Dataset.from_dict({
"id": df["datasetId"].to_list(),
"summary": summaries,
"embedding": [emb.tolist() for emb in embeddings],
"likes": df["likes"].to_list(),
"downloads": df["downloads"].to_list(),
"last_modified": df["last_modified"].cast(pl.Utf8).to_list(),
})
return ds
def embed_models(model, model_source):
"""Embed model summaries and return a Dataset with embeddings."""
logger.info(f"=== Embedding model cards from {model_source} ===")
df = pl.scan_parquet(
f"hf://datasets/{model_source}/data/train-*.parquet"
)
schema = df.collect_schema()
select_cols = ["modelId", "summary", "likes", "downloads", "last_modified"]
has_param_count = "param_count" in schema
if has_param_count:
select_cols.append("param_count")
df = df.select(select_cols).collect()
logger.info(f"Loaded {len(df):,} model records")
summaries = df["summary"].to_list()
logger.info(f"Computing embeddings for {len(summaries):,} summaries...")
embeddings = model.encode(
summaries,
batch_size=BATCH_SIZE,
show_progress_bar=True,
prompt_name="document",
)
logger.info("Embeddings computed")
data = {
"id": df["modelId"].to_list(),
"summary": summaries,
"embedding": [emb.tolist() for emb in embeddings],
"likes": df["likes"].to_list(),
"downloads": df["downloads"].to_list(),
"last_modified": df["last_modified"].cast(pl.Utf8).to_list(),
}
if has_param_count:
data["param_count"] = df["param_count"].to_list()
return Dataset.from_dict(data)
def main():
parser = argparse.ArgumentParser(
description="Compute embeddings and push to Hub dataset"
)
parser.add_argument(
"--model-path",
help="Local/mounted path to embedding model (skips download)",
)
parser.add_argument(
"--output-repo",
default=OUTPUT_REPO,
help=f"Hub dataset repo to push embeddings to (default: {OUTPUT_REPO})",
)
args = parser.parse_args()
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
device = get_device()
logger.info(f"Device: {device}")
model_name = args.model_path or EMBEDDING_MODEL
logger.info(f"Loading embedding model: {model_name}")
model = SentenceTransformer(model_name, device=device)
# Embed datasets
ds_datasets = embed_datasets(model, DATASET_SOURCE)
logger.info(f"Pushing dataset_cards config ({len(ds_datasets):,} rows) to {args.output_repo}...")
ds_datasets.push_to_hub(args.output_repo, config_name="dataset_cards", token=HF_TOKEN)
# Embed models
ds_models = embed_models(model, MODEL_SOURCE)
logger.info(f"Pushing model_cards config ({len(ds_models):,} rows) to {args.output_repo}...")
ds_models.push_to_hub(args.output_repo, config_name="model_cards", token=HF_TOKEN)
logger.info("=== Done ===")
if __name__ == "__main__":
main()