Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from openai import OpenAI | |
| import time | |
| from datasets import load_dataset # 导入Hugging Face datasets库 | |
| # ======================================================================= | |
| # 1. 导入您的类和数据 | |
| # ======================================================================= | |
| # 请将 'your_agent_file' 替换为包含 MsPatient 类的实际文件名 | |
| from ms_patient import MsPatient | |
| # 使用Streamlit的缓存功能来加载和缓存数据集 | |
| def load_hf_dataset(): | |
| """ | |
| 从Hugging Face Hub加载并缓存数据集。 | |
| 这将返回一个字典列表,每个字典代表一个病人数据。 | |
| """ | |
| try: | |
| # 加载'train'分割部分的数据 | |
| dataset = load_dataset("sci-m-wang/Anna-CPsyCounD", split='train') | |
| # 转换为pandas DataFrame再转为字典列表,方便处理 | |
| return dataset.to_pandas().to_dict('records') | |
| except Exception as e: | |
| st.error(f"从Hugging Face加载数据集失败: {e}") | |
| return [] | |
| # 加载数据 | |
| ALL_PATIENTS = load_hf_dataset() | |
| # ======================================================================= | |
| # 2. Streamlit 应用界面 | |
| # ======================================================================= | |
| # --- 页面配置 --- | |
| st.set_page_config( | |
| page_title="与Anna对话", | |
| page_icon="�", | |
| layout="wide" | |
| ) | |
| # --- 自定义CSS样式 --- | |
| st.markdown(""" | |
| <style> | |
| /* 主聊天容器 */ | |
| .st-emotion-cache-1y4p8pa { | |
| padding-top: 2rem; | |
| } | |
| /* 聊天消息 */ | |
| .st-chat-message { | |
| border-radius: 0.8rem; | |
| padding: 0.9rem 1.2rem; | |
| box-shadow: 0 2px 5px rgba(0,0,0,0.05); | |
| background-color: #ffffff; | |
| } | |
| .st-chat-message[data-testid="chat-message-container-user"] { | |
| background-color: #dcf8c6; | |
| } | |
| /* 侧边栏 */ | |
| .st-sidebar { | |
| background-color: #f8f9fa; | |
| border-right: 1px solid #e9ecef; | |
| } | |
| .st-sidebar h2 { | |
| color: #343a40; | |
| } | |
| .st-expanderHeader { | |
| font-size: 1.1rem; | |
| font-weight: 600; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --- 初始化 Session State --- | |
| if "patient_agent" not in st.session_state: | |
| st.session_state.patient_agent = None | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "selected_patient_id" not in st.session_state: | |
| st.session_state.selected_patient_id = None | |
| if "openai_client" not in st.session_state: | |
| st.session_state.openai_client = None | |
| if "model_name" not in st.session_state: | |
| st.session_state.model_name = "gpt-4o-mini" # 默认模型 | |
| # --- 侧边栏 --- | |
| with st.sidebar: | |
| st.title("👩 AnnaAgent 设置") | |
| st.markdown("---") | |
| # API Key 输入 | |
| with st.expander("🔑 API 设置", expanded=True): | |
| api_key = st.text_input("输入您的 OpenAI API Key", type="password", help="您的API Key将仅用于本次会话,不会被储存。") | |
| base_url = st.text_input("API Base URL (可选)", value="https://api.openai.com/v1") | |
| model_name = st.text_input("模型名称", value=st.session_state.model_name) | |
| if st.button("连接模型"): | |
| if api_key: | |
| try: | |
| st.session_state.openai_client = OpenAI(api_key=api_key, base_url=base_url) | |
| st.session_state.model_name = model_name | |
| st.success("连接成功!") | |
| if st.session_state.patient_agent: | |
| st.session_state.patient_agent.client = st.session_state.openai_client | |
| except Exception as e: | |
| st.error(f"连接失败: {e}") | |
| else: | |
| st.warning("请输入API Key。") | |
| st.markdown("---") | |
| # 病人选择 | |
| if not ALL_PATIENTS: | |
| st.error("无法加载病人数据,请检查网络连接或数据集名称。") | |
| else: | |
| patient_options = {p["id"]: f"{p['portrait']['gender']},{p['portrait']['age']}岁 - {p['portrait']['symptom']}" for p in ALL_PATIENTS} | |
| selected_id = st.selectbox( | |
| "选择一位病人进行对话", | |
| options=list(patient_options.keys()), | |
| format_func=lambda x: patient_options[x] | |
| ) | |
| # 当选择的病人变化时,重置状态 | |
| if st.session_state.selected_patient_id != selected_id: | |
| st.session_state.selected_patient_id = selected_id | |
| selected_patient_data = next((p for p in ALL_PATIENTS if p["id"] == selected_id), None) | |
| with st.spinner("正在生成病人角色..."): | |
| st.session_state.patient_agent = MsPatient( | |
| portrait=selected_patient_data["portrait"], | |
| report=selected_patient_data["report"], | |
| previous_conversations=selected_patient_data["conversation"], | |
| language="Chinese", | |
| client=st.session_state.openai_client | |
| ) | |
| st.session_state.messages = [{"role": "assistant", "content": "你好,医生..."}] | |
| st.rerun() | |
| # 显示病人信息 | |
| if st.session_state.patient_agent: | |
| st.markdown("---") | |
| st.subheader("病人信息") | |
| agent = st.session_state.patient_agent | |
| st.info(f""" | |
| **基本情况**: {agent.portrait['gender']}, {agent.portrait['age']}岁, {agent.portrait['occupation']}, {agent.portrait['marital_status']} | |
| **近期状态**: {agent.status} | |
| """) | |
| with st.expander("查看完整系统提示 (System Prompt)"): | |
| st.code(agent.get_system_prompt(), language='markdown') | |
| # --- 主聊天界面 --- | |
| st.title("💬 与 Anna 对话") | |
| st.caption("这是一个模拟心理咨询的AI Agent。由 `MsPatient` 类驱动。") | |
| # 显示聊天记录 | |
| for message in st.session_state.messages: | |
| avatar = "👩" if message["role"] == "assistant" else "🧑⚕️" | |
| with st.chat_message(message["role"], avatar=avatar): | |
| st.markdown(message["content"]) | |
| # 获取用户输入 | |
| if prompt := st.chat_input("请输入您想说的话..."): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user", avatar="🧑⚕️"): | |
| st.markdown(prompt) | |
| if st.session_state.patient_agent: | |
| with st.chat_message("assistant", avatar="👩"): | |
| with st.spinner("Anna正在思考..."): | |
| response = st.session_state.patient_agent.chat(prompt) | |
| st.markdown(response) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| else: | |
| st.warning("请先在左侧选择一位病人并配置API Key。") |