Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import numpy as np | |
| import faiss | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from groq import Groq | |
| import nltk | |
| import re | |
| from nltk.corpus import stopwords | |
| from nltk.tokenize import word_tokenize, sent_tokenize | |
| from nltk.stem import WordNetLemmatizer | |
| from multiprocessing import Pool, cpu_count | |
| nltk.download("all") | |
| # Load stopwords and lemmatizer | |
| stop_words = set(stopwords.words("english")) | |
| lemmatizer = WordNetLemmatizer() | |
| # Load dataset | |
| def load_and_preprocess_dataset(): | |
| """Load and preprocess the dataset.""" | |
| dataset = load_dataset("MedRAG/textbooks") | |
| print("Dataset loaded successfully.") | |
| return dataset | |
| # Preprocessing function | |
| def preprocess_text(text): | |
| """Preprocess text by lowercasing, removing special characters, and lemmatizing.""" | |
| text = text.lower() # Convert to lowercase | |
| text = re.sub(r"[^\w\s]", "", text) # Remove special characters | |
| words = word_tokenize(text) # Tokenization | |
| words = [lemmatizer.lemmatize(w) for w in words if w not in stop_words] # Lemmatization & stopword removal | |
| return " ".join(words) | |
| # Chunking function | |
| def chunk_text(text, chunk_size=3): | |
| """Split text into chunks of sentences.""" | |
| sentences = sent_tokenize(text) # Split text into sentences | |
| return [" ".join(sentences[i:i + chunk_size]) for i in range(0, len(sentences), chunk_size)] | |
| # Generate embeddings in parallel | |
| def generate_embeddings_parallel(chunks): | |
| """Generate embeddings for chunks in parallel.""" | |
| embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| with Pool(cpu_count()) as pool: | |
| embeddings = pool.map(embed_model.encode, chunks) | |
| return embeddings | |
| # Generate embeddings for the dataset | |
| def generate_embeddings(dataset): | |
| """Generate embeddings for the dataset.""" | |
| print("Preprocessing dataset...") | |
| dataset = dataset.map(lambda row: {"cleaned_content": preprocess_text(row["content"])}) | |
| dataset = dataset.map(lambda row: {"chunks": chunk_text(row["cleaned_content"])}) | |
| print("Generating embeddings...") | |
| all_chunks = [chunk for row in dataset["train"]["chunks"] for chunk in row] | |
| embeddings = generate_embeddings_parallel(all_chunks) | |
| # Add embeddings to the dataset | |
| dataset = dataset.map(lambda row, idx: {"embedding": embeddings[idx]}, with_indices=True) | |
| return dataset | |
| # Create FAISS index | |
| def create_faiss_index(dataset): | |
| """Create and save a FAISS index for the embeddings.""" | |
| embeddings_np = np.array([np.array(row["embedding"]).flatten().tolist() for row in dataset["train"]], dtype=np.float32) | |
| index = faiss.IndexFlatL2(embeddings_np.shape[1]) | |
| index.add(embeddings_np) | |
| faiss.write_index(index, "faiss_medical.index") | |
| print("FAISS index created and saved.") | |
| # Load FAISS index | |
| def load_faiss_index(): | |
| """Load the FAISS index.""" | |
| index = faiss.read_index("faiss_medical.index") | |
| print("FAISS index loaded.") | |
| return index | |
| # Retrieve medical summary | |
| def retrieve_medical_summary(query, index, id_to_text, k=3): | |
| """Retrieve the most relevant medical literature from FAISS.""" | |
| embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| query_embedding = embed_model.encode([query]) | |
| D, I = index.search(np.array(query_embedding).astype("float32"), k) | |
| retrieved_docs = [id_to_text.get(int(idx), "No relevant data found.") for idx in I[0]] | |
| retrieved_docs = [doc if isinstance(doc, str) else " ".join(doc) for doc in retrieved_docs] | |
| return "\n\n---\n\n".join(retrieved_docs) if retrieved_docs else "No relevant data found." | |
| # Generate medical answer using Groq | |
| def generate_medical_answer_groq(query, index, id_to_text): | |
| """Generate a medical response using Groq's API.""" | |
| retrieved_summary = retrieve_medical_summary(query, index, id_to_text) | |
| if not retrieved_summary or retrieved_summary == "No relevant data found.": | |
| return "No relevant medical data found. Please consult a healthcare professional." | |
| client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
| try: | |
| response = client.chat.completions.create( | |
| model="llama-3.3-70b-versatile", | |
| messages=[ | |
| {"role": "system", "content": "You are an expert AI specializing in medical knowledge."}, | |
| {"role": "user", "content": f"Summarize the following medical literature and provide a structured medical answer:\n\n### Medical Literature ###\n{retrieved_summary}\n\n### Patient Question ###\n{query}\n\n### Medical Advice ###"} | |
| ], | |
| max_tokens=500, | |
| temperature=0.3 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}" | |
| # Gradio interface | |
| def ask_medical_question(question): | |
| """Gradio interface for asking medical questions.""" | |
| return generate_medical_answer_groq(question, index, id_to_text) | |
| # Main function | |
| def main(): | |
| """Main function to set up the system.""" | |
| global index, id_to_text | |
| # Load and preprocess dataset | |
| dataset = load_and_preprocess_dataset() | |
| dataset = generate_embeddings(dataset) | |
| # Create FAISS index | |
| create_faiss_index(dataset) | |
| # Load FAISS index | |
| index = load_faiss_index() | |
| # Create ID to text mapping | |
| medical_texts = dataset["train"]["chunks"] | |
| id_to_text = {idx: text for idx, text in enumerate(medical_texts)} | |
| with open("id_to_text.json", "w") as f: | |
| json.dump(id_to_text, f) | |
| # Launch Gradio app | |
| iface = gr.Interface( | |
| fn=ask_medical_question, | |
| inputs=gr.Textbox(lines=2, placeholder="Enter your medical question here..."), | |
| outputs=gr.Textbox(lines=10, placeholder="AI-generated medical advice will appear here..."), | |
| title="Medical Question Answering System", | |
| description="Ask any medical question, and the AI will provide an answer based on medical literature." | |
| ) | |
| iface.launch() | |
| # Run the main function | |
| if __name__ == "__main__": | |
| main() |