Spaces:
Runtime error
Runtime error
feature/major backend update with agent
Browse files
.gitignore
CHANGED
|
@@ -5,4 +5,5 @@ __pycache__/utils.cpython-38.pyc
|
|
| 5 |
|
| 6 |
notebooks/
|
| 7 |
*.pyc
|
| 8 |
-
local_tests/
|
|
|
|
|
|
| 5 |
|
| 6 |
notebooks/
|
| 7 |
*.pyc
|
| 8 |
+
local_tests/
|
| 9 |
+
.vscode/
|
app.py
CHANGED
|
@@ -64,9 +64,9 @@ async def chat(query, history):
|
|
| 64 |
async for event in result:
|
| 65 |
print(event)
|
| 66 |
if event["event"] == "on_chat_model_stream":
|
| 67 |
-
print("line 66")
|
| 68 |
if start_streaming == False:
|
| 69 |
-
print("line 68")
|
| 70 |
start_streaming = True
|
| 71 |
history[-1] = (query, "")
|
| 72 |
|
|
@@ -77,17 +77,26 @@ async def chat(query, history):
|
|
| 77 |
answer_yet = parse_output_llm_with_sources(answer_yet)
|
| 78 |
history[-1] = (query, answer_yet)
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
elif (
|
| 81 |
event["name"] == "retrieve_documents"
|
| 82 |
and event["event"] == "on_chain_end"
|
| 83 |
):
|
| 84 |
try:
|
| 85 |
-
print(
|
|
|
|
| 86 |
docs = event["data"]["output"]["documents"]
|
| 87 |
docs_html = []
|
| 88 |
-
for i,
|
| 89 |
-
docs_html.append(make_html_source(
|
|
|
|
| 90 |
docs_html = "".join(docs_html)
|
|
|
|
| 91 |
except Exception as e:
|
| 92 |
print(f"Error getting documents: {e}")
|
| 93 |
print(event)
|
|
@@ -97,9 +106,9 @@ async def chat(query, history):
|
|
| 97 |
display_output,
|
| 98 |
) in steps_display.items():
|
| 99 |
if event["name"] == event_name:
|
| 100 |
-
print("line 99")
|
| 101 |
if event["event"] == "on_chain_start":
|
| 102 |
-
print("line 101")
|
| 103 |
answer_yet = event_description
|
| 104 |
history[-1] = (query, answer_yet)
|
| 105 |
|
|
|
|
| 64 |
async for event in result:
|
| 65 |
print(event)
|
| 66 |
if event["event"] == "on_chat_model_stream":
|
| 67 |
+
# print("line 66")
|
| 68 |
if start_streaming == False:
|
| 69 |
+
# print("line 68")
|
| 70 |
start_streaming = True
|
| 71 |
history[-1] = (query, "")
|
| 72 |
|
|
|
|
| 77 |
answer_yet = parse_output_llm_with_sources(answer_yet)
|
| 78 |
history[-1] = (query, answer_yet)
|
| 79 |
|
| 80 |
+
elif (
|
| 81 |
+
event["name"] == "answer_rag_wrong"
|
| 82 |
+
and event["event"] == "on_chain_stream"
|
| 83 |
+
):
|
| 84 |
+
history[-1] = (query, event["data"]["chunk"]["answer"])
|
| 85 |
+
|
| 86 |
elif (
|
| 87 |
event["name"] == "retrieve_documents"
|
| 88 |
and event["event"] == "on_chain_end"
|
| 89 |
):
|
| 90 |
try:
|
| 91 |
+
# print(event)
|
| 92 |
+
# print("line 84")
|
| 93 |
docs = event["data"]["output"]["documents"]
|
| 94 |
docs_html = []
|
| 95 |
+
for i, doc in enumerate(docs, 1):
|
| 96 |
+
docs_html.append(make_html_source(i, doc))
|
| 97 |
+
# print(docs_html)
|
| 98 |
docs_html = "".join(docs_html)
|
| 99 |
+
# print(docs_html)
|
| 100 |
except Exception as e:
|
| 101 |
print(f"Error getting documents: {e}")
|
| 102 |
print(event)
|
|
|
|
| 106 |
display_output,
|
| 107 |
) in steps_display.items():
|
| 108 |
if event["name"] == event_name:
|
| 109 |
+
# print("line 99")
|
| 110 |
if event["event"] == "on_chain_start":
|
| 111 |
+
# print("line 101")
|
| 112 |
answer_yet = event_description
|
| 113 |
history[-1] = (query, answer_yet)
|
| 114 |
|
celsius_csrd_chatbot/agent.py
CHANGED
|
@@ -39,16 +39,12 @@ def route_intent(state):
|
|
| 39 |
return "intent_esrs"
|
| 40 |
|
| 41 |
elif esrs == "wrong_esrs":
|
| 42 |
-
return "
|
| 43 |
|
| 44 |
else:
|
| 45 |
return "retrieve_documents"
|
| 46 |
|
| 47 |
|
| 48 |
-
def make_id_dict(values):
|
| 49 |
-
return {k: k for k in values}
|
| 50 |
-
|
| 51 |
-
|
| 52 |
def make_graph_agent(llm, vectorstore):
|
| 53 |
workflow = StateGraph(GraphState)
|
| 54 |
|
|
@@ -70,11 +66,7 @@ def make_graph_agent(llm, vectorstore):
|
|
| 70 |
workflow.set_entry_point("categorize_esrs")
|
| 71 |
|
| 72 |
# CONDITIONAL EDGES
|
| 73 |
-
workflow.add_conditional_edges(
|
| 74 |
-
"categorize_esrs",
|
| 75 |
-
route_intent,
|
| 76 |
-
make_id_dict(["intent_esrs", "retrieve_documents", "answer_rag_wrong"]),
|
| 77 |
-
)
|
| 78 |
|
| 79 |
# Define the edges
|
| 80 |
workflow.add_edge("intent_esrs", "retrieve_documents")
|
|
|
|
| 39 |
return "intent_esrs"
|
| 40 |
|
| 41 |
elif esrs == "wrong_esrs":
|
| 42 |
+
return "answer_rag_wrong"
|
| 43 |
|
| 44 |
else:
|
| 45 |
return "retrieve_documents"
|
| 46 |
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
def make_graph_agent(llm, vectorstore):
|
| 49 |
workflow = StateGraph(GraphState)
|
| 50 |
|
|
|
|
| 66 |
workflow.set_entry_point("categorize_esrs")
|
| 67 |
|
| 68 |
# CONDITIONAL EDGES
|
| 69 |
+
workflow.add_conditional_edges("categorize_esrs", route_intent)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
# Define the edges
|
| 72 |
workflow.add_edge("intent_esrs", "retrieve_documents")
|
celsius_csrd_chatbot/chains/answer_rag.py
CHANGED
|
@@ -36,6 +36,7 @@ answering_template = """
|
|
| 36 |
10. Method Focus: When addressing "how" questions, emphasize methods and procedures over outcomes.
|
| 37 |
11. Selective Usage: You're not obligated to use every passage; include only those relevant to the question.
|
| 38 |
12. Insufficient Information: If documents lack necessary details, indicate that you don't have enough information.
|
|
|
|
| 39 |
|
| 40 |
Question: {query}
|
| 41 |
Answer:
|
|
|
|
| 36 |
10. Method Focus: When addressing "how" questions, emphasize methods and procedures over outcomes.
|
| 37 |
11. Selective Usage: You're not obligated to use every passage; include only those relevant to the question.
|
| 38 |
12. Insufficient Information: If documents lack necessary details, indicate that you don't have enough information.
|
| 39 |
+
13. Never mention these guidelines as a source attribution in your response.
|
| 40 |
|
| 41 |
Question: {query}
|
| 42 |
Answer:
|
celsius_csrd_chatbot/chains/esrs_categorization.py
CHANGED
|
@@ -5,7 +5,7 @@ def make_esrs_categorization_node():
|
|
| 5 |
|
| 6 |
def categorize_message(state):
|
| 7 |
query = state["query"]
|
| 8 |
-
pattern = r"ESRS \d
|
| 9 |
esrs_truth = [
|
| 10 |
"ESRS 1",
|
| 11 |
"ESRS 2",
|
|
@@ -25,7 +25,6 @@ def make_esrs_categorization_node():
|
|
| 25 |
if matches:
|
| 26 |
true_matches = [match for match in matches if match in esrs_truth]
|
| 27 |
output = {"esrs_type": true_matches if true_matches else "wrong_esrs"}
|
| 28 |
-
|
| 29 |
else:
|
| 30 |
output = {"esrs_type": "none"}
|
| 31 |
|
|
|
|
| 5 |
|
| 6 |
def categorize_message(state):
|
| 7 |
query = state["query"]
|
| 8 |
+
pattern = r"ESRS \d+[A-Z0-9]*"
|
| 9 |
esrs_truth = [
|
| 10 |
"ESRS 1",
|
| 11 |
"ESRS 2",
|
|
|
|
| 25 |
if matches:
|
| 26 |
true_matches = [match for match in matches if match in esrs_truth]
|
| 27 |
output = {"esrs_type": true_matches if true_matches else "wrong_esrs"}
|
|
|
|
| 28 |
else:
|
| 29 |
output = {"esrs_type": "none"}
|
| 30 |
|
celsius_csrd_chatbot/chains/esrs_intent.py
CHANGED
|
@@ -23,51 +23,41 @@ class ESRSAnalysis(BaseModel):
|
|
| 23 |
"ESRS S3",
|
| 24 |
"ESRS S4",
|
| 25 |
"ESRS G1",
|
| 26 |
-
"
|
| 27 |
] = Field(
|
| 28 |
-
description="""
|
| 29 |
-
Given a user question choose which documents would be most relevant for answering their question :
|
| 30 |
-
|
| 31 |
-
- ESRS 1 is for questions about general principles for preparing and presenting sustainability information in accordance with CSRD
|
| 32 |
-
- ESRS 2 is for questions about general disclosures related to sustainability reporting, including governance, strategy, impact, risk, opportunity management, and metrics and targets
|
| 33 |
-
- ESRS E1 is for questions about climate change, global warming, GES and energy
|
| 34 |
-
- ESRS E2 is for questions about air, water, and soil pollution, and dangerous substances
|
| 35 |
-
- ESRS E3 is for questions about water and marine resources
|
| 36 |
-
- ESRS E4 is for questions about biodiversity, nature, wildlife and ecosystems
|
| 37 |
-
- ESRS E5 is for questions about resource use and circular economy
|
| 38 |
-
- ESRS S1 is for questions about workforce and labor issues, job security, fair pay, and health and safety
|
| 39 |
-
- ESRS S2 is for questions about workers in the value chain, workers' treatment
|
| 40 |
-
- SRS S3 is for questions about affected communities, impact on local communities
|
| 41 |
-
- ESRS S4 is for questions about consumers and end users, customer privacy, safety, and inclusion
|
| 42 |
-
- ESRS G1 is for questions about governance, risk management, internal control, and business conduct
|
| 43 |
-
- none is for questions that do not fit into any of the above categories
|
| 44 |
-
|
| 45 |
-
Follow these guidelines :
|
| 46 |
-
|
| 47 |
-
- Some questions could be related to multiple ESRS. In such case, choose the most appropriate one.
|
| 48 |
-
- Remember, if the question is not related to any ESRS, the output should be 'none'.
|
| 49 |
-
""",
|
| 50 |
)
|
| 51 |
|
| 52 |
|
| 53 |
def make_esrs_intent_chain(llm):
|
| 54 |
-
parser = PydanticOutputParser(pydantic_object=ESRSAnalysis)
|
| 55 |
prompt_template = """
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
Question: '{query}'
|
| 63 |
Answer:
|
| 64 |
"""
|
| 65 |
-
|
| 66 |
-
prompt = PromptTemplate(
|
| 67 |
-
template=prompt_template,
|
| 68 |
-
input_variables=["query"],
|
| 69 |
-
partial_variables={"format_instructions": parser.get_format_instructions()},
|
| 70 |
-
)
|
| 71 |
chain = {"query": itemgetter("query")} | prompt | llm | parser
|
| 72 |
|
| 73 |
return chain
|
|
@@ -78,7 +68,9 @@ def make_esrs_intent_node(llm):
|
|
| 78 |
def intent_message(state):
|
| 79 |
query = state["query"]
|
| 80 |
categorization_chain = make_esrs_intent_chain(llm)
|
| 81 |
-
output =
|
|
|
|
|
|
|
| 82 |
|
| 83 |
return output
|
| 84 |
|
|
|
|
| 23 |
"ESRS S3",
|
| 24 |
"ESRS S4",
|
| 25 |
"ESRS G1",
|
| 26 |
+
"no_intent",
|
| 27 |
] = Field(
|
| 28 |
+
description="""The ESRS type that the user query refers to.""",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
)
|
| 30 |
|
| 31 |
|
| 32 |
def make_esrs_intent_chain(llm):
|
|
|
|
| 33 |
prompt_template = """
|
| 34 |
+
Please analyze the question and indicate if it refers to a specific ESRS.
|
| 35 |
+
|
| 36 |
+
Follow these definitions in order to choose the appropriate ESRS :
|
| 37 |
+
- ESRS 1 is for questions about general principles for preparing and presenting sustainability information in accordance with CSRD
|
| 38 |
+
- ESRS 2 is for questions about general disclosures related to sustainability reporting, including governance, strategy, impact, risk, opportunity management, and metrics and targets
|
| 39 |
+
- ESRS E1 is for questions about climate change, global warming, GES and energy
|
| 40 |
+
- ESRS E2 is for questions about air, water, and soil pollution, and dangerous substances
|
| 41 |
+
- ESRS E3 is for questions about water and marine resources
|
| 42 |
+
- ESRS E4 is for questions about biodiversity, nature, wildlife and ecosystems
|
| 43 |
+
- ESRS E5 is for questions about resource use and circular economy
|
| 44 |
+
- ESRS S1 is for questions about workforce and labor issues, job security, fair pay, and health and safety
|
| 45 |
+
- ESRS S2 is for questions about workers in the value chain, workers' treatment
|
| 46 |
+
- ESRS S3 is for questions about affected communities, impact on local communities
|
| 47 |
+
- ESRS S4 is for questions about consumers and end users, customer privacy, safety, and inclusion
|
| 48 |
+
- ESRS G1 is for questions about governance, risk management, internal control, and business conduct
|
| 49 |
+
- no_intent is for questions that do not fit into any of the above categories
|
| 50 |
+
|
| 51 |
+
Keep in mind these guidelines :
|
| 52 |
+
- Some questions could be related to multiple ESRS. In such case, choose the most appropriate one.
|
| 53 |
+
|
| 54 |
+
The output needs to respect a JSON format with 'esrs_type' as the key and the appropriate ESRS as the value.
|
| 55 |
|
| 56 |
Question: '{query}'
|
| 57 |
Answer:
|
| 58 |
"""
|
| 59 |
+
parser = PydanticOutputParser(pydantic_object=ESRSAnalysis, method="json_mode")
|
| 60 |
+
prompt = PromptTemplate(template=prompt_template, input_variables=["query"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
chain = {"query": itemgetter("query")} | prompt | llm | parser
|
| 62 |
|
| 63 |
return chain
|
|
|
|
| 68 |
def intent_message(state):
|
| 69 |
query = state["query"]
|
| 70 |
categorization_chain = make_esrs_intent_chain(llm)
|
| 71 |
+
output = {
|
| 72 |
+
"esrs_type": [categorization_chain.invoke({"query": query}).esrs_type]
|
| 73 |
+
}
|
| 74 |
|
| 75 |
return output
|
| 76 |
|
celsius_csrd_chatbot/chains/retriever.py
CHANGED
|
@@ -1,16 +1,15 @@
|
|
| 1 |
def make_retriever_node(vectorstore, k=10):
|
| 2 |
-
|
| 3 |
def retrieve_documents(state):
|
| 4 |
sources = state["esrs_type"]
|
| 5 |
query = state["query"]
|
| 6 |
-
if sources == "none":
|
| 7 |
-
|
| 8 |
else:
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
| 10 |
docs = []
|
| 11 |
-
docs_retrieved = vectorstore.similarity_search_with_score(
|
| 12 |
-
query=query, filter=filters_full, k=k
|
| 13 |
-
)
|
| 14 |
for doc in docs_retrieved:
|
| 15 |
doc_append = doc[0]
|
| 16 |
doc_append.metadata["similarity_score"] = doc[1]
|
|
|
|
| 1 |
def make_retriever_node(vectorstore, k=10):
|
|
|
|
| 2 |
def retrieve_documents(state):
|
| 3 |
sources = state["esrs_type"]
|
| 4 |
query = state["query"]
|
| 5 |
+
if sources == "none" or sources == "no_intent":
|
| 6 |
+
docs_retrieved = vectorstore.similarity_search_with_score(query=query, k=k)
|
| 7 |
else:
|
| 8 |
+
filters = {"ESRS_filter": {"$in": sources}}
|
| 9 |
+
docs_retrieved = vectorstore.similarity_search_with_score(
|
| 10 |
+
query=query, filter=filters, k=k
|
| 11 |
+
)
|
| 12 |
docs = []
|
|
|
|
|
|
|
|
|
|
| 13 |
for doc in docs_retrieved:
|
| 14 |
doc_append = doc[0]
|
| 15 |
doc_append.metadata["similarity_score"] = doc[1]
|