From 3cc5e3f479cee82c2811db291ab2fb8b6c3be660 Mon Sep 17 00:00:00 2001 From: 666ghj <670939375@qq.com> Date: Tue, 2 Dec 2025 14:25:53 +0800 Subject: [PATCH] Refactor simulation management and enhance logging capabilities - Updated simulation preparation checks to exclude script files from the required files list, improving clarity on file management. - Implemented a robust retry mechanism for Zep API calls in the ZepEntityReader service, enhancing reliability. - Enhanced logging in simulation scripts to provide clearer insights into the simulation process and errors. - Updated simulation runner to manage stdout and stderr logs more effectively, ensuring better error tracking. - Improved profile generation to standardize gender fields and ensure all required fields are populated correctly. --- backend/app/api/simulation.py | 46 +++-- .../app/services/oasis_profile_generator.py | 157 ++++++++++++------ backend/app/services/simulation_manager.py | 66 ++------ backend/app/services/simulation_runner.py | 64 +++++-- backend/app/services/zep_entity_reader.py | 78 +++++++-- backend/scripts/run_parallel_simulation.py | 141 +++++++++++++++- backend/scripts/run_reddit_simulation.py | 103 +++++++++++- backend/scripts/run_twitter_simulation.py | 105 ++++++++++-- 8 files changed, 595 insertions(+), 165 deletions(-) diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index df445d2..e31558c 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -221,6 +221,8 @@ def _check_simulation_prepared(simulation_id: str) -> tuple: 1. state.json 存在且 status 为 "ready" 2. 必要文件存在:reddit_profiles.json, twitter_profiles.csv, simulation_config.json + 注意:运行脚本(run_*.py)保留在 backend/scripts/ 目录,不再复制到模拟目录 + Args: simulation_id: 模拟ID @@ -236,15 +238,12 @@ def _check_simulation_prepared(simulation_id: str) -> tuple: if not os.path.exists(simulation_dir): return False, {"reason": "模拟目录不存在"} - # 必要文件列表 + # 必要文件列表(不包括脚本,脚本位于 backend/scripts/) required_files = [ "state.json", "simulation_config.json", "reddit_profiles.json", - "twitter_profiles.csv", - "run_reddit_simulation.py", - "run_twitter_simulation.py", - "run_parallel_simulation.py" + "twitter_profiles.csv" ] # 检查文件是否存在 @@ -272,9 +271,13 @@ def _check_simulation_prepared(simulation_id: str) -> tuple: state_data = json.load(f) status = state_data.get("status", "") + config_generated = state_data.get("config_generated", False) + + # 详细日志 + logger.debug(f"检测模拟准备状态: {simulation_id}, status={status}, config_generated={config_generated}") # 如果状态是ready或preparing(已有文件),认为准备完成 - if status in ["ready", "preparing"] and state_data.get("config_generated"): + if status in ["ready", "preparing"] and config_generated: # 获取文件统计信息 profiles_file = os.path.join(simulation_dir, "reddit_profiles.json") config_file = os.path.join(simulation_dir, "simulation_config.json") @@ -298,21 +301,23 @@ def _check_simulation_prepared(simulation_id: str) -> tuple: except Exception as e: logger.warning(f"自动更新状态失败: {e}") + logger.info(f"模拟 {simulation_id} 检测结果: 已准备完成 (status={status}, config_generated={config_generated})") return True, { "status": status, "entities_count": state_data.get("entities_count", 0), "profiles_count": profiles_count, "entity_types": state_data.get("entity_types", []), - "config_generated": state_data.get("config_generated", False), + "config_generated": config_generated, "created_at": state_data.get("created_at"), "updated_at": state_data.get("updated_at"), "existing_files": existing_files } else: + logger.warning(f"模拟 {simulation_id} 检测结果: 未准备完成 (status={status}, config_generated={config_generated})") return False, { - "reason": f"状态不是ready: {status}", + "reason": f"状态不是ready或config_generated为false: status={status}, config_generated={config_generated}", "status": status, - "config_generated": state_data.get("config_generated", False) + "config_generated": config_generated } except Exception as e: @@ -386,10 +391,13 @@ def prepare_simulation(): # 检查是否强制重新生成 force_regenerate = data.get('force_regenerate', False) + logger.info(f"开始处理 /prepare 请求: simulation_id={simulation_id}, force_regenerate={force_regenerate}") # 检查是否已经准备完成(避免重复生成) if not force_regenerate: + logger.debug(f"检查模拟 {simulation_id} 是否已准备完成...") is_prepared, prepare_info = _check_simulation_prepared(simulation_id) + logger.debug(f"检查结果: is_prepared={is_prepared}, prepare_info={prepare_info}") if is_prepared: logger.info(f"模拟 {simulation_id} 已准备完成,跳过重复生成") return jsonify({ @@ -402,6 +410,8 @@ def prepare_simulation(): "prepare_info": prepare_info } }) + else: + logger.info(f"模拟 {simulation_id} 未准备完成,将启动准备任务") # 从项目获取必要信息 project = ProjectManager.get_project(state.project_id) @@ -850,25 +860,27 @@ def download_simulation_config(simulation_id: str): }), 500 -@simulation_bp.route('//script//download', methods=['GET']) -def download_simulation_script(simulation_id: str, script_name: str): +@simulation_bp.route('/script//download', methods=['GET']) +def download_simulation_script(script_name: str): """ - 下载模拟脚本文件 + 下载模拟运行脚本文件(通用脚本,位于 backend/scripts/) script_name可选值: - run_twitter_simulation.py - run_reddit_simulation.py - run_parallel_simulation.py + - action_logger.py """ try: - manager = SimulationManager() - sim_dir = manager._get_simulation_dir(simulation_id) + # 脚本位于 backend/scripts/ 目录 + scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts')) # 验证脚本名称 allowed_scripts = [ "run_twitter_simulation.py", "run_reddit_simulation.py", - "run_parallel_simulation.py" + "run_parallel_simulation.py", + "action_logger.py" ] if script_name not in allowed_scripts: @@ -877,12 +889,12 @@ def download_simulation_script(simulation_id: str, script_name: str): "error": f"未知脚本: {script_name},可选: {allowed_scripts}" }), 400 - script_path = os.path.join(sim_dir, script_name) + script_path = os.path.join(scripts_dir, script_name) if not os.path.exists(script_path): return jsonify({ "success": False, - "error": "脚本文件不存在,请先调用 /prepare 接口" + "error": f"脚本文件不存在: {script_name}" }), 404 return send_file( diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index f0ff6ee..00a5f4a 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -10,6 +10,7 @@ OASIS Agent Profile生成器 import json import random +import time from typing import Dict, Any, List, Optional from dataclasses import dataclass, field from datetime import datetime @@ -315,32 +316,54 @@ class OasisProfileGenerator: comprehensive_query = f"关于{entity_name}的所有信息、活动、事件、关系和背景" def search_edges(): - """搜索边(事实/关系)""" - try: - return self.zep_client.graph.search( - query=comprehensive_query, - graph_id=self.graph_id, - limit=30, - scope="edges", - reranker="rrf" - ) - except Exception as e: - logger.debug(f"Zep边搜索失败: {e}") - return None + """搜索边(事实/关系)- 带重试机制""" + max_retries = 3 + last_exception = None + delay = 2.0 + + for attempt in range(max_retries): + try: + return self.zep_client.graph.search( + query=comprehensive_query, + graph_id=self.graph_id, + limit=30, + scope="edges", + reranker="rrf" + ) + except Exception as e: + last_exception = e + if attempt < max_retries - 1: + logger.debug(f"Zep边搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...") + time.sleep(delay) + delay *= 2 + else: + logger.debug(f"Zep边搜索在 {max_retries} 次尝试后仍失败: {e}") + return None def search_nodes(): - """搜索节点(实体摘要)""" - try: - return self.zep_client.graph.search( - query=comprehensive_query, - graph_id=self.graph_id, - limit=20, - scope="nodes", - reranker="rrf" - ) - except Exception as e: - logger.debug(f"Zep节点搜索失败: {e}") - return None + """搜索节点(实体摘要)- 带重试机制""" + max_retries = 3 + last_exception = None + delay = 2.0 + + for attempt in range(max_retries): + try: + return self.zep_client.graph.search( + query=comprehensive_query, + graph_id=self.graph_id, + limit=20, + scope="nodes", + reranker="rrf" + ) + except Exception as e: + last_exception = e + if attempt < max_retries - 1: + logger.debug(f"Zep节点搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...") + time.sleep(delay) + delay *= 2 + else: + logger.debug(f"Zep节点搜索在 {max_retries} 次尝试后仍失败: {e}") + return None try: # 并行执行edges和nodes搜索 @@ -684,18 +707,20 @@ class OasisProfileGenerator: - 立场观点(对话题的态度、可能被激怒/感动的内容) - 独特特征(口头禅、特殊经历、个人爱好) - 个人记忆(人设的重要部分,要介绍这个个体与事件的关联,以及这个个体在事件中的已有动作与反应) -3. age: 年龄数字 -4. gender: 性别(男/女) -5. mbti: MBTI类型 -6. country: 国家 +3. age: 年龄数字(必须是整数) +4. gender: 性别,必须是英文: "male" 或 "female" +5. mbti: MBTI类型(如INTJ、ENFP等) +6. country: 国家(使用中文,如"中国") 7. profession: 职业 8. interested_topics: 感兴趣话题数组 重要: - 所有字段值必须是字符串或数字,不要使用换行符 - persona必须是一段连贯的文字描述 -- 使用中文 -- 内容要与实体信息保持一致""" +- 使用中文(除了gender字段必须用英文male/female) +- 内容要与实体信息保持一致 +- age必须是有效的整数,gender必须是"male"或"female" +""" def _build_group_persona_prompt( self, @@ -731,17 +756,18 @@ class OasisProfileGenerator: - 立场态度(对核心话题的官方立场、面对争议的处理方式) - 特殊说明(代表的群体画像、运营习惯) - 机构记忆(机构人设的重要部分,要介绍这个机构与事件的关联,以及这个机构在事件中的已有动作与反应) -3. age: null(机构不适用) -4. gender: null(机构不适用) -5. mbti: 可选,用于描述账号风格,如ISTJ代表严谨保守 -6. country: 国家 +3. age: 固定填30(机构账号的虚拟年龄) +4. gender: 固定填"other"(机构账号使用other表示非个人) +5. mbti: MBTI类型,用于描述账号风格,如ISTJ代表严谨保守 +6. country: 国家(使用中文,如"中国") 7. profession: 机构职能描述 8. interested_topics: 关注领域数组 重要: -- 所有字段值必须是字符串、数字或null +- 所有字段值必须是字符串或数字,不允许null值 - persona必须是一段连贯的文字描述,不要使用换行符 -- 使用中文 +- 使用中文(除了gender字段必须用英文"other") +- age必须是整数30,gender必须是字符串"other" - 机构账号发言要符合其身份定位""" def _generate_profile_rule_based( @@ -784,6 +810,10 @@ class OasisProfileGenerator: return { "bio": f"Official account for {entity_name}. News and updates.", "persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.", + "age": 30, # 机构虚拟年龄 + "gender": "other", # 机构使用other + "mbti": "ISTJ", # 机构风格:严谨保守 + "country": "中国", "profession": "Media", "interested_topics": ["General News", "Current Events", "Public Affairs"], } @@ -792,6 +822,10 @@ class OasisProfileGenerator: return { "bio": f"Official account of {entity_name}.", "persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.", + "age": 30, # 机构虚拟年龄 + "gender": "other", # 机构使用other + "mbti": "ISTJ", # 机构风格:严谨保守 + "country": "中国", "profession": entity_type, "interested_topics": ["Public Policy", "Community", "Official Announcements"], } @@ -1039,6 +1073,31 @@ class OasisProfileGenerator: logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (OASIS CSV格式)") + def _normalize_gender(self, gender: Optional[str]) -> str: + """ + 标准化gender字段为OASIS要求的英文格式 + + OASIS要求: male, female, other + """ + if not gender: + return "other" + + gender_lower = gender.lower().strip() + + # 中文映射 + gender_map = { + "男": "male", + "女": "female", + "机构": "other", + "其他": "other", + # 英文已有 + "male": "male", + "female": "female", + "other": "other", + } + + return gender_map.get(gender_lower, "other") + def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str): """ 保存Reddit Profile为JSON格式 @@ -1048,26 +1107,30 @@ class OasisProfileGenerator: 2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics 我们使用详细格式,与用户示例数据(36个简单人设.json)保持一致 + + OASIS要求所有字段都必须存在: + - age: 整数 + - gender: "male", "female", 或 "other" + - mbti: MBTI类型字符串 + - country: 国家字符串 """ data = [] for profile in profiles: # 使用详细格式(与用户示例兼容) + # 确保所有必需字段都有有效值 item = { "realname": profile.name, "username": profile.user_name, - "bio": profile.bio[:150] if profile.bio else "", # OASIS bio限制150字符 + "bio": profile.bio[:150] if profile.bio else f"{profile.name}", "persona": profile.persona or f"{profile.name} is a participant in social discussions.", + # OASIS必需字段 - 确保都有默认值 + "age": profile.age if profile.age else 30, + "gender": self._normalize_gender(profile.gender), + "mbti": profile.mbti if profile.mbti else "ISTJ", + "country": profile.country if profile.country else "中国", } - # 添加人设详情字段 - if profile.age: - item["age"] = profile.age - if profile.gender: - item["gender"] = profile.gender - if profile.mbti: - item["mbti"] = profile.mbti - if profile.country: - item["country"] = profile.country + # 可选字段 if profile.profession: item["profession"] = profile.profession if profile.interested_topics: @@ -1078,7 +1141,7 @@ class OasisProfileGenerator: with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) - logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式)") + logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式,已标准化gender字段)") # 保留旧方法名作为别名,保持向后兼容 def save_profiles_to_json( diff --git a/backend/app/services/simulation_manager.py b/backend/app/services/simulation_manager.py index dcfaa96..d8d1e5c 100644 --- a/backend/app/services/simulation_manager.py +++ b/backend/app/services/simulation_manager.py @@ -127,12 +127,6 @@ class SimulationManager: '../../uploads/simulations' ) - # 预设脚本目录 - SCRIPTS_DIR = os.path.join( - os.path.dirname(__file__), - '../../scripts' - ) - def __init__(self): # 确保目录存在 os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True) @@ -426,27 +420,8 @@ class SimulationManager: total=3 ) - # ========== 阶段4: 复制预设脚本 ========== - script_files = ["run_twitter_simulation.py", "run_reddit_simulation.py", - "run_parallel_simulation.py", "action_logger.py"] - - if progress_callback: - progress_callback( - "copying_scripts", 0, - "开始准备脚本...", - current=0, - total=len(script_files) - ) - - self._copy_preset_scripts(sim_dir) - - if progress_callback: - progress_callback( - "copying_scripts", 100, - f"完成,共 {len(script_files)} 个脚本", - current=len(script_files), - total=len(script_files) - ) + # 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录 + # 启动模拟时,simulation_runner 会从 scripts/ 目录运行脚本 # 更新状态 state.status = SimulationStatus.READY @@ -466,24 +441,6 @@ class SimulationManager: self._save_simulation_state(state) raise - def _copy_preset_scripts(self, sim_dir: str): - """复制预设脚本到模拟目录""" - scripts = [ - "run_twitter_simulation.py", - "run_reddit_simulation.py", - "run_parallel_simulation.py" - ] - - for script in scripts: - src = os.path.join(self.SCRIPTS_DIR, script) - dst = os.path.join(sim_dir, script) - - if os.path.exists(src): - shutil.copy2(src, dst) - logger.debug(f"复制脚本: {script}") - else: - logger.warning(f"预设脚本不存在: {src}") - def get_simulation(self, simulation_id: str) -> Optional[SimulationState]: """获取模拟状态""" return self._load_simulation_state(simulation_id) @@ -531,21 +488,22 @@ class SimulationManager: """获取运行说明""" sim_dir = self._get_simulation_dir(simulation_id) config_path = os.path.join(sim_dir, "simulation_config.json") + scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts')) return { "simulation_dir": sim_dir, + "scripts_dir": scripts_dir, "config_file": config_path, "commands": { - "twitter": f"python run_twitter_simulation.py --config simulation_config.json", - "reddit": f"python run_reddit_simulation.py --config simulation_config.json", - "parallel": f"python run_parallel_simulation.py --config simulation_config.json", + "twitter": f"python {scripts_dir}/run_twitter_simulation.py --config {config_path}", + "reddit": f"python {scripts_dir}/run_reddit_simulation.py --config {config_path}", + "parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}", }, "instructions": ( - f"1. 进入模拟目录: cd {sim_dir}\n" - f"2. 激活conda环境: conda activate MiroFish\n" - f"3. 运行模拟:\n" - f" - 单独运行Twitter: python run_twitter_simulation.py --config simulation_config.json\n" - f" - 单独运行Reddit: python run_reddit_simulation.py --config simulation_config.json\n" - f" - 并行运行双平台: python run_parallel_simulation.py --config simulation_config.json" + f"1. 激活conda环境: conda activate MiroFish\n" + f"2. 运行模拟 (脚本位于 {scripts_dir}):\n" + f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n" + f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n" + f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}" ) } diff --git a/backend/app/services/simulation_runner.py b/backend/app/services/simulation_runner.py index 68d4b41..c77a083 100644 --- a/backend/app/services/simulation_runner.py +++ b/backend/app/services/simulation_runner.py @@ -182,11 +182,19 @@ class SimulationRunner: '../../uploads/simulations' ) + # 脚本目录 + SCRIPTS_DIR = os.path.join( + os.path.dirname(__file__), + '../../scripts' + ) + # 内存中的运行状态 _run_states: Dict[str, SimulationRunState] = {} _processes: Dict[str, subprocess.Popen] = {} _action_queues: Dict[str, Queue] = {} _monitor_threads: Dict[str, threading.Thread] = {} + _stdout_files: Dict[str, Any] = {} # 存储 stdout 文件句柄 + _stderr_files: Dict[str, Any] = {} # 存储 stderr 文件句柄 @classmethod def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: @@ -310,7 +318,7 @@ class SimulationRunner: cls._save_run_state(state) - # 确定运行哪个脚本 + # 确定运行哪个脚本(脚本位于 backend/scripts/ 目录) if platform == "twitter": script_name = "run_twitter_simulation.py" state.twitter_running = True @@ -322,7 +330,7 @@ class SimulationRunner: state.twitter_running = True state.reddit_running = True - script_path = os.path.join(sim_dir, script_name) + script_path = os.path.join(cls.SCRIPTS_DIR, script_name) if not os.path.exists(script_path): raise ValueError(f"脚本不存在: {script_path}") @@ -333,24 +341,36 @@ class SimulationRunner: # 启动模拟进程 try: - # 构建运行命令 + # 构建运行命令,使用完整路径 + action_log_path = os.path.join(sim_dir, "actions.jsonl") + cmd = [ sys.executable, # Python解释器 script_path, - "--config", "simulation_config.json", - "--action-log", "actions.jsonl", # 动作日志文件 + "--config", config_path, # 使用完整配置文件路径 + "--action-log", action_log_path, # 动作日志文件完整路径 ] - # 设置工作目录为模拟目录 + # 创建输出日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞 + stdout_log_path = os.path.join(sim_dir, "simulation_stdout.log") + stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log") + stdout_file = open(stdout_log_path, 'w', encoding='utf-8') + stderr_file = open(stderr_log_path, 'w', encoding='utf-8') + + # 设置工作目录为模拟目录(数据库等文件会生成在此) process = subprocess.Popen( cmd, cwd=sim_dir, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + stdout=stdout_file, + stderr=stderr_file, text=True, bufsize=1, ) + # 保存文件句柄以便后续关闭 + cls._stdout_files[simulation_id] = stdout_file + cls._stderr_files[simulation_id] = stderr_file + state.process_pid = process.pid state.runner_status = RunnerStatus.RUNNING cls._processes[simulation_id] = process @@ -434,8 +454,16 @@ class SimulationRunner: logger.info(f"模拟完成: {simulation_id}") else: state.runner_status = RunnerStatus.FAILED - stderr = process.stderr.read() if process.stderr else "" - state.error = f"进程退出码: {exit_code}, 错误: {stderr[:500]}" + # 从 stderr 日志文件读取错误信息 + stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log") + stderr = "" + try: + if os.path.exists(stderr_log_path): + with open(stderr_log_path, 'r', encoding='utf-8') as f: + stderr = f.read() + except Exception: + pass + state.error = f"进程退出码: {exit_code}, 错误: {stderr[-1000:]}" # 取最后1000字符 logger.error(f"模拟失败: {simulation_id}, error={state.error}") state.twitter_running = False @@ -449,9 +477,23 @@ class SimulationRunner: cls._save_run_state(state) finally: - # 清理 + # 清理进程资源 cls._processes.pop(simulation_id, None) cls._action_queues.pop(simulation_id, None) + + # 关闭日志文件句柄 + if simulation_id in cls._stdout_files: + try: + cls._stdout_files[simulation_id].close() + except Exception: + pass + cls._stdout_files.pop(simulation_id, None) + if simulation_id in cls._stderr_files: + try: + cls._stderr_files[simulation_id].close() + except Exception: + pass + cls._stderr_files.pop(simulation_id, None) @classmethod def stop_simulation(cls, simulation_id: str) -> SimulationRunState: diff --git a/backend/app/services/zep_entity_reader.py b/backend/app/services/zep_entity_reader.py index b2a3a3b..d165cd9 100644 --- a/backend/app/services/zep_entity_reader.py +++ b/backend/app/services/zep_entity_reader.py @@ -3,7 +3,8 @@ Zep实体读取与过滤服务 从Zep图谱中读取节点,筛选出符合预定义实体类型的节点 """ -from typing import Dict, Any, List, Optional, Set +import time +from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from dataclasses import dataclass, field from zep_cloud.client import Zep @@ -13,6 +14,9 @@ from ..utils.logger import get_logger logger = get_logger('mirofish.zep_entity_reader') +# 用于泛型返回类型 +T = TypeVar('T') + @dataclass class EntityNode: @@ -80,9 +84,48 @@ class ZepEntityReader: self.client = Zep(api_key=self.api_key) + def _call_with_retry( + self, + func: Callable[[], T], + operation_name: str, + max_retries: int = 3, + initial_delay: float = 2.0 + ) -> T: + """ + 带重试机制的Zep API调用 + + Args: + func: 要执行的函数(无参数的lambda或callable) + operation_name: 操作名称,用于日志 + max_retries: 最大重试次数(默认3次,即最多尝试3次) + initial_delay: 初始延迟秒数 + + Returns: + API调用结果 + """ + last_exception = None + delay = initial_delay + + for attempt in range(max_retries): + try: + return func() + except Exception as e: + last_exception = e + if attempt < max_retries - 1: + logger.warning( + f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, " + f"{delay:.1f}秒后重试..." + ) + time.sleep(delay) + delay *= 2 # 指数退避 + else: + logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}") + + raise last_exception + def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: """ - 获取图谱的所有节点 + 获取图谱的所有节点(带重试机制) Args: graph_id: 图谱ID @@ -92,7 +135,11 @@ class ZepEntityReader: """ logger.info(f"获取图谱 {graph_id} 的所有节点...") - nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id) + # 使用重试机制调用Zep API + nodes = self._call_with_retry( + func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id), + operation_name=f"获取节点(graph={graph_id})" + ) nodes_data = [] for node in nodes: @@ -109,7 +156,7 @@ class ZepEntityReader: def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: """ - 获取图谱的所有边 + 获取图谱的所有边(带重试机制) Args: graph_id: 图谱ID @@ -119,7 +166,11 @@ class ZepEntityReader: """ logger.info(f"获取图谱 {graph_id} 的所有边...") - edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id) + # 使用重试机制调用Zep API + edges = self._call_with_retry( + func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id), + operation_name=f"获取边(graph={graph_id})" + ) edges_data = [] for edge in edges: @@ -137,7 +188,7 @@ class ZepEntityReader: def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: """ - 获取指定节点的所有相关边 + 获取指定节点的所有相关边(带重试机制) Args: node_uuid: 节点UUID @@ -146,7 +197,11 @@ class ZepEntityReader: 边列表 """ try: - edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid) + # 使用重试机制调用Zep API + edges = self._call_with_retry( + func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), + operation_name=f"获取节点边(node={node_uuid[:8]}...)" + ) edges_data = [] for edge in edges: @@ -288,7 +343,7 @@ class ZepEntityReader: entity_uuid: str ) -> Optional[EntityNode]: """ - 获取单个实体及其完整上下文(边和关联节点) + 获取单个实体及其完整上下文(边和关联节点,带重试机制) Args: graph_id: 图谱ID @@ -298,8 +353,11 @@ class ZepEntityReader: EntityNode或None """ try: - # 获取节点 - node = self.client.graph.node.get(uuid_=entity_uuid) + # 使用重试机制获取节点 + node = self._call_with_retry( + func=lambda: self.client.graph.node.get(uuid_=entity_uuid), + operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" + ) if not node: return None diff --git a/backend/scripts/run_parallel_simulation.py b/backend/scripts/run_parallel_simulation.py index b3f6c07..6172b25 100644 --- a/backend/scripts/run_parallel_simulation.py +++ b/backend/scripts/run_parallel_simulation.py @@ -9,13 +9,118 @@ OASIS 双平台并行模拟预设脚本 import argparse import asyncio import json +import logging import os import random import sys from datetime import datetime from typing import Dict, Any, List, Optional -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# 添加 backend 目录到路径 +# 脚本固定位于 backend/scripts/ 目录 +_scripts_dir = os.path.dirname(os.path.abspath(__file__)) +_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) +_project_root = os.path.abspath(os.path.join(_backend_dir, '..')) +sys.path.insert(0, _scripts_dir) +sys.path.insert(0, _backend_dir) + +# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) +from dotenv import load_dotenv +_env_file = os.path.join(_project_root, '.env') +if os.path.exists(_env_file): + load_dotenv(_env_file) + print(f"已加载环境配置: {_env_file}") +else: + # 尝试加载 backend/.env + _backend_env = os.path.join(_backend_dir, '.env') + if os.path.exists(_backend_env): + load_dotenv(_backend_env) + print(f"已加载环境配置: {_backend_env}") + + +class UnicodeFormatter(logging.Formatter): + """ + 自定义格式化器,将 Unicode 转义序列(如 \\uXXXX)转换为可读字符 + """ + + # 匹配 \uXXXX 形式的 Unicode 转义序列 + UNICODE_ESCAPE_PATTERN = None + + @classmethod + def _get_pattern(cls): + if cls.UNICODE_ESCAPE_PATTERN is None: + import re + cls.UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') + return cls.UNICODE_ESCAPE_PATTERN + + def format(self, record): + # 先获取原始格式化结果 + result = super().format(record) + # 使用正则表达式替换 Unicode 转义序列 + pattern = self._get_pattern() + + def replace_unicode(match): + try: + return chr(int(match.group(1), 16)) + except (ValueError, OverflowError): + return match.group(0) + + return pattern.sub(replace_unicode, result) + + +def setup_oasis_logging(log_dir: str): + """ + 配置 OASIS 的日志,覆盖默认的带时间戳日志文件 + + Args: + log_dir: 日志目录路径 + """ + os.makedirs(log_dir, exist_ok=True) + + # 清理旧的日志文件 + for f in os.listdir(log_dir): + old_log = os.path.join(log_dir, f) + if os.path.isfile(old_log) and f.endswith('.log'): + try: + os.remove(old_log) + except OSError: + pass + + # 创建自定义格式化器(支持 Unicode 解码) + formatter = UnicodeFormatter( + "%(levelname)s - %(asctime)s - %(name)s - %(message)s" + ) + + # 重新配置 OASIS 使用的日志器,使用固定名称(不带时间戳) + loggers_config = { + "social.agent": os.path.join(log_dir, "social.agent.log"), + "social.twitter": os.path.join(log_dir, "social.twitter.log"), + "social.rec": os.path.join(log_dir, "social.rec.log"), + "oasis.env": os.path.join(log_dir, "oasis.env.log"), + "table": os.path.join(log_dir, "table.log"), + } + + for logger_name, log_file in loggers_config.items(): + logger = logging.getLogger(logger_name) + logger.setLevel(logging.DEBUG) + # 清除 OASIS 添加的现有处理器(带时间戳的日志文件) + logger.handlers.clear() + # 添加新的文件处理器(使用 UTF-8 编码,固定文件名) + file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w') + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + # 防止日志向上传播(避免重复) + logger.propagate = False + + print(f"日志配置完成,日志目录: {log_dir}") + + +def init_logging_for_simulation(simulation_dir: str): + """初始化模拟的日志配置""" + log_dir = os.path.join(simulation_dir, "log") + setup_oasis_logging(log_dir) + from action_logger import ActionLogger @@ -74,17 +179,34 @@ def create_model(config: Dict[str, Any]): """ 创建LLM模型 - OASIS使用camel-ai的ModelFactory,配置方式: - - 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量 - - 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量 - """ - llm_model = config.get("llm_model", "gpt-4o-mini") - llm_base_url = config.get("llm_base_url", "") + 统一使用项目根目录 .env 文件中的配置(优先级最高): + - LLM_API_KEY: API密钥 + - LLM_BASE_URL: API基础URL + - LLM_MODEL_NAME: 模型名称 + + OASIS使用camel-ai的ModelFactory,需要设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量 + """ + # 优先从 .env 读取配置 + llm_api_key = os.environ.get("LLM_API_KEY", "") + llm_base_url = os.environ.get("LLM_BASE_URL", "") + llm_model = os.environ.get("LLM_MODEL_NAME", "") + + # 如果 .env 中没有,则使用 config 作为备用 + if not llm_model: + llm_model = config.get("llm_model", "gpt-4o-mini") + + # 设置 camel-ai 所需的环境变量 + if llm_api_key: + os.environ["OPENAI_API_KEY"] = llm_api_key + + if not os.environ.get("OPENAI_API_KEY"): + raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY") - # 如果配置了base_url,设置环境变量 if llm_base_url: os.environ["OPENAI_API_BASE_URL"] = llm_base_url + print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...") + return ModelFactory.create( model_platform=ModelPlatformType.OPENAI, model_type=llm_model, @@ -453,6 +575,9 @@ async def main(): config = load_config(args.config) simulation_dir = os.path.dirname(args.config) or "." + # 初始化日志配置(清理旧日志文件,使用固定名称) + init_logging_for_simulation(simulation_dir) + # 创建动作日志记录器 action_log_path = os.path.join(simulation_dir, args.action_log) action_logger = ActionLogger(action_log_path) diff --git a/backend/scripts/run_reddit_simulation.py b/backend/scripts/run_reddit_simulation.py index 59f7748..2081bd2 100644 --- a/backend/scripts/run_reddit_simulation.py +++ b/backend/scripts/run_reddit_simulation.py @@ -9,6 +9,7 @@ OASIS Reddit模拟预设脚本 import argparse import asyncio import json +import logging import os import random import sys @@ -16,7 +17,76 @@ from datetime import datetime from typing import Dict, Any, List # 添加项目路径 -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +_scripts_dir = os.path.dirname(os.path.abspath(__file__)) +_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) +_project_root = os.path.abspath(os.path.join(_backend_dir, '..')) +sys.path.insert(0, _scripts_dir) +sys.path.insert(0, _backend_dir) + +# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) +from dotenv import load_dotenv +_env_file = os.path.join(_project_root, '.env') +if os.path.exists(_env_file): + load_dotenv(_env_file) +else: + _backend_env = os.path.join(_backend_dir, '.env') + if os.path.exists(_backend_env): + load_dotenv(_backend_env) + + +import re + + +class UnicodeFormatter(logging.Formatter): + """自定义格式化器,将 Unicode 转义序列转换为可读字符""" + + UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') + + def format(self, record): + result = super().format(record) + + def replace_unicode(match): + try: + return chr(int(match.group(1), 16)) + except (ValueError, OverflowError): + return match.group(0) + + return self.UNICODE_ESCAPE_PATTERN.sub(replace_unicode, result) + + +def setup_oasis_logging(log_dir: str): + """配置 OASIS 的日志,使用固定名称的日志文件""" + os.makedirs(log_dir, exist_ok=True) + + # 清理旧的日志文件 + for f in os.listdir(log_dir): + old_log = os.path.join(log_dir, f) + if os.path.isfile(old_log) and f.endswith('.log'): + try: + os.remove(old_log) + except OSError: + pass + + formatter = UnicodeFormatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s") + + loggers_config = { + "social.agent": os.path.join(log_dir, "social.agent.log"), + "social.twitter": os.path.join(log_dir, "social.twitter.log"), + "social.rec": os.path.join(log_dir, "social.rec.log"), + "oasis.env": os.path.join(log_dir, "oasis.env.log"), + "table": os.path.join(log_dir, "table.log"), + } + + for logger_name, log_file in loggers_config.items(): + logger = logging.getLogger(logger_name) + logger.setLevel(logging.DEBUG) + logger.handlers.clear() + file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w') + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + logger.propagate = False + try: from camel.models import ModelFactory @@ -82,19 +152,32 @@ class RedditSimulationRunner: """ 创建LLM模型 - OASIS使用camel-ai的ModelFactory,配置方式: - - 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量 - - 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量 + 统一使用项目根目录 .env 文件中的配置(优先级最高): + - LLM_API_KEY: API密钥 + - LLM_BASE_URL: API基础URL + - LLM_MODEL_NAME: 模型名称 """ - import os + # 优先从 .env 读取配置 + llm_api_key = os.environ.get("LLM_API_KEY", "") + llm_base_url = os.environ.get("LLM_BASE_URL", "") + llm_model = os.environ.get("LLM_MODEL_NAME", "") - llm_model = self.config.get("llm_model", "gpt-4o-mini") - llm_base_url = self.config.get("llm_base_url", "") + # 如果 .env 中没有,则使用 config 作为备用 + if not llm_model: + llm_model = self.config.get("llm_model", "gpt-4o-mini") + + # 设置 camel-ai 所需的环境变量 + if llm_api_key: + os.environ["OPENAI_API_KEY"] = llm_api_key + + if not os.environ.get("OPENAI_API_KEY"): + raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY") - # 如果配置了base_url,设置环境变量 if llm_base_url: os.environ["OPENAI_API_BASE_URL"] = llm_base_url + print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...") + return ModelFactory.create( model_platform=ModelPlatformType.OPENAI, model_type=llm_model, @@ -289,6 +372,10 @@ async def main(): print(f"错误: 配置文件不存在: {args.config}") sys.exit(1) + # 初始化日志配置(使用固定文件名,清理旧日志) + simulation_dir = os.path.dirname(args.config) or "." + setup_oasis_logging(os.path.join(simulation_dir, "log")) + runner = RedditSimulationRunner(args.config) await runner.run() diff --git a/backend/scripts/run_twitter_simulation.py b/backend/scripts/run_twitter_simulation.py index c5f8966..470ad6c 100644 --- a/backend/scripts/run_twitter_simulation.py +++ b/backend/scripts/run_twitter_simulation.py @@ -9,6 +9,7 @@ OASIS Twitter模拟预设脚本 import argparse import asyncio import json +import logging import os import random import sys @@ -16,7 +17,76 @@ from datetime import datetime from typing import Dict, Any, List # 添加项目路径 -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +_scripts_dir = os.path.dirname(os.path.abspath(__file__)) +_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) +_project_root = os.path.abspath(os.path.join(_backend_dir, '..')) +sys.path.insert(0, _scripts_dir) +sys.path.insert(0, _backend_dir) + +# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) +from dotenv import load_dotenv +_env_file = os.path.join(_project_root, '.env') +if os.path.exists(_env_file): + load_dotenv(_env_file) +else: + _backend_env = os.path.join(_backend_dir, '.env') + if os.path.exists(_backend_env): + load_dotenv(_backend_env) + + +import re + + +class UnicodeFormatter(logging.Formatter): + """自定义格式化器,将 Unicode 转义序列转换为可读字符""" + + UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') + + def format(self, record): + result = super().format(record) + + def replace_unicode(match): + try: + return chr(int(match.group(1), 16)) + except (ValueError, OverflowError): + return match.group(0) + + return self.UNICODE_ESCAPE_PATTERN.sub(replace_unicode, result) + + +def setup_oasis_logging(log_dir: str): + """配置 OASIS 的日志,使用固定名称的日志文件""" + os.makedirs(log_dir, exist_ok=True) + + # 清理旧的日志文件 + for f in os.listdir(log_dir): + old_log = os.path.join(log_dir, f) + if os.path.isfile(old_log) and f.endswith('.log'): + try: + os.remove(old_log) + except OSError: + pass + + formatter = UnicodeFormatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s") + + loggers_config = { + "social.agent": os.path.join(log_dir, "social.agent.log"), + "social.twitter": os.path.join(log_dir, "social.twitter.log"), + "social.rec": os.path.join(log_dir, "social.rec.log"), + "oasis.env": os.path.join(log_dir, "oasis.env.log"), + "table": os.path.join(log_dir, "table.log"), + } + + for logger_name, log_file in loggers_config.items(): + logger = logging.getLogger(logger_name) + logger.setLevel(logging.DEBUG) + logger.handlers.clear() + file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w') + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + logger.propagate = False + try: from camel.models import ModelFactory @@ -75,21 +145,32 @@ class TwitterSimulationRunner: """ 创建LLM模型 - OASIS使用camel-ai的ModelFactory,配置方式: - - 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量 - - 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量 - - 配置文件中的 llm_model 对应 model_type + 统一使用项目根目录 .env 文件中的配置(优先级最高): + - LLM_API_KEY: API密钥 + - LLM_BASE_URL: API基础URL + - LLM_MODEL_NAME: 模型名称 """ - import os + # 优先从 .env 读取配置 + llm_api_key = os.environ.get("LLM_API_KEY", "") + llm_base_url = os.environ.get("LLM_BASE_URL", "") + llm_model = os.environ.get("LLM_MODEL_NAME", "") - llm_model = self.config.get("llm_model", "gpt-4o-mini") - llm_base_url = self.config.get("llm_base_url", "") + # 如果 .env 中没有,则使用 config 作为备用 + if not llm_model: + llm_model = self.config.get("llm_model", "gpt-4o-mini") + + # 设置 camel-ai 所需的环境变量 + if llm_api_key: + os.environ["OPENAI_API_KEY"] = llm_api_key + + if not os.environ.get("OPENAI_API_KEY"): + raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY") - # 如果配置了base_url,设置环境变量(OASIS通过环境变量读取) if llm_base_url: os.environ["OPENAI_API_BASE_URL"] = llm_base_url + print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...") + return ModelFactory.create( model_platform=ModelPlatformType.OPENAI, model_type=llm_model, @@ -304,6 +385,10 @@ async def main(): print(f"错误: 配置文件不存在: {args.config}") sys.exit(1) + # 初始化日志配置(使用固定文件名,清理旧日志) + simulation_dir = os.path.dirname(args.config) or "." + setup_oasis_logging(os.path.join(simulation_dir, "log")) + runner = TwitterSimulationRunner(args.config) await runner.run()