diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index ac6252a..0722a46 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -803,6 +803,116 @@ def get_simulation_profiles(simulation_id: str): }), 500 +@simulation_bp.route('//profiles/realtime', methods=['GET']) +def get_simulation_profiles_realtime(simulation_id: str): + """ + 实时获取模拟的Agent Profile(用于在生成过程中实时查看进度) + + 与 /profiles 接口的区别: + - 直接读取文件,不经过 SimulationManager + - 适用于生成过程中的实时查看 + - 返回额外的元数据(如文件修改时间、是否正在生成等) + + Query参数: + platform: 平台类型(reddit/twitter,默认reddit) + + 返回: + { + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "platform": "reddit", + "count": 15, + "total_expected": 93, // 预期总数(如果有) + "is_generating": true, // 是否正在生成 + "file_exists": true, + "file_modified_at": "2025-12-04T18:20:00", + "profiles": [...] + } + } + """ + import json + import csv + from datetime import datetime + + try: + platform = request.args.get('platform', 'reddit') + + # 获取模拟目录 + sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) + + if not os.path.exists(sim_dir): + return jsonify({ + "success": False, + "error": f"模拟不存在: {simulation_id}" + }), 404 + + # 确定文件路径 + if platform == "reddit": + profiles_file = os.path.join(sim_dir, "reddit_profiles.json") + else: + profiles_file = os.path.join(sim_dir, "twitter_profiles.csv") + + # 检查文件是否存在 + file_exists = os.path.exists(profiles_file) + profiles = [] + file_modified_at = None + + if file_exists: + # 获取文件修改时间 + file_stat = os.stat(profiles_file) + file_modified_at = datetime.fromtimestamp(file_stat.st_mtime).isoformat() + + try: + if platform == "reddit": + with open(profiles_file, 'r', encoding='utf-8') as f: + profiles = json.load(f) + else: + with open(profiles_file, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + profiles = list(reader) + except (json.JSONDecodeError, Exception) as e: + logger.warning(f"读取 profiles 文件失败(可能正在写入中): {e}") + profiles = [] + + # 检查是否正在生成(通过 state.json 判断) + is_generating = False + total_expected = None + + state_file = os.path.join(sim_dir, "state.json") + if os.path.exists(state_file): + try: + with open(state_file, 'r', encoding='utf-8') as f: + state_data = json.load(f) + status = state_data.get("status", "") + is_generating = status == "preparing" + total_expected = state_data.get("entities_count") + except Exception: + pass + + return jsonify({ + "success": True, + "data": { + "simulation_id": simulation_id, + "platform": platform, + "count": len(profiles), + "total_expected": total_expected, + "is_generating": is_generating, + "file_exists": file_exists, + "file_modified_at": file_modified_at, + "profiles": profiles + } + }) + + except Exception as e: + logger.error(f"实时获取Profile失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + @simulation_bp.route('//config', methods=['GET']) def get_simulation_config(simulation_id: str): """ diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 00a5f4a..662ffdc 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -853,7 +853,9 @@ class OasisProfileGenerator: use_llm: bool = True, progress_callback: Optional[callable] = None, graph_id: Optional[str] = None, - parallel_count: int = 5 + parallel_count: int = 5, + realtime_output_path: Optional[str] = None, + output_platform: str = "reddit" ) -> List[OasisAgentProfile]: """ 批量从实体生成Agent Profile(支持并行生成) @@ -864,6 +866,8 @@ class OasisProfileGenerator: progress_callback: 进度回调函数 (current, total, message) graph_id: 图谱ID,用于Zep检索获取更丰富上下文 parallel_count: 并行生成数量,默认5 + realtime_output_path: 实时写入的文件路径(如果提供,每生成一个就写入一次) + output_platform: 输出平台格式 ("reddit" 或 "twitter") Returns: Agent Profile列表 @@ -880,6 +884,37 @@ class OasisProfileGenerator: completed_count = [0] # 使用列表以便在闭包中修改 lock = Lock() + # 实时写入文件的辅助函数 + def save_profiles_realtime(): + """实时保存已生成的 profiles 到文件""" + if not realtime_output_path: + return + + with lock: + # 过滤出已生成的 profiles + existing_profiles = [p for p in profiles if p is not None] + if not existing_profiles: + return + + try: + if output_platform == "reddit": + # Reddit JSON 格式 + profiles_data = [p.to_reddit_format() for p in existing_profiles] + with open(realtime_output_path, 'w', encoding='utf-8') as f: + json.dump(profiles_data, f, ensure_ascii=False, indent=2) + else: + # Twitter CSV 格式 + import csv + profiles_data = [p.to_twitter_format() for p in existing_profiles] + if profiles_data: + fieldnames = list(profiles_data[0].keys()) + with open(realtime_output_path, 'w', encoding='utf-8', newline='') as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(profiles_data) + except Exception as e: + logger.warning(f"实时保存 profiles 失败: {e}") + def generate_single_profile(idx: int, entity: EntityNode) -> tuple: """生成单个profile的工作函数""" entity_type = entity.get_entity_type() or "Entity" @@ -936,6 +971,9 @@ class OasisProfileGenerator: completed_count[0] += 1 current = completed_count[0] + # 实时写入文件 + save_profiles_realtime() + if progress_callback: progress_callback( current, @@ -961,6 +999,8 @@ class OasisProfileGenerator: source_entity_uuid=entity.uuid, source_entity_type=entity_type, ) + # 实时写入文件(即使是备用人设) + save_profiles_realtime() print(f"\n{'='*60}") print(f"人设生成完成!共生成 {len([p for p in profiles if p])} 个Agent") diff --git a/backend/app/services/simulation_config_generator.py b/backend/app/services/simulation_config_generator.py index a75f00e..7fd09b2 100644 --- a/backend/app/services/simulation_config_generator.py +++ b/backend/app/services/simulation_config_generator.py @@ -292,7 +292,7 @@ class SimulationConfigGenerator: # ========== 步骤2: 生成事件配置 ========== report_progress(2, "生成事件配置和热点话题...") - event_config_result = self._generate_event_config(context, simulation_requirement) + event_config_result = self._generate_event_config(context, simulation_requirement, entities) event_config = self._parse_event_config(event_config_result) reasoning_parts.append(f"事件配置: {event_config_result.get('reasoning', '成功')}") @@ -318,6 +318,12 @@ class SimulationConfigGenerator: reasoning_parts.append(f"Agent配置: 成功生成 {len(all_agent_configs)} 个") + # ========== 为初始帖子分配发布者 Agent ========== + logger.info("为初始帖子分配合适的发布者 Agent...") + event_config = self._assign_initial_post_agents(event_config, all_agent_configs) + assigned_count = len([p for p in event_config.initial_posts if p.get("poster_agent_id") is not None]) + reasoning_parts.append(f"初始帖子分配: {assigned_count} 个帖子已分配发布者") + # ========== 最后一步: 生成平台配置 ========== report_progress(total_steps, "生成平台配置...") twitter_config = None @@ -583,32 +589,63 @@ class SimulationConfigGenerator: peak_activity_multiplier=1.5 ) - def _generate_event_config(self, context: str, simulation_requirement: str) -> Dict[str, Any]: + def _generate_event_config( + self, + context: str, + simulation_requirement: str, + entities: List[EntityNode] + ) -> Dict[str, Any]: """生成事件配置""" + + # 获取可用的实体类型列表,供 LLM 参考 + entity_types_available = list(set( + e.get_entity_type() or "Unknown" for e in entities + )) + + # 为每种类型列出代表性实体名称 + type_examples = {} + for e in entities: + etype = e.get_entity_type() or "Unknown" + if etype not in type_examples: + type_examples[etype] = [] + if len(type_examples[etype]) < 3: + type_examples[etype].append(e.name) + + type_info = "\n".join([ + f"- {t}: {', '.join(examples)}" + for t, examples in type_examples.items() + ]) + prompt = f"""基于以下模拟需求,生成事件配置。 模拟需求: {simulation_requirement} {context[:3000]} +## 可用实体类型及示例 +{type_info} + ## 任务 请生成事件配置JSON: - 提取热点话题关键词 - 描述舆论发展方向 -- 设计初始帖子内容 +- 设计初始帖子内容,**每个帖子必须指定 poster_type(发布者类型)** + +**重要**: poster_type 必须从上面的"可用实体类型"中选择,这样初始帖子才能分配给合适的 Agent 发布。 +例如:官方声明应由 Official/University 类型发布,新闻由 MediaOutlet 发布,学生观点由 Student 发布。 返回JSON格式(不要markdown): {{ "hot_topics": ["关键词1", "关键词2", ...], "narrative_direction": "<舆论发展方向描述>", "initial_posts": [ - {{"content": "帖子内容", "poster_type": "MediaOutlet"}}, + {{"content": "帖子内容", "poster_type": "实体类型(必须从可用类型中选择)"}}, ... ], "reasoning": "<简要说明>" }}""" - system_prompt = "你是舆论分析专家。返回纯JSON格式。" + system_prompt = "你是舆论分析专家。返回纯JSON格式。注意 poster_type 必须精确匹配可用实体类型。" try: return self._call_llm_with_retry(prompt, system_prompt) @@ -630,6 +667,91 @@ class SimulationConfigGenerator: narrative_direction=result.get("narrative_direction", "") ) + def _assign_initial_post_agents( + self, + event_config: EventConfig, + agent_configs: List[AgentActivityConfig] + ) -> EventConfig: + """ + 为初始帖子分配合适的发布者 Agent + + 根据每个帖子的 poster_type 匹配最合适的 agent_id + """ + if not event_config.initial_posts: + return event_config + + # 按实体类型建立 agent 索引 + agents_by_type: Dict[str, List[AgentActivityConfig]] = {} + for agent in agent_configs: + etype = agent.entity_type.lower() + if etype not in agents_by_type: + agents_by_type[etype] = [] + agents_by_type[etype].append(agent) + + # 类型映射表(处理 LLM 可能输出的不同格式) + type_aliases = { + "official": ["official", "university", "governmentagency", "government"], + "university": ["university", "official"], + "mediaoutlet": ["mediaoutlet", "media"], + "student": ["student", "person"], + "professor": ["professor", "expert", "teacher"], + "alumni": ["alumni", "person"], + "organization": ["organization", "ngo", "company", "group"], + "person": ["person", "student", "alumni"], + } + + # 记录每种类型已使用的 agent 索引,避免重复使用同一个 agent + used_indices: Dict[str, int] = {} + + updated_posts = [] + for post in event_config.initial_posts: + poster_type = post.get("poster_type", "").lower() + content = post.get("content", "") + + # 尝试找到匹配的 agent + matched_agent_id = None + + # 1. 直接匹配 + if poster_type in agents_by_type: + agents = agents_by_type[poster_type] + idx = used_indices.get(poster_type, 0) % len(agents) + matched_agent_id = agents[idx].agent_id + used_indices[poster_type] = idx + 1 + else: + # 2. 使用别名匹配 + for alias_key, aliases in type_aliases.items(): + if poster_type in aliases or alias_key == poster_type: + for alias in aliases: + if alias in agents_by_type: + agents = agents_by_type[alias] + idx = used_indices.get(alias, 0) % len(agents) + matched_agent_id = agents[idx].agent_id + used_indices[alias] = idx + 1 + break + if matched_agent_id is not None: + break + + # 3. 如果仍未找到,使用影响力最高的 agent + if matched_agent_id is None: + logger.warning(f"未找到类型 '{poster_type}' 的匹配 Agent,使用影响力最高的 Agent") + if agent_configs: + # 按影响力排序,选择影响力最高的 + sorted_agents = sorted(agent_configs, key=lambda a: a.influence_weight, reverse=True) + matched_agent_id = sorted_agents[0].agent_id + else: + matched_agent_id = 0 + + updated_posts.append({ + "content": content, + "poster_type": post.get("poster_type", "Unknown"), + "poster_agent_id": matched_agent_id + }) + + logger.info(f"初始帖子分配: poster_type='{poster_type}' -> agent_id={matched_agent_id}") + + event_config.initial_posts = updated_posts + return event_config + def _generate_agent_configs_batch( self, context: str, diff --git a/backend/app/services/simulation_manager.py b/backend/app/services/simulation_manager.py index d8d1e5c..31dfab6 100644 --- a/backend/app/services/simulation_manager.py +++ b/backend/app/services/simulation_manager.py @@ -324,17 +324,30 @@ class SimulationManager: item_name=msg ) + # 设置实时保存的文件路径(优先使用 Reddit JSON 格式) + realtime_output_path = None + realtime_platform = "reddit" + if state.enable_reddit: + realtime_output_path = os.path.join(sim_dir, "reddit_profiles.json") + realtime_platform = "reddit" + elif state.enable_twitter: + realtime_output_path = os.path.join(sim_dir, "twitter_profiles.csv") + realtime_platform = "twitter" + profiles = generator.generate_profiles_from_entities( entities=filtered.entities, use_llm=use_llm_for_profiles, progress_callback=profile_progress, graph_id=state.graph_id, # 传入graph_id用于Zep检索 - parallel_count=parallel_profile_count # 并行生成数量 + parallel_count=parallel_profile_count, # 并行生成数量 + realtime_output_path=realtime_output_path, # 实时保存路径 + output_platform=realtime_platform # 输出格式 ) state.profiles_count = len(profiles) # 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式) + # Reddit 已经在生成过程中实时保存了,这里再保存一次确保完整性 if progress_callback: progress_callback( "generating_profiles", 95,