davanstrien HF Staff Claude Opus 4.6 (1M context) commited on
Commit
21a271a
·
1 Parent(s): 756e837

Migrate off persistent storage to bucket-mounted ChromaDB

Browse files

- Simplify main.py: remove setup_database() and all indexing logic.
Space now reads pre-built ChromaDB from mounted storage bucket.
- Add build_chroma_index.py: standalone uv script that builds the
ChromaDB index as an HF Job on GPU (much faster than CPU).
- Update generate_summaries_uv.py: support mounted volumes for model
and input data, pin transformers<4.52, fix vllm version, reduce
content truncation to 3000 chars to avoid exceeding model max length.
- Update HFJOBS_COMMANDS.md: correct output repo names, add index
build command, use hf jobs uv run with volume mounts.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Files changed (4) hide show
  1. HFJOBS_COMMANDS.md +134 -0
  2. build_chroma_index.py +293 -0
  3. generate_summaries_uv.py +43 -18
  4. main.py +14 -295
HFJOBS_COMMANDS.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HFJobs Commands for Summary Generation
2
+
3
+ This document contains the hfjobs commands for running the summary generation pipeline.
4
+
5
+ ## Performance Optimizations
6
+
7
+ For batch inference workloads (processing thousands of short summaries), consider these vLLM optimizations:
8
+
9
+ ### Memory and Throughput Settings
10
+
11
+ 1. **GPU Memory Utilization** (`gpu_memory_utilization`)
12
+ - Default: 0.9 (90%)
13
+ - Recommended: 0.95 or 0.98 for batch workloads
14
+ - Allocates more GPU memory for KV cache, allowing more concurrent sequences
15
+
16
+ 2. **Chunked Prefill** (`enable_chunked_prefill`)
17
+ - Set to `True` for many short requests
18
+ - Interleaves prefill and decode phases more efficiently
19
+ - Particularly beneficial for uniform, short outputs like summaries
20
+
21
+ 3. **Max Batched Tokens** (`max_num_batched_tokens`)
22
+ - Default: 512
23
+ - Recommended: 4096 or 8192 for better throughput
24
+ - Controls tokens processed together in a single iteration
25
+
26
+ 4. **Max Number of Sequences** (`max_num_seqs`)
27
+ - Increase to 256 or 512 for batch workloads
28
+ - More concurrent sequences = better throughput
29
+ - L4 GPU (24GB) can handle aggressive settings
30
+
31
+ ### Example Optimized Configuration
32
+
33
+ ```python
34
+ llm = LLM(
35
+ model=local_model_path,
36
+ max_model_len=4096,
37
+ gpu_memory_utilization=0.95, # Use 95% of GPU memory
38
+ enable_chunked_prefill=True, # Better for short requests
39
+ max_num_batched_tokens=8192, # High throughput batching
40
+ max_num_seqs=256, # Many concurrent sequences
41
+ )
42
+ ```
43
+
44
+ ## Summary Generation (hf jobs uv run)
45
+
46
+ Uses `generate_summaries_uv.py` with volume mounts for fast startup (no download step).
47
+
48
+ ### Dataset Summaries
49
+
50
+ ```bash
51
+ hf jobs uv run --flavor l4x1 \
52
+ -v hf://datasets/librarian-bots/dataset_cards_with_metadata:/input:ro \
53
+ -v hf://davanstrien/Smol-Hub-tldr:/model:ro \
54
+ -s HF_TOKEN \
55
+ --timeout 2h \
56
+ generate_summaries_uv.py \
57
+ /model \
58
+ librarian-bots/dataset_cards_with_metadata \
59
+ davanstrien/datasets_with_metadata_and_summaries \
60
+ --card-type dataset \
61
+ --input-path /input \
62
+ --batch-size 2000
63
+ ```
64
+
65
+ ### Model Summaries
66
+
67
+ ```bash
68
+ hf jobs uv run --flavor l4x1 \
69
+ -v hf://datasets/librarian-bots/model_cards_with_metadata:/input:ro \
70
+ -v hf://davanstrien/SmolLM2-135M-tldr-sft-2025-03-12_19-02:/model:ro \
71
+ -s HF_TOKEN \
72
+ --timeout 2h \
73
+ generate_summaries_uv.py \
74
+ /model \
75
+ librarian-bots/model_cards_with_metadata \
76
+ davanstrien/models_with_metadata_and_summaries \
77
+ --card-type model \
78
+ --min-likes 5 \
79
+ --min-downloads 1000 \
80
+ --input-path /input \
81
+ --batch-size 2000
82
+ ```
83
+
84
+ ### Without volume mounts (downloads data instead)
85
+
86
+ If volumes aren't available, the script falls back to downloading:
87
+
88
+ ```bash
89
+ hf jobs uv run --flavor l4x1 \
90
+ -s HF_TOKEN \
91
+ --timeout 2h \
92
+ generate_summaries_uv.py \
93
+ davanstrien/Smol-Hub-tldr \
94
+ librarian-bots/dataset_cards_with_metadata \
95
+ davanstrien/datasets_with_metadata_and_summaries \
96
+ --card-type dataset \
97
+ --batch-size 2000
98
+ ```
99
+
100
+ ## ChromaDB Index Build
101
+
102
+ Builds/updates the ChromaDB vector index from the summary datasets. Must run after summary generation to update search results. Writes to a Storage Bucket mounted as a volume.
103
+
104
+ ```bash
105
+ hf jobs uv run --flavor l4x1 \
106
+ -v hf://buckets/davanstrien/search-v2-chroma:/data \
107
+ -s HF_TOKEN \
108
+ https://huggingface.co/spaces/davanstrien/huggingface-datasets-search-v2/raw/main/build_chroma_index.py
109
+ ```
110
+
111
+ For a full rebuild (delete existing collections first):
112
+
113
+ ```bash
114
+ hf jobs uv run --flavor l4x1 \
115
+ -v hf://buckets/davanstrien/search-v2-chroma:/data \
116
+ -s HF_TOKEN \
117
+ https://huggingface.co/spaces/davanstrien/huggingface-datasets-search-v2/raw/main/build_chroma_index.py \
118
+ --full-rebuild
119
+ ```
120
+
121
+ ### Full Pipeline (summaries → index)
122
+
123
+ Run summary generation first, then rebuild the index:
124
+
125
+ 1. Generate dataset summaries (see Dataset Summaries above)
126
+ 2. Generate model summaries (see Model Summaries above)
127
+ 3. Build the ChromaDB index (this section)
128
+
129
+ ## Notes
130
+
131
+ - The vLLM Docker image approach is preferred over the uv:debian image because it includes all necessary system dependencies (Python headers, CUDA libraries, etc.)
132
+ - The script is run directly from the Hugging Face Space URL using `uv run`
133
+ - Adjust `--batch-size` based on available GPU memory
134
+ - For models, adjust `--min-likes` and `--min-downloads` thresholds as needed
build_chroma_index.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11"
3
+ # dependencies = [
4
+ # "chromadb==1.0.12",
5
+ # "hf-transfer",
6
+ # "hf-xet",
7
+ # "huggingface-hub",
8
+ # "polars",
9
+ # "python-dateutil",
10
+ # "sentence-transformers",
11
+ # "torch",
12
+ # ]
13
+ # ///
14
+ """
15
+ Build ChromaDB index for the datasets-search-v2 Space.
16
+
17
+ Reads summary parquets from the Hub, embeds them with Qwen3-Embedding-0.6B,
18
+ and writes the ChromaDB index to a mounted Storage Bucket.
19
+
20
+ Usage (via hf jobs):
21
+ hf jobs uv run \
22
+ --flavor l4x1 \
23
+ -v hf://buckets/davanstrien/search-v2-chroma:/data \
24
+ -s HF_TOKEN \
25
+ build_chroma_index.py
26
+
27
+ Local usage:
28
+ uv run build_chroma_index.py --data-dir ./data
29
+ """
30
+ import argparse
31
+ import logging
32
+ import os
33
+ import sys
34
+
35
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
36
+
37
+ import chromadb
38
+ import dateutil.parser
39
+ import polars as pl
40
+ import torch
41
+ from chromadb.config import Settings
42
+ from chromadb.utils import embedding_functions
43
+ from huggingface_hub import login
44
+
45
+ logging.basicConfig(
46
+ level=logging.INFO,
47
+ format="%(asctime)s - %(levelname)s - %(message)s",
48
+ datefmt="%Y-%m-%d %H:%M:%S",
49
+ )
50
+ logger = logging.getLogger(__name__)
51
+
52
+ EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"
53
+ BATCH_SIZE = 2000
54
+
55
+ DATASET_SOURCE = "davanstrien/datasets_with_metadata_and_summaries"
56
+ MODEL_SOURCE = "davanstrien/models_with_metadata_and_summaries"
57
+
58
+
59
+ def get_device():
60
+ if torch.cuda.is_available():
61
+ return "cuda"
62
+ elif torch.backends.mps.is_available():
63
+ return "mps"
64
+ return "cpu"
65
+
66
+
67
+ def get_embedding_function(device):
68
+ logger.info(f"Loading embedding model {EMBEDDING_MODEL} on {device}")
69
+ return embedding_functions.SentenceTransformerEmbeddingFunction(
70
+ model_name=EMBEDDING_MODEL, device=device
71
+ )
72
+
73
+
74
+ def build_dataset_collection(client, embedding_function):
75
+ """Build/update the dataset_cards collection."""
76
+ logger.info("=== Building dataset collection ===")
77
+
78
+ collection = client.get_or_create_collection(
79
+ embedding_function=embedding_function,
80
+ name="dataset_cards",
81
+ metadata={"hnsw:space": "cosine"},
82
+ )
83
+
84
+ df = pl.scan_parquet(
85
+ f"hf://datasets/{DATASET_SOURCE}/data/train-*.parquet"
86
+ )
87
+ df = df.filter(
88
+ pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
89
+ )
90
+ df = df.filter(
91
+ pl.col("datasetId")
92
+ .str.contains_any(["gemma-2-2B-it-thinking-function_calling-V0"])
93
+ .not_()
94
+ )
95
+
96
+ # Check for incremental update
97
+ latest_update = None
98
+ if collection.count() > 0:
99
+ metadata = collection.get(include=["metadatas"]).get("metadatas")
100
+ logger.info(f"Found {len(metadata)} existing records in collection")
101
+ last_modifieds = [
102
+ dateutil.parser.parse(m.get("last_modified")) for m in metadata
103
+ ]
104
+ latest_update = max(last_modifieds)
105
+ logger.info(f"Most recent record in DB from: {latest_update}")
106
+
107
+ df = df.select(["datasetId", "summary", "likes", "downloads", "last_modified"])
108
+
109
+ total_incoming = df.select(pl.len()).collect().item()
110
+ logger.info(f"Total incoming records from source: {total_incoming}")
111
+
112
+ if latest_update:
113
+ logger.info(f"Filtering records newer than {latest_update}")
114
+ df = df.with_columns(pl.col("last_modified").str.to_datetime())
115
+ df = df.filter(pl.col("last_modified") > latest_update)
116
+ filtered_count = df.select(pl.len()).collect().item()
117
+ logger.info(f"Found {filtered_count} records to update after filtering")
118
+
119
+ df = df.collect()
120
+ total_rows = len(df)
121
+
122
+ if total_rows > 0:
123
+ logger.info(f"Updating dataset collection with {total_rows} new records")
124
+ for i in range(0, total_rows, BATCH_SIZE):
125
+ batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
126
+ batch_size = len(batch_df)
127
+
128
+ collection.upsert(
129
+ ids=batch_df.select(["datasetId"]).to_series().to_list(),
130
+ documents=batch_df.select(["summary"]).to_series().to_list(),
131
+ metadatas=[
132
+ {
133
+ "likes": int(likes),
134
+ "downloads": int(downloads),
135
+ "last_modified": str(last_modified),
136
+ }
137
+ for likes, downloads, last_modified in zip(
138
+ batch_df.select(["likes"]).to_series().to_list(),
139
+ batch_df.select(["downloads"]).to_series().to_list(),
140
+ batch_df.select(["last_modified"]).to_series().to_list(),
141
+ )
142
+ ],
143
+ )
144
+ logger.info(f"Processed {i + batch_size:,} / {total_rows:,} dataset records")
145
+ else:
146
+ logger.info("No new dataset records to update")
147
+
148
+ final_count = collection.count()
149
+ logger.info(f"Dataset collection: {final_count:,} total records")
150
+
151
+
152
+ def build_model_collection(client, embedding_function):
153
+ """Build/update the model_cards collection."""
154
+ logger.info("=== Building model collection ===")
155
+
156
+ collection = client.get_or_create_collection(
157
+ embedding_function=embedding_function,
158
+ name="model_cards",
159
+ metadata={"hnsw:space": "cosine"},
160
+ )
161
+
162
+ model_lazy_df = pl.scan_parquet(
163
+ f"hf://datasets/{MODEL_SOURCE}/data/train-*.parquet"
164
+ )
165
+
166
+ # Check for incremental update
167
+ model_latest_update = None
168
+ if collection.count() > 0:
169
+ model_metadata = collection.get(include=["metadatas"]).get("metadatas")
170
+ logger.info(f"Found {len(model_metadata)} existing model records in collection")
171
+ model_last_modifieds = [
172
+ dateutil.parser.parse(m.get("last_modified")) for m in model_metadata
173
+ ]
174
+ model_latest_update = max(model_last_modifieds)
175
+ logger.info(f"Most recent model record in DB from: {model_latest_update}")
176
+
177
+ # Set up columns to select
178
+ schema = model_lazy_df.collect_schema()
179
+ select_columns = ["modelId", "summary", "likes", "downloads", "last_modified"]
180
+ if "param_count" in schema:
181
+ logger.info("Found 'param_count' column in model data schema.")
182
+ select_columns.append("param_count")
183
+ else:
184
+ logger.warning("'param_count' column not found. Will add with null values.")
185
+
186
+ model_df = model_lazy_df.select(select_columns)
187
+ model_row_count = model_df.select(pl.len()).collect().item()
188
+ logger.info(f"Total model records in source: {model_row_count}")
189
+
190
+ if model_latest_update:
191
+ logger.info(f"Filtering model records newer than {model_latest_update}")
192
+ model_df = model_df.with_columns(pl.col("last_modified").str.to_datetime())
193
+ model_df = model_df.filter(pl.col("last_modified") > model_latest_update)
194
+ model_filtered_count = model_df.select(pl.len()).collect().item()
195
+ logger.info(f"Found {model_filtered_count} model records to update")
196
+ else:
197
+ model_filtered_count = model_df.select(pl.len()).collect().item()
198
+ logger.info(f"Initial model load: processing all {model_filtered_count} records")
199
+
200
+ if model_filtered_count > 0:
201
+ model_df = model_df.collect()
202
+
203
+ if "param_count" not in model_df.columns:
204
+ model_df = model_df.with_columns(
205
+ pl.lit(None).cast(pl.Int64).alias("param_count")
206
+ )
207
+
208
+ total_rows = len(model_df)
209
+ logger.info(f"Updating model collection with {total_rows} new records")
210
+
211
+ for i in range(0, total_rows, BATCH_SIZE):
212
+ batch_df = model_df.slice(i, min(BATCH_SIZE, total_rows - i))
213
+
214
+ collection.upsert(
215
+ ids=batch_df.select(["modelId"]).to_series().to_list(),
216
+ documents=batch_df.select(["summary"]).to_series().to_list(),
217
+ metadatas=[
218
+ {
219
+ "likes": int(likes),
220
+ "downloads": int(downloads),
221
+ "last_modified": str(last_modified),
222
+ "param_count": int(param_count)
223
+ if param_count is not None
224
+ else 0,
225
+ }
226
+ for likes, downloads, last_modified, param_count in zip(
227
+ batch_df.select(["likes"]).to_series().to_list(),
228
+ batch_df.select(["downloads"]).to_series().to_list(),
229
+ batch_df.select(["last_modified"]).to_series().to_list(),
230
+ batch_df.select(["param_count"]).to_series().to_list(),
231
+ )
232
+ ],
233
+ )
234
+ logger.info(
235
+ f"Processed {i + len(batch_df):,} / {total_rows:,} model records"
236
+ )
237
+ else:
238
+ logger.info("No new model records to update")
239
+
240
+ logger.info(f"Model collection: {collection.count():,} total records")
241
+
242
+
243
+ def main():
244
+ parser = argparse.ArgumentParser(
245
+ description="Build ChromaDB index for datasets-search-v2"
246
+ )
247
+ parser.add_argument(
248
+ "--data-dir",
249
+ default="/data",
250
+ help="Path to write ChromaDB data (default: /data, the bucket mount point)",
251
+ )
252
+ parser.add_argument(
253
+ "--full-rebuild",
254
+ action="store_true",
255
+ help="Delete existing collections and rebuild from scratch",
256
+ )
257
+ args = parser.parse_args()
258
+
259
+ # Login
260
+ HF_TOKEN = os.environ.get("HF_TOKEN")
261
+ if HF_TOKEN:
262
+ login(token=HF_TOKEN)
263
+
264
+ chroma_path = os.path.join(args.data_dir, "chroma")
265
+ logger.info(f"ChromaDB path: {chroma_path}")
266
+ logger.info(f"ChromaDB version: {chromadb.__version__}")
267
+
268
+ client = chromadb.PersistentClient(
269
+ path=chroma_path,
270
+ settings=Settings(anonymized_telemetry=False, is_persistent=True),
271
+ )
272
+
273
+ if args.full_rebuild:
274
+ logger.info("Full rebuild requested — deleting existing collections")
275
+ for name in ["dataset_cards", "model_cards"]:
276
+ try:
277
+ client.delete_collection(name)
278
+ logger.info(f"Deleted collection: {name}")
279
+ except Exception:
280
+ pass
281
+
282
+ device = get_device()
283
+ logger.info(f"Using device: {device}")
284
+ embedding_function = get_embedding_function(device)
285
+
286
+ build_dataset_collection(client, embedding_function)
287
+ build_model_collection(client, embedding_function)
288
+
289
+ logger.info("=== Index build complete ===")
290
+
291
+
292
+ if __name__ == "__main__":
293
+ main()
generate_summaries_uv.py CHANGED
@@ -6,13 +6,11 @@
6
  # "huggingface-hub[hf_xet]",
