616 lines
24 KiB
Python
616 lines
24 KiB
Python
"""
|
||
OASIS模拟管理器
|
||
管理Twitter和Reddit双平台并行模拟
|
||
使用预设脚本 + LLM智能生成配置参数
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import shutil
|
||
from typing import Dict, Any, List, Optional
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
|
||
from ..config import Config
|
||
from ..utils.logger import get_logger
|
||
from .zep_entity_reader import ZepEntityReader, FilteredEntities, EntityNode
|
||
from .llm_graph_builder import LLMGraphBuilderService
|
||
from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
|
||
from .simulation_config_generator import SimulationConfigGenerator, SimulationParameters
|
||
|
||
logger = get_logger('mirofish.simulation')
|
||
|
||
|
||
class SimulationStatus(str, Enum):
|
||
"""模拟状态"""
|
||
CREATED = "created"
|
||
PREPARING = "preparing"
|
||
READY = "ready"
|
||
RUNNING = "running"
|
||
PAUSED = "paused"
|
||
STOPPED = "stopped" # 模拟被手动停止
|
||
COMPLETED = "completed" # 模拟自然完成
|
||
FAILED = "failed"
|
||
|
||
|
||
class PlatformType(str, Enum):
|
||
"""平台类型"""
|
||
TWITTER = "twitter"
|
||
REDDIT = "reddit"
|
||
|
||
|
||
@dataclass
|
||
class SimulationState:
|
||
"""模拟状态"""
|
||
simulation_id: str
|
||
project_id: str
|
||
graph_id: str
|
||
|
||
# 平台启用状态
|
||
enable_twitter: bool = True
|
||
enable_reddit: bool = True
|
||
|
||
# 状态
|
||
status: SimulationStatus = SimulationStatus.CREATED
|
||
|
||
# 准备阶段数据
|
||
entities_count: int = 0
|
||
profiles_count: int = 0
|
||
entity_types: List[str] = field(default_factory=list)
|
||
|
||
# 配置生成信息
|
||
config_generated: bool = False
|
||
config_reasoning: str = ""
|
||
|
||
# 运行时数据
|
||
current_round: int = 0
|
||
twitter_status: str = "not_started"
|
||
reddit_status: str = "not_started"
|
||
|
||
# 时间戳
|
||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||
|
||
# 错误信息
|
||
error: Optional[str] = None
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""完整状态字典(内部使用)"""
|
||
return {
|
||
"simulation_id": self.simulation_id,
|
||
"project_id": self.project_id,
|
||
"graph_id": self.graph_id,
|
||
"enable_twitter": self.enable_twitter,
|
||
"enable_reddit": self.enable_reddit,
|
||
"status": self.status.value,
|
||
"entities_count": self.entities_count,
|
||
"profiles_count": self.profiles_count,
|
||
"entity_types": self.entity_types,
|
||
"config_generated": self.config_generated,
|
||
"config_reasoning": self.config_reasoning,
|
||
"current_round": self.current_round,
|
||
"twitter_status": self.twitter_status,
|
||
"reddit_status": self.reddit_status,
|
||
"created_at": self.created_at,
|
||
"updated_at": self.updated_at,
|
||
"error": self.error,
|
||
}
|
||
|
||
def to_simple_dict(self) -> Dict[str, Any]:
|
||
"""简化状态字典(API返回使用)"""
|
||
return {
|
||
"simulation_id": self.simulation_id,
|
||
"project_id": self.project_id,
|
||
"graph_id": self.graph_id,
|
||
"status": self.status.value,
|
||
"entities_count": self.entities_count,
|
||
"profiles_count": self.profiles_count,
|
||
"entity_types": self.entity_types,
|
||
"config_generated": self.config_generated,
|
||
"error": self.error,
|
||
}
|
||
|
||
|
||
class SimulationManager:
|
||
"""
|
||
模拟管理器
|
||
|
||
核心功能:
|
||
1. 从Zep图谱读取实体并过滤
|
||
2. 生成OASIS Agent Profile
|
||
3. 使用LLM智能生成模拟配置参数
|
||
4. 准备预设脚本所需的所有文件
|
||
"""
|
||
|
||
# 模拟数据存储目录
|
||
SIMULATION_DATA_DIR = os.path.join(
|
||
os.path.dirname(__file__),
|
||
'../../uploads/simulations'
|
||
)
|
||
|
||
def __init__(self):
|
||
# 确保目录存在
|
||
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
||
|
||
# 内存中的模拟状态缓存
|
||
self._simulations: Dict[str, SimulationState] = {}
|
||
|
||
def _get_simulation_dir(self, simulation_id: str) -> str:
|
||
"""获取模拟数据目录"""
|
||
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
|
||
os.makedirs(sim_dir, exist_ok=True)
|
||
return sim_dir
|
||
|
||
def _save_simulation_state(self, state: SimulationState):
|
||
"""保存模拟状态到文件"""
|
||
sim_dir = self._get_simulation_dir(state.simulation_id)
|
||
state_file = os.path.join(sim_dir, "state.json")
|
||
|
||
state.updated_at = datetime.now().isoformat()
|
||
|
||
with open(state_file, 'w', encoding='utf-8') as f:
|
||
json.dump(state.to_dict(), f, ensure_ascii=False, indent=2)
|
||
|
||
self._simulations[state.simulation_id] = state
|
||
|
||
def _filter_entities_from_data(
|
||
self,
|
||
graph_data: Dict[str, Any],
|
||
defined_entity_types: Optional[List[str]] = None
|
||
) -> FilteredEntities:
|
||
"""Filter entities from disk-stored graph data (no Zep needed)."""
|
||
nodes = graph_data.get("nodes", [])
|
||
edges = graph_data.get("edges", [])
|
||
total_count = len(nodes)
|
||
|
||
# Build node UUID map for edge enrichment
|
||
node_map = {n["uuid"]: n for n in nodes}
|
||
|
||
filtered_entities = []
|
||
entity_types_found = set()
|
||
|
||
for node in nodes:
|
||
labels = node.get("labels", [])
|
||
meaningful_labels = [l for l in labels if l not in ("Entity", "Node")]
|
||
if not meaningful_labels:
|
||
continue
|
||
|
||
entity_type = meaningful_labels[0]
|
||
if defined_entity_types and entity_type not in defined_entity_types:
|
||
continue
|
||
|
||
entity_types_found.add(entity_type)
|
||
|
||
# Find related edges
|
||
related_edges = []
|
||
related_nodes = []
|
||
node_uuid = node.get("uuid", "")
|
||
for edge in edges:
|
||
if edge.get("source_node_uuid") == node_uuid or edge.get("target_node_uuid") == node_uuid:
|
||
related_edges.append({
|
||
"uuid": edge.get("uuid", ""),
|
||
"name": edge.get("name", ""),
|
||
"fact": edge.get("fact", ""),
|
||
"source_node_uuid": edge.get("source_node_uuid", ""),
|
||
"target_node_uuid": edge.get("target_node_uuid", ""),
|
||
"source_node_name": edge.get("source_node_name", ""),
|
||
"target_node_name": edge.get("target_node_name", ""),
|
||
})
|
||
# Add related node
|
||
other_uuid = (edge.get("target_node_uuid") if edge.get("source_node_uuid") == node_uuid
|
||
else edge.get("source_node_uuid"))
|
||
other_node = node_map.get(other_uuid)
|
||
if other_node:
|
||
related_nodes.append({
|
||
"uuid": other_node.get("uuid", ""),
|
||
"name": other_node.get("name", ""),
|
||
"labels": other_node.get("labels", []),
|
||
})
|
||
|
||
filtered_entities.append(EntityNode(
|
||
uuid=node_uuid,
|
||
name=node.get("name", ""),
|
||
labels=labels,
|
||
summary=node.get("summary", ""),
|
||
attributes=node.get("attributes", {}),
|
||
related_edges=related_edges,
|
||
related_nodes=related_nodes,
|
||
))
|
||
|
||
return FilteredEntities(
|
||
entities=filtered_entities,
|
||
entity_types=entity_types_found,
|
||
total_count=total_count,
|
||
filtered_count=len(filtered_entities),
|
||
)
|
||
|
||
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
|
||
"""从文件加载模拟状态"""
|
||
if simulation_id in self._simulations:
|
||
return self._simulations[simulation_id]
|
||
|
||
sim_dir = self._get_simulation_dir(simulation_id)
|
||
state_file = os.path.join(sim_dir, "state.json")
|
||
|
||
if not os.path.exists(state_file):
|
||
return None
|
||
|
||
with open(state_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
state = SimulationState(
|
||
simulation_id=simulation_id,
|
||
project_id=data.get("project_id", ""),
|
||
graph_id=data.get("graph_id", ""),
|
||
enable_twitter=data.get("enable_twitter", True),
|
||
enable_reddit=data.get("enable_reddit", True),
|
||
status=SimulationStatus(data.get("status", "created")),
|
||
entities_count=data.get("entities_count", 0),
|
||
profiles_count=data.get("profiles_count", 0),
|
||
entity_types=data.get("entity_types", []),
|
||
config_generated=data.get("config_generated", False),
|
||
config_reasoning=data.get("config_reasoning", ""),
|
||
current_round=data.get("current_round", 0),
|
||
twitter_status=data.get("twitter_status", "not_started"),
|
||
reddit_status=data.get("reddit_status", "not_started"),
|
||
created_at=data.get("created_at", datetime.now().isoformat()),
|
||
updated_at=data.get("updated_at", datetime.now().isoformat()),
|
||
error=data.get("error"),
|
||
)
|
||
|
||
self._simulations[simulation_id] = state
|
||
return state
|
||
|
||
def create_simulation(
|
||
self,
|
||
project_id: str,
|
||
graph_id: str,
|
||
enable_twitter: bool = True,
|
||
enable_reddit: bool = True,
|
||
) -> SimulationState:
|
||
"""
|
||
创建新的模拟
|
||
|
||
Args:
|
||
project_id: 项目ID
|
||
graph_id: Zep图谱ID
|
||
enable_twitter: 是否启用Twitter模拟
|
||
enable_reddit: 是否启用Reddit模拟
|
||
|
||
Returns:
|
||
SimulationState
|
||
"""
|
||
import uuid
|
||
simulation_id = f"sim_{uuid.uuid4().hex[:12]}"
|
||
|
||
state = SimulationState(
|
||
simulation_id=simulation_id,
|
||
project_id=project_id,
|
||
graph_id=graph_id,
|
||
enable_twitter=enable_twitter,
|
||
enable_reddit=enable_reddit,
|
||
status=SimulationStatus.CREATED,
|
||
)
|
||
|
||
self._save_simulation_state(state)
|
||
logger.info(f"创建模拟: {simulation_id}, project={project_id}, graph={graph_id}")
|
||
|
||
return state
|
||
|
||
def prepare_simulation(
|
||
self,
|
||
simulation_id: str,
|
||
simulation_requirement: str,
|
||
document_text: str,
|
||
defined_entity_types: Optional[List[str]] = None,
|
||
use_llm_for_profiles: bool = True,
|
||
progress_callback: Optional[callable] = None,
|
||
parallel_profile_count: int = 3
|
||
) -> SimulationState:
|
||
"""
|
||
准备模拟环境(全程自动化)
|
||
|
||
步骤:
|
||
1. 从Zep图谱读取并过滤实体
|
||
2. 为每个实体生成OASIS Agent Profile(可选LLM增强,支持并行)
|
||
3. 使用LLM智能生成模拟配置参数(时间、活跃度、发言频率等)
|
||
4. 保存配置文件和Profile文件
|
||
5. 复制预设脚本到模拟目录
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
simulation_requirement: 模拟需求描述(用于LLM生成配置)
|
||
document_text: 原始文档内容(用于LLM理解背景)
|
||
defined_entity_types: 预定义的实体类型(可选)
|
||
use_llm_for_profiles: 是否使用LLM生成详细人设
|
||
progress_callback: 进度回调函数 (stage, progress, message)
|
||
parallel_profile_count: 并行生成人设的数量,默认3
|
||
|
||
Returns:
|
||
SimulationState
|
||
"""
|
||
state = self._load_simulation_state(simulation_id)
|
||
if not state:
|
||
raise ValueError(f"Simulation not found: {simulation_id}")
|
||
|
||
try:
|
||
state.status = SimulationStatus.PREPARING
|
||
self._save_simulation_state(state)
|
||
|
||
sim_dir = self._get_simulation_dir(simulation_id)
|
||
|
||
# ========== 阶段1: 读取并过滤实体 ==========
|
||
if progress_callback:
|
||
progress_callback("reading", 0, "Reading graph data...")
|
||
|
||
# Try loading graph data from disk first (LLM-built graphs)
|
||
from ..models.project import ProjectManager
|
||
disk_graph_data = None
|
||
all_projects = ProjectManager.list_projects()
|
||
for proj in all_projects:
|
||
if proj.graph_id == state.graph_id:
|
||
project_dir = ProjectManager._get_project_dir(proj.project_id)
|
||
disk_graph_data = LLMGraphBuilderService.load_graph_data(project_dir)
|
||
break
|
||
|
||
if progress_callback:
|
||
progress_callback("reading", 30, "Filtering entities...")
|
||
|
||
if disk_graph_data:
|
||
# Build FilteredEntities from disk data
|
||
filtered = self._filter_entities_from_data(
|
||
disk_graph_data, defined_entity_types
|
||
)
|
||
else:
|
||
# Fall back to Zep
|
||
reader = ZepEntityReader()
|
||
filtered = reader.filter_defined_entities(
|
||
graph_id=state.graph_id,
|
||
defined_entity_types=defined_entity_types,
|
||
enrich_with_edges=True
|
||
)
|
||
|
||
state.entities_count = filtered.filtered_count
|
||
state.entity_types = list(filtered.entity_types)
|
||
|
||
if progress_callback:
|
||
progress_callback(
|
||
"reading", 100,
|
||
f"Done, {filtered.filtered_count} entities found",
|
||
current=filtered.filtered_count,
|
||
total=filtered.filtered_count
|
||
)
|
||
|
||
if filtered.filtered_count == 0:
|
||
state.status = SimulationStatus.FAILED
|
||
state.error = "No matching entities found. Please check that the graph was built correctly"
|
||
self._save_simulation_state(state)
|
||
return state
|
||
|
||
# ========== 阶段2: 生成Agent Profile ==========
|
||
total_entities = len(filtered.entities)
|
||
|
||
if progress_callback:
|
||
progress_callback(
|
||
"generating_profiles", 0,
|
||
"Starting generation...",
|
||
current=0,
|
||
total=total_entities
|
||
)
|
||
|
||
# 传入graph_id以启用Zep检索功能,获取更丰富的上下文
|
||
generator = OasisProfileGenerator(graph_id=state.graph_id)
|
||
|
||
def profile_progress(current, total, msg):
|
||
if progress_callback:
|
||
progress_callback(
|
||
"generating_profiles",
|
||
int(current / total * 100),
|
||
msg,
|
||
current=current,
|
||
total=total,
|
||
item_name=msg
|
||
)
|
||
|
||
# 设置实时保存的文件路径(优先使用 Reddit JSON 格式)
|
||
realtime_output_path = None
|
||
realtime_platform = "reddit"
|
||
if state.enable_reddit:
|
||
realtime_output_path = os.path.join(sim_dir, "reddit_profiles.json")
|
||
realtime_platform = "reddit"
|
||
elif state.enable_twitter:
|
||
realtime_output_path = os.path.join(sim_dir, "twitter_profiles.csv")
|
||
realtime_platform = "twitter"
|
||
|
||
profiles = generator.generate_profiles_from_entities(
|
||
entities=filtered.entities,
|
||
use_llm=use_llm_for_profiles,
|
||
progress_callback=profile_progress,
|
||
graph_id=state.graph_id, # 传入graph_id用于Zep检索
|
||
parallel_count=parallel_profile_count, # 并行生成数量
|
||
realtime_output_path=realtime_output_path, # 实时保存路径
|
||
output_platform=realtime_platform # 输出格式
|
||
)
|
||
|
||
state.profiles_count = len(profiles)
|
||
|
||
# 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式)
|
||
# Reddit 已经在生成过程中实时保存了,这里再保存一次确保完整性
|
||
if progress_callback:
|
||
progress_callback(
|
||
"generating_profiles", 95,
|
||
"Saving profile files...",
|
||
current=total_entities,
|
||
total=total_entities
|
||
)
|
||
|
||
if state.enable_reddit:
|
||
generator.save_profiles(
|
||
profiles=profiles,
|
||
file_path=os.path.join(sim_dir, "reddit_profiles.json"),
|
||
platform="reddit"
|
||
)
|
||
|
||
if state.enable_twitter:
|
||
# Twitter使用CSV格式!这是OASIS的要求
|
||
generator.save_profiles(
|
||
profiles=profiles,
|
||
file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
|
||
platform="twitter"
|
||
)
|
||
|
||
if progress_callback:
|
||
progress_callback(
|
||
"generating_profiles", 100,
|
||
f"Done, {len(profiles)} profiles generated",
|
||
current=len(profiles),
|
||
total=len(profiles)
|
||
)
|
||
|
||
# ========== 阶段3: LLM智能生成模拟配置 ==========
|
||
if progress_callback:
|
||
progress_callback(
|
||
"generating_config", 0,
|
||
"Analyzing simulation requirements...",
|
||
current=0,
|
||
total=3
|
||
)
|
||
|
||
config_generator = SimulationConfigGenerator()
|
||
|
||
if progress_callback:
|
||
progress_callback(
|
||
"generating_config", 30,
|
||
"Calling LLM to generate config...",
|
||
current=1,
|
||
total=3
|
||
)
|
||
|
||
sim_params = config_generator.generate_config(
|
||
simulation_id=simulation_id,
|
||
project_id=state.project_id,
|
||
graph_id=state.graph_id,
|
||
simulation_requirement=simulation_requirement,
|
||
document_text=document_text,
|
||
entities=filtered.entities,
|
||
enable_twitter=state.enable_twitter,
|
||
enable_reddit=state.enable_reddit
|
||
)
|
||
|
||
if progress_callback:
|
||
progress_callback(
|
||
"generating_config", 70,
|
||
"Saving config files...",
|
||
current=2,
|
||
total=3
|
||
)
|
||
|
||
# 保存配置文件
|
||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||
with open(config_path, 'w', encoding='utf-8') as f:
|
||
f.write(sim_params.to_json())
|
||
|
||
state.config_generated = True
|
||
state.config_reasoning = sim_params.generation_reasoning
|
||
|
||
if progress_callback:
|
||
progress_callback(
|
||
"generating_config", 100,
|
||
"Config generation complete",
|
||
current=3,
|
||
total=3
|
||
)
|
||
|
||
# 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录
|
||
# 启动模拟时,simulation_runner 会从 scripts/ 目录运行脚本
|
||
|
||
# 更新状态
|
||
state.status = SimulationStatus.READY
|
||
self._save_simulation_state(state)
|
||
|
||
logger.info(f"模拟准备完成: {simulation_id}, "
|
||
f"entities={state.entities_count}, profiles={state.profiles_count}")
|
||
|
||
return state
|
||
|
||
except Exception as e:
|
||
logger.error(f"模拟准备失败: {simulation_id}, error={str(e)}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
state.status = SimulationStatus.FAILED
|
||
state.error = str(e)
|
||
self._save_simulation_state(state)
|
||
raise
|
||
|
||
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
||
"""获取模拟状态"""
|
||
return self._load_simulation_state(simulation_id)
|
||
|
||
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
|
||
"""列出所有模拟"""
|
||
simulations = []
|
||
|
||
if os.path.exists(self.SIMULATION_DATA_DIR):
|
||
for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
|
||
# 跳过隐藏文件(如 .DS_Store)和非目录文件
|
||
sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id)
|
||
if sim_id.startswith('.') or not os.path.isdir(sim_path):
|
||
continue
|
||
|
||
state = self._load_simulation_state(sim_id)
|
||
if state:
|
||
if project_id is None or state.project_id == project_id:
|
||
simulations.append(state)
|
||
|
||
return simulations
|
||
|
||
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
|
||
"""获取模拟的Agent Profile"""
|
||
state = self._load_simulation_state(simulation_id)
|
||
if not state:
|
||
raise ValueError(f"Simulation not found: {simulation_id}")
|
||
|
||
sim_dir = self._get_simulation_dir(simulation_id)
|
||
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
|
||
|
||
if not os.path.exists(profile_path):
|
||
return []
|
||
|
||
with open(profile_path, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
|
||
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
|
||
"""获取模拟配置"""
|
||
sim_dir = self._get_simulation_dir(simulation_id)
|
||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||
|
||
if not os.path.exists(config_path):
|
||
return None
|
||
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
|
||
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
|
||
"""获取运行说明"""
|
||
sim_dir = self._get_simulation_dir(simulation_id)
|
||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
|
||
|
||
return {
|
||
"simulation_dir": sim_dir,
|
||
"scripts_dir": scripts_dir,
|
||
"config_file": config_path,
|
||
"commands": {
|
||
"twitter": f"python {scripts_dir}/run_twitter_simulation.py --config {config_path}",
|
||
"reddit": f"python {scripts_dir}/run_reddit_simulation.py --config {config_path}",
|
||
"parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}",
|
||
},
|
||
"instructions": (
|
||
f"1. 激活conda环境: conda activate MiroFish\n"
|
||
f"2. 运行模拟 (脚本位于 {scripts_dir}):\n"
|
||
f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n"
|
||
f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n"
|
||
f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}"
|
||
)
|
||
}
|