- Added a new Interview module to facilitate interactions with agents post-simulation, allowing for single and batch interviews. - Introduced IPC communication mechanism for command and response handling between the Flask backend and simulation scripts. - Updated README.md to include detailed instructions on the new Interview functionality, including API endpoints and usage examples. - Enhanced simulation scripts to support waiting for commands after completion, improving user control over the simulation environment. - Implemented error handling and logging for interview processes, ensuring robust operation and traceability.
394 lines
12 KiB
Python
394 lines
12 KiB
Python
"""
|
||
模拟IPC通信模块
|
||
用于Flask后端和模拟脚本之间的进程间通信
|
||
|
||
通过文件系统实现简单的命令/响应模式:
|
||
1. Flask写入命令到 commands/ 目录
|
||
2. 模拟脚本轮询命令目录,执行命令并写入响应到 responses/ 目录
|
||
3. Flask轮询响应目录获取结果
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import time
|
||
import uuid
|
||
from typing import Dict, Any, Optional, List
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
|
||
from ..utils.logger import get_logger
|
||
|
||
logger = get_logger('mirofish.simulation_ipc')
|
||
|
||
|
||
class CommandType(str, Enum):
|
||
"""命令类型"""
|
||
INTERVIEW = "interview" # 单个Agent采访
|
||
BATCH_INTERVIEW = "batch_interview" # 批量采访
|
||
CLOSE_ENV = "close_env" # 关闭环境
|
||
|
||
|
||
class CommandStatus(str, Enum):
|
||
"""命令状态"""
|
||
PENDING = "pending"
|
||
PROCESSING = "processing"
|
||
COMPLETED = "completed"
|
||
FAILED = "failed"
|
||
|
||
|
||
@dataclass
|
||
class IPCCommand:
|
||
"""IPC命令"""
|
||
command_id: str
|
||
command_type: CommandType
|
||
args: Dict[str, Any]
|
||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"command_id": self.command_id,
|
||
"command_type": self.command_type.value,
|
||
"args": self.args,
|
||
"timestamp": self.timestamp
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: Dict[str, Any]) -> 'IPCCommand':
|
||
return cls(
|
||
command_id=data["command_id"],
|
||
command_type=CommandType(data["command_type"]),
|
||
args=data.get("args", {}),
|
||
timestamp=data.get("timestamp", datetime.now().isoformat())
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class IPCResponse:
|
||
"""IPC响应"""
|
||
command_id: str
|
||
status: CommandStatus
|
||
result: Optional[Dict[str, Any]] = None
|
||
error: Optional[str] = None
|
||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"command_id": self.command_id,
|
||
"status": self.status.value,
|
||
"result": self.result,
|
||
"error": self.error,
|
||
"timestamp": self.timestamp
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: Dict[str, Any]) -> 'IPCResponse':
|
||
return cls(
|
||
command_id=data["command_id"],
|
||
status=CommandStatus(data["status"]),
|
||
result=data.get("result"),
|
||
error=data.get("error"),
|
||
timestamp=data.get("timestamp", datetime.now().isoformat())
|
||
)
|
||
|
||
|
||
class SimulationIPCClient:
|
||
"""
|
||
模拟IPC客户端(Flask端使用)
|
||
|
||
用于向模拟进程发送命令并等待响应
|
||
"""
|
||
|
||
def __init__(self, simulation_dir: str):
|
||
"""
|
||
初始化IPC客户端
|
||
|
||
Args:
|
||
simulation_dir: 模拟数据目录
|
||
"""
|
||
self.simulation_dir = simulation_dir
|
||
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
||
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
||
|
||
# 确保目录存在
|
||
os.makedirs(self.commands_dir, exist_ok=True)
|
||
os.makedirs(self.responses_dir, exist_ok=True)
|
||
|
||
def send_command(
|
||
self,
|
||
command_type: CommandType,
|
||
args: Dict[str, Any],
|
||
timeout: float = 60.0,
|
||
poll_interval: float = 0.5
|
||
) -> IPCResponse:
|
||
"""
|
||
发送命令并等待响应
|
||
|
||
Args:
|
||
command_type: 命令类型
|
||
args: 命令参数
|
||
timeout: 超时时间(秒)
|
||
poll_interval: 轮询间隔(秒)
|
||
|
||
Returns:
|
||
IPCResponse
|
||
|
||
Raises:
|
||
TimeoutError: 等待响应超时
|
||
"""
|
||
command_id = str(uuid.uuid4())
|
||
command = IPCCommand(
|
||
command_id=command_id,
|
||
command_type=command_type,
|
||
args=args
|
||
)
|
||
|
||
# 写入命令文件
|
||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||
with open(command_file, 'w', encoding='utf-8') as f:
|
||
json.dump(command.to_dict(), f, ensure_ascii=False, indent=2)
|
||
|
||
logger.info(f"发送IPC命令: {command_type.value}, command_id={command_id}")
|
||
|
||
# 等待响应
|
||
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
||
start_time = time.time()
|
||
|
||
while time.time() - start_time < timeout:
|
||
if os.path.exists(response_file):
|
||
try:
|
||
with open(response_file, 'r', encoding='utf-8') as f:
|
||
response_data = json.load(f)
|
||
response = IPCResponse.from_dict(response_data)
|
||
|
||
# 清理命令和响应文件
|
||
try:
|
||
os.remove(command_file)
|
||
os.remove(response_file)
|
||
except OSError:
|
||
pass
|
||
|
||
logger.info(f"收到IPC响应: command_id={command_id}, status={response.status.value}")
|
||
return response
|
||
except (json.JSONDecodeError, KeyError) as e:
|
||
logger.warning(f"解析响应失败: {e}")
|
||
|
||
time.sleep(poll_interval)
|
||
|
||
# 超时
|
||
logger.error(f"等待IPC响应超时: command_id={command_id}")
|
||
|
||
# 清理命令文件
|
||
try:
|
||
os.remove(command_file)
|
||
except OSError:
|
||
pass
|
||
|
||
raise TimeoutError(f"等待命令响应超时 ({timeout}秒)")
|
||
|
||
def send_interview(
|
||
self,
|
||
agent_id: int,
|
||
prompt: str,
|
||
platform: str = None,
|
||
timeout: float = 60.0
|
||
) -> IPCResponse:
|
||
"""
|
||
发送单个Agent采访命令
|
||
|
||
Args:
|
||
agent_id: Agent ID
|
||
prompt: 采访问题
|
||
platform: 指定平台(可选)
|
||
- "twitter": 只采访Twitter平台
|
||
- "reddit": 只采访Reddit平台
|
||
- None: 双平台模拟时同时采访两个平台,单平台模拟时采访该平台
|
||
timeout: 超时时间
|
||
|
||
Returns:
|
||
IPCResponse,result字段包含采访结果
|
||
"""
|
||
args = {
|
||
"agent_id": agent_id,
|
||
"prompt": prompt
|
||
}
|
||
if platform:
|
||
args["platform"] = platform
|
||
|
||
return self.send_command(
|
||
command_type=CommandType.INTERVIEW,
|
||
args=args,
|
||
timeout=timeout
|
||
)
|
||
|
||
def send_batch_interview(
|
||
self,
|
||
interviews: List[Dict[str, Any]],
|
||
platform: str = None,
|
||
timeout: float = 120.0
|
||
) -> IPCResponse:
|
||
"""
|
||
发送批量采访命令
|
||
|
||
Args:
|
||
interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)}
|
||
platform: 默认平台(可选,会被每个采访项的platform覆盖)
|
||
- "twitter": 默认只采访Twitter平台
|
||
- "reddit": 默认只采访Reddit平台
|
||
- None: 双平台模拟时每个Agent同时采访两个平台
|
||
timeout: 超时时间
|
||
|
||
Returns:
|
||
IPCResponse,result字段包含所有采访结果
|
||
"""
|
||
args = {"interviews": interviews}
|
||
if platform:
|
||
args["platform"] = platform
|
||
|
||
return self.send_command(
|
||
command_type=CommandType.BATCH_INTERVIEW,
|
||
args=args,
|
||
timeout=timeout
|
||
)
|
||
|
||
def send_close_env(self, timeout: float = 30.0) -> IPCResponse:
|
||
"""
|
||
发送关闭环境命令
|
||
|
||
Args:
|
||
timeout: 超时时间
|
||
|
||
Returns:
|
||
IPCResponse
|
||
"""
|
||
return self.send_command(
|
||
command_type=CommandType.CLOSE_ENV,
|
||
args={},
|
||
timeout=timeout
|
||
)
|
||
|
||
def check_env_alive(self) -> bool:
|
||
"""
|
||
检查模拟环境是否存活
|
||
|
||
通过检查 env_status.json 文件来判断
|
||
"""
|
||
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
||
if not os.path.exists(status_file):
|
||
return False
|
||
|
||
try:
|
||
with open(status_file, 'r', encoding='utf-8') as f:
|
||
status = json.load(f)
|
||
return status.get("status") == "alive"
|
||
except (json.JSONDecodeError, OSError):
|
||
return False
|
||
|
||
|
||
class SimulationIPCServer:
|
||
"""
|
||
模拟IPC服务器(模拟脚本端使用)
|
||
|
||
轮询命令目录,执行命令并返回响应
|
||
"""
|
||
|
||
def __init__(self, simulation_dir: str):
|
||
"""
|
||
初始化IPC服务器
|
||
|
||
Args:
|
||
simulation_dir: 模拟数据目录
|
||
"""
|
||
self.simulation_dir = simulation_dir
|
||
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
||
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
||
|
||
# 确保目录存在
|
||
os.makedirs(self.commands_dir, exist_ok=True)
|
||
os.makedirs(self.responses_dir, exist_ok=True)
|
||
|
||
# 环境状态
|
||
self._running = False
|
||
|
||
def start(self):
|
||
"""标记服务器为运行状态"""
|
||
self._running = True
|
||
self._update_env_status("alive")
|
||
|
||
def stop(self):
|
||
"""标记服务器为停止状态"""
|
||
self._running = False
|
||
self._update_env_status("stopped")
|
||
|
||
def _update_env_status(self, status: str):
|
||
"""更新环境状态文件"""
|
||
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
||
with open(status_file, 'w', encoding='utf-8') as f:
|
||
json.dump({
|
||
"status": status,
|
||
"timestamp": datetime.now().isoformat()
|
||
}, f, ensure_ascii=False, indent=2)
|
||
|
||
def poll_commands(self) -> Optional[IPCCommand]:
|
||
"""
|
||
轮询命令目录,返回第一个待处理的命令
|
||
|
||
Returns:
|
||
IPCCommand 或 None
|
||
"""
|
||
if not os.path.exists(self.commands_dir):
|
||
return None
|
||
|
||
# 按时间排序获取命令文件
|
||
command_files = []
|
||
for filename in os.listdir(self.commands_dir):
|
||
if filename.endswith('.json'):
|
||
filepath = os.path.join(self.commands_dir, filename)
|
||
command_files.append((filepath, os.path.getmtime(filepath)))
|
||
|
||
command_files.sort(key=lambda x: x[1])
|
||
|
||
for filepath, _ in command_files:
|
||
try:
|
||
with open(filepath, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
return IPCCommand.from_dict(data)
|
||
except (json.JSONDecodeError, KeyError, OSError) as e:
|
||
logger.warning(f"读取命令文件失败: {filepath}, {e}")
|
||
continue
|
||
|
||
return None
|
||
|
||
def send_response(self, response: IPCResponse):
|
||
"""
|
||
发送响应
|
||
|
||
Args:
|
||
response: IPC响应
|
||
"""
|
||
response_file = os.path.join(self.responses_dir, f"{response.command_id}.json")
|
||
with open(response_file, 'w', encoding='utf-8') as f:
|
||
json.dump(response.to_dict(), f, ensure_ascii=False, indent=2)
|
||
|
||
# 删除命令文件
|
||
command_file = os.path.join(self.commands_dir, f"{response.command_id}.json")
|
||
try:
|
||
os.remove(command_file)
|
||
except OSError:
|
||
pass
|
||
|
||
def send_success(self, command_id: str, result: Dict[str, Any]):
|
||
"""发送成功响应"""
|
||
self.send_response(IPCResponse(
|
||
command_id=command_id,
|
||
status=CommandStatus.COMPLETED,
|
||
result=result
|
||
))
|
||
|
||
def send_error(self, command_id: str, error: str):
|
||
"""发送错误响应"""
|
||
self.send_response(IPCResponse(
|
||
command_id=command_id,
|
||
status=CommandStatus.FAILED,
|
||
error=error
|
||
))
|