Refactor simulation management and enhance logging capabilities
- Updated simulation preparation checks to exclude script files from the required files list, improving clarity on file management. - Implemented a robust retry mechanism for Zep API calls in the ZepEntityReader service, enhancing reliability. - Enhanced logging in simulation scripts to provide clearer insights into the simulation process and errors. - Updated simulation runner to manage stdout and stderr logs more effectively, ensuring better error tracking. - Improved profile generation to standardize gender fields and ensure all required fields are populated correctly.
This commit is contained in:
parent
af5c235695
commit
3cc5e3f479
8 changed files with 595 additions and 165 deletions
|
|
@ -221,6 +221,8 @@ def _check_simulation_prepared(simulation_id: str) -> tuple:
|
||||||
1. state.json 存在且 status 为 "ready"
|
1. state.json 存在且 status 为 "ready"
|
||||||
2. 必要文件存在:reddit_profiles.json, twitter_profiles.csv, simulation_config.json
|
2. 必要文件存在:reddit_profiles.json, twitter_profiles.csv, simulation_config.json
|
||||||
|
|
||||||
|
注意:运行脚本(run_*.py)保留在 backend/scripts/ 目录,不再复制到模拟目录
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
simulation_id: 模拟ID
|
simulation_id: 模拟ID
|
||||||
|
|
||||||
|
|
@ -236,15 +238,12 @@ def _check_simulation_prepared(simulation_id: str) -> tuple:
|
||||||
if not os.path.exists(simulation_dir):
|
if not os.path.exists(simulation_dir):
|
||||||
return False, {"reason": "模拟目录不存在"}
|
return False, {"reason": "模拟目录不存在"}
|
||||||
|
|
||||||
# 必要文件列表
|
# 必要文件列表(不包括脚本,脚本位于 backend/scripts/)
|
||||||
required_files = [
|
required_files = [
|
||||||
"state.json",
|
"state.json",
|
||||||
"simulation_config.json",
|
"simulation_config.json",
|
||||||
"reddit_profiles.json",
|
"reddit_profiles.json",
|
||||||
"twitter_profiles.csv",
|
"twitter_profiles.csv"
|
||||||
"run_reddit_simulation.py",
|
|
||||||
"run_twitter_simulation.py",
|
|
||||||
"run_parallel_simulation.py"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# 检查文件是否存在
|
# 检查文件是否存在
|
||||||
|
|
@ -272,9 +271,13 @@ def _check_simulation_prepared(simulation_id: str) -> tuple:
|
||||||
state_data = json.load(f)
|
state_data = json.load(f)
|
||||||
|
|
||||||
status = state_data.get("status", "")
|
status = state_data.get("status", "")
|
||||||
|
config_generated = state_data.get("config_generated", False)
|
||||||
|
|
||||||
|
# 详细日志
|
||||||
|
logger.debug(f"检测模拟准备状态: {simulation_id}, status={status}, config_generated={config_generated}")
|
||||||
|
|
||||||
# 如果状态是ready或preparing(已有文件),认为准备完成
|
# 如果状态是ready或preparing(已有文件),认为准备完成
|
||||||
if status in ["ready", "preparing"] and state_data.get("config_generated"):
|
if status in ["ready", "preparing"] and config_generated:
|
||||||
# 获取文件统计信息
|
# 获取文件统计信息
|
||||||
profiles_file = os.path.join(simulation_dir, "reddit_profiles.json")
|
profiles_file = os.path.join(simulation_dir, "reddit_profiles.json")
|
||||||
config_file = os.path.join(simulation_dir, "simulation_config.json")
|
config_file = os.path.join(simulation_dir, "simulation_config.json")
|
||||||
|
|
@ -298,21 +301,23 @@ def _check_simulation_prepared(simulation_id: str) -> tuple:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"自动更新状态失败: {e}")
|
logger.warning(f"自动更新状态失败: {e}")
|
||||||
|
|
||||||
|
logger.info(f"模拟 {simulation_id} 检测结果: 已准备完成 (status={status}, config_generated={config_generated})")
|
||||||
return True, {
|
return True, {
|
||||||
"status": status,
|
"status": status,
|
||||||
"entities_count": state_data.get("entities_count", 0),
|
"entities_count": state_data.get("entities_count", 0),
|
||||||
"profiles_count": profiles_count,
|
"profiles_count": profiles_count,
|
||||||
"entity_types": state_data.get("entity_types", []),
|
"entity_types": state_data.get("entity_types", []),
|
||||||
"config_generated": state_data.get("config_generated", False),
|
"config_generated": config_generated,
|
||||||
"created_at": state_data.get("created_at"),
|
"created_at": state_data.get("created_at"),
|
||||||
"updated_at": state_data.get("updated_at"),
|
"updated_at": state_data.get("updated_at"),
|
||||||
"existing_files": existing_files
|
"existing_files": existing_files
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
logger.warning(f"模拟 {simulation_id} 检测结果: 未准备完成 (status={status}, config_generated={config_generated})")
|
||||||
return False, {
|
return False, {
|
||||||
"reason": f"状态不是ready: {status}",
|
"reason": f"状态不是ready或config_generated为false: status={status}, config_generated={config_generated}",
|
||||||
"status": status,
|
"status": status,
|
||||||
"config_generated": state_data.get("config_generated", False)
|
"config_generated": config_generated
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -386,10 +391,13 @@ def prepare_simulation():
|
||||||
|
|
||||||
# 检查是否强制重新生成
|
# 检查是否强制重新生成
|
||||||
force_regenerate = data.get('force_regenerate', False)
|
force_regenerate = data.get('force_regenerate', False)
|
||||||
|
logger.info(f"开始处理 /prepare 请求: simulation_id={simulation_id}, force_regenerate={force_regenerate}")
|
||||||
|
|
||||||
# 检查是否已经准备完成(避免重复生成)
|
# 检查是否已经准备完成(避免重复生成)
|
||||||
if not force_regenerate:
|
if not force_regenerate:
|
||||||
|
logger.debug(f"检查模拟 {simulation_id} 是否已准备完成...")
|
||||||
is_prepared, prepare_info = _check_simulation_prepared(simulation_id)
|
is_prepared, prepare_info = _check_simulation_prepared(simulation_id)
|
||||||
|
logger.debug(f"检查结果: is_prepared={is_prepared}, prepare_info={prepare_info}")
|
||||||
if is_prepared:
|
if is_prepared:
|
||||||
logger.info(f"模拟 {simulation_id} 已准备完成,跳过重复生成")
|
logger.info(f"模拟 {simulation_id} 已准备完成,跳过重复生成")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
@ -402,6 +410,8 @@ def prepare_simulation():
|
||||||
"prepare_info": prepare_info
|
"prepare_info": prepare_info
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
else:
|
||||||
|
logger.info(f"模拟 {simulation_id} 未准备完成,将启动准备任务")
|
||||||
|
|
||||||
# 从项目获取必要信息
|
# 从项目获取必要信息
|
||||||
project = ProjectManager.get_project(state.project_id)
|
project = ProjectManager.get_project(state.project_id)
|
||||||
|
|
@ -850,25 +860,27 @@ def download_simulation_config(simulation_id: str):
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
@simulation_bp.route('/<simulation_id>/script/<script_name>/download', methods=['GET'])
|
@simulation_bp.route('/script/<script_name>/download', methods=['GET'])
|
||||||
def download_simulation_script(simulation_id: str, script_name: str):
|
def download_simulation_script(script_name: str):
|
||||||
"""
|
"""
|
||||||
下载模拟脚本文件
|
下载模拟运行脚本文件(通用脚本,位于 backend/scripts/)
|
||||||
|
|
||||||
script_name可选值:
|
script_name可选值:
|
||||||
- run_twitter_simulation.py
|
- run_twitter_simulation.py
|
||||||
- run_reddit_simulation.py
|
- run_reddit_simulation.py
|
||||||
- run_parallel_simulation.py
|
- run_parallel_simulation.py
|
||||||
|
- action_logger.py
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
manager = SimulationManager()
|
# 脚本位于 backend/scripts/ 目录
|
||||||
sim_dir = manager._get_simulation_dir(simulation_id)
|
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
|
||||||
|
|
||||||
# 验证脚本名称
|
# 验证脚本名称
|
||||||
allowed_scripts = [
|
allowed_scripts = [
|
||||||
"run_twitter_simulation.py",
|
"run_twitter_simulation.py",
|
||||||
"run_reddit_simulation.py",
|
"run_reddit_simulation.py",
|
||||||
"run_parallel_simulation.py"
|
"run_parallel_simulation.py",
|
||||||
|
"action_logger.py"
|
||||||
]
|
]
|
||||||
|
|
||||||
if script_name not in allowed_scripts:
|
if script_name not in allowed_scripts:
|
||||||
|
|
@ -877,12 +889,12 @@ def download_simulation_script(simulation_id: str, script_name: str):
|
||||||
"error": f"未知脚本: {script_name},可选: {allowed_scripts}"
|
"error": f"未知脚本: {script_name},可选: {allowed_scripts}"
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
script_path = os.path.join(sim_dir, script_name)
|
script_path = os.path.join(scripts_dir, script_name)
|
||||||
|
|
||||||
if not os.path.exists(script_path):
|
if not os.path.exists(script_path):
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": "脚本文件不存在,请先调用 /prepare 接口"
|
"error": f"脚本文件不存在: {script_name}"
|
||||||
}), 404
|
}), 404
|
||||||
|
|
||||||
return send_file(
|
return send_file(
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ OASIS Agent Profile生成器
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
@ -315,32 +316,54 @@ class OasisProfileGenerator:
|
||||||
comprehensive_query = f"关于{entity_name}的所有信息、活动、事件、关系和背景"
|
comprehensive_query = f"关于{entity_name}的所有信息、活动、事件、关系和背景"
|
||||||
|
|
||||||
def search_edges():
|
def search_edges():
|
||||||
"""搜索边(事实/关系)"""
|
"""搜索边(事实/关系)- 带重试机制"""
|
||||||
try:
|
max_retries = 3
|
||||||
return self.zep_client.graph.search(
|
last_exception = None
|
||||||
query=comprehensive_query,
|
delay = 2.0
|
||||||
graph_id=self.graph_id,
|
|
||||||
limit=30,
|
for attempt in range(max_retries):
|
||||||
scope="edges",
|
try:
|
||||||
reranker="rrf"
|
return self.zep_client.graph.search(
|
||||||
)
|
query=comprehensive_query,
|
||||||
except Exception as e:
|
graph_id=self.graph_id,
|
||||||
logger.debug(f"Zep边搜索失败: {e}")
|
limit=30,
|
||||||
return None
|
scope="edges",
|
||||||
|
reranker="rrf"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.debug(f"Zep边搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
|
||||||
|
time.sleep(delay)
|
||||||
|
delay *= 2
|
||||||
|
else:
|
||||||
|
logger.debug(f"Zep边搜索在 {max_retries} 次尝试后仍失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def search_nodes():
|
def search_nodes():
|
||||||
"""搜索节点(实体摘要)"""
|
"""搜索节点(实体摘要)- 带重试机制"""
|
||||||
try:
|
max_retries = 3
|
||||||
return self.zep_client.graph.search(
|
last_exception = None
|
||||||
query=comprehensive_query,
|
delay = 2.0
|
||||||
graph_id=self.graph_id,
|
|
||||||
limit=20,
|
for attempt in range(max_retries):
|
||||||
scope="nodes",
|
try:
|
||||||
reranker="rrf"
|
return self.zep_client.graph.search(
|
||||||
)
|
query=comprehensive_query,
|
||||||
except Exception as e:
|
graph_id=self.graph_id,
|
||||||
logger.debug(f"Zep节点搜索失败: {e}")
|
limit=20,
|
||||||
return None
|
scope="nodes",
|
||||||
|
reranker="rrf"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.debug(f"Zep节点搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
|
||||||
|
time.sleep(delay)
|
||||||
|
delay *= 2
|
||||||
|
else:
|
||||||
|
logger.debug(f"Zep节点搜索在 {max_retries} 次尝试后仍失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 并行执行edges和nodes搜索
|
# 并行执行edges和nodes搜索
|
||||||
|
|
@ -684,18 +707,20 @@ class OasisProfileGenerator:
|
||||||
- 立场观点(对话题的态度、可能被激怒/感动的内容)
|
- 立场观点(对话题的态度、可能被激怒/感动的内容)
|
||||||
- 独特特征(口头禅、特殊经历、个人爱好)
|
- 独特特征(口头禅、特殊经历、个人爱好)
|
||||||
- 个人记忆(人设的重要部分,要介绍这个个体与事件的关联,以及这个个体在事件中的已有动作与反应)
|
- 个人记忆(人设的重要部分,要介绍这个个体与事件的关联,以及这个个体在事件中的已有动作与反应)
|
||||||
3. age: 年龄数字
|
3. age: 年龄数字(必须是整数)
|
||||||
4. gender: 性别(男/女)
|
4. gender: 性别,必须是英文: "male" 或 "female"
|
||||||
5. mbti: MBTI类型
|
5. mbti: MBTI类型(如INTJ、ENFP等)
|
||||||
6. country: 国家
|
6. country: 国家(使用中文,如"中国")
|
||||||
7. profession: 职业
|
7. profession: 职业
|
||||||
8. interested_topics: 感兴趣话题数组
|
8. interested_topics: 感兴趣话题数组
|
||||||
|
|
||||||
重要:
|
重要:
|
||||||
- 所有字段值必须是字符串或数字,不要使用换行符
|
- 所有字段值必须是字符串或数字,不要使用换行符
|
||||||
- persona必须是一段连贯的文字描述
|
- persona必须是一段连贯的文字描述
|
||||||
- 使用中文
|
- 使用中文(除了gender字段必须用英文male/female)
|
||||||
- 内容要与实体信息保持一致"""
|
- 内容要与实体信息保持一致
|
||||||
|
- age必须是有效的整数,gender必须是"male"或"female"
|
||||||
|
"""
|
||||||
|
|
||||||
def _build_group_persona_prompt(
|
def _build_group_persona_prompt(
|
||||||
self,
|
self,
|
||||||
|
|
@ -731,17 +756,18 @@ class OasisProfileGenerator:
|
||||||
- 立场态度(对核心话题的官方立场、面对争议的处理方式)
|
- 立场态度(对核心话题的官方立场、面对争议的处理方式)
|
||||||
- 特殊说明(代表的群体画像、运营习惯)
|
- 特殊说明(代表的群体画像、运营习惯)
|
||||||
- 机构记忆(机构人设的重要部分,要介绍这个机构与事件的关联,以及这个机构在事件中的已有动作与反应)
|
- 机构记忆(机构人设的重要部分,要介绍这个机构与事件的关联,以及这个机构在事件中的已有动作与反应)
|
||||||
3. age: null(机构不适用)
|
3. age: 固定填30(机构账号的虚拟年龄)
|
||||||
4. gender: null(机构不适用)
|
4. gender: 固定填"other"(机构账号使用other表示非个人)
|
||||||
5. mbti: 可选,用于描述账号风格,如ISTJ代表严谨保守
|
5. mbti: MBTI类型,用于描述账号风格,如ISTJ代表严谨保守
|
||||||
6. country: 国家
|
6. country: 国家(使用中文,如"中国")
|
||||||
7. profession: 机构职能描述
|
7. profession: 机构职能描述
|
||||||
8. interested_topics: 关注领域数组
|
8. interested_topics: 关注领域数组
|
||||||
|
|
||||||
重要:
|
重要:
|
||||||
- 所有字段值必须是字符串、数字或null
|
- 所有字段值必须是字符串或数字,不允许null值
|
||||||
- persona必须是一段连贯的文字描述,不要使用换行符
|
- persona必须是一段连贯的文字描述,不要使用换行符
|
||||||
- 使用中文
|
- 使用中文(除了gender字段必须用英文"other")
|
||||||
|
- age必须是整数30,gender必须是字符串"other"
|
||||||
- 机构账号发言要符合其身份定位"""
|
- 机构账号发言要符合其身份定位"""
|
||||||
|
|
||||||
def _generate_profile_rule_based(
|
def _generate_profile_rule_based(
|
||||||
|
|
@ -784,6 +810,10 @@ class OasisProfileGenerator:
|
||||||
return {
|
return {
|
||||||
"bio": f"Official account for {entity_name}. News and updates.",
|
"bio": f"Official account for {entity_name}. News and updates.",
|
||||||
"persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.",
|
"persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.",
|
||||||
|
"age": 30, # 机构虚拟年龄
|
||||||
|
"gender": "other", # 机构使用other
|
||||||
|
"mbti": "ISTJ", # 机构风格:严谨保守
|
||||||
|
"country": "中国",
|
||||||
"profession": "Media",
|
"profession": "Media",
|
||||||
"interested_topics": ["General News", "Current Events", "Public Affairs"],
|
"interested_topics": ["General News", "Current Events", "Public Affairs"],
|
||||||
}
|
}
|
||||||
|
|
@ -792,6 +822,10 @@ class OasisProfileGenerator:
|
||||||
return {
|
return {
|
||||||
"bio": f"Official account of {entity_name}.",
|
"bio": f"Official account of {entity_name}.",
|
||||||
"persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.",
|
"persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.",
|
||||||
|
"age": 30, # 机构虚拟年龄
|
||||||
|
"gender": "other", # 机构使用other
|
||||||
|
"mbti": "ISTJ", # 机构风格:严谨保守
|
||||||
|
"country": "中国",
|
||||||
"profession": entity_type,
|
"profession": entity_type,
|
||||||
"interested_topics": ["Public Policy", "Community", "Official Announcements"],
|
"interested_topics": ["Public Policy", "Community", "Official Announcements"],
|
||||||
}
|
}
|
||||||
|
|
@ -1039,6 +1073,31 @@ class OasisProfileGenerator:
|
||||||
|
|
||||||
logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (OASIS CSV格式)")
|
logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (OASIS CSV格式)")
|
||||||
|
|
||||||
|
def _normalize_gender(self, gender: Optional[str]) -> str:
|
||||||
|
"""
|
||||||
|
标准化gender字段为OASIS要求的英文格式
|
||||||
|
|
||||||
|
OASIS要求: male, female, other
|
||||||
|
"""
|
||||||
|
if not gender:
|
||||||
|
return "other"
|
||||||
|
|
||||||
|
gender_lower = gender.lower().strip()
|
||||||
|
|
||||||
|
# 中文映射
|
||||||
|
gender_map = {
|
||||||
|
"男": "male",
|
||||||
|
"女": "female",
|
||||||
|
"机构": "other",
|
||||||
|
"其他": "other",
|
||||||
|
# 英文已有
|
||||||
|
"male": "male",
|
||||||
|
"female": "female",
|
||||||
|
"other": "other",
|
||||||
|
}
|
||||||
|
|
||||||
|
return gender_map.get(gender_lower, "other")
|
||||||
|
|
||||||
def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str):
|
def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str):
|
||||||
"""
|
"""
|
||||||
保存Reddit Profile为JSON格式
|
保存Reddit Profile为JSON格式
|
||||||
|
|
@ -1048,26 +1107,30 @@ class OasisProfileGenerator:
|
||||||
2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics
|
2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics
|
||||||
|
|
||||||
我们使用详细格式,与用户示例数据(36个简单人设.json)保持一致
|
我们使用详细格式,与用户示例数据(36个简单人设.json)保持一致
|
||||||
|
|
||||||
|
OASIS要求所有字段都必须存在:
|
||||||
|
- age: 整数
|
||||||
|
- gender: "male", "female", 或 "other"
|
||||||
|
- mbti: MBTI类型字符串
|
||||||
|
- country: 国家字符串
|
||||||
"""
|
"""
|
||||||
data = []
|
data = []
|
||||||
for profile in profiles:
|
for profile in profiles:
|
||||||
# 使用详细格式(与用户示例兼容)
|
# 使用详细格式(与用户示例兼容)
|
||||||
|
# 确保所有必需字段都有有效值
|
||||||
item = {
|
item = {
|
||||||
"realname": profile.name,
|
"realname": profile.name,
|
||||||
"username": profile.user_name,
|
"username": profile.user_name,
|
||||||
"bio": profile.bio[:150] if profile.bio else "", # OASIS bio限制150字符
|
"bio": profile.bio[:150] if profile.bio else f"{profile.name}",
|
||||||
"persona": profile.persona or f"{profile.name} is a participant in social discussions.",
|
"persona": profile.persona or f"{profile.name} is a participant in social discussions.",
|
||||||
|
# OASIS必需字段 - 确保都有默认值
|
||||||
|
"age": profile.age if profile.age else 30,
|
||||||
|
"gender": self._normalize_gender(profile.gender),
|
||||||
|
"mbti": profile.mbti if profile.mbti else "ISTJ",
|
||||||
|
"country": profile.country if profile.country else "中国",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 添加人设详情字段
|
# 可选字段
|
||||||
if profile.age:
|
|
||||||
item["age"] = profile.age
|
|
||||||
if profile.gender:
|
|
||||||
item["gender"] = profile.gender
|
|
||||||
if profile.mbti:
|
|
||||||
item["mbti"] = profile.mbti
|
|
||||||
if profile.country:
|
|
||||||
item["country"] = profile.country
|
|
||||||
if profile.profession:
|
if profile.profession:
|
||||||
item["profession"] = profile.profession
|
item["profession"] = profile.profession
|
||||||
if profile.interested_topics:
|
if profile.interested_topics:
|
||||||
|
|
@ -1078,7 +1141,7 @@ class OasisProfileGenerator:
|
||||||
with open(file_path, 'w', encoding='utf-8') as f:
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式)")
|
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式,已标准化gender字段)")
|
||||||
|
|
||||||
# 保留旧方法名作为别名,保持向后兼容
|
# 保留旧方法名作为别名,保持向后兼容
|
||||||
def save_profiles_to_json(
|
def save_profiles_to_json(
|
||||||
|
|
|
||||||
|
|
@ -127,12 +127,6 @@ class SimulationManager:
|
||||||
'../../uploads/simulations'
|
'../../uploads/simulations'
|
||||||
)
|
)
|
||||||
|
|
||||||
# 预设脚本目录
|
|
||||||
SCRIPTS_DIR = os.path.join(
|
|
||||||
os.path.dirname(__file__),
|
|
||||||
'../../scripts'
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 确保目录存在
|
# 确保目录存在
|
||||||
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
||||||
|
|
@ -426,27 +420,8 @@ class SimulationManager:
|
||||||
total=3
|
total=3
|
||||||
)
|
)
|
||||||
|
|
||||||
# ========== 阶段4: 复制预设脚本 ==========
|
# 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录
|
||||||
script_files = ["run_twitter_simulation.py", "run_reddit_simulation.py",
|
# 启动模拟时,simulation_runner 会从 scripts/ 目录运行脚本
|
||||||
"run_parallel_simulation.py", "action_logger.py"]
|
|
||||||
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(
|
|
||||||
"copying_scripts", 0,
|
|
||||||
"开始准备脚本...",
|
|
||||||
current=0,
|
|
||||||
total=len(script_files)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._copy_preset_scripts(sim_dir)
|
|
||||||
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(
|
|
||||||
"copying_scripts", 100,
|
|
||||||
f"完成,共 {len(script_files)} 个脚本",
|
|
||||||
current=len(script_files),
|
|
||||||
total=len(script_files)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 更新状态
|
# 更新状态
|
||||||
state.status = SimulationStatus.READY
|
state.status = SimulationStatus.READY
|
||||||
|
|
@ -466,24 +441,6 @@ class SimulationManager:
|
||||||
self._save_simulation_state(state)
|
self._save_simulation_state(state)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _copy_preset_scripts(self, sim_dir: str):
|
|
||||||
"""复制预设脚本到模拟目录"""
|
|
||||||
scripts = [
|
|
||||||
"run_twitter_simulation.py",
|
|
||||||
"run_reddit_simulation.py",
|
|
||||||
"run_parallel_simulation.py"
|
|
||||||
]
|
|
||||||
|
|
||||||
for script in scripts:
|
|
||||||
src = os.path.join(self.SCRIPTS_DIR, script)
|
|
||||||
dst = os.path.join(sim_dir, script)
|
|
||||||
|
|
||||||
if os.path.exists(src):
|
|
||||||
shutil.copy2(src, dst)
|
|
||||||
logger.debug(f"复制脚本: {script}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"预设脚本不存在: {src}")
|
|
||||||
|
|
||||||
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
||||||
"""获取模拟状态"""
|
"""获取模拟状态"""
|
||||||
return self._load_simulation_state(simulation_id)
|
return self._load_simulation_state(simulation_id)
|
||||||
|
|
@ -531,21 +488,22 @@ class SimulationManager:
|
||||||
"""获取运行说明"""
|
"""获取运行说明"""
|
||||||
sim_dir = self._get_simulation_dir(simulation_id)
|
sim_dir = self._get_simulation_dir(simulation_id)
|
||||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||||
|
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"simulation_dir": sim_dir,
|
"simulation_dir": sim_dir,
|
||||||
|
"scripts_dir": scripts_dir,
|
||||||
"config_file": config_path,
|
"config_file": config_path,
|
||||||
"commands": {
|
"commands": {
|
||||||
"twitter": f"python run_twitter_simulation.py --config simulation_config.json",
|
"twitter": f"python {scripts_dir}/run_twitter_simulation.py --config {config_path}",
|
||||||
"reddit": f"python run_reddit_simulation.py --config simulation_config.json",
|
"reddit": f"python {scripts_dir}/run_reddit_simulation.py --config {config_path}",
|
||||||
"parallel": f"python run_parallel_simulation.py --config simulation_config.json",
|
"parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}",
|
||||||
},
|
},
|
||||||
"instructions": (
|
"instructions": (
|
||||||
f"1. 进入模拟目录: cd {sim_dir}\n"
|
f"1. 激活conda环境: conda activate MiroFish\n"
|
||||||
f"2. 激活conda环境: conda activate MiroFish\n"
|
f"2. 运行模拟 (脚本位于 {scripts_dir}):\n"
|
||||||
f"3. 运行模拟:\n"
|
f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n"
|
||||||
f" - 单独运行Twitter: python run_twitter_simulation.py --config simulation_config.json\n"
|
f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n"
|
||||||
f" - 单独运行Reddit: python run_reddit_simulation.py --config simulation_config.json\n"
|
f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}"
|
||||||
f" - 并行运行双平台: python run_parallel_simulation.py --config simulation_config.json"
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -182,11 +182,19 @@ class SimulationRunner:
|
||||||
'../../uploads/simulations'
|
'../../uploads/simulations'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 脚本目录
|
||||||
|
SCRIPTS_DIR = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
'../../scripts'
|
||||||
|
)
|
||||||
|
|
||||||
# 内存中的运行状态
|
# 内存中的运行状态
|
||||||
_run_states: Dict[str, SimulationRunState] = {}
|
_run_states: Dict[str, SimulationRunState] = {}
|
||||||
_processes: Dict[str, subprocess.Popen] = {}
|
_processes: Dict[str, subprocess.Popen] = {}
|
||||||
_action_queues: Dict[str, Queue] = {}
|
_action_queues: Dict[str, Queue] = {}
|
||||||
_monitor_threads: Dict[str, threading.Thread] = {}
|
_monitor_threads: Dict[str, threading.Thread] = {}
|
||||||
|
_stdout_files: Dict[str, Any] = {} # 存储 stdout 文件句柄
|
||||||
|
_stderr_files: Dict[str, Any] = {} # 存储 stderr 文件句柄
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
|
def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
|
||||||
|
|
@ -310,7 +318,7 @@ class SimulationRunner:
|
||||||
|
|
||||||
cls._save_run_state(state)
|
cls._save_run_state(state)
|
||||||
|
|
||||||
# 确定运行哪个脚本
|
# 确定运行哪个脚本(脚本位于 backend/scripts/ 目录)
|
||||||
if platform == "twitter":
|
if platform == "twitter":
|
||||||
script_name = "run_twitter_simulation.py"
|
script_name = "run_twitter_simulation.py"
|
||||||
state.twitter_running = True
|
state.twitter_running = True
|
||||||
|
|
@ -322,7 +330,7 @@ class SimulationRunner:
|
||||||
state.twitter_running = True
|
state.twitter_running = True
|
||||||
state.reddit_running = True
|
state.reddit_running = True
|
||||||
|
|
||||||
script_path = os.path.join(sim_dir, script_name)
|
script_path = os.path.join(cls.SCRIPTS_DIR, script_name)
|
||||||
|
|
||||||
if not os.path.exists(script_path):
|
if not os.path.exists(script_path):
|
||||||
raise ValueError(f"脚本不存在: {script_path}")
|
raise ValueError(f"脚本不存在: {script_path}")
|
||||||
|
|
@ -333,24 +341,36 @@ class SimulationRunner:
|
||||||
|
|
||||||
# 启动模拟进程
|
# 启动模拟进程
|
||||||
try:
|
try:
|
||||||
# 构建运行命令
|
# 构建运行命令,使用完整路径
|
||||||
|
action_log_path = os.path.join(sim_dir, "actions.jsonl")
|
||||||
|
|
||||||
cmd = [
|
cmd = [
|
||||||
sys.executable, # Python解释器
|
sys.executable, # Python解释器
|
||||||
script_path,
|
script_path,
|
||||||
"--config", "simulation_config.json",
|
"--config", config_path, # 使用完整配置文件路径
|
||||||
"--action-log", "actions.jsonl", # 动作日志文件
|
"--action-log", action_log_path, # 动作日志文件完整路径
|
||||||
]
|
]
|
||||||
|
|
||||||
# 设置工作目录为模拟目录
|
# 创建输出日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
|
||||||
|
stdout_log_path = os.path.join(sim_dir, "simulation_stdout.log")
|
||||||
|
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
|
||||||
|
stdout_file = open(stdout_log_path, 'w', encoding='utf-8')
|
||||||
|
stderr_file = open(stderr_log_path, 'w', encoding='utf-8')
|
||||||
|
|
||||||
|
# 设置工作目录为模拟目录(数据库等文件会生成在此)
|
||||||
process = subprocess.Popen(
|
process = subprocess.Popen(
|
||||||
cmd,
|
cmd,
|
||||||
cwd=sim_dir,
|
cwd=sim_dir,
|
||||||
stdout=subprocess.PIPE,
|
stdout=stdout_file,
|
||||||
stderr=subprocess.PIPE,
|
stderr=stderr_file,
|
||||||
text=True,
|
text=True,
|
||||||
bufsize=1,
|
bufsize=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 保存文件句柄以便后续关闭
|
||||||
|
cls._stdout_files[simulation_id] = stdout_file
|
||||||
|
cls._stderr_files[simulation_id] = stderr_file
|
||||||
|
|
||||||
state.process_pid = process.pid
|
state.process_pid = process.pid
|
||||||
state.runner_status = RunnerStatus.RUNNING
|
state.runner_status = RunnerStatus.RUNNING
|
||||||
cls._processes[simulation_id] = process
|
cls._processes[simulation_id] = process
|
||||||
|
|
@ -434,8 +454,16 @@ class SimulationRunner:
|
||||||
logger.info(f"模拟完成: {simulation_id}")
|
logger.info(f"模拟完成: {simulation_id}")
|
||||||
else:
|
else:
|
||||||
state.runner_status = RunnerStatus.FAILED
|
state.runner_status = RunnerStatus.FAILED
|
||||||
stderr = process.stderr.read() if process.stderr else ""
|
# 从 stderr 日志文件读取错误信息
|
||||||
state.error = f"进程退出码: {exit_code}, 错误: {stderr[:500]}"
|
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
|
||||||
|
stderr = ""
|
||||||
|
try:
|
||||||
|
if os.path.exists(stderr_log_path):
|
||||||
|
with open(stderr_log_path, 'r', encoding='utf-8') as f:
|
||||||
|
stderr = f.read()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
state.error = f"进程退出码: {exit_code}, 错误: {stderr[-1000:]}" # 取最后1000字符
|
||||||
logger.error(f"模拟失败: {simulation_id}, error={state.error}")
|
logger.error(f"模拟失败: {simulation_id}, error={state.error}")
|
||||||
|
|
||||||
state.twitter_running = False
|
state.twitter_running = False
|
||||||
|
|
@ -449,10 +477,24 @@ class SimulationRunner:
|
||||||
cls._save_run_state(state)
|
cls._save_run_state(state)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# 清理
|
# 清理进程资源
|
||||||
cls._processes.pop(simulation_id, None)
|
cls._processes.pop(simulation_id, None)
|
||||||
cls._action_queues.pop(simulation_id, None)
|
cls._action_queues.pop(simulation_id, None)
|
||||||
|
|
||||||
|
# 关闭日志文件句柄
|
||||||
|
if simulation_id in cls._stdout_files:
|
||||||
|
try:
|
||||||
|
cls._stdout_files[simulation_id].close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
cls._stdout_files.pop(simulation_id, None)
|
||||||
|
if simulation_id in cls._stderr_files:
|
||||||
|
try:
|
||||||
|
cls._stderr_files[simulation_id].close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
cls._stderr_files.pop(simulation_id, None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
|
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
|
||||||
"""停止模拟"""
|
"""停止模拟"""
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,8 @@ Zep实体读取与过滤服务
|
||||||
从Zep图谱中读取节点,筛选出符合预定义实体类型的节点
|
从Zep图谱中读取节点,筛选出符合预定义实体类型的节点
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Any, List, Optional, Set
|
import time
|
||||||
|
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from zep_cloud.client import Zep
|
from zep_cloud.client import Zep
|
||||||
|
|
@ -13,6 +14,9 @@ from ..utils.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger('mirofish.zep_entity_reader')
|
logger = get_logger('mirofish.zep_entity_reader')
|
||||||
|
|
||||||
|
# 用于泛型返回类型
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EntityNode:
|
class EntityNode:
|
||||||
|
|
@ -80,9 +84,48 @@ class ZepEntityReader:
|
||||||
|
|
||||||
self.client = Zep(api_key=self.api_key)
|
self.client = Zep(api_key=self.api_key)
|
||||||
|
|
||||||
|
def _call_with_retry(
|
||||||
|
self,
|
||||||
|
func: Callable[[], T],
|
||||||
|
operation_name: str,
|
||||||
|
max_retries: int = 3,
|
||||||
|
initial_delay: float = 2.0
|
||||||
|
) -> T:
|
||||||
|
"""
|
||||||
|
带重试机制的Zep API调用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: 要执行的函数(无参数的lambda或callable)
|
||||||
|
operation_name: 操作名称,用于日志
|
||||||
|
max_retries: 最大重试次数(默认3次,即最多尝试3次)
|
||||||
|
initial_delay: 初始延迟秒数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API调用结果
|
||||||
|
"""
|
||||||
|
last_exception = None
|
||||||
|
delay = initial_delay
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.warning(
|
||||||
|
f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, "
|
||||||
|
f"{delay:.1f}秒后重试..."
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
delay *= 2 # 指数退避
|
||||||
|
else:
|
||||||
|
logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}")
|
||||||
|
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取图谱的所有节点
|
获取图谱的所有节点(带重试机制)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: 图谱ID
|
||||||
|
|
@ -92,7 +135,11 @@ class ZepEntityReader:
|
||||||
"""
|
"""
|
||||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||||
|
|
||||||
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
# 使用重试机制调用Zep API
|
||||||
|
nodes = self._call_with_retry(
|
||||||
|
func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id),
|
||||||
|
operation_name=f"获取节点(graph={graph_id})"
|
||||||
|
)
|
||||||
|
|
||||||
nodes_data = []
|
nodes_data = []
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
|
|
@ -109,7 +156,7 @@ class ZepEntityReader:
|
||||||
|
|
||||||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取图谱的所有边
|
获取图谱的所有边(带重试机制)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: 图谱ID
|
||||||
|
|
@ -119,7 +166,11 @@ class ZepEntityReader:
|
||||||
"""
|
"""
|
||||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||||
|
|
||||||
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
# 使用重试机制调用Zep API
|
||||||
|
edges = self._call_with_retry(
|
||||||
|
func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id),
|
||||||
|
operation_name=f"获取边(graph={graph_id})"
|
||||||
|
)
|
||||||
|
|
||||||
edges_data = []
|
edges_data = []
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
|
|
@ -137,7 +188,7 @@ class ZepEntityReader:
|
||||||
|
|
||||||
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取指定节点的所有相关边
|
获取指定节点的所有相关边(带重试机制)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_uuid: 节点UUID
|
node_uuid: 节点UUID
|
||||||
|
|
@ -146,7 +197,11 @@ class ZepEntityReader:
|
||||||
边列表
|
边列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid)
|
# 使用重试机制调用Zep API
|
||||||
|
edges = self._call_with_retry(
|
||||||
|
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
|
||||||
|
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
|
||||||
|
)
|
||||||
|
|
||||||
edges_data = []
|
edges_data = []
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
|
|
@ -288,7 +343,7 @@ class ZepEntityReader:
|
||||||
entity_uuid: str
|
entity_uuid: str
|
||||||
) -> Optional[EntityNode]:
|
) -> Optional[EntityNode]:
|
||||||
"""
|
"""
|
||||||
获取单个实体及其完整上下文(边和关联节点)
|
获取单个实体及其完整上下文(边和关联节点,带重试机制)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: 图谱ID
|
||||||
|
|
@ -298,8 +353,11 @@ class ZepEntityReader:
|
||||||
EntityNode或None
|
EntityNode或None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 获取节点
|
# 使用重试机制获取节点
|
||||||
node = self.client.graph.node.get(uuid_=entity_uuid)
|
node = self._call_with_retry(
|
||||||
|
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
|
||||||
|
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
|
||||||
|
)
|
||||||
|
|
||||||
if not node:
|
if not node:
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -9,13 +9,118 @@ OASIS 双平台并行模拟预设脚本
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
# 添加 backend 目录到路径
|
||||||
|
# 脚本固定位于 backend/scripts/ 目录
|
||||||
|
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..'))
|
||||||
|
_project_root = os.path.abspath(os.path.join(_backend_dir, '..'))
|
||||||
|
sys.path.insert(0, _scripts_dir)
|
||||||
|
sys.path.insert(0, _backend_dir)
|
||||||
|
|
||||||
|
# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置)
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
_env_file = os.path.join(_project_root, '.env')
|
||||||
|
if os.path.exists(_env_file):
|
||||||
|
load_dotenv(_env_file)
|
||||||
|
print(f"已加载环境配置: {_env_file}")
|
||||||
|
else:
|
||||||
|
# 尝试加载 backend/.env
|
||||||
|
_backend_env = os.path.join(_backend_dir, '.env')
|
||||||
|
if os.path.exists(_backend_env):
|
||||||
|
load_dotenv(_backend_env)
|
||||||
|
print(f"已加载环境配置: {_backend_env}")
|
||||||
|
|
||||||
|
|
||||||
|
class UnicodeFormatter(logging.Formatter):
|
||||||
|
"""
|
||||||
|
自定义格式化器,将 Unicode 转义序列(如 \\uXXXX)转换为可读字符
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 匹配 \uXXXX 形式的 Unicode 转义序列
|
||||||
|
UNICODE_ESCAPE_PATTERN = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_pattern(cls):
|
||||||
|
if cls.UNICODE_ESCAPE_PATTERN is None:
|
||||||
|
import re
|
||||||
|
cls.UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
|
||||||
|
return cls.UNICODE_ESCAPE_PATTERN
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
# 先获取原始格式化结果
|
||||||
|
result = super().format(record)
|
||||||
|
# 使用正则表达式替换 Unicode 转义序列
|
||||||
|
pattern = self._get_pattern()
|
||||||
|
|
||||||
|
def replace_unicode(match):
|
||||||
|
try:
|
||||||
|
return chr(int(match.group(1), 16))
|
||||||
|
except (ValueError, OverflowError):
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
return pattern.sub(replace_unicode, result)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_oasis_logging(log_dir: str):
|
||||||
|
"""
|
||||||
|
配置 OASIS 的日志,覆盖默认的带时间戳日志文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_dir: 日志目录路径
|
||||||
|
"""
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 清理旧的日志文件
|
||||||
|
for f in os.listdir(log_dir):
|
||||||
|
old_log = os.path.join(log_dir, f)
|
||||||
|
if os.path.isfile(old_log) and f.endswith('.log'):
|
||||||
|
try:
|
||||||
|
os.remove(old_log)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 创建自定义格式化器(支持 Unicode 解码)
|
||||||
|
formatter = UnicodeFormatter(
|
||||||
|
"%(levelname)s - %(asctime)s - %(name)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 重新配置 OASIS 使用的日志器,使用固定名称(不带时间戳)
|
||||||
|
loggers_config = {
|
||||||
|
"social.agent": os.path.join(log_dir, "social.agent.log"),
|
||||||
|
"social.twitter": os.path.join(log_dir, "social.twitter.log"),
|
||||||
|
"social.rec": os.path.join(log_dir, "social.rec.log"),
|
||||||
|
"oasis.env": os.path.join(log_dir, "oasis.env.log"),
|
||||||
|
"table": os.path.join(log_dir, "table.log"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for logger_name, log_file in loggers_config.items():
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
# 清除 OASIS 添加的现有处理器(带时间戳的日志文件)
|
||||||
|
logger.handlers.clear()
|
||||||
|
# 添加新的文件处理器(使用 UTF-8 编码,固定文件名)
|
||||||
|
file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w')
|
||||||
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
# 防止日志向上传播(避免重复)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
print(f"日志配置完成,日志目录: {log_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
def init_logging_for_simulation(simulation_dir: str):
|
||||||
|
"""初始化模拟的日志配置"""
|
||||||
|
log_dir = os.path.join(simulation_dir, "log")
|
||||||
|
setup_oasis_logging(log_dir)
|
||||||
|
|
||||||
|
|
||||||
from action_logger import ActionLogger
|
from action_logger import ActionLogger
|
||||||
|
|
||||||
|
|
@ -74,17 +179,34 @@ def create_model(config: Dict[str, Any]):
|
||||||
"""
|
"""
|
||||||
创建LLM模型
|
创建LLM模型
|
||||||
|
|
||||||
OASIS使用camel-ai的ModelFactory,配置方式:
|
统一使用项目根目录 .env 文件中的配置(优先级最高):
|
||||||
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量
|
- LLM_API_KEY: API密钥
|
||||||
- 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量
|
- LLM_BASE_URL: API基础URL
|
||||||
"""
|
- LLM_MODEL_NAME: 模型名称
|
||||||
llm_model = config.get("llm_model", "gpt-4o-mini")
|
|
||||||
llm_base_url = config.get("llm_base_url", "")
|
OASIS使用camel-ai的ModelFactory,需要设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量
|
||||||
|
"""
|
||||||
|
# 优先从 .env 读取配置
|
||||||
|
llm_api_key = os.environ.get("LLM_API_KEY", "")
|
||||||
|
llm_base_url = os.environ.get("LLM_BASE_URL", "")
|
||||||
|
llm_model = os.environ.get("LLM_MODEL_NAME", "")
|
||||||
|
|
||||||
|
# 如果 .env 中没有,则使用 config 作为备用
|
||||||
|
if not llm_model:
|
||||||
|
llm_model = config.get("llm_model", "gpt-4o-mini")
|
||||||
|
|
||||||
|
# 设置 camel-ai 所需的环境变量
|
||||||
|
if llm_api_key:
|
||||||
|
os.environ["OPENAI_API_KEY"] = llm_api_key
|
||||||
|
|
||||||
|
if not os.environ.get("OPENAI_API_KEY"):
|
||||||
|
raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY")
|
||||||
|
|
||||||
# 如果配置了base_url,设置环境变量
|
|
||||||
if llm_base_url:
|
if llm_base_url:
|
||||||
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
|
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
|
||||||
|
|
||||||
|
print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...")
|
||||||
|
|
||||||
return ModelFactory.create(
|
return ModelFactory.create(
|
||||||
model_platform=ModelPlatformType.OPENAI,
|
model_platform=ModelPlatformType.OPENAI,
|
||||||
model_type=llm_model,
|
model_type=llm_model,
|
||||||
|
|
@ -453,6 +575,9 @@ async def main():
|
||||||
config = load_config(args.config)
|
config = load_config(args.config)
|
||||||
simulation_dir = os.path.dirname(args.config) or "."
|
simulation_dir = os.path.dirname(args.config) or "."
|
||||||
|
|
||||||
|
# 初始化日志配置(清理旧日志文件,使用固定名称)
|
||||||
|
init_logging_for_simulation(simulation_dir)
|
||||||
|
|
||||||
# 创建动作日志记录器
|
# 创建动作日志记录器
|
||||||
action_log_path = os.path.join(simulation_dir, args.action_log)
|
action_log_path = os.path.join(simulation_dir, args.action_log)
|
||||||
action_logger = ActionLogger(action_log_path)
|
action_logger = ActionLogger(action_log_path)
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ OASIS Reddit模拟预设脚本
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -16,7 +17,76 @@ from datetime import datetime
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
# 添加项目路径
|
# 添加项目路径
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..'))
|
||||||
|
_project_root = os.path.abspath(os.path.join(_backend_dir, '..'))
|
||||||
|
sys.path.insert(0, _scripts_dir)
|
||||||
|
sys.path.insert(0, _backend_dir)
|
||||||
|
|
||||||
|
# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置)
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
_env_file = os.path.join(_project_root, '.env')
|
||||||
|
if os.path.exists(_env_file):
|
||||||
|
load_dotenv(_env_file)
|
||||||
|
else:
|
||||||
|
_backend_env = os.path.join(_backend_dir, '.env')
|
||||||
|
if os.path.exists(_backend_env):
|
||||||
|
load_dotenv(_backend_env)
|
||||||
|
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class UnicodeFormatter(logging.Formatter):
|
||||||
|
"""自定义格式化器,将 Unicode 转义序列转换为可读字符"""
|
||||||
|
|
||||||
|
UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
result = super().format(record)
|
||||||
|
|
||||||
|
def replace_unicode(match):
|
||||||
|
try:
|
||||||
|
return chr(int(match.group(1), 16))
|
||||||
|
except (ValueError, OverflowError):
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
return self.UNICODE_ESCAPE_PATTERN.sub(replace_unicode, result)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_oasis_logging(log_dir: str):
|
||||||
|
"""配置 OASIS 的日志,使用固定名称的日志文件"""
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 清理旧的日志文件
|
||||||
|
for f in os.listdir(log_dir):
|
||||||
|
old_log = os.path.join(log_dir, f)
|
||||||
|
if os.path.isfile(old_log) and f.endswith('.log'):
|
||||||
|
try:
|
||||||
|
os.remove(old_log)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
formatter = UnicodeFormatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s")
|
||||||
|
|
||||||
|
loggers_config = {
|
||||||
|
"social.agent": os.path.join(log_dir, "social.agent.log"),
|
||||||
|
"social.twitter": os.path.join(log_dir, "social.twitter.log"),
|
||||||
|
"social.rec": os.path.join(log_dir, "social.rec.log"),
|
||||||
|
"oasis.env": os.path.join(log_dir, "oasis.env.log"),
|
||||||
|
"table": os.path.join(log_dir, "table.log"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for logger_name, log_file in loggers_config.items():
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logger.handlers.clear()
|
||||||
|
file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w')
|
||||||
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from camel.models import ModelFactory
|
from camel.models import ModelFactory
|
||||||
|
|
@ -82,19 +152,32 @@ class RedditSimulationRunner:
|
||||||
"""
|
"""
|
||||||
创建LLM模型
|
创建LLM模型
|
||||||
|
|
||||||
OASIS使用camel-ai的ModelFactory,配置方式:
|
统一使用项目根目录 .env 文件中的配置(优先级最高):
|
||||||
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量
|
- LLM_API_KEY: API密钥
|
||||||
- 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量
|
- LLM_BASE_URL: API基础URL
|
||||||
|
- LLM_MODEL_NAME: 模型名称
|
||||||
"""
|
"""
|
||||||
import os
|
# 优先从 .env 读取配置
|
||||||
|
llm_api_key = os.environ.get("LLM_API_KEY", "")
|
||||||
|
llm_base_url = os.environ.get("LLM_BASE_URL", "")
|
||||||
|
llm_model = os.environ.get("LLM_MODEL_NAME", "")
|
||||||
|
|
||||||
llm_model = self.config.get("llm_model", "gpt-4o-mini")
|
# 如果 .env 中没有,则使用 config 作为备用
|
||||||
llm_base_url = self.config.get("llm_base_url", "")
|
if not llm_model:
|
||||||
|
llm_model = self.config.get("llm_model", "gpt-4o-mini")
|
||||||
|
|
||||||
|
# 设置 camel-ai 所需的环境变量
|
||||||
|
if llm_api_key:
|
||||||
|
os.environ["OPENAI_API_KEY"] = llm_api_key
|
||||||
|
|
||||||
|
if not os.environ.get("OPENAI_API_KEY"):
|
||||||
|
raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY")
|
||||||
|
|
||||||
# 如果配置了base_url,设置环境变量
|
|
||||||
if llm_base_url:
|
if llm_base_url:
|
||||||
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
|
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
|
||||||
|
|
||||||
|
print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...")
|
||||||
|
|
||||||
return ModelFactory.create(
|
return ModelFactory.create(
|
||||||
model_platform=ModelPlatformType.OPENAI,
|
model_platform=ModelPlatformType.OPENAI,
|
||||||
model_type=llm_model,
|
model_type=llm_model,
|
||||||
|
|
@ -289,6 +372,10 @@ async def main():
|
||||||
print(f"错误: 配置文件不存在: {args.config}")
|
print(f"错误: 配置文件不存在: {args.config}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 初始化日志配置(使用固定文件名,清理旧日志)
|
||||||
|
simulation_dir = os.path.dirname(args.config) or "."
|
||||||
|
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
||||||
|
|
||||||
runner = RedditSimulationRunner(args.config)
|
runner = RedditSimulationRunner(args.config)
|
||||||
await runner.run()
|
await runner.run()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ OASIS Twitter模拟预设脚本
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -16,7 +17,76 @@ from datetime import datetime
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
# 添加项目路径
|
# 添加项目路径
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..'))
|
||||||
|
_project_root = os.path.abspath(os.path.join(_backend_dir, '..'))
|
||||||
|
sys.path.insert(0, _scripts_dir)
|
||||||
|
sys.path.insert(0, _backend_dir)
|
||||||
|
|
||||||
|
# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置)
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
_env_file = os.path.join(_project_root, '.env')
|
||||||
|
if os.path.exists(_env_file):
|
||||||
|
load_dotenv(_env_file)
|
||||||
|
else:
|
||||||
|
_backend_env = os.path.join(_backend_dir, '.env')
|
||||||
|
if os.path.exists(_backend_env):
|
||||||
|
load_dotenv(_backend_env)
|
||||||
|
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class UnicodeFormatter(logging.Formatter):
|
||||||
|
"""自定义格式化器,将 Unicode 转义序列转换为可读字符"""
|
||||||
|
|
||||||
|
UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
result = super().format(record)
|
||||||
|
|
||||||
|
def replace_unicode(match):
|
||||||
|
try:
|
||||||
|
return chr(int(match.group(1), 16))
|
||||||
|
except (ValueError, OverflowError):
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
return self.UNICODE_ESCAPE_PATTERN.sub(replace_unicode, result)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_oasis_logging(log_dir: str):
|
||||||
|
"""配置 OASIS 的日志,使用固定名称的日志文件"""
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 清理旧的日志文件
|
||||||
|
for f in os.listdir(log_dir):
|
||||||
|
old_log = os.path.join(log_dir, f)
|
||||||
|
if os.path.isfile(old_log) and f.endswith('.log'):
|
||||||
|
try:
|
||||||
|
os.remove(old_log)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
formatter = UnicodeFormatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s")
|
||||||
|
|
||||||
|
loggers_config = {
|
||||||
|
"social.agent": os.path.join(log_dir, "social.agent.log"),
|
||||||
|
"social.twitter": os.path.join(log_dir, "social.twitter.log"),
|
||||||
|
"social.rec": os.path.join(log_dir, "social.rec.log"),
|
||||||
|
"oasis.env": os.path.join(log_dir, "oasis.env.log"),
|
||||||
|
"table": os.path.join(log_dir, "table.log"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for logger_name, log_file in loggers_config.items():
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logger.handlers.clear()
|
||||||
|
file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w')
|
||||||
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from camel.models import ModelFactory
|
from camel.models import ModelFactory
|
||||||
|
|
@ -75,21 +145,32 @@ class TwitterSimulationRunner:
|
||||||
"""
|
"""
|
||||||
创建LLM模型
|
创建LLM模型
|
||||||
|
|
||||||
OASIS使用camel-ai的ModelFactory,配置方式:
|
统一使用项目根目录 .env 文件中的配置(优先级最高):
|
||||||
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量
|
- LLM_API_KEY: API密钥
|
||||||
- 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量
|
- LLM_BASE_URL: API基础URL
|
||||||
|
- LLM_MODEL_NAME: 模型名称
|
||||||
配置文件中的 llm_model 对应 model_type
|
|
||||||
"""
|
"""
|
||||||
import os
|
# 优先从 .env 读取配置
|
||||||
|
llm_api_key = os.environ.get("LLM_API_KEY", "")
|
||||||
|
llm_base_url = os.environ.get("LLM_BASE_URL", "")
|
||||||
|
llm_model = os.environ.get("LLM_MODEL_NAME", "")
|
||||||
|
|
||||||
llm_model = self.config.get("llm_model", "gpt-4o-mini")
|
# 如果 .env 中没有,则使用 config 作为备用
|
||||||
llm_base_url = self.config.get("llm_base_url", "")
|
if not llm_model:
|
||||||
|
llm_model = self.config.get("llm_model", "gpt-4o-mini")
|
||||||
|
|
||||||
|
# 设置 camel-ai 所需的环境变量
|
||||||
|
if llm_api_key:
|
||||||
|
os.environ["OPENAI_API_KEY"] = llm_api_key
|
||||||
|
|
||||||
|
if not os.environ.get("OPENAI_API_KEY"):
|
||||||
|
raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY")
|
||||||
|
|
||||||
# 如果配置了base_url,设置环境变量(OASIS通过环境变量读取)
|
|
||||||
if llm_base_url:
|
if llm_base_url:
|
||||||
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
|
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
|
||||||
|
|
||||||
|
print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...")
|
||||||
|
|
||||||
return ModelFactory.create(
|
return ModelFactory.create(
|
||||||
model_platform=ModelPlatformType.OPENAI,
|
model_platform=ModelPlatformType.OPENAI,
|
||||||
model_type=llm_model,
|
model_type=llm_model,
|
||||||
|
|
@ -304,6 +385,10 @@ async def main():
|
||||||
print(f"错误: 配置文件不存在: {args.config}")
|
print(f"错误: 配置文件不存在: {args.config}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 初始化日志配置(使用固定文件名,清理旧日志)
|
||||||
|
simulation_dir = os.path.dirname(args.config) or "."
|
||||||
|
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
||||||
|
|
||||||
runner = TwitterSimulationRunner(args.config)
|
runner = TwitterSimulationRunner(args.config)
|
||||||
await runner.run()
|
await runner.run()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue