AliHashir commited on
Commit
947ed8c
·
1 Parent(s): 2d6a8ea

feat: implement evidence selection logic and add debug endpoint

Browse files
Files changed (4) hide show
  1. app/logic/selector.py +62 -1
  2. app/main.py +10 -0
  3. app/nlp/embed.py +26 -1
  4. requirements.txt +3 -1
app/logic/selector.py CHANGED
@@ -1 +1,62 @@
1
- """Evidence selection and ranking logic."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/logic/selector.py
2
+ from __future__ import annotations
3
+ import asyncio
4
+ from typing import List
5
+ import numpy as np
6
+
7
+ from app.schemas import Source
8
+ from app.fetch.fetcher import get_paragraphs_with_fallback
9
+ from app.nlp.embed import embed_text, embed_texts
10
+
11
+ SIM_THRESHOLD = 0.25 # drop very weak matches
12
+
13
+ async def select_evidence(
14
+ claim: str,
15
+ sources: List[Source],
16
+ per_source: int = 2,
17
+ max_total: int = 8,
18
+ ) -> List[Source]:
19
+ claim_vec = embed_text(claim)
20
+
21
+ # fetch paragraphs concurrently
22
+ tasks = [get_paragraphs_with_fallback(s.url, s.snippet) for s in sources]
23
+ all_paras = await asyncio.gather(*tasks)
24
+
25
+ selected_sources: list[Source] = []
26
+ for s, paras in zip(sources, all_paras):
27
+ if not paras:
28
+ selected_sources.append(s)
29
+ continue
30
+
31
+ para_vecs = embed_texts(paras)
32
+ sims = para_vecs @ claim_vec # cosine because normalized
33
+ top_idx = np.argsort(-sims)[:per_source]
34
+
35
+ evidence: list[str] = []
36
+ for i in top_idx:
37
+ score = float(sims[i])
38
+ if score < SIM_THRESHOLD:
39
+ continue
40
+ text = paras[i].strip()
41
+ if len(text) > 500:
42
+ text = text[:497] + "..."
43
+ evidence.append(text)
44
+
45
+ selected_sources.append(
46
+ Source(title=s.title, url=s.url, snippet=s.snippet, evidence=evidence)
47
+ )
48
+
49
+ # cap total evidence across all sources
50
+ def total_evidence() -> int:
51
+ return sum(len(s.evidence) for s in selected_sources)
52
+
53
+ if total_evidence() > max_total:
54
+ # trim round-robin
55
+ while total_evidence() > max_total:
56
+ for s in selected_sources:
57
+ if s.evidence:
58
+ s.evidence.pop()
59
+ if total_evidence() <= max_total:
60
+ break
61
+
62
+ return selected_sources
app/main.py CHANGED
@@ -36,6 +36,16 @@ async def _fetch(u: str = Query(..., min_length=10, max_length=2000)):
36
  return {"count": len(paras), "samples": paras[:3]}
37
 
38
 
 
 
 
 
 
 
 
 
 
 
39
  # Root endpoint
40
  @app.get("/")
41
  async def root():
 
36
  return {"count": len(paras), "samples": paras[:3]}
37
 
38
 
39
+ @app.get("/_select")
40
+ async def _select(claim: str = Query(..., min_length=8, max_length=300)):
41
+ """Debug select endpoint for testing evidence selection."""
42
+ from app.logic.selector import select_evidence
43
+ search = get_search()
44
+ sources = await search(claim)
45
+ picked = await select_evidence(claim, sources, per_source=2, max_total=8)
46
+ return {"n_sources": len(picked), "items": [s.model_dump() for s in picked]}
47
+
48
+
49
  # Root endpoint
50
  @app.get("/")
51
  async def root():
app/nlp/embed.py CHANGED
@@ -1 +1,26 @@
1
- """Text embedding utilities using sentence-transformers."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/nlp/embed.py
2
+ from __future__ import annotations
3
+ from functools import lru_cache
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+ MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
8
+
9
+ @lru_cache(maxsize=1)
10
+ def _load_model() -> SentenceTransformer:
11
+ # CPU is fine for this model
12
+ return SentenceTransformer(MODEL_NAME)
13
+
14
+ def embed_texts(texts: list[str]) -> np.ndarray:
15
+ model = _load_model()
16
+ vecs = model.encode(
17
+ texts,
18
+ batch_size=32,
19
+ convert_to_numpy=True,
20
+ normalize_embeddings=True,
21
+ show_progress_bar=False,
22
+ )
23
+ return vecs.astype("float32")
24
+
25
+ def embed_text(text: str) -> np.ndarray:
26
+ return embed_texts([text])[0]
requirements.txt CHANGED
@@ -11,7 +11,7 @@ lxml==4.9.3
11
 
12
  # ML and NLP
13
  transformers==4.35.2
14
- sentence-transformers==2.2.2
15
  torch==2.1.1
16
  scikit-learn==1.3.2
17
 
@@ -20,3 +20,5 @@ python-dotenv==1.0.0
20
 
21
  # Optional web interface
22
  jinja2==3.1.2
 
 
 
11
 
12
  # ML and NLP
13
  transformers==4.35.2
14
+ sentence-transformers==2.7.0
15
  torch==2.1.1
16
  scikit-learn==1.3.2
17
 
 
20
 
21
  # Optional web interface
22
  jinja2==3.1.2
23
+
24
+ numpy==1.24.4