from __future__ import annotations import json from pathlib import Path from typing import Any, Dict, List, Tuple, Optional import faiss import numpy as np from datasets import Dataset, load_dataset from sentence_transformers import SentenceTransformer from .config import RAGBENCH_DATASET, EMBEDDING_MODEL class SubsetVectorDB: """ Simple FAISS-based vector database for a single RAGBench subset + split. This class is intentionally lightweight and file-based: - Each (subset, split) pair gets its own folder under ``vector_store/``. - We build a single FAISS index over all documents' concatenated text. - We also persist a small ``meta.json`` mapping index -> (row_index, doc_index). At evaluation time we can: - Lazily build the index once (or load it if it already exists). - Retrieve the top-k most similar documents for a given question. - Optionally restrict results to a particular example row. """ def __init__( self, subset: str, split: str = "test", root_dir: Optional[Path] = None, ) -> None: self.subset = subset self.split = split project_root = Path(__file__).resolve().parents[1] self.root_dir = (root_dir or project_root / "vector_store").resolve() self.index_dir = self.root_dir / subset / split self.index_dir.mkdir(parents=True, exist_ok=True) self.index_path = self.index_dir / "index.faiss" self.meta_path = self.index_dir / "meta.json" # Will be populated by ``build_or_load`` self.embedder: Optional[SentenceTransformer] = None self.index: Optional[faiss.Index] = None self.meta: List[Dict[str, Any]] = [] # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _load_embedder(self) -> SentenceTransformer: if self.embedder is None: self.embedder = SentenceTransformer(EMBEDDING_MODEL) return self.embedder def _load_index_files(self) -> bool: """ Try to load index + meta files from disk. Returns True if successful, False if anything is missing. """ if not self.index_path.exists() or not self.meta_path.exists(): return False self.index = faiss.read_index(str(self.index_path)) with self.meta_path.open("r", encoding="utf-8") as f: self.meta = json.load(f) return True # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def build_or_load(self, ds: Optional[Dataset] = None) -> None: """ Ensure the FAISS index exists for (subset, split). If the index files are already on disk we simply load them. Otherwise we: - iterate over the dataset - concatenate each document's sentences into a single string - build a dense embedding using SentenceTransformers - create a cosine-similarity FAISS index and persist it """ if self._load_index_files(): return if ds is None: ds = load_dataset(RAGBENCH_DATASET, self.subset, split=self.split) texts: List[str] = [] meta: List[Dict[str, Any]] = [] for row_idx, row in enumerate(ds): # ``documents_sentences`` is a list of docs; # each doc is a list of (sentence_key, sentence_text) pairs. for doc_idx, doc in enumerate(row["documents_sentences"]): doc_text = " ".join(sentence_text for _, sentence_text in doc) texts.append(doc_text) meta.append({"row_index": int(row_idx), "doc_index": int(doc_idx)}) if not texts: raise ValueError( f"No documents found while building vector DB for subset={self.subset}, split={self.split}" ) embedder = self._load_embedder() embeddings = embedder.encode( texts, batch_size=32, show_progress_bar=True, convert_to_numpy=True, ) # FAISS expects float32 embeddings = np.asarray(embeddings, dtype="float32") # Use cosine similarity via inner product on L2-normalized vectors faiss.normalize_L2(embeddings) dim = embeddings.shape[1] index = faiss.IndexFlatIP(dim) index.add(embeddings) # Persist to disk so subsequent runs are cheap faiss.write_index(index, str(self.index_path)) with self.meta_path.open("w", encoding="utf-8") as f: json.dump(meta, f, indent=2) self.index = index self.meta = meta def search( self, query: str, k: int = 10, restrict_row_index: Optional[int] = None, ) -> List[Tuple[int, int, float]]: """ Search the vector DB for the top-k documents relevant to ``query``. Returns a list of (row_index, doc_index, score) tuples. If ``restrict_row_index`` is provided, we will over-sample and then filter to only documents that belong to that example row. """ if self.index is None or not self.meta: if not self._load_index_files(): raise RuntimeError( "Vector DB has not been built yet. Call build_or_load() first." ) embedder = self._load_embedder() q_emb = embedder.encode([query], convert_to_numpy=True) q_emb = np.asarray(q_emb, dtype="float32") faiss.normalize_L2(q_emb) # For restricted searches we over-sample so that filtering still leaves # enough candidates. For unrestricted we just use k. search_k = k * 10 if restrict_row_index is not None else k search_k = max(search_k, k) scores, indices = self.index.search(q_emb, search_k) scores = scores[0] indices = indices[0] results: List[Tuple[int, int, float]] = [] for idx, score in zip(indices, scores): if idx < 0 or idx >= len(self.meta): continue meta = self.meta[int(idx)] row_index = meta["row_index"] doc_index = meta["doc_index"] if restrict_row_index is not None and row_index != restrict_row_index: continue results.append((row_index, doc_index, float(score))) if len(results) >= k: break return results