Refactor simulation management and enhance logging capabilities

- Updated simulation preparation checks to exclude script files from the required files list, improving clarity on file management.
- Implemented a robust retry mechanism for Zep API calls in the ZepEntityReader service, enhancing reliability.
- Enhanced logging in simulation scripts to provide clearer insights into the simulation process and errors.
- Updated simulation runner to manage stdout and stderr logs more effectively, ensuring better error tracking.
- Improved profile generation to standardize gender fields and ensure all required fields are populated correctly.
This commit is contained in:
666ghj 2025-12-02 14:25:53 +08:00
parent af5c235695
commit 3cc5e3f479
8 changed files with 595 additions and 165 deletions

View file

@ -221,6 +221,8 @@ def _check_simulation_prepared(simulation_id: str) -> tuple:
1. state.json 存在且 status "ready" 1. state.json 存在且 status "ready"
2. 必要文件存在reddit_profiles.json, twitter_profiles.csv, simulation_config.json 2. 必要文件存在reddit_profiles.json, twitter_profiles.csv, simulation_config.json
注意运行脚本(run_*.py)保留在 backend/scripts/ 目录不再复制到模拟目录
Args: Args:
simulation_id: 模拟ID simulation_id: 模拟ID
@ -236,15 +238,12 @@ def _check_simulation_prepared(simulation_id: str) -> tuple:
if not os.path.exists(simulation_dir): if not os.path.exists(simulation_dir):
return False, {"reason": "模拟目录不存在"} return False, {"reason": "模拟目录不存在"}
# 必要文件列表 # 必要文件列表(不包括脚本,脚本位于 backend/scripts/
required_files = [ required_files = [
"state.json", "state.json",
"simulation_config.json", "simulation_config.json",
"reddit_profiles.json", "reddit_profiles.json",
"twitter_profiles.csv", "twitter_profiles.csv"
"run_reddit_simulation.py",
"run_twitter_simulation.py",
"run_parallel_simulation.py"
] ]
# 检查文件是否存在 # 检查文件是否存在
@ -272,9 +271,13 @@ def _check_simulation_prepared(simulation_id: str) -> tuple:
state_data = json.load(f) state_data = json.load(f)
status = state_data.get("status", "") status = state_data.get("status", "")
config_generated = state_data.get("config_generated", False)
# 详细日志
logger.debug(f"检测模拟准备状态: {simulation_id}, status={status}, config_generated={config_generated}")
# 如果状态是ready或preparing已有文件认为准备完成 # 如果状态是ready或preparing已有文件认为准备完成
if status in ["ready", "preparing"] and state_data.get("config_generated"): if status in ["ready", "preparing"] and config_generated:
# 获取文件统计信息 # 获取文件统计信息
profiles_file = os.path.join(simulation_dir, "reddit_profiles.json") profiles_file = os.path.join(simulation_dir, "reddit_profiles.json")
config_file = os.path.join(simulation_dir, "simulation_config.json") config_file = os.path.join(simulation_dir, "simulation_config.json")
@ -298,21 +301,23 @@ def _check_simulation_prepared(simulation_id: str) -> tuple:
except Exception as e: except Exception as e:
logger.warning(f"自动更新状态失败: {e}") logger.warning(f"自动更新状态失败: {e}")
logger.info(f"模拟 {simulation_id} 检测结果: 已准备完成 (status={status}, config_generated={config_generated})")
return True, { return True, {
"status": status, "status": status,
"entities_count": state_data.get("entities_count", 0), "entities_count": state_data.get("entities_count", 0),
"profiles_count": profiles_count, "profiles_count": profiles_count,
"entity_types": state_data.get("entity_types", []), "entity_types": state_data.get("entity_types", []),
"config_generated": state_data.get("config_generated", False), "config_generated": config_generated,
"created_at": state_data.get("created_at"), "created_at": state_data.get("created_at"),
"updated_at": state_data.get("updated_at"), "updated_at": state_data.get("updated_at"),
"existing_files": existing_files "existing_files": existing_files
} }
else: else:
logger.warning(f"模拟 {simulation_id} 检测结果: 未准备完成 (status={status}, config_generated={config_generated})")
return False, { return False, {
"reason": f"状态不是ready: {status}", "reason": f"状态不是ready或config_generated为false: status={status}, config_generated={config_generated}",
"status": status, "status": status,
"config_generated": state_data.get("config_generated", False) "config_generated": config_generated
} }
except Exception as e: except Exception as e:
@ -386,10 +391,13 @@ def prepare_simulation():
# 检查是否强制重新生成 # 检查是否强制重新生成
force_regenerate = data.get('force_regenerate', False) force_regenerate = data.get('force_regenerate', False)
logger.info(f"开始处理 /prepare 请求: simulation_id={simulation_id}, force_regenerate={force_regenerate}")
# 检查是否已经准备完成(避免重复生成) # 检查是否已经准备完成(避免重复生成)
if not force_regenerate: if not force_regenerate:
logger.debug(f"检查模拟 {simulation_id} 是否已准备完成...")
is_prepared, prepare_info = _check_simulation_prepared(simulation_id) is_prepared, prepare_info = _check_simulation_prepared(simulation_id)
logger.debug(f"检查结果: is_prepared={is_prepared}, prepare_info={prepare_info}")
if is_prepared: if is_prepared:
logger.info(f"模拟 {simulation_id} 已准备完成,跳过重复生成") logger.info(f"模拟 {simulation_id} 已准备完成,跳过重复生成")
return jsonify({ return jsonify({
@ -402,6 +410,8 @@ def prepare_simulation():
"prepare_info": prepare_info "prepare_info": prepare_info
} }
}) })
else:
logger.info(f"模拟 {simulation_id} 未准备完成,将启动准备任务")
# 从项目获取必要信息 # 从项目获取必要信息
project = ProjectManager.get_project(state.project_id) project = ProjectManager.get_project(state.project_id)
@ -850,25 +860,27 @@ def download_simulation_config(simulation_id: str):
}), 500 }), 500
@simulation_bp.route('/<simulation_id>/script/<script_name>/download', methods=['GET']) @simulation_bp.route('/script/<script_name>/download', methods=['GET'])
def download_simulation_script(simulation_id: str, script_name: str): def download_simulation_script(script_name: str):
""" """
下载模拟脚本文件 下载模拟运行脚本文件通用脚本位于 backend/scripts/
script_name可选值 script_name可选值
- run_twitter_simulation.py - run_twitter_simulation.py
- run_reddit_simulation.py - run_reddit_simulation.py
- run_parallel_simulation.py - run_parallel_simulation.py
- action_logger.py
""" """
try: try:
manager = SimulationManager() # 脚本位于 backend/scripts/ 目录
sim_dir = manager._get_simulation_dir(simulation_id) scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
# 验证脚本名称 # 验证脚本名称
allowed_scripts = [ allowed_scripts = [
"run_twitter_simulation.py", "run_twitter_simulation.py",
"run_reddit_simulation.py", "run_reddit_simulation.py",
"run_parallel_simulation.py" "run_parallel_simulation.py",
"action_logger.py"
] ]
if script_name not in allowed_scripts: if script_name not in allowed_scripts:
@ -877,12 +889,12 @@ def download_simulation_script(simulation_id: str, script_name: str):
"error": f"未知脚本: {script_name},可选: {allowed_scripts}" "error": f"未知脚本: {script_name},可选: {allowed_scripts}"
}), 400 }), 400
script_path = os.path.join(sim_dir, script_name) script_path = os.path.join(scripts_dir, script_name)
if not os.path.exists(script_path): if not os.path.exists(script_path):
return jsonify({ return jsonify({
"success": False, "success": False,
"error": "脚本文件不存在,请先调用 /prepare 接口" "error": f"脚本文件不存在: {script_name}"
}), 404 }), 404
return send_file( return send_file(

View file

@ -10,6 +10,7 @@ OASIS Agent Profile生成器
import json import json
import random import random
import time
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
@ -315,32 +316,54 @@ class OasisProfileGenerator:
comprehensive_query = f"关于{entity_name}的所有信息、活动、事件、关系和背景" comprehensive_query = f"关于{entity_name}的所有信息、活动、事件、关系和背景"
def search_edges(): def search_edges():
"""搜索边(事实/关系)""" """搜索边(事实/关系)- 带重试机制"""
try: max_retries = 3
return self.zep_client.graph.search( last_exception = None
query=comprehensive_query, delay = 2.0
graph_id=self.graph_id,
limit=30, for attempt in range(max_retries):
scope="edges", try:
reranker="rrf" return self.zep_client.graph.search(
) query=comprehensive_query,
except Exception as e: graph_id=self.graph_id,
logger.debug(f"Zep边搜索失败: {e}") limit=30,
return None scope="edges",
reranker="rrf"
)
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.debug(f"Zep边搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
time.sleep(delay)
delay *= 2
else:
logger.debug(f"Zep边搜索在 {max_retries} 次尝试后仍失败: {e}")
return None
def search_nodes(): def search_nodes():
"""搜索节点(实体摘要)""" """搜索节点(实体摘要)- 带重试机制"""
try: max_retries = 3
return self.zep_client.graph.search( last_exception = None
query=comprehensive_query, delay = 2.0
graph_id=self.graph_id,
limit=20, for attempt in range(max_retries):
scope="nodes", try:
reranker="rrf" return self.zep_client.graph.search(
) query=comprehensive_query,
except Exception as e: graph_id=self.graph_id,
logger.debug(f"Zep节点搜索失败: {e}") limit=20,
return None scope="nodes",
reranker="rrf"
)
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.debug(f"Zep节点搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
time.sleep(delay)
delay *= 2
else:
logger.debug(f"Zep节点搜索在 {max_retries} 次尝试后仍失败: {e}")
return None
try: try:
# 并行执行edges和nodes搜索 # 并行执行edges和nodes搜索
@ -684,18 +707,20 @@ class OasisProfileGenerator:
- 立场观点对话题的态度可能被激怒/感动的内容 - 立场观点对话题的态度可能被激怒/感动的内容
- 独特特征口头禅特殊经历个人爱好 - 独特特征口头禅特殊经历个人爱好
- 个人记忆人设的重要部分要介绍这个个体与事件的关联以及这个个体在事件中的已有动作与反应 - 个人记忆人设的重要部分要介绍这个个体与事件的关联以及这个个体在事件中的已有动作与反应
3. age: 年龄数字 3. age: 年龄数字必须是整数
4. gender: 性别/ 4. gender: 性别必须是英文: "male" "female"
5. mbti: MBTI类型 5. mbti: MBTI类型如INTJENFP等
6. country: 国家 6. country: 国家使用中文"中国"
7. profession: 职业 7. profession: 职业
8. interested_topics: 感兴趣话题数组 8. interested_topics: 感兴趣话题数组
重要: 重要:
- 所有字段值必须是字符串或数字不要使用换行符 - 所有字段值必须是字符串或数字不要使用换行符
- persona必须是一段连贯的文字描述 - persona必须是一段连贯的文字描述
- 使用中文 - 使用中文除了gender字段必须用英文male/female
- 内容要与实体信息保持一致""" - 内容要与实体信息保持一致
- age必须是有效的整数gender必须是"male""female"
"""
def _build_group_persona_prompt( def _build_group_persona_prompt(
self, self,
@ -731,17 +756,18 @@ class OasisProfileGenerator:
- 立场态度对核心话题的官方立场面对争议的处理方式 - 立场态度对核心话题的官方立场面对争议的处理方式
- 特殊说明代表的群体画像运营习惯 - 特殊说明代表的群体画像运营习惯
- 机构记忆机构人设的重要部分要介绍这个机构与事件的关联以及这个机构在事件中的已有动作与反应 - 机构记忆机构人设的重要部分要介绍这个机构与事件的关联以及这个机构在事件中的已有动作与反应
3. age: null机构不适用 3. age: 固定填30机构账号的虚拟年龄
4. gender: null机构不适用 4. gender: 固定填"other"机构账号使用other表示非个人
5. mbti: 可选用于描述账号风格如ISTJ代表严谨保守 5. mbti: MBTI类型用于描述账号风格如ISTJ代表严谨保守
6. country: 国家 6. country: 国家使用中文"中国"
7. profession: 机构职能描述 7. profession: 机构职能描述
8. interested_topics: 关注领域数组 8. interested_topics: 关注领域数组
重要: 重要:
- 所有字段值必须是字符串数字或null - 所有字段值必须是字符串或数字不允许null值
- persona必须是一段连贯的文字描述不要使用换行符 - persona必须是一段连贯的文字描述不要使用换行符
- 使用中文 - 使用中文除了gender字段必须用英文"other"
- age必须是整数30gender必须是字符串"other"
- 机构账号发言要符合其身份定位""" - 机构账号发言要符合其身份定位"""
def _generate_profile_rule_based( def _generate_profile_rule_based(
@ -784,6 +810,10 @@ class OasisProfileGenerator:
return { return {
"bio": f"Official account for {entity_name}. News and updates.", "bio": f"Official account for {entity_name}. News and updates.",
"persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.", "persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.",
"age": 30, # 机构虚拟年龄
"gender": "other", # 机构使用other
"mbti": "ISTJ", # 机构风格:严谨保守
"country": "中国",
"profession": "Media", "profession": "Media",
"interested_topics": ["General News", "Current Events", "Public Affairs"], "interested_topics": ["General News", "Current Events", "Public Affairs"],
} }
@ -792,6 +822,10 @@ class OasisProfileGenerator:
return { return {
"bio": f"Official account of {entity_name}.", "bio": f"Official account of {entity_name}.",
"persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.", "persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.",
"age": 30, # 机构虚拟年龄
"gender": "other", # 机构使用other
"mbti": "ISTJ", # 机构风格:严谨保守
"country": "中国",
"profession": entity_type, "profession": entity_type,
"interested_topics": ["Public Policy", "Community", "Official Announcements"], "interested_topics": ["Public Policy", "Community", "Official Announcements"],
} }
@ -1039,6 +1073,31 @@ class OasisProfileGenerator:
logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (OASIS CSV格式)") logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (OASIS CSV格式)")
def _normalize_gender(self, gender: Optional[str]) -> str:
"""
标准化gender字段为OASIS要求的英文格式
OASIS要求: male, female, other
"""
if not gender:
return "other"
gender_lower = gender.lower().strip()
# 中文映射
gender_map = {
"": "male",
"": "female",
"机构": "other",
"其他": "other",
# 英文已有
"male": "male",
"female": "female",
"other": "other",
}
return gender_map.get(gender_lower, "other")
def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str): def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str):
""" """
保存Reddit Profile为JSON格式 保存Reddit Profile为JSON格式
@ -1048,26 +1107,30 @@ class OasisProfileGenerator:
2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics 2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics
我们使用详细格式与用户示例数据(36个简单人设.json)保持一致 我们使用详细格式与用户示例数据(36个简单人设.json)保持一致
OASIS要求所有字段都必须存在
- age: 整数
- gender: "male", "female", "other"
- mbti: MBTI类型字符串
- country: 国家字符串
""" """
data = [] data = []
for profile in profiles: for profile in profiles:
# 使用详细格式(与用户示例兼容) # 使用详细格式(与用户示例兼容)
# 确保所有必需字段都有有效值
item = { item = {
"realname": profile.name, "realname": profile.name,
"username": profile.user_name, "username": profile.user_name,
"bio": profile.bio[:150] if profile.bio else "", # OASIS bio限制150字符 "bio": profile.bio[:150] if profile.bio else f"{profile.name}",
"persona": profile.persona or f"{profile.name} is a participant in social discussions.", "persona": profile.persona or f"{profile.name} is a participant in social discussions.",
# OASIS必需字段 - 确保都有默认值
"age": profile.age if profile.age else 30,
"gender": self._normalize_gender(profile.gender),
"mbti": profile.mbti if profile.mbti else "ISTJ",
"country": profile.country if profile.country else "中国",
} }
# 添加人设详情字段 # 可选字段
if profile.age:
item["age"] = profile.age
if profile.gender:
item["gender"] = profile.gender
if profile.mbti:
item["mbti"] = profile.mbti
if profile.country:
item["country"] = profile.country
if profile.profession: if profile.profession:
item["profession"] = profile.profession item["profession"] = profile.profession
if profile.interested_topics: if profile.interested_topics:
@ -1078,7 +1141,7 @@ class OasisProfileGenerator:
with open(file_path, 'w', encoding='utf-8') as f: with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2) json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式)") logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式已标准化gender字段)")
# 保留旧方法名作为别名,保持向后兼容 # 保留旧方法名作为别名,保持向后兼容
def save_profiles_to_json( def save_profiles_to_json(

View file

@ -127,12 +127,6 @@ class SimulationManager:
'../../uploads/simulations' '../../uploads/simulations'
) )
# 预设脚本目录
SCRIPTS_DIR = os.path.join(
os.path.dirname(__file__),
'../../scripts'
)
def __init__(self): def __init__(self):
# 确保目录存在 # 确保目录存在
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True) os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
@ -426,27 +420,8 @@ class SimulationManager:
total=3 total=3
) )
# ========== 阶段4: 复制预设脚本 ========== # 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录
script_files = ["run_twitter_simulation.py", "run_reddit_simulation.py", # 启动模拟时simulation_runner 会从 scripts/ 目录运行脚本
"run_parallel_simulation.py", "action_logger.py"]
if progress_callback:
progress_callback(
"copying_scripts", 0,
"开始准备脚本...",
current=0,
total=len(script_files)
)
self._copy_preset_scripts(sim_dir)
if progress_callback:
progress_callback(
"copying_scripts", 100,
f"完成,共 {len(script_files)} 个脚本",
current=len(script_files),
total=len(script_files)
)
# 更新状态 # 更新状态
state.status = SimulationStatus.READY state.status = SimulationStatus.READY
@ -466,24 +441,6 @@ class SimulationManager:
self._save_simulation_state(state) self._save_simulation_state(state)
raise raise
def _copy_preset_scripts(self, sim_dir: str):
"""复制预设脚本到模拟目录"""
scripts = [
"run_twitter_simulation.py",
"run_reddit_simulation.py",
"run_parallel_simulation.py"
]
for script in scripts:
src = os.path.join(self.SCRIPTS_DIR, script)
dst = os.path.join(sim_dir, script)
if os.path.exists(src):
shutil.copy2(src, dst)
logger.debug(f"复制脚本: {script}")
else:
logger.warning(f"预设脚本不存在: {src}")
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]: def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
"""获取模拟状态""" """获取模拟状态"""
return self._load_simulation_state(simulation_id) return self._load_simulation_state(simulation_id)
@ -531,21 +488,22 @@ class SimulationManager:
"""获取运行说明""" """获取运行说明"""
sim_dir = self._get_simulation_dir(simulation_id) sim_dir = self._get_simulation_dir(simulation_id)
config_path = os.path.join(sim_dir, "simulation_config.json") config_path = os.path.join(sim_dir, "simulation_config.json")
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
return { return {
"simulation_dir": sim_dir, "simulation_dir": sim_dir,
"scripts_dir": scripts_dir,
"config_file": config_path, "config_file": config_path,
"commands": { "commands": {
"twitter": f"python run_twitter_simulation.py --config simulation_config.json", "twitter": f"python {scripts_dir}/run_twitter_simulation.py --config {config_path}",
"reddit": f"python run_reddit_simulation.py --config simulation_config.json", "reddit": f"python {scripts_dir}/run_reddit_simulation.py --config {config_path}",
"parallel": f"python run_parallel_simulation.py --config simulation_config.json", "parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}",
}, },
"instructions": ( "instructions": (
f"1. 进入模拟目录: cd {sim_dir}\n" f"1. 激活conda环境: conda activate MiroFish\n"
f"2. 激活conda环境: conda activate MiroFish\n" f"2. 运行模拟 (脚本位于 {scripts_dir}):\n"
f"3. 运行模拟:\n" f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n"
f" - 单独运行Twitter: python run_twitter_simulation.py --config simulation_config.json\n" f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n"
f" - 单独运行Reddit: python run_reddit_simulation.py --config simulation_config.json\n" f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}"
f" - 并行运行双平台: python run_parallel_simulation.py --config simulation_config.json"
) )
} }

