mwitiderrick commited on
Commit
2304b58
·
verified ·
1 Parent(s): e2c5c85

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +30 -0
  2. rag_dspy.py +74 -0
  3. readme.md +43 -0
  4. requirements.txt +8 -3
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from rag_dspy import MedicalRAG
3
+
4
+ st.set_page_config(page_title="Medical QA Bot", page_icon="🩺")
5
+ st.title("🩺 Medical QA Bot")
6
+ st.write("Ask a medical question and get an answer based on retrieved medical literature.")
7
+
8
+ if "history" not in st.session_state:
9
+ st.session_state["history"] = []
10
+
11
+ rag_chain = MedicalRAG()
12
+
13
+ with st.form("chat_form"):
14
+ user_question = st.text_input("Enter your medical question:", "")
15
+ submitted = st.form_submit_button("Get Answer")
16
+
17
+ if submitted and user_question.strip():
18
+ with st.spinner("Retrieving answer..."):
19
+ result = rag_chain.forward(user_question)
20
+ answer = result.final_answer
21
+ st.session_state["history"].append((user_question, answer))
22
+ st.markdown(f"**Answer:** {answer}")
23
+
24
+ if st.session_state["history"]:
25
+ st.markdown("---")
26
+ st.markdown("### Conversation History")
27
+ for q, a in reversed(st.session_state["history"]):
28
+ st.markdown(f"**Q:** {q}")
29
+ st.markdown(f"**A:** {a}")
30
+
rag_dspy.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag_dspy.py
2
+
3
+ import dspy
4
+ from dspy_qdrant import QdrantRM
5
+ from qdrant_client import QdrantClient, models
6
+ from dotenv import load_dotenv
7
+ import os
8
+
9
+ load_dotenv()
10
+ # DSPy setup
11
+ lm = dspy.LM("gpt-4", max_tokens=512,api_key=os.environ.get("OPENAI_API_KEY"))
12
+ client = QdrantClient(url=os.environ.get("QDRANT_CLOUD_URL"), api_key=os.environ.get("QDRANT_API_KEY"))
13
+ collection_name = "medical_chat_bot"
14
+ rm = QdrantRM(
15
+ qdrant_collection_name=collection_name,
16
+ qdrant_client=client,
17
+ vector_name="dense", # <-- MATCHES your vector field in upsert
18
+ document_field="passage_text", # <-- MATCHES your payload field in upsert
19
+ k=20)
20
+
21
+ dspy.settings.configure(lm=lm, rm=rm)
22
+
23
+ # Manual reranker using ColBERT multivector field
24
+ # Manual reranker using Qdrant’s native prefetch + ColBERT query
25
+ def rerank_with_colbert(query_text):
26
+ from fastembed import TextEmbedding, LateInteractionTextEmbedding
27
+
28
+ # Encode query once with both models
29
+ dense_model = TextEmbedding("BAAI/bge-small-en")
30
+ colbert_model = LateInteractionTextEmbedding("colbert-ir/colbertv2.0")
31
+
32
+ dense_query = list(dense_model.embed(query_text))[0]
33
+ colbert_query = list(colbert_model.embed(query_text))[0]
34
+
35
+ # Combined query: retrieve with dense, rerank with ColBERT
36
+ results = client.query_points(
37
+ collection_name=collection_name,
38
+ prefetch=models.Prefetch(
39
+ query=dense_query,
40
+ using="dense"
41
+ ),
42
+ query=colbert_query,
43
+ using="colbert",
44
+ limit=5,
45
+ with_payload=True
46
+ )
47
+
48
+ points = results.points
49
+ docs = []
50
+
51
+ for point in points:
52
+ docs.append(point.payload['passage_text'])
53
+
54
+ return docs
55
+
56
+ # DSPy Signature and Module
57
+ class MedicalAnswer(dspy.Signature):
58
+ question = dspy.InputField(desc="The medical question to answer")
59
+ context = dspy.OutputField(desc="The answer to the medical question")
60
+ final_answer = dspy.OutputField(desc="The answer to the medical question")
61
+
62
+ class MedicalRAG(dspy.Module):
63
+ def __init__(self):
64
+ super().__init__()
65
+
66
+ def forward(self, question):
67
+ reranked_docs = rerank_with_colbert(question)
68
+
69
+ context_str = "\n".join(reranked_docs)
70
+
71
+ return dspy.ChainOfThought(MedicalAnswer)(
72
+ question=question,
73
+ context=context_str
74
+ )
readme.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Medical QA Chatbot
2
+
3
+ This is a Chain-of-Thought powered medical chatbot that:
4
+
5
+ - Retrieves answers from a Qdrant Cloud vector DB using dense + ColBERT multivectors
6
+ - Uses Stanford DSPy to reason step-by-step with retrieved context
7
+ - Supports traceable source highlighting in Chainlit
8
+ - Deployable on Hugging Face Spaces via Docker
9
+
10
+ ---
11
+
12
+ ## How to Deploy
13
+
14
+ - Add your `OPENAI_API_KEY` as a secret environment variable in Hugging Face Space settings
15
+ - Make sure `qdrant-client` points to your Qdrant Cloud instance in `rag_dspy.py`
16
+ - Run the Space
17
+
18
+ ## Sample Questions
19
+
20
+ ### General Medical Knowledge
21
+ - What are the most common symptoms of lupus?
22
+
23
+ - How is type 2 diabetes usually managed in adults?
24
+
25
+ - What is the difference between viral and bacterial pneumonia?
26
+
27
+ ### Treatment & Medication
28
+ - What are the first-line medications for treating hypertension?
29
+
30
+ - How does metformin work to lower blood sugar?
31
+
32
+ ### Diagnosis & Tests
33
+ - What diagnostic tests are used to detect rheumatoid arthritis?
34
+
35
+ - When is a colonoscopy recommended for cancer screening?
36
+
37
+ ### Hospital & Patient Care
38
+ - What are the psychosocial challenges faced by cancer patients?
39
+
40
+ - How do hospitals manage patients with multidrug-resistant infections?
41
+
42
+ ### Clinical Guidelines / Rare Topics
43
+ - What is the recommended treatment for acute myocardial infarction in elderly patients?
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
1
+ datasets==3.6.0
2
+ streamlit
3
+ git+https://github.com/stanfordnlp/dspy.git
4
+ python-dotenv==1.1.0
5
+ cachetools
6
+ cloudpickle
7
+ qdrant-client[fastembed]>=1.14.2
8
+ dspy-qdrant