Implement Interview feature for agent interactions in simulations
- Added a new Interview module to facilitate interactions with agents post-simulation, allowing for single and batch interviews. - Introduced IPC communication mechanism for command and response handling between the Flask backend and simulation scripts. - Updated README.md to include detailed instructions on the new Interview functionality, including API endpoints and usage examples. - Enhanced simulation scripts to support waiting for commands after completion, improving user control over the simulation environment. - Implemented error handling and logging for interview processes, ensuring robust operation and traceability.
This commit is contained in:
parent
29bff9ca27
commit
1042d50306
8 changed files with 2963 additions and 70 deletions
|
|
@ -73,6 +73,11 @@
|
|||
启动模拟 → 运行OASIS脚本 → 实时监控 → 记录动作 → (可选)更新Zep图谱记忆 → 状态查询
|
||||
```
|
||||
|
||||
4. **Interview采访流程**:
|
||||
```
|
||||
模拟完成 → 环境进入等待模式 → 发送Interview命令 → Agent回答 → 获取结果 → (可选)关闭环境
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 技术栈
|
||||
|
|
@ -152,6 +157,7 @@ backend/
|
|||
│ ├── simulation_config_generator.py # 配置生成
|
||||
│ ├── simulation_manager.py # 模拟管理
|
||||
│ ├── simulation_runner.py # 模拟运行
|
||||
│ ├── simulation_ipc.py # 模拟IPC通信(Interview功能)
|
||||
│ └── zep_graph_memory_updater.py # 图谱记忆动态更新
|
||||
└── utils/ # 工具类
|
||||
├── __init__.py
|
||||
|
|
@ -211,12 +217,43 @@ backend/
|
|||
4. 解析动作日志(actions.jsonl)
|
||||
5. (可选)将Agent活动实时更新到Zep图谱
|
||||
6. 实时更新运行状态
|
||||
7. 支持停止/暂停/恢复
|
||||
7. 模拟完成后进入等待命令模式
|
||||
8. 支持停止/暂停/恢复
|
||||
|
||||
**核心服务**:
|
||||
- `SimulationRunner`: 模拟运行器
|
||||
- `ZepGraphMemoryUpdater`: 图谱记忆动态更新器
|
||||
|
||||
### 4. Agent采访(Interview)模块
|
||||
|
||||
**功能**: 在模拟完成后对Agent进行采访
|
||||
|
||||
**特点**:
|
||||
- **模拟状态持久化**: 模拟完成后环境不立即关闭,进入等待命令模式
|
||||
- **IPC通信机制**: 通过文件系统在Flask后端和模拟脚本之间通信
|
||||
- **单个采访**: 对指定Agent提问并获取回答
|
||||
- **批量采访**: 同时对多个Agent提不同问题
|
||||
- **全局采访**: 使用相同问题采访所有Agent
|
||||
- **采访历史**: 从数据库读取所有Interview记录
|
||||
|
||||
**核心服务**:
|
||||
- `SimulationIPCClient`: IPC客户端(Flask端使用)
|
||||
- `SimulationIPCServer`: IPC服务器(模拟脚本端使用)
|
||||
|
||||
**工作原理**:
|
||||
```
|
||||
Flask后端 模拟脚本
|
||||
│ │
|
||||
│ 写入命令文件 │
|
||||
│ ─────────────────────────→│
|
||||
│ │ 轮询命令目录
|
||||
│ │ 执行Interview
|
||||
│ │ 写入响应文件
|
||||
│←───────────────────────── │
|
||||
│ 读取响应文件 │
|
||||
│ │
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API接口文档
|
||||
|
|
@ -630,6 +667,290 @@ backend/
|
|||
|
||||
---
|
||||
|
||||
### Interview 采访接口
|
||||
|
||||
> **注意**: 所有Interview接口的参数都通过请求体(JSON)传递,包括simulation_id。
|
||||
>
|
||||
> **双平台模式说明**: 当不指定`platform`参数时,双平台模拟会同时采访两个平台并返回整合结果。
|
||||
|
||||
#### 1. 采访单个Agent
|
||||
|
||||
**接口**: `POST /api/simulation/interview`
|
||||
|
||||
**请求参数**:
|
||||
```json
|
||||
{
|
||||
"simulation_id": "sim_xxxx",
|
||||
"agent_id": 0,
|
||||
"prompt": "你对这件事有什么看法?",
|
||||
"platform": "reddit",
|
||||
"timeout": 60
|
||||
}
|
||||
```
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| simulation_id | String | 是 | - | 模拟ID |
|
||||
| agent_id | Integer | 是 | - | Agent ID |
|
||||
| prompt | String | 是 | - | 采访问题 |
|
||||
| platform | String | 否 | null | 指定平台(twitter/reddit),不指定则双平台同时采访 |
|
||||
| timeout | Integer | 否 | 60 | 超时时间(秒) |
|
||||
|
||||
**返回示例(指定单平台)**:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"success": true,
|
||||
"agent_id": 0,
|
||||
"prompt": "你对这件事有什么看法?",
|
||||
"result": {
|
||||
"agent_id": 0,
|
||||
"response": "我认为这件事反映了...",
|
||||
"platform": "reddit",
|
||||
"timestamp": "2025-12-08T10:00:00"
|
||||
},
|
||||
"timestamp": "2025-12-08T10:00:01"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**返回示例(不指定platform,双平台模式)**:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"success": true,
|
||||
"agent_id": 0,
|
||||
"prompt": "你对这件事有什么看法?",
|
||||
"result": {
|
||||
"agent_id": 0,
|
||||
"prompt": "你对这件事有什么看法?",
|
||||
"platforms": {
|
||||
"twitter": {
|
||||
"agent_id": 0,
|
||||
"response": "从Twitter视角来看...",
|
||||
"platform": "twitter",
|
||||
"timestamp": "2025-12-08T10:00:00"
|
||||
},
|
||||
"reddit": {
|
||||
"agent_id": 0,
|
||||
"response": "作为Reddit用户,我认为...",
|
||||
"platform": "reddit",
|
||||
"timestamp": "2025-12-08T10:00:00"
|
||||
}
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-12-08T10:00:01"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**注意**: 此功能需要模拟环境处于运行状态(完成模拟循环后进入等待命令模式)
|
||||
|
||||
---
|
||||
|
||||
#### 2. 批量采访多个Agent
|
||||
|
||||
**接口**: `POST /api/simulation/interview/batch`
|
||||
|
||||
**请求参数**:
|
||||
```json
|
||||
{
|
||||
"simulation_id": "sim_xxxx",
|
||||
"interviews": [
|
||||
{"agent_id": 0, "prompt": "你对A有什么看法?", "platform": "twitter"},
|
||||
{"agent_id": 1, "prompt": "你对B有什么看法?", "platform": "reddit"},
|
||||
{"agent_id": 2, "prompt": "你对C有什么看法?"}
|
||||
],
|
||||
"platform": "reddit",
|
||||
"timeout": 120
|
||||
}
|
||||
```
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| simulation_id | String | 是 | - | 模拟ID |
|
||||
| interviews | Array | 是 | - | 采访列表,每项包含agent_id、prompt和可选的platform |
|
||||
| platform | String | 否 | null | 默认平台(被每项的platform覆盖),不指定则双平台同时采访 |
|
||||
| timeout | Integer | 否 | 120 | 超时时间(秒) |
|
||||
|
||||
**返回示例**:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"success": true,
|
||||
"interviews_count": 3,
|
||||
"result": {
|
||||
"interviews_count": 6,
|
||||
"results": {
|
||||
"twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"},
|
||||
"reddit_1": {"agent_id": 1, "response": "...", "platform": "reddit"},
|
||||
"twitter_2": {"agent_id": 2, "response": "...", "platform": "twitter"},
|
||||
"reddit_2": {"agent_id": 2, "response": "...", "platform": "reddit"}
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-12-08T10:00:01"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
#### 3. 全局采访(采访所有Agent)
|
||||
|
||||
**接口**: `POST /api/simulation/interview/all`
|
||||
|
||||
**请求参数**:
|
||||
```json
|
||||
{
|
||||
"simulation_id": "sim_xxxx",
|
||||
"prompt": "你对这件事整体有什么看法?",
|
||||
"platform": "reddit",
|
||||
"timeout": 180
|
||||
}
|
||||
```
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| simulation_id | String | 是 | - | 模拟ID |
|
||||
| prompt | String | 是 | - | 采访问题(所有Agent使用相同问题) |
|
||||
| platform | String | 否 | null | 指定平台(twitter/reddit),不指定则双平台同时采访 |
|
||||
| timeout | Integer | 否 | 180 | 超时时间(秒) |
|
||||
|
||||
**返回示例**:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"success": true,
|
||||
"interviews_count": 50,
|
||||
"result": {
|
||||
"interviews_count": 100,
|
||||
"results": {
|
||||
"twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"},
|
||||
"reddit_0": {"agent_id": 0, "response": "...", "platform": "reddit"},
|
||||
"twitter_1": {"agent_id": 1, "response": "...", "platform": "twitter"},
|
||||
"reddit_1": {"agent_id": 1, "response": "...", "platform": "reddit"},
|
||||
...
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-12-08T10:00:01"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
#### 4. 获取Interview历史
|
||||
|
||||
**接口**: `POST /api/simulation/interview/history`
|
||||
|
||||
**请求参数**:
|
||||
```json
|
||||
{
|
||||
"simulation_id": "sim_xxxx",
|
||||
"platform": "reddit",
|
||||
"agent_id": 0,
|
||||
"limit": 100
|
||||
}
|
||||
```
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| simulation_id | String | 是 | - | 模拟ID |
|
||||
| platform | String | 否 | reddit | 平台类型(reddit/twitter) |
|
||||
| agent_id | Integer | 否 | - | 过滤Agent ID |
|
||||
| limit | Integer | 否 | 100 | 返回数量限制 |
|
||||
|
||||
**返回示例**:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"count": 10,
|
||||
"history": [
|
||||
{
|
||||
"agent_id": 0,
|
||||
"response": "我认为...",
|
||||
"prompt": "你对这件事有什么看法?",
|
||||
"timestamp": "2025-12-08T10:00:00",
|
||||
"platform": "reddit"
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
#### 5. 获取模拟环境状态
|
||||
|
||||
**接口**: `POST /api/simulation/env-status`
|
||||
|
||||
**请求参数**:
|
||||
```json
|
||||
{
|
||||
"simulation_id": "sim_xxxx"
|
||||
}
|
||||
```
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| simulation_id | String | 是 | - | 模拟ID |
|
||||
|
||||
**返回示例**:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"simulation_id": "sim_xxxx",
|
||||
"env_alive": true,
|
||||
"twitter_available": true,
|
||||
"reddit_available": true,
|
||||
"message": "环境正在运行,可以接收Interview命令"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
#### 6. 关闭模拟环境
|
||||
|
||||
**接口**: `POST /api/simulation/close-env`
|
||||
|
||||
**请求参数**:
|
||||
```json
|
||||
{
|
||||
"simulation_id": "sim_10b494550540",
|
||||
"timeout": 30
|
||||
}
|
||||
```
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| simulation_id | String | 是 | - | 模拟ID |
|
||||
| timeout | Integer | 否 | 30 | 超时时间(秒) |
|
||||
|
||||
**返回示例**:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"success": true,
|
||||
"message": "环境关闭命令已发送",
|
||||
"result": {"message": "环境即将关闭"},
|
||||
"timestamp": "2025-12-08T10:00:01"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**注意**: 此接口与 `/stop` 不同:
|
||||
- `/stop`: 强制终止模拟进程
|
||||
- `/close-env`: 优雅地关闭环境,让模拟进程正常退出
|
||||
|
||||
#### 6. 获取运行状态
|
||||
|
||||
**接口**: `GET /api/simulation/{simulation_id}/run-status`
|
||||
|
|
@ -1395,6 +1716,92 @@ POST /api/simulation/start
|
|||
|
||||
---
|
||||
|
||||
### 9. SimulationIPCClient/Server (IPC通信模块)
|
||||
|
||||
**文件**: `app/services/simulation_ipc.py`
|
||||
|
||||
**功能**: 实现Flask后端与模拟脚本之间的进程间通信
|
||||
|
||||
**核心类**:
|
||||
|
||||
```python
|
||||
class SimulationIPCClient:
|
||||
"""IPC客户端(Flask端使用)"""
|
||||
|
||||
def send_interview(agent_id: int, prompt: str, timeout: float) -> IPCResponse:
|
||||
"""发送单个Agent采访命令"""
|
||||
|
||||
def send_batch_interview(interviews: List[Dict], timeout: float) -> IPCResponse:
|
||||
"""发送批量采访命令"""
|
||||
|
||||
def send_close_env(timeout: float) -> IPCResponse:
|
||||
"""发送关闭环境命令"""
|
||||
|
||||
def check_env_alive() -> bool:
|
||||
"""检查模拟环境是否存活"""
|
||||
```
|
||||
|
||||
```python
|
||||
class SimulationIPCServer:
|
||||
"""IPC服务器(模拟脚本端使用)"""
|
||||
|
||||
def poll_commands() -> Optional[IPCCommand]:
|
||||
"""轮询获取待处理命令"""
|
||||
|
||||
def send_response(response: IPCResponse):
|
||||
"""发送响应"""
|
||||
```
|
||||
|
||||
**命令类型**:
|
||||
|
||||
| 命令类型 | 说明 |
|
||||
|----------|------|
|
||||
| interview | 单个Agent采访 |
|
||||
| batch_interview | 批量采访 |
|
||||
| close_env | 关闭环境 |
|
||||
|
||||
**文件结构**:
|
||||
|
||||
```
|
||||
uploads/simulations/sim_xxx/
|
||||
├── ipc_commands/ # 命令文件目录
|
||||
│ └── {command_id}.json # 待处理命令
|
||||
├── ipc_responses/ # 响应文件目录
|
||||
│ └── {command_id}.json # 命令响应
|
||||
└── env_status.json # 环境状态文件
|
||||
```
|
||||
|
||||
**使用示例**:
|
||||
|
||||
```python
|
||||
# Flask端发送Interview命令
|
||||
from app.services import SimulationRunner
|
||||
|
||||
# 单个采访
|
||||
result = SimulationRunner.interview_agent(
|
||||
simulation_id="sim_xxx",
|
||||
agent_id=0,
|
||||
prompt="你对这件事有什么看法?"
|
||||
)
|
||||
|
||||
# 批量采访
|
||||
result = SimulationRunner.interview_agents_batch(
|
||||
simulation_id="sim_xxx",
|
||||
interviews=[
|
||||
{"agent_id": 0, "prompt": "问题A"},
|
||||
{"agent_id": 1, "prompt": "问题B"}
|
||||
]
|
||||
)
|
||||
|
||||
# 全局采访
|
||||
result = SimulationRunner.interview_all_agents(
|
||||
simulation_id="sim_xxx",
|
||||
prompt="你认为事件会如何发展?"
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 工具类
|
||||
|
||||
### 1. FileParser (文件解析器)
|
||||
|
|
@ -1661,7 +2068,35 @@ curl -X POST http://localhost:5001/api/simulation/start \
|
|||
# Step 8: 实时查询运行状态
|
||||
curl http://localhost:5001/api/simulation/{sim_xxx}/run-status
|
||||
|
||||
# Step 9: 停止模拟
|
||||
# Step 9: 检查环境状态(模拟完成后环境会进入等待命令模式)
|
||||
curl http://localhost:5001/api/simulation/{sim_xxx}/env-status
|
||||
|
||||
# Step 10: 采访单个Agent
|
||||
curl -X POST http://localhost:5001/api/simulation/{sim_xxx}/interview \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"agent_id": 0,
|
||||
"prompt": "你对这件事有什么看法?"
|
||||
}'
|
||||
|
||||
# Step 11: 全局采访(采访所有Agent)
|
||||
curl -X POST http://localhost:5001/api/simulation/{sim_xxx}/interview/all \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"prompt": "你认为事件的后续发展会如何?"
|
||||
}'
|
||||
|
||||
# Step 12: 获取Interview历史
|
||||
curl http://localhost:5001/api/simulation/{sim_xxx}/interview/history
|
||||
|
||||
# Step 13: 关闭模拟环境(优雅退出)
|
||||
curl -X POST http://localhost:5001/api/simulation/close-env \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"simulation_id": "sim_xxx"
|
||||
}'
|
||||
|
||||
# 或者强制停止模拟
|
||||
curl -X POST http://localhost:5001/api/simulation/stop \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
|
|
@ -1830,6 +2265,31 @@ MIT License
|
|||
|
||||
---
|
||||
|
||||
**最后更新**: 2025-12-05
|
||||
**版本**: v1.1.0
|
||||
**最后更新**: 2025-12-08
|
||||
**版本**: v1.2.0
|
||||
|
||||
### 更新日志
|
||||
|
||||
**v1.2.0 (2025-12-08)**:
|
||||
- 新增 Interview 采访功能
|
||||
- 支持单个Agent采访
|
||||
- 支持批量采访多个Agent
|
||||
- 支持全局采访(所有Agent使用相同问题)
|
||||
- 支持获取Interview历史记录
|
||||
- 新增模拟状态持久化
|
||||
- 模拟完成后环境不立即关闭,进入等待命令模式
|
||||
- 支持优雅关闭环境命令
|
||||
- 新增 IPC 通信机制
|
||||
- Flask后端与模拟脚本之间的进程间通信
|
||||
- 基于文件系统的命令/响应模式
|
||||
|
||||
**v1.1.0 (2025-12-05)**:
|
||||
- 新增图谱记忆动态更新功能
|
||||
- 支持 max_rounds 参数限制模拟轮数
|
||||
|
||||
**v1.0.0**:
|
||||
- 初始版本发布
|
||||
- 支持知识图谱构建
|
||||
- 支持Agent人设生成
|
||||
- 支持双平台模拟
|
||||
|
||||
|
|
|
|||
|
|
@ -1723,3 +1723,568 @@ def get_simulation_comments(simulation_id: str):
|
|||
"error": str(e),
|
||||
"traceback": traceback.format_exc()
|
||||
}), 500
|
||||
|
||||
|
||||
# ============== Interview 采访接口 ==============
|
||||
|
||||
@simulation_bp.route('/interview', methods=['POST'])
|
||||
def interview_agent():
|
||||
"""
|
||||
采访单个Agent
|
||||
|
||||
注意:此功能需要模拟环境处于运行状态(完成模拟循环后进入等待命令模式)
|
||||
|
||||
请求(JSON):
|
||||
{
|
||||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||||
"agent_id": 0, // 必填,Agent ID
|
||||
"prompt": "你对这件事有什么看法?", // 必填,采访问题
|
||||
"platform": "twitter", // 可选,指定平台(twitter/reddit)
|
||||
// 不指定时:双平台模拟同时采访两个平台
|
||||
"timeout": 60 // 可选,超时时间(秒),默认60
|
||||
}
|
||||
|
||||
返回(不指定platform,双平台模式):
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"agent_id": 0,
|
||||
"prompt": "你对这件事有什么看法?",
|
||||
"result": {
|
||||
"agent_id": 0,
|
||||
"prompt": "...",
|
||||
"platforms": {
|
||||
"twitter": {"agent_id": 0, "response": "...", "platform": "twitter"},
|
||||
"reddit": {"agent_id": 0, "response": "...", "platform": "reddit"}
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-12-08T10:00:01"
|
||||
}
|
||||
}
|
||||
|
||||
返回(指定platform):
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"agent_id": 0,
|
||||
"prompt": "你对这件事有什么看法?",
|
||||
"result": {
|
||||
"agent_id": 0,
|
||||
"response": "我认为...",
|
||||
"platform": "twitter",
|
||||
"timestamp": "2025-12-08T10:00:00"
|
||||
},
|
||||
"timestamp": "2025-12-08T10:00:01"
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = request.get_json() or {}
|
||||
|
||||
simulation_id = data.get('simulation_id')
|
||||
agent_id = data.get('agent_id')
|
||||
prompt = data.get('prompt')
|
||||
platform = data.get('platform') # 可选:twitter/reddit/None
|
||||
timeout = data.get('timeout', 60)
|
||||
|
||||
if not simulation_id:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "请提供 simulation_id"
|
||||
}), 400
|
||||
|
||||
if agent_id is None:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "请提供 agent_id"
|
||||
}), 400
|
||||
|
||||
if not prompt:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "请提供 prompt(采访问题)"
|
||||
}), 400
|
||||
|
||||
# 验证platform参数
|
||||
if platform and platform not in ("twitter", "reddit"):
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "platform 参数只能是 'twitter' 或 'reddit'"
|
||||
}), 400
|
||||
|
||||
# 检查环境状态
|
||||
if not SimulationRunner.check_env_alive(simulation_id):
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
|
||||
}), 400
|
||||
|
||||
result = SimulationRunner.interview_agent(
|
||||
simulation_id=simulation_id,
|
||||
agent_id=agent_id,
|
||||
prompt=prompt,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
return jsonify({
|
||||
"success": result.get("success", False),
|
||||
"data": result
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}), 400
|
||||
|
||||
except TimeoutError as e:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": f"等待Interview响应超时: {str(e)}"
|
||||
}), 504
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Interview失败: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc()
|
||||
}), 500
|
||||
|
||||
|
||||
@simulation_bp.route('/interview/batch', methods=['POST'])
|
||||
def interview_agents_batch():
|
||||
"""
|
||||
批量采访多个Agent
|
||||
|
||||
注意:此功能需要模拟环境处于运行状态
|
||||
|
||||
请求(JSON):
|
||||
{
|
||||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||||
"interviews": [ // 必填,采访列表
|
||||
{
|
||||
"agent_id": 0,
|
||||
"prompt": "你对A有什么看法?",
|
||||
"platform": "twitter" // 可选,指定该Agent的采访平台
|
||||
},
|
||||
{
|
||||
"agent_id": 1,
|
||||
"prompt": "你对B有什么看法?" // 不指定platform则使用默认值
|
||||
}
|
||||
],
|
||||
"platform": "reddit", // 可选,默认平台(被每项的platform覆盖)
|
||||
// 不指定时:双平台模拟每个Agent同时采访两个平台
|
||||
"timeout": 120 // 可选,超时时间(秒),默认120
|
||||
}
|
||||
|
||||
返回:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"interviews_count": 2,
|
||||
"result": {
|
||||
"interviews_count": 4,
|
||||
"results": {
|
||||
"twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"},
|
||||
"reddit_0": {"agent_id": 0, "response": "...", "platform": "reddit"},
|
||||
"twitter_1": {"agent_id": 1, "response": "...", "platform": "twitter"},
|
||||
"reddit_1": {"agent_id": 1, "response": "...", "platform": "reddit"}
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-12-08T10:00:01"
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = request.get_json() or {}
|
||||
|
||||
simulation_id = data.get('simulation_id')
|
||||
interviews = data.get('interviews')
|
||||
platform = data.get('platform') # 可选:twitter/reddit/None
|
||||
timeout = data.get('timeout', 120)
|
||||
|
||||
if not simulation_id:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "请提供 simulation_id"
|
||||
}), 400
|
||||
|
||||
if not interviews or not isinstance(interviews, list):
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "请提供 interviews(采访列表)"
|
||||
}), 400
|
||||
|
||||
# 验证platform参数
|
||||
if platform and platform not in ("twitter", "reddit"):
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "platform 参数只能是 'twitter' 或 'reddit'"
|
||||
}), 400
|
||||
|
||||
# 验证每个采访项
|
||||
for i, interview in enumerate(interviews):
|
||||
if 'agent_id' not in interview:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": f"采访列表第{i+1}项缺少 agent_id"
|
||||
}), 400
|
||||
if 'prompt' not in interview:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": f"采访列表第{i+1}项缺少 prompt"
|
||||
}), 400
|
||||
# 验证每项的platform(如果有)
|
||||
item_platform = interview.get('platform')
|
||||
if item_platform and item_platform not in ("twitter", "reddit"):
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": f"采访列表第{i+1}项的platform只能是 'twitter' 或 'reddit'"
|
||||
}), 400
|
||||
|
||||
# 检查环境状态
|
||||
if not SimulationRunner.check_env_alive(simulation_id):
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
|
||||
}), 400
|
||||
|
||||
result = SimulationRunner.interview_agents_batch(
|
||||
simulation_id=simulation_id,
|
||||
interviews=interviews,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
return jsonify({
|
||||
"success": result.get("success", False),
|
||||
"data": result
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}), 400
|
||||
|
||||
except TimeoutError as e:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": f"等待批量Interview响应超时: {str(e)}"
|
||||
}), 504
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量Interview失败: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc()
|
||||
}), 500
|
||||
|
||||
|
||||
@simulation_bp.route('/interview/all', methods=['POST'])
|
||||
def interview_all_agents():
|
||||
"""
|
||||
全局采访 - 使用相同问题采访所有Agent
|
||||
|
||||
注意:此功能需要模拟环境处于运行状态
|
||||
|
||||
请求(JSON):
|
||||
{
|
||||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||||
"prompt": "你对这件事整体有什么看法?", // 必填,采访问题(所有Agent使用相同问题)
|
||||
"platform": "reddit", // 可选,指定平台(twitter/reddit)
|
||||
// 不指定时:双平台模拟每个Agent同时采访两个平台
|
||||
"timeout": 180 // 可选,超时时间(秒),默认180
|
||||
}
|
||||
|
||||
返回:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"interviews_count": 50,
|
||||
"result": {
|
||||
"interviews_count": 100,
|
||||
"results": {
|
||||
"twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"},
|
||||
"reddit_0": {"agent_id": 0, "response": "...", "platform": "reddit"},
|
||||
...
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-12-08T10:00:01"
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = request.get_json() or {}
|
||||
|
||||
simulation_id = data.get('simulation_id')
|
||||
prompt = data.get('prompt')
|
||||
platform = data.get('platform') # 可选:twitter/reddit/None
|
||||
timeout = data.get('timeout', 180)
|
||||
|
||||
if not simulation_id:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "请提供 simulation_id"
|
||||
}), 400
|
||||
|
||||
if not prompt:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "请提供 prompt(采访问题)"
|
||||
}), 400
|
||||
|
||||
# 验证platform参数
|
||||
if platform and platform not in ("twitter", "reddit"):
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "platform 参数只能是 'twitter' 或 'reddit'"
|
||||
}), 400
|
||||
|
||||
# 检查环境状态
|
||||
if not SimulationRunner.check_env_alive(simulation_id):
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
|
||||
}), 400
|
||||
|
||||
result = SimulationRunner.interview_all_agents(
|
||||
simulation_id=simulation_id,
|
||||
prompt=prompt,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
return jsonify({
|
||||
"success": result.get("success", False),
|
||||
"data": result
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}), 400
|
||||
|
||||
except TimeoutError as e:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": f"等待全局Interview响应超时: {str(e)}"
|
||||
}), 504
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"全局Interview失败: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc()
|
||||
}), 500
|
||||
|
||||
|
||||
@simulation_bp.route('/interview/history', methods=['POST'])
|
||||
def get_interview_history():
|
||||
"""
|
||||
获取Interview历史记录
|
||||
|
||||
从模拟数据库中读取所有Interview记录
|
||||
|
||||
请求(JSON):
|
||||
{
|
||||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||||
"platform": "reddit", // 可选,平台类型(reddit/twitter),默认reddit
|
||||
"agent_id": 0, // 可选,过滤Agent ID
|
||||
"limit": 100 // 可选,返回数量,默认100
|
||||
}
|
||||
|
||||
返回:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"count": 10,
|
||||
"history": [
|
||||
{
|
||||
"agent_id": 0,
|
||||
"response": "我认为...",
|
||||
"prompt": "你对这件事有什么看法?",
|
||||
"timestamp": "2025-12-08T10:00:00",
|
||||
"platform": "reddit"
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = request.get_json() or {}
|
||||
|
||||
simulation_id = data.get('simulation_id')
|
||||
platform = data.get('platform', 'reddit')
|
||||
agent_id = data.get('agent_id')
|
||||
limit = data.get('limit', 100)
|
||||
|
||||
if not simulation_id:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "请提供 simulation_id"
|
||||
}), 400
|
||||
|
||||
history = SimulationRunner.get_interview_history(
|
||||
simulation_id=simulation_id,
|
||||
platform=platform,
|
||||
agent_id=agent_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"data": {
|
||||
"count": len(history),
|
||||
"history": history
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Interview历史失败: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc()
|
||||
}), 500
|
||||
|
||||
|
||||
@simulation_bp.route('/env-status', methods=['POST'])
|
||||
def get_env_status():
|
||||
"""
|
||||
获取模拟环境状态
|
||||
|
||||
检查模拟环境是否存活(可以接收Interview命令)
|
||||
|
||||
请求(JSON):
|
||||
{
|
||||
"simulation_id": "sim_xxxx" // 必填,模拟ID
|
||||
}
|
||||
|
||||
返回:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"simulation_id": "sim_xxxx",
|
||||
"env_alive": true,
|
||||
"twitter_available": true,
|
||||
"reddit_available": true,
|
||||
"message": "环境正在运行,可以接收Interview命令"
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = request.get_json() or {}
|
||||
|
||||
simulation_id = data.get('simulation_id')
|
||||
|
||||
if not simulation_id:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "请提供 simulation_id"
|
||||
}), 400
|
||||
|
||||
env_alive = SimulationRunner.check_env_alive(simulation_id)
|
||||
|
||||
# 获取更详细的状态信息
|
||||
env_status = SimulationRunner.get_env_status_detail(simulation_id)
|
||||
|
||||
if env_alive:
|
||||
message = "环境正在运行,可以接收Interview命令"
|
||||
else:
|
||||
message = "环境未运行或已关闭"
|
||||
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"data": {
|
||||
"simulation_id": simulation_id,
|
||||
"env_alive": env_alive,
|
||||
"twitter_available": env_status.get("twitter_available", False),
|
||||
"reddit_available": env_status.get("reddit_available", False),
|
||||
"message": message
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取环境状态失败: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc()
|
||||
}), 500
|
||||
|
||||
|
||||
@simulation_bp.route('/close-env', methods=['POST'])
|
||||
def close_simulation_env():
|
||||
"""
|
||||
关闭模拟环境
|
||||
|
||||
向模拟发送关闭环境命令,使其优雅退出等待命令模式。
|
||||
|
||||
注意:这不同于 /stop 接口,/stop 会强制终止进程,
|
||||
而此接口会让模拟优雅地关闭环境并退出。
|
||||
|
||||
请求(JSON):
|
||||
{
|
||||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||||
"timeout": 30 // 可选,超时时间(秒),默认30
|
||||
}
|
||||
|
||||
返回:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"message": "环境关闭命令已发送",
|
||||
"result": {...},
|
||||
"timestamp": "2025-12-08T10:00:01"
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = request.get_json() or {}
|
||||
|
||||
simulation_id = data.get('simulation_id')
|
||||
timeout = data.get('timeout', 30)
|
||||
|
||||
if not simulation_id:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "请提供 simulation_id"
|
||||
}), 400
|
||||
|
||||
result = SimulationRunner.close_simulation_env(
|
||||
simulation_id=simulation_id,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 更新模拟状态
|
||||
manager = SimulationManager()
|
||||
state = manager.get_simulation(simulation_id)
|
||||
if state:
|
||||
state.status = SimulationStatus.COMPLETED
|
||||
manager._save_simulation_state(state)
|
||||
|
||||
return jsonify({
|
||||
"success": result.get("success", False),
|
||||
"data": result
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}), 400
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"关闭环境失败: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc()
|
||||
}), 500
|
||||
|
|
|
|||
|
|
@ -28,6 +28,14 @@ from .zep_graph_memory_updater import (
|
|||
ZepGraphMemoryManager,
|
||||
AgentActivity
|
||||
)
|
||||
from .simulation_ipc import (
|
||||
SimulationIPCClient,
|
||||
SimulationIPCServer,
|
||||
IPCCommand,
|
||||
IPCResponse,
|
||||
CommandType,
|
||||
CommandStatus
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'OntologyGenerator',
|
||||
|
|
@ -55,5 +63,11 @@ __all__ = [
|
|||
'ZepGraphMemoryUpdater',
|
||||
'ZepGraphMemoryManager',
|
||||
'AgentActivity',
|
||||
'SimulationIPCClient',
|
||||
'SimulationIPCServer',
|
||||
'IPCCommand',
|
||||
'IPCResponse',
|
||||
'CommandType',
|
||||
'CommandStatus',
|
||||
]
|
||||
|
||||
|
|
|
|||
394
backend/app/services/simulation_ipc.py
Normal file
394
backend/app/services/simulation_ipc.py
Normal file
|
|
@ -0,0 +1,394 @@
|
|||
"""
|
||||
模拟IPC通信模块
|
||||
用于Flask后端和模拟脚本之间的进程间通信
|
||||
|
||||
通过文件系统实现简单的命令/响应模式:
|
||||
1. Flask写入命令到 commands/ 目录
|
||||
2. 模拟脚本轮询命令目录,执行命令并写入响应到 responses/ 目录
|
||||
3. Flask轮询响应目录获取结果
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.simulation_ipc')
|
||||
|
||||
|
||||
class CommandType(str, Enum):
|
||||
"""命令类型"""
|
||||
INTERVIEW = "interview" # 单个Agent采访
|
||||
BATCH_INTERVIEW = "batch_interview" # 批量采访
|
||||
CLOSE_ENV = "close_env" # 关闭环境
|
||||
|
||||
|
||||
class CommandStatus(str, Enum):
|
||||
"""命令状态"""
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCCommand:
|
||||
"""IPC命令"""
|
||||
command_id: str
|
||||
command_type: CommandType
|
||||
args: Dict[str, Any]
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"command_id": self.command_id,
|
||||
"command_type": self.command_type.value,
|
||||
"args": self.args,
|
||||
"timestamp": self.timestamp
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'IPCCommand':
|
||||
return cls(
|
||||
command_id=data["command_id"],
|
||||
command_type=CommandType(data["command_type"]),
|
||||
args=data.get("args", {}),
|
||||
timestamp=data.get("timestamp", datetime.now().isoformat())
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCResponse:
|
||||
"""IPC响应"""
|
||||
command_id: str
|
||||
status: CommandStatus
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"command_id": self.command_id,
|
||||
"status": self.status.value,
|
||||
"result": self.result,
|
||||
"error": self.error,
|
||||
"timestamp": self.timestamp
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'IPCResponse':
|
||||
return cls(
|
||||
command_id=data["command_id"],
|
||||
status=CommandStatus(data["status"]),
|
||||
result=data.get("result"),
|
||||
error=data.get("error"),
|
||||
timestamp=data.get("timestamp", datetime.now().isoformat())
|
||||
)
|
||||
|
||||
|
||||
class SimulationIPCClient:
|
||||
"""
|
||||
模拟IPC客户端(Flask端使用)
|
||||
|
||||
用于向模拟进程发送命令并等待响应
|
||||
"""
|
||||
|
||||
def __init__(self, simulation_dir: str):
|
||||
"""
|
||||
初始化IPC客户端
|
||||
|
||||
Args:
|
||||
simulation_dir: 模拟数据目录
|
||||
"""
|
||||
self.simulation_dir = simulation_dir
|
||||
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
||||
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
def send_command(
|
||||
self,
|
||||
command_type: CommandType,
|
||||
args: Dict[str, Any],
|
||||
timeout: float = 60.0,
|
||||
poll_interval: float = 0.5
|
||||
) -> IPCResponse:
|
||||
"""
|
||||
发送命令并等待响应
|
||||
|
||||
Args:
|
||||
command_type: 命令类型
|
||||
args: 命令参数
|
||||
timeout: 超时时间(秒)
|
||||
poll_interval: 轮询间隔(秒)
|
||||
|
||||
Returns:
|
||||
IPCResponse
|
||||
|
||||
Raises:
|
||||
TimeoutError: 等待响应超时
|
||||
"""
|
||||
command_id = str(uuid.uuid4())
|
||||
command = IPCCommand(
|
||||
command_id=command_id,
|
||||
command_type=command_type,
|
||||
args=args
|
||||
)
|
||||
|
||||
# 写入命令文件
|
||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||
with open(command_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(command.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"发送IPC命令: {command_type.value}, command_id={command_id}")
|
||||
|
||||
# 等待响应
|
||||
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
if os.path.exists(response_file):
|
||||
try:
|
||||
with open(response_file, 'r', encoding='utf-8') as f:
|
||||
response_data = json.load(f)
|
||||
response = IPCResponse.from_dict(response_data)
|
||||
|
||||
# 清理命令和响应文件
|
||||
try:
|
||||
os.remove(command_file)
|
||||
os.remove(response_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
logger.info(f"收到IPC响应: command_id={command_id}, status={response.status.value}")
|
||||
return response
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"解析响应失败: {e}")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
# 超时
|
||||
logger.error(f"等待IPC响应超时: command_id={command_id}")
|
||||
|
||||
# 清理命令文件
|
||||
try:
|
||||
os.remove(command_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
raise TimeoutError(f"等待命令响应超时 ({timeout}秒)")
|
||||
|
||||
def send_interview(
|
||||
self,
|
||||
agent_id: int,
|
||||
prompt: str,
|
||||
platform: str = None,
|
||||
timeout: float = 60.0
|
||||
) -> IPCResponse:
|
||||
"""
|
||||
发送单个Agent采访命令
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
prompt: 采访问题
|
||||
platform: 指定平台(可选)
|
||||
- "twitter": 只采访Twitter平台
|
||||
- "reddit": 只采访Reddit平台
|
||||
- None: 双平台模拟时同时采访两个平台,单平台模拟时采访该平台
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
IPCResponse,result字段包含采访结果
|
||||
"""
|
||||
args = {
|
||||
"agent_id": agent_id,
|
||||
"prompt": prompt
|
||||
}
|
||||
if platform:
|
||||
args["platform"] = platform
|
||||
|
||||
return self.send_command(
|
||||
command_type=CommandType.INTERVIEW,
|
||||
args=args,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
def send_batch_interview(
|
||||
self,
|
||||
interviews: List[Dict[str, Any]],
|
||||
platform: str = None,
|
||||
timeout: float = 120.0
|
||||
) -> IPCResponse:
|
||||
"""
|
||||
发送批量采访命令
|
||||
|
||||
Args:
|
||||
interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)}
|
||||
platform: 默认平台(可选,会被每个采访项的platform覆盖)
|
||||
- "twitter": 默认只采访Twitter平台
|
||||
- "reddit": 默认只采访Reddit平台
|
||||
- None: 双平台模拟时每个Agent同时采访两个平台
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
IPCResponse,result字段包含所有采访结果
|
||||
"""
|
||||
args = {"interviews": interviews}
|
||||
if platform:
|
||||
args["platform"] = platform
|
||||
|
||||
return self.send_command(
|
||||
command_type=CommandType.BATCH_INTERVIEW,
|
||||
args=args,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
def send_close_env(self, timeout: float = 30.0) -> IPCResponse:
|
||||
"""
|
||||
发送关闭环境命令
|
||||
|
||||
Args:
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
IPCResponse
|
||||
"""
|
||||
return self.send_command(
|
||||
command_type=CommandType.CLOSE_ENV,
|
||||
args={},
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
def check_env_alive(self) -> bool:
|
||||
"""
|
||||
检查模拟环境是否存活
|
||||
|
||||
通过检查 env_status.json 文件来判断
|
||||
"""
|
||||
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
||||
if not os.path.exists(status_file):
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(status_file, 'r', encoding='utf-8') as f:
|
||||
status = json.load(f)
|
||||
return status.get("status") == "alive"
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
class SimulationIPCServer:
|
||||
"""
|
||||
模拟IPC服务器(模拟脚本端使用)
|
||||
|
||||
轮询命令目录,执行命令并返回响应
|
||||
"""
|
||||
|
||||
def __init__(self, simulation_dir: str):
|
||||
"""
|
||||
初始化IPC服务器
|
||||
|
||||
Args:
|
||||
simulation_dir: 模拟数据目录
|
||||
"""
|
||||
self.simulation_dir = simulation_dir
|
||||
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
||||
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
# 环境状态
|
||||
self._running = False
|
||||
|
||||
def start(self):
|
||||
"""标记服务器为运行状态"""
|
||||
self._running = True
|
||||
self._update_env_status("alive")
|
||||
|
||||
def stop(self):
|
||||
"""标记服务器为停止状态"""
|
||||
self._running = False
|
||||
self._update_env_status("stopped")
|
||||
|
||||
def _update_env_status(self, status: str):
|
||||
"""更新环境状态文件"""
|
||||
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
||||
with open(status_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"status": status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def poll_commands(self) -> Optional[IPCCommand]:
|
||||
"""
|
||||
轮询命令目录,返回第一个待处理的命令
|
||||
|
||||
Returns:
|
||||
IPCCommand 或 None
|
||||
"""
|
||||
if not os.path.exists(self.commands_dir):
|
||||
return None
|
||||
|
||||
# 按时间排序获取命令文件
|
||||
command_files = []
|
||||
for filename in os.listdir(self.commands_dir):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(self.commands_dir, filename)
|
||||
command_files.append((filepath, os.path.getmtime(filepath)))
|
||||
|
||||
command_files.sort(key=lambda x: x[1])
|
||||
|
||||
for filepath, _ in command_files:
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return IPCCommand.from_dict(data)
|
||||
except (json.JSONDecodeError, KeyError, OSError) as e:
|
||||
logger.warning(f"读取命令文件失败: {filepath}, {e}")
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def send_response(self, response: IPCResponse):
|
||||
"""
|
||||
发送响应
|
||||
|
||||
Args:
|
||||
response: IPC响应
|
||||
"""
|
||||
response_file = os.path.join(self.responses_dir, f"{response.command_id}.json")
|
||||
with open(response_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(response.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 删除命令文件
|
||||
command_file = os.path.join(self.commands_dir, f"{response.command_id}.json")
|
||||
try:
|
||||
os.remove(command_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def send_success(self, command_id: str, result: Dict[str, Any]):
|
||||
"""发送成功响应"""
|
||||
self.send_response(IPCResponse(
|
||||
command_id=command_id,
|
||||
status=CommandStatus.COMPLETED,
|
||||
result=result
|
||||
))
|
||||
|
||||
def send_error(self, command_id: str, error: str):
|
||||
"""发送错误响应"""
|
||||
self.send_response(IPCResponse(
|
||||
command_id=command_id,
|
||||
status=CommandStatus.FAILED,
|
||||
error=error
|
||||
))
|
||||
|
|
@ -12,7 +12,7 @@ import threading
|
|||
import subprocess
|
||||
import signal
|
||||
import atexit
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
|
@ -21,6 +21,7 @@ from queue import Queue
|
|||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_graph_memory_updater import ZepGraphMemoryManager
|
||||
from .simulation_ipc import SimulationIPCClient, CommandType, IPCResponse
|
||||
|
||||
logger = get_logger('mirofish.simulation_runner')
|
||||
|
||||
|
|
@ -989,4 +990,365 @@ class SimulationRunner:
|
|||
if process.poll() is None:
|
||||
running.append(sim_id)
|
||||
return running
|
||||
|
||||
# ============== Interview 功能 ==============
|
||||
|
||||
@classmethod
|
||||
def check_env_alive(cls, simulation_id: str) -> bool:
|
||||
"""
|
||||
检查模拟环境是否存活(可以接收Interview命令)
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
|
||||
Returns:
|
||||
True 表示环境存活,False 表示环境已关闭
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
return False
|
||||
|
||||
ipc_client = SimulationIPCClient(sim_dir)
|
||||
return ipc_client.check_env_alive()
|
||||
|
||||
@classmethod
|
||||
def get_env_status_detail(cls, simulation_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取模拟环境的详细状态信息
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
|
||||
Returns:
|
||||
状态详情字典,包含 status, twitter_available, reddit_available, timestamp
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
status_file = os.path.join(sim_dir, "env_status.json")
|
||||
|
||||
default_status = {
|
||||
"status": "stopped",
|
||||
"twitter_available": False,
|
||||
"reddit_available": False,
|
||||
"timestamp": None
|
||||
}
|
||||
|
||||
if not os.path.exists(status_file):
|
||||
return default_status
|
||||
|
||||
try:
|
||||
with open(status_file, 'r', encoding='utf-8') as f:
|
||||
status = json.load(f)
|
||||
return {
|
||||
"status": status.get("status", "stopped"),
|
||||
"twitter_available": status.get("twitter_available", False),
|
||||
"reddit_available": status.get("reddit_available", False),
|
||||
"timestamp": status.get("timestamp")
|
||||
}
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return default_status
|
||||
|
||||
@classmethod
|
||||
def interview_agent(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
agent_id: int,
|
||||
prompt: str,
|
||||
platform: str = None,
|
||||
timeout: float = 60.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
采访单个Agent
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
agent_id: Agent ID
|
||||
prompt: 采访问题
|
||||
platform: 指定平台(可选)
|
||||
- "twitter": 只采访Twitter平台
|
||||
- "reddit": 只采访Reddit平台
|
||||
- None: 双平台模拟时同时采访两个平台,返回整合结果
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
采访结果字典
|
||||
|
||||
Raises:
|
||||
ValueError: 模拟不存在或环境未运行
|
||||
TimeoutError: 等待响应超时
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
ipc_client = SimulationIPCClient(sim_dir)
|
||||
|
||||
if not ipc_client.check_env_alive():
|
||||
raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}")
|
||||
|
||||
logger.info(f"发送Interview命令: simulation_id={simulation_id}, agent_id={agent_id}, platform={platform}")
|
||||
|
||||
response = ipc_client.send_interview(
|
||||
agent_id=agent_id,
|
||||
prompt=prompt,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if response.status.value == "completed":
|
||||
return {
|
||||
"success": True,
|
||||
"agent_id": agent_id,
|
||||
"prompt": prompt,
|
||||
"result": response.result,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"agent_id": agent_id,
|
||||
"prompt": prompt,
|
||||
"error": response.error,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def interview_agents_batch(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
interviews: List[Dict[str, Any]],
|
||||
platform: str = None,
|
||||
timeout: float = 120.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
批量采访多个Agent
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)}
|
||||
platform: 默认平台(可选,会被每个采访项的platform覆盖)
|
||||
- "twitter": 默认只采访Twitter平台
|
||||
- "reddit": 默认只采访Reddit平台
|
||||
- None: 双平台模拟时每个Agent同时采访两个平台
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
批量采访结果字典
|
||||
|
||||
Raises:
|
||||
ValueError: 模拟不存在或环境未运行
|
||||
TimeoutError: 等待响应超时
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
ipc_client = SimulationIPCClient(sim_dir)
|
||||
|
||||
if not ipc_client.check_env_alive():
|
||||
raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}")
|
||||
|
||||
logger.info(f"发送批量Interview命令: simulation_id={simulation_id}, count={len(interviews)}, platform={platform}")
|
||||
|
||||
response = ipc_client.send_batch_interview(
|
||||
interviews=interviews,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if response.status.value == "completed":
|
||||
return {
|
||||
"success": True,
|
||||
"interviews_count": len(interviews),
|
||||
"result": response.result,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"interviews_count": len(interviews),
|
||||
"error": response.error,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def interview_all_agents(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
prompt: str,
|
||||
platform: str = None,
|
||||
timeout: float = 180.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
采访所有Agent(全局采访)
|
||||
|
||||
使用相同的问题采访模拟中的所有Agent
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
prompt: 采访问题(所有Agent使用相同问题)
|
||||
platform: 指定平台(可选)
|
||||
- "twitter": 只采访Twitter平台
|
||||
- "reddit": 只采访Reddit平台
|
||||
- None: 双平台模拟时每个Agent同时采访两个平台
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
全局采访结果字典
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
# 从配置文件获取所有Agent信息
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(f"模拟配置不存在: {simulation_id}")
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
agent_configs = config.get("agent_configs", [])
|
||||
if not agent_configs:
|
||||
raise ValueError(f"模拟配置中没有Agent: {simulation_id}")
|
||||
|
||||
# 构建批量采访列表
|
||||
interviews = []
|
||||
for agent_config in agent_configs:
|
||||
agent_id = agent_config.get("agent_id")
|
||||
if agent_id is not None:
|
||||
interviews.append({
|
||||
"agent_id": agent_id,
|
||||
"prompt": prompt
|
||||
})
|
||||
|
||||
logger.info(f"发送全局Interview命令: simulation_id={simulation_id}, agent_count={len(interviews)}, platform={platform}")
|
||||
|
||||
return cls.interview_agents_batch(
|
||||
simulation_id=simulation_id,
|
||||
interviews=interviews,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def close_simulation_env(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
timeout: float = 30.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
关闭模拟环境(而不是停止模拟进程)
|
||||
|
||||
向模拟发送关闭环境命令,使其优雅退出等待命令模式
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
操作结果字典
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
ipc_client = SimulationIPCClient(sim_dir)
|
||||
|
||||
if not ipc_client.check_env_alive():
|
||||
return {
|
||||
"success": True,
|
||||
"message": "环境已经关闭"
|
||||
}
|
||||
|
||||
logger.info(f"发送关闭环境命令: simulation_id={simulation_id}")
|
||||
|
||||
try:
|
||||
response = ipc_client.send_close_env(timeout=timeout)
|
||||
|
||||
return {
|
||||
"success": response.status.value == "completed",
|
||||
"message": "环境关闭命令已发送",
|
||||
"result": response.result,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
except TimeoutError:
|
||||
# 超时可能是因为环境正在关闭
|
||||
return {
|
||||
"success": True,
|
||||
"message": "环境关闭命令已发送(等待响应超时,环境可能正在关闭)"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_interview_history(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
platform: str = "reddit",
|
||||
agent_id: Optional[int] = None,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取Interview历史记录(从数据库读取)
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
platform: 平台类型(reddit/twitter)
|
||||
agent_id: 过滤Agent ID(可选)
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
Interview历史记录列表
|
||||
"""
|
||||
import sqlite3
|
||||
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
db_path = os.path.join(sim_dir, f"{platform}_simulation.db")
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return []
|
||||
|
||||
results = []
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 构建查询
|
||||
# 注意:ActionType.INTERVIEW.value 应该是字符串形式
|
||||
if agent_id is not None:
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
WHERE action = 'interview' AND user_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
""", (agent_id, limit))
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
WHERE action = 'interview'
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
""", (limit,))
|
||||
|
||||
for user_id, info_json, created_at in cursor.fetchall():
|
||||
try:
|
||||
info = json.loads(info_json) if info_json else {}
|
||||
except json.JSONDecodeError:
|
||||
info = {"raw": info_json}
|
||||
|
||||
results.append({
|
||||
"agent_id": user_id,
|
||||
"response": info.get("response", info),
|
||||
"prompt": info.get("prompt", ""),
|
||||
"timestamp": created_at,
|
||||
"platform": platform
|
||||
})
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"读取Interview历史失败: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,18 @@
|
|||
OASIS 双平台并行模拟预设脚本
|
||||
同时运行Twitter和Reddit模拟,读取相同的配置文件
|
||||
|
||||
功能特性:
|
||||
- 双平台(Twitter + Reddit)并行模拟
|
||||
- 完成模拟后不立即关闭环境,进入等待命令模式
|
||||
- 支持通过IPC接收Interview命令
|
||||
- 支持单个Agent采访和批量采访
|
||||
- 支持远程关闭环境命令
|
||||
|
||||
使用方式:
|
||||
python run_parallel_simulation.py --config simulation_config.json
|
||||
python run_parallel_simulation.py --config simulation_config.json --no-wait # 完成后立即关闭
|
||||
python run_parallel_simulation.py --config simulation_config.json --twitter-only
|
||||
python run_parallel_simulation.py --config simulation_config.json --reddit-only
|
||||
|
||||
日志结构:
|
||||
sim_xxx/
|
||||
|
|
@ -119,7 +129,7 @@ except ImportError as e:
|
|||
sys.exit(1)
|
||||
|
||||
|
||||
# Twitter可用动作
|
||||
# Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发)
|
||||
TWITTER_ACTIONS = [
|
||||
ActionType.CREATE_POST,
|
||||
ActionType.LIKE_POST,
|
||||
|
|
@ -129,7 +139,7 @@ TWITTER_ACTIONS = [
|
|||
ActionType.QUOTE_POST,
|
||||
]
|
||||
|
||||
# Reddit可用动作
|
||||
# Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发)
|
||||
REDDIT_ACTIONS = [
|
||||
ActionType.LIKE_POST,
|
||||
ActionType.DISLIKE_POST,
|
||||
|
|
@ -147,6 +157,405 @@ REDDIT_ACTIONS = [
|
|||
]
|
||||
|
||||
|
||||
# IPC相关常量
|
||||
IPC_COMMANDS_DIR = "ipc_commands"
|
||||
IPC_RESPONSES_DIR = "ipc_responses"
|
||||
ENV_STATUS_FILE = "env_status.json"
|
||||
|
||||
class CommandType:
|
||||
"""命令类型常量"""
|
||||
INTERVIEW = "interview"
|
||||
BATCH_INTERVIEW = "batch_interview"
|
||||
CLOSE_ENV = "close_env"
|
||||
|
||||
|
||||
class ParallelIPCHandler:
|
||||
"""
|
||||
双平台IPC命令处理器
|
||||
|
||||
管理两个平台的环境,处理Interview命令
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
simulation_dir: str,
|
||||
twitter_env=None,
|
||||
twitter_agent_graph=None,
|
||||
reddit_env=None,
|
||||
reddit_agent_graph=None
|
||||
):
|
||||
self.simulation_dir = simulation_dir
|
||||
self.twitter_env = twitter_env
|
||||
self.twitter_agent_graph = twitter_agent_graph
|
||||
self.reddit_env = reddit_env
|
||||
self.reddit_agent_graph = reddit_agent_graph
|
||||
|
||||
self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR)
|
||||
self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR)
|
||||
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
def update_status(self, status: str):
|
||||
"""更新环境状态"""
|
||||
with open(self.status_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"status": status,
|
||||
"twitter_available": self.twitter_env is not None,
|
||||
"reddit_available": self.reddit_env is not None,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def poll_command(self) -> Optional[Dict[str, Any]]:
|
||||
"""轮询获取待处理命令"""
|
||||
if not os.path.exists(self.commands_dir):
|
||||
return None
|
||||
|
||||
# 获取命令文件(按时间排序)
|
||||
command_files = []
|
||||
for filename in os.listdir(self.commands_dir):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(self.commands_dir, filename)
|
||||
command_files.append((filepath, os.path.getmtime(filepath)))
|
||||
|
||||
command_files.sort(key=lambda x: x[1])
|
||||
|
||||
for filepath, _ in command_files:
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
|
||||
"""发送响应"""
|
||||
response = {
|
||||
"command_id": command_id,
|
||||
"status": status,
|
||||
"result": result,
|
||||
"error": error,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
||||
with open(response_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(response, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 删除命令文件
|
||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||
try:
|
||||
os.remove(command_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _get_env_and_graph(self, platform: str):
|
||||
"""
|
||||
获取指定平台的环境和agent_graph
|
||||
|
||||
Args:
|
||||
platform: 平台名称 ("twitter" 或 "reddit")
|
||||
|
||||
Returns:
|
||||
(env, agent_graph, platform_name) 或 (None, None, None)
|
||||
"""
|
||||
if platform == "twitter" and self.twitter_env:
|
||||
return self.twitter_env, self.twitter_agent_graph, "twitter"
|
||||
elif platform == "reddit" and self.reddit_env:
|
||||
return self.reddit_env, self.reddit_agent_graph, "reddit"
|
||||
else:
|
||||
return None, None, None
|
||||
|
||||
async def _interview_single_platform(self, agent_id: int, prompt: str, platform: str) -> Dict[str, Any]:
|
||||
"""
|
||||
在单个平台上执行Interview
|
||||
|
||||
Returns:
|
||||
包含结果的字典,或包含error的字典
|
||||
"""
|
||||
env, agent_graph, actual_platform = self._get_env_and_graph(platform)
|
||||
|
||||
if not env or not agent_graph:
|
||||
return {"platform": platform, "error": f"{platform}平台不可用"}
|
||||
|
||||
try:
|
||||
agent = agent_graph.get_agent(agent_id)
|
||||
interview_action = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
actions = {agent: interview_action}
|
||||
await env.step(actions)
|
||||
|
||||
result = self._get_interview_result(agent_id, actual_platform)
|
||||
result["platform"] = actual_platform
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return {"platform": platform, "error": str(e)}
|
||||
|
||||
async def handle_interview(self, command_id: str, agent_id: int, prompt: str, platform: str = None) -> bool:
|
||||
"""
|
||||
处理单个Agent采访命令
|
||||
|
||||
Args:
|
||||
command_id: 命令ID
|
||||
agent_id: Agent ID
|
||||
prompt: 采访问题
|
||||
platform: 指定平台(可选)
|
||||
- "twitter": 只采访Twitter平台
|
||||
- "reddit": 只采访Reddit平台
|
||||
- None/不指定: 同时采访两个平台,返回整合结果
|
||||
|
||||
Returns:
|
||||
True 表示成功,False 表示失败
|
||||
"""
|
||||
# 如果指定了平台,只采访该平台
|
||||
if platform in ("twitter", "reddit"):
|
||||
result = await self._interview_single_platform(agent_id, prompt, platform)
|
||||
|
||||
if "error" in result:
|
||||
self.send_response(command_id, "failed", error=result["error"])
|
||||
print(f" Interview失败: agent_id={agent_id}, platform={platform}, error={result['error']}")
|
||||
return False
|
||||
else:
|
||||
self.send_response(command_id, "completed", result=result)
|
||||
print(f" Interview完成: agent_id={agent_id}, platform={platform}")
|
||||
return True
|
||||
|
||||
# 未指定平台:同时采访两个平台
|
||||
if not self.twitter_env and not self.reddit_env:
|
||||
self.send_response(command_id, "failed", error="没有可用的模拟环境")
|
||||
return False
|
||||
|
||||
results = {
|
||||
"agent_id": agent_id,
|
||||
"prompt": prompt,
|
||||
"platforms": {}
|
||||
}
|
||||
success_count = 0
|
||||
|
||||
# 并行采访两个平台
|
||||
tasks = []
|
||||
platforms_to_interview = []
|
||||
|
||||
if self.twitter_env:
|
||||
tasks.append(self._interview_single_platform(agent_id, prompt, "twitter"))
|
||||
platforms_to_interview.append("twitter")
|
||||
|
||||
if self.reddit_env:
|
||||
tasks.append(self._interview_single_platform(agent_id, prompt, "reddit"))
|
||||
platforms_to_interview.append("reddit")
|
||||
|
||||
# 并行执行
|
||||
platform_results = await asyncio.gather(*tasks)
|
||||
|
||||
for platform_name, platform_result in zip(platforms_to_interview, platform_results):
|
||||
results["platforms"][platform_name] = platform_result
|
||||
if "error" not in platform_result:
|
||||
success_count += 1
|
||||
|
||||
if success_count > 0:
|
||||
self.send_response(command_id, "completed", result=results)
|
||||
print(f" Interview完成: agent_id={agent_id}, 成功平台数={success_count}/{len(platforms_to_interview)}")
|
||||
return True
|
||||
else:
|
||||
errors = [f"{p}: {r.get('error', '未知错误')}" for p, r in results["platforms"].items()]
|
||||
self.send_response(command_id, "failed", error="; ".join(errors))
|
||||
print(f" Interview失败: agent_id={agent_id}, 所有平台都失败")
|
||||
return False
|
||||
|
||||
async def handle_batch_interview(self, command_id: str, interviews: List[Dict], platform: str = None) -> bool:
|
||||
"""
|
||||
处理批量采访命令
|
||||
|
||||
Args:
|
||||
command_id: 命令ID
|
||||
interviews: [{"agent_id": int, "prompt": str, "platform": str(optional)}, ...]
|
||||
platform: 默认平台(可被每个interview项覆盖)
|
||||
- "twitter": 只采访Twitter平台
|
||||
- "reddit": 只采访Reddit平台
|
||||
- None/不指定: 每个Agent同时采访两个平台
|
||||
"""
|
||||
# 按平台分组
|
||||
twitter_interviews = []
|
||||
reddit_interviews = []
|
||||
both_platforms_interviews = [] # 需要同时采访两个平台的
|
||||
|
||||
for interview in interviews:
|
||||
item_platform = interview.get("platform", platform)
|
||||
if item_platform == "twitter":
|
||||
twitter_interviews.append(interview)
|
||||
elif item_platform == "reddit":
|
||||
reddit_interviews.append(interview)
|
||||
else:
|
||||
# 未指定平台:两个平台都采访
|
||||
both_platforms_interviews.append(interview)
|
||||
|
||||
# 把 both_platforms_interviews 拆分到两个平台
|
||||
if both_platforms_interviews:
|
||||
if self.twitter_env:
|
||||
twitter_interviews.extend(both_platforms_interviews)
|
||||
if self.reddit_env:
|
||||
reddit_interviews.extend(both_platforms_interviews)
|
||||
|
||||
results = {}
|
||||
|
||||
# 处理Twitter平台的采访
|
||||
if twitter_interviews and self.twitter_env:
|
||||
try:
|
||||
twitter_actions = {}
|
||||
for interview in twitter_interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
prompt = interview.get("prompt", "")
|
||||
try:
|
||||
agent = self.twitter_agent_graph.get_agent(agent_id)
|
||||
twitter_actions[agent] = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" 警告: 无法获取Twitter Agent {agent_id}: {e}")
|
||||
|
||||
if twitter_actions:
|
||||
await self.twitter_env.step(twitter_actions)
|
||||
|
||||
for interview in twitter_interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
result = self._get_interview_result(agent_id, "twitter")
|
||||
result["platform"] = "twitter"
|
||||
results[f"twitter_{agent_id}"] = result
|
||||
except Exception as e:
|
||||
print(f" Twitter批量Interview失败: {e}")
|
||||
|
||||
# 处理Reddit平台的采访
|
||||
if reddit_interviews and self.reddit_env:
|
||||
try:
|
||||
reddit_actions = {}
|
||||
for interview in reddit_interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
prompt = interview.get("prompt", "")
|
||||
try:
|
||||
agent = self.reddit_agent_graph.get_agent(agent_id)
|
||||
reddit_actions[agent] = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" 警告: 无法获取Reddit Agent {agent_id}: {e}")
|
||||
|
||||
if reddit_actions:
|
||||
await self.reddit_env.step(reddit_actions)
|
||||
|
||||
for interview in reddit_interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
result = self._get_interview_result(agent_id, "reddit")
|
||||
result["platform"] = "reddit"
|
||||
results[f"reddit_{agent_id}"] = result
|
||||
except Exception as e:
|
||||
print(f" Reddit批量Interview失败: {e}")
|
||||
|
||||
if results:
|
||||
self.send_response(command_id, "completed", result={
|
||||
"interviews_count": len(results),
|
||||
"results": results
|
||||
})
|
||||
print(f" 批量Interview完成: {len(results)} 个Agent")
|
||||
return True
|
||||
else:
|
||||
self.send_response(command_id, "failed", error="没有成功的采访")
|
||||
return False
|
||||
|
||||
def _get_interview_result(self, agent_id: int, platform: str) -> Dict[str, Any]:
|
||||
"""从数据库获取最新的Interview结果"""
|
||||
db_path = os.path.join(self.simulation_dir, f"{platform}_simulation.db")
|
||||
|
||||
result = {
|
||||
"agent_id": agent_id,
|
||||
"response": None,
|
||||
"timestamp": None
|
||||
}
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return result
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询最新的Interview记录
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
WHERE action = ? AND user_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""", (ActionType.INTERVIEW.value, agent_id))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
user_id, info_json, created_at = row
|
||||
try:
|
||||
info = json.loads(info_json) if info_json else {}
|
||||
result["response"] = info.get("response", info)
|
||||
result["timestamp"] = created_at
|
||||
except json.JSONDecodeError:
|
||||
result["response"] = info_json
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f" 读取Interview结果失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def process_commands(self) -> bool:
|
||||
"""
|
||||
处理所有待处理命令
|
||||
|
||||
Returns:
|
||||
True 表示继续运行,False 表示应该退出
|
||||
"""
|
||||
command = self.poll_command()
|
||||
if not command:
|
||||
return True
|
||||
|
||||
command_id = command.get("command_id")
|
||||
command_type = command.get("command_type")
|
||||
args = command.get("args", {})
|
||||
|
||||
print(f"\n收到IPC命令: {command_type}, id={command_id}")
|
||||
|
||||
if command_type == CommandType.INTERVIEW:
|
||||
await self.handle_interview(
|
||||
command_id,
|
||||
args.get("agent_id", 0),
|
||||
args.get("prompt", ""),
|
||||
args.get("platform")
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.BATCH_INTERVIEW:
|
||||
await self.handle_batch_interview(
|
||||
command_id,
|
||||
args.get("interviews", []),
|
||||
args.get("platform")
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.CLOSE_ENV:
|
||||
print("收到关闭环境命令")
|
||||
self.send_response(command_id, "completed", result={"message": "环境即将关闭"})
|
||||
return False
|
||||
|
||||
else:
|
||||
self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}")
|
||||
return True
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
|
|
@ -398,13 +807,21 @@ def get_active_agents_for_round(
|
|||
return active_agents
|
||||
|
||||
|
||||
class PlatformSimulation:
|
||||
"""平台模拟结果容器"""
|
||||
def __init__(self):
|
||||
self.env = None
|
||||
self.agent_graph = None
|
||||
self.total_actions = 0
|
||||
|
||||
|
||||
async def run_twitter_simulation(
|
||||
config: Dict[str, Any],
|
||||
simulation_dir: str,
|
||||
action_logger: Optional[PlatformActionLogger] = None,
|
||||
main_logger: Optional[SimulationLogManager] = None,
|
||||
max_rounds: Optional[int] = None
|
||||
):
|
||||
) -> PlatformSimulation:
|
||||
"""运行Twitter模拟
|
||||
|
||||
Args:
|
||||
|
|
@ -413,7 +830,12 @@ async def run_twitter_simulation(
|
|||
action_logger: 动作日志记录器
|
||||
main_logger: 主日志管理器
|
||||
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||
|
||||
Returns:
|
||||
PlatformSimulation: 包含env和agent_graph的结果对象
|
||||
"""
|
||||
result = PlatformSimulation()
|
||||
|
||||
def log_info(msg):
|
||||
if main_logger:
|
||||
main_logger.info(f"[Twitter] {msg}")
|
||||
|
|
@ -428,9 +850,9 @@ async def run_twitter_simulation(
|
|||
profile_path = os.path.join(simulation_dir, "twitter_profiles.csv")
|
||||
if not os.path.exists(profile_path):
|
||||
log_info(f"错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
return result
|
||||
|
||||
agent_graph = await generate_twitter_agent_graph(
|
||||
result.agent_graph = await generate_twitter_agent_graph(
|
||||
profile_path=profile_path,
|
||||
model=model,
|
||||
available_actions=TWITTER_ACTIONS,
|
||||
|
|
@ -439,7 +861,7 @@ async def run_twitter_simulation(
|
|||
# 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X)
|
||||
agent_names = get_agent_names_from_config(config)
|
||||
# 如果配置中没有某个 agent,则使用 OASIS 的默认名称
|
||||
for agent_id, agent in agent_graph.get_agents():
|
||||
for agent_id, agent in result.agent_graph.get_agents():
|
||||
if agent_id not in agent_names:
|
||||
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
||||
|
||||
|
|
@ -447,14 +869,14 @@ async def run_twitter_simulation(
|
|||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
|
||||
env = oasis.make(
|
||||
agent_graph=agent_graph,
|
||||
result.env = oasis.make(
|
||||
agent_graph=result.agent_graph,
|
||||
platform=oasis.DefaultPlatformType.TWITTER,
|
||||
database_path=db_path,
|
||||
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
await result.env.reset()
|
||||
log_info("环境已启动")
|
||||
|
||||
if action_logger:
|
||||
|
|
@ -478,7 +900,7 @@ async def run_twitter_simulation(
|
|||
agent_id = post.get("poster_agent_id", 0)
|
||||
content = post.get("content", "")
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
agent = result.env.agent_graph.get_agent(agent_id)
|
||||
initial_actions[agent] = ManualAction(
|
||||
action_type=ActionType.CREATE_POST,
|
||||
action_args={"content": content}
|
||||
|
|
@ -498,7 +920,7 @@ async def run_twitter_simulation(
|
|||
pass
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
await result.env.step(initial_actions)
|
||||
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 记录 round 0 结束
|
||||
|
|
@ -526,7 +948,7 @@ async def run_twitter_simulation(
|
|||
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||
|
||||
active_agents = get_active_agents_for_round(
|
||||
env, config, simulated_hour, round_num
|
||||
result.env, config, simulated_hour, round_num
|
||||
)
|
||||
|
||||
# 无论是否有活跃agent,都记录round开始
|
||||
|
|
@ -540,7 +962,7 @@ async def run_twitter_simulation(
|
|||
continue
|
||||
|
||||
actions = {agent: LLMAction() for _, agent in active_agents}
|
||||
await env.step(actions)
|
||||
await result.env.step(actions)
|
||||
|
||||
# 从数据库获取实际执行的动作并记录
|
||||
actual_actions, last_rowid = fetch_new_actions_from_db(
|
||||
|
|
@ -567,13 +989,16 @@ async def run_twitter_simulation(
|
|||
progress = (round_num + 1) / total_rounds * 100
|
||||
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
||||
|
||||
await env.close()
|
||||
# 注意:不关闭环境,保留给Interview使用
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_simulation_end(total_rounds, total_actions)
|
||||
|
||||
result.total_actions = total_actions
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
log_info(f"模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def run_reddit_simulation(
|
||||
|
|
@ -582,7 +1007,7 @@ async def run_reddit_simulation(
|
|||
action_logger: Optional[PlatformActionLogger] = None,
|
||||
main_logger: Optional[SimulationLogManager] = None,
|
||||
max_rounds: Optional[int] = None
|
||||
):
|
||||
) -> PlatformSimulation:
|
||||
"""运行Reddit模拟
|
||||
|
||||
Args:
|
||||
|
|
@ -591,7 +1016,12 @@ async def run_reddit_simulation(
|
|||
action_logger: 动作日志记录器
|
||||
main_logger: 主日志管理器
|
||||
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||
|
||||
Returns:
|
||||
PlatformSimulation: 包含env和agent_graph的结果对象
|
||||
"""
|
||||
result = PlatformSimulation()
|
||||
|
||||
def log_info(msg):
|
||||
if main_logger:
|
||||
main_logger.info(f"[Reddit] {msg}")
|
||||
|
|
@ -605,9 +1035,9 @@ async def run_reddit_simulation(
|
|||
profile_path = os.path.join(simulation_dir, "reddit_profiles.json")
|
||||
if not os.path.exists(profile_path):
|
||||
log_info(f"错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
return result
|
||||
|
||||
agent_graph = await generate_reddit_agent_graph(
|
||||
result.agent_graph = await generate_reddit_agent_graph(
|
||||
profile_path=profile_path,
|
||||
model=model,
|
||||
available_actions=REDDIT_ACTIONS,
|
||||
|
|
@ -616,7 +1046,7 @@ async def run_reddit_simulation(
|
|||
# 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X)
|
||||
agent_names = get_agent_names_from_config(config)
|
||||
# 如果配置中没有某个 agent,则使用 OASIS 的默认名称
|
||||
for agent_id, agent in agent_graph.get_agents():
|
||||
for agent_id, agent in result.agent_graph.get_agents():
|
||||
if agent_id not in agent_names:
|
||||
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
||||
|
||||
|
|
@ -624,14 +1054,14 @@ async def run_reddit_simulation(
|
|||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
|
||||
env = oasis.make(
|
||||
agent_graph=agent_graph,
|
||||
result.env = oasis.make(
|
||||
agent_graph=result.agent_graph,
|
||||
platform=oasis.DefaultPlatformType.REDDIT,
|
||||
database_path=db_path,
|
||||
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
await result.env.reset()
|
||||
log_info("环境已启动")
|
||||
|
||||
if action_logger:
|
||||
|
|
@ -655,7 +1085,7 @@ async def run_reddit_simulation(
|
|||
agent_id = post.get("poster_agent_id", 0)
|
||||
content = post.get("content", "")
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
agent = result.env.agent_graph.get_agent(agent_id)
|
||||
if agent in initial_actions:
|
||||
if not isinstance(initial_actions[agent], list):
|
||||
initial_actions[agent] = [initial_actions[agent]]
|
||||
|
|
@ -683,7 +1113,7 @@ async def run_reddit_simulation(
|
|||
pass
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
await result.env.step(initial_actions)
|
||||
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 记录 round 0 结束
|
||||
|
|
@ -711,7 +1141,7 @@ async def run_reddit_simulation(
|
|||
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||
|
||||
active_agents = get_active_agents_for_round(
|
||||
env, config, simulated_hour, round_num
|
||||
result.env, config, simulated_hour, round_num
|
||||
)
|
||||
|
||||
# 无论是否有活跃agent,都记录round开始
|
||||
|
|
@ -725,7 +1155,7 @@ async def run_reddit_simulation(
|
|||
continue
|
||||
|
||||
actions = {agent: LLMAction() for _, agent in active_agents}
|
||||
await env.step(actions)
|
||||
await result.env.step(actions)
|
||||
|
||||
# 从数据库获取实际执行的动作并记录
|
||||
actual_actions, last_rowid = fetch_new_actions_from_db(
|
||||
|
|
@ -752,13 +1182,16 @@ async def run_reddit_simulation(
|
|||
progress = (round_num + 1) / total_rounds * 100
|
||||
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
||||
|
||||
await env.close()
|
||||
# 注意:不关闭环境,保留给Interview使用
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_simulation_end(total_rounds, total_actions)
|
||||
|
||||
result.total_actions = total_actions
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
log_info(f"模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
@ -785,6 +1218,12 @@ async def main():
|
|||
default=None,
|
||||
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-wait',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='模拟完成后立即关闭环境,不进入等待命令模式'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -794,6 +1233,7 @@ async def main():
|
|||
|
||||
config = load_config(args.config)
|
||||
simulation_dir = os.path.dirname(args.config) or "."
|
||||
wait_for_commands = not args.no_wait
|
||||
|
||||
# 初始化日志配置(禁用 OASIS 日志,清理旧文件)
|
||||
init_logging_for_simulation(simulation_dir)
|
||||
|
|
@ -807,6 +1247,7 @@ async def main():
|
|||
log_manager.info("OASIS 双平台并行模拟")
|
||||
log_manager.info(f"配置文件: {args.config}")
|
||||
log_manager.info(f"模拟ID: {config.get('simulation_id', 'unknown')}")
|
||||
log_manager.info(f"等待命令模式: {'启用' if wait_for_commands else '禁用'}")
|
||||
log_manager.info("=" * 60)
|
||||
|
||||
time_config = config.get("time_config", {})
|
||||
|
|
@ -832,20 +1273,70 @@ async def main():
|
|||
|
||||
start_time = datetime.now()
|
||||
|
||||
# 存储两个平台的模拟结果
|
||||
twitter_result: Optional[PlatformSimulation] = None
|
||||
reddit_result: Optional[PlatformSimulation] = None
|
||||
|
||||
if args.twitter_only:
|
||||
await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds)
|
||||
twitter_result = await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds)
|
||||
elif args.reddit_only:
|
||||
await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds)
|
||||
reddit_result = await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds)
|
||||
else:
|
||||
# 并行运行(每个平台使用独立的日志记录器)
|
||||
await asyncio.gather(
|
||||
results = await asyncio.gather(
|
||||
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds),
|
||||
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds),
|
||||
)
|
||||
twitter_result, reddit_result = results
|
||||
|
||||
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||
log_manager.info("=" * 60)
|
||||
log_manager.info(f"全部模拟完成! 总耗时: {total_elapsed:.1f}秒")
|
||||
log_manager.info(f"模拟循环完成! 总耗时: {total_elapsed:.1f}秒")
|
||||
|
||||
# 是否进入等待命令模式
|
||||
if wait_for_commands:
|
||||
log_manager.info("")
|
||||
log_manager.info("=" * 60)
|
||||
log_manager.info("进入等待命令模式 - 环境保持运行")
|
||||
log_manager.info("支持的命令: interview, batch_interview, close_env")
|
||||
log_manager.info("=" * 60)
|
||||
|
||||
# 创建IPC处理器
|
||||
ipc_handler = ParallelIPCHandler(
|
||||
simulation_dir=simulation_dir,
|
||||
twitter_env=twitter_result.env if twitter_result else None,
|
||||
twitter_agent_graph=twitter_result.agent_graph if twitter_result else None,
|
||||
reddit_env=reddit_result.env if reddit_result else None,
|
||||
reddit_agent_graph=reddit_result.agent_graph if reddit_result else None
|
||||
)
|
||||
ipc_handler.update_status("alive")
|
||||
|
||||
# 等待命令循环
|
||||
try:
|
||||
while True:
|
||||
should_continue = await ipc_handler.process_commands()
|
||||
if not should_continue:
|
||||
break
|
||||
await asyncio.sleep(0.5) # 轮询间隔
|
||||
except KeyboardInterrupt:
|
||||
print("\n收到中断信号")
|
||||
except Exception as e:
|
||||
print(f"\n命令处理出错: {e}")
|
||||
|
||||
log_manager.info("\n关闭环境...")
|
||||
ipc_handler.update_status("stopped")
|
||||
|
||||
# 关闭环境
|
||||
if twitter_result and twitter_result.env:
|
||||
await twitter_result.env.close()
|
||||
log_manager.info("[Twitter] 环境已关闭")
|
||||
|
||||
if reddit_result and reddit_result.env:
|
||||
await reddit_result.env.close()
|
||||
log_manager.info("[Reddit] 环境已关闭")
|
||||
|
||||
log_manager.info("=" * 60)
|
||||
log_manager.info(f"全部完成!")
|
||||
log_manager.info(f"日志文件:")
|
||||
log_manager.info(f" - {os.path.join(simulation_dir, 'simulation.log')}")
|
||||
log_manager.info(f" - {os.path.join(simulation_dir, 'twitter', 'actions.jsonl')}")
|
||||
|
|
@ -855,4 +1346,3 @@ async def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,15 @@
|
|||
OASIS Reddit模拟预设脚本
|
||||
此脚本读取配置文件中的参数来执行模拟,实现全程自动化
|
||||
|
||||
功能特性:
|
||||
- 完成模拟后不立即关闭环境,进入等待命令模式
|
||||
- 支持通过IPC接收Interview命令
|
||||
- 支持单个Agent采访和批量采访
|
||||
- 支持远程关闭环境命令
|
||||
|
||||
使用方式:
|
||||
python run_reddit_simulation.py --config /path/to/simulation_config.json
|
||||
python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -13,8 +20,9 @@ import logging
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
# 添加项目路径
|
||||
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
|
@ -118,10 +126,261 @@ except ImportError as e:
|
|||
sys.exit(1)
|
||||
|
||||
|
||||
# IPC相关常量
|
||||
IPC_COMMANDS_DIR = "ipc_commands"
|
||||
IPC_RESPONSES_DIR = "ipc_responses"
|
||||
ENV_STATUS_FILE = "env_status.json"
|
||||
|
||||
class CommandType:
|
||||
"""命令类型常量"""
|
||||
INTERVIEW = "interview"
|
||||
BATCH_INTERVIEW = "batch_interview"
|
||||
CLOSE_ENV = "close_env"
|
||||
|
||||
|
||||
class IPCHandler:
|
||||
"""IPC命令处理器"""
|
||||
|
||||
def __init__(self, simulation_dir: str, env, agent_graph):
|
||||
self.simulation_dir = simulation_dir
|
||||
self.env = env
|
||||
self.agent_graph = agent_graph
|
||||
self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR)
|
||||
self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR)
|
||||
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
|
||||
self._running = True
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
def update_status(self, status: str):
|
||||
"""更新环境状态"""
|
||||
with open(self.status_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"status": status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def poll_command(self) -> Optional[Dict[str, Any]]:
|
||||
"""轮询获取待处理命令"""
|
||||
if not os.path.exists(self.commands_dir):
|
||||
return None
|
||||
|
||||
# 获取命令文件(按时间排序)
|
||||
command_files = []
|
||||
for filename in os.listdir(self.commands_dir):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(self.commands_dir, filename)
|
||||
command_files.append((filepath, os.path.getmtime(filepath)))
|
||||
|
||||
command_files.sort(key=lambda x: x[1])
|
||||
|
||||
for filepath, _ in command_files:
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
|
||||
"""发送响应"""
|
||||
response = {
|
||||
"command_id": command_id,
|
||||
"status": status,
|
||||
"result": result,
|
||||
"error": error,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
||||
with open(response_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(response, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 删除命令文件
|
||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||
try:
|
||||
os.remove(command_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool:
|
||||
"""
|
||||
处理单个Agent采访命令
|
||||
|
||||
Returns:
|
||||
True 表示成功,False 表示失败
|
||||
"""
|
||||
try:
|
||||
# 获取Agent
|
||||
agent = self.agent_graph.get_agent(agent_id)
|
||||
|
||||
# 创建Interview动作
|
||||
interview_action = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
|
||||
# 执行Interview
|
||||
actions = {agent: interview_action}
|
||||
await self.env.step(actions)
|
||||
|
||||
# 从数据库获取结果
|
||||
result = self._get_interview_result(agent_id)
|
||||
|
||||
self.send_response(command_id, "completed", result=result)
|
||||
print(f" Interview完成: agent_id={agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f" Interview失败: agent_id={agent_id}, error={error_msg}")
|
||||
self.send_response(command_id, "failed", error=error_msg)
|
||||
return False
|
||||
|
||||
async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool:
|
||||
"""
|
||||
处理批量采访命令
|
||||
|
||||
Args:
|
||||
interviews: [{"agent_id": int, "prompt": str}, ...]
|
||||
"""
|
||||
try:
|
||||
# 构建动作字典
|
||||
actions = {}
|
||||
agent_prompts = {} # 记录每个agent的prompt
|
||||
|
||||
for interview in interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
prompt = interview.get("prompt", "")
|
||||
|
||||
try:
|
||||
agent = self.agent_graph.get_agent(agent_id)
|
||||
actions[agent] = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
agent_prompts[agent_id] = prompt
|
||||
except Exception as e:
|
||||
print(f" 警告: 无法获取Agent {agent_id}: {e}")
|
||||
|
||||
if not actions:
|
||||
self.send_response(command_id, "failed", error="没有有效的Agent")
|
||||
return False
|
||||
|
||||
# 执行批量Interview
|
||||
await self.env.step(actions)
|
||||
|
||||
# 获取所有结果
|
||||
results = {}
|
||||
for agent_id in agent_prompts.keys():
|
||||
result = self._get_interview_result(agent_id)
|
||||
results[agent_id] = result
|
||||
|
||||
self.send_response(command_id, "completed", result={
|
||||
"interviews_count": len(results),
|
||||
"results": results
|
||||
})
|
||||
print(f" 批量Interview完成: {len(results)} 个Agent")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f" 批量Interview失败: {error_msg}")
|
||||
self.send_response(command_id, "failed", error=error_msg)
|
||||
return False
|
||||
|
||||
def _get_interview_result(self, agent_id: int) -> Dict[str, Any]:
|
||||
"""从数据库获取最新的Interview结果"""
|
||||
db_path = os.path.join(self.simulation_dir, "reddit_simulation.db")
|
||||
|
||||
result = {
|
||||
"agent_id": agent_id,
|
||||
"response": None,
|
||||
"timestamp": None
|
||||
}
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return result
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询最新的Interview记录
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
WHERE action = ? AND user_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""", (ActionType.INTERVIEW.value, agent_id))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
user_id, info_json, created_at = row
|
||||
try:
|
||||
info = json.loads(info_json) if info_json else {}
|
||||
result["response"] = info.get("response", info)
|
||||
result["timestamp"] = created_at
|
||||
except json.JSONDecodeError:
|
||||
result["response"] = info_json
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f" 读取Interview结果失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def process_commands(self) -> bool:
|
||||
"""
|
||||
处理所有待处理命令
|
||||
|
||||
Returns:
|
||||
True 表示继续运行,False 表示应该退出
|
||||
"""
|
||||
command = self.poll_command()
|
||||
if not command:
|
||||
return True
|
||||
|
||||
command_id = command.get("command_id")
|
||||
command_type = command.get("command_type")
|
||||
args = command.get("args", {})
|
||||
|
||||
print(f"\n收到IPC命令: {command_type}, id={command_id}")
|
||||
|
||||
if command_type == CommandType.INTERVIEW:
|
||||
await self.handle_interview(
|
||||
command_id,
|
||||
args.get("agent_id", 0),
|
||||
args.get("prompt", "")
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.BATCH_INTERVIEW:
|
||||
await self.handle_batch_interview(
|
||||
command_id,
|
||||
args.get("interviews", [])
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.CLOSE_ENV:
|
||||
print("收到关闭环境命令")
|
||||
self.send_response(command_id, "completed", result={"message": "环境即将关闭"})
|
||||
return False
|
||||
|
||||
else:
|
||||
self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}")
|
||||
return True
|
||||
|
||||
|
||||
class RedditSimulationRunner:
|
||||
"""Reddit模拟运行器"""
|
||||
|
||||
# Reddit可用动作
|
||||
# Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发)
|
||||
AVAILABLE_ACTIONS = [
|
||||
ActionType.LIKE_POST,
|
||||
ActionType.DISLIKE_POST,
|
||||
|
|
@ -138,16 +397,21 @@ class RedditSimulationRunner:
|
|||
ActionType.MUTE,
|
||||
]
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
def __init__(self, config_path: str, wait_for_commands: bool = True):
|
||||
"""
|
||||
初始化模拟运行器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径 (simulation_config.json)
|
||||
wait_for_commands: 模拟完成后是否等待命令(默认True)
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.config = self._load_config()
|
||||
self.simulation_dir = os.path.dirname(config_path)
|
||||
self.wait_for_commands = wait_for_commands
|
||||
self.env = None
|
||||
self.agent_graph = None
|
||||
self.ipc_handler = None
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
|
|
@ -261,6 +525,7 @@ class RedditSimulationRunner:
|
|||
print("OASIS Reddit模拟")
|
||||
print(f"配置文件: {self.config_path}")
|
||||
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
|
||||
print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}")
|
||||
print("=" * 60)
|
||||
|
||||
time_config = self.config.get("time_config", {})
|
||||
|
|
@ -292,7 +557,7 @@ class RedditSimulationRunner:
|
|||
print(f"错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
|
||||
agent_graph = await generate_reddit_agent_graph(
|
||||
self.agent_graph = await generate_reddit_agent_graph(
|
||||
profile_path=profile_path,
|
||||
model=model,
|
||||
available_actions=self.AVAILABLE_ACTIONS,
|
||||
|
|
@ -304,16 +569,20 @@ class RedditSimulationRunner:
|
|||
print(f"已删除旧数据库: {db_path}")
|
||||
|
||||
print("创建OASIS环境...")
|
||||
env = oasis.make(
|
||||
agent_graph=agent_graph,
|
||||
self.env = oasis.make(
|
||||
agent_graph=self.agent_graph,
|
||||
platform=oasis.DefaultPlatformType.REDDIT,
|
||||
database_path=db_path,
|
||||
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
await self.env.reset()
|
||||
print("环境初始化完成\n")
|
||||
|
||||
# 初始化IPC处理器
|
||||
self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph)
|
||||
self.ipc_handler.update_status("running")
|
||||
|
||||
# 执行初始事件
|
||||
event_config = self.config.get("event_config", {})
|
||||
initial_posts = event_config.get("initial_posts", [])
|
||||
|
|
@ -325,7 +594,7 @@ class RedditSimulationRunner:
|
|||
agent_id = post.get("poster_agent_id", 0)
|
||||
content = post.get("content", "")
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
agent = self.env.agent_graph.get_agent(agent_id)
|
||||
if agent in initial_actions:
|
||||
if not isinstance(initial_actions[agent], list):
|
||||
initial_actions[agent] = [initial_actions[agent]]
|
||||
|
|
@ -342,7 +611,7 @@ class RedditSimulationRunner:
|
|||
print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}")
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
await self.env.step(initial_actions)
|
||||
print(f" 已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 主模拟循环
|
||||
|
|
@ -355,7 +624,7 @@ class RedditSimulationRunner:
|
|||
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||
|
||||
active_agents = self._get_active_agents_for_round(
|
||||
env, simulated_hour, round_num
|
||||
self.env, simulated_hour, round_num
|
||||
)
|
||||
|
||||
if not active_agents:
|
||||
|
|
@ -366,7 +635,7 @@ class RedditSimulationRunner:
|
|||
for _, agent in active_agents
|
||||
}
|
||||
|
||||
await env.step(actions)
|
||||
await self.env.step(actions)
|
||||
|
||||
if (round_num + 1) % 10 == 0 or round_num == 0:
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
|
|
@ -376,12 +645,39 @@ class RedditSimulationRunner:
|
|||
f"- {len(active_agents)} agents active "
|
||||
f"- elapsed: {elapsed:.1f}s")
|
||||
|
||||
await env.close()
|
||||
|
||||
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||
print(f"\n模拟完成!")
|
||||
print(f"\n模拟循环完成!")
|
||||
print(f" - 总耗时: {total_elapsed:.1f}秒")
|
||||
print(f" - 数据库: {db_path}")
|
||||
|
||||
# 是否进入等待命令模式
|
||||
if self.wait_for_commands:
|
||||
print("\n" + "=" * 60)
|
||||
print("进入等待命令模式 - 环境保持运行")
|
||||
print("支持的命令: interview, batch_interview, close_env")
|
||||
print("=" * 60)
|
||||
|
||||
self.ipc_handler.update_status("alive")
|
||||
|
||||
# 等待命令循环
|
||||
try:
|
||||
while True:
|
||||
should_continue = await self.ipc_handler.process_commands()
|
||||
if not should_continue:
|
||||
break
|
||||
await asyncio.sleep(0.5) # 轮询间隔
|
||||
except KeyboardInterrupt:
|
||||
print("\n收到中断信号")
|
||||
except Exception as e:
|
||||
print(f"\n命令处理出错: {e}")
|
||||
|
||||
print("\n关闭环境...")
|
||||
|
||||
# 关闭环境
|
||||
self.ipc_handler.update_status("stopped")
|
||||
await self.env.close()
|
||||
|
||||
print("环境已关闭")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
|
|
@ -399,6 +695,12 @@ async def main():
|
|||
default=None,
|
||||
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-wait',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='模拟完成后立即关闭环境,不进入等待命令模式'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -410,7 +712,10 @@ async def main():
|
|||
simulation_dir = os.path.dirname(args.config) or "."
|
||||
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
||||
|
||||
runner = RedditSimulationRunner(args.config)
|
||||
runner = RedditSimulationRunner(
|
||||
config_path=args.config,
|
||||
wait_for_commands=not args.no_wait
|
||||
)
|
||||
await runner.run(max_rounds=args.max_rounds)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,15 @@
|
|||
OASIS Twitter模拟预设脚本
|
||||
此脚本读取配置文件中的参数来执行模拟,实现全程自动化
|
||||
|
||||
功能特性:
|
||||
- 完成模拟后不立即关闭环境,进入等待命令模式
|
||||
- 支持通过IPC接收Interview命令
|
||||
- 支持单个Agent采访和批量采访
|
||||
- 支持远程关闭环境命令
|
||||
|
||||
使用方式:
|
||||
python run_twitter_simulation.py --config /path/to/simulation_config.json
|
||||
python run_twitter_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -13,8 +20,9 @@ import logging
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
# 添加项目路径
|
||||
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
|
@ -118,10 +126,261 @@ except ImportError as e:
|
|||
sys.exit(1)
|
||||
|
||||
|
||||
# IPC相关常量
|
||||
IPC_COMMANDS_DIR = "ipc_commands"
|
||||
IPC_RESPONSES_DIR = "ipc_responses"
|
||||
ENV_STATUS_FILE = "env_status.json"
|
||||
|
||||
class CommandType:
|
||||
"""命令类型常量"""
|
||||
INTERVIEW = "interview"
|
||||
BATCH_INTERVIEW = "batch_interview"
|
||||
CLOSE_ENV = "close_env"
|
||||
|
||||
|
||||
class IPCHandler:
|
||||
"""IPC命令处理器"""
|
||||
|
||||
def __init__(self, simulation_dir: str, env, agent_graph):
|
||||
self.simulation_dir = simulation_dir
|
||||
self.env = env
|
||||
self.agent_graph = agent_graph
|
||||
self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR)
|
||||
self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR)
|
||||
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
|
||||
self._running = True
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
def update_status(self, status: str):
|
||||
"""更新环境状态"""
|
||||
with open(self.status_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"status": status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def poll_command(self) -> Optional[Dict[str, Any]]:
|
||||
"""轮询获取待处理命令"""
|
||||
if not os.path.exists(self.commands_dir):
|
||||
return None
|
||||
|
||||
# 获取命令文件(按时间排序)
|
||||
command_files = []
|
||||
for filename in os.listdir(self.commands_dir):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(self.commands_dir, filename)
|
||||
command_files.append((filepath, os.path.getmtime(filepath)))
|
||||
|
||||
command_files.sort(key=lambda x: x[1])
|
||||
|
||||
for filepath, _ in command_files:
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
|
||||
"""发送响应"""
|
||||
response = {
|
||||
"command_id": command_id,
|
||||
"status": status,
|
||||
"result": result,
|
||||
"error": error,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
||||
with open(response_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(response, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 删除命令文件
|
||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||
try:
|
||||
os.remove(command_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool:
|
||||
"""
|
||||
处理单个Agent采访命令
|
||||
|
||||
Returns:
|
||||
True 表示成功,False 表示失败
|
||||
"""
|
||||
try:
|
||||
# 获取Agent
|
||||
agent = self.agent_graph.get_agent(agent_id)
|
||||
|
||||
# 创建Interview动作
|
||||
interview_action = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
|
||||
# 执行Interview
|
||||
actions = {agent: interview_action}
|
||||
await self.env.step(actions)
|
||||
|
||||
# 从数据库获取结果
|
||||
result = self._get_interview_result(agent_id)
|
||||
|
||||
self.send_response(command_id, "completed", result=result)
|
||||
print(f" Interview完成: agent_id={agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f" Interview失败: agent_id={agent_id}, error={error_msg}")
|
||||
self.send_response(command_id, "failed", error=error_msg)
|
||||
return False
|
||||
|
||||
async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool:
|
||||
"""
|
||||
处理批量采访命令
|
||||
|
||||
Args:
|
||||
interviews: [{"agent_id": int, "prompt": str}, ...]
|
||||
"""
|
||||
try:
|
||||
# 构建动作字典
|
||||
actions = {}
|
||||
agent_prompts = {} # 记录每个agent的prompt
|
||||
|
||||
for interview in interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
prompt = interview.get("prompt", "")
|
||||
|
||||
try:
|
||||
agent = self.agent_graph.get_agent(agent_id)
|
||||
actions[agent] = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
agent_prompts[agent_id] = prompt
|
||||
except Exception as e:
|
||||
print(f" 警告: 无法获取Agent {agent_id}: {e}")
|
||||
|
||||
if not actions:
|
||||
self.send_response(command_id, "failed", error="没有有效的Agent")
|
||||
return False
|
||||
|
||||
# 执行批量Interview
|
||||
await self.env.step(actions)
|
||||
|
||||
# 获取所有结果
|
||||
results = {}
|
||||
for agent_id in agent_prompts.keys():
|
||||
result = self._get_interview_result(agent_id)
|
||||
results[agent_id] = result
|
||||
|
||||
self.send_response(command_id, "completed", result={
|
||||
"interviews_count": len(results),
|
||||
"results": results
|
||||
})
|
||||
print(f" 批量Interview完成: {len(results)} 个Agent")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f" 批量Interview失败: {error_msg}")
|
||||
self.send_response(command_id, "failed", error=error_msg)
|
||||
return False
|
||||
|
||||
def _get_interview_result(self, agent_id: int) -> Dict[str, Any]:
|
||||
"""从数据库获取最新的Interview结果"""
|
||||
db_path = os.path.join(self.simulation_dir, "twitter_simulation.db")
|
||||
|
||||
result = {
|
||||
"agent_id": agent_id,
|
||||
"response": None,
|
||||
"timestamp": None
|
||||
}
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return result
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询最新的Interview记录
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
WHERE action = ? AND user_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""", (ActionType.INTERVIEW.value, agent_id))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
user_id, info_json, created_at = row
|
||||
try:
|
||||
info = json.loads(info_json) if info_json else {}
|
||||
result["response"] = info.get("response", info)
|
||||
result["timestamp"] = created_at
|
||||
except json.JSONDecodeError:
|
||||
result["response"] = info_json
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f" 读取Interview结果失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def process_commands(self) -> bool:
|
||||
"""
|
||||
处理所有待处理命令
|
||||
|
||||
Returns:
|
||||
True 表示继续运行,False 表示应该退出
|
||||
"""
|
||||
command = self.poll_command()
|
||||
if not command:
|
||||
return True
|
||||
|
||||
command_id = command.get("command_id")
|
||||
command_type = command.get("command_type")
|
||||
args = command.get("args", {})
|
||||
|
||||
print(f"\n收到IPC命令: {command_type}, id={command_id}")
|
||||
|
||||
if command_type == CommandType.INTERVIEW:
|
||||
await self.handle_interview(
|
||||
command_id,
|
||||
args.get("agent_id", 0),
|
||||
args.get("prompt", "")
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.BATCH_INTERVIEW:
|
||||
await self.handle_batch_interview(
|
||||
command_id,
|
||||
args.get("interviews", [])
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.CLOSE_ENV:
|
||||
print("收到关闭环境命令")
|
||||
self.send_response(command_id, "completed", result={"message": "环境即将关闭"})
|
||||
return False
|
||||
|
||||
else:
|
||||
self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}")
|
||||
return True
|
||||
|
||||
|
||||
class TwitterSimulationRunner:
|
||||
"""Twitter模拟运行器"""
|
||||
|
||||
# Twitter可用动作
|
||||
# Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发)
|
||||
AVAILABLE_ACTIONS = [
|
||||
ActionType.CREATE_POST,
|
||||
ActionType.LIKE_POST,
|
||||
|
|
@ -131,16 +390,21 @@ class TwitterSimulationRunner:
|
|||
ActionType.QUOTE_POST,
|
||||
]
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
def __init__(self, config_path: str, wait_for_commands: bool = True):
|
||||
"""
|
||||
初始化模拟运行器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径 (simulation_config.json)
|
||||
wait_for_commands: 模拟完成后是否等待命令(默认True)
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.config = self._load_config()
|
||||
self.simulation_dir = os.path.dirname(config_path)
|
||||
self.wait_for_commands = wait_for_commands
|
||||
self.env = None
|
||||
self.agent_graph = None
|
||||
self.ipc_handler = None
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
|
|
@ -269,6 +533,7 @@ class TwitterSimulationRunner:
|
|||
print("OASIS Twitter模拟")
|
||||
print(f"配置文件: {self.config_path}")
|
||||
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
|
||||
print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}")
|
||||
print("=" * 60)
|
||||
|
||||
# 加载时间配置
|
||||
|
|
@ -305,7 +570,7 @@ class TwitterSimulationRunner:
|
|||
print(f"错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
|
||||
agent_graph = await generate_twitter_agent_graph(
|
||||
self.agent_graph = await generate_twitter_agent_graph(
|
||||
profile_path=profile_path,
|
||||
model=model,
|
||||
available_actions=self.AVAILABLE_ACTIONS,
|
||||
|
|
@ -319,16 +584,20 @@ class TwitterSimulationRunner:
|
|||
|
||||
# 创建环境
|
||||
print("创建OASIS环境...")
|
||||
env = oasis.make(
|
||||
agent_graph=agent_graph,
|
||||
self.env = oasis.make(
|
||||
agent_graph=self.agent_graph,
|
||||
platform=oasis.DefaultPlatformType.TWITTER,
|
||||
database_path=db_path,
|
||||
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
await self.env.reset()
|
||||
print("环境初始化完成\n")
|
||||
|
||||
# 初始化IPC处理器
|
||||
self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph)
|
||||
self.ipc_handler.update_status("running")
|
||||
|
||||
# 执行初始事件
|
||||
event_config = self.config.get("event_config", {})
|
||||
initial_posts = event_config.get("initial_posts", [])
|
||||
|
|
@ -340,7 +609,7 @@ class TwitterSimulationRunner:
|
|||
agent_id = post.get("poster_agent_id", 0)
|
||||
content = post.get("content", "")
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
agent = self.env.agent_graph.get_agent(agent_id)
|
||||
initial_actions[agent] = ManualAction(
|
||||
action_type=ActionType.CREATE_POST,
|
||||
action_args={"content": content}
|
||||
|
|
@ -349,7 +618,7 @@ class TwitterSimulationRunner:
|
|||
print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}")
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
await self.env.step(initial_actions)
|
||||
print(f" 已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 主模拟循环
|
||||
|
|
@ -364,7 +633,7 @@ class TwitterSimulationRunner:
|
|||
|
||||
# 获取本轮激活的Agent
|
||||
active_agents = self._get_active_agents_for_round(
|
||||
env, simulated_hour, round_num
|
||||
self.env, simulated_hour, round_num
|
||||
)
|
||||
|
||||
if not active_agents:
|
||||
|
|
@ -377,7 +646,7 @@ class TwitterSimulationRunner:
|
|||
}
|
||||
|
||||
# 执行动作
|
||||
await env.step(actions)
|
||||
await self.env.step(actions)
|
||||
|
||||
# 打印进度
|
||||
if (round_num + 1) % 10 == 0 or round_num == 0:
|
||||
|
|
@ -388,13 +657,39 @@ class TwitterSimulationRunner:
|
|||
f"- {len(active_agents)} agents active "
|
||||
f"- elapsed: {elapsed:.1f}s")
|
||||
|
||||
# 关闭环境
|
||||
await env.close()
|
||||
|
||||
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||
print(f"\n模拟完成!")
|
||||
print(f"\n模拟循环完成!")
|
||||
print(f" - 总耗时: {total_elapsed:.1f}秒")
|
||||
print(f" - 数据库: {db_path}")
|
||||
|
||||
# 是否进入等待命令模式
|
||||
if self.wait_for_commands:
|
||||
print("\n" + "=" * 60)
|
||||
print("进入等待命令模式 - 环境保持运行")
|
||||
print("支持的命令: interview, batch_interview, close_env")
|
||||
print("=" * 60)
|
||||
|
||||
self.ipc_handler.update_status("alive")
|
||||
|
||||
# 等待命令循环
|
||||
try:
|
||||
while True:
|
||||
should_continue = await self.ipc_handler.process_commands()
|
||||
if not should_continue:
|
||||
break
|
||||
await asyncio.sleep(0.5) # 轮询间隔
|
||||
except KeyboardInterrupt:
|
||||
print("\n收到中断信号")
|
||||
except Exception as e:
|
||||
print(f"\n命令处理出错: {e}")
|
||||
|
||||
print("\n关闭环境...")
|
||||
|
||||
# 关闭环境
|
||||
self.ipc_handler.update_status("stopped")
|
||||
await self.env.close()
|
||||
|
||||
print("环境已关闭")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
|
|
@ -412,6 +707,12 @@ async def main():
|
|||
default=None,
|
||||
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-wait',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='模拟完成后立即关闭环境,不进入等待命令模式'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -423,10 +724,12 @@ async def main():
|
|||
simulation_dir = os.path.dirname(args.config) or "."
|
||||
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
||||
|
||||
runner = TwitterSimulationRunner(args.config)
|
||||
runner = TwitterSimulationRunner(
|
||||
config_path=args.config,
|
||||
wait_for_commands=not args.no_wait
|
||||
)
|
||||
await runner.run(max_rounds=args.max_rounds)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue