diff --git a/backend/README.md b/backend/README.md index 3ef6179..72ece0c 100644 --- a/backend/README.md +++ b/backend/README.md @@ -672,6 +672,12 @@ Flask后端 模拟脚本 > **注意**: 所有Interview接口的参数都通过请求体(JSON)传递,包括simulation_id。 > > **双平台模式说明**: 当不指定`platform`参数时,双平台模拟会同时采访两个平台并返回整合结果。 +> +> **Prompt自动优化**: 系统会自动在用户提供的prompt前添加说明前缀,避免Agent调用工具: +> ``` +> 原始prompt: "武汉大学发布撤销处分通告后你有什么看法" +> 优化后: "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:武汉大学发布撤销处分通告后你有什么看法" +> ``` #### 1. 采访单个Agent @@ -860,11 +866,11 @@ Flask后端 模拟脚本 | 参数 | 类型 | 必填 | 默认值 | 说明 | |------|------|------|--------|------| | simulation_id | String | 是 | - | 模拟ID | -| platform | String | 否 | reddit | 平台类型(reddit/twitter) | -| agent_id | Integer | 否 | - | 过滤Agent ID | +| platform | String | 否 | null | 平台类型(reddit/twitter),不指定则返回两个平台的所有历史 | +| agent_id | Integer | 否 | - | 只获取该Agent的采访历史 | | limit | Integer | 否 | 100 | 返回数量限制 | -**返回示例**: +**返回示例(不指定platform,返回双平台历史)**: ```json { "success": true, @@ -875,7 +881,14 @@ Flask后端 模拟脚本 "agent_id": 0, "response": "我认为...", "prompt": "你对这件事有什么看法?", - "timestamp": "2025-12-08T10:00:00", + "timestamp": "2025-12-08T10:00:02", + "platform": "twitter" + }, + { + "agent_id": 0, + "response": "从Reddit角度来看...", + "prompt": "你对这件事有什么看法?", + "timestamp": "2025-12-08T10:00:01", "platform": "reddit" }, ... diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index 55f210d..f3d6140 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -19,6 +19,29 @@ from ..models.project import ProjectManager logger = get_logger('mirofish.api.simulation') +# Interview prompt 优化前缀 +# 添加此前缀可以避免Agent调用工具,直接用文本回复 +INTERVIEW_PROMPT_PREFIX = "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:" + + +def optimize_interview_prompt(prompt: str) -> str: + """ + 优化Interview提问,添加前缀避免Agent调用工具 + + Args: + prompt: 原始提问 + + Returns: + 优化后的提问 + """ + if not prompt: + return prompt + # 避免重复添加前缀 + if prompt.startswith(INTERVIEW_PROMPT_PREFIX): + return prompt + return f"{INTERVIEW_PROMPT_PREFIX}{prompt}" + + # ============== 实体读取接口 ============== @simulation_bp.route('/entities/', methods=['GET']) @@ -1819,14 +1842,17 @@ def interview_agent(): "error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。" }), 400 + # 优化prompt,添加前缀避免Agent调用工具 + optimized_prompt = optimize_interview_prompt(prompt) + result = SimulationRunner.interview_agent( simulation_id=simulation_id, agent_id=agent_id, - prompt=prompt, + prompt=optimized_prompt, platform=platform, timeout=timeout ) - + return jsonify({ "success": result.get("success", False), "data": result @@ -1951,9 +1977,16 @@ def interview_agents_batch(): "error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。" }), 400 + # 优化每个采访项的prompt,添加前缀避免Agent调用工具 + optimized_interviews = [] + for interview in interviews: + optimized_interview = interview.copy() + optimized_interview['prompt'] = optimize_interview_prompt(interview.get('prompt', '')) + optimized_interviews.append(optimized_interview) + result = SimulationRunner.interview_agents_batch( simulation_id=simulation_id, - interviews=interviews, + interviews=optimized_interviews, platform=platform, timeout=timeout ) @@ -2051,9 +2084,12 @@ def interview_all_agents(): "error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。" }), 400 + # 优化prompt,添加前缀避免Agent调用工具 + optimized_prompt = optimize_interview_prompt(prompt) + result = SimulationRunner.interview_all_agents( simulation_id=simulation_id, - prompt=prompt, + prompt=optimized_prompt, platform=platform, timeout=timeout ) @@ -2094,8 +2130,9 @@ def get_interview_history(): 请求(JSON): { "simulation_id": "sim_xxxx", // 必填,模拟ID - "platform": "reddit", // 可选,平台类型(reddit/twitter),默认reddit - "agent_id": 0, // 可选,过滤Agent ID + "platform": "reddit", // 可选,平台类型(reddit/twitter) + // 不指定则返回两个平台的所有历史 + "agent_id": 0, // 可选,只获取该Agent的采访历史 "limit": 100 // 可选,返回数量,默认100 } @@ -2121,7 +2158,7 @@ def get_interview_history(): data = request.get_json() or {} simulation_id = data.get('simulation_id') - platform = data.get('platform', 'reddit') + platform = data.get('platform') # 不指定则返回两个平台的历史 agent_id = data.get('agent_id') limit = data.get('limit', 100) diff --git a/backend/app/services/simulation_runner.py b/backend/app/services/simulation_runner.py index fd12c79..eda1d7a 100644 --- a/backend/app/services/simulation_runner.py +++ b/backend/app/services/simulation_runner.py @@ -1279,30 +1279,16 @@ class SimulationRunner: } @classmethod - def get_interview_history( + def _get_interview_history_from_db( cls, - simulation_id: str, - platform: str = "reddit", + db_path: str, + platform_name: str, agent_id: Optional[int] = None, limit: int = 100 ) -> List[Dict[str, Any]]: - """ - 获取Interview历史记录(从数据库读取) - - Args: - simulation_id: 模拟ID - platform: 平台类型(reddit/twitter) - agent_id: 过滤Agent ID(可选) - limit: 返回数量限制 - - Returns: - Interview历史记录列表 - """ + """从单个数据库获取Interview历史""" import sqlite3 - sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) - db_path = os.path.join(sim_dir, f"{platform}_simulation.db") - if not os.path.exists(db_path): return [] @@ -1312,8 +1298,6 @@ class SimulationRunner: conn = sqlite3.connect(db_path) cursor = conn.cursor() - # 构建查询 - # 注意:ActionType.INTERVIEW.value 应该是字符串形式 if agent_id is not None: cursor.execute(""" SELECT user_id, info, created_at @@ -1342,13 +1326,66 @@ class SimulationRunner: "response": info.get("response", info), "prompt": info.get("prompt", ""), "timestamp": created_at, - "platform": platform + "platform": platform_name }) conn.close() except Exception as e: - logger.error(f"读取Interview历史失败: {e}") + logger.error(f"读取Interview历史失败 ({platform_name}): {e}") + + return results + + @classmethod + def get_interview_history( + cls, + simulation_id: str, + platform: str = None, + agent_id: Optional[int] = None, + limit: int = 100 + ) -> List[Dict[str, Any]]: + """ + 获取Interview历史记录(从数据库读取) + + Args: + simulation_id: 模拟ID + platform: 平台类型(reddit/twitter/None) + - "reddit": 只获取Reddit平台的历史 + - "twitter": 只获取Twitter平台的历史 + - None: 获取两个平台的所有历史 + agent_id: 指定Agent ID(可选,只获取该Agent的历史) + limit: 每个平台返回数量限制 + + Returns: + Interview历史记录列表 + """ + sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) + + results = [] + + # 确定要查询的平台 + if platform in ("reddit", "twitter"): + platforms = [platform] + else: + # 不指定platform时,查询两个平台 + platforms = ["twitter", "reddit"] + + for p in platforms: + db_path = os.path.join(sim_dir, f"{p}_simulation.db") + platform_results = cls._get_interview_history_from_db( + db_path=db_path, + platform_name=p, + agent_id=agent_id, + limit=limit + ) + results.extend(platform_results) + + # 按时间降序排序 + results.sort(key=lambda x: x.get("timestamp", ""), reverse=True) + + # 如果查询了多个平台,限制总数 + if len(platforms) > 1 and len(results) > limit: + results = results[:limit] return results