View file

@ -182,11 +182,19 @@ class SimulationRunner:
'../../uploads/simulations' '../../uploads/simulations'
) )
# 脚本目录
SCRIPTS_DIR = os.path.join(
os.path.dirname(__file__),
'../../scripts'
)
# 内存中的运行状态 # 内存中的运行状态
_run_states: Dict[str, SimulationRunState] = {} _run_states: Dict[str, SimulationRunState] = {}
_processes: Dict[str, subprocess.Popen] = {} _processes: Dict[str, subprocess.Popen] = {}
_action_queues: Dict[str, Queue] = {} _action_queues: Dict[str, Queue] = {}
_monitor_threads: Dict[str, threading.Thread] = {} _monitor_threads: Dict[str, threading.Thread] = {}
_stdout_files: Dict[str, Any] = {} # 存储 stdout 文件句柄
_stderr_files: Dict[str, Any] = {} # 存储 stderr 文件句柄
@classmethod @classmethod
def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
@ -310,7 +318,7 @@ class SimulationRunner:
cls._save_run_state(state) cls._save_run_state(state)
# 确定运行哪个脚本 # 确定运行哪个脚本(脚本位于 backend/scripts/ 目录)
if platform == "twitter": if platform == "twitter":
script_name = "run_twitter_simulation.py" script_name = "run_twitter_simulation.py"
state.twitter_running = True state.twitter_running = True
@ -322,7 +330,7 @@ class SimulationRunner:
state.twitter_running = True state.twitter_running = True
state.reddit_running = True state.reddit_running = True
script_path = os.path.join(sim_dir, script_name) script_path = os.path.join(cls.SCRIPTS_DIR, script_name)
if not os.path.exists(script_path): if not os.path.exists(script_path):
raise ValueError(f"脚本不存在: {script_path}") raise ValueError(f"脚本不存在: {script_path}")
@ -333,24 +341,36 @@ class SimulationRunner:
# 启动模拟进程 # 启动模拟进程
try: try:
# 构建运行命令 # 构建运行命令,使用完整路径
action_log_path = os.path.join(sim_dir, "actions.jsonl")
cmd = [ cmd = [
sys.executable, # Python解释器 sys.executable, # Python解释器
script_path, script_path,
"--config", "simulation_config.json", "--config", config_path, # 使用完整配置文件路径
"--action-log", "actions.jsonl", # 动作日志文件 "--action-log", action_log_path, # 动作日志文件完整路径
] ]
# 设置工作目录为模拟目录 # 创建输出日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
stdout_log_path = os.path.join(sim_dir, "simulation_stdout.log")
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
stdout_file = open(stdout_log_path, 'w', encoding='utf-8')
stderr_file = open(stderr_log_path, 'w', encoding='utf-8')
# 设置工作目录为模拟目录(数据库等文件会生成在此)
process = subprocess.Popen( process = subprocess.Popen(
cmd, cmd,
cwd=sim_dir, cwd=sim_dir,
stdout=subprocess.PIPE, stdout=stdout_file,
stderr=subprocess.PIPE, stderr=stderr_file,
text=True, text=True,
bufsize=1, bufsize=1,
) )
# 保存文件句柄以便后续关闭
cls._stdout_files[simulation_id] = stdout_file
cls._stderr_files[simulation_id] = stderr_file
state.process_pid = process.pid state.process_pid = process.pid
state.runner_status = RunnerStatus.RUNNING state.runner_status = RunnerStatus.RUNNING
cls._processes[simulation_id] = process cls._processes[simulation_id] = process
@ -434,8 +454,16 @@ class SimulationRunner:
logger.info(f"模拟完成: {simulation_id}") logger.info(f"模拟完成: {simulation_id}")
else: else:
state.runner_status = RunnerStatus.FAILED state.runner_status = RunnerStatus.FAILED
stderr = process.stderr.read() if process.stderr else "" # 从 stderr 日志文件读取错误信息
state.error = f"进程退出码: {exit_code}, 错误: {stderr[:500]}" stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
stderr = ""
try:
if os.path.exists(stderr_log_path):
with open(stderr_log_path, 'r', encoding='utf-8') as f:
stderr = f.read()
except Exception:
pass
state.error = f"进程退出码: {exit_code}, 错误: {stderr[-1000:]}" # 取最后1000字符
logger.error(f"模拟失败: {simulation_id}, error={state.error}") logger.error(f"模拟失败: {simulation_id}, error={state.error}")
state.twitter_running = False state.twitter_running = False
@ -449,10 +477,24 @@ class SimulationRunner:
cls._save_run_state(state) cls._save_run_state(state)
finally: finally:
# 清理 # 清理进程资源
cls._processes.pop(simulation_id, None) cls._processes.pop(simulation_id, None)
cls._action_queues.pop(simulation_id, None) cls._action_queues.pop(simulation_id, None)
# 关闭日志文件句柄
if simulation_id in cls._stdout_files:
try:
cls._stdout_files[simulation_id].close()
except Exception:
pass
cls._stdout_files.pop(simulation_id, None)
if simulation_id in cls._stderr_files:
try:
cls._stderr_files[simulation_id].close()
except Exception:
pass
cls._stderr_files.pop(simulation_id, None)
@classmethod @classmethod
def stop_simulation(cls, simulation_id: str) -> SimulationRunState: def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
"""停止模拟""" """停止模拟"""

