# app/main.py import os from typing import List, Tuple, Optional import requests from datasets import load_dataset from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse from pydantic import BaseModel from ragbench_eval.pipeline import RagBenchExperiment from ragbench_eval.retriever import ExampleRetriever from ragbench_eval.generator import RAGGenerator from ragbench_eval.judge import RAGJudge from ragbench_eval.metrics import trace_from_attributes from ragbench_eval.config import RAGBENCH_DATASET, DOMAIN_TO_SUBSETS # --------------------------------------------------------------------- # FastAPI app # --------------------------------------------------------------------- app = FastAPI(title="RAGBench Chat + RAG Evaluation API") # --------------------------------------------------------------------- # Config for Hugging Face router (LLM chat) # --------------------------------------------------------------------- HF_TOKEN = os.getenv("HF_TOKEN") HF_CHAT_BASE_URL = os.getenv("HF_CHAT_BASE_URL", "https://router.huggingface.co/v1") HF_CHAT_MODEL = os.getenv( "HF_CHAT_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct" ) if HF_TOKEN is None or HF_TOKEN.strip() == "": # We don't crash the app, but /chat will raise a clear error print("[RAGBench Chat] WARNING: HF_TOKEN is not set. /chat will fail.") # --------------------------------------------------------------------- # Request models # --------------------------------------------------------------------- class RunRequest(BaseModel): domain: str # "biomedical", "general_knowledge", ... k: int = 3 max_examples: Optional[int] = 20 split: str = "test" # "test" or "validation" class QAExampleRequest(BaseModel): subset: str # e.g. "covidqa", "pubmedqa" index: int = 0 # which row from that subset k: int = 3 split: str = "test" class ChatRequest(BaseModel): domain: str # must be one of DOMAIN_TO_SUBSETS keys question: str # --------------------------------------------------------------------- # LLM Chat endpoint (using HF router OpenAI-compatible API) # --------------------------------------------------------------------- @app.post("/chat") def chat(req: ChatRequest): """ Simple domain-aware chat endpoint. Uses Hugging Face router OpenAI-compatible Chat Completions API: POST {HF_CHAT_BASE_URL}/chat/completions """ if not HF_TOKEN: raise HTTPException( status_code=500, detail="HF_TOKEN environment variable is not set in the backend.", ) if req.domain not in DOMAIN_TO_SUBSETS: raise HTTPException( status_code=400, detail=f"Unknown domain '{req.domain}'. " f"Valid domains: {', '.join(DOMAIN_TO_SUBSETS.keys())}", ) system_prompt = ( "You are an assistant answering questions in the domain: " f"{req.domain}. " "Answer using correct, verifiable information. " "If you are not sure, clearly say that you are not sure instead of " "guessing. Be concise and avoid fabricating facts." ) payload = { "model": HF_CHAT_MODEL, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": req.question}, ], "temperature": 0.2, "max_tokens": 512, } try: resp = requests.post( f"{HF_CHAT_BASE_URL}/chat/completions", headers={ "Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json", }, json=payload, timeout=60, ) resp.raise_for_status() except requests.exceptions.RequestException as e: # Surface clear error to frontend (will show "Error: HTTP 500") raise HTTPException( status_code=500, detail=f"Hugging Face router request failed: {e}", ) data = resp.json() try: answer = data["choices"][0]["message"]["content"] except Exception as e: raise HTTPException( status_code=500, detail=f"Unexpected response format from HF router: {e}", ) return {"answer": answer} # --------------------------------------------------------------------- # RAGBench evaluation endpoints # --------------------------------------------------------------------- @app.post("/run_domain") def run_domain(req: RunRequest): """ Run a full RAGBench evaluation over a domain (biomedical, finance, etc.) """ exp = RagBenchExperiment( k=req.k, max_examples=req.max_examples, split=req.split, ) result = exp.run_domain(req.domain) return result @app.post("/qa_example") def qa_example(req: QAExampleRequest): """ Run RAG (retriever + generator + judge) on a single RAGBench example. """ ds = load_dataset(RAGBENCH_DATASET, req.subset, split=req.split) if req.index < 0 or req.index >= len(ds): return {"error": f"index {req.index} out of range (0..{len(ds) - 1})"} row = ds[req.index] # Build full per-document sentence lists docs_sentences_full: List[List[Tuple[str, str]]] = [] for doc in row["documents_sentences"]: docs_sentences_full.append([(k, s) for k, s in doc]) question = row["question"] # 1) Retrieve top-k docs retriever = ExampleRetriever() doc_indices = retriever.rank_docs(question, docs_sentences_full, k=req.k) selected_docs = [docs_sentences_full[j] for j in doc_indices] # 2) Generate answer from retrieved docs generator = RAGGenerator() answer = generator.generate(question, selected_docs) # 3) Judge + metrics judge = RAGJudge() attrs = judge.annotate(question, answer, selected_docs) pred_metrics = trace_from_attributes(attrs, selected_docs) docs_view = [] for i, doc in enumerate(selected_docs): docs_view.append( { "doc_index": doc_indices[i], "sentences": [{"key": k, "text": s} for k, s in doc], } ) return { "subset": req.subset, "index": req.index, "question": question, "answer": answer, "retrieved_docs": docs_view, "judge_attributes": attrs, "predicted_trace_metrics": pred_metrics, "ground_truth": { "relevance_score": row.get("relevance_score"), "utilization_score": row.get("utilization_score"), "completeness_score": row.get("completeness_score"), "adherence_score": row.get("adherence_score"), }, } @app.get("/health") def health(): return {"status": "ok"} # --------------------------------------------------------------------- # HTML Chat UI at root "/" # --------------------------------------------------------------------- @app.get("/", response_class=HTMLResponse) def chat_ui(): html = """ RAGBench Chat

RAGBench Chat

Select a domain, then start chatting.
""" return HTMLResponse(content=html) # --------------------------------------------------------------------- # HTML RAGBench Evaluation UI at "/eval" # --------------------------------------------------------------------- @app.get("/eval", response_class=HTMLResponse) def eval_ui(): """Simple page to run /run_domain evaluations from the browser.""" html = """ RAGBench RAG Evaluation

RAGBench RAG Evaluation

Run Domain Evaluation
Results
{}
""" return HTMLResponse(content=html)