Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import PyPDFLoader, WebBaseLoader | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.vectorstores import SKLearnVectorStore | |
| from langchain_openai import ChatOpenAI | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_pinecone import PineconeVectorStore | |
| from langchain.prompts import PromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from pydantic import BaseModel, Field | |
| from typing import List, TypedDict, Optional | |
| from langchain.schema import Document | |
| from langgraph.graph import START, END, StateGraph | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| url = [ | |
| "https://www.investopedia.com/", | |
| "https://www.fool.com/", | |
| "https://www.morningstar.com/", | |
| "https://www.kiplinger.com/", | |
| "https://www.nerdwallet.com/" | |
| ] | |
| # Initialize Embedding and Vector DB | |
| embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| # Initialize Pinecone connection | |
| try: | |
| pc = PineconeVectorStore( | |
| pinecone_api_key=os.environ.get('PINCE_CONE_LIGHT'), | |
| embedding=embedding_model, | |
| index_name='rag-rubic', | |
| namespace='vectors_lightmodel' | |
| ) | |
| retriever = pc.as_retriever(search_kwargs={"k": 10}) | |
| except Exception as e: | |
| print(f"Pinecone connection error: {e}") | |
| # Fallback to SKLearn vector store if Pinecone fails | |
| retriever = None | |
| # Initialize the LLM | |
| llm = ChatOpenAI( | |
| model='gpt-4o-mini', | |
| api_key=os.environ.get('OPEN_AI_KEY'), | |
| temperature=0.2 | |
| ) | |
| # Schema for grading documents | |
| class GradeDocuments(BaseModel): | |
| binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'") | |
| structured_llm_grader = llm.with_structured_output(GradeDocuments) | |
| # Define System and Grading prompt | |
| system = """You are a grader assessing relevance of a retrieved document to a user question. | |
| If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant. | |
| Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""" | |
| grade_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", system), | |
| ("human", "Retrieved document: \n\n {documents} \n\n User question: {question}") | |
| ]) | |
| retrieval_grader = grade_prompt | structured_llm_grader | |
| # RAG Prompt template | |
| prompt = PromptTemplate( | |
| template=''' | |
| You are a Registered Investment Advisor with expertise in Indian financial markets and client relations. | |
| You must understand what the user is asking about their financial investments and respond to their queries based on the information in the documents only. | |
| Use the following documents to answer the question. If you do not know the answer, say you don't know. | |
| Query: {question} | |
| Documents: {context} | |
| ''', | |
| input_variables=['question', 'context'] | |
| ) | |
| rag_chain = prompt | llm | StrOutputParser() | |
| # Web search tool for adding data from websites | |
| web_search_tool = TavilySearchResults(api_key=os.environ.get('TAVILY_API_KEY'), k=10) | |
| # Define Graph states and transitions | |
| class GraphState(TypedDict): | |
| question: str | |
| generation: Optional[str] | |
| need_web_search: Optional[str] # Changed from 'web_search' to 'need_web_search' | |
| documents: List | |
| def retrieve_db(state): | |
| """Gather data for the query.""" | |
| question = state['question'] | |
| if retriever: | |
| try: | |
| results = retriever.invoke(question) | |
| return {'documents': results, 'question': question} | |
| except Exception as e: | |
| print(f"Retriever error: {e}") | |
| # If retriever fails or doesn't exist, return empty documents | |
| return {'documents': [], 'question': question, 'need_web_search': 'yes'} | |
| def grade_docs(state): | |
| """Grades the docs generated by the retriever_db | |
| If 1, returns the docs if 0 proceeds for web search""" | |
| question = state['question'] | |
| docs = state['documents'] | |
| filterd_data = [] | |
| web = "no" | |
| for data in docs: | |
| score = retrieval_grader.invoke({'question':question, 'documents':docs}) | |
| grade = score.binary_score | |
| if grade == 'yes': | |
| filterd_data.append(data) | |
| else: | |
| #print("----------Failed, proceeding with WebSearch------------------") | |
| web = 'yes' | |
| return {"documents": filterd_data, "question": question, "need_web_search": web} | |
| def decide(state): | |
| """Decide if the generation should be based on DB or web search DATA""" | |
| web = state.get('need_web_search', 'no') # Updated key name | |
| if web == 'yes': | |
| return 'web_search' | |
| else: | |
| return 'generate' | |
| def web_search(state): | |
| """Perform a web search and store both content and source URLs.""" | |
| question = state['question'] | |
| documents = state["documents"] | |
| # Get search results | |
| results = web_search_tool.invoke({"query": question}) | |
| # Process results with sources | |
| docs = [] | |
| for res in results: | |
| content = res["content"] # Extract answer content | |
| source = res["url"] # Extract source URL | |
| # Create Document with metadata | |
| doc = Document(page_content=content, metadata={"source": source}) | |
| docs.append(doc) | |
| if not results: | |
| #print("No results from web search. Returning default response.") | |
| return {"documents": [], "question": question} | |
| documents.extend(docs) | |
| return {"documents": documents, "question": question} | |
| def generate(state): | |
| #print("Inside generate function") # Debugging | |
| documents = state['documents'] | |
| question = state['question'] | |
| # Generate response using retrieved documents | |
| response = rag_chain.invoke({'context': documents, 'question': question}) | |
| # Extract source URLs | |
| sources = [doc.metadata.get("source", "Unknown source") for doc in documents if "source" in doc.metadata] | |
| # Format response with citations | |
| formatted_response = response + "\n\nSources:\n" + "\n".join(sources) if sources else response | |
| #print("Generated response:", formatted_response) # Debugging | |
| # Return response with sources | |
| return { | |
| 'documents': documents, | |
| 'question': question, | |
| 'generation': formatted_response # Append sources to the response | |
| } | |
| # Compile Workflow | |
| workflow = StateGraph(GraphState) | |
| workflow.add_node("retrieve", retrieve_db) | |
| workflow.add_node("grader", grade_docs) | |
| workflow.add_node("web_search", web_search) # Now this won't conflict with the state key | |
| workflow.add_node("generate", generate) | |
| workflow.add_edge(START, "retrieve") | |
| workflow.add_edge("retrieve", "grader") | |
| workflow.add_conditional_edges( | |
| "grader", | |
| decide, | |
| { | |
| 'web_search': 'web_search', | |
| 'generate': 'generate' | |
| }, | |
| ) | |
| workflow.add_edge("web_search", "generate") | |
| workflow.add_edge("generate", END) | |
| # Compile the graph | |
| crag = workflow.compile() | |
| # Define Gradio Interface with proper chat history management | |
| def process_query(user_input, history): | |
| # Initialize history if it's None | |
| if history is None: | |
| history = [] | |
| # Add user input to history | |
| history.append((user_input, "")) | |
| # Process the query | |
| inputs = {"question": user_input} | |
| response = "" | |
| try: | |
| # Execute the graph | |
| result = crag.invoke(inputs) | |
| if result and 'generation' in result: | |
| response = result['generation'] | |
| else: | |
| response = "I couldn't find relevant information to answer your question." | |
| except Exception as e: | |
| #print(f"Error in crag execution: {e}") | |
| response = "I encountered an error while processing your request. Please try again." | |
| # Update the last response in history | |
| history[-1] = (user_input, response) | |
| return history, "" | |
| # Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🤖 RAG-Powered Financial Advisor Chatbot") | |
| chatbot = gr.Chatbot( | |
| [], | |
| elem_id="chatbot", | |
| bubble_full_width=False, | |
| height=600, | |
| avatar_images=(None, "🤖") | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Ask me anything about Indian financial markets...", | |
| label="Your question:", | |
| scale=9 | |
| ) | |
| submit_btn = gr.Button("Send", scale=1) | |
| clear_btn = gr.Button("Clear Chat") | |
| # Set up event handlers | |
| submit_click_event = submit_btn.click( | |
| process_query, | |
| inputs=[msg, chatbot], | |
| outputs=[chatbot, msg] | |
| ) | |
| msg.submit( | |
| process_query, | |
| inputs=[msg, chatbot], | |
| outputs=[chatbot, msg] | |
| ) | |
| clear_btn.click(lambda: [], outputs=[chatbot]) | |
| if __name__ == "__main__": | |
| demo.launch() |