7
  # "polars",
8
  # "stamina",
9
- # "transformers",
10
- # "vllm",
11
  # "tqdm",
12
  # "setuptools",
13
- # "flashinfer-python",
14
  # ]
15
- #
16
  # ///
17
  import argparse
18
  import logging
@@ -54,12 +52,17 @@ logger.info(f"PyTorch version: {torch.__version__}")
54
  logger.info(f"vLLM version: {vllm.__version__}")
55
 
56
 
57
- def format_prompt(content: str, card_type: str, tokenizer) -> str:
58
- """Format content as a prompt for the model."""
 
 
 
 
 
59
  if card_type == "model":
60
- messages = [{"role": "user", "content": f"<MODEL_CARD>{content[:4000]}"}]
61
  else:
62
- messages = [{"role": "user", "content": f"<DATASET_CARD>{content[:4000]}"}]
63
 
64
  return tokenizer.apply_chat_template(
65
  messages, add_generation_prompt=True, tokenize=False
@@ -67,12 +70,21 @@ def format_prompt(content: str, card_type: str, tokenizer) -> str:
67
 
68
 
69
  def load_and_filter_data(
70
- dataset_id: str, card_type: str, min_likes: int = 1, min_downloads: int = 1
 
71
  ) -> pl.DataFrame:
72
- """Load and filter dataset/model data."""
73
- logger.info(f"Loading data from {dataset_id}")
74
- ds = load_dataset(dataset_id, split="train")
75
- df = ds.to_polars().lazy()
 
 
 
 
 
 
 
 
76
 
77
  # Extract content after YAML frontmatter
78
  df = df.with_columns(
@@ -108,6 +120,7 @@ def generate_summaries(
108
  min_likes: int = 1,
109
  min_downloads: int = 1,
110
  hf_token: Optional[str] = None,
 
111
  ):
112
  """Main function to generate summaries."""
113
 
@@ -118,13 +131,19 @@ def generate_summaries(
118
 
119
  # Load and filter data
120
  df_filtered = load_and_filter_data(
121
- input_dataset_id, card_type, min_likes, min_downloads
 
122
  )
123
 
124
- # Download model to local directory first
125
- logger.info(f"Downloading model {model_id} to local directory...")
126
- local_model_path = snapshot_download(repo_id=model_id, resume_download=True)
127
- logger.info(f"Model downloaded to: {local_model_path}")
 
 
 
 
 
128
 
129
  # Initialize model and tokenizer from local path
130
  logger.info(f"Initializing vLLM model from local path: {local_model_path}")
@@ -229,6 +248,11 @@ def main():
229
  parser.add_argument(
230
  "--hf-token", help="Hugging Face token (uses HF_TOKEN env var if not provided)"
231
  )
 
 
 
 
 
232
 
233
  args = parser.parse_args()
234
 
@@ -243,6 +267,7 @@ def main():
243
  min_likes=args.min_likes,
244
  min_downloads=args.min_downloads,
245
  hf_token=args.hf_token,
 
246
  )
247
 
248
 
 
6
  # "huggingface-hub[hf_xet]",
7
  # "polars",
8
  # "stamina",
9
+ # "transformers<4.52",
10
+ # "vllm>=0.8",
11
  # "tqdm",
12
  # "setuptools",
 
13
  # ]
 
14
  # ///
15
  import argparse
16
  import logging
 
52
  logger.info(f"vLLM version: {vllm.__version__}")
53
 
54
 
55
+ def format_prompt(content: str, card_type: str, tokenizer, max_content_chars: int = 3000) -> str:
56
+ """Format content as a prompt for the model.
57
+
58
+ Truncates content to max_content_chars (default 3000) to stay safely
59
+ under the model's max sequence length after tokenization.
60
+ """
61
+ truncated = content[:max_content_chars]
62
  if card_type == "model":
63
+ messages = [{"role": "user", "content": f"<MODEL_CARD>{truncated}"}]
64
  else:
65
+ messages = [{"role": "user", "content": f"<DATASET_CARD>{truncated}"}]
66
 
67
  return tokenizer.apply_chat_template(
68
  messages, add_generation_prompt=True, tokenize=False
 
70
 
71
 
72
  def load_and_filter_data(
73
+ dataset_id: str, card_type: str, min_likes: int = 1, min_downloads: int = 1,
74
+ local_path: Optional[str] = None,
75
  ) -> pl.DataFrame:
76
+ """Load and filter dataset/model data.
77
+
78
+ If local_path is provided (e.g. a mounted volume), reads parquet files
79
+ directly from disk instead of downloading from the Hub.
80
+ """
81
+ if local_path:
82
+ logger.info(f"Loading data from local path: {local_path}")
83
+ df = pl.scan_parquet(os.path.join(local_path, "data", "train-*.parquet"))
84
+ else:
85
+ logger.info(f"Loading data from {dataset_id}")
86
+ ds = load_dataset(dataset_id, split="train")
87
+ df = ds.to_polars().lazy()
88
 
89
  # Extract content after YAML frontmatter
90
  df = df.with_columns(
 
120
  min_likes: int = 1,
121
  min_downloads: int = 1,
122
  hf_token: Optional[str] = None,
123
+ input_path: Optional[str] = None,
124
  ):
125
  """Main function to generate summaries."""
126
 
 
131
 
132
  # Load and filter data
133
  df_filtered = load_and_filter_data(
134
+ input_dataset_id, card_type, min_likes, min_downloads,
135
+ local_path=input_path,
136
  )
137
 
138
+ # Use model_id directly if it's a local path (e.g. mounted volume),
139
+ # otherwise download from the Hub
140
+ if os.path.isdir(model_id):
141
+ local_model_path = model_id
142
+ logger.info(f"Using model from local/mounted path: {local_model_path}")
143
+ else:
144
+ logger.info(f"Downloading model {model_id} to local directory...")
145
+ local_model_path = snapshot_download(repo_id=model_id, resume_download=True)
146
+ logger.info(f"Model downloaded to: {local_model_path}")
147
 
148
  # Initialize model and tokenizer from local path
149
  logger.info(f"Initializing vLLM model from local path: {local_model_path}")
 
248
  parser.add_argument(
249
  "--hf-token", help="Hugging Face token (uses HF_TOKEN env var if not provided)"
250
  )
251
+ parser.add_argument(
252
+ "--input-path",
253
+ help="Local/mounted path to input dataset (skips download). "
254
+ "E.g. /input when using -v hf://datasets/org/dataset:/input",
255
+ )
256
 
257
  args = parser.parse_args()
258
 
 
267
  min_likes=args.min_likes,
268
  min_downloads=args.min_downloads,
269
  hf_token=args.hf_token,
270
+ input_path=args.input_path,
271
  )
272
 
273
 
main.py CHANGED
@@ -3,32 +3,27 @@ import logging
3
  import os
4
  import sys
5
  from contextlib import asynccontextmanager
6
- from datetime import datetime
7
  from typing import List, Optional
8
 
9
  import chromadb
10
- import dateutil.parser
11
  import httpx
12
- import polars as pl
13
  import torch
14
  from cashews import cache
15
  from chromadb.utils import embedding_functions
16
  from fastapi import FastAPI, HTTPException, Query
17
  from fastapi.middleware.cors import CORSMiddleware
18
  from pydantic import BaseModel
19
- from transformers import AutoTokenizer
20
  from dotenv import load_dotenv
21
  from huggingface_hub import login
22
 
23
  load_dotenv(override=True)
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
  login(token=HF_TOKEN)
 
26
  # Configuration constants
27
- MODEL_NAME = "davanstrien/Smol-Hub-tldr"
28
  EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"
29
- BATCH_SIZE = 2000
30
  CACHE_TTL = "24h"
31
- TRENDING_CACHE_TTL = "1h" # 15 minutes cache for trending data
32
 
33
  if torch.cuda.is_available():
34
  DEVICE = "cuda"
@@ -37,34 +32,34 @@ elif torch.backends.mps.is_available():
37
  else:
38
  DEVICE = "cpu"
39
 
40
-
41
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
42
-
43
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER
44
- # Set up logging
45
  logging.basicConfig(level=logging.INFO)
46
  logger = logging.getLogger(__name__)
47
 
48
- LOCAL = False
49
- if sys.platform == "darwin":
50
- LOCAL = True
51
  DATA_DIR = "data" if LOCAL else "/data"
 
52
  # Configure cache
53
  cache.setup("mem://", size_limit="8gb")
54
 
55
- # Initialize ChromaDB client
56
  client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma")
57
 
58
 
59
  # Initialize FastAPI app
60
  @asynccontextmanager
61
  async def lifespan(app: FastAPI):
62
- # Setup
63
- setup_database()
 
 
 
 
 
 
64
 
65
  yield
66
 
67
- # Cleanup
68
  await cache.close()
69
 
70
 
@@ -92,282 +87,6 @@ def get_embedding_function():
92
  )
93
 
94
 
95
- def setup_database():
96
- try:
97
- embedding_function = get_embedding_function()
98
- dataset_collection = client.get_or_create_collection(
99
- embedding_function=embedding_function,
100
- name="dataset_cards",
101
- metadata={"hnsw:space": "cosine"},
102
- )
103
- model_collection = client.get_or_create_collection(
104
- embedding_function=embedding_function,
105
- name="model_cards",
106
- metadata={"hnsw:space": "cosine"},
107
- )
108
-
109
- # Load dataset data
110
- df = pl.scan_parquet(
111
- "hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet"
112
- )
113
- df = df.filter(
114
- pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
115
- )
116
- df = df.filter(
117
- pl.col("datasetId")
118
- .str.contains_any(
119
- ["gemma-2-2B-it-thinking-function_calling-V0"]
120
- ) # course model that's not useful for retrieving
121
- .not_()
122
- )
123
- # Get the most recent last_modified date from the collection
124
- latest_update = None
125
- if dataset_collection.count() > 0:
126
- metadata = dataset_collection.get(include=["metadatas"]).get("metadatas")
127
- logger.info(f"Found {len(metadata)} existing records in collection")
128
-
129
- last_modifieds = [
130
- dateutil.parser.parse(m.get("last_modified")) for m in metadata
131
- ]
132
- latest_update = max(last_modifieds)
133
- logger.info(f"Most recent record in DB from: {latest_update}")
134
- logger.info(f"Oldest record in DB from: {min(last_modifieds)}")
135
-
136
- # Log sample of existing timestamps for debugging
137
- sample_timestamps = sorted(last_modifieds, reverse=True)[:5]
138
- logger.info(f"Sample of most recent DB timestamps: {sample_timestamps}")
139
-
140
- # Filter and process only newer records
141
- df = df.select(["datasetId", "summary", "likes", "downloads", "last_modified"])
142
-
143
- # Log some stats about the incoming data BEFORE collecting
144
- total_incoming = df.select(pl.len()).collect().item()
145
- logger.info(f"Total incoming records from source: {total_incoming}")
146
-
147
- # Get sample of dates to understand the data
148
- sample_df = (
149
- df.select(["datasetId", "last_modified"])
150
- .sort("last_modified", descending=True)
151
- .limit(5)
152
- .collect()
153
- )
154
- logger.info(f"Sample of most recent incoming records: {sample_df.rows()[:3]}")
155
-
156
- if latest_update:
157
- logger.info(f"Filtering records newer than {latest_update}")
158
- logger.info(f"Latest update type: {type(latest_update)}")
159
-
160
- # Get date range before filtering
161
- date_stats = df.select(
162
- [
163
- pl.col("last_modified").min().alias("min_date"),
164
- pl.col("last_modified").max().alias("max_date"),
165
- ]
166
- ).collect()
167
- logger.info(f"Incoming data date range: {date_stats.row(0)}")
168
-
169
- # Ensure last_modified is datetime before comparison
170
- df = df.with_columns(pl.col("last_modified").str.to_datetime())
171
- df = df.filter(pl.col("last_modified") > latest_update)
172
- filtered_count = df.select(pl.len()).collect().item()
173
- logger.info(f"Found {filtered_count} records to update after filtering")
174
-
175
- if filtered_count == 0:
176
- logger.warning(
177
- "No new records found after filtering! This might indicate a problem."
178
- )
179
- # Log a few records that were just below the cutoff
180
- just_before = (
181
- df.select(["datasetId", "last_modified"])
182
- .filter(pl.col("last_modified") <= latest_update)
183
- .sort("last_modified", descending=True)
184
- .limit(3)
185
- .collect()
186
- )
187
- if len(just_before) > 0:
188
- logger.info(f"Records just before cutoff: {just_before.rows()}")
189
-
190
- df = df.collect()
191
- total_rows = len(df)
192
-
193
- if total_rows > 0:
194
- logger.info(f"Updating dataset collection with {total_rows} new records")
195
- logger.info(
196
- f"Date range of updates: {df['last_modified'].min()} to {df['last_modified'].max()}"
197
- )
198
-
199
- for i in range(0, total_rows, BATCH_SIZE):
200
- batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
201
- batch_size = len(batch_df)
202
- logger.info(
203
- f"Processing batch {i // BATCH_SIZE + 1}: {batch_size} records "
204
- f"({batch_df['last_modified'].min()} to {batch_df['last_modified'].max()})"
205
- )
206
-
207
- ids_to_upsert = batch_df.select(["datasetId"]).to_series().to_list()
208
-
209
- # Log progress for every batch
210
- if i == 0 or (i // BATCH_SIZE + 1) % 5 == 0: # Log every 5th batch
211
- logger.info(f"Upserting batch {i // BATCH_SIZE + 1} (sample IDs: {ids_to_upsert[:3]})")
212
-
213
- # Check if any of these already exist (sample only)
214
- if i == 0: # Only log for first batch to reduce noise
215
- existing_check = dataset_collection.get(
216
- ids=ids_to_upsert[:3], include=["metadatas"]
217
- )
218
- if existing_check["ids"]:
219
- logger.info(
220
- f"Sample: {len(existing_check['ids'])} existing records being updated"
221
- )
222
-
223
- dataset_collection.upsert(
224
- ids=ids_to_upsert,
225
- documents=batch_df.select(["summary"]).to_series().to_list(),
226
- metadatas=[
227
- {
228
- "likes": int(likes),
229
- "downloads": int(downloads),
230
- "last_modified": str(last_modified),
231
- }
232
- for likes, downloads, last_modified in zip(
233
- batch_df.select(["likes"]).to_series().to_list(),
234
- batch_df.select(["downloads"]).to_series().to_list(),
235
- batch_df.select(["last_modified"]).to_series().to_list(),
236
- )
237
- ],
238
- )
239
- logger.info(f"Processed {i + batch_size:,} / {total_rows:,} records")
240
-
241
- # Final validation
242
- final_count = dataset_collection.count()
243
- logger.info(f"Database initialized with {final_count:,} total rows")
244
-
245
- # Verify the update worked by checking latest records
246
- if final_count > 0:
247
- # Get ALL metadata to find the true latest timestamp (not just 5 records)
248
- final_metadata = dataset_collection.get(include=["metadatas"])
249
- final_timestamps = [
250
- dateutil.parser.parse(m.get("last_modified"))
251
- for m in final_metadata.get("metadatas")
252
- ]
253
- if final_timestamps:
254
- latest_after_update = max(final_timestamps)
255
- logger.info(f"Latest record after update: {latest_after_update}")
256
- if latest_update and latest_after_update <= latest_update:
257
- logger.error(
258
- "WARNING: No new records were added! Latest timestamp hasn't changed."
259
- )
260
- elif latest_update:
261
- logger.info(
262
- f"Successfully added records from {latest_update} to {latest_after_update}"
263
- )
264
- else:
265
- logger.info(f"Initial database setup completed. Latest record: {latest_after_update}")
266
-
267
- # Load model data
268
- model_lazy_df = pl.scan_parquet(
269
- "hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet"
270
- )
271
- model_row_count = model_lazy_df.select(pl.len()).collect().item()
272
- logger.info(f"Total model records in source: {model_row_count}")
273
-
274
- # Get the most recent last_modified date from the model collection
275
- model_latest_update = None
276
- if model_collection.count() > 0:
277
- model_metadata = model_collection.get(include=["metadatas"]).get(
278
- "metadatas"
279
- )
280
- logger.info(
281
- f"Found {len(model_metadata)} existing model records in collection"
282
- )
283
-
284
- model_last_modifieds = [
285
- dateutil.parser.parse(m.get("last_modified")) for m in model_metadata
286
- ]
287
- model_latest_update = max(model_last_modifieds)
288
- logger.info(f"Most recent model record in DB from: {model_latest_update}")
289
-
290
- # Set up model schema columns
291
- schema = model_lazy_df.collect_schema()
292
- select_columns = [
293
- "modelId",
294
- "summary",
295
- "likes",
296
- "downloads",
297
- "last_modified",
298
- ]
299
- if "param_count" in schema:
300
- logger.info("Found 'param_count' column in model data schema.")
301
- select_columns.append("param_count")
302
- else:
303
- logger.warning(
304
- "'param_count' column not found in model data schema. Will add it with null values."
305
- )
306
-
307
- # Filter and process only newer model records
308
- model_df = model_lazy_df.select(select_columns)
309
-
310
- # Apply timestamp filtering like we do for datasets
311
- if model_latest_update:
312
- logger.info(f"Filtering model records newer than {model_latest_update}")
313
- model_df = model_df.with_columns(pl.col("last_modified").str.to_datetime())
314
- model_df = model_df.filter(pl.col("last_modified") > model_latest_update)
315
- model_filtered_count = model_df.select(pl.len()).collect().item()
316
- logger.info(f"Found {model_filtered_count} model records to update after filtering")
317
- else:
318
- model_filtered_count = model_df.select(pl.len()).collect().item()
319
- logger.info(f"Initial model load: processing all {model_filtered_count} model records")
320
-
321
- if model_filtered_count > 0:
322
- model_df = model_df.collect()
323
-
324
- # If param_count was not in the original schema, add it now to the collected DataFrame
325
- if "param_count" not in model_df.columns:
326
- model_df = model_df.with_columns(
327
- pl.lit(None).cast(pl.Int64).alias("param_count")
328
- )
329
-
330
- total_rows = len(model_df)
331
- logger.info(f"Updating model collection with {total_rows} new records")
332
-
333
- for i in range(0, total_rows, BATCH_SIZE):
334
- batch_df = model_df.slice(i, min(BATCH_SIZE, total_rows - i))
335
-
336
- model_collection.upsert(
337
- ids=batch_df.select(["modelId"]).to_series().to_list(),
338
- documents=batch_df.select(["summary"]).to_series().to_list(),
339
- metadatas=[
340
- {
341
- "likes": int(likes),
342
- "downloads": int(downloads),
343
- "last_modified": str(last_modified),
344
- "param_count": int(param_count)
345
- if param_count is not None
346
- else 0,
347
- }
348
- for likes, downloads, last_modified, param_count in zip(
349
- batch_df.select(["likes"]).to_series().to_list(),
350
- batch_df.select(["downloads"]).to_series().to_list(),
351
- batch_df.select(["last_modified"]).to_series().to_list(),
352
- batch_df.select(["param_count"]).to_series().to_list(),
353
- )
354
- ],
355
- )
356
- logger.info(
357
- f"Processed {i + len(batch_df):,} / {total_rows:,} model rows"
358
- )
359
-
360
- logger.info(
361
- f"Model database initialized with {model_collection.count():,} rows"
362
- )
363
-
364
- except Exception as e:
365
- logger.error(f"Setup error: {e}")
366
-
367
-
368
- # Setup database is called in lifespan, not here
369
-
370
-
371
  class QueryResult(BaseModel):
372
  dataset_id: str
373
  similarity: float
 
3
  import os
4
  import sys
5
  from contextlib import asynccontextmanager
 
6
  from typing import List, Optional
7
 
8
  import chromadb
 
9
  import httpx
 
10
  import torch
11
  from cashews import cache
12
  from chromadb.utils import embedding_functions
13
  from fastapi import FastAPI, HTTPException, Query
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from pydantic import BaseModel
 
16
  from dotenv import load_dotenv
17
  from huggingface_hub import login
18
 
19
  load_dotenv(override=True)
20
  HF_TOKEN = os.getenv("HF_TOKEN")
21
  login(token=HF_TOKEN)
22
+
23
  # Configuration constants
 
24
  EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"
 
25
  CACHE_TTL = "24h"
26
+ TRENDING_CACHE_TTL = "1h"
27
 
28
  if torch.cuda.is_available():
29
  DEVICE = "cuda"
 
32
  else:
33
  DEVICE = "cpu"
34
 
35
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
 
 
 
 
36
  logging.basicConfig(level=logging.INFO)
37
  logger = logging.getLogger(__name__)
38
 
39
+ LOCAL = sys.platform == "darwin"
 
 
40
  DATA_DIR = "data" if LOCAL else "/data"
41
+
42
  # Configure cache
43
  cache.setup("mem://", size_limit="8gb")
44
 
45
+ # Initialize ChromaDB client (index is pre-built by build_chroma_index.py Job)
46
  client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma")
47
 
48
 
49
  # Initialize FastAPI app
50
  @asynccontextmanager
51
  async def lifespan(app: FastAPI):
52
+ # Index is pre-built by build_chroma_index.py Job — no setup needed
53
+ logger.info(f"ChromaDB path: {DATA_DIR}/chroma")
54
+ try:
55
+ dc = client.get_collection("dataset_cards")
56
+ mc = client.get_collection("model_cards")
57
+ logger.info(f"dataset_cards: {dc.count():,} records, model_cards: {mc.count():,} records")
58
+ except Exception as e:
59
+ logger.error(f"Failed to read collections — is the bucket mounted at {DATA_DIR}? {e}")
60
 
61
  yield
62
 
 
63
  await cache.close()
64
 
65
 
 
87
  )
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  class QueryResult(BaseModel):
91
  dataset_id: str
92
  similarity: float