View file

@ -3,7 +3,8 @@ Zep实体读取与过滤服务
从Zep图谱中读取节点筛选出符合预定义实体类型的节点 从Zep图谱中读取节点筛选出符合预定义实体类型的节点
""" """
from typing import Dict, Any, List, Optional, Set import time
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
from dataclasses import dataclass, field from dataclasses import dataclass, field
from zep_cloud.client import Zep from zep_cloud.client import Zep
@ -13,6 +14,9 @@ from ..utils.logger import get_logger
logger = get_logger('mirofish.zep_entity_reader') logger = get_logger('mirofish.zep_entity_reader')
# 用于泛型返回类型
T = TypeVar('T')
@dataclass @dataclass
class EntityNode: class EntityNode:
@ -80,9 +84,48 @@ class ZepEntityReader:
self.client = Zep(api_key=self.api_key) self.client = Zep(api_key=self.api_key)
def _call_with_retry(
self,
func: Callable[[], T],
operation_name: str,
max_retries: int = 3,
initial_delay: float = 2.0
) -> T:
"""
带重试机制的Zep API调用
Args:
func: 要执行的函数无参数的lambda或callable
operation_name: 操作名称用于日志
max_retries: 最大重试次数默认3次即最多尝试3次
initial_delay: 初始延迟秒数
Returns:
API调用结果
"""
last_exception = None
delay = initial_delay
for attempt in range(max_retries):
try:
return func()
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.warning(
f"Zep {operation_name}{attempt + 1} 次尝试失败: {str(e)[:100]}, "
f"{delay:.1f}秒后重试..."
)
time.sleep(delay)
delay *= 2 # 指数退避
else:
logger.error(f"Zep {operation_name}{max_retries} 次尝试后仍失败: {str(e)}")
raise last_exception
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
""" """
获取图谱的所有节点 获取图谱的所有节点带重试机制
Args: Args:
graph_id: 图谱ID graph_id: 图谱ID
@ -92,7 +135,11 @@ class ZepEntityReader:
""" """
logger.info(f"获取图谱 {graph_id} 的所有节点...") logger.info(f"获取图谱 {graph_id} 的所有节点...")
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id) # 使用重试机制调用Zep API
nodes = self._call_with_retry(
func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id),
operation_name=f"获取节点(graph={graph_id})"
)
nodes_data = [] nodes_data = []
for node in nodes: for node in nodes:
@ -109,7 +156,7 @@ class ZepEntityReader:
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
""" """
获取图谱的所有边 获取图谱的所有边带重试机制
Args: Args:
graph_id: 图谱ID graph_id: 图谱ID
@ -119,7 +166,11 @@ class ZepEntityReader:
""" """
logger.info(f"获取图谱 {graph_id} 的所有边...") logger.info(f"获取图谱 {graph_id} 的所有边...")
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id) # 使用重试机制调用Zep API
edges = self._call_with_retry(
func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id),
operation_name=f"获取边(graph={graph_id})"
)
edges_data = [] edges_data = []
for edge in edges: for edge in edges:
@ -137,7 +188,7 @@ class ZepEntityReader:
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
""" """
获取指定节点的所有相关边 获取指定节点的所有相关边带重试机制
Args: Args:
node_uuid: 节点UUID node_uuid: 节点UUID
@ -146,7 +197,11 @@ class ZepEntityReader:
边列表 边列表
""" """
try: try:
edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid) # 使用重试机制调用Zep API
edges = self._call_with_retry(
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
)
edges_data = [] edges_data = []
for edge in edges: for edge in edges:
@ -288,7 +343,7 @@ class ZepEntityReader:
entity_uuid: str entity_uuid: str
) -> Optional[EntityNode]: ) -> Optional[EntityNode]:
""" """
获取单个实体及其完整上下文边和关联节点 获取单个实体及其完整上下文边和关联节点带重试机制
Args: Args:
graph_id: 图谱ID graph_id: 图谱ID
@ -298,8 +353,11 @@ class ZepEntityReader:
EntityNode或None EntityNode或None
""" """
try: try:
# 获取节点 # 使用重试机制获取节点
node = self.client.graph.node.get(uuid_=entity_uuid) node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
)
if not node: if not node:
return None return None

