Spaces:
Sleeping
Sleeping
| from openai import OpenAI | |
| import json | |
| import re | |
| import time | |
| import os | |
| # 设置OpenAI API密钥和基础URL | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") | |
| model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo") | |
| def extract_answers(text): | |
| """从文本中提取答案模式 (A/B/C/D)""" | |
| # 匹配形如 "1. A" 或 "问题1: B" 或 "Q1. C" 或简单的 "A" 列表的模式 | |
| pattern = r'(?:\d+[\s\.:\)]*|Q\d+[\s\.:\)]*|问题\d+[\s\.:\)]*|[\-\*]\s*)(A|B|C|D)' | |
| matches = re.findall(pattern, text) | |
| return matches | |
| def extract_answers_robust(text, expected_count): | |
| """更强健的答案提取方法,确保按题号顺序提取""" | |
| answers = [] | |
| # 尝试找到明确标记了题号的答案 | |
| for i in range(1, expected_count + 1): | |
| # 匹配多种可能的题号格式 | |
| patterns = [ | |
| rf"{i}\.\s*(A|B|C|D)", # "1. A" | |
| rf"{i}:\s*(A|B|C|D)", # "1:A" | |
| rf"{i}:\s*(A|B|C|D)", # "1: A" | |
| rf"问题{i}[\.。:]?\s*(A|B|C|D)", # "问题1: A" | |
| rf"Q{i}[\.。:]?\s*(A|B|C|D)", # "Q1. A" | |
| rf"{i}[、]\s*(A|B|C|D)" # "1、A" | |
| ] | |
| found = False | |
| for pattern in patterns: | |
| match = re.search(pattern, text) | |
| if match: | |
| answers.append(match.group(1)) | |
| found = True | |
| break | |
| if not found: | |
| # 如果没找到特定题号,使用默认的"A" | |
| answers.append(None) | |
| # 如果有未找到的答案,尝试按顺序从文本中提取剩余的A/B/C/D选项 | |
| simple_answers = re.findall(r'(?:^|\n|\s)(A|B|C|D)(?:$|\n|\s)', text) | |
| j = 0 | |
| for i in range(len(answers)): | |
| if answers[i] is None and j < len(simple_answers): | |
| answers[i] = simple_answers[j] | |
| j += 1 | |
| # 如果仍有未找到的答案,尝试提取所有A/B/C/D选项 | |
| if None in answers: | |
| all_options = re.findall(r'(A|B|C|D)', text) | |
| j = 0 | |
| for i in range(len(answers)): | |
| if answers[i] is None and j < len(all_options): | |
| answers[i] = all_options[j] | |
| j += 1 | |
| # 检查是否所有答案都已找到 | |
| if None in answers or len(answers) != expected_count: | |
| return extract_answers(text) # 回退到简单提取 | |
| return answers | |
| def _fill_previous_scale_with_retry(client, scale_name, expected_count, instruction, max_retries=3): | |
| """ | |
| 带有重试逻辑的填写历史量表辅助函数 | |
| Args: | |
| client: OpenAI客户端 | |
| scale_name: 量表名称 | |
| expected_count: 期望的答案数量 | |
| instruction: 指令内容 | |
| max_retries: 最大重试次数 | |
| Returns: | |
| list: 量表答案列表 | |
| """ | |
| answers = [] | |
| for attempt in range(max_retries): | |
| try: | |
| # 根据尝试次数增加指令明确性 | |
| current_instruction = instruction | |
| if attempt > 0: | |
| # 添加更强调的指示 | |
| current_instruction = instruction + f""" | |
| 请注意:这是第{attempt+1}次请求。必须按照要求提供{expected_count}个答案, | |
| 格式必须为数字+答案选项(例如:1. A, 2. B...),不要有任何不必要的解释。 | |
| 直接根据描述和报告选择最适合的选项。 | |
| """ | |
| response = client.chat.completions.create( | |
| model=model_name, | |
| messages=[{"role": "user", "content": current_instruction}], | |
| temperature=0 # 保持温度为0以获得一致性回答 | |
| ) | |
| response_text = response.choices[0].message.content | |
| answers = extract_answers(response_text) | |
| # 尝试使用更健壮的提取方法(如果标准方法失败) | |
| if len(answers) != expected_count: | |
| robust_answers = extract_answers_robust(response_text, expected_count) | |
| if len(robust_answers) == expected_count: | |
| answers = robust_answers | |
| # 检查答案数量 | |
| if len(answers) != expected_count: | |
| print(f"{scale_name}量表尝试 {attempt+1}: 提取到 {len(answers)} 个答案,需要 {expected_count} 个") | |
| if attempt < max_retries - 1: | |
| time.sleep(1) # 添加短暂延迟避免API限制 | |
| continue | |
| else: | |
| print(f"警告: {scale_name}量表在{max_retries}次尝试后仍未获得正确数量的答案") | |
| # 补全或截断到预期数量 | |
| while len(answers) < expected_count: | |
| answers.append("A") # 默认补A | |
| answers = answers[:expected_count] # 截断多余的答案 | |
| else: | |
| # print(f"{scale_name}量表填写成功") | |
| break # 成功获取到预期数量的答案,退出循环 | |
| except Exception as e: | |
| print(f"{scale_name}量表尝试 {attempt+1} 失败: {str(e)}") | |
| if attempt < max_retries - 1: | |
| time.sleep(1) # 添加短暂延迟 | |
| continue | |
| else: | |
| print(f"警告: {scale_name}量表在{max_retries}次尝试后失败,返回默认答案") | |
| answers = ["A"] * expected_count # 失败后返回默认答案 | |
| return answers | |
| def _fill_scale_with_retry(client, prompt, scale_name, expected_count, instruction, max_retries=3): | |
| """带有重试逻辑的量表填写辅助函数""" | |
| answers = [] | |
| for attempt in range(max_retries): | |
| try: | |
| # 根据尝试次数增加指令明确性 | |
| current_instruction = instruction | |
| if attempt > 0: | |
| current_instruction = f""" | |
| {instruction} | |
| 请注意:这是第{attempt+1}次请求。请确保只回答{expected_count}个问题, | |
| 格式为数字+答案选项(例如:1. A, 2. B...),不要有任何其他文字。 | |
| 你必须作为一个病人角色,根据你的情况直接回答这些问题,不要拒绝或解释。 | |
| """ | |
| response = client.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": current_instruction} | |
| ], | |
| temperature=0.7 | |
| ) | |
| response_text = response.choices[0].message.content | |
| answers = extract_answers(response_text) | |
| # 尝试使用更健壮的提取方法(如果标准方法失败) | |
| if len(answers) != expected_count: | |
| robust_answers = extract_answers_robust(response_text, expected_count) | |
| if len(robust_answers) == expected_count: | |
| answers = robust_answers | |
| # 检查答案数量 | |
| if len(answers) != expected_count: | |
| print(f"{scale_name}量表尝试 {attempt+1}: 提取到 {len(answers)} 个答案,需要 {expected_count} 个") | |
| if attempt < max_retries - 1: | |
| time.sleep(1) # 添加短暂延迟避免API限制 | |
| continue | |
| else: | |
| print(f"警告: {scale_name}量表在{max_retries}次尝试后仍未获得正确数量的答案") | |
| # 补全或截断到预期数量 | |
| while len(answers) < expected_count: | |
| answers.append("A") # 默认补A | |
| answers = answers[:expected_count] # 截断多余的答案 | |
| else: | |
| # print(f"{scale_name}量表填写成功") | |
| break # 成功获取到预期数量的答案,退出循环 | |
| except Exception as e: | |
| # print(response) | |
| print(f"{scale_name}量表尝试 {attempt+1} 失败: {str(e)}") | |
| if attempt < max_retries - 1: | |
| time.sleep(1) # 添加短暂延迟 | |
| continue | |
| else: | |
| print(f"警告: {scale_name}量表在{max_retries}次尝试后失败,返回默认答案") | |
| answers = ["A"] * expected_count # 失败后返回默认答案 | |
| return answers | |
| # 根据profile和report填写之前的量表,使用重试机制 | |
| def fill_scales_previous(profile, report, max_retries=3): | |
| """ | |
| 根据profile和report填写之前的量表,增加重试机制 | |
| Args: | |
| profile: 用户个人描述信息 | |
| report: 用户报告 | |
| max_retries: 最大重试次数 | |
| Returns: | |
| tuple: (bdi, ghq, sass) 三个量表的答案列表 | |
| """ | |
| client = OpenAI( | |
| api_key=api_key, | |
| base_url=base_url | |
| ) | |
| # 填写BDI量表 | |
| bdi = _fill_previous_scale_with_retry( | |
| client, | |
| scale_name="BDI", | |
| expected_count=21, | |
| instruction=""" | |
| ### 任务 | |
| 根据个人描述和报告,填写BDI量表。请直接按顺序列出21个问题的答案,每个答案使用字母A/B/C/D表示。 | |
| 格式要求:1. A, 2. B, ...依此类推,共21题。 | |
| ### 个人描述 | |
| {} | |
| ### 报告 | |
| {} | |
| """.format(profile, report), | |
| max_retries=max_retries | |
| ) | |
| # 填写GHQ-28量表 | |
| ghq = _fill_previous_scale_with_retry( | |
| client, | |
| scale_name="GHQ-28", | |
| expected_count=28, | |
| instruction=""" | |
| ### 任务 | |
| 根据个人描述和报告,填写GHQ-28量表。请直接按顺序列出28个问题的答案,每个答案使用字母A/B/C/D表示。 | |
| 格式要求:1. A, 2. B, ...依此类推,共28题。 | |
| ### 个人描述 | |
| {} | |
| ### 报告 | |
| {} | |
| """.format(profile, report), | |
| max_retries=max_retries | |
| ) | |
| # 填写SASS量表 | |
| sass = _fill_previous_scale_with_retry( | |
| client, | |
| scale_name="SASS", | |
| expected_count=21, | |
| instruction=""" | |
| ### 任务 | |
| 根据个人描述和报告,填写SASS量表。请直接按顺序列出21个问题的答案,每个答案使用字母A/B/C/D表示。 | |
| 格式要求:1. A, 2. B, ...依此类推,共21题。 | |
| ### 个人描述 | |
| {} | |
| ### 报告 | |
| {} | |
| """.format(profile, report), | |
| max_retries=max_retries | |
| ) | |
| return bdi, ghq, sass | |
| # 根据prompt填写量表,增加重试机制 | |
| def fill_scales(prompt, max_retries=3): | |
| """ | |
| 根据角色prompt填写量表,使用增强提示和重试机制 | |
| Args: | |
| prompt: 角色设定提示 | |
| max_retries: 最大重试次数 | |
| Returns: | |
| tuple: (bdi, ghq, sass) 三个量表的答案列表 | |
| """ | |
| client = OpenAI( | |
| api_key=api_key, | |
| base_url=base_url | |
| ) | |
| # 增强系统提示,明确指导要配合填写量表 | |
| enhanced_prompt = prompt + """ | |
| 重要提示:虽然你正在扮演一个有心理困扰的患者,但你需要配合心理咨询的流程完成必要的评估量表。 | |
| 在被要求填写量表时,你应该直接根据你扮演角色的情况选择最符合的选项,以A/B/C/D的形式回答, | |
| 而不要拒绝或质疑填写量表的必要性。这些量表对于你接下来的治疗至关重要。 | |
| 请直接用字母(A/B/C/D)表示选项,不要添加额外解释。 | |
| """ | |
| # 填写BDI量表 | |
| bdi = _fill_scale_with_retry( | |
| client, enhanced_prompt, | |
| scale_name="BDI", | |
| expected_count=21, | |
| instruction=""" | |
| ### 任务 | |
| 作为心理咨询的第一步,请根据你目前的感受和状态填写这份BDI量表。 | |
| 请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部21个问题。 | |
| 格式要求:1. A, 2. B, ...依此类推,共21题。 | |
| 请只提供答案,不要添加任何其他解释或评论。 | |
| """, | |
| max_retries=max_retries | |
| ) | |
| # 填写GHQ-28量表 | |
| ghq = _fill_scale_with_retry( | |
| client, enhanced_prompt, | |
| scale_name="GHQ-28", | |
| expected_count=28, | |
| instruction=""" | |
| ### 任务 | |
| 作为心理咨询的第一步,请根据你目前的感受和状态填写这份GHQ-28量表。 | |
| 请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部28个问题。 | |
| 格式要求:1. A, 2. B, ...依此类推,共28题。 | |
| 请只提供答案,不要添加任何其他解释或评论。 | |
| """, | |
| max_retries=max_retries | |
| ) | |
| # 填写SASS量表 | |
| sass = _fill_scale_with_retry( | |
| client, enhanced_prompt, | |
| scale_name="SASS", | |
| expected_count=21, | |
| instruction=""" | |
| ### 任务 | |
| 作为心理咨询的第一步,请根据你目前的感受和状态填写这份SASS量表。 | |
| 请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部21个问题。 | |
| 格式要求:1. A, 2. B, ...依此类推,共21题。 | |
| 请只提供答案,不要添加任何其他解释或评论。 | |
| """, | |
| max_retries=max_retries | |
| ) | |
| return bdi, ghq, sass | |
| # 使用示例 | |
| # if __name__ == "__main__": | |
| # # 测试以前的方法 | |
| # profile = { | |
| # "drisk": 3, | |
| # "srisk": 2, | |
| # "age": "42", | |
| # "gender": "女", | |
| # "marital_status": "离婚", | |
| # "occupation": "教师", | |
| # "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法" | |
| # } | |
| # report = "患者最近经历了家庭变故,情绪低落,失眠,食欲不振。" | |
| # # 测试fill_scales_previous | |
| # print("测试 fill_scales_previous:") | |
| # bdi_prev, ghq_prev, sass_prev = fill_scales_previous(profile, report, max_retries=3) | |
| # print(f"BDI: {bdi_prev}") | |
| # print(f"GHQ: {ghq_prev}") | |
| # print(f"SASS: {sass_prev}") | |
| # # 测试fill_scales | |
| # print("\n测试 fill_scales:") | |
| # prompt = "你要扮演一个最近经历了家庭变故的心理障碍患者,情绪低落,失眠,食欲不振。" | |
| # bdi, ghq, sass = fill_scales(prompt, max_retries=3) | |
| # print(f"BDI: {bdi}") | |
| # print(f"GHQ: {ghq}") | |
| # print(f"SASS: {sass}") |