From 1042d50306d5d6802570a26fd221d862d21956f5 Mon Sep 17 00:00:00 2001 From: 666ghj <670939375@qq.com> Date: Mon, 8 Dec 2025 15:55:39 +0800 Subject: [PATCH] Implement Interview feature for agent interactions in simulations - Added a new Interview module to facilitate interactions with agents post-simulation, allowing for single and batch interviews. - Introduced IPC communication mechanism for command and response handling between the Flask backend and simulation scripts. - Updated README.md to include detailed instructions on the new Interview functionality, including API endpoints and usage examples. - Enhanced simulation scripts to support waiting for commands after completion, improving user control over the simulation environment. - Implemented error handling and logging for interview processes, ensuring robust operation and traceability. --- backend/README.md | 468 ++++++++++++++++- backend/app/api/simulation.py | 565 +++++++++++++++++++++ backend/app/services/__init__.py | 14 + backend/app/services/simulation_ipc.py | 394 ++++++++++++++ backend/app/services/simulation_runner.py | 364 ++++++++++++- backend/scripts/run_parallel_simulation.py | 556 ++++++++++++++++++-- backend/scripts/run_reddit_simulation.py | 335 +++++++++++- backend/scripts/run_twitter_simulation.py | 337 +++++++++++- 8 files changed, 2963 insertions(+), 70 deletions(-) create mode 100644 backend/app/services/simulation_ipc.py diff --git a/backend/README.md b/backend/README.md index 5eaffbe..3ef6179 100644 --- a/backend/README.md +++ b/backend/README.md @@ -73,6 +73,11 @@ 启动模拟 → 运行OASIS脚本 → 实时监控 → 记录动作 → (可选)更新Zep图谱记忆 → 状态查询 ``` +4. **Interview采访流程**: + ``` + 模拟完成 → 环境进入等待模式 → 发送Interview命令 → Agent回答 → 获取结果 → (可选)关闭环境 + ``` + --- ## 技术栈 @@ -152,6 +157,7 @@ backend/ │ ├── simulation_config_generator.py # 配置生成 │ ├── simulation_manager.py # 模拟管理 │ ├── simulation_runner.py # 模拟运行 + │ ├── simulation_ipc.py # 模拟IPC通信(Interview功能) │ └── zep_graph_memory_updater.py # 图谱记忆动态更新 └── utils/ # 工具类 ├── __init__.py @@ -211,12 +217,43 @@ backend/ 4. 解析动作日志(actions.jsonl) 5. (可选)将Agent活动实时更新到Zep图谱 6. 实时更新运行状态 -7. 支持停止/暂停/恢复 +7. 模拟完成后进入等待命令模式 +8. 支持停止/暂停/恢复 **核心服务**: - `SimulationRunner`: 模拟运行器 - `ZepGraphMemoryUpdater`: 图谱记忆动态更新器 +### 4. Agent采访(Interview)模块 + +**功能**: 在模拟完成后对Agent进行采访 + +**特点**: +- **模拟状态持久化**: 模拟完成后环境不立即关闭,进入等待命令模式 +- **IPC通信机制**: 通过文件系统在Flask后端和模拟脚本之间通信 +- **单个采访**: 对指定Agent提问并获取回答 +- **批量采访**: 同时对多个Agent提不同问题 +- **全局采访**: 使用相同问题采访所有Agent +- **采访历史**: 从数据库读取所有Interview记录 + +**核心服务**: +- `SimulationIPCClient`: IPC客户端(Flask端使用) +- `SimulationIPCServer`: IPC服务器(模拟脚本端使用) + +**工作原理**: +``` +Flask后端 模拟脚本 + │ │ + │ 写入命令文件 │ + │ ─────────────────────────→│ + │ │ 轮询命令目录 + │ │ 执行Interview + │ │ 写入响应文件 + │←───────────────────────── │ + │ 读取响应文件 │ + │ │ +``` + --- ## API接口文档 @@ -630,6 +667,290 @@ backend/ --- +### Interview 采访接口 + +> **注意**: 所有Interview接口的参数都通过请求体(JSON)传递,包括simulation_id。 +> +> **双平台模式说明**: 当不指定`platform`参数时,双平台模拟会同时采访两个平台并返回整合结果。 + +#### 1. 采访单个Agent + +**接口**: `POST /api/simulation/interview` + +**请求参数**: +```json +{ + "simulation_id": "sim_xxxx", + "agent_id": 0, + "prompt": "你对这件事有什么看法?", + "platform": "reddit", + "timeout": 60 +} +``` + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|------|------|------|--------|------| +| simulation_id | String | 是 | - | 模拟ID | +| agent_id | Integer | 是 | - | Agent ID | +| prompt | String | 是 | - | 采访问题 | +| platform | String | 否 | null | 指定平台(twitter/reddit),不指定则双平台同时采访 | +| timeout | Integer | 否 | 60 | 超时时间(秒) | + +**返回示例(指定单平台)**: +```json +{ + "success": true, + "data": { + "success": true, + "agent_id": 0, + "prompt": "你对这件事有什么看法?", + "result": { + "agent_id": 0, + "response": "我认为这件事反映了...", + "platform": "reddit", + "timestamp": "2025-12-08T10:00:00" + }, + "timestamp": "2025-12-08T10:00:01" + } +} +``` + +**返回示例(不指定platform,双平台模式)**: +```json +{ + "success": true, + "data": { + "success": true, + "agent_id": 0, + "prompt": "你对这件事有什么看法?", + "result": { + "agent_id": 0, + "prompt": "你对这件事有什么看法?", + "platforms": { + "twitter": { + "agent_id": 0, + "response": "从Twitter视角来看...", + "platform": "twitter", + "timestamp": "2025-12-08T10:00:00" + }, + "reddit": { + "agent_id": 0, + "response": "作为Reddit用户,我认为...", + "platform": "reddit", + "timestamp": "2025-12-08T10:00:00" + } + } + }, + "timestamp": "2025-12-08T10:00:01" + } +} +``` + +**注意**: 此功能需要模拟环境处于运行状态(完成模拟循环后进入等待命令模式) + +--- + +#### 2. 批量采访多个Agent + +**接口**: `POST /api/simulation/interview/batch` + +**请求参数**: +```json +{ + "simulation_id": "sim_xxxx", + "interviews": [ + {"agent_id": 0, "prompt": "你对A有什么看法?", "platform": "twitter"}, + {"agent_id": 1, "prompt": "你对B有什么看法?", "platform": "reddit"}, + {"agent_id": 2, "prompt": "你对C有什么看法?"} + ], + "platform": "reddit", + "timeout": 120 +} +``` + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|------|------|------|--------|------| +| simulation_id | String | 是 | - | 模拟ID | +| interviews | Array | 是 | - | 采访列表,每项包含agent_id、prompt和可选的platform | +| platform | String | 否 | null | 默认平台(被每项的platform覆盖),不指定则双平台同时采访 | +| timeout | Integer | 否 | 120 | 超时时间(秒) | + +**返回示例**: +```json +{ + "success": true, + "data": { + "success": true, + "interviews_count": 3, + "result": { + "interviews_count": 6, + "results": { + "twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"}, + "reddit_1": {"agent_id": 1, "response": "...", "platform": "reddit"}, + "twitter_2": {"agent_id": 2, "response": "...", "platform": "twitter"}, + "reddit_2": {"agent_id": 2, "response": "...", "platform": "reddit"} + } + }, + "timestamp": "2025-12-08T10:00:01" + } +} +``` + +--- + +#### 3. 全局采访(采访所有Agent) + +**接口**: `POST /api/simulation/interview/all` + +**请求参数**: +```json +{ + "simulation_id": "sim_xxxx", + "prompt": "你对这件事整体有什么看法?", + "platform": "reddit", + "timeout": 180 +} +``` + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|------|------|------|--------|------| +| simulation_id | String | 是 | - | 模拟ID | +| prompt | String | 是 | - | 采访问题(所有Agent使用相同问题) | +| platform | String | 否 | null | 指定平台(twitter/reddit),不指定则双平台同时采访 | +| timeout | Integer | 否 | 180 | 超时时间(秒) | + +**返回示例**: +```json +{ + "success": true, + "data": { + "success": true, + "interviews_count": 50, + "result": { + "interviews_count": 100, + "results": { + "twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"}, + "reddit_0": {"agent_id": 0, "response": "...", "platform": "reddit"}, + "twitter_1": {"agent_id": 1, "response": "...", "platform": "twitter"}, + "reddit_1": {"agent_id": 1, "response": "...", "platform": "reddit"}, + ... + } + }, + "timestamp": "2025-12-08T10:00:01" + } +} +``` + +--- + +#### 4. 获取Interview历史 + +**接口**: `POST /api/simulation/interview/history` + +**请求参数**: +```json +{ + "simulation_id": "sim_xxxx", + "platform": "reddit", + "agent_id": 0, + "limit": 100 +} +``` + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|------|------|------|--------|------| +| simulation_id | String | 是 | - | 模拟ID | +| platform | String | 否 | reddit | 平台类型(reddit/twitter) | +| agent_id | Integer | 否 | - | 过滤Agent ID | +| limit | Integer | 否 | 100 | 返回数量限制 | + +**返回示例**: +```json +{ + "success": true, + "data": { + "count": 10, + "history": [ + { + "agent_id": 0, + "response": "我认为...", + "prompt": "你对这件事有什么看法?", + "timestamp": "2025-12-08T10:00:00", + "platform": "reddit" + }, + ... + ] + } +} +``` + +--- + +#### 5. 获取模拟环境状态 + +**接口**: `POST /api/simulation/env-status` + +**请求参数**: +```json +{ + "simulation_id": "sim_xxxx" +} +``` + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|------|------|------|--------|------| +| simulation_id | String | 是 | - | 模拟ID | + +**返回示例**: +```json +{ + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "env_alive": true, + "twitter_available": true, + "reddit_available": true, + "message": "环境正在运行,可以接收Interview命令" + } +} +``` + +--- + +#### 6. 关闭模拟环境 + +**接口**: `POST /api/simulation/close-env` + +**请求参数**: +```json +{ + "simulation_id": "sim_10b494550540", + "timeout": 30 +} +``` + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|------|------|------|--------|------| +| simulation_id | String | 是 | - | 模拟ID | +| timeout | Integer | 否 | 30 | 超时时间(秒) | + +**返回示例**: +```json +{ + "success": true, + "data": { + "success": true, + "message": "环境关闭命令已发送", + "result": {"message": "环境即将关闭"}, + "timestamp": "2025-12-08T10:00:01" + } +} +``` + +**注意**: 此接口与 `/stop` 不同: +- `/stop`: 强制终止模拟进程 +- `/close-env`: 优雅地关闭环境,让模拟进程正常退出 + #### 6. 获取运行状态 **接口**: `GET /api/simulation/{simulation_id}/run-status` @@ -1395,6 +1716,92 @@ POST /api/simulation/start --- +### 9. SimulationIPCClient/Server (IPC通信模块) + +**文件**: `app/services/simulation_ipc.py` + +**功能**: 实现Flask后端与模拟脚本之间的进程间通信 + +**核心类**: + +```python +class SimulationIPCClient: + """IPC客户端(Flask端使用)""" + + def send_interview(agent_id: int, prompt: str, timeout: float) -> IPCResponse: + """发送单个Agent采访命令""" + + def send_batch_interview(interviews: List[Dict], timeout: float) -> IPCResponse: + """发送批量采访命令""" + + def send_close_env(timeout: float) -> IPCResponse: + """发送关闭环境命令""" + + def check_env_alive() -> bool: + """检查模拟环境是否存活""" +``` + +```python +class SimulationIPCServer: + """IPC服务器(模拟脚本端使用)""" + + def poll_commands() -> Optional[IPCCommand]: + """轮询获取待处理命令""" + + def send_response(response: IPCResponse): + """发送响应""" +``` + +**命令类型**: + +| 命令类型 | 说明 | +|----------|------| +| interview | 单个Agent采访 | +| batch_interview | 批量采访 | +| close_env | 关闭环境 | + +**文件结构**: + +``` +uploads/simulations/sim_xxx/ +├── ipc_commands/ # 命令文件目录 +│ └── {command_id}.json # 待处理命令 +├── ipc_responses/ # 响应文件目录 +│ └── {command_id}.json # 命令响应 +└── env_status.json # 环境状态文件 +``` + +**使用示例**: + +```python +# Flask端发送Interview命令 +from app.services import SimulationRunner + +# 单个采访 +result = SimulationRunner.interview_agent( + simulation_id="sim_xxx", + agent_id=0, + prompt="你对这件事有什么看法?" +) + +# 批量采访 +result = SimulationRunner.interview_agents_batch( + simulation_id="sim_xxx", + interviews=[ + {"agent_id": 0, "prompt": "问题A"}, + {"agent_id": 1, "prompt": "问题B"} + ] +) + +# 全局采访 +result = SimulationRunner.interview_all_agents( + simulation_id="sim_xxx", + prompt="你认为事件会如何发展?" +) +``` + +--- + ## 工具类 ### 1. FileParser (文件解析器) @@ -1661,7 +2068,35 @@ curl -X POST http://localhost:5001/api/simulation/start \ # Step 8: 实时查询运行状态 curl http://localhost:5001/api/simulation/{sim_xxx}/run-status -# Step 9: 停止模拟 +# Step 9: 检查环境状态(模拟完成后环境会进入等待命令模式) +curl http://localhost:5001/api/simulation/{sim_xxx}/env-status + +# Step 10: 采访单个Agent +curl -X POST http://localhost:5001/api/simulation/{sim_xxx}/interview \ + -H "Content-Type: application/json" \ + -d '{ + "agent_id": 0, + "prompt": "你对这件事有什么看法?" + }' + +# Step 11: 全局采访(采访所有Agent) +curl -X POST http://localhost:5001/api/simulation/{sim_xxx}/interview/all \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "你认为事件的后续发展会如何?" + }' + +# Step 12: 获取Interview历史 +curl http://localhost:5001/api/simulation/{sim_xxx}/interview/history + +# Step 13: 关闭模拟环境(优雅退出) +curl -X POST http://localhost:5001/api/simulation/close-env \ + -H "Content-Type: application/json" \ + -d '{ + "simulation_id": "sim_xxx" + }' + +# 或者强制停止模拟 curl -X POST http://localhost:5001/api/simulation/stop \ -H "Content-Type: application/json" \ -d '{ @@ -1830,6 +2265,31 @@ MIT License --- -**最后更新**: 2025-12-05 -**版本**: v1.1.0 +**最后更新**: 2025-12-08 +**版本**: v1.2.0 + +### 更新日志 + +**v1.2.0 (2025-12-08)**: +- 新增 Interview 采访功能 + - 支持单个Agent采访 + - 支持批量采访多个Agent + - 支持全局采访(所有Agent使用相同问题) + - 支持获取Interview历史记录 +- 新增模拟状态持久化 + - 模拟完成后环境不立即关闭,进入等待命令模式 + - 支持优雅关闭环境命令 +- 新增 IPC 通信机制 + - Flask后端与模拟脚本之间的进程间通信 + - 基于文件系统的命令/响应模式 + +**v1.1.0 (2025-12-05)**: +- 新增图谱记忆动态更新功能 +- 支持 max_rounds 参数限制模拟轮数 + +**v1.0.0**: +- 初始版本发布 +- 支持知识图谱构建 +- 支持Agent人设生成 +- 支持双平台模拟 diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index 06f5d40..55f210d 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -1723,3 +1723,568 @@ def get_simulation_comments(simulation_id: str): "error": str(e), "traceback": traceback.format_exc() }), 500 + + +# ============== Interview 采访接口 ============== + +@simulation_bp.route('/interview', methods=['POST']) +def interview_agent(): + """ + 采访单个Agent + + 注意:此功能需要模拟环境处于运行状态(完成模拟循环后进入等待命令模式) + + 请求(JSON): + { + "simulation_id": "sim_xxxx", // 必填,模拟ID + "agent_id": 0, // 必填,Agent ID + "prompt": "你对这件事有什么看法?", // 必填,采访问题 + "platform": "twitter", // 可选,指定平台(twitter/reddit) + // 不指定时:双平台模拟同时采访两个平台 + "timeout": 60 // 可选,超时时间(秒),默认60 + } + + 返回(不指定platform,双平台模式): + { + "success": true, + "data": { + "agent_id": 0, + "prompt": "你对这件事有什么看法?", + "result": { + "agent_id": 0, + "prompt": "...", + "platforms": { + "twitter": {"agent_id": 0, "response": "...", "platform": "twitter"}, + "reddit": {"agent_id": 0, "response": "...", "platform": "reddit"} + } + }, + "timestamp": "2025-12-08T10:00:01" + } + } + + 返回(指定platform): + { + "success": true, + "data": { + "agent_id": 0, + "prompt": "你对这件事有什么看法?", + "result": { + "agent_id": 0, + "response": "我认为...", + "platform": "twitter", + "timestamp": "2025-12-08T10:00:00" + }, + "timestamp": "2025-12-08T10:00:01" + } + } + """ + try: + data = request.get_json() or {} + + simulation_id = data.get('simulation_id') + agent_id = data.get('agent_id') + prompt = data.get('prompt') + platform = data.get('platform') # 可选:twitter/reddit/None + timeout = data.get('timeout', 60) + + if not simulation_id: + return jsonify({ + "success": False, + "error": "请提供 simulation_id" + }), 400 + + if agent_id is None: + return jsonify({ + "success": False, + "error": "请提供 agent_id" + }), 400 + + if not prompt: + return jsonify({ + "success": False, + "error": "请提供 prompt(采访问题)" + }), 400 + + # 验证platform参数 + if platform and platform not in ("twitter", "reddit"): + return jsonify({ + "success": False, + "error": "platform 参数只能是 'twitter' 或 'reddit'" + }), 400 + + # 检查环境状态 + if not SimulationRunner.check_env_alive(simulation_id): + return jsonify({ + "success": False, + "error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。" + }), 400 + + result = SimulationRunner.interview_agent( + simulation_id=simulation_id, + agent_id=agent_id, + prompt=prompt, + platform=platform, + timeout=timeout + ) + + return jsonify({ + "success": result.get("success", False), + "data": result + }) + + except ValueError as e: + return jsonify({ + "success": False, + "error": str(e) + }), 400 + + except TimeoutError as e: + return jsonify({ + "success": False, + "error": f"等待Interview响应超时: {str(e)}" + }), 504 + + except Exception as e: + logger.error(f"Interview失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('/interview/batch', methods=['POST']) +def interview_agents_batch(): + """ + 批量采访多个Agent + + 注意:此功能需要模拟环境处于运行状态 + + 请求(JSON): + { + "simulation_id": "sim_xxxx", // 必填,模拟ID + "interviews": [ // 必填,采访列表 + { + "agent_id": 0, + "prompt": "你对A有什么看法?", + "platform": "twitter" // 可选,指定该Agent的采访平台 + }, + { + "agent_id": 1, + "prompt": "你对B有什么看法?" // 不指定platform则使用默认值 + } + ], + "platform": "reddit", // 可选,默认平台(被每项的platform覆盖) + // 不指定时:双平台模拟每个Agent同时采访两个平台 + "timeout": 120 // 可选,超时时间(秒),默认120 + } + + 返回: + { + "success": true, + "data": { + "interviews_count": 2, + "result": { + "interviews_count": 4, + "results": { + "twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"}, + "reddit_0": {"agent_id": 0, "response": "...", "platform": "reddit"}, + "twitter_1": {"agent_id": 1, "response": "...", "platform": "twitter"}, + "reddit_1": {"agent_id": 1, "response": "...", "platform": "reddit"} + } + }, + "timestamp": "2025-12-08T10:00:01" + } + } + """ + try: + data = request.get_json() or {} + + simulation_id = data.get('simulation_id') + interviews = data.get('interviews') + platform = data.get('platform') # 可选:twitter/reddit/None + timeout = data.get('timeout', 120) + + if not simulation_id: + return jsonify({ + "success": False, + "error": "请提供 simulation_id" + }), 400 + + if not interviews or not isinstance(interviews, list): + return jsonify({ + "success": False, + "error": "请提供 interviews(采访列表)" + }), 400 + + # 验证platform参数 + if platform and platform not in ("twitter", "reddit"): + return jsonify({ + "success": False, + "error": "platform 参数只能是 'twitter' 或 'reddit'" + }), 400 + + # 验证每个采访项 + for i, interview in enumerate(interviews): + if 'agent_id' not in interview: + return jsonify({ + "success": False, + "error": f"采访列表第{i+1}项缺少 agent_id" + }), 400 + if 'prompt' not in interview: + return jsonify({ + "success": False, + "error": f"采访列表第{i+1}项缺少 prompt" + }), 400 + # 验证每项的platform(如果有) + item_platform = interview.get('platform') + if item_platform and item_platform not in ("twitter", "reddit"): + return jsonify({ + "success": False, + "error": f"采访列表第{i+1}项的platform只能是 'twitter' 或 'reddit'" + }), 400 + + # 检查环境状态 + if not SimulationRunner.check_env_alive(simulation_id): + return jsonify({ + "success": False, + "error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。" + }), 400 + + result = SimulationRunner.interview_agents_batch( + simulation_id=simulation_id, + interviews=interviews, + platform=platform, + timeout=timeout + ) + + return jsonify({ + "success": result.get("success", False), + "data": result + }) + + except ValueError as e: + return jsonify({ + "success": False, + "error": str(e) + }), 400 + + except TimeoutError as e: + return jsonify({ + "success": False, + "error": f"等待批量Interview响应超时: {str(e)}" + }), 504 + + except Exception as e: + logger.error(f"批量Interview失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('/interview/all', methods=['POST']) +def interview_all_agents(): + """ + 全局采访 - 使用相同问题采访所有Agent + + 注意:此功能需要模拟环境处于运行状态 + + 请求(JSON): + { + "simulation_id": "sim_xxxx", // 必填,模拟ID + "prompt": "你对这件事整体有什么看法?", // 必填,采访问题(所有Agent使用相同问题) + "platform": "reddit", // 可选,指定平台(twitter/reddit) + // 不指定时:双平台模拟每个Agent同时采访两个平台 + "timeout": 180 // 可选,超时时间(秒),默认180 + } + + 返回: + { + "success": true, + "data": { + "interviews_count": 50, + "result": { + "interviews_count": 100, + "results": { + "twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"}, + "reddit_0": {"agent_id": 0, "response": "...", "platform": "reddit"}, + ... + } + }, + "timestamp": "2025-12-08T10:00:01" + } + } + """ + try: + data = request.get_json() or {} + + simulation_id = data.get('simulation_id') + prompt = data.get('prompt') + platform = data.get('platform') # 可选:twitter/reddit/None + timeout = data.get('timeout', 180) + + if not simulation_id: + return jsonify({ + "success": False, + "error": "请提供 simulation_id" + }), 400 + + if not prompt: + return jsonify({ + "success": False, + "error": "请提供 prompt(采访问题)" + }), 400 + + # 验证platform参数 + if platform and platform not in ("twitter", "reddit"): + return jsonify({ + "success": False, + "error": "platform 参数只能是 'twitter' 或 'reddit'" + }), 400 + + # 检查环境状态 + if not SimulationRunner.check_env_alive(simulation_id): + return jsonify({ + "success": False, + "error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。" + }), 400 + + result = SimulationRunner.interview_all_agents( + simulation_id=simulation_id, + prompt=prompt, + platform=platform, + timeout=timeout + ) + + return jsonify({ + "success": result.get("success", False), + "data": result + }) + + except ValueError as e: + return jsonify({ + "success": False, + "error": str(e) + }), 400 + + except TimeoutError as e: + return jsonify({ + "success": False, + "error": f"等待全局Interview响应超时: {str(e)}" + }), 504 + + except Exception as e: + logger.error(f"全局Interview失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('/interview/history', methods=['POST']) +def get_interview_history(): + """ + 获取Interview历史记录 + + 从模拟数据库中读取所有Interview记录 + + 请求(JSON): + { + "simulation_id": "sim_xxxx", // 必填,模拟ID + "platform": "reddit", // 可选,平台类型(reddit/twitter),默认reddit + "agent_id": 0, // 可选,过滤Agent ID + "limit": 100 // 可选,返回数量,默认100 + } + + 返回: + { + "success": true, + "data": { + "count": 10, + "history": [ + { + "agent_id": 0, + "response": "我认为...", + "prompt": "你对这件事有什么看法?", + "timestamp": "2025-12-08T10:00:00", + "platform": "reddit" + }, + ... + ] + } + } + """ + try: + data = request.get_json() or {} + + simulation_id = data.get('simulation_id') + platform = data.get('platform', 'reddit') + agent_id = data.get('agent_id') + limit = data.get('limit', 100) + + if not simulation_id: + return jsonify({ + "success": False, + "error": "请提供 simulation_id" + }), 400 + + history = SimulationRunner.get_interview_history( + simulation_id=simulation_id, + platform=platform, + agent_id=agent_id, + limit=limit + ) + + return jsonify({ + "success": True, + "data": { + "count": len(history), + "history": history + } + }) + + except Exception as e: + logger.error(f"获取Interview历史失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('/env-status', methods=['POST']) +def get_env_status(): + """ + 获取模拟环境状态 + + 检查模拟环境是否存活(可以接收Interview命令) + + 请求(JSON): + { + "simulation_id": "sim_xxxx" // 必填,模拟ID + } + + 返回: + { + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "env_alive": true, + "twitter_available": true, + "reddit_available": true, + "message": "环境正在运行,可以接收Interview命令" + } + } + """ + try: + data = request.get_json() or {} + + simulation_id = data.get('simulation_id') + + if not simulation_id: + return jsonify({ + "success": False, + "error": "请提供 simulation_id" + }), 400 + + env_alive = SimulationRunner.check_env_alive(simulation_id) + + # 获取更详细的状态信息 + env_status = SimulationRunner.get_env_status_detail(simulation_id) + + if env_alive: + message = "环境正在运行,可以接收Interview命令" + else: + message = "环境未运行或已关闭" + + return jsonify({ + "success": True, + "data": { + "simulation_id": simulation_id, + "env_alive": env_alive, + "twitter_available": env_status.get("twitter_available", False), + "reddit_available": env_status.get("reddit_available", False), + "message": message + } + }) + + except Exception as e: + logger.error(f"获取环境状态失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('/close-env', methods=['POST']) +def close_simulation_env(): + """ + 关闭模拟环境 + + 向模拟发送关闭环境命令,使其优雅退出等待命令模式。 + + 注意:这不同于 /stop 接口,/stop 会强制终止进程, + 而此接口会让模拟优雅地关闭环境并退出。 + + 请求(JSON): + { + "simulation_id": "sim_xxxx", // 必填,模拟ID + "timeout": 30 // 可选,超时时间(秒),默认30 + } + + 返回: + { + "success": true, + "data": { + "message": "环境关闭命令已发送", + "result": {...}, + "timestamp": "2025-12-08T10:00:01" + } + } + """ + try: + data = request.get_json() or {} + + simulation_id = data.get('simulation_id') + timeout = data.get('timeout', 30) + + if not simulation_id: + return jsonify({ + "success": False, + "error": "请提供 simulation_id" + }), 400 + + result = SimulationRunner.close_simulation_env( + simulation_id=simulation_id, + timeout=timeout + ) + + # 更新模拟状态 + manager = SimulationManager() + state = manager.get_simulation(simulation_id) + if state: + state.status = SimulationStatus.COMPLETED + manager._save_simulation_state(state) + + return jsonify({ + "success": result.get("success", False), + "data": result + }) + + except ValueError as e: + return jsonify({ + "success": False, + "error": str(e) + }), 400 + + except Exception as e: + logger.error(f"关闭环境失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py index b4dda02..8db85d8 100644 --- a/backend/app/services/__init__.py +++ b/backend/app/services/__init__.py @@ -28,6 +28,14 @@ from .zep_graph_memory_updater import ( ZepGraphMemoryManager, AgentActivity ) +from .simulation_ipc import ( + SimulationIPCClient, + SimulationIPCServer, + IPCCommand, + IPCResponse, + CommandType, + CommandStatus +) __all__ = [ 'OntologyGenerator', @@ -55,5 +63,11 @@ __all__ = [ 'ZepGraphMemoryUpdater', 'ZepGraphMemoryManager', 'AgentActivity', + 'SimulationIPCClient', + 'SimulationIPCServer', + 'IPCCommand', + 'IPCResponse', + 'CommandType', + 'CommandStatus', ] diff --git a/backend/app/services/simulation_ipc.py b/backend/app/services/simulation_ipc.py new file mode 100644 index 0000000..9d70d0b --- /dev/null +++ b/backend/app/services/simulation_ipc.py @@ -0,0 +1,394 @@ +""" +模拟IPC通信模块 +用于Flask后端和模拟脚本之间的进程间通信 + +通过文件系统实现简单的命令/响应模式: +1. Flask写入命令到 commands/ 目录 +2. 模拟脚本轮询命令目录,执行命令并写入响应到 responses/ 目录 +3. Flask轮询响应目录获取结果 +""" + +import os +import json +import time +import uuid +from typing import Dict, Any, Optional, List +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum + +from ..utils.logger import get_logger + +logger = get_logger('mirofish.simulation_ipc') + + +class CommandType(str, Enum): + """命令类型""" + INTERVIEW = "interview" # 单个Agent采访 + BATCH_INTERVIEW = "batch_interview" # 批量采访 + CLOSE_ENV = "close_env" # 关闭环境 + + +class CommandStatus(str, Enum): + """命令状态""" + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + + +@dataclass +class IPCCommand: + """IPC命令""" + command_id: str + command_type: CommandType + args: Dict[str, Any] + timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) + + def to_dict(self) -> Dict[str, Any]: + return { + "command_id": self.command_id, + "command_type": self.command_type.value, + "args": self.args, + "timestamp": self.timestamp + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'IPCCommand': + return cls( + command_id=data["command_id"], + command_type=CommandType(data["command_type"]), + args=data.get("args", {}), + timestamp=data.get("timestamp", datetime.now().isoformat()) + ) + + +@dataclass +class IPCResponse: + """IPC响应""" + command_id: str + status: CommandStatus + result: Optional[Dict[str, Any]] = None + error: Optional[str] = None + timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) + + def to_dict(self) -> Dict[str, Any]: + return { + "command_id": self.command_id, + "status": self.status.value, + "result": self.result, + "error": self.error, + "timestamp": self.timestamp + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'IPCResponse': + return cls( + command_id=data["command_id"], + status=CommandStatus(data["status"]), + result=data.get("result"), + error=data.get("error"), + timestamp=data.get("timestamp", datetime.now().isoformat()) + ) + + +class SimulationIPCClient: + """ + 模拟IPC客户端(Flask端使用) + + 用于向模拟进程发送命令并等待响应 + """ + + def __init__(self, simulation_dir: str): + """ + 初始化IPC客户端 + + Args: + simulation_dir: 模拟数据目录 + """ + self.simulation_dir = simulation_dir + self.commands_dir = os.path.join(simulation_dir, "ipc_commands") + self.responses_dir = os.path.join(simulation_dir, "ipc_responses") + + # 确保目录存在 + os.makedirs(self.commands_dir, exist_ok=True) + os.makedirs(self.responses_dir, exist_ok=True) + + def send_command( + self, + command_type: CommandType, + args: Dict[str, Any], + timeout: float = 60.0, + poll_interval: float = 0.5 + ) -> IPCResponse: + """ + 发送命令并等待响应 + + Args: + command_type: 命令类型 + args: 命令参数 + timeout: 超时时间(秒) + poll_interval: 轮询间隔(秒) + + Returns: + IPCResponse + + Raises: + TimeoutError: 等待响应超时 + """ + command_id = str(uuid.uuid4()) + command = IPCCommand( + command_id=command_id, + command_type=command_type, + args=args + ) + + # 写入命令文件 + command_file = os.path.join(self.commands_dir, f"{command_id}.json") + with open(command_file, 'w', encoding='utf-8') as f: + json.dump(command.to_dict(), f, ensure_ascii=False, indent=2) + + logger.info(f"发送IPC命令: {command_type.value}, command_id={command_id}") + + # 等待响应 + response_file = os.path.join(self.responses_dir, f"{command_id}.json") + start_time = time.time() + + while time.time() - start_time < timeout: + if os.path.exists(response_file): + try: + with open(response_file, 'r', encoding='utf-8') as f: + response_data = json.load(f) + response = IPCResponse.from_dict(response_data) + + # 清理命令和响应文件 + try: + os.remove(command_file) + os.remove(response_file) + except OSError: + pass + + logger.info(f"收到IPC响应: command_id={command_id}, status={response.status.value}") + return response + except (json.JSONDecodeError, KeyError) as e: + logger.warning(f"解析响应失败: {e}") + + time.sleep(poll_interval) + + # 超时 + logger.error(f"等待IPC响应超时: command_id={command_id}") + + # 清理命令文件 + try: + os.remove(command_file) + except OSError: + pass + + raise TimeoutError(f"等待命令响应超时 ({timeout}秒)") + + def send_interview( + self, + agent_id: int, + prompt: str, + platform: str = None, + timeout: float = 60.0 + ) -> IPCResponse: + """ + 发送单个Agent采访命令 + + Args: + agent_id: Agent ID + prompt: 采访问题 + platform: 指定平台(可选) + - "twitter": 只采访Twitter平台 + - "reddit": 只采访Reddit平台 + - None: 双平台模拟时同时采访两个平台,单平台模拟时采访该平台 + timeout: 超时时间 + + Returns: + IPCResponse,result字段包含采访结果 + """ + args = { + "agent_id": agent_id, + "prompt": prompt + } + if platform: + args["platform"] = platform + + return self.send_command( + command_type=CommandType.INTERVIEW, + args=args, + timeout=timeout + ) + + def send_batch_interview( + self, + interviews: List[Dict[str, Any]], + platform: str = None, + timeout: float = 120.0 + ) -> IPCResponse: + """ + 发送批量采访命令 + + Args: + interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)} + platform: 默认平台(可选,会被每个采访项的platform覆盖) + - "twitter": 默认只采访Twitter平台 + - "reddit": 默认只采访Reddit平台 + - None: 双平台模拟时每个Agent同时采访两个平台 + timeout: 超时时间 + + Returns: + IPCResponse,result字段包含所有采访结果 + """ + args = {"interviews": interviews} + if platform: + args["platform"] = platform + + return self.send_command( + command_type=CommandType.BATCH_INTERVIEW, + args=args, + timeout=timeout + ) + + def send_close_env(self, timeout: float = 30.0) -> IPCResponse: + """ + 发送关闭环境命令 + + Args: + timeout: 超时时间 + + Returns: + IPCResponse + """ + return self.send_command( + command_type=CommandType.CLOSE_ENV, + args={}, + timeout=timeout + ) + + def check_env_alive(self) -> bool: + """ + 检查模拟环境是否存活 + + 通过检查 env_status.json 文件来判断 + """ + status_file = os.path.join(self.simulation_dir, "env_status.json") + if not os.path.exists(status_file): + return False + + try: + with open(status_file, 'r', encoding='utf-8') as f: + status = json.load(f) + return status.get("status") == "alive" + except (json.JSONDecodeError, OSError): + return False + + +class SimulationIPCServer: + """ + 模拟IPC服务器(模拟脚本端使用) + + 轮询命令目录,执行命令并返回响应 + """ + + def __init__(self, simulation_dir: str): + """ + 初始化IPC服务器 + + Args: + simulation_dir: 模拟数据目录 + """ + self.simulation_dir = simulation_dir + self.commands_dir = os.path.join(simulation_dir, "ipc_commands") + self.responses_dir = os.path.join(simulation_dir, "ipc_responses") + + # 确保目录存在 + os.makedirs(self.commands_dir, exist_ok=True) + os.makedirs(self.responses_dir, exist_ok=True) + + # 环境状态 + self._running = False + + def start(self): + """标记服务器为运行状态""" + self._running = True + self._update_env_status("alive") + + def stop(self): + """标记服务器为停止状态""" + self._running = False + self._update_env_status("stopped") + + def _update_env_status(self, status: str): + """更新环境状态文件""" + status_file = os.path.join(self.simulation_dir, "env_status.json") + with open(status_file, 'w', encoding='utf-8') as f: + json.dump({ + "status": status, + "timestamp": datetime.now().isoformat() + }, f, ensure_ascii=False, indent=2) + + def poll_commands(self) -> Optional[IPCCommand]: + """ + 轮询命令目录,返回第一个待处理的命令 + + Returns: + IPCCommand 或 None + """ + if not os.path.exists(self.commands_dir): + return None + + # 按时间排序获取命令文件 + command_files = [] + for filename in os.listdir(self.commands_dir): + if filename.endswith('.json'): + filepath = os.path.join(self.commands_dir, filename) + command_files.append((filepath, os.path.getmtime(filepath))) + + command_files.sort(key=lambda x: x[1]) + + for filepath, _ in command_files: + try: + with open(filepath, 'r', encoding='utf-8') as f: + data = json.load(f) + return IPCCommand.from_dict(data) + except (json.JSONDecodeError, KeyError, OSError) as e: + logger.warning(f"读取命令文件失败: {filepath}, {e}") + continue + + return None + + def send_response(self, response: IPCResponse): + """ + 发送响应 + + Args: + response: IPC响应 + """ + response_file = os.path.join(self.responses_dir, f"{response.command_id}.json") + with open(response_file, 'w', encoding='utf-8') as f: + json.dump(response.to_dict(), f, ensure_ascii=False, indent=2) + + # 删除命令文件 + command_file = os.path.join(self.commands_dir, f"{response.command_id}.json") + try: + os.remove(command_file) + except OSError: + pass + + def send_success(self, command_id: str, result: Dict[str, Any]): + """发送成功响应""" + self.send_response(IPCResponse( + command_id=command_id, + status=CommandStatus.COMPLETED, + result=result + )) + + def send_error(self, command_id: str, error: str): + """发送错误响应""" + self.send_response(IPCResponse( + command_id=command_id, + status=CommandStatus.FAILED, + error=error + )) diff --git a/backend/app/services/simulation_runner.py b/backend/app/services/simulation_runner.py index 125eede..fd12c79 100644 --- a/backend/app/services/simulation_runner.py +++ b/backend/app/services/simulation_runner.py @@ -12,7 +12,7 @@ import threading import subprocess import signal import atexit -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, Union from dataclasses import dataclass, field from datetime import datetime from enum import Enum @@ -21,6 +21,7 @@ from queue import Queue from ..config import Config from ..utils.logger import get_logger from .zep_graph_memory_updater import ZepGraphMemoryManager +from .simulation_ipc import SimulationIPCClient, CommandType, IPCResponse logger = get_logger('mirofish.simulation_runner') @@ -989,4 +990,365 @@ class SimulationRunner: if process.poll() is None: running.append(sim_id) return running + + # ============== Interview 功能 ============== + + @classmethod + def check_env_alive(cls, simulation_id: str) -> bool: + """ + 检查模拟环境是否存活(可以接收Interview命令) + + Args: + simulation_id: 模拟ID + + Returns: + True 表示环境存活,False 表示环境已关闭 + """ + sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) + if not os.path.exists(sim_dir): + return False + + ipc_client = SimulationIPCClient(sim_dir) + return ipc_client.check_env_alive() + + @classmethod + def get_env_status_detail(cls, simulation_id: str) -> Dict[str, Any]: + """ + 获取模拟环境的详细状态信息 + + Args: + simulation_id: 模拟ID + + Returns: + 状态详情字典,包含 status, twitter_available, reddit_available, timestamp + """ + sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) + status_file = os.path.join(sim_dir, "env_status.json") + + default_status = { + "status": "stopped", + "twitter_available": False, + "reddit_available": False, + "timestamp": None + } + + if not os.path.exists(status_file): + return default_status + + try: + with open(status_file, 'r', encoding='utf-8') as f: + status = json.load(f) + return { + "status": status.get("status", "stopped"), + "twitter_available": status.get("twitter_available", False), + "reddit_available": status.get("reddit_available", False), + "timestamp": status.get("timestamp") + } + except (json.JSONDecodeError, OSError): + return default_status + + @classmethod + def interview_agent( + cls, + simulation_id: str, + agent_id: int, + prompt: str, + platform: str = None, + timeout: float = 60.0 + ) -> Dict[str, Any]: + """ + 采访单个Agent + + Args: + simulation_id: 模拟ID + agent_id: Agent ID + prompt: 采访问题 + platform: 指定平台(可选) + - "twitter": 只采访Twitter平台 + - "reddit": 只采访Reddit平台 + - None: 双平台模拟时同时采访两个平台,返回整合结果 + timeout: 超时时间(秒) + + Returns: + 采访结果字典 + + Raises: + ValueError: 模拟不存在或环境未运行 + TimeoutError: 等待响应超时 + """ + sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) + if not os.path.exists(sim_dir): + raise ValueError(f"模拟不存在: {simulation_id}") + + ipc_client = SimulationIPCClient(sim_dir) + + if not ipc_client.check_env_alive(): + raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}") + + logger.info(f"发送Interview命令: simulation_id={simulation_id}, agent_id={agent_id}, platform={platform}") + + response = ipc_client.send_interview( + agent_id=agent_id, + prompt=prompt, + platform=platform, + timeout=timeout + ) + + if response.status.value == "completed": + return { + "success": True, + "agent_id": agent_id, + "prompt": prompt, + "result": response.result, + "timestamp": response.timestamp + } + else: + return { + "success": False, + "agent_id": agent_id, + "prompt": prompt, + "error": response.error, + "timestamp": response.timestamp + } + + @classmethod + def interview_agents_batch( + cls, + simulation_id: str, + interviews: List[Dict[str, Any]], + platform: str = None, + timeout: float = 120.0 + ) -> Dict[str, Any]: + """ + 批量采访多个Agent + + Args: + simulation_id: 模拟ID + interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)} + platform: 默认平台(可选,会被每个采访项的platform覆盖) + - "twitter": 默认只采访Twitter平台 + - "reddit": 默认只采访Reddit平台 + - None: 双平台模拟时每个Agent同时采访两个平台 + timeout: 超时时间(秒) + + Returns: + 批量采访结果字典 + + Raises: + ValueError: 模拟不存在或环境未运行 + TimeoutError: 等待响应超时 + """ + sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) + if not os.path.exists(sim_dir): + raise ValueError(f"模拟不存在: {simulation_id}") + + ipc_client = SimulationIPCClient(sim_dir) + + if not ipc_client.check_env_alive(): + raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}") + + logger.info(f"发送批量Interview命令: simulation_id={simulation_id}, count={len(interviews)}, platform={platform}") + + response = ipc_client.send_batch_interview( + interviews=interviews, + platform=platform, + timeout=timeout + ) + + if response.status.value == "completed": + return { + "success": True, + "interviews_count": len(interviews), + "result": response.result, + "timestamp": response.timestamp + } + else: + return { + "success": False, + "interviews_count": len(interviews), + "error": response.error, + "timestamp": response.timestamp + } + + @classmethod + def interview_all_agents( + cls, + simulation_id: str, + prompt: str, + platform: str = None, + timeout: float = 180.0 + ) -> Dict[str, Any]: + """ + 采访所有Agent(全局采访) + + 使用相同的问题采访模拟中的所有Agent + + Args: + simulation_id: 模拟ID + prompt: 采访问题(所有Agent使用相同问题) + platform: 指定平台(可选) + - "twitter": 只采访Twitter平台 + - "reddit": 只采访Reddit平台 + - None: 双平台模拟时每个Agent同时采访两个平台 + timeout: 超时时间(秒) + + Returns: + 全局采访结果字典 + """ + sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) + if not os.path.exists(sim_dir): + raise ValueError(f"模拟不存在: {simulation_id}") + + # 从配置文件获取所有Agent信息 + config_path = os.path.join(sim_dir, "simulation_config.json") + if not os.path.exists(config_path): + raise ValueError(f"模拟配置不存在: {simulation_id}") + + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + + agent_configs = config.get("agent_configs", []) + if not agent_configs: + raise ValueError(f"模拟配置中没有Agent: {simulation_id}") + + # 构建批量采访列表 + interviews = [] + for agent_config in agent_configs: + agent_id = agent_config.get("agent_id") + if agent_id is not None: + interviews.append({ + "agent_id": agent_id, + "prompt": prompt + }) + + logger.info(f"发送全局Interview命令: simulation_id={simulation_id}, agent_count={len(interviews)}, platform={platform}") + + return cls.interview_agents_batch( + simulation_id=simulation_id, + interviews=interviews, + platform=platform, + timeout=timeout + ) + + @classmethod + def close_simulation_env( + cls, + simulation_id: str, + timeout: float = 30.0 + ) -> Dict[str, Any]: + """ + 关闭模拟环境(而不是停止模拟进程) + + 向模拟发送关闭环境命令,使其优雅退出等待命令模式 + + Args: + simulation_id: 模拟ID + timeout: 超时时间(秒) + + Returns: + 操作结果字典 + """ + sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) + if not os.path.exists(sim_dir): + raise ValueError(f"模拟不存在: {simulation_id}") + + ipc_client = SimulationIPCClient(sim_dir) + + if not ipc_client.check_env_alive(): + return { + "success": True, + "message": "环境已经关闭" + } + + logger.info(f"发送关闭环境命令: simulation_id={simulation_id}") + + try: + response = ipc_client.send_close_env(timeout=timeout) + + return { + "success": response.status.value == "completed", + "message": "环境关闭命令已发送", + "result": response.result, + "timestamp": response.timestamp + } + except TimeoutError: + # 超时可能是因为环境正在关闭 + return { + "success": True, + "message": "环境关闭命令已发送(等待响应超时,环境可能正在关闭)" + } + + @classmethod + def get_interview_history( + cls, + simulation_id: str, + platform: str = "reddit", + agent_id: Optional[int] = None, + limit: int = 100 + ) -> List[Dict[str, Any]]: + """ + 获取Interview历史记录(从数据库读取) + + Args: + simulation_id: 模拟ID + platform: 平台类型(reddit/twitter) + agent_id: 过滤Agent ID(可选) + limit: 返回数量限制 + + Returns: + Interview历史记录列表 + """ + import sqlite3 + + sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) + db_path = os.path.join(sim_dir, f"{platform}_simulation.db") + + if not os.path.exists(db_path): + return [] + + results = [] + + try: + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # 构建查询 + # 注意:ActionType.INTERVIEW.value 应该是字符串形式 + if agent_id is not None: + cursor.execute(""" + SELECT user_id, info, created_at + FROM trace + WHERE action = 'interview' AND user_id = ? + ORDER BY created_at DESC + LIMIT ? + """, (agent_id, limit)) + else: + cursor.execute(""" + SELECT user_id, info, created_at + FROM trace + WHERE action = 'interview' + ORDER BY created_at DESC + LIMIT ? + """, (limit,)) + + for user_id, info_json, created_at in cursor.fetchall(): + try: + info = json.loads(info_json) if info_json else {} + except json.JSONDecodeError: + info = {"raw": info_json} + + results.append({ + "agent_id": user_id, + "response": info.get("response", info), + "prompt": info.get("prompt", ""), + "timestamp": created_at, + "platform": platform + }) + + conn.close() + + except Exception as e: + logger.error(f"读取Interview历史失败: {e}") + + return results diff --git a/backend/scripts/run_parallel_simulation.py b/backend/scripts/run_parallel_simulation.py index 1702656..d00dc17 100644 --- a/backend/scripts/run_parallel_simulation.py +++ b/backend/scripts/run_parallel_simulation.py @@ -2,8 +2,18 @@ OASIS 双平台并行模拟预设脚本 同时运行Twitter和Reddit模拟,读取相同的配置文件 +功能特性: +- 双平台(Twitter + Reddit)并行模拟 +- 完成模拟后不立即关闭环境,进入等待命令模式 +- 支持通过IPC接收Interview命令 +- 支持单个Agent采访和批量采访 +- 支持远程关闭环境命令 + 使用方式: python run_parallel_simulation.py --config simulation_config.json + python run_parallel_simulation.py --config simulation_config.json --no-wait # 完成后立即关闭 + python run_parallel_simulation.py --config simulation_config.json --twitter-only + python run_parallel_simulation.py --config simulation_config.json --reddit-only 日志结构: sim_xxx/ @@ -119,7 +129,7 @@ except ImportError as e: sys.exit(1) -# Twitter可用动作 +# Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) TWITTER_ACTIONS = [ ActionType.CREATE_POST, ActionType.LIKE_POST, @@ -129,7 +139,7 @@ TWITTER_ACTIONS = [ ActionType.QUOTE_POST, ] -# Reddit可用动作 +# Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) REDDIT_ACTIONS = [ ActionType.LIKE_POST, ActionType.DISLIKE_POST, @@ -147,6 +157,405 @@ REDDIT_ACTIONS = [ ] +# IPC相关常量 +IPC_COMMANDS_DIR = "ipc_commands" +IPC_RESPONSES_DIR = "ipc_responses" +ENV_STATUS_FILE = "env_status.json" + +class CommandType: + """命令类型常量""" + INTERVIEW = "interview" + BATCH_INTERVIEW = "batch_interview" + CLOSE_ENV = "close_env" + + +class ParallelIPCHandler: + """ + 双平台IPC命令处理器 + + 管理两个平台的环境,处理Interview命令 + """ + + def __init__( + self, + simulation_dir: str, + twitter_env=None, + twitter_agent_graph=None, + reddit_env=None, + reddit_agent_graph=None + ): + self.simulation_dir = simulation_dir + self.twitter_env = twitter_env + self.twitter_agent_graph = twitter_agent_graph + self.reddit_env = reddit_env + self.reddit_agent_graph = reddit_agent_graph + + self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR) + self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) + self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) + + # 确保目录存在 + os.makedirs(self.commands_dir, exist_ok=True) + os.makedirs(self.responses_dir, exist_ok=True) + + def update_status(self, status: str): + """更新环境状态""" + with open(self.status_file, 'w', encoding='utf-8') as f: + json.dump({ + "status": status, + "twitter_available": self.twitter_env is not None, + "reddit_available": self.reddit_env is not None, + "timestamp": datetime.now().isoformat() + }, f, ensure_ascii=False, indent=2) + + def poll_command(self) -> Optional[Dict[str, Any]]: + """轮询获取待处理命令""" + if not os.path.exists(self.commands_dir): + return None + + # 获取命令文件(按时间排序) + command_files = [] + for filename in os.listdir(self.commands_dir): + if filename.endswith('.json'): + filepath = os.path.join(self.commands_dir, filename) + command_files.append((filepath, os.path.getmtime(filepath))) + + command_files.sort(key=lambda x: x[1]) + + for filepath, _ in command_files: + try: + with open(filepath, 'r', encoding='utf-8') as f: + return json.load(f) + except (json.JSONDecodeError, OSError): + continue + + return None + + def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): + """发送响应""" + response = { + "command_id": command_id, + "status": status, + "result": result, + "error": error, + "timestamp": datetime.now().isoformat() + } + + response_file = os.path.join(self.responses_dir, f"{command_id}.json") + with open(response_file, 'w', encoding='utf-8') as f: + json.dump(response, f, ensure_ascii=False, indent=2) + + # 删除命令文件 + command_file = os.path.join(self.commands_dir, f"{command_id}.json") + try: + os.remove(command_file) + except OSError: + pass + + def _get_env_and_graph(self, platform: str): + """ + 获取指定平台的环境和agent_graph + + Args: + platform: 平台名称 ("twitter" 或 "reddit") + + Returns: + (env, agent_graph, platform_name) 或 (None, None, None) + """ + if platform == "twitter" and self.twitter_env: + return self.twitter_env, self.twitter_agent_graph, "twitter" + elif platform == "reddit" and self.reddit_env: + return self.reddit_env, self.reddit_agent_graph, "reddit" + else: + return None, None, None + + async def _interview_single_platform(self, agent_id: int, prompt: str, platform: str) -> Dict[str, Any]: + """ + 在单个平台上执行Interview + + Returns: + 包含结果的字典,或包含error的字典 + """ + env, agent_graph, actual_platform = self._get_env_and_graph(platform) + + if not env or not agent_graph: + return {"platform": platform, "error": f"{platform}平台不可用"} + + try: + agent = agent_graph.get_agent(agent_id) + interview_action = ManualAction( + action_type=ActionType.INTERVIEW, + action_args={"prompt": prompt} + ) + actions = {agent: interview_action} + await env.step(actions) + + result = self._get_interview_result(agent_id, actual_platform) + result["platform"] = actual_platform + return result + + except Exception as e: + return {"platform": platform, "error": str(e)} + + async def handle_interview(self, command_id: str, agent_id: int, prompt: str, platform: str = None) -> bool: + """ + 处理单个Agent采访命令 + + Args: + command_id: 命令ID + agent_id: Agent ID + prompt: 采访问题 + platform: 指定平台(可选) + - "twitter": 只采访Twitter平台 + - "reddit": 只采访Reddit平台 + - None/不指定: 同时采访两个平台,返回整合结果 + + Returns: + True 表示成功,False 表示失败 + """ + # 如果指定了平台,只采访该平台 + if platform in ("twitter", "reddit"): + result = await self._interview_single_platform(agent_id, prompt, platform) + + if "error" in result: + self.send_response(command_id, "failed", error=result["error"]) + print(f" Interview失败: agent_id={agent_id}, platform={platform}, error={result['error']}") + return False + else: + self.send_response(command_id, "completed", result=result) + print(f" Interview完成: agent_id={agent_id}, platform={platform}") + return True + + # 未指定平台:同时采访两个平台 + if not self.twitter_env and not self.reddit_env: + self.send_response(command_id, "failed", error="没有可用的模拟环境") + return False + + results = { + "agent_id": agent_id, + "prompt": prompt, + "platforms": {} + } + success_count = 0 + + # 并行采访两个平台 + tasks = [] + platforms_to_interview = [] + + if self.twitter_env: + tasks.append(self._interview_single_platform(agent_id, prompt, "twitter")) + platforms_to_interview.append("twitter") + + if self.reddit_env: + tasks.append(self._interview_single_platform(agent_id, prompt, "reddit")) + platforms_to_interview.append("reddit") + + # 并行执行 + platform_results = await asyncio.gather(*tasks) + + for platform_name, platform_result in zip(platforms_to_interview, platform_results): + results["platforms"][platform_name] = platform_result + if "error" not in platform_result: + success_count += 1 + + if success_count > 0: + self.send_response(command_id, "completed", result=results) + print(f" Interview完成: agent_id={agent_id}, 成功平台数={success_count}/{len(platforms_to_interview)}") + return True + else: + errors = [f"{p}: {r.get('error', '未知错误')}" for p, r in results["platforms"].items()] + self.send_response(command_id, "failed", error="; ".join(errors)) + print(f" Interview失败: agent_id={agent_id}, 所有平台都失败") + return False + + async def handle_batch_interview(self, command_id: str, interviews: List[Dict], platform: str = None) -> bool: + """ + 处理批量采访命令 + + Args: + command_id: 命令ID + interviews: [{"agent_id": int, "prompt": str, "platform": str(optional)}, ...] + platform: 默认平台(可被每个interview项覆盖) + - "twitter": 只采访Twitter平台 + - "reddit": 只采访Reddit平台 + - None/不指定: 每个Agent同时采访两个平台 + """ + # 按平台分组 + twitter_interviews = [] + reddit_interviews = [] + both_platforms_interviews = [] # 需要同时采访两个平台的 + + for interview in interviews: + item_platform = interview.get("platform", platform) + if item_platform == "twitter": + twitter_interviews.append(interview) + elif item_platform == "reddit": + reddit_interviews.append(interview) + else: + # 未指定平台:两个平台都采访 + both_platforms_interviews.append(interview) + + # 把 both_platforms_interviews 拆分到两个平台 + if both_platforms_interviews: + if self.twitter_env: + twitter_interviews.extend(both_platforms_interviews) + if self.reddit_env: + reddit_interviews.extend(both_platforms_interviews) + + results = {} + + # 处理Twitter平台的采访 + if twitter_interviews and self.twitter_env: + try: + twitter_actions = {} + for interview in twitter_interviews: + agent_id = interview.get("agent_id") + prompt = interview.get("prompt", "") + try: + agent = self.twitter_agent_graph.get_agent(agent_id) + twitter_actions[agent] = ManualAction( + action_type=ActionType.INTERVIEW, + action_args={"prompt": prompt} + ) + except Exception as e: + print(f" 警告: 无法获取Twitter Agent {agent_id}: {e}") + + if twitter_actions: + await self.twitter_env.step(twitter_actions) + + for interview in twitter_interviews: + agent_id = interview.get("agent_id") + result = self._get_interview_result(agent_id, "twitter") + result["platform"] = "twitter" + results[f"twitter_{agent_id}"] = result + except Exception as e: + print(f" Twitter批量Interview失败: {e}") + + # 处理Reddit平台的采访 + if reddit_interviews and self.reddit_env: + try: + reddit_actions = {} + for interview in reddit_interviews: + agent_id = interview.get("agent_id") + prompt = interview.get("prompt", "") + try: + agent = self.reddit_agent_graph.get_agent(agent_id) + reddit_actions[agent] = ManualAction( + action_type=ActionType.INTERVIEW, + action_args={"prompt": prompt} + ) + except Exception as e: + print(f" 警告: 无法获取Reddit Agent {agent_id}: {e}") + + if reddit_actions: + await self.reddit_env.step(reddit_actions) + + for interview in reddit_interviews: + agent_id = interview.get("agent_id") + result = self._get_interview_result(agent_id, "reddit") + result["platform"] = "reddit" + results[f"reddit_{agent_id}"] = result + except Exception as e: + print(f" Reddit批量Interview失败: {e}") + + if results: + self.send_response(command_id, "completed", result={ + "interviews_count": len(results), + "results": results + }) + print(f" 批量Interview完成: {len(results)} 个Agent") + return True + else: + self.send_response(command_id, "failed", error="没有成功的采访") + return False + + def _get_interview_result(self, agent_id: int, platform: str) -> Dict[str, Any]: + """从数据库获取最新的Interview结果""" + db_path = os.path.join(self.simulation_dir, f"{platform}_simulation.db") + + result = { + "agent_id": agent_id, + "response": None, + "timestamp": None + } + + if not os.path.exists(db_path): + return result + + try: + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # 查询最新的Interview记录 + cursor.execute(""" + SELECT user_id, info, created_at + FROM trace + WHERE action = ? AND user_id = ? + ORDER BY created_at DESC + LIMIT 1 + """, (ActionType.INTERVIEW.value, agent_id)) + + row = cursor.fetchone() + if row: + user_id, info_json, created_at = row + try: + info = json.loads(info_json) if info_json else {} + result["response"] = info.get("response", info) + result["timestamp"] = created_at + except json.JSONDecodeError: + result["response"] = info_json + + conn.close() + + except Exception as e: + print(f" 读取Interview结果失败: {e}") + + return result + + async def process_commands(self) -> bool: + """ + 处理所有待处理命令 + + Returns: + True 表示继续运行,False 表示应该退出 + """ + command = self.poll_command() + if not command: + return True + + command_id = command.get("command_id") + command_type = command.get("command_type") + args = command.get("args", {}) + + print(f"\n收到IPC命令: {command_type}, id={command_id}") + + if command_type == CommandType.INTERVIEW: + await self.handle_interview( + command_id, + args.get("agent_id", 0), + args.get("prompt", ""), + args.get("platform") + ) + return True + + elif command_type == CommandType.BATCH_INTERVIEW: + await self.handle_batch_interview( + command_id, + args.get("interviews", []), + args.get("platform") + ) + return True + + elif command_type == CommandType.CLOSE_ENV: + print("收到关闭环境命令") + self.send_response(command_id, "completed", result={"message": "环境即将关闭"}) + return False + + else: + self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}") + return True + + def load_config(config_path: str) -> Dict[str, Any]: """加载配置文件""" with open(config_path, 'r', encoding='utf-8') as f: @@ -398,13 +807,21 @@ def get_active_agents_for_round( return active_agents +class PlatformSimulation: + """平台模拟结果容器""" + def __init__(self): + self.env = None + self.agent_graph = None + self.total_actions = 0 + + async def run_twitter_simulation( config: Dict[str, Any], simulation_dir: str, action_logger: Optional[PlatformActionLogger] = None, main_logger: Optional[SimulationLogManager] = None, max_rounds: Optional[int] = None -): +) -> PlatformSimulation: """运行Twitter模拟 Args: @@ -413,7 +830,12 @@ async def run_twitter_simulation( action_logger: 动作日志记录器 main_logger: 主日志管理器 max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) + + Returns: + PlatformSimulation: 包含env和agent_graph的结果对象 """ + result = PlatformSimulation() + def log_info(msg): if main_logger: main_logger.info(f"[Twitter] {msg}") @@ -428,9 +850,9 @@ async def run_twitter_simulation( profile_path = os.path.join(simulation_dir, "twitter_profiles.csv") if not os.path.exists(profile_path): log_info(f"错误: Profile文件不存在: {profile_path}") - return + return result - agent_graph = await generate_twitter_agent_graph( + result.agent_graph = await generate_twitter_agent_graph( profile_path=profile_path, model=model, available_actions=TWITTER_ACTIONS, @@ -439,7 +861,7 @@ async def run_twitter_simulation( # 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X) agent_names = get_agent_names_from_config(config) # 如果配置中没有某个 agent,则使用 OASIS 的默认名称 - for agent_id, agent in agent_graph.get_agents(): + for agent_id, agent in result.agent_graph.get_agents(): if agent_id not in agent_names: agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}') @@ -447,14 +869,14 @@ async def run_twitter_simulation( if os.path.exists(db_path): os.remove(db_path) - env = oasis.make( - agent_graph=agent_graph, + result.env = oasis.make( + agent_graph=result.agent_graph, platform=oasis.DefaultPlatformType.TWITTER, database_path=db_path, semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 ) - await env.reset() + await result.env.reset() log_info("环境已启动") if action_logger: @@ -478,7 +900,7 @@ async def run_twitter_simulation( agent_id = post.get("poster_agent_id", 0) content = post.get("content", "") try: - agent = env.agent_graph.get_agent(agent_id) + agent = result.env.agent_graph.get_agent(agent_id) initial_actions[agent] = ManualAction( action_type=ActionType.CREATE_POST, action_args={"content": content} @@ -498,7 +920,7 @@ async def run_twitter_simulation( pass if initial_actions: - await env.step(initial_actions) + await result.env.step(initial_actions) log_info(f"已发布 {len(initial_actions)} 条初始帖子") # 记录 round 0 结束 @@ -526,7 +948,7 @@ async def run_twitter_simulation( simulated_day = simulated_minutes // (60 * 24) + 1 active_agents = get_active_agents_for_round( - env, config, simulated_hour, round_num + result.env, config, simulated_hour, round_num ) # 无论是否有活跃agent,都记录round开始 @@ -540,7 +962,7 @@ async def run_twitter_simulation( continue actions = {agent: LLMAction() for _, agent in active_agents} - await env.step(actions) + await result.env.step(actions) # 从数据库获取实际执行的动作并记录 actual_actions, last_rowid = fetch_new_actions_from_db( @@ -567,13 +989,16 @@ async def run_twitter_simulation( progress = (round_num + 1) / total_rounds * 100 log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)") - await env.close() + # 注意:不关闭环境,保留给Interview使用 if action_logger: action_logger.log_simulation_end(total_rounds, total_actions) + result.total_actions = total_actions elapsed = (datetime.now() - start_time).total_seconds() - log_info(f"模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}") + log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}") + + return result async def run_reddit_simulation( @@ -582,7 +1007,7 @@ async def run_reddit_simulation( action_logger: Optional[PlatformActionLogger] = None, main_logger: Optional[SimulationLogManager] = None, max_rounds: Optional[int] = None -): +) -> PlatformSimulation: """运行Reddit模拟 Args: @@ -591,7 +1016,12 @@ async def run_reddit_simulation( action_logger: 动作日志记录器 main_logger: 主日志管理器 max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) + + Returns: + PlatformSimulation: 包含env和agent_graph的结果对象 """ + result = PlatformSimulation() + def log_info(msg): if main_logger: main_logger.info(f"[Reddit] {msg}") @@ -605,9 +1035,9 @@ async def run_reddit_simulation( profile_path = os.path.join(simulation_dir, "reddit_profiles.json") if not os.path.exists(profile_path): log_info(f"错误: Profile文件不存在: {profile_path}") - return + return result - agent_graph = await generate_reddit_agent_graph( + result.agent_graph = await generate_reddit_agent_graph( profile_path=profile_path, model=model, available_actions=REDDIT_ACTIONS, @@ -616,7 +1046,7 @@ async def run_reddit_simulation( # 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X) agent_names = get_agent_names_from_config(config) # 如果配置中没有某个 agent,则使用 OASIS 的默认名称 - for agent_id, agent in agent_graph.get_agents(): + for agent_id, agent in result.agent_graph.get_agents(): if agent_id not in agent_names: agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}') @@ -624,14 +1054,14 @@ async def run_reddit_simulation( if os.path.exists(db_path): os.remove(db_path) - env = oasis.make( - agent_graph=agent_graph, + result.env = oasis.make( + agent_graph=result.agent_graph, platform=oasis.DefaultPlatformType.REDDIT, database_path=db_path, semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 ) - await env.reset() + await result.env.reset() log_info("环境已启动") if action_logger: @@ -655,7 +1085,7 @@ async def run_reddit_simulation( agent_id = post.get("poster_agent_id", 0) content = post.get("content", "") try: - agent = env.agent_graph.get_agent(agent_id) + agent = result.env.agent_graph.get_agent(agent_id) if agent in initial_actions: if not isinstance(initial_actions[agent], list): initial_actions[agent] = [initial_actions[agent]] @@ -683,7 +1113,7 @@ async def run_reddit_simulation( pass if initial_actions: - await env.step(initial_actions) + await result.env.step(initial_actions) log_info(f"已发布 {len(initial_actions)} 条初始帖子") # 记录 round 0 结束 @@ -711,7 +1141,7 @@ async def run_reddit_simulation( simulated_day = simulated_minutes // (60 * 24) + 1 active_agents = get_active_agents_for_round( - env, config, simulated_hour, round_num + result.env, config, simulated_hour, round_num ) # 无论是否有活跃agent,都记录round开始 @@ -725,7 +1155,7 @@ async def run_reddit_simulation( continue actions = {agent: LLMAction() for _, agent in active_agents} - await env.step(actions) + await result.env.step(actions) # 从数据库获取实际执行的动作并记录 actual_actions, last_rowid = fetch_new_actions_from_db( @@ -752,13 +1182,16 @@ async def run_reddit_simulation( progress = (round_num + 1) / total_rounds * 100 log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)") - await env.close() + # 注意:不关闭环境,保留给Interview使用 if action_logger: action_logger.log_simulation_end(total_rounds, total_actions) + result.total_actions = total_actions elapsed = (datetime.now() - start_time).total_seconds() - log_info(f"模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}") + log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}") + + return result async def main(): @@ -785,6 +1218,12 @@ async def main(): default=None, help='最大模拟轮数(可选,用于截断过长的模拟)' ) + parser.add_argument( + '--no-wait', + action='store_true', + default=False, + help='模拟完成后立即关闭环境,不进入等待命令模式' + ) args = parser.parse_args() @@ -794,6 +1233,7 @@ async def main(): config = load_config(args.config) simulation_dir = os.path.dirname(args.config) or "." + wait_for_commands = not args.no_wait # 初始化日志配置(禁用 OASIS 日志,清理旧文件) init_logging_for_simulation(simulation_dir) @@ -807,6 +1247,7 @@ async def main(): log_manager.info("OASIS 双平台并行模拟") log_manager.info(f"配置文件: {args.config}") log_manager.info(f"模拟ID: {config.get('simulation_id', 'unknown')}") + log_manager.info(f"等待命令模式: {'启用' if wait_for_commands else '禁用'}") log_manager.info("=" * 60) time_config = config.get("time_config", {}) @@ -832,20 +1273,70 @@ async def main(): start_time = datetime.now() + # 存储两个平台的模拟结果 + twitter_result: Optional[PlatformSimulation] = None + reddit_result: Optional[PlatformSimulation] = None + if args.twitter_only: - await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds) + twitter_result = await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds) elif args.reddit_only: - await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds) + reddit_result = await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds) else: # 并行运行(每个平台使用独立的日志记录器) - await asyncio.gather( + results = await asyncio.gather( run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds), run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds), ) + twitter_result, reddit_result = results total_elapsed = (datetime.now() - start_time).total_seconds() log_manager.info("=" * 60) - log_manager.info(f"全部模拟完成! 总耗时: {total_elapsed:.1f}秒") + log_manager.info(f"模拟循环完成! 总耗时: {total_elapsed:.1f}秒") + + # 是否进入等待命令模式 + if wait_for_commands: + log_manager.info("") + log_manager.info("=" * 60) + log_manager.info("进入等待命令模式 - 环境保持运行") + log_manager.info("支持的命令: interview, batch_interview, close_env") + log_manager.info("=" * 60) + + # 创建IPC处理器 + ipc_handler = ParallelIPCHandler( + simulation_dir=simulation_dir, + twitter_env=twitter_result.env if twitter_result else None, + twitter_agent_graph=twitter_result.agent_graph if twitter_result else None, + reddit_env=reddit_result.env if reddit_result else None, + reddit_agent_graph=reddit_result.agent_graph if reddit_result else None + ) + ipc_handler.update_status("alive") + + # 等待命令循环 + try: + while True: + should_continue = await ipc_handler.process_commands() + if not should_continue: + break + await asyncio.sleep(0.5) # 轮询间隔 + except KeyboardInterrupt: + print("\n收到中断信号") + except Exception as e: + print(f"\n命令处理出错: {e}") + + log_manager.info("\n关闭环境...") + ipc_handler.update_status("stopped") + + # 关闭环境 + if twitter_result and twitter_result.env: + await twitter_result.env.close() + log_manager.info("[Twitter] 环境已关闭") + + if reddit_result and reddit_result.env: + await reddit_result.env.close() + log_manager.info("[Reddit] 环境已关闭") + + log_manager.info("=" * 60) + log_manager.info(f"全部完成!") log_manager.info(f"日志文件:") log_manager.info(f" - {os.path.join(simulation_dir, 'simulation.log')}") log_manager.info(f" - {os.path.join(simulation_dir, 'twitter', 'actions.jsonl')}") @@ -855,4 +1346,3 @@ async def main(): if __name__ == "__main__": asyncio.run(main()) - diff --git a/backend/scripts/run_reddit_simulation.py b/backend/scripts/run_reddit_simulation.py index 1ac62fe..2fa073e 100644 --- a/backend/scripts/run_reddit_simulation.py +++ b/backend/scripts/run_reddit_simulation.py @@ -2,8 +2,15 @@ OASIS Reddit模拟预设脚本 此脚本读取配置文件中的参数来执行模拟,实现全程自动化 +功能特性: +- 完成模拟后不立即关闭环境,进入等待命令模式 +- 支持通过IPC接收Interview命令 +- 支持单个Agent采访和批量采访 +- 支持远程关闭环境命令 + 使用方式: python run_reddit_simulation.py --config /path/to/simulation_config.json + python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭 """ import argparse @@ -13,8 +20,9 @@ import logging import os import random import sys +import sqlite3 from datetime import datetime -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional # 添加项目路径 _scripts_dir = os.path.dirname(os.path.abspath(__file__)) @@ -118,10 +126,261 @@ except ImportError as e: sys.exit(1) +# IPC相关常量 +IPC_COMMANDS_DIR = "ipc_commands" +IPC_RESPONSES_DIR = "ipc_responses" +ENV_STATUS_FILE = "env_status.json" + +class CommandType: + """命令类型常量""" + INTERVIEW = "interview" + BATCH_INTERVIEW = "batch_interview" + CLOSE_ENV = "close_env" + + +class IPCHandler: + """IPC命令处理器""" + + def __init__(self, simulation_dir: str, env, agent_graph): + self.simulation_dir = simulation_dir + self.env = env + self.agent_graph = agent_graph + self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR) + self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) + self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) + self._running = True + + # 确保目录存在 + os.makedirs(self.commands_dir, exist_ok=True) + os.makedirs(self.responses_dir, exist_ok=True) + + def update_status(self, status: str): + """更新环境状态""" + with open(self.status_file, 'w', encoding='utf-8') as f: + json.dump({ + "status": status, + "timestamp": datetime.now().isoformat() + }, f, ensure_ascii=False, indent=2) + + def poll_command(self) -> Optional[Dict[str, Any]]: + """轮询获取待处理命令""" + if not os.path.exists(self.commands_dir): + return None + + # 获取命令文件(按时间排序) + command_files = [] + for filename in os.listdir(self.commands_dir): + if filename.endswith('.json'): + filepath = os.path.join(self.commands_dir, filename) + command_files.append((filepath, os.path.getmtime(filepath))) + + command_files.sort(key=lambda x: x[1]) + + for filepath, _ in command_files: + try: + with open(filepath, 'r', encoding='utf-8') as f: + return json.load(f) + except (json.JSONDecodeError, OSError): + continue + + return None + + def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): + """发送响应""" + response = { + "command_id": command_id, + "status": status, + "result": result, + "error": error, + "timestamp": datetime.now().isoformat() + } + + response_file = os.path.join(self.responses_dir, f"{command_id}.json") + with open(response_file, 'w', encoding='utf-8') as f: + json.dump(response, f, ensure_ascii=False, indent=2) + + # 删除命令文件 + command_file = os.path.join(self.commands_dir, f"{command_id}.json") + try: + os.remove(command_file) + except OSError: + pass + + async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool: + """ + 处理单个Agent采访命令 + + Returns: + True 表示成功,False 表示失败 + """ + try: + # 获取Agent + agent = self.agent_graph.get_agent(agent_id) + + # 创建Interview动作 + interview_action = ManualAction( + action_type=ActionType.INTERVIEW, + action_args={"prompt": prompt} + ) + + # 执行Interview + actions = {agent: interview_action} + await self.env.step(actions) + + # 从数据库获取结果 + result = self._get_interview_result(agent_id) + + self.send_response(command_id, "completed", result=result) + print(f" Interview完成: agent_id={agent_id}") + return True + + except Exception as e: + error_msg = str(e) + print(f" Interview失败: agent_id={agent_id}, error={error_msg}") + self.send_response(command_id, "failed", error=error_msg) + return False + + async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool: + """ + 处理批量采访命令 + + Args: + interviews: [{"agent_id": int, "prompt": str}, ...] + """ + try: + # 构建动作字典 + actions = {} + agent_prompts = {} # 记录每个agent的prompt + + for interview in interviews: + agent_id = interview.get("agent_id") + prompt = interview.get("prompt", "") + + try: + agent = self.agent_graph.get_agent(agent_id) + actions[agent] = ManualAction( + action_type=ActionType.INTERVIEW, + action_args={"prompt": prompt} + ) + agent_prompts[agent_id] = prompt + except Exception as e: + print(f" 警告: 无法获取Agent {agent_id}: {e}") + + if not actions: + self.send_response(command_id, "failed", error="没有有效的Agent") + return False + + # 执行批量Interview + await self.env.step(actions) + + # 获取所有结果 + results = {} + for agent_id in agent_prompts.keys(): + result = self._get_interview_result(agent_id) + results[agent_id] = result + + self.send_response(command_id, "completed", result={ + "interviews_count": len(results), + "results": results + }) + print(f" 批量Interview完成: {len(results)} 个Agent") + return True + + except Exception as e: + error_msg = str(e) + print(f" 批量Interview失败: {error_msg}") + self.send_response(command_id, "failed", error=error_msg) + return False + + def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: + """从数据库获取最新的Interview结果""" + db_path = os.path.join(self.simulation_dir, "reddit_simulation.db") + + result = { + "agent_id": agent_id, + "response": None, + "timestamp": None + } + + if not os.path.exists(db_path): + return result + + try: + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # 查询最新的Interview记录 + cursor.execute(""" + SELECT user_id, info, created_at + FROM trace + WHERE action = ? AND user_id = ? + ORDER BY created_at DESC + LIMIT 1 + """, (ActionType.INTERVIEW.value, agent_id)) + + row = cursor.fetchone() + if row: + user_id, info_json, created_at = row + try: + info = json.loads(info_json) if info_json else {} + result["response"] = info.get("response", info) + result["timestamp"] = created_at + except json.JSONDecodeError: + result["response"] = info_json + + conn.close() + + except Exception as e: + print(f" 读取Interview结果失败: {e}") + + return result + + async def process_commands(self) -> bool: + """ + 处理所有待处理命令 + + Returns: + True 表示继续运行,False 表示应该退出 + """ + command = self.poll_command() + if not command: + return True + + command_id = command.get("command_id") + command_type = command.get("command_type") + args = command.get("args", {}) + + print(f"\n收到IPC命令: {command_type}, id={command_id}") + + if command_type == CommandType.INTERVIEW: + await self.handle_interview( + command_id, + args.get("agent_id", 0), + args.get("prompt", "") + ) + return True + + elif command_type == CommandType.BATCH_INTERVIEW: + await self.handle_batch_interview( + command_id, + args.get("interviews", []) + ) + return True + + elif command_type == CommandType.CLOSE_ENV: + print("收到关闭环境命令") + self.send_response(command_id, "completed", result={"message": "环境即将关闭"}) + return False + + else: + self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}") + return True + + class RedditSimulationRunner: """Reddit模拟运行器""" - # Reddit可用动作 + # Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) AVAILABLE_ACTIONS = [ ActionType.LIKE_POST, ActionType.DISLIKE_POST, @@ -138,16 +397,21 @@ class RedditSimulationRunner: ActionType.MUTE, ] - def __init__(self, config_path: str): + def __init__(self, config_path: str, wait_for_commands: bool = True): """ 初始化模拟运行器 Args: config_path: 配置文件路径 (simulation_config.json) + wait_for_commands: 模拟完成后是否等待命令(默认True) """ self.config_path = config_path self.config = self._load_config() self.simulation_dir = os.path.dirname(config_path) + self.wait_for_commands = wait_for_commands + self.env = None + self.agent_graph = None + self.ipc_handler = None def _load_config(self) -> Dict[str, Any]: """加载配置文件""" @@ -261,6 +525,7 @@ class RedditSimulationRunner: print("OASIS Reddit模拟") print(f"配置文件: {self.config_path}") print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}") + print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}") print("=" * 60) time_config = self.config.get("time_config", {}) @@ -292,7 +557,7 @@ class RedditSimulationRunner: print(f"错误: Profile文件不存在: {profile_path}") return - agent_graph = await generate_reddit_agent_graph( + self.agent_graph = await generate_reddit_agent_graph( profile_path=profile_path, model=model, available_actions=self.AVAILABLE_ACTIONS, @@ -304,16 +569,20 @@ class RedditSimulationRunner: print(f"已删除旧数据库: {db_path}") print("创建OASIS环境...") - env = oasis.make( - agent_graph=agent_graph, + self.env = oasis.make( + agent_graph=self.agent_graph, platform=oasis.DefaultPlatformType.REDDIT, database_path=db_path, semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 ) - await env.reset() + await self.env.reset() print("环境初始化完成\n") + # 初始化IPC处理器 + self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph) + self.ipc_handler.update_status("running") + # 执行初始事件 event_config = self.config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) @@ -325,7 +594,7 @@ class RedditSimulationRunner: agent_id = post.get("poster_agent_id", 0) content = post.get("content", "") try: - agent = env.agent_graph.get_agent(agent_id) + agent = self.env.agent_graph.get_agent(agent_id) if agent in initial_actions: if not isinstance(initial_actions[agent], list): initial_actions[agent] = [initial_actions[agent]] @@ -342,7 +611,7 @@ class RedditSimulationRunner: print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}") if initial_actions: - await env.step(initial_actions) + await self.env.step(initial_actions) print(f" 已发布 {len(initial_actions)} 条初始帖子") # 主模拟循环 @@ -355,7 +624,7 @@ class RedditSimulationRunner: simulated_day = simulated_minutes // (60 * 24) + 1 active_agents = self._get_active_agents_for_round( - env, simulated_hour, round_num + self.env, simulated_hour, round_num ) if not active_agents: @@ -366,7 +635,7 @@ class RedditSimulationRunner: for _, agent in active_agents } - await env.step(actions) + await self.env.step(actions) if (round_num + 1) % 10 == 0 or round_num == 0: elapsed = (datetime.now() - start_time).total_seconds() @@ -376,12 +645,39 @@ class RedditSimulationRunner: f"- {len(active_agents)} agents active " f"- elapsed: {elapsed:.1f}s") - await env.close() - total_elapsed = (datetime.now() - start_time).total_seconds() - print(f"\n模拟完成!") + print(f"\n模拟循环完成!") print(f" - 总耗时: {total_elapsed:.1f}秒") print(f" - 数据库: {db_path}") + + # 是否进入等待命令模式 + if self.wait_for_commands: + print("\n" + "=" * 60) + print("进入等待命令模式 - 环境保持运行") + print("支持的命令: interview, batch_interview, close_env") + print("=" * 60) + + self.ipc_handler.update_status("alive") + + # 等待命令循环 + try: + while True: + should_continue = await self.ipc_handler.process_commands() + if not should_continue: + break + await asyncio.sleep(0.5) # 轮询间隔 + except KeyboardInterrupt: + print("\n收到中断信号") + except Exception as e: + print(f"\n命令处理出错: {e}") + + print("\n关闭环境...") + + # 关闭环境 + self.ipc_handler.update_status("stopped") + await self.env.close() + + print("环境已关闭") print("=" * 60) @@ -399,6 +695,12 @@ async def main(): default=None, help='最大模拟轮数(可选,用于截断过长的模拟)' ) + parser.add_argument( + '--no-wait', + action='store_true', + default=False, + help='模拟完成后立即关闭环境,不进入等待命令模式' + ) args = parser.parse_args() @@ -410,7 +712,10 @@ async def main(): simulation_dir = os.path.dirname(args.config) or "." setup_oasis_logging(os.path.join(simulation_dir, "log")) - runner = RedditSimulationRunner(args.config) + runner = RedditSimulationRunner( + config_path=args.config, + wait_for_commands=not args.no_wait + ) await runner.run(max_rounds=args.max_rounds) diff --git a/backend/scripts/run_twitter_simulation.py b/backend/scripts/run_twitter_simulation.py index 1d4ffda..c2a0f1f 100644 --- a/backend/scripts/run_twitter_simulation.py +++ b/backend/scripts/run_twitter_simulation.py @@ -2,8 +2,15 @@ OASIS Twitter模拟预设脚本 此脚本读取配置文件中的参数来执行模拟,实现全程自动化 +功能特性: +- 完成模拟后不立即关闭环境,进入等待命令模式 +- 支持通过IPC接收Interview命令 +- 支持单个Agent采访和批量采访 +- 支持远程关闭环境命令 + 使用方式: python run_twitter_simulation.py --config /path/to/simulation_config.json + python run_twitter_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭 """ import argparse @@ -13,8 +20,9 @@ import logging import os import random import sys +import sqlite3 from datetime import datetime -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional # 添加项目路径 _scripts_dir = os.path.dirname(os.path.abspath(__file__)) @@ -118,10 +126,261 @@ except ImportError as e: sys.exit(1) +# IPC相关常量 +IPC_COMMANDS_DIR = "ipc_commands" +IPC_RESPONSES_DIR = "ipc_responses" +ENV_STATUS_FILE = "env_status.json" + +class CommandType: + """命令类型常量""" + INTERVIEW = "interview" + BATCH_INTERVIEW = "batch_interview" + CLOSE_ENV = "close_env" + + +class IPCHandler: + """IPC命令处理器""" + + def __init__(self, simulation_dir: str, env, agent_graph): + self.simulation_dir = simulation_dir + self.env = env + self.agent_graph = agent_graph + self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR) + self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) + self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) + self._running = True + + # 确保目录存在 + os.makedirs(self.commands_dir, exist_ok=True) + os.makedirs(self.responses_dir, exist_ok=True) + + def update_status(self, status: str): + """更新环境状态""" + with open(self.status_file, 'w', encoding='utf-8') as f: + json.dump({ + "status": status, + "timestamp": datetime.now().isoformat() + }, f, ensure_ascii=False, indent=2) + + def poll_command(self) -> Optional[Dict[str, Any]]: + """轮询获取待处理命令""" + if not os.path.exists(self.commands_dir): + return None + + # 获取命令文件(按时间排序) + command_files = [] + for filename in os.listdir(self.commands_dir): + if filename.endswith('.json'): + filepath = os.path.join(self.commands_dir, filename) + command_files.append((filepath, os.path.getmtime(filepath))) + + command_files.sort(key=lambda x: x[1]) + + for filepath, _ in command_files: + try: + with open(filepath, 'r', encoding='utf-8') as f: + return json.load(f) + except (json.JSONDecodeError, OSError): + continue + + return None + + def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): + """发送响应""" + response = { + "command_id": command_id, + "status": status, + "result": result, + "error": error, + "timestamp": datetime.now().isoformat() + } + + response_file = os.path.join(self.responses_dir, f"{command_id}.json") + with open(response_file, 'w', encoding='utf-8') as f: + json.dump(response, f, ensure_ascii=False, indent=2) + + # 删除命令文件 + command_file = os.path.join(self.commands_dir, f"{command_id}.json") + try: + os.remove(command_file) + except OSError: + pass + + async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool: + """ + 处理单个Agent采访命令 + + Returns: + True 表示成功,False 表示失败 + """ + try: + # 获取Agent + agent = self.agent_graph.get_agent(agent_id) + + # 创建Interview动作 + interview_action = ManualAction( + action_type=ActionType.INTERVIEW, + action_args={"prompt": prompt} + ) + + # 执行Interview + actions = {agent: interview_action} + await self.env.step(actions) + + # 从数据库获取结果 + result = self._get_interview_result(agent_id) + + self.send_response(command_id, "completed", result=result) + print(f" Interview完成: agent_id={agent_id}") + return True + + except Exception as e: + error_msg = str(e) + print(f" Interview失败: agent_id={agent_id}, error={error_msg}") + self.send_response(command_id, "failed", error=error_msg) + return False + + async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool: + """ + 处理批量采访命令 + + Args: + interviews: [{"agent_id": int, "prompt": str}, ...] + """ + try: + # 构建动作字典 + actions = {} + agent_prompts = {} # 记录每个agent的prompt + + for interview in interviews: + agent_id = interview.get("agent_id") + prompt = interview.get("prompt", "") + + try: + agent = self.agent_graph.get_agent(agent_id) + actions[agent] = ManualAction( + action_type=ActionType.INTERVIEW, + action_args={"prompt": prompt} + ) + agent_prompts[agent_id] = prompt + except Exception as e: + print(f" 警告: 无法获取Agent {agent_id}: {e}") + + if not actions: + self.send_response(command_id, "failed", error="没有有效的Agent") + return False + + # 执行批量Interview + await self.env.step(actions) + + # 获取所有结果 + results = {} + for agent_id in agent_prompts.keys(): + result = self._get_interview_result(agent_id) + results[agent_id] = result + + self.send_response(command_id, "completed", result={ + "interviews_count": len(results), + "results": results + }) + print(f" 批量Interview完成: {len(results)} 个Agent") + return True + + except Exception as e: + error_msg = str(e) + print(f" 批量Interview失败: {error_msg}") + self.send_response(command_id, "failed", error=error_msg) + return False + + def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: + """从数据库获取最新的Interview结果""" + db_path = os.path.join(self.simulation_dir, "twitter_simulation.db") + + result = { + "agent_id": agent_id, + "response": None, + "timestamp": None + } + + if not os.path.exists(db_path): + return result + + try: + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # 查询最新的Interview记录 + cursor.execute(""" + SELECT user_id, info, created_at + FROM trace + WHERE action = ? AND user_id = ? + ORDER BY created_at DESC + LIMIT 1 + """, (ActionType.INTERVIEW.value, agent_id)) + + row = cursor.fetchone() + if row: + user_id, info_json, created_at = row + try: + info = json.loads(info_json) if info_json else {} + result["response"] = info.get("response", info) + result["timestamp"] = created_at + except json.JSONDecodeError: + result["response"] = info_json + + conn.close() + + except Exception as e: + print(f" 读取Interview结果失败: {e}") + + return result + + async def process_commands(self) -> bool: + """ + 处理所有待处理命令 + + Returns: + True 表示继续运行,False 表示应该退出 + """ + command = self.poll_command() + if not command: + return True + + command_id = command.get("command_id") + command_type = command.get("command_type") + args = command.get("args", {}) + + print(f"\n收到IPC命令: {command_type}, id={command_id}") + + if command_type == CommandType.INTERVIEW: + await self.handle_interview( + command_id, + args.get("agent_id", 0), + args.get("prompt", "") + ) + return True + + elif command_type == CommandType.BATCH_INTERVIEW: + await self.handle_batch_interview( + command_id, + args.get("interviews", []) + ) + return True + + elif command_type == CommandType.CLOSE_ENV: + print("收到关闭环境命令") + self.send_response(command_id, "completed", result={"message": "环境即将关闭"}) + return False + + else: + self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}") + return True + + class TwitterSimulationRunner: """Twitter模拟运行器""" - # Twitter可用动作 + # Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) AVAILABLE_ACTIONS = [ ActionType.CREATE_POST, ActionType.LIKE_POST, @@ -131,16 +390,21 @@ class TwitterSimulationRunner: ActionType.QUOTE_POST, ] - def __init__(self, config_path: str): + def __init__(self, config_path: str, wait_for_commands: bool = True): """ 初始化模拟运行器 Args: config_path: 配置文件路径 (simulation_config.json) + wait_for_commands: 模拟完成后是否等待命令(默认True) """ self.config_path = config_path self.config = self._load_config() self.simulation_dir = os.path.dirname(config_path) + self.wait_for_commands = wait_for_commands + self.env = None + self.agent_graph = None + self.ipc_handler = None def _load_config(self) -> Dict[str, Any]: """加载配置文件""" @@ -269,6 +533,7 @@ class TwitterSimulationRunner: print("OASIS Twitter模拟") print(f"配置文件: {self.config_path}") print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}") + print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}") print("=" * 60) # 加载时间配置 @@ -305,7 +570,7 @@ class TwitterSimulationRunner: print(f"错误: Profile文件不存在: {profile_path}") return - agent_graph = await generate_twitter_agent_graph( + self.agent_graph = await generate_twitter_agent_graph( profile_path=profile_path, model=model, available_actions=self.AVAILABLE_ACTIONS, @@ -319,16 +584,20 @@ class TwitterSimulationRunner: # 创建环境 print("创建OASIS环境...") - env = oasis.make( - agent_graph=agent_graph, + self.env = oasis.make( + agent_graph=self.agent_graph, platform=oasis.DefaultPlatformType.TWITTER, database_path=db_path, semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 ) - await env.reset() + await self.env.reset() print("环境初始化完成\n") + # 初始化IPC处理器 + self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph) + self.ipc_handler.update_status("running") + # 执行初始事件 event_config = self.config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) @@ -340,7 +609,7 @@ class TwitterSimulationRunner: agent_id = post.get("poster_agent_id", 0) content = post.get("content", "") try: - agent = env.agent_graph.get_agent(agent_id) + agent = self.env.agent_graph.get_agent(agent_id) initial_actions[agent] = ManualAction( action_type=ActionType.CREATE_POST, action_args={"content": content} @@ -349,7 +618,7 @@ class TwitterSimulationRunner: print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}") if initial_actions: - await env.step(initial_actions) + await self.env.step(initial_actions) print(f" 已发布 {len(initial_actions)} 条初始帖子") # 主模拟循环 @@ -364,7 +633,7 @@ class TwitterSimulationRunner: # 获取本轮激活的Agent active_agents = self._get_active_agents_for_round( - env, simulated_hour, round_num + self.env, simulated_hour, round_num ) if not active_agents: @@ -377,7 +646,7 @@ class TwitterSimulationRunner: } # 执行动作 - await env.step(actions) + await self.env.step(actions) # 打印进度 if (round_num + 1) % 10 == 0 or round_num == 0: @@ -388,13 +657,39 @@ class TwitterSimulationRunner: f"- {len(active_agents)} agents active " f"- elapsed: {elapsed:.1f}s") - # 关闭环境 - await env.close() - total_elapsed = (datetime.now() - start_time).total_seconds() - print(f"\n模拟完成!") + print(f"\n模拟循环完成!") print(f" - 总耗时: {total_elapsed:.1f}秒") print(f" - 数据库: {db_path}") + + # 是否进入等待命令模式 + if self.wait_for_commands: + print("\n" + "=" * 60) + print("进入等待命令模式 - 环境保持运行") + print("支持的命令: interview, batch_interview, close_env") + print("=" * 60) + + self.ipc_handler.update_status("alive") + + # 等待命令循环 + try: + while True: + should_continue = await self.ipc_handler.process_commands() + if not should_continue: + break + await asyncio.sleep(0.5) # 轮询间隔 + except KeyboardInterrupt: + print("\n收到中断信号") + except Exception as e: + print(f"\n命令处理出错: {e}") + + print("\n关闭环境...") + + # 关闭环境 + self.ipc_handler.update_status("stopped") + await self.env.close() + + print("环境已关闭") print("=" * 60) @@ -412,6 +707,12 @@ async def main(): default=None, help='最大模拟轮数(可选,用于截断过长的模拟)' ) + parser.add_argument( + '--no-wait', + action='store_true', + default=False, + help='模拟完成后立即关闭环境,不进入等待命令模式' + ) args = parser.parse_args() @@ -423,10 +724,12 @@ async def main(): simulation_dir = os.path.dirname(args.config) or "." setup_oasis_logging(os.path.join(simulation_dir, "log")) - runner = TwitterSimulationRunner(args.config) + runner = TwitterSimulationRunner( + config_path=args.config, + wait_for_commands=not args.no_wait + ) await runner.run(max_rounds=args.max_rounds) if __name__ == "__main__": asyncio.run(main()) -