View file

@ -9,13 +9,118 @@ OASIS 双平台并行模拟预设脚本
import argparse import argparse
import asyncio import asyncio
import json import json
import logging
import os import os
import random import random
import sys import sys
from datetime import datetime from datetime import datetime
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 添加 backend 目录到路径
# 脚本固定位于 backend/scripts/ 目录
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..'))
_project_root = os.path.abspath(os.path.join(_backend_dir, '..'))
sys.path.insert(0, _scripts_dir)
sys.path.insert(0, _backend_dir)
# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置)
from dotenv import load_dotenv
_env_file = os.path.join(_project_root, '.env')
if os.path.exists(_env_file):
load_dotenv(_env_file)
print(f"已加载环境配置: {_env_file}")
else:
# 尝试加载 backend/.env
_backend_env = os.path.join(_backend_dir, '.env')
if os.path.exists(_backend_env):
load_dotenv(_backend_env)
print(f"已加载环境配置: {_backend_env}")
class UnicodeFormatter(logging.Formatter):
"""
自定义格式化器 Unicode 转义序列 \\uXXXX转换为可读字符
"""
# 匹配 \uXXXX 形式的 Unicode 转义序列
UNICODE_ESCAPE_PATTERN = None
@classmethod
def _get_pattern(cls):
if cls.UNICODE_ESCAPE_PATTERN is None:
import re
cls.UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
return cls.UNICODE_ESCAPE_PATTERN
def format(self, record):
# 先获取原始格式化结果
result = super().format(record)
# 使用正则表达式替换 Unicode 转义序列
pattern = self._get_pattern()
def replace_unicode(match):
try:
return chr(int(match.group(1), 16))
except (ValueError, OverflowError):
return match.group(0)
return pattern.sub(replace_unicode, result)
def setup_oasis_logging(log_dir: str):
"""
配置 OASIS 的日志覆盖默认的带时间戳日志文件
Args:
log_dir: 日志目录路径
"""
os.makedirs(log_dir, exist_ok=True)
# 清理旧的日志文件
for f in os.listdir(log_dir):
old_log = os.path.join(log_dir, f)
if os.path.isfile(old_log) and f.endswith('.log'):
try:
os.remove(old_log)
except OSError:
pass
# 创建自定义格式化器(支持 Unicode 解码)
formatter = UnicodeFormatter(
"%(levelname)s - %(asctime)s - %(name)s - %(message)s"
)
# 重新配置 OASIS 使用的日志器,使用固定名称(不带时间戳)
loggers_config = {
"social.agent": os.path.join(log_dir, "social.agent.log"),
"social.twitter": os.path.join(log_dir, "social.twitter.log"),
"social.rec": os.path.join(log_dir, "social.rec.log"),
"oasis.env": os.path.join(log_dir, "oasis.env.log"),
"table": os.path.join(log_dir, "table.log"),
}
for logger_name, log_file in loggers_config.items():
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG)
# 清除 OASIS 添加的现有处理器(带时间戳的日志文件)
logger.handlers.clear()
# 添加新的文件处理器(使用 UTF-8 编码,固定文件名)
file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# 防止日志向上传播(避免重复)
logger.propagate = False
print(f"日志配置完成,日志目录: {log_dir}")
def init_logging_for_simulation(simulation_dir: str):
"""初始化模拟的日志配置"""
log_dir = os.path.join(simulation_dir, "log")
setup_oasis_logging(log_dir)
from action_logger import ActionLogger from action_logger import ActionLogger
@ -74,17 +179,34 @@ def create_model(config: Dict[str, Any]):
""" """
创建LLM模型 创建LLM模型
OASIS使用camel-ai的ModelFactory配置方式 统一使用项目根目录 .env 文件中的配置优先级最高
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量 - LLM_API_KEY: API密钥
- 自定义API: 设置 OPENAI_API_KEY OPENAI_API_BASE_URL 环境变量 - LLM_BASE_URL: API基础URL
""" - LLM_MODEL_NAME: 模型名称
llm_model = config.get("llm_model", "gpt-4o-mini")
llm_base_url = config.get("llm_base_url", "") OASIS使用camel-ai的ModelFactory需要设置 OPENAI_API_KEY OPENAI_API_BASE_URL 环境变量
"""
# 优先从 .env 读取配置
llm_api_key = os.environ.get("LLM_API_KEY", "")
llm_base_url = os.environ.get("LLM_BASE_URL", "")
llm_model = os.environ.get("LLM_MODEL_NAME", "")
# 如果 .env 中没有,则使用 config 作为备用
if not llm_model:
llm_model = config.get("llm_model", "gpt-4o-mini")
# 设置 camel-ai 所需的环境变量
if llm_api_key:
os.environ["OPENAI_API_KEY"] = llm_api_key
if not os.environ.get("OPENAI_API_KEY"):
raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY")
# 如果配置了base_url设置环境变量
if llm_base_url: if llm_base_url:
os.environ["OPENAI_API_BASE_URL"] = llm_base_url os.environ["OPENAI_API_BASE_URL"] = llm_base_url
print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...")
return ModelFactory.create( return ModelFactory.create(
model_platform=ModelPlatformType.OPENAI, model_platform=ModelPlatformType.OPENAI,
model_type=llm_model, model_type=llm_model,
@ -453,6 +575,9 @@ async def main():
config = load_config(args.config) config = load_config(args.config)
simulation_dir = os.path.dirname(args.config) or "." simulation_dir = os.path.dirname(args.config) or "."
# 初始化日志配置(清理旧日志文件,使用固定名称)
init_logging_for_simulation(simulation_dir)
# 创建动作日志记录器 # 创建动作日志记录器
action_log_path = os.path.join(simulation_dir, args.action_log) action_log_path = os.path.join(simulation_dir, args.action_log)
action_logger = ActionLogger(action_log_path) action_logger = ActionLogger(action_log_path)

View file

@ -9,6 +9,7 @@ OASIS Reddit模拟预设脚本
import argparse import argparse
import asyncio import asyncio
import json import json
import logging
import os import os
import random import random
import sys import sys
@ -16,7 +17,76 @@ from datetime import datetime
from typing import Dict, Any, List from typing import Dict, Any, List
# 添加项目路径 # 添加项目路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) _scripts_dir = os.path.dirname(os.path.abspath(__file__))
_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..'))
_project_root = os.path.abspath(os.path.join(_backend_dir, '..'))
sys.path.insert(0, _scripts_dir)
sys.path.insert(0, _backend_dir)
# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置)
from dotenv import load_dotenv
_env_file = os.path.join(_project_root, '.env')
if os.path.exists(_env_file):
load_dotenv(_env_file)
else:
_backend_env = os.path.join(_backend_dir, '.env')
if os.path.exists(_backend_env):
load_dotenv(_backend_env)
import re
class UnicodeFormatter(logging.Formatter):
"""自定义格式化器,将 Unicode 转义序列转换为可读字符"""
UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
def format(self, record):
result = super().format(record)
def replace_unicode(match):
try:
return chr(int(match.group(1), 16))
except (ValueError, OverflowError):
return match.group(0)
return self.UNICODE_ESCAPE_PATTERN.sub(replace_unicode, result)
def setup_oasis_logging(log_dir: str):
"""配置 OASIS 的日志,使用固定名称的日志文件"""
os.makedirs(log_dir, exist_ok=True)
# 清理旧的日志文件
for f in os.listdir(log_dir):
old_log = os.path.join(log_dir, f)
if os.path.isfile(old_log) and f.endswith('.log'):
try:
os.remove(old_log)
except OSError:
pass
formatter = UnicodeFormatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s")
loggers_config = {
"social.agent": os.path.join(log_dir, "social.agent.log"),
"social.twitter": os.path.join(log_dir, "social.twitter.log"),
"social.rec": os.path.join(log_dir, "social.rec.log"),
"oasis.env": os.path.join(log_dir, "oasis.env.log"),
"table": os.path.join(log_dir, "table.log"),
}
for logger_name, log_file in loggers_config.items():
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG)
logger.handlers.clear()
file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.propagate = False
try: try:
from camel.models import ModelFactory from camel.models import ModelFactory
@ -82,19 +152,32 @@ class RedditSimulationRunner:
""" """
创建LLM模型 创建LLM模型
OASIS使用camel-ai的ModelFactory配置方式 统一使用项目根目录 .env 文件中的配置优先级最高
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量 - LLM_API_KEY: API密钥
- 自定义API: 设置 OPENAI_API_KEY OPENAI_API_BASE_URL 环境变量 - LLM_BASE_URL: API基础URL
- LLM_MODEL_NAME: 模型名称
""" """
import os # 优先从 .env 读取配置
llm_api_key = os.environ.get("LLM_API_KEY", "")
llm_base_url = os.environ.get("LLM_BASE_URL", "")
llm_model = os.environ.get("LLM_MODEL_NAME", "")
llm_model = self.config.get("llm_model", "gpt-4o-mini") # 如果 .env 中没有,则使用 config 作为备用
llm_base_url = self.config.get("llm_base_url", "") if not llm_model:
llm_model = self.config.get("llm_model", "gpt-4o-mini")
# 设置 camel-ai 所需的环境变量
if llm_api_key:
os.environ["OPENAI_API_KEY"] = llm_api_key
if not os.environ.get("OPENAI_API_KEY"):
raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY")
# 如果配置了base_url设置环境变量
if llm_base_url: if llm_base_url:
os.environ["OPENAI_API_BASE_URL"] = llm_base_url os.environ["OPENAI_API_BASE_URL"] = llm_base_url
print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...")
return ModelFactory.create( return ModelFactory.create(
model_platform=ModelPlatformType.OPENAI, model_platform=ModelPlatformType.OPENAI,
model_type=llm_model, model_type=llm_model,
@ -289,6 +372,10 @@ async def main():
print(f"错误: 配置文件不存在: {args.config}") print(f"错误: 配置文件不存在: {args.config}")
sys.exit(1) sys.exit(1)
# 初始化日志配置(使用固定文件名,清理旧日志)
simulation_dir = os.path.dirname(args.config) or "."
setup_oasis_logging(os.path.join(simulation_dir, "log"))
runner = RedditSimulationRunner(args.config) runner = RedditSimulationRunner(args.config)
await runner.run() await runner.run()

View file

@ -9,6 +9,7 @@ OASIS Twitter模拟预设脚本
import argparse import argparse
import asyncio import asyncio
import json import json
import logging
import os import os
import random import random
import sys import sys
@ -16,7 +17,76 @@ from datetime import datetime
from typing import Dict, Any, List from typing import Dict, Any, List
# 添加项目路径 # 添加项目路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) _scripts_dir = os.path.dirname(os.path.abspath(__file__))
_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..'))
_project_root = os.path.abspath(os.path.join(_backend_dir, '..'))
sys.path.insert(0, _scripts_dir)
sys.path.insert(0, _backend_dir)
# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置)
from dotenv import load_dotenv
_env_file = os.path.join(_project_root, '.env')
if os.path.exists(_env_file):
load_dotenv(_env_file)
else:
_backend_env = os.path.join(_backend_dir, '.env')
if os.path.exists(_backend_env):
load_dotenv(_backend_env)
import re
class UnicodeFormatter(logging.Formatter):
"""自定义格式化器,将 Unicode 转义序列转换为可读字符"""
UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
def format(self, record):
result = super().format(record)
def replace_unicode(match):
try:
return chr(int(match.group(1), 16))
except (ValueError, OverflowError):
return match.group(0)
return self.UNICODE_ESCAPE_PATTERN.sub(replace_unicode, result)
def setup_oasis_logging(log_dir: str):
"""配置 OASIS 的日志,使用固定名称的日志文件"""
os.makedirs(log_dir, exist_ok=True)
# 清理旧的日志文件
for f in os.listdir(log_dir):
old_log = os.path.join(log_dir, f)
if os.path.isfile(old_log) and f.endswith('.log'):
try:
os.remove(old_log)
except OSError:
pass
formatter = UnicodeFormatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s")
loggers_config = {
"social.agent": os.path.join(log_dir, "social.agent.log"),
"social.twitter": os.path.join(log_dir, "social.twitter.log"),
"social.rec": os.path.join(log_dir, "social.rec.log"),
"oasis.env": os.path.join(log_dir, "oasis.env.log"),
"table": os.path.join(log_dir, "table.log"),
}
for logger_name, log_file in loggers_config.items():
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG)
logger.handlers.clear()
file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.propagate = False
try: try:
from camel.models import ModelFactory from camel.models import ModelFactory
@ -75,21 +145,32 @@ class TwitterSimulationRunner:
""" """
创建LLM模型 创建LLM模型
OASIS使用camel-ai的ModelFactory配置方式 统一使用项目根目录 .env 文件中的配置优先级最高
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量 - LLM_API_KEY: API密钥
- 自定义API: 设置 OPENAI_API_KEY OPENAI_API_BASE_URL 环境变量 - LLM_BASE_URL: API基础URL
- LLM_MODEL_NAME: 模型名称
配置文件中的 llm_model 对应 model_type
""" """
import os # 优先从 .env 读取配置
llm_api_key = os.environ.get("LLM_API_KEY", "")
llm_base_url = os.environ.get("LLM_BASE_URL", "")
llm_model = os.environ.get("LLM_MODEL_NAME", "")
llm_model = self.config.get("llm_model", "gpt-4o-mini") # 如果 .env 中没有,则使用 config 作为备用
llm_base_url = self.config.get("llm_base_url", "") if not llm_model:
llm_model = self.config.get("llm_model", "gpt-4o-mini")
# 设置 camel-ai 所需的环境变量
if llm_api_key:
os.environ["OPENAI_API_KEY"] = llm_api_key
if not os.environ.get("OPENAI_API_KEY"):
raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY")
# 如果配置了base_url设置环境变量OASIS通过环境变量读取
if llm_base_url: if llm_base_url:
os.environ["OPENAI_API_BASE_URL"] = llm_base_url os.environ["OPENAI_API_BASE_URL"] = llm_base_url
print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...")
return ModelFactory.create( return ModelFactory.create(
model_platform=ModelPlatformType.OPENAI, model_platform=ModelPlatformType.OPENAI,
model_type=llm_model, model_type=llm_model,
@ -304,6 +385,10 @@ async def main():
print(f"错误: 配置文件不存在: {args.config}") print(f"错误: 配置文件不存在: {args.config}")
sys.exit(1) sys.exit(1)
# 初始化日志配置(使用固定文件名,清理旧日志)
simulation_dir = os.path.dirname(args.config) or "."
setup_oasis_logging(os.path.join(simulation_dir, "log"))
runner = TwitterSimulationRunner(args.config) runner = TwitterSimulationRunner(args.config)
await runner.run() await runner.run()