Implement real-time profile retrieval and saving in simulation API
- Added a new endpoint to retrieve real-time agent profiles during simulation, allowing users to monitor progress without going through the SimulationManager. - Enhanced the profile generation process to support real-time saving of generated profiles to specified file formats (JSON for Reddit, CSV for Twitter). - Updated the simulation configuration generator to assign appropriate agents to initial posts based on their types, improving the relevance of generated content. - Improved error handling and logging for better traceability during profile generation and retrieval processes.
This commit is contained in:
parent
39253b3213
commit
88676e8207
4 changed files with 292 additions and 7 deletions
|
|
@ -803,6 +803,116 @@ def get_simulation_profiles(simulation_id: str):
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@simulation_bp.route('/<simulation_id>/profiles/realtime', methods=['GET'])
|
||||||
|
def get_simulation_profiles_realtime(simulation_id: str):
|
||||||
|
"""
|
||||||
|
实时获取模拟的Agent Profile(用于在生成过程中实时查看进度)
|
||||||
|
|
||||||
|
与 /profiles 接口的区别:
|
||||||
|
- 直接读取文件,不经过 SimulationManager
|
||||||
|
- 适用于生成过程中的实时查看
|
||||||
|
- 返回额外的元数据(如文件修改时间、是否正在生成等)
|
||||||
|
|
||||||
|
Query参数:
|
||||||
|
platform: 平台类型(reddit/twitter,默认reddit)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": {
|
||||||
|
"simulation_id": "sim_xxxx",
|
||||||
|
"platform": "reddit",
|
||||||
|
"count": 15,
|
||||||
|
"total_expected": 93, // 预期总数(如果有)
|
||||||
|
"is_generating": true, // 是否正在生成
|
||||||
|
"file_exists": true,
|
||||||
|
"file_modified_at": "2025-12-04T18:20:00",
|
||||||
|
"profiles": [...]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import csv
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
try:
|
||||||
|
platform = request.args.get('platform', 'reddit')
|
||||||
|
|
||||||
|
# 获取模拟目录
|
||||||
|
sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id)
|
||||||
|
|
||||||
|
if not os.path.exists(sim_dir):
|
||||||
|
return jsonify({
|
||||||
|
"success": False,
|
||||||
|
"error": f"模拟不存在: {simulation_id}"
|
||||||
|
}), 404
|
||||||
|
|
||||||
|
# 确定文件路径
|
||||||
|
if platform == "reddit":
|
||||||
|
profiles_file = os.path.join(sim_dir, "reddit_profiles.json")
|
||||||
|
else:
|
||||||
|
profiles_file = os.path.join(sim_dir, "twitter_profiles.csv")
|
||||||
|
|
||||||
|
# 检查文件是否存在
|
||||||
|
file_exists = os.path.exists(profiles_file)
|
||||||
|
profiles = []
|
||||||
|
file_modified_at = None
|
||||||
|
|
||||||
|
if file_exists:
|
||||||
|
# 获取文件修改时间
|
||||||
|
file_stat = os.stat(profiles_file)
|
||||||
|
file_modified_at = datetime.fromtimestamp(file_stat.st_mtime).isoformat()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if platform == "reddit":
|
||||||
|
with open(profiles_file, 'r', encoding='utf-8') as f:
|
||||||
|
profiles = json.load(f)
|
||||||
|
else:
|
||||||
|
with open(profiles_file, 'r', encoding='utf-8') as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
profiles = list(reader)
|
||||||
|
except (json.JSONDecodeError, Exception) as e:
|
||||||
|
logger.warning(f"读取 profiles 文件失败(可能正在写入中): {e}")
|
||||||
|
profiles = []
|
||||||
|
|
||||||
|
# 检查是否正在生成(通过 state.json 判断)
|
||||||
|
is_generating = False
|
||||||
|
total_expected = None
|
||||||
|
|
||||||
|
state_file = os.path.join(sim_dir, "state.json")
|
||||||
|
if os.path.exists(state_file):
|
||||||
|
try:
|
||||||
|
with open(state_file, 'r', encoding='utf-8') as f:
|
||||||
|
state_data = json.load(f)
|
||||||
|
status = state_data.get("status", "")
|
||||||
|
is_generating = status == "preparing"
|
||||||
|
total_expected = state_data.get("entities_count")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"simulation_id": simulation_id,
|
||||||
|
"platform": platform,
|
||||||
|
"count": len(profiles),
|
||||||
|
"total_expected": total_expected,
|
||||||
|
"is_generating": is_generating,
|
||||||
|
"file_exists": file_exists,
|
||||||
|
"file_modified_at": file_modified_at,
|
||||||
|
"profiles": profiles
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"实时获取Profile失败: {str(e)}")
|
||||||
|
return jsonify({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"traceback": traceback.format_exc()
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
@simulation_bp.route('/<simulation_id>/config', methods=['GET'])
|
@simulation_bp.route('/<simulation_id>/config', methods=['GET'])
|
||||||
def get_simulation_config(simulation_id: str):
|
def get_simulation_config(simulation_id: str):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -853,7 +853,9 @@ class OasisProfileGenerator:
|
||||||
use_llm: bool = True,
|
use_llm: bool = True,
|
||||||
progress_callback: Optional[callable] = None,
|
progress_callback: Optional[callable] = None,
|
||||||
graph_id: Optional[str] = None,
|
graph_id: Optional[str] = None,
|
||||||
parallel_count: int = 5
|
parallel_count: int = 5,
|
||||||
|
realtime_output_path: Optional[str] = None,
|
||||||
|
output_platform: str = "reddit"
|
||||||
) -> List[OasisAgentProfile]:
|
) -> List[OasisAgentProfile]:
|
||||||
"""
|
"""
|
||||||
批量从实体生成Agent Profile(支持并行生成)
|
批量从实体生成Agent Profile(支持并行生成)
|
||||||
|
|
@ -864,6 +866,8 @@ class OasisProfileGenerator:
|
||||||
progress_callback: 进度回调函数 (current, total, message)
|
progress_callback: 进度回调函数 (current, total, message)
|
||||||
graph_id: 图谱ID,用于Zep检索获取更丰富上下文
|
graph_id: 图谱ID,用于Zep检索获取更丰富上下文
|
||||||
parallel_count: 并行生成数量,默认5
|
parallel_count: 并行生成数量,默认5
|
||||||
|
realtime_output_path: 实时写入的文件路径(如果提供,每生成一个就写入一次)
|
||||||
|
output_platform: 输出平台格式 ("reddit" 或 "twitter")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent Profile列表
|
Agent Profile列表
|
||||||
|
|
@ -880,6 +884,37 @@ class OasisProfileGenerator:
|
||||||
completed_count = [0] # 使用列表以便在闭包中修改
|
completed_count = [0] # 使用列表以便在闭包中修改
|
||||||
lock = Lock()
|
lock = Lock()
|
||||||
|
|
||||||
|
# 实时写入文件的辅助函数
|
||||||
|
def save_profiles_realtime():
|
||||||
|
"""实时保存已生成的 profiles 到文件"""
|
||||||
|
if not realtime_output_path:
|
||||||
|
return
|
||||||
|
|
||||||
|
with lock:
|
||||||
|
# 过滤出已生成的 profiles
|
||||||
|
existing_profiles = [p for p in profiles if p is not None]
|
||||||
|
if not existing_profiles:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if output_platform == "reddit":
|
||||||
|
# Reddit JSON 格式
|
||||||
|
profiles_data = [p.to_reddit_format() for p in existing_profiles]
|
||||||
|
with open(realtime_output_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(profiles_data, f, ensure_ascii=False, indent=2)
|
||||||
|
else:
|
||||||
|
# Twitter CSV 格式
|
||||||
|
import csv
|
||||||
|
profiles_data = [p.to_twitter_format() for p in existing_profiles]
|
||||||
|
if profiles_data:
|
||||||
|
fieldnames = list(profiles_data[0].keys())
|
||||||
|
with open(realtime_output_path, 'w', encoding='utf-8', newline='') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(profiles_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"实时保存 profiles 失败: {e}")
|
||||||
|
|
||||||
def generate_single_profile(idx: int, entity: EntityNode) -> tuple:
|
def generate_single_profile(idx: int, entity: EntityNode) -> tuple:
|
||||||
"""生成单个profile的工作函数"""
|
"""生成单个profile的工作函数"""
|
||||||
entity_type = entity.get_entity_type() or "Entity"
|
entity_type = entity.get_entity_type() or "Entity"
|
||||||
|
|
@ -936,6 +971,9 @@ class OasisProfileGenerator:
|
||||||
completed_count[0] += 1
|
completed_count[0] += 1
|
||||||
current = completed_count[0]
|
current = completed_count[0]
|
||||||
|
|
||||||
|
# 实时写入文件
|
||||||
|
save_profiles_realtime()
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(
|
progress_callback(
|
||||||
current,
|
current,
|
||||||
|
|
@ -961,6 +999,8 @@ class OasisProfileGenerator:
|
||||||
source_entity_uuid=entity.uuid,
|
source_entity_uuid=entity.uuid,
|
||||||
source_entity_type=entity_type,
|
source_entity_type=entity_type,
|
||||||
)
|
)
|
||||||
|
# 实时写入文件(即使是备用人设)
|
||||||
|
save_profiles_realtime()
|
||||||
|
|
||||||
print(f"\n{'='*60}")
|
print(f"\n{'='*60}")
|
||||||
print(f"人设生成完成!共生成 {len([p for p in profiles if p])} 个Agent")
|
print(f"人设生成完成!共生成 {len([p for p in profiles if p])} 个Agent")
|
||||||
|
|
|
||||||
|
|
@ -292,7 +292,7 @@ class SimulationConfigGenerator:
|
||||||
|
|
||||||
# ========== 步骤2: 生成事件配置 ==========
|
# ========== 步骤2: 生成事件配置 ==========
|
||||||
report_progress(2, "生成事件配置和热点话题...")
|
report_progress(2, "生成事件配置和热点话题...")
|
||||||
event_config_result = self._generate_event_config(context, simulation_requirement)
|
event_config_result = self._generate_event_config(context, simulation_requirement, entities)
|
||||||
event_config = self._parse_event_config(event_config_result)
|
event_config = self._parse_event_config(event_config_result)
|
||||||
reasoning_parts.append(f"事件配置: {event_config_result.get('reasoning', '成功')}")
|
reasoning_parts.append(f"事件配置: {event_config_result.get('reasoning', '成功')}")
|
||||||
|
|
||||||
|
|
@ -318,6 +318,12 @@ class SimulationConfigGenerator:
|
||||||
|
|
||||||
reasoning_parts.append(f"Agent配置: 成功生成 {len(all_agent_configs)} 个")
|
reasoning_parts.append(f"Agent配置: 成功生成 {len(all_agent_configs)} 个")
|
||||||
|
|
||||||
|
# ========== 为初始帖子分配发布者 Agent ==========
|
||||||
|
logger.info("为初始帖子分配合适的发布者 Agent...")
|
||||||
|
event_config = self._assign_initial_post_agents(event_config, all_agent_configs)
|
||||||
|
assigned_count = len([p for p in event_config.initial_posts if p.get("poster_agent_id") is not None])
|
||||||
|
reasoning_parts.append(f"初始帖子分配: {assigned_count} 个帖子已分配发布者")
|
||||||
|
|
||||||
# ========== 最后一步: 生成平台配置 ==========
|
# ========== 最后一步: 生成平台配置 ==========
|
||||||
report_progress(total_steps, "生成平台配置...")
|
report_progress(total_steps, "生成平台配置...")
|
||||||
twitter_config = None
|
twitter_config = None
|
||||||
|
|
@ -583,32 +589,63 @@ class SimulationConfigGenerator:
|
||||||
peak_activity_multiplier=1.5
|
peak_activity_multiplier=1.5
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_event_config(self, context: str, simulation_requirement: str) -> Dict[str, Any]:
|
def _generate_event_config(
|
||||||
|
self,
|
||||||
|
context: str,
|
||||||
|
simulation_requirement: str,
|
||||||
|
entities: List[EntityNode]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""生成事件配置"""
|
"""生成事件配置"""
|
||||||
|
|
||||||
|
# 获取可用的实体类型列表,供 LLM 参考
|
||||||
|
entity_types_available = list(set(
|
||||||
|
e.get_entity_type() or "Unknown" for e in entities
|
||||||
|
))
|
||||||
|
|
||||||
|
# 为每种类型列出代表性实体名称
|
||||||
|
type_examples = {}
|
||||||
|
for e in entities:
|
||||||
|
etype = e.get_entity_type() or "Unknown"
|
||||||
|
if etype not in type_examples:
|
||||||
|
type_examples[etype] = []
|
||||||
|
if len(type_examples[etype]) < 3:
|
||||||
|
type_examples[etype].append(e.name)
|
||||||
|
|
||||||
|
type_info = "\n".join([
|
||||||
|
f"- {t}: {', '.join(examples)}"
|
||||||
|
for t, examples in type_examples.items()
|
||||||
|
])
|
||||||
|
|
||||||
prompt = f"""基于以下模拟需求,生成事件配置。
|
prompt = f"""基于以下模拟需求,生成事件配置。
|
||||||
|
|
||||||
模拟需求: {simulation_requirement}
|
模拟需求: {simulation_requirement}
|
||||||
|
|
||||||
{context[:3000]}
|
{context[:3000]}
|
||||||
|
|
||||||
|
## 可用实体类型及示例
|
||||||
|
{type_info}
|
||||||
|
|
||||||
## 任务
|
## 任务
|
||||||
请生成事件配置JSON:
|
请生成事件配置JSON:
|
||||||
- 提取热点话题关键词
|
- 提取热点话题关键词
|
||||||
- 描述舆论发展方向
|
- 描述舆论发展方向
|
||||||
- 设计初始帖子内容
|
- 设计初始帖子内容,**每个帖子必须指定 poster_type(发布者类型)**
|
||||||
|
|
||||||
|
**重要**: poster_type 必须从上面的"可用实体类型"中选择,这样初始帖子才能分配给合适的 Agent 发布。
|
||||||
|
例如:官方声明应由 Official/University 类型发布,新闻由 MediaOutlet 发布,学生观点由 Student 发布。
|
||||||
|
|
||||||
返回JSON格式(不要markdown):
|
返回JSON格式(不要markdown):
|
||||||
{{
|
{{
|
||||||
"hot_topics": ["关键词1", "关键词2", ...],
|
"hot_topics": ["关键词1", "关键词2", ...],
|
||||||
"narrative_direction": "<舆论发展方向描述>",
|
"narrative_direction": "<舆论发展方向描述>",
|
||||||
"initial_posts": [
|
"initial_posts": [
|
||||||
{{"content": "帖子内容", "poster_type": "MediaOutlet"}},
|
{{"content": "帖子内容", "poster_type": "实体类型(必须从可用类型中选择)"}},
|
||||||
...
|
...
|
||||||
],
|
],
|
||||||
"reasoning": "<简要说明>"
|
"reasoning": "<简要说明>"
|
||||||
}}"""
|
}}"""
|
||||||
|
|
||||||
system_prompt = "你是舆论分析专家。返回纯JSON格式。"
|
system_prompt = "你是舆论分析专家。返回纯JSON格式。注意 poster_type 必须精确匹配可用实体类型。"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self._call_llm_with_retry(prompt, system_prompt)
|
return self._call_llm_with_retry(prompt, system_prompt)
|
||||||
|
|
@ -630,6 +667,91 @@ class SimulationConfigGenerator:
|
||||||
narrative_direction=result.get("narrative_direction", "")
|
narrative_direction=result.get("narrative_direction", "")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _assign_initial_post_agents(
|
||||||
|
self,
|
||||||
|
event_config: EventConfig,
|
||||||
|
agent_configs: List[AgentActivityConfig]
|
||||||
|
) -> EventConfig:
|
||||||
|
"""
|
||||||
|
为初始帖子分配合适的发布者 Agent
|
||||||
|
|
||||||
|
根据每个帖子的 poster_type 匹配最合适的 agent_id
|
||||||
|
"""
|
||||||
|
if not event_config.initial_posts:
|
||||||
|
return event_config
|
||||||
|
|
||||||
|
# 按实体类型建立 agent 索引
|
||||||
|
agents_by_type: Dict[str, List[AgentActivityConfig]] = {}
|
||||||
|
for agent in agent_configs:
|
||||||
|
etype = agent.entity_type.lower()
|
||||||
|
if etype not in agents_by_type:
|
||||||
|
agents_by_type[etype] = []
|
||||||
|
agents_by_type[etype].append(agent)
|
||||||
|
|
||||||
|
# 类型映射表(处理 LLM 可能输出的不同格式)
|
||||||
|
type_aliases = {
|
||||||
|
"official": ["official", "university", "governmentagency", "government"],
|
||||||
|
"university": ["university", "official"],
|
||||||
|
"mediaoutlet": ["mediaoutlet", "media"],
|
||||||
|
"student": ["student", "person"],
|
||||||
|
"professor": ["professor", "expert", "teacher"],
|
||||||
|
"alumni": ["alumni", "person"],
|
||||||
|
"organization": ["organization", "ngo", "company", "group"],
|
||||||
|
"person": ["person", "student", "alumni"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 记录每种类型已使用的 agent 索引,避免重复使用同一个 agent
|
||||||
|
used_indices: Dict[str, int] = {}
|
||||||
|
|
||||||
|
updated_posts = []
|
||||||
|
for post in event_config.initial_posts:
|
||||||
|
poster_type = post.get("poster_type", "").lower()
|
||||||
|
content = post.get("content", "")
|
||||||
|
|
||||||
|
# 尝试找到匹配的 agent
|
||||||
|
matched_agent_id = None
|
||||||
|
|
||||||
|
# 1. 直接匹配
|
||||||
|
if poster_type in agents_by_type:
|
||||||
|
agents = agents_by_type[poster_type]
|
||||||
|
idx = used_indices.get(poster_type, 0) % len(agents)
|
||||||
|
matched_agent_id = agents[idx].agent_id
|
||||||
|
used_indices[poster_type] = idx + 1
|
||||||
|
else:
|
||||||
|
# 2. 使用别名匹配
|
||||||
|
for alias_key, aliases in type_aliases.items():
|
||||||
|
if poster_type in aliases or alias_key == poster_type:
|
||||||
|
for alias in aliases:
|
||||||
|
if alias in agents_by_type:
|
||||||
|
agents = agents_by_type[alias]
|
||||||
|
idx = used_indices.get(alias, 0) % len(agents)
|
||||||
|
matched_agent_id = agents[idx].agent_id
|
||||||
|
used_indices[alias] = idx + 1
|
||||||
|
break
|
||||||
|
if matched_agent_id is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 3. 如果仍未找到,使用影响力最高的 agent
|
||||||
|
if matched_agent_id is None:
|
||||||
|
logger.warning(f"未找到类型 '{poster_type}' 的匹配 Agent,使用影响力最高的 Agent")
|
||||||
|
if agent_configs:
|
||||||
|
# 按影响力排序,选择影响力最高的
|
||||||
|
sorted_agents = sorted(agent_configs, key=lambda a: a.influence_weight, reverse=True)
|
||||||
|
matched_agent_id = sorted_agents[0].agent_id
|
||||||
|
else:
|
||||||
|
matched_agent_id = 0
|
||||||
|
|
||||||
|
updated_posts.append({
|
||||||
|
"content": content,
|
||||||
|
"poster_type": post.get("poster_type", "Unknown"),
|
||||||
|
"poster_agent_id": matched_agent_id
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"初始帖子分配: poster_type='{poster_type}' -> agent_id={matched_agent_id}")
|
||||||
|
|
||||||
|
event_config.initial_posts = updated_posts
|
||||||
|
return event_config
|
||||||
|
|
||||||
def _generate_agent_configs_batch(
|
def _generate_agent_configs_batch(
|
||||||
self,
|
self,
|
||||||
context: str,
|
context: str,
|
||||||
|
|
|
||||||
|
|
@ -324,17 +324,30 @@ class SimulationManager:
|
||||||
item_name=msg
|
item_name=msg
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 设置实时保存的文件路径(优先使用 Reddit JSON 格式)
|
||||||
|
realtime_output_path = None
|
||||||
|
realtime_platform = "reddit"
|
||||||
|
if state.enable_reddit:
|
||||||
|
realtime_output_path = os.path.join(sim_dir, "reddit_profiles.json")
|
||||||
|
realtime_platform = "reddit"
|
||||||
|
elif state.enable_twitter:
|
||||||
|
realtime_output_path = os.path.join(sim_dir, "twitter_profiles.csv")
|
||||||
|
realtime_platform = "twitter"
|
||||||
|
|
||||||
profiles = generator.generate_profiles_from_entities(
|
profiles = generator.generate_profiles_from_entities(
|
||||||
entities=filtered.entities,
|
entities=filtered.entities,
|
||||||
use_llm=use_llm_for_profiles,
|
use_llm=use_llm_for_profiles,
|
||||||
progress_callback=profile_progress,
|
progress_callback=profile_progress,
|
||||||
graph_id=state.graph_id, # 传入graph_id用于Zep检索
|
graph_id=state.graph_id, # 传入graph_id用于Zep检索
|
||||||
parallel_count=parallel_profile_count # 并行生成数量
|
parallel_count=parallel_profile_count, # 并行生成数量
|
||||||
|
realtime_output_path=realtime_output_path, # 实时保存路径
|
||||||
|
output_platform=realtime_platform # 输出格式
|
||||||
)
|
)
|
||||||
|
|
||||||
state.profiles_count = len(profiles)
|
state.profiles_count = len(profiles)
|
||||||
|
|
||||||
# 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式)
|
# 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式)
|
||||||
|
# Reddit 已经在生成过程中实时保存了,这里再保存一次确保完整性
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(
|
progress_callback(
|
||||||
"generating_profiles", 95,
|
"generating_profiles", 95,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue