1713 lines
66 KiB
Python
1713 lines
66 KiB
Python
"""
|
||
OASIS模拟运行器
|
||
在后台运行模拟并记录每个Agent的动作,支持实时状态监控
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import time
|
||
import asyncio
|
||
import threading
|
||
import subprocess
|
||
import signal
|
||
import atexit
|
||
from typing import Dict, Any, List, Optional, Union
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
from queue import Queue
|
||
|
||
from ..config import Config
|
||
from ..utils.logger import get_logger
|
||
from .zep_graph_memory_updater import ZepGraphMemoryManager
|
||
from .simulation_ipc import SimulationIPCClient, CommandType, IPCResponse
|
||
|
||
logger = get_logger('mirofish.simulation_runner')
|
||
|
||
# 标记是否已注册清理函数
|
||
_cleanup_registered = False
|
||
|
||
|
||
class RunnerStatus(str, Enum):
|
||
"""运行器状态"""
|
||
IDLE = "idle"
|
||
STARTING = "starting"
|
||
RUNNING = "running"
|
||
PAUSED = "paused"
|
||
STOPPING = "stopping"
|
||
STOPPED = "stopped"
|
||
COMPLETED = "completed"
|
||
FAILED = "failed"
|
||
|
||
|
||
@dataclass
|
||
class AgentAction:
|
||
"""Agent动作记录"""
|
||
round_num: int
|
||
timestamp: str
|
||
platform: str # twitter / reddit
|
||
agent_id: int
|
||
agent_name: str
|
||
action_type: str # CREATE_POST, LIKE_POST, etc.
|
||
action_args: Dict[str, Any] = field(default_factory=dict)
|
||
result: Optional[str] = None
|
||
success: bool = True
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"round_num": self.round_num,
|
||
"timestamp": self.timestamp,
|
||
"platform": self.platform,
|
||
"agent_id": self.agent_id,
|
||
"agent_name": self.agent_name,
|
||
"action_type": self.action_type,
|
||
"action_args": self.action_args,
|
||
"result": self.result,
|
||
"success": self.success,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class RoundSummary:
|
||
"""每轮摘要"""
|
||
round_num: int
|
||
start_time: str
|
||
end_time: Optional[str] = None
|
||
simulated_hour: int = 0
|
||
twitter_actions: int = 0
|
||
reddit_actions: int = 0
|
||
active_agents: List[int] = field(default_factory=list)
|
||
actions: List[AgentAction] = field(default_factory=list)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"round_num": self.round_num,
|
||
"start_time": self.start_time,
|
||
"end_time": self.end_time,
|
||
"simulated_hour": self.simulated_hour,
|
||
"twitter_actions": self.twitter_actions,
|
||
"reddit_actions": self.reddit_actions,
|
||
"active_agents": self.active_agents,
|
||
"actions_count": len(self.actions),
|
||
"actions": [a.to_dict() for a in self.actions],
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class SimulationRunState:
|
||
"""模拟运行状态(实时)"""
|
||
simulation_id: str
|
||
runner_status: RunnerStatus = RunnerStatus.IDLE
|
||
|
||
# 进度信息
|
||
current_round: int = 0
|
||
total_rounds: int = 0
|
||
simulated_hours: int = 0
|
||
total_simulation_hours: int = 0
|
||
|
||
# 各平台独立轮次和模拟时间(用于双平台并行显示)
|
||
twitter_current_round: int = 0
|
||
reddit_current_round: int = 0
|
||
twitter_simulated_hours: int = 0
|
||
reddit_simulated_hours: int = 0
|
||
|
||
# 平台状态
|
||
twitter_running: bool = False
|
||
reddit_running: bool = False
|
||
twitter_actions_count: int = 0
|
||
reddit_actions_count: int = 0
|
||
|
||
# 平台完成状态(通过检测 actions.jsonl 中的 simulation_end 事件)
|
||
twitter_completed: bool = False
|
||
reddit_completed: bool = False
|
||
|
||
# 每轮摘要
|
||
rounds: List[RoundSummary] = field(default_factory=list)
|
||
|
||
# 最近动作(用于前端实时展示)
|
||
recent_actions: List[AgentAction] = field(default_factory=list)
|
||
max_recent_actions: int = 50
|
||
|
||
# 时间戳
|
||
started_at: Optional[str] = None
|
||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||
completed_at: Optional[str] = None
|
||
|
||
# 错误信息
|
||
error: Optional[str] = None
|
||
|
||
# 进程ID(用于停止)
|
||
process_pid: Optional[int] = None
|
||
|
||
def add_action(self, action: AgentAction):
|
||
"""添加动作到最近动作列表"""
|
||
self.recent_actions.insert(0, action)
|
||
if len(self.recent_actions) > self.max_recent_actions:
|
||
self.recent_actions = self.recent_actions[:self.max_recent_actions]
|
||
|
||
if action.platform == "twitter":
|
||
self.twitter_actions_count += 1
|
||
else:
|
||
self.reddit_actions_count += 1
|
||
|
||
self.updated_at = datetime.now().isoformat()
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"simulation_id": self.simulation_id,
|
||
"runner_status": self.runner_status.value,
|
||
"current_round": self.current_round,
|
||
"total_rounds": self.total_rounds,
|
||
"simulated_hours": self.simulated_hours,
|
||
"total_simulation_hours": self.total_simulation_hours,
|
||
"progress_percent": round(self.current_round / max(self.total_rounds, 1) * 100, 1),
|
||
# 各平台独立轮次和时间
|
||
"twitter_current_round": self.twitter_current_round,
|
||
"reddit_current_round": self.reddit_current_round,
|
||
"twitter_simulated_hours": self.twitter_simulated_hours,
|
||
"reddit_simulated_hours": self.reddit_simulated_hours,
|
||
"twitter_running": self.twitter_running,
|
||
"reddit_running": self.reddit_running,
|
||
"twitter_completed": self.twitter_completed,
|
||
"reddit_completed": self.reddit_completed,
|
||
"twitter_actions_count": self.twitter_actions_count,
|
||
"reddit_actions_count": self.reddit_actions_count,
|
||
"total_actions_count": self.twitter_actions_count + self.reddit_actions_count,
|
||
"started_at": self.started_at,
|
||
"updated_at": self.updated_at,
|
||
"completed_at": self.completed_at,
|
||
"error": self.error,
|
||
"process_pid": self.process_pid,
|
||
}
|
||
|
||
def to_detail_dict(self) -> Dict[str, Any]:
|
||
"""包含最近动作的详细信息"""
|
||
result = self.to_dict()
|
||
result["recent_actions"] = [a.to_dict() for a in self.recent_actions]
|
||
result["rounds_count"] = len(self.rounds)
|
||
return result
|
||
|
||
|
||
class SimulationRunner:
|
||
"""
|
||
模拟运行器
|
||
|
||
负责:
|
||
1. 在后台进程中运行OASIS模拟
|
||
2. 解析运行日志,记录每个Agent的动作
|
||
3. 提供实时状态查询接口
|
||
4. 支持暂停/停止/恢复操作
|
||
"""
|
||
|
||
# 运行状态存储目录
|
||
RUN_STATE_DIR = os.path.join(
|
||
os.path.dirname(__file__),
|
||
'../../uploads/simulations'
|
||
)
|
||
|
||
# 脚本目录
|
||
SCRIPTS_DIR = os.path.join(
|
||
os.path.dirname(__file__),
|
||
'../../scripts'
|
||
)
|
||
|
||
# 内存中的运行状态
|
||
_run_states: Dict[str, SimulationRunState] = {}
|
||
_processes: Dict[str, subprocess.Popen] = {}
|
||
_action_queues: Dict[str, Queue] = {}
|
||
_monitor_threads: Dict[str, threading.Thread] = {}
|
||
_stdout_files: Dict[str, Any] = {} # 存储 stdout 文件句柄
|
||
_stderr_files: Dict[str, Any] = {} # 存储 stderr 文件句柄
|
||
|
||
# 图谱记忆更新配置
|
||
_graph_memory_enabled: Dict[str, bool] = {} # simulation_id -> enabled
|
||
|
||
@classmethod
|
||
def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
|
||
"""获取运行状态"""
|
||
if simulation_id in cls._run_states:
|
||
return cls._run_states[simulation_id]
|
||
|
||
# 尝试从文件加载
|
||
state = cls._load_run_state(simulation_id)
|
||
if state:
|
||
cls._run_states[simulation_id] = state
|
||
return state
|
||
|
||
@classmethod
|
||
def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
|
||
"""从文件加载运行状态"""
|
||
state_file = os.path.join(cls.RUN_STATE_DIR, simulation_id, "run_state.json")
|
||
if not os.path.exists(state_file):
|
||
return None
|
||
|
||
try:
|
||
with open(state_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
state = SimulationRunState(
|
||
simulation_id=simulation_id,
|
||
runner_status=RunnerStatus(data.get("runner_status", "idle")),
|
||
current_round=data.get("current_round", 0),
|
||
total_rounds=data.get("total_rounds", 0),
|
||
simulated_hours=data.get("simulated_hours", 0),
|
||
total_simulation_hours=data.get("total_simulation_hours", 0),
|
||
# 各平台独立轮次和时间
|
||
twitter_current_round=data.get("twitter_current_round", 0),
|
||
reddit_current_round=data.get("reddit_current_round", 0),
|
||
twitter_simulated_hours=data.get("twitter_simulated_hours", 0),
|
||
reddit_simulated_hours=data.get("reddit_simulated_hours", 0),
|
||
twitter_running=data.get("twitter_running", False),
|
||
reddit_running=data.get("reddit_running", False),
|
||
twitter_completed=data.get("twitter_completed", False),
|
||
reddit_completed=data.get("reddit_completed", False),
|
||
twitter_actions_count=data.get("twitter_actions_count", 0),
|
||
reddit_actions_count=data.get("reddit_actions_count", 0),
|
||
started_at=data.get("started_at"),
|
||
updated_at=data.get("updated_at", datetime.now().isoformat()),
|
||
completed_at=data.get("completed_at"),
|
||
error=data.get("error"),
|
||
process_pid=data.get("process_pid"),
|
||
)
|
||
|
||
# 加载最近动作
|
||
actions_data = data.get("recent_actions", [])
|
||
for a in actions_data:
|
||
state.recent_actions.append(AgentAction(
|
||
round_num=a.get("round_num", 0),
|
||
timestamp=a.get("timestamp", ""),
|
||
platform=a.get("platform", ""),
|
||
agent_id=a.get("agent_id", 0),
|
||
agent_name=a.get("agent_name", ""),
|
||
action_type=a.get("action_type", ""),
|
||
action_args=a.get("action_args", {}),
|
||
result=a.get("result"),
|
||
success=a.get("success", True),
|
||
))
|
||
|
||
return state
|
||
except Exception as e:
|
||
logger.error(f"加载运行状态失败: {str(e)}")
|
||
return None
|
||
|
||
@classmethod
|
||
def _save_run_state(cls, state: SimulationRunState):
|
||
"""保存运行状态到文件"""
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id)
|
||
os.makedirs(sim_dir, exist_ok=True)
|
||
state_file = os.path.join(sim_dir, "run_state.json")
|
||
|
||
data = state.to_detail_dict()
|
||
|
||
with open(state_file, 'w', encoding='utf-8') as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
|
||
cls._run_states[state.simulation_id] = state
|
||
|
||
@classmethod
|
||
def start_simulation(
|
||
cls,
|
||
simulation_id: str,
|
||
platform: str = "parallel", # twitter / reddit / parallel
|
||
max_rounds: int = None, # 最大模拟轮数(可选,用于截断过长的模拟)
|
||
enable_graph_memory_update: bool = False, # 是否将活动更新到Zep图谱
|
||
graph_id: str = None # Zep图谱ID(启用图谱更新时必需)
|
||
) -> SimulationRunState:
|
||
"""
|
||
启动模拟
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
platform: 运行平台 (twitter/reddit/parallel)
|
||
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||
enable_graph_memory_update: 是否将Agent活动动态更新到Zep图谱
|
||
graph_id: Zep图谱ID(启用图谱更新时必需)
|
||
|
||
Returns:
|
||
SimulationRunState
|
||
"""
|
||
# 检查是否已在运行
|
||
existing = cls.get_run_state(simulation_id)
|
||
if existing and existing.runner_status in [RunnerStatus.RUNNING, RunnerStatus.STARTING]:
|
||
raise ValueError(f"模拟已在运行中: {simulation_id}")
|
||
|
||
# 加载模拟配置
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||
|
||
if not os.path.exists(config_path):
|
||
raise ValueError(f"模拟配置不存在,请先调用 /prepare 接口")
|
||
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
|
||
# 初始化运行状态
|
||
time_config = config.get("time_config", {})
|
||
total_hours = time_config.get("total_simulation_hours", 72)
|
||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||
total_rounds = int(total_hours * 60 / minutes_per_round)
|
||
|
||
# 如果指定了最大轮数,则截断
|
||
if max_rounds is not None and max_rounds > 0:
|
||
original_rounds = total_rounds
|
||
total_rounds = min(total_rounds, max_rounds)
|
||
if total_rounds < original_rounds:
|
||
logger.info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
|
||
|
||
state = SimulationRunState(
|
||
simulation_id=simulation_id,
|
||
runner_status=RunnerStatus.STARTING,
|
||
total_rounds=total_rounds,
|
||
total_simulation_hours=total_hours,
|
||
started_at=datetime.now().isoformat(),
|
||
)
|
||
|
||
cls._save_run_state(state)
|
||
|
||
# 如果启用图谱记忆更新,创建更新器
|
||
if enable_graph_memory_update:
|
||
if not graph_id:
|
||
raise ValueError("启用图谱记忆更新时必须提供 graph_id")
|
||
|
||
try:
|
||
ZepGraphMemoryManager.create_updater(simulation_id, graph_id)
|
||
cls._graph_memory_enabled[simulation_id] = True
|
||
logger.info(f"已启用图谱记忆更新: simulation_id={simulation_id}, graph_id={graph_id}")
|
||
except Exception as e:
|
||
logger.error(f"创建图谱记忆更新器失败: {e}")
|
||
cls._graph_memory_enabled[simulation_id] = False
|
||
else:
|
||
cls._graph_memory_enabled[simulation_id] = False
|
||
|
||
# 确定运行哪个脚本(脚本位于 backend/scripts/ 目录)
|
||
if platform == "twitter":
|
||
script_name = "run_twitter_simulation.py"
|
||
state.twitter_running = True
|
||
elif platform == "reddit":
|
||
script_name = "run_reddit_simulation.py"
|
||
state.reddit_running = True
|
||
else:
|
||
script_name = "run_parallel_simulation.py"
|
||
state.twitter_running = True
|
||
state.reddit_running = True
|
||
|
||
script_path = os.path.join(cls.SCRIPTS_DIR, script_name)
|
||
|
||
if not os.path.exists(script_path):
|
||
raise ValueError(f"脚本不存在: {script_path}")
|
||
|
||
# 创建动作队列
|
||
action_queue = Queue()
|
||
cls._action_queues[simulation_id] = action_queue
|
||
|
||
# 启动模拟进程
|
||
try:
|
||
# 构建运行命令,使用完整路径
|
||
# 新的日志结构:
|
||
# twitter/actions.jsonl - Twitter 动作日志
|
||
# reddit/actions.jsonl - Reddit 动作日志
|
||
# simulation.log - 主进程日志
|
||
|
||
cmd = [
|
||
sys.executable, # Python解释器
|
||
script_path,
|
||
"--config", config_path, # 使用完整配置文件路径
|
||
]
|
||
|
||
# 如果指定了最大轮数,添加到命令行参数
|
||
if max_rounds is not None and max_rounds > 0:
|
||
cmd.extend(["--max-rounds", str(max_rounds)])
|
||
|
||
# 创建主日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
|
||
main_log_path = os.path.join(sim_dir, "simulation.log")
|
||
main_log_file = open(main_log_path, 'w', encoding='utf-8')
|
||
|
||
# 设置子进程环境变量,确保 Windows 上使用 UTF-8 编码
|
||
# 这可以修复第三方库(如 OASIS)读取文件时未指定编码的问题
|
||
env = os.environ.copy()
|
||
env['PYTHONUTF8'] = '1' # Python 3.7+ 支持,让所有 open() 默认使用 UTF-8
|
||
env['PYTHONIOENCODING'] = 'utf-8' # 确保 stdout/stderr 使用 UTF-8
|
||
|
||
# 设置工作目录为模拟目录(数据库等文件会生成在此)
|
||
# 使用 start_new_session=True 创建新的进程组,确保可以通过 os.killpg 终止所有子进程
|
||
process = subprocess.Popen(
|
||
cmd,
|
||
cwd=sim_dir,
|
||
stdout=main_log_file,
|
||
stderr=subprocess.STDOUT, # stderr 也写入同一个文件
|
||
text=True,
|
||
encoding='utf-8', # 显式指定编码
|
||
bufsize=1,
|
||
env=env, # 传递带有 UTF-8 设置的环境变量
|
||
start_new_session=True, # 创建新进程组,确保服务器关闭时能终止所有相关进程
|
||
)
|
||
|
||
# 保存文件句柄以便后续关闭
|
||
cls._stdout_files[simulation_id] = main_log_file
|
||
cls._stderr_files[simulation_id] = None # 不再需要单独的 stderr
|
||
|
||
state.process_pid = process.pid
|
||
state.runner_status = RunnerStatus.RUNNING
|
||
cls._processes[simulation_id] = process
|
||
cls._save_run_state(state)
|
||
|
||
# 启动监控线程
|
||
monitor_thread = threading.Thread(
|
||
target=cls._monitor_simulation,
|
||
args=(simulation_id,),
|
||
daemon=True
|
||
)
|
||
monitor_thread.start()
|
||
cls._monitor_threads[simulation_id] = monitor_thread
|
||
|
||
logger.info(f"模拟启动成功: {simulation_id}, pid={process.pid}, platform={platform}")
|
||
|
||
except Exception as e:
|
||
state.runner_status = RunnerStatus.FAILED
|
||
state.error = str(e)
|
||
cls._save_run_state(state)
|
||
raise
|
||
|
||
return state
|
||
|
||
@classmethod
|
||
def _monitor_simulation(cls, simulation_id: str):
|
||
"""监控模拟进程,解析动作日志"""
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||
|
||
# 新的日志结构:分平台的动作日志
|
||
twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl")
|
||
reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl")
|
||
|
||
process = cls._processes.get(simulation_id)
|
||
state = cls.get_run_state(simulation_id)
|
||
|
||
if not process or not state:
|
||
return
|
||
|
||
twitter_position = 0
|
||
reddit_position = 0
|
||
|
||
try:
|
||
while process.poll() is None: # 进程仍在运行
|
||
# 读取 Twitter 动作日志
|
||
if os.path.exists(twitter_actions_log):
|
||
twitter_position = cls._read_action_log(
|
||
twitter_actions_log, twitter_position, state, "twitter"
|
||
)
|
||
|
||
# 读取 Reddit 动作日志
|
||
if os.path.exists(reddit_actions_log):
|
||
reddit_position = cls._read_action_log(
|
||
reddit_actions_log, reddit_position, state, "reddit"
|
||
)
|
||
|
||
# 更新状态
|
||
cls._save_run_state(state)
|
||
time.sleep(2)
|
||
|
||
# 进程结束后,最后读取一次日志
|
||
if os.path.exists(twitter_actions_log):
|
||
cls._read_action_log(twitter_actions_log, twitter_position, state, "twitter")
|
||
if os.path.exists(reddit_actions_log):
|
||
cls._read_action_log(reddit_actions_log, reddit_position, state, "reddit")
|
||
|
||
# 进程结束
|
||
exit_code = process.returncode
|
||
|
||
if exit_code == 0:
|
||
state.runner_status = RunnerStatus.COMPLETED
|
||
state.completed_at = datetime.now().isoformat()
|
||
logger.info(f"模拟完成: {simulation_id}")
|
||
else:
|
||
state.runner_status = RunnerStatus.FAILED
|
||
# 从主日志文件读取错误信息
|
||
main_log_path = os.path.join(sim_dir, "simulation.log")
|
||
error_info = ""
|
||
try:
|
||
if os.path.exists(main_log_path):
|
||
with open(main_log_path, 'r', encoding='utf-8') as f:
|
||
error_info = f.read()[-2000:] # 取最后2000字符
|
||
except Exception:
|
||
pass
|
||
state.error = f"进程退出码: {exit_code}, 错误: {error_info}"
|
||
logger.error(f"模拟失败: {simulation_id}, error={state.error}")
|
||
|
||
state.twitter_running = False
|
||
state.reddit_running = False
|
||
cls._save_run_state(state)
|
||
|
||
except Exception as e:
|
||
logger.error(f"监控线程异常: {simulation_id}, error={str(e)}")
|
||
state.runner_status = RunnerStatus.FAILED
|
||
state.error = str(e)
|
||
cls._save_run_state(state)
|
||
|
||
finally:
|
||
# 停止图谱记忆更新器
|
||
if cls._graph_memory_enabled.get(simulation_id, False):
|
||
try:
|
||
ZepGraphMemoryManager.stop_updater(simulation_id)
|
||
logger.info(f"已停止图谱记忆更新: simulation_id={simulation_id}")
|
||
except Exception as e:
|
||
logger.error(f"停止图谱记忆更新器失败: {e}")
|
||
cls._graph_memory_enabled.pop(simulation_id, None)
|
||
|
||
# 清理进程资源
|
||
cls._processes.pop(simulation_id, None)
|
||
cls._action_queues.pop(simulation_id, None)
|
||
|
||
# 关闭日志文件句柄
|
||
if simulation_id in cls._stdout_files:
|
||
try:
|
||
cls._stdout_files[simulation_id].close()
|
||
except Exception:
|
||
pass
|
||
cls._stdout_files.pop(simulation_id, None)
|
||
if simulation_id in cls._stderr_files and cls._stderr_files[simulation_id]:
|
||
try:
|
||
cls._stderr_files[simulation_id].close()
|
||
except Exception:
|
||
pass
|
||
cls._stderr_files.pop(simulation_id, None)
|
||
|
||
@classmethod
|
||
def _read_action_log(
|
||
cls,
|
||
log_path: str,
|
||
position: int,
|
||
state: SimulationRunState,
|
||
platform: str
|
||
) -> int:
|
||
"""
|
||
读取动作日志文件
|
||
|
||
Args:
|
||
log_path: 日志文件路径
|
||
position: 上次读取位置
|
||
state: 运行状态对象
|
||
platform: 平台名称 (twitter/reddit)
|
||
|
||
Returns:
|
||
新的读取位置
|
||
"""
|
||
# 检查是否启用了图谱记忆更新
|
||
graph_memory_enabled = cls._graph_memory_enabled.get(state.simulation_id, False)
|
||
graph_updater = None
|
||
if graph_memory_enabled:
|
||
graph_updater = ZepGraphMemoryManager.get_updater(state.simulation_id)
|
||
|
||
try:
|
||
with open(log_path, 'r', encoding='utf-8') as f:
|
||
f.seek(position)
|
||
for line in f:
|
||
line = line.strip()
|
||
if line:
|
||
try:
|
||
action_data = json.loads(line)
|
||
|
||
# 处理事件类型的条目
|
||
if "event_type" in action_data:
|
||
event_type = action_data.get("event_type")
|
||
|
||
# 检测 simulation_end 事件,标记平台已完成
|
||
if event_type == "simulation_end":
|
||
if platform == "twitter":
|
||
state.twitter_completed = True
|
||
state.twitter_running = False
|
||
logger.info(f"Twitter 模拟已完成: {state.simulation_id}, total_rounds={action_data.get('total_rounds')}, total_actions={action_data.get('total_actions')}")
|
||
elif platform == "reddit":
|
||
state.reddit_completed = True
|
||
state.reddit_running = False
|
||
logger.info(f"Reddit 模拟已完成: {state.simulation_id}, total_rounds={action_data.get('total_rounds')}, total_actions={action_data.get('total_actions')}")
|
||
|
||
# 检查是否所有启用的平台都已完成
|
||
# 如果只运行了一个平台,只检查那个平台
|
||
# 如果运行了两个平台,需要两个都完成
|
||
all_completed = cls._check_all_platforms_completed(state)
|
||
if all_completed:
|
||
state.runner_status = RunnerStatus.COMPLETED
|
||
state.completed_at = datetime.now().isoformat()
|
||
logger.info(f"所有平台模拟已完成: {state.simulation_id}")
|
||
|
||
# 更新轮次信息(从 round_end 事件)
|
||
elif event_type == "round_end":
|
||
round_num = action_data.get("round", 0)
|
||
simulated_hours = action_data.get("simulated_hours", 0)
|
||
|
||
# 更新各平台独立的轮次和时间
|
||
if platform == "twitter":
|
||
if round_num > state.twitter_current_round:
|
||
state.twitter_current_round = round_num
|
||
state.twitter_simulated_hours = simulated_hours
|
||
elif platform == "reddit":
|
||
if round_num > state.reddit_current_round:
|
||
state.reddit_current_round = round_num
|
||
state.reddit_simulated_hours = simulated_hours
|
||
|
||
# 总体轮次取两个平台的最大值
|
||
if round_num > state.current_round:
|
||
state.current_round = round_num
|
||
# 总体时间取两个平台的最大值
|
||
state.simulated_hours = max(state.twitter_simulated_hours, state.reddit_simulated_hours)
|
||
|
||
continue
|
||
|
||
action = AgentAction(
|
||
round_num=action_data.get("round", 0),
|
||
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
|
||
platform=platform,
|
||
agent_id=action_data.get("agent_id", 0),
|
||
agent_name=action_data.get("agent_name", ""),
|
||
action_type=action_data.get("action_type", ""),
|
||
action_args=action_data.get("action_args", {}),
|
||
result=action_data.get("result"),
|
||
success=action_data.get("success", True),
|
||
)
|
||
state.add_action(action)
|
||
|
||
# 更新轮次
|
||
if action.round_num and action.round_num > state.current_round:
|
||
state.current_round = action.round_num
|
||
|
||
# 如果启用了图谱记忆更新,将活动发送到Zep
|
||
if graph_updater:
|
||
graph_updater.add_activity_from_dict(action_data, platform)
|
||
|
||
except json.JSONDecodeError:
|
||
pass
|
||
return f.tell()
|
||
except Exception as e:
|
||
logger.warning(f"读取动作日志失败: {log_path}, error={e}")
|
||
return position
|
||
|
||
@classmethod
|
||
def _check_all_platforms_completed(cls, state: SimulationRunState) -> bool:
|
||
"""
|
||
检查所有启用的平台是否都已完成模拟
|
||
|
||
通过检查对应的 actions.jsonl 文件是否存在来判断平台是否被启用
|
||
|
||
Returns:
|
||
True 如果所有启用的平台都已完成
|
||
"""
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id)
|
||
twitter_log = os.path.join(sim_dir, "twitter", "actions.jsonl")
|
||
reddit_log = os.path.join(sim_dir, "reddit", "actions.jsonl")
|
||
|
||
# 检查哪些平台被启用(通过文件是否存在判断)
|
||
twitter_enabled = os.path.exists(twitter_log)
|
||
reddit_enabled = os.path.exists(reddit_log)
|
||
|
||
# 如果平台被启用但未完成,则返回 False
|
||
if twitter_enabled and not state.twitter_completed:
|
||
return False
|
||
if reddit_enabled and not state.reddit_completed:
|
||
return False
|
||
|
||
# 至少有一个平台被启用且已完成
|
||
return twitter_enabled or reddit_enabled
|
||
|
||
@classmethod
|
||
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
|
||
"""停止模拟"""
|
||
state = cls.get_run_state(simulation_id)
|
||
if not state:
|
||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||
|
||
if state.runner_status not in [RunnerStatus.RUNNING, RunnerStatus.PAUSED]:
|
||
raise ValueError(f"模拟未在运行: {simulation_id}, status={state.runner_status}")
|
||
|
||
state.runner_status = RunnerStatus.STOPPING
|
||
cls._save_run_state(state)
|
||
|
||
# 终止进程
|
||
process = cls._processes.get(simulation_id)
|
||
if process and process.poll() is None:
|
||
try:
|
||
# 使用进程组 ID 终止整个进程组(包括所有子进程)
|
||
# 由于使用了 start_new_session=True,进程组 ID 等于主进程 PID
|
||
pgid = os.getpgid(process.pid)
|
||
logger.info(f"终止进程组: simulation={simulation_id}, pgid={pgid}")
|
||
|
||
# 先发送 SIGTERM 给整个进程组
|
||
os.killpg(pgid, signal.SIGTERM)
|
||
|
||
try:
|
||
process.wait(timeout=10)
|
||
except subprocess.TimeoutExpired:
|
||
# 如果 10 秒后还没结束,强制发送 SIGKILL
|
||
logger.warning(f"进程组未响应 SIGTERM,强制终止: {simulation_id}")
|
||
os.killpg(pgid, signal.SIGKILL)
|
||
process.wait(timeout=5)
|
||
|
||
except ProcessLookupError:
|
||
# 进程已经不存在
|
||
pass
|
||
except Exception as e:
|
||
logger.error(f"终止进程组失败: {simulation_id}, error={e}")
|
||
# 回退到直接终止进程
|
||
try:
|
||
process.terminate()
|
||
process.wait(timeout=5)
|
||
except Exception:
|
||
process.kill()
|
||
|
||
state.runner_status = RunnerStatus.STOPPED
|
||
state.twitter_running = False
|
||
state.reddit_running = False
|
||
state.completed_at = datetime.now().isoformat()
|
||
cls._save_run_state(state)
|
||
|
||
# 停止图谱记忆更新器
|
||
if cls._graph_memory_enabled.get(simulation_id, False):
|
||
try:
|
||
ZepGraphMemoryManager.stop_updater(simulation_id)
|
||
logger.info(f"已停止图谱记忆更新: simulation_id={simulation_id}")
|
||
except Exception as e:
|
||
logger.error(f"停止图谱记忆更新器失败: {e}")
|
||
cls._graph_memory_enabled.pop(simulation_id, None)
|
||
|
||
logger.info(f"模拟已停止: {simulation_id}")
|
||
return state
|
||
|
||
@classmethod
|
||
def _read_actions_from_file(
|
||
cls,
|
||
file_path: str,
|
||
default_platform: Optional[str] = None,
|
||
platform_filter: Optional[str] = None,
|
||
agent_id: Optional[int] = None,
|
||
round_num: Optional[int] = None
|
||
) -> List[AgentAction]:
|
||
"""
|
||
从单个动作文件中读取动作
|
||
|
||
Args:
|
||
file_path: 动作日志文件路径
|
||
default_platform: 默认平台(当动作记录中没有 platform 字段时使用)
|
||
platform_filter: 过滤平台
|
||
agent_id: 过滤 Agent ID
|
||
round_num: 过滤轮次
|
||
"""
|
||
if not os.path.exists(file_path):
|
||
return []
|
||
|
||
actions = []
|
||
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
try:
|
||
data = json.loads(line)
|
||
|
||
# 跳过非动作记录(如 simulation_start, round_start, round_end 等事件)
|
||
if "event_type" in data:
|
||
continue
|
||
|
||
# 跳过没有 agent_id 的记录(非 Agent 动作)
|
||
if "agent_id" not in data:
|
||
continue
|
||
|
||
# 获取平台:优先使用记录中的 platform,否则使用默认平台
|
||
record_platform = data.get("platform") or default_platform or ""
|
||
|
||
# 过滤
|
||
if platform_filter and record_platform != platform_filter:
|
||
continue
|
||
if agent_id is not None and data.get("agent_id") != agent_id:
|
||
continue
|
||
if round_num is not None and data.get("round") != round_num:
|
||
continue
|
||
|
||
actions.append(AgentAction(
|
||
round_num=data.get("round", 0),
|
||
timestamp=data.get("timestamp", ""),
|
||
platform=record_platform,
|
||
agent_id=data.get("agent_id", 0),
|
||
agent_name=data.get("agent_name", ""),
|
||
action_type=data.get("action_type", ""),
|
||
action_args=data.get("action_args", {}),
|
||
result=data.get("result"),
|
||
success=data.get("success", True),
|
||
))
|
||
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
return actions
|
||
|
||
@classmethod
|
||
def get_all_actions(
|
||
cls,
|
||
simulation_id: str,
|
||
platform: Optional[str] = None,
|
||
agent_id: Optional[int] = None,
|
||
round_num: Optional[int] = None
|
||
) -> List[AgentAction]:
|
||
"""
|
||
获取所有平台的完整动作历史(无分页限制)
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
platform: 过滤平台(twitter/reddit)
|
||
agent_id: 过滤Agent
|
||
round_num: 过滤轮次
|
||
|
||
Returns:
|
||
完整的动作列表(按时间戳排序,新的在前)
|
||
"""
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||
actions = []
|
||
|
||
# 读取 Twitter 动作文件(根据文件路径自动设置 platform 为 twitter)
|
||
twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl")
|
||
if not platform or platform == "twitter":
|
||
actions.extend(cls._read_actions_from_file(
|
||
twitter_actions_log,
|
||
default_platform="twitter", # 自动填充 platform 字段
|
||
platform_filter=platform,
|
||
agent_id=agent_id,
|
||
round_num=round_num
|
||
))
|
||
|
||
# 读取 Reddit 动作文件(根据文件路径自动设置 platform 为 reddit)
|
||
reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl")
|
||
if not platform or platform == "reddit":
|
||
actions.extend(cls._read_actions_from_file(
|
||
reddit_actions_log,
|
||
default_platform="reddit", # 自动填充 platform 字段
|
||
platform_filter=platform,
|
||
agent_id=agent_id,
|
||
round_num=round_num
|
||
))
|
||
|
||
# 如果分平台文件不存在,尝试读取旧的单一文件格式
|
||
if not actions:
|
||
actions_log = os.path.join(sim_dir, "actions.jsonl")
|
||
actions = cls._read_actions_from_file(
|
||
actions_log,
|
||
default_platform=None, # 旧格式文件中应该有 platform 字段
|
||
platform_filter=platform,
|
||
agent_id=agent_id,
|
||
round_num=round_num
|
||
)
|
||
|
||
# 按时间戳排序(新的在前)
|
||
actions.sort(key=lambda x: x.timestamp, reverse=True)
|
||
|
||
return actions
|
||
|
||
@classmethod
|
||
def get_actions(
|
||
cls,
|
||
simulation_id: str,
|
||
limit: int = 100,
|
||
offset: int = 0,
|
||
platform: Optional[str] = None,
|
||
agent_id: Optional[int] = None,
|
||
round_num: Optional[int] = None
|
||
) -> List[AgentAction]:
|
||
"""
|
||
获取动作历史(带分页)
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
limit: 返回数量限制
|
||
offset: 偏移量
|
||
platform: 过滤平台
|
||
agent_id: 过滤Agent
|
||
round_num: 过滤轮次
|
||
|
||
Returns:
|
||
动作列表
|
||
"""
|
||
actions = cls.get_all_actions(
|
||
simulation_id=simulation_id,
|
||
platform=platform,
|
||
agent_id=agent_id,
|
||
round_num=round_num
|
||
)
|
||
|
||
# 分页
|
||
return actions[offset:offset + limit]
|
||
|
||
@classmethod
|
||
def get_timeline(
|
||
cls,
|
||
simulation_id: str,
|
||
start_round: int = 0,
|
||
end_round: Optional[int] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取模拟时间线(按轮次汇总)
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
start_round: 起始轮次
|
||
end_round: 结束轮次
|
||
|
||
Returns:
|
||
每轮的汇总信息
|
||
"""
|
||
actions = cls.get_actions(simulation_id, limit=10000)
|
||
|
||
# 按轮次分组
|
||
rounds: Dict[int, Dict[str, Any]] = {}
|
||
|
||
for action in actions:
|
||
round_num = action.round_num
|
||
|
||
if round_num < start_round:
|
||
continue
|
||
if end_round is not None and round_num > end_round:
|
||
continue
|
||
|
||
if round_num not in rounds:
|
||
rounds[round_num] = {
|
||
"round_num": round_num,
|
||
"twitter_actions": 0,
|
||
"reddit_actions": 0,
|
||
"active_agents": set(),
|
||
"action_types": {},
|
||
"first_action_time": action.timestamp,
|
||
"last_action_time": action.timestamp,
|
||
}
|
||
|
||
r = rounds[round_num]
|
||
|
||
if action.platform == "twitter":
|
||
r["twitter_actions"] += 1
|
||
else:
|
||
r["reddit_actions"] += 1
|
||
|
||
r["active_agents"].add(action.agent_id)
|
||
r["action_types"][action.action_type] = r["action_types"].get(action.action_type, 0) + 1
|
||
r["last_action_time"] = action.timestamp
|
||
|
||
# 转换为列表
|
||
result = []
|
||
for round_num in sorted(rounds.keys()):
|
||
r = rounds[round_num]
|
||
result.append({
|
||
"round_num": round_num,
|
||
"twitter_actions": r["twitter_actions"],
|
||
"reddit_actions": r["reddit_actions"],
|
||
"total_actions": r["twitter_actions"] + r["reddit_actions"],
|
||
"active_agents_count": len(r["active_agents"]),
|
||
"active_agents": list(r["active_agents"]),
|
||
"action_types": r["action_types"],
|
||
"first_action_time": r["first_action_time"],
|
||
"last_action_time": r["last_action_time"],
|
||
})
|
||
|
||
return result
|
||
|
||
@classmethod
|
||
def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取每个Agent的统计信息
|
||
|
||
Returns:
|
||
Agent统计列表
|
||
"""
|
||
actions = cls.get_actions(simulation_id, limit=10000)
|
||
|
||
agent_stats: Dict[int, Dict[str, Any]] = {}
|
||
|
||
for action in actions:
|
||
agent_id = action.agent_id
|
||
|
||
if agent_id not in agent_stats:
|
||
agent_stats[agent_id] = {
|
||
"agent_id": agent_id,
|
||
"agent_name": action.agent_name,
|
||
"total_actions": 0,
|
||
"twitter_actions": 0,
|
||
"reddit_actions": 0,
|
||
"action_types": {},
|
||
"first_action_time": action.timestamp,
|
||
"last_action_time": action.timestamp,
|
||
}
|
||
|
||
stats = agent_stats[agent_id]
|
||
stats["total_actions"] += 1
|
||
|
||
if action.platform == "twitter":
|
||
stats["twitter_actions"] += 1
|
||
else:
|
||
stats["reddit_actions"] += 1
|
||
|
||
stats["action_types"][action.action_type] = stats["action_types"].get(action.action_type, 0) + 1
|
||
stats["last_action_time"] = action.timestamp
|
||
|
||
# 按总动作数排序
|
||
result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True)
|
||
|
||
return result
|
||
|
||
@classmethod
|
||
def cleanup_simulation_logs(cls, simulation_id: str) -> Dict[str, Any]:
|
||
"""
|
||
清理模拟的运行日志(用于强制重新开始模拟)
|
||
|
||
会删除以下文件:
|
||
- run_state.json
|
||
- twitter/actions.jsonl
|
||
- reddit/actions.jsonl
|
||
- simulation.log
|
||
- stdout.log / stderr.log
|
||
- twitter_simulation.db(模拟数据库)
|
||
- reddit_simulation.db(模拟数据库)
|
||
- env_status.json(环境状态)
|
||
|
||
注意:不会删除配置文件(simulation_config.json)和 profile 文件
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
|
||
Returns:
|
||
清理结果信息
|
||
"""
|
||
import shutil
|
||
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||
|
||
if not os.path.exists(sim_dir):
|
||
return {"success": True, "message": "模拟目录不存在,无需清理"}
|
||
|
||
cleaned_files = []
|
||
errors = []
|
||
|
||
# 要删除的文件列表(包括数据库文件)
|
||
files_to_delete = [
|
||
"run_state.json",
|
||
"simulation.log",
|
||
"stdout.log",
|
||
"stderr.log",
|
||
"twitter_simulation.db", # Twitter 平台数据库
|
||
"reddit_simulation.db", # Reddit 平台数据库
|
||
"env_status.json", # 环境状态文件
|
||
]
|
||
|
||
# 要删除的目录列表(包含动作日志)
|
||
dirs_to_clean = ["twitter", "reddit"]
|
||
|
||
# 删除文件
|
||
for filename in files_to_delete:
|
||
file_path = os.path.join(sim_dir, filename)
|
||
if os.path.exists(file_path):
|
||
try:
|
||
os.remove(file_path)
|
||
cleaned_files.append(filename)
|
||
except Exception as e:
|
||
errors.append(f"删除 {filename} 失败: {str(e)}")
|
||
|
||
# 清理平台目录中的动作日志
|
||
for dir_name in dirs_to_clean:
|
||
dir_path = os.path.join(sim_dir, dir_name)
|
||
if os.path.exists(dir_path):
|
||
actions_file = os.path.join(dir_path, "actions.jsonl")
|
||
if os.path.exists(actions_file):
|
||
try:
|
||
os.remove(actions_file)
|
||
cleaned_files.append(f"{dir_name}/actions.jsonl")
|
||
except Exception as e:
|
||
errors.append(f"删除 {dir_name}/actions.jsonl 失败: {str(e)}")
|
||
|
||
# 清理内存中的运行状态
|
||
if simulation_id in cls._run_states:
|
||
del cls._run_states[simulation_id]
|
||
|
||
logger.info(f"清理模拟日志完成: {simulation_id}, 删除文件: {cleaned_files}")
|
||
|
||
return {
|
||
"success": len(errors) == 0,
|
||
"cleaned_files": cleaned_files,
|
||
"errors": errors if errors else None
|
||
}
|
||
|
||
# 防止重复清理的标志
|
||
_cleanup_done = False
|
||
|
||
@classmethod
|
||
def cleanup_all_simulations(cls):
|
||
"""
|
||
清理所有运行中的模拟进程
|
||
|
||
在服务器关闭时调用,确保所有子进程被终止
|
||
"""
|
||
# 防止重复清理
|
||
if cls._cleanup_done:
|
||
return
|
||
cls._cleanup_done = True
|
||
|
||
# 检查是否有内容需要清理(避免空进程的进程打印无用日志)
|
||
has_processes = bool(cls._processes)
|
||
has_updaters = bool(cls._graph_memory_enabled)
|
||
|
||
if not has_processes and not has_updaters:
|
||
return # 没有需要清理的内容,静默返回
|
||
|
||
logger.info("正在清理所有模拟进程...")
|
||
|
||
# 首先停止所有图谱记忆更新器(stop_all 内部会打印日志)
|
||
try:
|
||
ZepGraphMemoryManager.stop_all()
|
||
except Exception as e:
|
||
logger.error(f"停止图谱记忆更新器失败: {e}")
|
||
cls._graph_memory_enabled.clear()
|
||
|
||
# 复制字典以避免在迭代时修改
|
||
processes = list(cls._processes.items())
|
||
|
||
for simulation_id, process in processes:
|
||
try:
|
||
if process.poll() is None: # 进程仍在运行
|
||
logger.info(f"终止模拟进程: {simulation_id}, pid={process.pid}")
|
||
|
||
try:
|
||
# 使用进程组终止(包括所有子进程)
|
||
pgid = os.getpgid(process.pid)
|
||
os.killpg(pgid, signal.SIGTERM)
|
||
|
||
try:
|
||
process.wait(timeout=5)
|
||
except subprocess.TimeoutExpired:
|
||
logger.warning(f"进程组未响应 SIGTERM,强制终止: {simulation_id}")
|
||
os.killpg(pgid, signal.SIGKILL)
|
||
process.wait(timeout=5)
|
||
|
||
except (ProcessLookupError, OSError):
|
||
# 进程可能已经不存在,尝试直接终止
|
||
try:
|
||
process.terminate()
|
||
process.wait(timeout=3)
|
||
except Exception:
|
||
process.kill()
|
||
|
||
# 更新 run_state.json
|
||
state = cls.get_run_state(simulation_id)
|
||
if state:
|
||
state.runner_status = RunnerStatus.STOPPED
|
||
state.twitter_running = False
|
||
state.reddit_running = False
|
||
state.completed_at = datetime.now().isoformat()
|
||
state.error = "服务器关闭,模拟被终止"
|
||
cls._save_run_state(state)
|
||
|
||
# 同时更新 state.json,将状态设为 stopped
|
||
try:
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||
state_file = os.path.join(sim_dir, "state.json")
|
||
logger.info(f"尝试更新 state.json: {state_file}")
|
||
if os.path.exists(state_file):
|
||
with open(state_file, 'r', encoding='utf-8') as f:
|
||
state_data = json.load(f)
|
||
state_data['status'] = 'stopped'
|
||
state_data['updated_at'] = datetime.now().isoformat()
|
||
with open(state_file, 'w', encoding='utf-8') as f:
|
||
json.dump(state_data, f, indent=2, ensure_ascii=False)
|
||
logger.info(f"已更新 state.json 状态为 stopped: {simulation_id}")
|
||
else:
|
||
logger.warning(f"state.json 不存在: {state_file}")
|
||
except Exception as state_err:
|
||
logger.warning(f"更新 state.json 失败: {simulation_id}, error={state_err}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"清理进程失败: {simulation_id}, error={e}")
|
||
|
||
# 清理文件句柄
|
||
for simulation_id, file_handle in list(cls._stdout_files.items()):
|
||
try:
|
||
if file_handle:
|
||
file_handle.close()
|
||
except Exception:
|
||
pass
|
||
cls._stdout_files.clear()
|
||
|
||
for simulation_id, file_handle in list(cls._stderr_files.items()):
|
||
try:
|
||
if file_handle:
|
||
file_handle.close()
|
||
except Exception:
|
||
pass
|
||
cls._stderr_files.clear()
|
||
|
||
# 清理内存中的状态
|
||
cls._processes.clear()
|
||
cls._action_queues.clear()
|
||
|
||
logger.info("模拟进程清理完成")
|
||
|
||
@classmethod
|
||
def register_cleanup(cls):
|
||
"""
|
||
注册清理函数
|
||
|
||
在 Flask 应用启动时调用,确保服务器关闭时清理所有模拟进程
|
||
"""
|
||
global _cleanup_registered
|
||
|
||
if _cleanup_registered:
|
||
return
|
||
|
||
# Flask debug 模式下,只在 reloader 子进程中注册清理(实际运行应用的进程)
|
||
# WERKZEUG_RUN_MAIN=true 表示是 reloader 子进程
|
||
# 如果不是 debug 模式,则没有这个环境变量,也需要注册
|
||
is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true'
|
||
is_debug_mode = os.environ.get('FLASK_DEBUG') == '1' or os.environ.get('WERKZEUG_RUN_MAIN') is not None
|
||
|
||
# 在 debug 模式下,只在 reloader 子进程中注册;非 debug 模式下始终注册
|
||
if is_debug_mode and not is_reloader_process:
|
||
_cleanup_registered = True # 标记已注册,防止子进程再次尝试
|
||
return
|
||
|
||
# 保存原有的信号处理器
|
||
original_sigint = signal.getsignal(signal.SIGINT)
|
||
original_sigterm = signal.getsignal(signal.SIGTERM)
|
||
|
||
def cleanup_handler(signum=None, frame=None):
|
||
"""信号处理器:先清理模拟进程,再调用原处理器"""
|
||
# 只有在有进程需要清理时才打印日志
|
||
if cls._processes or cls._graph_memory_enabled:
|
||
logger.info(f"收到信号 {signum},开始清理...")
|
||
cls.cleanup_all_simulations()
|
||
|
||
# 调用原有的信号处理器,让 Flask 正常退出
|
||
if signum == signal.SIGINT and callable(original_sigint):
|
||
original_sigint(signum, frame)
|
||
elif signum == signal.SIGTERM and callable(original_sigterm):
|
||
original_sigterm(signum, frame)
|
||
else:
|
||
# 如果原处理器不可调用(如 SIG_DFL),则使用默认行为
|
||
raise KeyboardInterrupt
|
||
|
||
# 注册 atexit 处理器(作为备用)
|
||
atexit.register(cls.cleanup_all_simulations)
|
||
|
||
# 注册信号处理器(仅在主线程中)
|
||
try:
|
||
# SIGTERM: kill 命令默认信号
|
||
signal.signal(signal.SIGTERM, cleanup_handler)
|
||
# SIGINT: Ctrl+C
|
||
signal.signal(signal.SIGINT, cleanup_handler)
|
||
except ValueError:
|
||
# 不在主线程中,只能使用 atexit
|
||
logger.warning("无法注册信号处理器(不在主线程),仅使用 atexit")
|
||
|
||
_cleanup_registered = True
|
||
|
||
@classmethod
|
||
def get_running_simulations(cls) -> List[str]:
|
||
"""
|
||
获取所有正在运行的模拟ID列表
|
||
"""
|
||
running = []
|
||
for sim_id, process in cls._processes.items():
|
||
if process.poll() is None:
|
||
running.append(sim_id)
|
||
return running
|
||
|
||
# ============== Interview 功能 ==============
|
||
|
||
@classmethod
|
||
def check_env_alive(cls, simulation_id: str) -> bool:
|
||
"""
|
||
检查模拟环境是否存活(可以接收Interview命令)
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
|
||
Returns:
|
||
True 表示环境存活,False 表示环境已关闭
|
||
"""
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||
if not os.path.exists(sim_dir):
|
||
return False
|
||
|
||
ipc_client = SimulationIPCClient(sim_dir)
|
||
return ipc_client.check_env_alive()
|
||
|
||
@classmethod
|
||
def get_env_status_detail(cls, simulation_id: str) -> Dict[str, Any]:
|
||
"""
|
||
获取模拟环境的详细状态信息
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
|
||
Returns:
|
||
状态详情字典,包含 status, twitter_available, reddit_available, timestamp
|
||
"""
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||
status_file = os.path.join(sim_dir, "env_status.json")
|
||
|
||
default_status = {
|
||
"status": "stopped",
|
||
"twitter_available": False,
|
||
"reddit_available": False,
|
||
"timestamp": None
|
||
}
|
||
|
||
if not os.path.exists(status_file):
|
||
return default_status
|
||
|
||
try:
|
||
with open(status_file, 'r', encoding='utf-8') as f:
|
||
status = json.load(f)
|
||
return {
|
||
"status": status.get("status", "stopped"),
|
||
"twitter_available": status.get("twitter_available", False),
|
||
"reddit_available": status.get("reddit_available", False),
|
||
"timestamp": status.get("timestamp")
|
||
}
|
||
except (json.JSONDecodeError, OSError):
|
||
return default_status
|
||
|
||
@classmethod
|
||
def interview_agent(
|
||
cls,
|
||
simulation_id: str,
|
||
agent_id: int,
|
||
prompt: str,
|
||
platform: str = None,
|
||
timeout: float = 60.0
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
采访单个Agent
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
agent_id: Agent ID
|
||
prompt: 采访问题
|
||
platform: 指定平台(可选)
|
||
- "twitter": 只采访Twitter平台
|
||
- "reddit": 只采访Reddit平台
|
||
- None: 双平台模拟时同时采访两个平台,返回整合结果
|
||
timeout: 超时时间(秒)
|
||
|
||
Returns:
|
||
采访结果字典
|
||
|
||
Raises:
|
||
ValueError: 模拟不存在或环境未运行
|
||
TimeoutError: 等待响应超时
|
||
"""
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||
if not os.path.exists(sim_dir):
|
||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||
|
||
ipc_client = SimulationIPCClient(sim_dir)
|
||
|
||
if not ipc_client.check_env_alive():
|
||
raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}")
|
||
|
||
logger.info(f"发送Interview命令: simulation_id={simulation_id}, agent_id={agent_id}, platform={platform}")
|
||
|
||
response = ipc_client.send_interview(
|
||
agent_id=agent_id,
|
||
prompt=prompt,
|
||
platform=platform,
|
||
timeout=timeout
|
||
)
|
||
|
||
if response.status.value == "completed":
|
||
return {
|
||
"success": True,
|
||
"agent_id": agent_id,
|
||
"prompt": prompt,
|
||
"result": response.result,
|
||
"timestamp": response.timestamp
|
||
}
|
||
else:
|
||
return {
|
||
"success": False,
|
||
"agent_id": agent_id,
|
||
"prompt": prompt,
|
||
"error": response.error,
|
||
"timestamp": response.timestamp
|
||
}
|
||
|
||
@classmethod
|
||
def interview_agents_batch(
|
||
cls,
|
||
simulation_id: str,
|
||
interviews: List[Dict[str, Any]],
|
||
platform: str = None,
|
||
timeout: float = 120.0
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
批量采访多个Agent
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)}
|
||
platform: 默认平台(可选,会被每个采访项的platform覆盖)
|
||
- "twitter": 默认只采访Twitter平台
|
||
- "reddit": 默认只采访Reddit平台
|
||
- None: 双平台模拟时每个Agent同时采访两个平台
|
||
timeout: 超时时间(秒)
|
||
|
||
Returns:
|
||
批量采访结果字典
|
||
|
||
Raises:
|
||
ValueError: 模拟不存在或环境未运行
|
||
TimeoutError: 等待响应超时
|
||
"""
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||
if not os.path.exists(sim_dir):
|
||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||
|
||
ipc_client = SimulationIPCClient(sim_dir)
|
||
|
||
if not ipc_client.check_env_alive():
|
||
raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}")
|
||
|
||
logger.info(f"发送批量Interview命令: simulation_id={simulation_id}, count={len(interviews)}, platform={platform}")
|
||
|
||
response = ipc_client.send_batch_interview(
|
||
interviews=interviews,
|
||
platform=platform,
|
||
timeout=timeout
|
||
)
|
||
|
||
if response.status.value == "completed":
|
||
return {
|
||
"success": True,
|
||
"interviews_count": len(interviews),
|
||
"result": response.result,
|
||
"timestamp": response.timestamp
|
||
}
|
||
else:
|
||
return {
|
||
"success": False,
|
||
"interviews_count": len(interviews),
|
||
"error": response.error,
|
||
"timestamp": response.timestamp
|
||
}
|
||
|
||
@classmethod
|
||
def interview_all_agents(
|
||
cls,
|
||
simulation_id: str,
|
||
prompt: str,
|
||
platform: str = None,
|
||
timeout: float = 180.0
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
采访所有Agent(全局采访)
|
||
|
||
使用相同的问题采访模拟中的所有Agent
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
prompt: 采访问题(所有Agent使用相同问题)
|
||
platform: 指定平台(可选)
|
||
- "twitter": 只采访Twitter平台
|
||
- "reddit": 只采访Reddit平台
|
||
- None: 双平台模拟时每个Agent同时采访两个平台
|
||
timeout: 超时时间(秒)
|
||
|
||
Returns:
|
||
全局采访结果字典
|
||
"""
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||
if not os.path.exists(sim_dir):
|
||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||
|
||
# 从配置文件获取所有Agent信息
|
||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||
if not os.path.exists(config_path):
|
||
raise ValueError(f"模拟配置不存在: {simulation_id}")
|
||
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
|
||
agent_configs = config.get("agent_configs", [])
|
||
if not agent_configs:
|
||
raise ValueError(f"模拟配置中没有Agent: {simulation_id}")
|
||
|
||
# 构建批量采访列表
|
||
interviews = []
|
||
for agent_config in agent_configs:
|
||
agent_id = agent_config.get("agent_id")
|
||
if agent_id is not None:
|
||
interviews.append({
|
||
"agent_id": agent_id,
|
||
"prompt": prompt
|
||
})
|
||
|
||
logger.info(f"发送全局Interview命令: simulation_id={simulation_id}, agent_count={len(interviews)}, platform={platform}")
|
||
|
||
return cls.interview_agents_batch(
|
||
simulation_id=simulation_id,
|
||
interviews=interviews,
|
||
platform=platform,
|
||
timeout=timeout
|
||
)
|
||
|
||
@classmethod
|
||
def close_simulation_env(
|
||
cls,
|
||
simulation_id: str,
|
||
timeout: float = 30.0
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
关闭模拟环境(而不是停止模拟进程)
|
||
|
||
向模拟发送关闭环境命令,使其优雅退出等待命令模式
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
timeout: 超时时间(秒)
|
||
|
||
Returns:
|
||
操作结果字典
|
||
"""
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||
if not os.path.exists(sim_dir):
|
||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||
|
||
ipc_client = SimulationIPCClient(sim_dir)
|
||
|
||
if not ipc_client.check_env_alive():
|
||
return {
|
||
"success": True,
|
||
"message": "环境已经关闭"
|
||
}
|
||
|
||
logger.info(f"发送关闭环境命令: simulation_id={simulation_id}")
|
||
|
||
try:
|
||
response = ipc_client.send_close_env(timeout=timeout)
|
||
|
||
return {
|
||
"success": response.status.value == "completed",
|
||
"message": "环境关闭命令已发送",
|
||
"result": response.result,
|
||
"timestamp": response.timestamp
|
||
}
|
||
except TimeoutError:
|
||
# 超时可能是因为环境正在关闭
|
||
return {
|
||
"success": True,
|
||
"message": "环境关闭命令已发送(等待响应超时,环境可能正在关闭)"
|
||
}
|
||
|
||
@classmethod
|
||
def _get_interview_history_from_db(
|
||
cls,
|
||
db_path: str,
|
||
platform_name: str,
|
||
agent_id: Optional[int] = None,
|
||
limit: int = 100
|
||
) -> List[Dict[str, Any]]:
|
||
"""从单个数据库获取Interview历史"""
|
||
import sqlite3
|
||
|
||
if not os.path.exists(db_path):
|
||
return []
|
||
|
||
results = []
|
||
|
||
try:
|
||
conn = sqlite3.connect(db_path)
|
||
cursor = conn.cursor()
|
||
|
||
if agent_id is not None:
|
||
cursor.execute("""
|
||
SELECT user_id, info, created_at
|
||
FROM trace
|
||
WHERE action = 'interview' AND user_id = ?
|
||
ORDER BY created_at DESC
|
||
LIMIT ?
|
||
""", (agent_id, limit))
|
||
else:
|
||
cursor.execute("""
|
||
SELECT user_id, info, created_at
|
||
FROM trace
|
||
WHERE action = 'interview'
|
||
ORDER BY created_at DESC
|
||
LIMIT ?
|
||
""", (limit,))
|
||
|
||
for user_id, info_json, created_at in cursor.fetchall():
|
||
try:
|
||
info = json.loads(info_json) if info_json else {}
|
||
except json.JSONDecodeError:
|
||
info = {"raw": info_json}
|
||
|
||
results.append({
|
||
"agent_id": user_id,
|
||
"response": info.get("response", info),
|
||
"prompt": info.get("prompt", ""),
|
||
"timestamp": created_at,
|
||
"platform": platform_name
|
||
})
|
||
|
||
conn.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"读取Interview历史失败 ({platform_name}): {e}")
|
||
|
||
return results
|
||
|
||
@classmethod
|
||
def get_interview_history(
|
||
cls,
|
||
simulation_id: str,
|
||
platform: str = None,
|
||
agent_id: Optional[int] = None,
|
||
limit: int = 100
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取Interview历史记录(从数据库读取)
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
platform: 平台类型(reddit/twitter/None)
|
||
- "reddit": 只获取Reddit平台的历史
|
||
- "twitter": 只获取Twitter平台的历史
|
||
- None: 获取两个平台的所有历史
|
||
agent_id: 指定Agent ID(可选,只获取该Agent的历史)
|
||
limit: 每个平台返回数量限制
|
||
|
||
Returns:
|
||
Interview历史记录列表
|
||
"""
|
||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||
|
||
results = []
|
||
|
||
# 确定要查询的平台
|
||
if platform in ("reddit", "twitter"):
|
||
platforms = [platform]
|
||
else:
|
||
# 不指定platform时,查询两个平台
|
||
platforms = ["twitter", "reddit"]
|
||
|
||
for p in platforms:
|
||
db_path = os.path.join(sim_dir, f"{p}_simulation.db")
|
||
platform_results = cls._get_interview_history_from_db(
|
||
db_path=db_path,
|
||
platform_name=p,
|
||
agent_id=agent_id,
|
||
limit=limit
|
||
)
|
||
results.extend(platform_results)
|
||
|
||
# 按时间降序排序
|
||
results.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
|
||
|
||
# 如果查询了多个平台,限制总数
|
||
if len(platforms) > 1 and len(results) > limit:
|
||
results = results[:limit]
|
||
|
||
return results
|
||
|