Enhance backend functionality with OASIS simulation features
- Updated README.md to include new simulation scripts and configuration details for OASIS, including API retry mechanisms and environment variable settings. - Added simulation management and configuration generation services to streamline the simulation process across Twitter and Reddit platforms. - Introduced new API routes for simulation-related operations, including entity retrieval and simulation status management. - Implemented a robust retry mechanism for external API calls to improve system stability. - Enhanced task management model to include detailed progress tracking. - Added logging capabilities for action tracking during simulations. - Included new scripts for running parallel simulations and testing profile formats.
This commit is contained in:
parent
c60e6e1089
commit
5f159f6d88
19 changed files with 7202 additions and 49 deletions
1444
backend/README.md
1444
backend/README.md
File diff suppressed because it is too large
Load diff
|
|
@ -46,8 +46,9 @@ def create_app(config_class=Config):
|
|||
return response
|
||||
|
||||
# 注册蓝图
|
||||
from .api import graph_bp
|
||||
from .api import graph_bp, simulation_bp
|
||||
app.register_blueprint(graph_bp, url_prefix='/api/graph')
|
||||
app.register_blueprint(simulation_bp, url_prefix='/api/simulation')
|
||||
|
||||
# 健康检查
|
||||
@app.route('/health')
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ API路由模块
|
|||
from flask import Blueprint
|
||||
|
||||
graph_bp = Blueprint('graph', __name__)
|
||||
simulation_bp = Blueprint('simulation', __name__)
|
||||
|
||||
from . import graph # noqa: E402, F401
|
||||
from . import simulation # noqa: E402, F401
|
||||
|
||||
|
|
|
|||
1330
backend/app/api/simulation.py
Normal file
1330
backend/app/api/simulation.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -41,6 +41,20 @@ class Config:
|
|||
DEFAULT_CHUNK_SIZE = 500 # 默认切块大小
|
||||
DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小
|
||||
|
||||
# OASIS模拟配置
|
||||
OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10'))
|
||||
OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations')
|
||||
|
||||
# OASIS平台可用动作配置
|
||||
OASIS_TWITTER_ACTIONS = [
|
||||
'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST'
|
||||
]
|
||||
OASIS_REDDIT_ACTIONS = [
|
||||
'LIKE_POST', 'DISLIKE_POST', 'CREATE_POST', 'CREATE_COMMENT',
|
||||
'LIKE_COMMENT', 'DISLIKE_COMMENT', 'SEARCH_POSTS', 'SEARCH_USER',
|
||||
'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE'
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def validate(cls):
|
||||
"""验证必要配置"""
|
||||
|
|
|
|||
|
|
@ -27,11 +27,12 @@ class Task:
|
|||
status: TaskStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
progress: int = 0 # 进度百分比 0-100
|
||||
progress: int = 0 # 总进度百分比 0-100
|
||||
message: str = "" # 状态消息
|
||||
result: Optional[Dict] = None # 任务结果
|
||||
error: Optional[str] = None # 错误信息
|
||||
metadata: Dict = field(default_factory=dict) # 额外元数据
|
||||
progress_detail: Dict = field(default_factory=dict) # 详细进度信息
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
|
|
@ -43,6 +44,7 @@ class Task:
|
|||
"updated_at": self.updated_at.isoformat(),
|
||||
"progress": self.progress,
|
||||
"message": self.message,
|
||||
"progress_detail": self.progress_detail,
|
||||
"result": self.result,
|
||||
"error": self.error,
|
||||
"metadata": self.metadata,
|
||||
|
|
@ -108,7 +110,8 @@ class TaskManager:
|
|||
progress: Optional[int] = None,
|
||||
message: Optional[str] = None,
|
||||
result: Optional[Dict] = None,
|
||||
error: Optional[str] = None
|
||||
error: Optional[str] = None,
|
||||
progress_detail: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
更新任务状态
|
||||
|
|
@ -120,6 +123,7 @@ class TaskManager:
|
|||
message: 消息
|
||||
result: 结果
|
||||
error: 错误信息
|
||||
progress_detail: 详细进度信息
|
||||
"""
|
||||
with self._task_lock:
|
||||
task = self._tasks.get(task_id)
|
||||
|
|
@ -135,6 +139,8 @@ class TaskManager:
|
|||
task.result = result
|
||||
if error is not None:
|
||||
task.error = error
|
||||
if progress_detail is not None:
|
||||
task.progress_detail = progress_detail
|
||||
|
||||
def complete_task(self, task_id: str, result: Dict):
|
||||
"""标记任务完成"""
|
||||
|
|
|
|||
|
|
@ -5,6 +5,47 @@
|
|||
from .ontology_generator import OntologyGenerator
|
||||
from .graph_builder import GraphBuilderService
|
||||
from .text_processor import TextProcessor
|
||||
from .zep_entity_reader import ZepEntityReader, EntityNode, FilteredEntities
|
||||
from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
|
||||
from .simulation_manager import SimulationManager, SimulationState, SimulationStatus
|
||||
from .simulation_config_generator import (
|
||||
SimulationConfigGenerator,
|
||||
SimulationParameters,
|
||||
AgentActivityConfig,
|
||||
TimeSimulationConfig,
|
||||
EventConfig,
|
||||
PlatformConfig
|
||||
)
|
||||
from .simulation_runner import (
|
||||
SimulationRunner,
|
||||
SimulationRunState,
|
||||
RunnerStatus,
|
||||
AgentAction,
|
||||
RoundSummary
|
||||
)
|
||||
|
||||
__all__ = ['OntologyGenerator', 'GraphBuilderService', 'TextProcessor']
|
||||
__all__ = [
|
||||
'OntologyGenerator',
|
||||
'GraphBuilderService',
|
||||
'TextProcessor',
|
||||
'ZepEntityReader',
|
||||
'EntityNode',
|
||||
'FilteredEntities',
|
||||
'OasisProfileGenerator',
|
||||
'OasisAgentProfile',
|
||||
'SimulationManager',
|
||||
'SimulationState',
|
||||
'SimulationStatus',
|
||||
'SimulationConfigGenerator',
|
||||
'SimulationParameters',
|
||||
'AgentActivityConfig',
|
||||
'TimeSimulationConfig',
|
||||
'EventConfig',
|
||||
'PlatformConfig',
|
||||
'SimulationRunner',
|
||||
'SimulationRunState',
|
||||
'RunnerStatus',
|
||||
'AgentAction',
|
||||
'RoundSummary',
|
||||
]
|
||||
|
||||
|
|
|
|||
561
backend/app/services/oasis_profile_generator.py
Normal file
561
backend/app/services/oasis_profile_generator.py
Normal file
|
|
@ -0,0 +1,561 @@
|
|||
"""
|
||||
OASIS Agent Profile生成器
|
||||
将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
||||
logger = get_logger('mirofish.oasis_profile')
|
||||
|
||||
|
||||
@dataclass
|
||||
class OasisAgentProfile:
|
||||
"""OASIS Agent Profile数据结构"""
|
||||
# 通用字段
|
||||
user_id: int
|
||||
user_name: str
|
||||
name: str
|
||||
bio: str
|
||||
persona: str
|
||||
|
||||
# 可选字段 - Reddit风格
|
||||
karma: int = 1000
|
||||
|
||||
# 可选字段 - Twitter风格
|
||||
friend_count: int = 100
|
||||
follower_count: int = 150
|
||||
statuses_count: int = 500
|
||||
|
||||
# 额外人设信息
|
||||
age: Optional[int] = None
|
||||
gender: Optional[str] = None
|
||||
mbti: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
profession: Optional[str] = None
|
||||
interested_topics: List[str] = field(default_factory=list)
|
||||
|
||||
# 来源实体信息
|
||||
source_entity_uuid: Optional[str] = None
|
||||
source_entity_type: Optional[str] = None
|
||||
|
||||
created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d"))
|
||||
|
||||
def to_reddit_format(self) -> Dict[str, Any]:
|
||||
"""转换为Reddit平台格式"""
|
||||
profile = {
|
||||
"user_id": self.user_id,
|
||||
"user_name": self.user_name,
|
||||
"name": self.name,
|
||||
"bio": self.bio,
|
||||
"persona": self.persona,
|
||||
"karma": self.karma,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
# 添加额外人设信息(如果有)
|
||||
if self.age:
|
||||
profile["age"] = self.age
|
||||
if self.gender:
|
||||
profile["gender"] = self.gender
|
||||
if self.mbti:
|
||||
profile["mbti"] = self.mbti
|
||||
if self.country:
|
||||
profile["country"] = self.country
|
||||
if self.profession:
|
||||
profile["profession"] = self.profession
|
||||
if self.interested_topics:
|
||||
profile["interested_topics"] = self.interested_topics
|
||||
|
||||
return profile
|
||||
|
||||
def to_twitter_format(self) -> Dict[str, Any]:
|
||||
"""转换为Twitter平台格式"""
|
||||
profile = {
|
||||
"user_id": self.user_id,
|
||||
"user_name": self.user_name,
|
||||
"name": self.name,
|
||||
"bio": self.bio,
|
||||
"persona": self.persona,
|
||||
"friend_count": self.friend_count,
|
||||
"follower_count": self.follower_count,
|
||||
"statuses_count": self.statuses_count,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
# 添加额外人设信息
|
||||
if self.age:
|
||||
profile["age"] = self.age
|
||||
if self.gender:
|
||||
profile["gender"] = self.gender
|
||||
if self.mbti:
|
||||
profile["mbti"] = self.mbti
|
||||
if self.country:
|
||||
profile["country"] = self.country
|
||||
if self.profession:
|
||||
profile["profession"] = self.profession
|
||||
if self.interested_topics:
|
||||
profile["interested_topics"] = self.interested_topics
|
||||
|
||||
return profile
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为完整字典格式"""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"user_name": self.user_name,
|
||||
"name": self.name,
|
||||
"bio": self.bio,
|
||||
"persona": self.persona,
|
||||
"karma": self.karma,
|
||||
"friend_count": self.friend_count,
|
||||
"follower_count": self.follower_count,
|
||||
"statuses_count": self.statuses_count,
|
||||
"age": self.age,
|
||||
"gender": self.gender,
|
||||
"mbti": self.mbti,
|
||||
"country": self.country,
|
||||
"profession": self.profession,
|
||||
"interested_topics": self.interested_topics,
|
||||
"source_entity_uuid": self.source_entity_uuid,
|
||||
"source_entity_type": self.source_entity_type,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
|
||||
class OasisProfileGenerator:
|
||||
"""
|
||||
OASIS Profile生成器
|
||||
|
||||
将Zep图谱中的实体转换为OASIS模拟所需的Agent Profile
|
||||
"""
|
||||
|
||||
# MBTI类型列表
|
||||
MBTI_TYPES = [
|
||||
"INTJ", "INTP", "ENTJ", "ENTP",
|
||||
"INFJ", "INFP", "ENFJ", "ENFP",
|
||||
"ISTJ", "ISFJ", "ESTJ", "ESFJ",
|
||||
"ISTP", "ISFP", "ESTP", "ESFP"
|
||||
]
|
||||
|
||||
# 常见国家列表
|
||||
COUNTRIES = [
|
||||
"China", "US", "UK", "Japan", "Germany", "France",
|
||||
"Canada", "Australia", "Brazil", "India", "South Korea"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
model_name: Optional[str] = None
|
||||
):
|
||||
self.api_key = api_key or Config.LLM_API_KEY
|
||||
self.base_url = base_url or Config.LLM_BASE_URL
|
||||
self.model_name = model_name or Config.LLM_MODEL_NAME
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
def generate_profile_from_entity(
|
||||
self,
|
||||
entity: EntityNode,
|
||||
user_id: int,
|
||||
use_llm: bool = True
|
||||
) -> OasisAgentProfile:
|
||||
"""
|
||||
从Zep实体生成OASIS Agent Profile
|
||||
|
||||
Args:
|
||||
entity: Zep实体节点
|
||||
user_id: 用户ID(用于OASIS)
|
||||
use_llm: 是否使用LLM生成详细人设
|
||||
|
||||
Returns:
|
||||
OasisAgentProfile
|
||||
"""
|
||||
entity_type = entity.get_entity_type() or "Entity"
|
||||
|
||||
# 基础信息
|
||||
name = entity.name
|
||||
user_name = self._generate_username(name)
|
||||
|
||||
# 构建上下文信息
|
||||
context = self._build_entity_context(entity)
|
||||
|
||||
if use_llm:
|
||||
# 使用LLM生成详细人设
|
||||
profile_data = self._generate_profile_with_llm(
|
||||
entity_name=name,
|
||||
entity_type=entity_type,
|
||||
entity_summary=entity.summary,
|
||||
entity_attributes=entity.attributes,
|
||||
context=context
|
||||
)
|
||||
else:
|
||||
# 使用规则生成基础人设
|
||||
profile_data = self._generate_profile_rule_based(
|
||||
entity_name=name,
|
||||
entity_type=entity_type,
|
||||
entity_summary=entity.summary,
|
||||
entity_attributes=entity.attributes
|
||||
)
|
||||
|
||||
return OasisAgentProfile(
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
name=name,
|
||||
bio=profile_data.get("bio", f"{entity_type}: {name}"),
|
||||
persona=profile_data.get("persona", entity.summary or f"A {entity_type} named {name}."),
|
||||
karma=profile_data.get("karma", random.randint(500, 5000)),
|
||||
friend_count=profile_data.get("friend_count", random.randint(50, 500)),
|
||||
follower_count=profile_data.get("follower_count", random.randint(100, 1000)),
|
||||
statuses_count=profile_data.get("statuses_count", random.randint(100, 2000)),
|
||||
age=profile_data.get("age"),
|
||||
gender=profile_data.get("gender"),
|
||||
mbti=profile_data.get("mbti"),
|
||||
country=profile_data.get("country"),
|
||||
profession=profile_data.get("profession"),
|
||||
interested_topics=profile_data.get("interested_topics", []),
|
||||
source_entity_uuid=entity.uuid,
|
||||
source_entity_type=entity_type,
|
||||
)
|
||||
|
||||
def _generate_username(self, name: str) -> str:
|
||||
"""生成用户名"""
|
||||
# 移除特殊字符,转换为小写
|
||||
username = name.lower().replace(" ", "_")
|
||||
username = ''.join(c for c in username if c.isalnum() or c == '_')
|
||||
|
||||
# 添加随机后缀避免重复
|
||||
suffix = random.randint(100, 999)
|
||||
return f"{username}_{suffix}"
|
||||
|
||||
def _build_entity_context(self, entity: EntityNode) -> str:
|
||||
"""构建实体的上下文信息"""
|
||||
context_parts = []
|
||||
|
||||
# 添加相关边信息
|
||||
if entity.related_edges:
|
||||
relationships = []
|
||||
for edge in entity.related_edges[:10]: # 最多取10条
|
||||
if edge.get("fact"):
|
||||
relationships.append(edge["fact"])
|
||||
|
||||
if relationships:
|
||||
context_parts.append("Related facts:\n" + "\n".join(f"- {r}" for r in relationships))
|
||||
|
||||
# 添加关联节点信息
|
||||
if entity.related_nodes:
|
||||
related_names = [n["name"] for n in entity.related_nodes[:5]]
|
||||
if related_names:
|
||||
context_parts.append(f"Related to: {', '.join(related_names)}")
|
||||
|
||||
return "\n\n".join(context_parts)
|
||||
|
||||
def _generate_profile_with_llm(
|
||||
self,
|
||||
entity_name: str,
|
||||
entity_type: str,
|
||||
entity_summary: str,
|
||||
entity_attributes: Dict[str, Any],
|
||||
context: str
|
||||
) -> Dict[str, Any]:
|
||||
"""使用LLM生成详细人设"""
|
||||
|
||||
prompt = f"""Based on the following entity information, generate a detailed social media user profile for simulation purposes.
|
||||
|
||||
Entity Information:
|
||||
- Name: {entity_name}
|
||||
- Type: {entity_type}
|
||||
- Summary: {entity_summary}
|
||||
- Attributes: {json.dumps(entity_attributes, ensure_ascii=False)}
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
Generate a JSON object with the following fields:
|
||||
{{
|
||||
"bio": "A short bio (max 150 chars) suitable for social media",
|
||||
"persona": "A detailed persona description (2-3 sentences) describing personality, interests, and behavior patterns",
|
||||
"age": <integer between 18-65, or null if not applicable>,
|
||||
"gender": "<male/female/other, or null if not applicable>",
|
||||
"mbti": "<MBTI type like INTJ, ENFP, etc., or null>",
|
||||
"country": "<country name, or null>",
|
||||
"profession": "<profession/occupation, or null>",
|
||||
"interested_topics": ["topic1", "topic2", ...]
|
||||
}}
|
||||
|
||||
Important:
|
||||
- The profile should be consistent with the entity type and context
|
||||
- Make the persona feel realistic and suitable for social media simulation
|
||||
- If the entity is an organization, institution, or non-person, adapt the profile accordingly (e.g., as an official account)
|
||||
- Return ONLY the JSON object, no additional text"""
|
||||
|
||||
try:
|
||||
# 使用重试机制调用LLM API
|
||||
from ..utils.retry import RetryableAPIClient
|
||||
|
||||
retry_client = RetryableAPIClient(max_retries=3, initial_delay=1.0)
|
||||
|
||||
def call_llm():
|
||||
return self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a profile generator for social media simulation. Generate realistic user profiles based on entity information."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
response = retry_client.call_with_retry(call_llm)
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM生成人设失败(已重试): {str(e)}, 使用规则生成")
|
||||
return self._generate_profile_rule_based(
|
||||
entity_name, entity_type, entity_summary, entity_attributes
|
||||
)
|
||||
|
||||
def _generate_profile_rule_based(
|
||||
self,
|
||||
entity_name: str,
|
||||
entity_type: str,
|
||||
entity_summary: str,
|
||||
entity_attributes: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""使用规则生成基础人设"""
|
||||
|
||||
# 根据实体类型生成不同的人设
|
||||
entity_type_lower = entity_type.lower()
|
||||
|
||||
if entity_type_lower in ["student", "alumni"]:
|
||||
return {
|
||||
"bio": f"{entity_type} with interests in academics and social issues.",
|
||||
"persona": f"{entity_name} is a {entity_type.lower()} who is actively engaged in academic and social discussions. They enjoy sharing perspectives and connecting with peers.",
|
||||
"age": random.randint(18, 30),
|
||||
"gender": random.choice(["male", "female"]),
|
||||
"mbti": random.choice(self.MBTI_TYPES),
|
||||
"country": random.choice(self.COUNTRIES),
|
||||
"profession": "Student",
|
||||
"interested_topics": ["Education", "Social Issues", "Technology"],
|
||||
}
|
||||
|
||||
elif entity_type_lower in ["publicfigure", "expert", "faculty"]:
|
||||
return {
|
||||
"bio": f"Expert and thought leader in their field.",
|
||||
"persona": f"{entity_name} is a recognized {entity_type.lower()} who shares insights and opinions on important matters. They are known for their expertise and influence in public discourse.",
|
||||
"age": random.randint(35, 60),
|
||||
"gender": random.choice(["male", "female"]),
|
||||
"mbti": random.choice(["ENTJ", "INTJ", "ENTP", "INTP"]),
|
||||
"country": random.choice(self.COUNTRIES),
|
||||
"profession": entity_attributes.get("occupation", "Expert"),
|
||||
"interested_topics": ["Politics", "Economics", "Culture & Society"],
|
||||
}
|
||||
|
||||
elif entity_type_lower in ["mediaoutlet", "socialmediaplatform"]:
|
||||
return {
|
||||
"bio": f"Official account for {entity_name}. News and updates.",
|
||||
"persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.",
|
||||
"profession": "Media",
|
||||
"interested_topics": ["General News", "Current Events", "Public Affairs"],
|
||||
}
|
||||
|
||||
elif entity_type_lower in ["university", "governmentagency", "ngo", "organization"]:
|
||||
return {
|
||||
"bio": f"Official account of {entity_name}.",
|
||||
"persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.",
|
||||
"profession": entity_type,
|
||||
"interested_topics": ["Public Policy", "Community", "Official Announcements"],
|
||||
}
|
||||
|
||||
else:
|
||||
# 默认人设
|
||||
return {
|
||||
"bio": entity_summary[:150] if entity_summary else f"{entity_type}: {entity_name}",
|
||||
"persona": entity_summary or f"{entity_name} is a {entity_type.lower()} participating in social discussions.",
|
||||
"age": random.randint(25, 50),
|
||||
"gender": random.choice(["male", "female"]),
|
||||
"mbti": random.choice(self.MBTI_TYPES),
|
||||
"country": random.choice(self.COUNTRIES),
|
||||
"profession": entity_type,
|
||||
"interested_topics": ["General", "Social Issues"],
|
||||
}
|
||||
|
||||
def generate_profiles_from_entities(
|
||||
self,
|
||||
entities: List[EntityNode],
|
||||
use_llm: bool = True,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> List[OasisAgentProfile]:
|
||||
"""
|
||||
批量从实体生成Agent Profile
|
||||
|
||||
Args:
|
||||
entities: 实体列表
|
||||
use_llm: 是否使用LLM生成详细人设
|
||||
progress_callback: 进度回调函数 (current, total, message)
|
||||
|
||||
Returns:
|
||||
Agent Profile列表
|
||||
"""
|
||||
profiles = []
|
||||
total = len(entities)
|
||||
|
||||
for idx, entity in enumerate(entities):
|
||||
if progress_callback:
|
||||
progress_callback(idx + 1, total, f"生成 {entity.name} 的人设...")
|
||||
|
||||
try:
|
||||
profile = self.generate_profile_from_entity(
|
||||
entity=entity,
|
||||
user_id=idx,
|
||||
use_llm=use_llm
|
||||
)
|
||||
profiles.append(profile)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成实体 {entity.name} 的人设失败: {str(e)}")
|
||||
# 创建一个基础profile
|
||||
profiles.append(OasisAgentProfile(
|
||||
user_id=idx,
|
||||
user_name=self._generate_username(entity.name),
|
||||
name=entity.name,
|
||||
bio=f"{entity.get_entity_type() or 'Entity'}: {entity.name}",
|
||||
persona=entity.summary or f"A participant in social discussions.",
|
||||
source_entity_uuid=entity.uuid,
|
||||
source_entity_type=entity.get_entity_type(),
|
||||
))
|
||||
|
||||
return profiles
|
||||
|
||||
def save_profiles(
|
||||
self,
|
||||
profiles: List[OasisAgentProfile],
|
||||
file_path: str,
|
||||
platform: str = "reddit"
|
||||
):
|
||||
"""
|
||||
保存Profile到文件(根据平台选择正确格式)
|
||||
|
||||
OASIS平台格式要求:
|
||||
- Twitter: CSV格式
|
||||
- Reddit: JSON格式
|
||||
|
||||
Args:
|
||||
profiles: Profile列表
|
||||
file_path: 文件路径
|
||||
platform: 平台类型 ("reddit" 或 "twitter")
|
||||
"""
|
||||
if platform == "twitter":
|
||||
self._save_twitter_csv(profiles, file_path)
|
||||
else:
|
||||
self._save_reddit_json(profiles, file_path)
|
||||
|
||||
def _save_twitter_csv(self, profiles: List[OasisAgentProfile], file_path: str):
|
||||
"""
|
||||
保存Twitter Profile为CSV格式
|
||||
|
||||
OASIS Twitter要求的CSV字段:
|
||||
user_id, user_name, name, bio, friend_count, follower_count, statuses_count, created_at
|
||||
"""
|
||||
import csv
|
||||
|
||||
# 确保文件扩展名是.csv
|
||||
if not file_path.endswith('.csv'):
|
||||
file_path = file_path.replace('.json', '.csv')
|
||||
|
||||
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.writer(f)
|
||||
|
||||
# 写入表头
|
||||
headers = ['user_id', 'user_name', 'name', 'bio', 'friend_count',
|
||||
'follower_count', 'statuses_count', 'created_at']
|
||||
writer.writerow(headers)
|
||||
|
||||
# 写入数据行
|
||||
for profile in profiles:
|
||||
# bio需要处理换行符和逗号
|
||||
bio = profile.bio.replace('\n', ' ').replace('\r', ' ')
|
||||
row = [
|
||||
profile.user_id,
|
||||
profile.user_name,
|
||||
profile.name,
|
||||
bio,
|
||||
profile.friend_count,
|
||||
profile.follower_count,
|
||||
profile.statuses_count,
|
||||
profile.created_at
|
||||
]
|
||||
writer.writerow(row)
|
||||
|
||||
logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (CSV格式)")
|
||||
|
||||
def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str):
|
||||
"""
|
||||
保存Reddit Profile为JSON格式
|
||||
|
||||
OASIS Reddit支持两种JSON格式:
|
||||
1. 基础格式: user_id, user_name, name, bio, karma, created_at
|
||||
2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics
|
||||
|
||||
我们使用详细格式,与用户示例数据(36个简单人设.json)保持一致
|
||||
"""
|
||||
data = []
|
||||
for profile in profiles:
|
||||
# 使用详细格式(与用户示例兼容)
|
||||
item = {
|
||||
"realname": profile.name,
|
||||
"username": profile.user_name,
|
||||
"bio": profile.bio[:150] if profile.bio else "", # OASIS bio限制150字符
|
||||
"persona": profile.persona or f"{profile.name} is a participant in social discussions.",
|
||||
}
|
||||
|
||||
# 添加人设详情字段
|
||||
if profile.age:
|
||||
item["age"] = profile.age
|
||||
if profile.gender:
|
||||
item["gender"] = profile.gender
|
||||
if profile.mbti:
|
||||
item["mbti"] = profile.mbti
|
||||
if profile.country:
|
||||
item["country"] = profile.country
|
||||
if profile.profession:
|
||||
item["profession"] = profile.profession
|
||||
if profile.interested_topics:
|
||||
item["interested_topics"] = profile.interested_topics
|
||||
|
||||
data.append(item)
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式)")
|
||||
|
||||
# 保留旧方法名作为别名,保持向后兼容
|
||||
def save_profiles_to_json(
|
||||
self,
|
||||
profiles: List[OasisAgentProfile],
|
||||
file_path: str,
|
||||
platform: str = "reddit"
|
||||
):
|
||||
"""[已废弃] 请使用 save_profiles() 方法"""
|
||||
logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法")
|
||||
self.save_profiles(profiles, file_path, platform)
|
||||
|
||||
584
backend/app/services/simulation_config_generator.py
Normal file
584
backend/app/services/simulation_config_generator.py
Normal file
|
|
@ -0,0 +1,584 @@
|
|||
"""
|
||||
模拟配置智能生成器
|
||||
使用LLM根据模拟需求、文档内容、图谱信息自动生成细致的模拟参数
|
||||
实现全程自动化,无需人工设置参数
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
||||
logger = get_logger('mirofish.simulation_config')
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentActivityConfig:
|
||||
"""单个Agent的活动配置"""
|
||||
agent_id: int
|
||||
entity_uuid: str
|
||||
entity_name: str
|
||||
entity_type: str
|
||||
|
||||
# 活跃度配置 (0.0-1.0)
|
||||
activity_level: float = 0.5 # 整体活跃度
|
||||
|
||||
# 发言频率(每小时预期发言次数)
|
||||
posts_per_hour: float = 1.0
|
||||
comments_per_hour: float = 2.0
|
||||
|
||||
# 活跃时间段(24小时制,0-23)
|
||||
active_hours: List[int] = field(default_factory=lambda: list(range(8, 23)))
|
||||
|
||||
# 响应速度(对热点事件的反应延迟,单位:模拟分钟)
|
||||
response_delay_min: int = 5
|
||||
response_delay_max: int = 60
|
||||
|
||||
# 情感倾向 (-1.0到1.0,负面到正面)
|
||||
sentiment_bias: float = 0.0
|
||||
|
||||
# 立场(对特定话题的态度)
|
||||
stance: str = "neutral" # supportive, opposing, neutral, observer
|
||||
|
||||
# 影响力权重(决定其发言被其他Agent看到的概率)
|
||||
influence_weight: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeSimulationConfig:
|
||||
"""时间模拟配置"""
|
||||
# 模拟总时长(模拟小时数)
|
||||
total_simulation_hours: int = 72 # 默认模拟72小时(3天)
|
||||
|
||||
# 每轮代表的时间(模拟分钟)
|
||||
minutes_per_round: int = 30
|
||||
|
||||
# 每小时激活的Agent数量范围
|
||||
agents_per_hour_min: int = 5
|
||||
agents_per_hour_max: int = 20
|
||||
|
||||
# 高峰时段(活跃度提升)
|
||||
peak_hours: List[int] = field(default_factory=lambda: [9, 10, 11, 14, 15, 20, 21, 22])
|
||||
peak_activity_multiplier: float = 1.5
|
||||
|
||||
# 低谷时段(活跃度降低)
|
||||
off_peak_hours: List[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5, 6])
|
||||
off_peak_activity_multiplier: float = 0.3
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventConfig:
|
||||
"""事件配置"""
|
||||
# 初始事件(模拟开始时的触发事件)
|
||||
initial_posts: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# 定时事件(在特定时间触发的事件)
|
||||
scheduled_events: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# 热点话题关键词
|
||||
hot_topics: List[str] = field(default_factory=list)
|
||||
|
||||
# 舆论引导方向
|
||||
narrative_direction: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlatformConfig:
|
||||
"""平台特定配置"""
|
||||
platform: str # twitter or reddit
|
||||
|
||||
# 推荐算法权重
|
||||
recency_weight: float = 0.4 # 时间新鲜度
|
||||
popularity_weight: float = 0.3 # 热度
|
||||
relevance_weight: float = 0.3 # 相关性
|
||||
|
||||
# 病毒传播阈值(达到多少互动后触发扩散)
|
||||
viral_threshold: int = 10
|
||||
|
||||
# 回声室效应强度(相似观点聚集程度)
|
||||
echo_chamber_strength: float = 0.5
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimulationParameters:
|
||||
"""完整的模拟参数配置"""
|
||||
# 基础信息
|
||||
simulation_id: str
|
||||
project_id: str
|
||||
graph_id: str
|
||||
simulation_requirement: str
|
||||
|
||||
# 时间配置
|
||||
time_config: TimeSimulationConfig = field(default_factory=TimeSimulationConfig)
|
||||
|
||||
# Agent配置列表
|
||||
agent_configs: List[AgentActivityConfig] = field(default_factory=list)
|
||||
|
||||
# 事件配置
|
||||
event_config: EventConfig = field(default_factory=EventConfig)
|
||||
|
||||
# 平台配置
|
||||
twitter_config: Optional[PlatformConfig] = None
|
||||
reddit_config: Optional[PlatformConfig] = None
|
||||
|
||||
# LLM配置
|
||||
llm_model: str = ""
|
||||
llm_base_url: str = ""
|
||||
|
||||
# 生成元数据
|
||||
generated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
generation_reasoning: str = "" # LLM的推理说明
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"simulation_id": self.simulation_id,
|
||||
"project_id": self.project_id,
|
||||
"graph_id": self.graph_id,
|
||||
"simulation_requirement": self.simulation_requirement,
|
||||
"time_config": asdict(self.time_config),
|
||||
"agent_configs": [asdict(a) for a in self.agent_configs],
|
||||
"event_config": asdict(self.event_config),
|
||||
"twitter_config": asdict(self.twitter_config) if self.twitter_config else None,
|
||||
"reddit_config": asdict(self.reddit_config) if self.reddit_config else None,
|
||||
"llm_model": self.llm_model,
|
||||
"llm_base_url": self.llm_base_url,
|
||||
"generated_at": self.generated_at,
|
||||
"generation_reasoning": self.generation_reasoning,
|
||||
}
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
"""转换为JSON字符串"""
|
||||
return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent)
|
||||
|
||||
|
||||
class SimulationConfigGenerator:
|
||||
"""
|
||||
模拟配置智能生成器
|
||||
|
||||
使用LLM分析模拟需求、文档内容、图谱实体信息,
|
||||
自动生成最佳的模拟参数配置
|
||||
"""
|
||||
|
||||
# 上下文最大字符数
|
||||
MAX_CONTEXT_LENGTH = 50000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
model_name: Optional[str] = None
|
||||
):
|
||||
self.api_key = api_key or Config.LLM_API_KEY
|
||||
self.base_url = base_url or Config.LLM_BASE_URL
|
||||
self.model_name = model_name or Config.LLM_MODEL_NAME
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
def generate_config(
|
||||
self,
|
||||
simulation_id: str,
|
||||
project_id: str,
|
||||
graph_id: str,
|
||||
simulation_requirement: str,
|
||||
document_text: str,
|
||||
entities: List[EntityNode],
|
||||
enable_twitter: bool = True,
|
||||
enable_reddit: bool = True,
|
||||
) -> SimulationParameters:
|
||||
"""
|
||||
智能生成完整的模拟配置
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
project_id: 项目ID
|
||||
graph_id: 图谱ID
|
||||
simulation_requirement: 模拟需求描述
|
||||
document_text: 原始文档内容
|
||||
entities: 过滤后的实体列表
|
||||
enable_twitter: 是否启用Twitter
|
||||
enable_reddit: 是否启用Reddit
|
||||
|
||||
Returns:
|
||||
SimulationParameters: 完整的模拟参数
|
||||
"""
|
||||
logger.info(f"开始智能生成模拟配置: simulation_id={simulation_id}")
|
||||
|
||||
# 1. 构建上下文信息(截断到50000字符)
|
||||
context = self._build_context(
|
||||
simulation_requirement=simulation_requirement,
|
||||
document_text=document_text,
|
||||
entities=entities
|
||||
)
|
||||
|
||||
# 2. 调用LLM生成配置
|
||||
llm_result = self._generate_config_with_llm(
|
||||
context=context,
|
||||
entities=entities,
|
||||
enable_twitter=enable_twitter,
|
||||
enable_reddit=enable_reddit
|
||||
)
|
||||
|
||||
# 3. 构建SimulationParameters对象
|
||||
params = self._build_parameters(
|
||||
simulation_id=simulation_id,
|
||||
project_id=project_id,
|
||||
graph_id=graph_id,
|
||||
simulation_requirement=simulation_requirement,
|
||||
entities=entities,
|
||||
llm_result=llm_result,
|
||||
enable_twitter=enable_twitter,
|
||||
enable_reddit=enable_reddit
|
||||
)
|
||||
|
||||
logger.info(f"模拟配置生成完成: {len(params.agent_configs)} 个Agent配置")
|
||||
|
||||
return params
|
||||
|
||||
def _build_context(
|
||||
self,
|
||||
simulation_requirement: str,
|
||||
document_text: str,
|
||||
entities: List[EntityNode]
|
||||
) -> str:
|
||||
"""构建LLM上下文,截断到最大长度"""
|
||||
|
||||
# 实体摘要
|
||||
entity_summary = self._summarize_entities(entities)
|
||||
|
||||
# 构建上下文
|
||||
context_parts = [
|
||||
f"## 模拟需求\n{simulation_requirement}",
|
||||
f"\n## 实体信息 ({len(entities)}个)\n{entity_summary}",
|
||||
]
|
||||
|
||||
current_length = sum(len(p) for p in context_parts)
|
||||
remaining_length = self.MAX_CONTEXT_LENGTH - current_length - 500 # 留500字符余量
|
||||
|
||||
if remaining_length > 0 and document_text:
|
||||
doc_text = document_text[:remaining_length]
|
||||
if len(document_text) > remaining_length:
|
||||
doc_text += "\n...(文档已截断)"
|
||||
context_parts.append(f"\n## 原始文档内容\n{doc_text}")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def _summarize_entities(self, entities: List[EntityNode]) -> str:
|
||||
"""生成实体摘要"""
|
||||
lines = []
|
||||
|
||||
# 按类型分组
|
||||
by_type: Dict[str, List[EntityNode]] = {}
|
||||
for e in entities:
|
||||
t = e.get_entity_type() or "Unknown"
|
||||
if t not in by_type:
|
||||
by_type[t] = []
|
||||
by_type[t].append(e)
|
||||
|
||||
for entity_type, type_entities in by_type.items():
|
||||
lines.append(f"\n### {entity_type} ({len(type_entities)}个)")
|
||||
for e in type_entities[:10]: # 每类最多显示10个
|
||||
summary_preview = (e.summary[:100] + "...") if len(e.summary) > 100 else e.summary
|
||||
lines.append(f"- {e.name}: {summary_preview}")
|
||||
if len(type_entities) > 10:
|
||||
lines.append(f" ... 还有 {len(type_entities) - 10} 个")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_config_with_llm(
|
||||
self,
|
||||
context: str,
|
||||
entities: List[EntityNode],
|
||||
enable_twitter: bool,
|
||||
enable_reddit: bool
|
||||
) -> Dict[str, Any]:
|
||||
"""调用LLM生成配置"""
|
||||
|
||||
# 构建实体列表用于Agent配置
|
||||
entity_list = []
|
||||
for i, e in enumerate(entities):
|
||||
entity_list.append({
|
||||
"agent_id": i,
|
||||
"entity_uuid": e.uuid,
|
||||
"entity_name": e.name,
|
||||
"entity_type": e.get_entity_type() or "Unknown",
|
||||
"summary": e.summary[:200] if e.summary else ""
|
||||
})
|
||||
|
||||
prompt = f"""你是一个社交媒体舆论模拟专家。请根据以下信息,生成详细的模拟参数配置。
|
||||
|
||||
{context}
|
||||
|
||||
## 实体列表(需要为每个实体生成活动配置)
|
||||
```json
|
||||
{json.dumps(entity_list, ensure_ascii=False, indent=2)}
|
||||
```
|
||||
|
||||
## 任务
|
||||
请生成一个JSON配置,包含以下部分:
|
||||
|
||||
1. **time_config** - 时间模拟配置
|
||||
- total_simulation_hours: 模拟总时长(小时),根据事件性质决定(短期热点24-72小时,长期舆论168-336小时)
|
||||
- minutes_per_round: 每轮代表的时间(分钟),建议15-60
|
||||
- agents_per_hour_min/max: 每小时激活的Agent数量范围
|
||||
- peak_hours: 高峰时段列表(0-23)
|
||||
- off_peak_hours: 低谷时段列表
|
||||
|
||||
2. **agent_configs** - 每个Agent的活动配置(必须为每个实体生成)
|
||||
对于每个agent_id,设置:
|
||||
- activity_level: 活跃度(0.0-1.0),官方机构通常0.1-0.3,媒体0.3-0.5,个人0.5-0.9
|
||||
- posts_per_hour: 每小时发帖频率,官方机构0.05-0.2,媒体0.5-2,个人0.1-1
|
||||
- comments_per_hour: 每小时评论频率
|
||||
- active_hours: 活跃时间段列表,官方通常工作时间,个人更分散
|
||||
- response_delay_min/max: 响应延迟(模拟分钟),官方较慢(30-180),个人较快(1-30)
|
||||
- sentiment_bias: 情感倾向(-1到1),根据实体立场设置
|
||||
- stance: 立场(supportive/opposing/neutral/observer)
|
||||
- influence_weight: 影响力权重,知名人物和媒体较高
|
||||
|
||||
3. **event_config** - 事件配置
|
||||
- initial_posts: 初始帖子列表,包含content和poster_agent_id
|
||||
- hot_topics: 热点话题关键词列表
|
||||
- narrative_direction: 舆论发展方向描述
|
||||
|
||||
4. **platform_configs** - 平台配置(如果启用)
|
||||
- viral_threshold: 病毒传播阈值
|
||||
- echo_chamber_strength: 回声室效应强度(0-1)
|
||||
|
||||
5. **reasoning** - 你的推理说明,解释为什么这样设置参数
|
||||
|
||||
## 重要原则
|
||||
- 官方机构(University、GovernmentAgency)发言频率低但影响力大
|
||||
- 媒体(MediaOutlet)发言频率中等,传播速度快
|
||||
- 个人(Student、PublicFigure)发言频率高但影响力分散
|
||||
- 根据模拟需求判断各实体的立场和情感倾向
|
||||
- 时间配置要符合真实社交媒体的使用规律
|
||||
|
||||
请返回JSON格式,不要包含markdown代码块标记。"""
|
||||
|
||||
try:
|
||||
# 使用重试机制调用LLM API
|
||||
from ..utils.retry import RetryableAPIClient
|
||||
|
||||
retry_client = RetryableAPIClient(max_retries=3, initial_delay=2.0, max_delay=60.0)
|
||||
|
||||
def call_llm():
|
||||
return self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是社交媒体舆论模拟专家,擅长设计真实的模拟参数。返回纯JSON格式,不要markdown。"
|
||||
},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.7,
|
||||
max_tokens=8000
|
||||
)
|
||||
|
||||
response = retry_client.call_with_retry(call_llm)
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
logger.info(f"LLM配置生成成功")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM配置生成失败(已重试): {str(e)}")
|
||||
# 返回默认配置
|
||||
return self._generate_default_config(entities)
|
||||
|
||||
def _generate_default_config(self, entities: List[EntityNode]) -> Dict[str, Any]:
|
||||
"""生成默认配置(LLM失败时的fallback)"""
|
||||
agent_configs = []
|
||||
|
||||
for i, e in enumerate(entities):
|
||||
entity_type = (e.get_entity_type() or "Unknown").lower()
|
||||
|
||||
# 根据实体类型设置默认参数
|
||||
if entity_type in ["university", "governmentagency", "ngo"]:
|
||||
config = {
|
||||
"agent_id": i,
|
||||
"activity_level": 0.2,
|
||||
"posts_per_hour": 0.1,
|
||||
"comments_per_hour": 0.05,
|
||||
"active_hours": list(range(9, 18)),
|
||||
"response_delay_min": 60,
|
||||
"response_delay_max": 240,
|
||||
"sentiment_bias": 0.0,
|
||||
"stance": "neutral",
|
||||
"influence_weight": 3.0
|
||||
}
|
||||
elif entity_type in ["mediaoutlet"]:
|
||||
config = {
|
||||
"agent_id": i,
|
||||
"activity_level": 0.6,
|
||||
"posts_per_hour": 1.0,
|
||||
"comments_per_hour": 0.5,
|
||||
"active_hours": list(range(6, 24)),
|
||||
"response_delay_min": 5,
|
||||
"response_delay_max": 30,
|
||||
"sentiment_bias": 0.0,
|
||||
"stance": "observer",
|
||||
"influence_weight": 2.5
|
||||
}
|
||||
elif entity_type in ["publicfigure", "expert"]:
|
||||
config = {
|
||||
"agent_id": i,
|
||||
"activity_level": 0.5,
|
||||
"posts_per_hour": 0.3,
|
||||
"comments_per_hour": 0.5,
|
||||
"active_hours": list(range(8, 23)),
|
||||
"response_delay_min": 10,
|
||||
"response_delay_max": 60,
|
||||
"sentiment_bias": 0.0,
|
||||
"stance": "neutral",
|
||||
"influence_weight": 2.0
|
||||
}
|
||||
else: # Student, Person, etc.
|
||||
config = {
|
||||
"agent_id": i,
|
||||
"activity_level": 0.7,
|
||||
"posts_per_hour": 0.5,
|
||||
"comments_per_hour": 1.0,
|
||||
"active_hours": list(range(7, 24)),
|
||||
"response_delay_min": 1,
|
||||
"response_delay_max": 20,
|
||||
"sentiment_bias": 0.0,
|
||||
"stance": "neutral",
|
||||
"influence_weight": 1.0
|
||||
}
|
||||
|
||||
agent_configs.append(config)
|
||||
|
||||
return {
|
||||
"time_config": {
|
||||
"total_simulation_hours": 72,
|
||||
"minutes_per_round": 30,
|
||||
"agents_per_hour_min": max(1, len(entities) // 10),
|
||||
"agents_per_hour_max": max(5, len(entities) // 3),
|
||||
"peak_hours": [9, 10, 11, 14, 15, 20, 21, 22],
|
||||
"off_peak_hours": [0, 1, 2, 3, 4, 5]
|
||||
},
|
||||
"agent_configs": agent_configs,
|
||||
"event_config": {
|
||||
"initial_posts": [],
|
||||
"hot_topics": [],
|
||||
"narrative_direction": ""
|
||||
},
|
||||
"reasoning": "使用默认配置(LLM生成失败)"
|
||||
}
|
||||
|
||||
def _build_parameters(
|
||||
self,
|
||||
simulation_id: str,
|
||||
project_id: str,
|
||||
graph_id: str,
|
||||
simulation_requirement: str,
|
||||
entities: List[EntityNode],
|
||||
llm_result: Dict[str, Any],
|
||||
enable_twitter: bool,
|
||||
enable_reddit: bool
|
||||
) -> SimulationParameters:
|
||||
"""根据LLM结果构建SimulationParameters对象"""
|
||||
|
||||
# 时间配置
|
||||
time_cfg = llm_result.get("time_config", {})
|
||||
time_config = TimeSimulationConfig(
|
||||
total_simulation_hours=time_cfg.get("total_simulation_hours", 72),
|
||||
minutes_per_round=time_cfg.get("minutes_per_round", 30),
|
||||
agents_per_hour_min=time_cfg.get("agents_per_hour_min", 5),
|
||||
agents_per_hour_max=time_cfg.get("agents_per_hour_max", 20),
|
||||
peak_hours=time_cfg.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]),
|
||||
off_peak_hours=time_cfg.get("off_peak_hours", [0, 1, 2, 3, 4, 5]),
|
||||
peak_activity_multiplier=time_cfg.get("peak_activity_multiplier", 1.5),
|
||||
off_peak_activity_multiplier=time_cfg.get("off_peak_activity_multiplier", 0.3)
|
||||
)
|
||||
|
||||
# Agent配置
|
||||
agent_configs = []
|
||||
llm_agent_configs = {cfg["agent_id"]: cfg for cfg in llm_result.get("agent_configs", [])}
|
||||
|
||||
for i, entity in enumerate(entities):
|
||||
cfg = llm_agent_configs.get(i, {})
|
||||
|
||||
agent_config = AgentActivityConfig(
|
||||
agent_id=i,
|
||||
entity_uuid=entity.uuid,
|
||||
entity_name=entity.name,
|
||||
entity_type=entity.get_entity_type() or "Unknown",
|
||||
activity_level=cfg.get("activity_level", 0.5),
|
||||
posts_per_hour=cfg.get("posts_per_hour", 0.5),
|
||||
comments_per_hour=cfg.get("comments_per_hour", 1.0),
|
||||
active_hours=cfg.get("active_hours", list(range(8, 23))),
|
||||
response_delay_min=cfg.get("response_delay_min", 5),
|
||||
response_delay_max=cfg.get("response_delay_max", 60),
|
||||
sentiment_bias=cfg.get("sentiment_bias", 0.0),
|
||||
stance=cfg.get("stance", "neutral"),
|
||||
influence_weight=cfg.get("influence_weight", 1.0)
|
||||
)
|
||||
agent_configs.append(agent_config)
|
||||
|
||||
# 事件配置
|
||||
event_cfg = llm_result.get("event_config", {})
|
||||
event_config = EventConfig(
|
||||
initial_posts=event_cfg.get("initial_posts", []),
|
||||
scheduled_events=event_cfg.get("scheduled_events", []),
|
||||
hot_topics=event_cfg.get("hot_topics", []),
|
||||
narrative_direction=event_cfg.get("narrative_direction", "")
|
||||
)
|
||||
|
||||
# 平台配置
|
||||
twitter_config = None
|
||||
reddit_config = None
|
||||
|
||||
platform_cfgs = llm_result.get("platform_configs", {})
|
||||
|
||||
if enable_twitter:
|
||||
tw_cfg = platform_cfgs.get("twitter", {})
|
||||
twitter_config = PlatformConfig(
|
||||
platform="twitter",
|
||||
recency_weight=tw_cfg.get("recency_weight", 0.4),
|
||||
popularity_weight=tw_cfg.get("popularity_weight", 0.3),
|
||||
relevance_weight=tw_cfg.get("relevance_weight", 0.3),
|
||||
viral_threshold=tw_cfg.get("viral_threshold", 10),
|
||||
echo_chamber_strength=tw_cfg.get("echo_chamber_strength", 0.5)
|
||||
)
|
||||
|
||||
if enable_reddit:
|
||||
rd_cfg = platform_cfgs.get("reddit", {})
|
||||
reddit_config = PlatformConfig(
|
||||
platform="reddit",
|
||||
recency_weight=rd_cfg.get("recency_weight", 0.3),
|
||||
popularity_weight=rd_cfg.get("popularity_weight", 0.4),
|
||||
relevance_weight=rd_cfg.get("relevance_weight", 0.3),
|
||||
viral_threshold=rd_cfg.get("viral_threshold", 15),
|
||||
echo_chamber_strength=rd_cfg.get("echo_chamber_strength", 0.6)
|
||||
)
|
||||
|
||||
return SimulationParameters(
|
||||
simulation_id=simulation_id,
|
||||
project_id=project_id,
|
||||
graph_id=graph_id,
|
||||
simulation_requirement=simulation_requirement,
|
||||
time_config=time_config,
|
||||
agent_configs=agent_configs,
|
||||
event_config=event_config,
|
||||
twitter_config=twitter_config,
|
||||
reddit_config=reddit_config,
|
||||
llm_model=self.model_name,
|
||||
llm_base_url=self.base_url,
|
||||
generation_reasoning=llm_result.get("reasoning", "")
|
||||
)
|
||||
|
||||
|
||||
546
backend/app/services/simulation_manager.py
Normal file
546
backend/app/services/simulation_manager.py
Normal file
|
|
@ -0,0 +1,546 @@
|
|||
"""
|
||||
OASIS模拟管理器
|
||||
管理Twitter和Reddit双平台并行模拟
|
||||
使用预设脚本 + LLM智能生成配置参数
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_entity_reader import ZepEntityReader, FilteredEntities
|
||||
from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
|
||||
from .simulation_config_generator import SimulationConfigGenerator, SimulationParameters
|
||||
|
||||
logger = get_logger('mirofish.simulation')
|
||||
|
||||
|
||||
class SimulationStatus(str, Enum):
|
||||
"""模拟状态"""
|
||||
CREATED = "created"
|
||||
PREPARING = "preparing"
|
||||
READY = "ready"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class PlatformType(str, Enum):
|
||||
"""平台类型"""
|
||||
TWITTER = "twitter"
|
||||
REDDIT = "reddit"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimulationState:
|
||||
"""模拟状态"""
|
||||
simulation_id: str
|
||||
project_id: str
|
||||
graph_id: str
|
||||
|
||||
# 平台启用状态
|
||||
enable_twitter: bool = True
|
||||
enable_reddit: bool = True
|
||||
|
||||
# 状态
|
||||
status: SimulationStatus = SimulationStatus.CREATED
|
||||
|
||||
# 准备阶段数据
|
||||
entities_count: int = 0
|
||||
profiles_count: int = 0
|
||||
entity_types: List[str] = field(default_factory=list)
|
||||
|
||||
# 配置生成信息
|
||||
config_generated: bool = False
|
||||
config_reasoning: str = ""
|
||||
|
||||
# 运行时数据
|
||||
current_round: int = 0
|
||||
twitter_status: str = "not_started"
|
||||
reddit_status: str = "not_started"
|
||||
|
||||
# 时间戳
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
# 错误信息
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""完整状态字典(内部使用)"""
|
||||
return {
|
||||
"simulation_id": self.simulation_id,
|
||||
"project_id": self.project_id,
|
||||
"graph_id": self.graph_id,
|
||||
"enable_twitter": self.enable_twitter,
|
||||
"enable_reddit": self.enable_reddit,
|
||||
"status": self.status.value,
|
||||
"entities_count": self.entities_count,
|
||||
"profiles_count": self.profiles_count,
|
||||
"entity_types": self.entity_types,
|
||||
"config_generated": self.config_generated,
|
||||
"config_reasoning": self.config_reasoning,
|
||||
"current_round": self.current_round,
|
||||
"twitter_status": self.twitter_status,
|
||||
"reddit_status": self.reddit_status,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
def to_simple_dict(self) -> Dict[str, Any]:
|
||||
"""简化状态字典(API返回使用)"""
|
||||
return {
|
||||
"simulation_id": self.simulation_id,
|
||||
"project_id": self.project_id,
|
||||
"graph_id": self.graph_id,
|
||||
"status": self.status.value,
|
||||
"entities_count": self.entities_count,
|
||||
"profiles_count": self.profiles_count,
|
||||
"entity_types": self.entity_types,
|
||||
"config_generated": self.config_generated,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
class SimulationManager:
|
||||
"""
|
||||
模拟管理器
|
||||
|
||||
核心功能:
|
||||
1. 从Zep图谱读取实体并过滤
|
||||
2. 生成OASIS Agent Profile
|
||||
3. 使用LLM智能生成模拟配置参数
|
||||
4. 准备预设脚本所需的所有文件
|
||||
"""
|
||||
|
||||
# 模拟数据存储目录
|
||||
SIMULATION_DATA_DIR = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../../uploads/simulations'
|
||||
)
|
||||
|
||||
# 预设脚本目录
|
||||
SCRIPTS_DIR = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../../scripts'
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
# 确保目录存在
|
||||
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
||||
|
||||
# 内存中的模拟状态缓存
|
||||
self._simulations: Dict[str, SimulationState] = {}
|
||||
|
||||
def _get_simulation_dir(self, simulation_id: str) -> str:
|
||||
"""获取模拟数据目录"""
|
||||
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
|
||||
os.makedirs(sim_dir, exist_ok=True)
|
||||
return sim_dir
|
||||
|
||||
def _save_simulation_state(self, state: SimulationState):
|
||||
"""保存模拟状态到文件"""
|
||||
sim_dir = self._get_simulation_dir(state.simulation_id)
|
||||
state_file = os.path.join(sim_dir, "state.json")
|
||||
|
||||
state.updated_at = datetime.now().isoformat()
|
||||
|
||||
with open(state_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(state.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
|
||||
self._simulations[state.simulation_id] = state
|
||||
|
||||
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
|
||||
"""从文件加载模拟状态"""
|
||||
if simulation_id in self._simulations:
|
||||
return self._simulations[simulation_id]
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
state_file = os.path.join(sim_dir, "state.json")
|
||||
|
||||
if not os.path.exists(state_file):
|
||||
return None
|
||||
|
||||
with open(state_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
state = SimulationState(
|
||||
simulation_id=simulation_id,
|
||||
project_id=data.get("project_id", ""),
|
||||
graph_id=data.get("graph_id", ""),
|
||||
enable_twitter=data.get("enable_twitter", True),
|
||||
enable_reddit=data.get("enable_reddit", True),
|
||||
status=SimulationStatus(data.get("status", "created")),
|
||||
entities_count=data.get("entities_count", 0),
|
||||
profiles_count=data.get("profiles_count", 0),
|
||||
entity_types=data.get("entity_types", []),
|
||||
config_generated=data.get("config_generated", False),
|
||||
config_reasoning=data.get("config_reasoning", ""),
|
||||
current_round=data.get("current_round", 0),
|
||||
twitter_status=data.get("twitter_status", "not_started"),
|
||||
reddit_status=data.get("reddit_status", "not_started"),
|
||||
created_at=data.get("created_at", datetime.now().isoformat()),
|
||||
updated_at=data.get("updated_at", datetime.now().isoformat()),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
self._simulations[simulation_id] = state
|
||||
return state
|
||||
|
||||
def create_simulation(
|
||||
self,
|
||||
project_id: str,
|
||||
graph_id: str,
|
||||
enable_twitter: bool = True,
|
||||
enable_reddit: bool = True,
|
||||
) -> SimulationState:
|
||||
"""
|
||||
创建新的模拟
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
graph_id: Zep图谱ID
|
||||
enable_twitter: 是否启用Twitter模拟
|
||||
enable_reddit: 是否启用Reddit模拟
|
||||
|
||||
Returns:
|
||||
SimulationState
|
||||
"""
|
||||
import uuid
|
||||
simulation_id = f"sim_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
state = SimulationState(
|
||||
simulation_id=simulation_id,
|
||||
project_id=project_id,
|
||||
graph_id=graph_id,
|
||||
enable_twitter=enable_twitter,
|
||||
enable_reddit=enable_reddit,
|
||||
status=SimulationStatus.CREATED,
|
||||
)
|
||||
|
||||
self._save_simulation_state(state)
|
||||
logger.info(f"创建模拟: {simulation_id}, project={project_id}, graph={graph_id}")
|
||||
|
||||
return state
|
||||
|
||||
def prepare_simulation(
|
||||
self,
|
||||
simulation_id: str,
|
||||
simulation_requirement: str,
|
||||
document_text: str,
|
||||
defined_entity_types: Optional[List[str]] = None,
|
||||
use_llm_for_profiles: bool = True,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> SimulationState:
|
||||
"""
|
||||
准备模拟环境(全程自动化)
|
||||
|
||||
步骤:
|
||||
1. 从Zep图谱读取并过滤实体
|
||||
2. 为每个实体生成OASIS Agent Profile(可选LLM增强)
|
||||
3. 使用LLM智能生成模拟配置参数(时间、活跃度、发言频率等)
|
||||
4. 保存配置文件和Profile文件
|
||||
5. 复制预设脚本到模拟目录
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
simulation_requirement: 模拟需求描述(用于LLM生成配置)
|
||||
document_text: 原始文档内容(用于LLM理解背景)
|
||||
defined_entity_types: 预定义的实体类型(可选)
|
||||
use_llm_for_profiles: 是否使用LLM生成详细人设
|
||||
progress_callback: 进度回调函数 (stage, progress, message)
|
||||
|
||||
Returns:
|
||||
SimulationState
|
||||
"""
|
||||
state = self._load_simulation_state(simulation_id)
|
||||
if not state:
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
try:
|
||||
state.status = SimulationStatus.PREPARING
|
||||
self._save_simulation_state(state)
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
|
||||
# ========== 阶段1: 读取并过滤实体 ==========
|
||||
if progress_callback:
|
||||
progress_callback("reading", 0, "正在连接Zep图谱...")
|
||||
|
||||
reader = ZepEntityReader()
|
||||
|
||||
if progress_callback:
|
||||
progress_callback("reading", 30, "正在读取节点数据...")
|
||||
|
||||
filtered = reader.filter_defined_entities(
|
||||
graph_id=state.graph_id,
|
||||
defined_entity_types=defined_entity_types,
|
||||
enrich_with_edges=True
|
||||
)
|
||||
|
||||
state.entities_count = filtered.filtered_count
|
||||
state.entity_types = list(filtered.entity_types)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"reading", 100,
|
||||
f"完成,共 {filtered.filtered_count} 个实体",
|
||||
current=filtered.filtered_count,
|
||||
total=filtered.filtered_count
|
||||
)
|
||||
|
||||
if filtered.filtered_count == 0:
|
||||
state.status = SimulationStatus.FAILED
|
||||
state.error = "没有找到符合条件的实体,请检查图谱是否正确构建"
|
||||
self._save_simulation_state(state)
|
||||
return state
|
||||
|
||||
# ========== 阶段2: 生成Agent Profile ==========
|
||||
total_entities = len(filtered.entities)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles", 0,
|
||||
"开始生成...",
|
||||
current=0,
|
||||
total=total_entities
|
||||
)
|
||||
|
||||
generator = OasisProfileGenerator()
|
||||
|
||||
def profile_progress(current, total, msg):
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles",
|
||||
int(current / total * 100),
|
||||
msg,
|
||||
current=current,
|
||||
total=total,
|
||||
item_name=msg
|
||||
)
|
||||
|
||||
profiles = generator.generate_profiles_from_entities(
|
||||
entities=filtered.entities,
|
||||
use_llm=use_llm_for_profiles,
|
||||
progress_callback=profile_progress
|
||||
)
|
||||
|
||||
state.profiles_count = len(profiles)
|
||||
|
||||
# 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式)
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles", 95,
|
||||
"保存Profile文件...",
|
||||
current=total_entities,
|
||||
total=total_entities
|
||||
)
|
||||
|
||||
if state.enable_reddit:
|
||||
generator.save_profiles(
|
||||
profiles=profiles,
|
||||
file_path=os.path.join(sim_dir, "reddit_profiles.json"),
|
||||
platform="reddit"
|
||||
)
|
||||
|
||||
if state.enable_twitter:
|
||||
# Twitter使用CSV格式!这是OASIS的要求
|
||||
generator.save_profiles(
|
||||
profiles=profiles,
|
||||
file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
|
||||
platform="twitter"
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles", 100,
|
||||
f"完成,共 {len(profiles)} 个Profile",
|
||||
current=len(profiles),
|
||||
total=len(profiles)
|
||||
)
|
||||
|
||||
# ========== 阶段3: LLM智能生成模拟配置 ==========
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 0,
|
||||
"正在分析模拟需求...",
|
||||
current=0,
|
||||
total=3
|
||||
)
|
||||
|
||||
config_generator = SimulationConfigGenerator()
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 30,
|
||||
"正在调用LLM生成配置...",
|
||||
current=1,
|
||||
total=3
|
||||
)
|
||||
|
||||
sim_params = config_generator.generate_config(
|
||||
simulation_id=simulation_id,
|
||||
project_id=state.project_id,
|
||||
graph_id=state.graph_id,
|
||||
simulation_requirement=simulation_requirement,
|
||||
document_text=document_text,
|
||||
entities=filtered.entities,
|
||||
enable_twitter=state.enable_twitter,
|
||||
enable_reddit=state.enable_reddit
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 70,
|
||||
"正在保存配置文件...",
|
||||
current=2,
|
||||
total=3
|
||||
)
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
f.write(sim_params.to_json())
|
||||
|
||||
state.config_generated = True
|
||||
state.config_reasoning = sim_params.generation_reasoning
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 100,
|
||||
"配置生成完成",
|
||||
current=3,
|
||||
total=3
|
||||
)
|
||||
|
||||
# ========== 阶段4: 复制预设脚本 ==========
|
||||
script_files = ["run_twitter_simulation.py", "run_reddit_simulation.py",
|
||||
"run_parallel_simulation.py", "action_logger.py"]
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"copying_scripts", 0,
|
||||
"开始准备脚本...",
|
||||
current=0,
|
||||
total=len(script_files)
|
||||
)
|
||||
|
||||
self._copy_preset_scripts(sim_dir)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"copying_scripts", 100,
|
||||
f"完成,共 {len(script_files)} 个脚本",
|
||||
current=len(script_files),
|
||||
total=len(script_files)
|
||||
)
|
||||
|
||||
# 更新状态
|
||||
state.status = SimulationStatus.READY
|
||||
self._save_simulation_state(state)
|
||||
|
||||
logger.info(f"模拟准备完成: {simulation_id}, "
|
||||
f"entities={state.entities_count}, profiles={state.profiles_count}")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模拟准备失败: {simulation_id}, error={str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
state.status = SimulationStatus.FAILED
|
||||
state.error = str(e)
|
||||
self._save_simulation_state(state)
|
||||
raise
|
||||
|
||||
def _copy_preset_scripts(self, sim_dir: str):
|
||||
"""复制预设脚本到模拟目录"""
|
||||
scripts = [
|
||||
"run_twitter_simulation.py",
|
||||
"run_reddit_simulation.py",
|
||||
"run_parallel_simulation.py"
|
||||
]
|
||||
|
||||
for script in scripts:
|
||||
src = os.path.join(self.SCRIPTS_DIR, script)
|
||||
dst = os.path.join(sim_dir, script)
|
||||
|
||||
if os.path.exists(src):
|
||||
shutil.copy2(src, dst)
|
||||
logger.debug(f"复制脚本: {script}")
|
||||
else:
|
||||
logger.warning(f"预设脚本不存在: {src}")
|
||||
|
||||
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
||||
"""获取模拟状态"""
|
||||
return self._load_simulation_state(simulation_id)
|
||||
|
||||
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
|
||||
"""列出所有模拟"""
|
||||
simulations = []
|
||||
|
||||
if os.path.exists(self.SIMULATION_DATA_DIR):
|
||||
for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
|
||||
state = self._load_simulation_state(sim_id)
|
||||
if state:
|
||||
if project_id is None or state.project_id == project_id:
|
||||
simulations.append(state)
|
||||
|
||||
return simulations
|
||||
|
||||
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
|
||||
"""获取模拟的Agent Profile"""
|
||||
state = self._load_simulation_state(simulation_id)
|
||||
if not state:
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
|
||||
|
||||
if not os.path.exists(profile_path):
|
||||
return []
|
||||
|
||||
with open(profile_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取模拟配置"""
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
return None
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
|
||||
"""获取运行说明"""
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
|
||||
return {
|
||||
"simulation_dir": sim_dir,
|
||||
"config_file": config_path,
|
||||
"commands": {
|
||||
"twitter": f"python run_twitter_simulation.py --config simulation_config.json",
|
||||
"reddit": f"python run_reddit_simulation.py --config simulation_config.json",
|
||||
"parallel": f"python run_parallel_simulation.py --config simulation_config.json",
|
||||
},
|
||||
"instructions": (
|
||||
f"1. 进入模拟目录: cd {sim_dir}\n"
|
||||
f"2. 激活conda环境: conda activate MiroFish\n"
|
||||
f"3. 运行模拟:\n"
|
||||
f" - 单独运行Twitter: python run_twitter_simulation.py --config simulation_config.json\n"
|
||||
f" - 单独运行Reddit: python run_reddit_simulation.py --config simulation_config.json\n"
|
||||
f" - 并行运行双平台: python run_parallel_simulation.py --config simulation_config.json"
|
||||
)
|
||||
}
|
||||
670
backend/app/services/simulation_runner.py
Normal file
670
backend/app/services/simulation_runner.py
Normal file
|
|
@ -0,0 +1,670 @@
|
|||
"""
|
||||
OASIS模拟运行器
|
||||
在后台运行模拟并记录每个Agent的动作,支持实时状态监控
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import threading
|
||||
import subprocess
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from queue import Queue
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.simulation_runner')
|
||||
|
||||
|
||||
class RunnerStatus(str, Enum):
|
||||
"""运行器状态"""
|
||||
IDLE = "idle"
|
||||
STARTING = "starting"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
STOPPING = "stopping"
|
||||
STOPPED = "stopped"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""Agent动作记录"""
|
||||
round_num: int
|
||||
timestamp: str
|
||||
platform: str # twitter / reddit
|
||||
agent_id: int
|
||||
agent_name: str
|
||||
action_type: str # CREATE_POST, LIKE_POST, etc.
|
||||
action_args: Dict[str, Any] = field(default_factory=dict)
|
||||
result: Optional[str] = None
|
||||
success: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"round_num": self.round_num,
|
||||
"timestamp": self.timestamp,
|
||||
"platform": self.platform,
|
||||
"agent_id": self.agent_id,
|
||||
"agent_name": self.agent_name,
|
||||
"action_type": self.action_type,
|
||||
"action_args": self.action_args,
|
||||
"result": self.result,
|
||||
"success": self.success,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoundSummary:
|
||||
"""每轮摘要"""
|
||||
round_num: int
|
||||
start_time: str
|
||||
end_time: Optional[str] = None
|
||||
simulated_hour: int = 0
|
||||
twitter_actions: int = 0
|
||||
reddit_actions: int = 0
|
||||
active_agents: List[int] = field(default_factory=list)
|
||||
actions: List[AgentAction] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"round_num": self.round_num,
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"simulated_hour": self.simulated_hour,
|
||||
"twitter_actions": self.twitter_actions,
|
||||
"reddit_actions": self.reddit_actions,
|
||||
"active_agents": self.active_agents,
|
||||
"actions_count": len(self.actions),
|
||||
"actions": [a.to_dict() for a in self.actions],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimulationRunState:
|
||||
"""模拟运行状态(实时)"""
|
||||
simulation_id: str
|
||||
runner_status: RunnerStatus = RunnerStatus.IDLE
|
||||
|
||||
# 进度信息
|
||||
current_round: int = 0
|
||||
total_rounds: int = 0
|
||||
simulated_hours: int = 0
|
||||
total_simulation_hours: int = 0
|
||||
|
||||
# 平台状态
|
||||
twitter_running: bool = False
|
||||
reddit_running: bool = False
|
||||
twitter_actions_count: int = 0
|
||||
reddit_actions_count: int = 0
|
||||
|
||||
# 每轮摘要
|
||||
rounds: List[RoundSummary] = field(default_factory=list)
|
||||
|
||||
# 最近动作(用于前端实时展示)
|
||||
recent_actions: List[AgentAction] = field(default_factory=list)
|
||||
max_recent_actions: int = 50
|
||||
|
||||
# 时间戳
|
||||
started_at: Optional[str] = None
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
completed_at: Optional[str] = None
|
||||
|
||||
# 错误信息
|
||||
error: Optional[str] = None
|
||||
|
||||
# 进程ID(用于停止)
|
||||
process_pid: Optional[int] = None
|
||||
|
||||
def add_action(self, action: AgentAction):
|
||||
"""添加动作到最近动作列表"""
|
||||
self.recent_actions.insert(0, action)
|
||||
if len(self.recent_actions) > self.max_recent_actions:
|
||||
self.recent_actions = self.recent_actions[:self.max_recent_actions]
|
||||
|
||||
if action.platform == "twitter":
|
||||
self.twitter_actions_count += 1
|
||||
else:
|
||||
self.reddit_actions_count += 1
|
||||
|
||||
self.updated_at = datetime.now().isoformat()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"simulation_id": self.simulation_id,
|
||||
"runner_status": self.runner_status.value,
|
||||
"current_round": self.current_round,
|
||||
"total_rounds": self.total_rounds,
|
||||
"simulated_hours": self.simulated_hours,
|
||||
"total_simulation_hours": self.total_simulation_hours,
|
||||
"progress_percent": round(self.current_round / max(self.total_rounds, 1) * 100, 1),
|
||||
"twitter_running": self.twitter_running,
|
||||
"reddit_running": self.reddit_running,
|
||||
"twitter_actions_count": self.twitter_actions_count,
|
||||
"reddit_actions_count": self.reddit_actions_count,
|
||||
"total_actions_count": self.twitter_actions_count + self.reddit_actions_count,
|
||||
"started_at": self.started_at,
|
||||
"updated_at": self.updated_at,
|
||||
"completed_at": self.completed_at,
|
||||
"error": self.error,
|
||||
"process_pid": self.process_pid,
|
||||
}
|
||||
|
||||
def to_detail_dict(self) -> Dict[str, Any]:
|
||||
"""包含最近动作的详细信息"""
|
||||
result = self.to_dict()
|
||||
result["recent_actions"] = [a.to_dict() for a in self.recent_actions]
|
||||
result["rounds_count"] = len(self.rounds)
|
||||
return result
|
||||
|
||||
|
||||
class SimulationRunner:
|
||||
"""
|
||||
模拟运行器
|
||||
|
||||
负责:
|
||||
1. 在后台进程中运行OASIS模拟
|
||||
2. 解析运行日志,记录每个Agent的动作
|
||||
3. 提供实时状态查询接口
|
||||
4. 支持暂停/停止/恢复操作
|
||||
"""
|
||||
|
||||
# 运行状态存储目录
|
||||
RUN_STATE_DIR = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../../uploads/simulations'
|
||||
)
|
||||
|
||||
# 内存中的运行状态
|
||||
_run_states: Dict[str, SimulationRunState] = {}
|
||||
_processes: Dict[str, subprocess.Popen] = {}
|
||||
_action_queues: Dict[str, Queue] = {}
|
||||
_monitor_threads: Dict[str, threading.Thread] = {}
|
||||
|
||||
@classmethod
|
||||
def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
|
||||
"""获取运行状态"""
|
||||
if simulation_id in cls._run_states:
|
||||
return cls._run_states[simulation_id]
|
||||
|
||||
# 尝试从文件加载
|
||||
state = cls._load_run_state(simulation_id)
|
||||
if state:
|
||||
cls._run_states[simulation_id] = state
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
|
||||
"""从文件加载运行状态"""
|
||||
state_file = os.path.join(cls.RUN_STATE_DIR, simulation_id, "run_state.json")
|
||||
if not os.path.exists(state_file):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(state_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
state = SimulationRunState(
|
||||
simulation_id=simulation_id,
|
||||
runner_status=RunnerStatus(data.get("runner_status", "idle")),
|
||||
current_round=data.get("current_round", 0),
|
||||
total_rounds=data.get("total_rounds", 0),
|
||||
simulated_hours=data.get("simulated_hours", 0),
|
||||
total_simulation_hours=data.get("total_simulation_hours", 0),
|
||||
twitter_running=data.get("twitter_running", False),
|
||||
reddit_running=data.get("reddit_running", False),
|
||||
twitter_actions_count=data.get("twitter_actions_count", 0),
|
||||
reddit_actions_count=data.get("reddit_actions_count", 0),
|
||||
started_at=data.get("started_at"),
|
||||
updated_at=data.get("updated_at", datetime.now().isoformat()),
|
||||
completed_at=data.get("completed_at"),
|
||||
error=data.get("error"),
|
||||
process_pid=data.get("process_pid"),
|
||||
)
|
||||
|
||||
# 加载最近动作
|
||||
actions_data = data.get("recent_actions", [])
|
||||
for a in actions_data:
|
||||
state.recent_actions.append(AgentAction(
|
||||
round_num=a.get("round_num", 0),
|
||||
timestamp=a.get("timestamp", ""),
|
||||
platform=a.get("platform", ""),
|
||||
agent_id=a.get("agent_id", 0),
|
||||
agent_name=a.get("agent_name", ""),
|
||||
action_type=a.get("action_type", ""),
|
||||
action_args=a.get("action_args", {}),
|
||||
result=a.get("result"),
|
||||
success=a.get("success", True),
|
||||
))
|
||||
|
||||
return state
|
||||
except Exception as e:
|
||||
logger.error(f"加载运行状态失败: {str(e)}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _save_run_state(cls, state: SimulationRunState):
|
||||
"""保存运行状态到文件"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id)
|
||||
os.makedirs(sim_dir, exist_ok=True)
|
||||
state_file = os.path.join(sim_dir, "run_state.json")
|
||||
|
||||
data = state.to_detail_dict()
|
||||
|
||||
with open(state_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
cls._run_states[state.simulation_id] = state
|
||||
|
||||
@classmethod
|
||||
def start_simulation(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
platform: str = "parallel" # twitter / reddit / parallel
|
||||
) -> SimulationRunState:
|
||||
"""
|
||||
启动模拟
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
platform: 运行平台 (twitter/reddit/parallel)
|
||||
|
||||
Returns:
|
||||
SimulationRunState
|
||||
"""
|
||||
# 检查是否已在运行
|
||||
existing = cls.get_run_state(simulation_id)
|
||||
if existing and existing.runner_status in [RunnerStatus.RUNNING, RunnerStatus.STARTING]:
|
||||
raise ValueError(f"模拟已在运行中: {simulation_id}")
|
||||
|
||||
# 加载模拟配置
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(f"模拟配置不存在,请先调用 /prepare 接口")
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# 初始化运行状态
|
||||
time_config = config.get("time_config", {})
|
||||
total_hours = time_config.get("total_simulation_hours", 72)
|
||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||
total_rounds = int(total_hours * 60 / minutes_per_round)
|
||||
|
||||
state = SimulationRunState(
|
||||
simulation_id=simulation_id,
|
||||
runner_status=RunnerStatus.STARTING,
|
||||
total_rounds=total_rounds,
|
||||
total_simulation_hours=total_hours,
|
||||
started_at=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
cls._save_run_state(state)
|
||||
|
||||
# 确定运行哪个脚本
|
||||
if platform == "twitter":
|
||||
script_name = "run_twitter_simulation.py"
|
||||
state.twitter_running = True
|
||||
elif platform == "reddit":
|
||||
script_name = "run_reddit_simulation.py"
|
||||
state.reddit_running = True
|
||||
else:
|
||||
script_name = "run_parallel_simulation.py"
|
||||
state.twitter_running = True
|
||||
state.reddit_running = True
|
||||
|
||||
script_path = os.path.join(sim_dir, script_name)
|
||||
|
||||
if not os.path.exists(script_path):
|
||||
raise ValueError(f"脚本不存在: {script_path}")
|
||||
|
||||
# 创建动作队列
|
||||
action_queue = Queue()
|
||||
cls._action_queues[simulation_id] = action_queue
|
||||
|
||||
# 启动模拟进程
|
||||
try:
|
||||
# 构建运行命令
|
||||
cmd = [
|
||||
sys.executable, # Python解释器
|
||||
script_path,
|
||||
"--config", "simulation_config.json",
|
||||
"--action-log", "actions.jsonl", # 动作日志文件
|
||||
]
|
||||
|
||||
# 设置工作目录为模拟目录
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=sim_dir,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
state.process_pid = process.pid
|
||||
state.runner_status = RunnerStatus.RUNNING
|
||||
cls._processes[simulation_id] = process
|
||||
cls._save_run_state(state)
|
||||
|
||||
# 启动监控线程
|
||||
monitor_thread = threading.Thread(
|
||||
target=cls._monitor_simulation,
|
||||
args=(simulation_id,),
|
||||
daemon=True
|
||||
)
|
||||
monitor_thread.start()
|
||||
cls._monitor_threads[simulation_id] = monitor_thread
|
||||
|
||||
logger.info(f"模拟启动成功: {simulation_id}, pid={process.pid}, platform={platform}")
|
||||
|
||||
except Exception as e:
|
||||
state.runner_status = RunnerStatus.FAILED
|
||||
state.error = str(e)
|
||||
cls._save_run_state(state)
|
||||
raise
|
||||
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
def _monitor_simulation(cls, simulation_id: str):
|
||||
"""监控模拟进程,解析动作日志"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
actions_log = os.path.join(sim_dir, "actions.jsonl")
|
||||
|
||||
process = cls._processes.get(simulation_id)
|
||||
state = cls.get_run_state(simulation_id)
|
||||
|
||||
if not process or not state:
|
||||
return
|
||||
|
||||
last_position = 0
|
||||
|
||||
try:
|
||||
while process.poll() is None: # 进程仍在运行
|
||||
# 读取动作日志
|
||||
if os.path.exists(actions_log):
|
||||
with open(actions_log, 'r', encoding='utf-8') as f:
|
||||
f.seek(last_position)
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
action_data = json.loads(line)
|
||||
action = AgentAction(
|
||||
round_num=action_data.get("round", 0),
|
||||
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
|
||||
platform=action_data.get("platform", "unknown"),
|
||||
agent_id=action_data.get("agent_id", 0),
|
||||
agent_name=action_data.get("agent_name", ""),
|
||||
action_type=action_data.get("action_type", ""),
|
||||
action_args=action_data.get("action_args", {}),
|
||||
result=action_data.get("result"),
|
||||
success=action_data.get("success", True),
|
||||
)
|
||||
state.add_action(action)
|
||||
|
||||
# 更新轮次
|
||||
if action.round_num > state.current_round:
|
||||
state.current_round = action.round_num
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
last_position = f.tell()
|
||||
|
||||
# 定期保存状态
|
||||
cls._save_run_state(state)
|
||||
time.sleep(1) # 每秒检查一次
|
||||
|
||||
# 进程结束
|
||||
exit_code = process.returncode
|
||||
|
||||
if exit_code == 0:
|
||||
state.runner_status = RunnerStatus.COMPLETED
|
||||
state.completed_at = datetime.now().isoformat()
|
||||
logger.info(f"模拟完成: {simulation_id}")
|
||||
else:
|
||||
state.runner_status = RunnerStatus.FAILED
|
||||
stderr = process.stderr.read() if process.stderr else ""
|
||||
state.error = f"进程退出码: {exit_code}, 错误: {stderr[:500]}"
|
||||
logger.error(f"模拟失败: {simulation_id}, error={state.error}")
|
||||
|
||||
state.twitter_running = False
|
||||
state.reddit_running = False
|
||||
cls._save_run_state(state)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"监控线程异常: {simulation_id}, error={str(e)}")
|
||||
state.runner_status = RunnerStatus.FAILED
|
||||
state.error = str(e)
|
||||
cls._save_run_state(state)
|
||||
|
||||
finally:
|
||||
# 清理
|
||||
cls._processes.pop(simulation_id, None)
|
||||
cls._action_queues.pop(simulation_id, None)
|
||||
|
||||
@classmethod
|
||||
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
|
||||
"""停止模拟"""
|
||||
state = cls.get_run_state(simulation_id)
|
||||
if not state:
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
if state.runner_status not in [RunnerStatus.RUNNING, RunnerStatus.PAUSED]:
|
||||
raise ValueError(f"模拟未在运行: {simulation_id}, status={state.runner_status}")
|
||||
|
||||
state.runner_status = RunnerStatus.STOPPING
|
||||
cls._save_run_state(state)
|
||||
|
||||
# 终止进程
|
||||
process = cls._processes.get(simulation_id)
|
||||
if process:
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
|
||||
state.runner_status = RunnerStatus.STOPPED
|
||||
state.twitter_running = False
|
||||
state.reddit_running = False
|
||||
state.completed_at = datetime.now().isoformat()
|
||||
cls._save_run_state(state)
|
||||
|
||||
logger.info(f"模拟已停止: {simulation_id}")
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
def get_actions(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
platform: Optional[str] = None,
|
||||
agent_id: Optional[int] = None,
|
||||
round_num: Optional[int] = None
|
||||
) -> List[AgentAction]:
|
||||
"""
|
||||
获取动作历史
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
limit: 返回数量限制
|
||||
offset: 偏移量
|
||||
platform: 过滤平台
|
||||
agent_id: 过滤Agent
|
||||
round_num: 过滤轮次
|
||||
|
||||
Returns:
|
||||
动作列表
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
actions_log = os.path.join(sim_dir, "actions.jsonl")
|
||||
|
||||
if not os.path.exists(actions_log):
|
||||
return []
|
||||
|
||||
actions = []
|
||||
|
||||
with open(actions_log, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
|
||||
# 过滤
|
||||
if platform and data.get("platform") != platform:
|
||||
continue
|
||||
if agent_id is not None and data.get("agent_id") != agent_id:
|
||||
continue
|
||||
if round_num is not None and data.get("round") != round_num:
|
||||
continue
|
||||
|
||||
actions.append(AgentAction(
|
||||
round_num=data.get("round", 0),
|
||||
timestamp=data.get("timestamp", ""),
|
||||
platform=data.get("platform", ""),
|
||||
agent_id=data.get("agent_id", 0),
|
||||
agent_name=data.get("agent_name", ""),
|
||||
action_type=data.get("action_type", ""),
|
||||
action_args=data.get("action_args", {}),
|
||||
result=data.get("result"),
|
||||
success=data.get("success", True),
|
||||
))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# 按时间倒序排列
|
||||
actions.reverse()
|
||||
|
||||
# 分页
|
||||
return actions[offset:offset + limit]
|
||||
|
||||
@classmethod
|
||||
def get_timeline(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
start_round: int = 0,
|
||||
end_round: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取模拟时间线(按轮次汇总)
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
start_round: 起始轮次
|
||||
end_round: 结束轮次
|
||||
|
||||
Returns:
|
||||
每轮的汇总信息
|
||||
"""
|
||||
actions = cls.get_actions(simulation_id, limit=10000)
|
||||
|
||||
# 按轮次分组
|
||||
rounds: Dict[int, Dict[str, Any]] = {}
|
||||
|
||||
for action in actions:
|
||||
round_num = action.round_num
|
||||
|
||||
if round_num < start_round:
|
||||
continue
|
||||
if end_round is not None and round_num > end_round:
|
||||
continue
|
||||
|
||||
if round_num not in rounds:
|
||||
rounds[round_num] = {
|
||||
"round_num": round_num,
|
||||
"twitter_actions": 0,
|
||||
"reddit_actions": 0,
|
||||
"active_agents": set(),
|
||||
"action_types": {},
|
||||
"first_action_time": action.timestamp,
|
||||
"last_action_time": action.timestamp,
|
||||
}
|
||||
|
||||
r = rounds[round_num]
|
||||
|
||||
if action.platform == "twitter":
|
||||
r["twitter_actions"] += 1
|
||||
else:
|
||||
r["reddit_actions"] += 1
|
||||
|
||||
r["active_agents"].add(action.agent_id)
|
||||
r["action_types"][action.action_type] = r["action_types"].get(action.action_type, 0) + 1
|
||||
r["last_action_time"] = action.timestamp
|
||||
|
||||
# 转换为列表
|
||||
result = []
|
||||
for round_num in sorted(rounds.keys()):
|
||||
r = rounds[round_num]
|
||||
result.append({
|
||||
"round_num": round_num,
|
||||
"twitter_actions": r["twitter_actions"],
|
||||
"reddit_actions": r["reddit_actions"],
|
||||
"total_actions": r["twitter_actions"] + r["reddit_actions"],
|
||||
"active_agents_count": len(r["active_agents"]),
|
||||
"active_agents": list(r["active_agents"]),
|
||||
"action_types": r["action_types"],
|
||||
"first_action_time": r["first_action_time"],
|
||||
"last_action_time": r["last_action_time"],
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取每个Agent的统计信息
|
||||
|
||||
Returns:
|
||||
Agent统计列表
|
||||
"""
|
||||
actions = cls.get_actions(simulation_id, limit=10000)
|
||||
|
||||
agent_stats: Dict[int, Dict[str, Any]] = {}
|
||||
|
||||
for action in actions:
|
||||
agent_id = action.agent_id
|
||||
|
||||
if agent_id not in agent_stats:
|
||||
agent_stats[agent_id] = {
|
||||
"agent_id": agent_id,
|
||||
"agent_name": action.agent_name,
|
||||
"total_actions": 0,
|
||||
"twitter_actions": 0,
|
||||
"reddit_actions": 0,
|
||||
"action_types": {},
|
||||
"first_action_time": action.timestamp,
|
||||
"last_action_time": action.timestamp,
|
||||
}
|
||||
|
||||
stats = agent_stats[agent_id]
|
||||
stats["total_actions"] += 1
|
||||
|
||||
if action.platform == "twitter":
|
||||
stats["twitter_actions"] += 1
|
||||
else:
|
||||
stats["reddit_actions"] += 1
|
||||
|
||||
stats["action_types"][action.action_type] = stats["action_types"].get(action.action_type, 0) + 1
|
||||
stats["last_action_time"] = action.timestamp
|
||||
|
||||
# 按总动作数排序
|
||||
result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True)
|
||||
|
||||
return result
|
||||
|
||||
386
backend/app/services/zep_entity_reader.py
Normal file
386
backend/app/services/zep_entity_reader.py
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
"""
|
||||
Zep实体读取与过滤服务
|
||||
从Zep图谱中读取节点,筛选出符合预定义实体类型的节点
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Set
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.zep_entity_reader')
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityNode:
|
||||
"""实体节点数据结构"""
|
||||
uuid: str
|
||||
name: str
|
||||
labels: List[str]
|
||||
summary: str
|
||||
attributes: Dict[str, Any]
|
||||
# 相关的边信息
|
||||
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
||||
# 相关的其他节点信息
|
||||
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"uuid": self.uuid,
|
||||
"name": self.name,
|
||||
"labels": self.labels,
|
||||
"summary": self.summary,
|
||||
"attributes": self.attributes,
|
||||
"related_edges": self.related_edges,
|
||||
"related_nodes": self.related_nodes,
|
||||
}
|
||||
|
||||
def get_entity_type(self) -> Optional[str]:
|
||||
"""获取实体类型(排除默认的Entity标签)"""
|
||||
for label in self.labels:
|
||||
if label not in ["Entity", "Node"]:
|
||||
return label
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilteredEntities:
|
||||
"""过滤后的实体集合"""
|
||||
entities: List[EntityNode]
|
||||
entity_types: Set[str]
|
||||
total_count: int
|
||||
filtered_count: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"entities": [e.to_dict() for e in self.entities],
|
||||
"entity_types": list(self.entity_types),
|
||||
"total_count": self.total_count,
|
||||
"filtered_count": self.filtered_count,
|
||||
}
|
||||
|
||||
|
||||
class ZepEntityReader:
|
||||
"""
|
||||
Zep实体读取与过滤服务
|
||||
|
||||
主要功能:
|
||||
1. 从Zep图谱读取所有节点
|
||||
2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点)
|
||||
3. 获取每个实体的相关边和关联节点信息
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY 未配置")
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有节点
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
Returns:
|
||||
节点列表
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||
|
||||
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
||||
|
||||
nodes_data = []
|
||||
for node in nodes:
|
||||
nodes_data.append({
|
||||
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||
"name": node.name or "",
|
||||
"labels": node.labels or [],
|
||||
"summary": node.summary or "",
|
||||
"attributes": node.attributes or {},
|
||||
})
|
||||
|
||||
logger.info(f"共获取 {len(nodes_data)} 个节点")
|
||||
return nodes_data
|
||||
|
||||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有边
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
Returns:
|
||||
边列表
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||
|
||||
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
edges_data.append({
|
||||
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
||||
"name": edge.name or "",
|
||||
"fact": edge.fact or "",
|
||||
"source_node_uuid": edge.source_node_uuid,
|
||||
"target_node_uuid": edge.target_node_uuid,
|
||||
"attributes": edge.attributes or {},
|
||||
})
|
||||
|
||||
logger.info(f"共获取 {len(edges_data)} 条边")
|
||||
return edges_data
|
||||
|
||||
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定节点的所有相关边
|
||||
|
||||
Args:
|
||||
node_uuid: 节点UUID
|
||||
|
||||
Returns:
|
||||
边列表
|
||||
"""
|
||||
try:
|
||||
edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
edges_data.append({
|
||||
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
||||
"name": edge.name or "",
|
||||
"fact": edge.fact or "",
|
||||
"source_node_uuid": edge.source_node_uuid,
|
||||
"target_node_uuid": edge.target_node_uuid,
|
||||
"attributes": edge.attributes or {},
|
||||
})
|
||||
|
||||
return edges_data
|
||||
except Exception as e:
|
||||
logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def filter_defined_entities(
|
||||
self,
|
||||
graph_id: str,
|
||||
defined_entity_types: Optional[List[str]] = None,
|
||||
enrich_with_edges: bool = True
|
||||
) -> FilteredEntities:
|
||||
"""
|
||||
筛选出符合预定义实体类型的节点
|
||||
|
||||
筛选逻辑:
|
||||
- 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过
|
||||
- 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型)
|
||||
enrich_with_edges: 是否获取每个实体的相关边信息
|
||||
|
||||
Returns:
|
||||
FilteredEntities: 过滤后的实体集合
|
||||
"""
|
||||
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
|
||||
|
||||
# 获取所有节点
|
||||
all_nodes = self.get_all_nodes(graph_id)
|
||||
total_count = len(all_nodes)
|
||||
|
||||
# 获取所有边(用于后续关联查找)
|
||||
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
||||
|
||||
# 构建节点UUID到节点数据的映射
|
||||
node_map = {n["uuid"]: n for n in all_nodes}
|
||||
|
||||
# 筛选符合条件的实体
|
||||
filtered_entities = []
|
||||
entity_types_found = set()
|
||||
|
||||
for node in all_nodes:
|
||||
labels = node.get("labels", [])
|
||||
|
||||
# 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签
|
||||
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
|
||||
|
||||
if not custom_labels:
|
||||
# 只有默认标签,跳过
|
||||
continue
|
||||
|
||||
# 如果指定了预定义类型,检查是否匹配
|
||||
if defined_entity_types:
|
||||
matching_labels = [l for l in custom_labels if l in defined_entity_types]
|
||||
if not matching_labels:
|
||||
continue
|
||||
entity_type = matching_labels[0]
|
||||
else:
|
||||
entity_type = custom_labels[0]
|
||||
|
||||
entity_types_found.add(entity_type)
|
||||
|
||||
# 创建实体节点对象
|
||||
entity = EntityNode(
|
||||
uuid=node["uuid"],
|
||||
name=node["name"],
|
||||
labels=labels,
|
||||
summary=node["summary"],
|
||||
attributes=node["attributes"],
|
||||
)
|
||||
|
||||
# 获取相关边和节点
|
||||
if enrich_with_edges:
|
||||
related_edges = []
|
||||
related_node_uuids = set()
|
||||
|
||||
for edge in all_edges:
|
||||
if edge["source_node_uuid"] == node["uuid"]:
|
||||
related_edges.append({
|
||||
"direction": "outgoing",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"target_node_uuid": edge["target_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["target_node_uuid"])
|
||||
elif edge["target_node_uuid"] == node["uuid"]:
|
||||
related_edges.append({
|
||||
"direction": "incoming",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"source_node_uuid": edge["source_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["source_node_uuid"])
|
||||
|
||||
entity.related_edges = related_edges
|
||||
|
||||
# 获取关联节点的基本信息
|
||||
related_nodes = []
|
||||
for related_uuid in related_node_uuids:
|
||||
if related_uuid in node_map:
|
||||
related_node = node_map[related_uuid]
|
||||
related_nodes.append({
|
||||
"uuid": related_node["uuid"],
|
||||
"name": related_node["name"],
|
||||
"labels": related_node["labels"],
|
||||
"summary": related_node.get("summary", ""),
|
||||
})
|
||||
|
||||
entity.related_nodes = related_nodes
|
||||
|
||||
filtered_entities.append(entity)
|
||||
|
||||
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
|
||||
f"实体类型: {entity_types_found}")
|
||||
|
||||
return FilteredEntities(
|
||||
entities=filtered_entities,
|
||||
entity_types=entity_types_found,
|
||||
total_count=total_count,
|
||||
filtered_count=len(filtered_entities),
|
||||
)
|
||||
|
||||
def get_entity_with_context(
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_uuid: str
|
||||
) -> Optional[EntityNode]:
|
||||
"""
|
||||
获取单个实体及其完整上下文(边和关联节点)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
entity_uuid: 实体UUID
|
||||
|
||||
Returns:
|
||||
EntityNode或None
|
||||
"""
|
||||
try:
|
||||
# 获取节点
|
||||
node = self.client.graph.node.get(uuid_=entity_uuid)
|
||||
|
||||
if not node:
|
||||
return None
|
||||
|
||||
# 获取节点的边
|
||||
edges = self.get_node_edges(entity_uuid)
|
||||
|
||||
# 获取所有节点用于关联查找
|
||||
all_nodes = self.get_all_nodes(graph_id)
|
||||
node_map = {n["uuid"]: n for n in all_nodes}
|
||||
|
||||
# 处理相关边和节点
|
||||
related_edges = []
|
||||
related_node_uuids = set()
|
||||
|
||||
for edge in edges:
|
||||
if edge["source_node_uuid"] == entity_uuid:
|
||||
related_edges.append({
|
||||
"direction": "outgoing",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"target_node_uuid": edge["target_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["target_node_uuid"])
|
||||
else:
|
||||
related_edges.append({
|
||||
"direction": "incoming",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"source_node_uuid": edge["source_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["source_node_uuid"])
|
||||
|
||||
# 获取关联节点信息
|
||||
related_nodes = []
|
||||
for related_uuid in related_node_uuids:
|
||||
if related_uuid in node_map:
|
||||
related_node = node_map[related_uuid]
|
||||
related_nodes.append({
|
||||
"uuid": related_node["uuid"],
|
||||
"name": related_node["name"],
|
||||
"labels": related_node["labels"],
|
||||
"summary": related_node.get("summary", ""),
|
||||
})
|
||||
|
||||
return EntityNode(
|
||||
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||
name=node.name or "",
|
||||
labels=node.labels or [],
|
||||
summary=node.summary or "",
|
||||
attributes=node.attributes or {},
|
||||
related_edges=related_edges,
|
||||
related_nodes=related_nodes,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_entities_by_type(
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_type: str,
|
||||
enrich_with_edges: bool = True
|
||||
) -> List[EntityNode]:
|
||||
"""
|
||||
获取指定类型的所有实体
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
entity_type: 实体类型(如 "Student", "PublicFigure" 等)
|
||||
enrich_with_edges: 是否获取相关边信息
|
||||
|
||||
Returns:
|
||||
实体列表
|
||||
"""
|
||||
result = self.filter_defined_entities(
|
||||
graph_id=graph_id,
|
||||
defined_entity_types=[entity_type],
|
||||
enrich_with_edges=enrich_with_edges
|
||||
)
|
||||
return result.entities
|
||||
|
||||
|
||||
238
backend/app/utils/retry.py
Normal file
238
backend/app/utils/retry.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""
|
||||
API调用重试机制
|
||||
用于处理LLM等外部API调用的重试逻辑
|
||||
"""
|
||||
|
||||
import time
|
||||
import random
|
||||
import functools
|
||||
from typing import Callable, Any, Optional, Type, Tuple
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.retry')
|
||||
|
||||
|
||||
def retry_with_backoff(
|
||||
max_retries: int = 3,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 30.0,
|
||||
backoff_factor: float = 2.0,
|
||||
jitter: bool = True,
|
||||
exceptions: Tuple[Type[Exception], ...] = (Exception,),
|
||||
on_retry: Optional[Callable[[Exception, int], None]] = None
|
||||
):
|
||||
"""
|
||||
带指数退避的重试装饰器
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数
|
||||
initial_delay: 初始延迟(秒)
|
||||
max_delay: 最大延迟(秒)
|
||||
backoff_factor: 退避因子
|
||||
jitter: 是否添加随机抖动
|
||||
exceptions: 需要重试的异常类型
|
||||
on_retry: 重试时的回调函数 (exception, retry_count)
|
||||
|
||||
Usage:
|
||||
@retry_with_backoff(max_retries=3)
|
||||
def call_llm_api():
|
||||
...
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
last_exception = None
|
||||
delay = initial_delay
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
if attempt == max_retries:
|
||||
logger.error(f"函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}")
|
||||
raise
|
||||
|
||||
# 计算延迟
|
||||
current_delay = min(delay, max_delay)
|
||||
if jitter:
|
||||
current_delay = current_delay * (0.5 + random.random())
|
||||
|
||||
logger.warning(
|
||||
f"函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, "
|
||||
f"{current_delay:.1f}秒后重试..."
|
||||
)
|
||||
|
||||
if on_retry:
|
||||
on_retry(e, attempt + 1)
|
||||
|
||||
time.sleep(current_delay)
|
||||
delay *= backoff_factor
|
||||
|
||||
raise last_exception
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def retry_with_backoff_async(
|
||||
max_retries: int = 3,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 30.0,
|
||||
backoff_factor: float = 2.0,
|
||||
jitter: bool = True,
|
||||
exceptions: Tuple[Type[Exception], ...] = (Exception,),
|
||||
on_retry: Optional[Callable[[Exception, int], None]] = None
|
||||
):
|
||||
"""
|
||||
异步版本的重试装饰器
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs) -> Any:
|
||||
last_exception = None
|
||||
delay = initial_delay
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
except exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
if attempt == max_retries:
|
||||
logger.error(f"异步函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}")
|
||||
raise
|
||||
|
||||
current_delay = min(delay, max_delay)
|
||||
if jitter:
|
||||
current_delay = current_delay * (0.5 + random.random())
|
||||
|
||||
logger.warning(
|
||||
f"异步函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, "
|
||||
f"{current_delay:.1f}秒后重试..."
|
||||
)
|
||||
|
||||
if on_retry:
|
||||
on_retry(e, attempt + 1)
|
||||
|
||||
await asyncio.sleep(current_delay)
|
||||
delay *= backoff_factor
|
||||
|
||||
raise last_exception
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
class RetryableAPIClient:
|
||||
"""
|
||||
可重试的API客户端封装
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retries: int = 3,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 30.0,
|
||||
backoff_factor: float = 2.0
|
||||
):
|
||||
self.max_retries = max_retries
|
||||
self.initial_delay = initial_delay
|
||||
self.max_delay = max_delay
|
||||
self.backoff_factor = backoff_factor
|
||||
|
||||
def call_with_retry(
|
||||
self,
|
||||
func: Callable,
|
||||
*args,
|
||||
exceptions: Tuple[Type[Exception], ...] = (Exception,),
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
执行函数调用并在失败时重试
|
||||
|
||||
Args:
|
||||
func: 要调用的函数
|
||||
*args: 函数参数
|
||||
exceptions: 需要重试的异常类型
|
||||
**kwargs: 函数关键字参数
|
||||
|
||||
Returns:
|
||||
函数返回值
|
||||
"""
|
||||
last_exception = None
|
||||
delay = self.initial_delay
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
if attempt == self.max_retries:
|
||||
logger.error(f"API调用在 {self.max_retries} 次重试后仍失败: {str(e)}")
|
||||
raise
|
||||
|
||||
current_delay = min(delay, self.max_delay)
|
||||
current_delay = current_delay * (0.5 + random.random())
|
||||
|
||||
logger.warning(
|
||||
f"API调用第 {attempt + 1} 次尝试失败: {str(e)}, "
|
||||
f"{current_delay:.1f}秒后重试..."
|
||||
)
|
||||
|
||||
time.sleep(current_delay)
|
||||
delay *= self.backoff_factor
|
||||
|
||||
raise last_exception
|
||||
|
||||
def call_batch_with_retry(
|
||||
self,
|
||||
items: list,
|
||||
process_func: Callable,
|
||||
exceptions: Tuple[Type[Exception], ...] = (Exception,),
|
||||
continue_on_failure: bool = True
|
||||
) -> Tuple[list, list]:
|
||||
"""
|
||||
批量调用并对每个失败项单独重试
|
||||
|
||||
Args:
|
||||
items: 要处理的项目列表
|
||||
process_func: 处理函数,接收单个item作为参数
|
||||
exceptions: 需要重试的异常类型
|
||||
continue_on_failure: 单项失败后是否继续处理其他项
|
||||
|
||||
Returns:
|
||||
(成功结果列表, 失败项列表)
|
||||
"""
|
||||
results = []
|
||||
failures = []
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
try:
|
||||
result = self.call_with_retry(
|
||||
process_func,
|
||||
item,
|
||||
exceptions=exceptions
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理第 {idx + 1} 项失败: {str(e)}")
|
||||
failures.append({
|
||||
"index": idx,
|
||||
"item": item,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
if not continue_on_failure:
|
||||
raise
|
||||
|
||||
return results, failures
|
||||
|
||||
|
|
@ -20,3 +20,7 @@ pydantic>=2.0.0
|
|||
# 文件处理
|
||||
werkzeug>=3.0.0
|
||||
|
||||
# OASIS社交媒体模拟框架
|
||||
oasis-ai>=0.1.0
|
||||
camel-ai>=0.2.0
|
||||
|
||||
|
|
|
|||
138
backend/scripts/action_logger.py
Normal file
138
backend/scripts/action_logger.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""
|
||||
动作日志记录器
|
||||
用于记录OASIS模拟中每个Agent的动作,供后端监控使用
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class ActionLogger:
|
||||
"""动作日志记录器"""
|
||||
|
||||
def __init__(self, log_path: str):
|
||||
"""
|
||||
初始化日志记录器
|
||||
|
||||
Args:
|
||||
log_path: 日志文件路径(.jsonl格式)
|
||||
"""
|
||||
self.log_path = log_path
|
||||
self._ensure_dir()
|
||||
|
||||
def _ensure_dir(self):
|
||||
"""确保目录存在"""
|
||||
log_dir = os.path.dirname(self.log_path)
|
||||
if log_dir:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
def log_action(
|
||||
self,
|
||||
round_num: int,
|
||||
platform: str,
|
||||
agent_id: int,
|
||||
agent_name: str,
|
||||
action_type: str,
|
||||
action_args: Optional[Dict[str, Any]] = None,
|
||||
result: Optional[str] = None,
|
||||
success: bool = True
|
||||
):
|
||||
"""
|
||||
记录一个动作
|
||||
|
||||
Args:
|
||||
round_num: 轮次
|
||||
platform: 平台 (twitter/reddit)
|
||||
agent_id: Agent ID
|
||||
agent_name: Agent名称
|
||||
action_type: 动作类型
|
||||
action_args: 动作参数
|
||||
result: 执行结果
|
||||
success: 是否成功
|
||||
"""
|
||||
entry = {
|
||||
"round": round_num,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"platform": platform,
|
||||
"agent_id": agent_id,
|
||||
"agent_name": agent_name,
|
||||
"action_type": action_type,
|
||||
"action_args": action_args or {},
|
||||
"result": result,
|
||||
"success": success,
|
||||
}
|
||||
|
||||
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
def log_round_start(self, round_num: int, simulated_hour: int, platform: str):
|
||||
"""记录轮次开始"""
|
||||
entry = {
|
||||
"round": round_num,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"platform": platform,
|
||||
"event_type": "round_start",
|
||||
"simulated_hour": simulated_hour,
|
||||
}
|
||||
|
||||
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
def log_round_end(self, round_num: int, actions_count: int, platform: str):
|
||||
"""记录轮次结束"""
|
||||
entry = {
|
||||
"round": round_num,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"platform": platform,
|
||||
"event_type": "round_end",
|
||||
"actions_count": actions_count,
|
||||
}
|
||||
|
||||
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
def log_simulation_start(self, platform: str, config: Dict[str, Any]):
|
||||
"""记录模拟开始"""
|
||||
entry = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"platform": platform,
|
||||
"event_type": "simulation_start",
|
||||
"total_rounds": config.get("time_config", {}).get("total_simulation_hours", 72) * 2,
|
||||
"agents_count": len(config.get("agent_configs", [])),
|
||||
}
|
||||
|
||||
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
def log_simulation_end(self, platform: str, total_rounds: int, total_actions: int):
|
||||
"""记录模拟结束"""
|
||||
entry = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"platform": platform,
|
||||
"event_type": "simulation_end",
|
||||
"total_rounds": total_rounds,
|
||||
"total_actions": total_actions,
|
||||
}
|
||||
|
||||
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
# 全局日志实例(可选)
|
||||
_global_logger: Optional[ActionLogger] = None
|
||||
|
||||
|
||||
def get_logger(log_path: Optional[str] = None) -> ActionLogger:
|
||||
"""获取全局日志实例"""
|
||||
global _global_logger
|
||||
|
||||
if log_path:
|
||||
_global_logger = ActionLogger(log_path)
|
||||
|
||||
if _global_logger is None:
|
||||
_global_logger = ActionLogger("actions.jsonl")
|
||||
|
||||
return _global_logger
|
||||
|
||||
503
backend/scripts/run_parallel_simulation.py
Normal file
503
backend/scripts/run_parallel_simulation.py
Normal file
|
|
@ -0,0 +1,503 @@
|
|||
"""
|
||||
OASIS 双平台并行模拟预设脚本
|
||||
同时运行Twitter和Reddit模拟,读取相同的配置文件
|
||||
|
||||
使用方式:
|
||||
python run_parallel_simulation.py --config simulation_config.json [--action-log actions.jsonl]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from action_logger import ActionLogger
|
||||
|
||||
try:
|
||||
from camel.models import ModelFactory
|
||||
from camel.types import ModelPlatformType
|
||||
import oasis
|
||||
from oasis import (
|
||||
ActionType,
|
||||
LLMAction,
|
||||
ManualAction,
|
||||
generate_twitter_agent_graph,
|
||||
generate_reddit_agent_graph
|
||||
)
|
||||
except ImportError as e:
|
||||
print(f"错误: 缺少依赖 {e}")
|
||||
print("请先安装: pip install oasis-ai camel-ai")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# Twitter可用动作
|
||||
TWITTER_ACTIONS = [
|
||||
ActionType.CREATE_POST,
|
||||
ActionType.LIKE_POST,
|
||||
ActionType.REPOST,
|
||||
ActionType.FOLLOW,
|
||||
ActionType.DO_NOTHING,
|
||||
ActionType.QUOTE_POST,
|
||||
]
|
||||
|
||||
# Reddit可用动作
|
||||
REDDIT_ACTIONS = [
|
||||
ActionType.LIKE_POST,
|
||||
ActionType.DISLIKE_POST,
|
||||
ActionType.CREATE_POST,
|
||||
ActionType.CREATE_COMMENT,
|
||||
ActionType.LIKE_COMMENT,
|
||||
ActionType.DISLIKE_COMMENT,
|
||||
ActionType.SEARCH_POSTS,
|
||||
ActionType.SEARCH_USER,
|
||||
ActionType.TREND,
|
||||
ActionType.REFRESH,
|
||||
ActionType.DO_NOTHING,
|
||||
ActionType.FOLLOW,
|
||||
ActionType.MUTE,
|
||||
]
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def create_model(config: Dict[str, Any]):
|
||||
"""
|
||||
创建LLM模型
|
||||
|
||||
OASIS使用camel-ai的ModelFactory,配置方式:
|
||||
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量
|
||||
- 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量
|
||||
"""
|
||||
llm_model = config.get("llm_model", "gpt-4o-mini")
|
||||
llm_base_url = config.get("llm_base_url", "")
|
||||
|
||||
# 如果配置了base_url,设置环境变量
|
||||
if llm_base_url:
|
||||
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
|
||||
|
||||
return ModelFactory.create(
|
||||
model_platform=ModelPlatformType.OPENAI,
|
||||
model_type=llm_model,
|
||||
)
|
||||
|
||||
|
||||
def get_active_agents_for_round(
|
||||
env,
|
||||
config: Dict[str, Any],
|
||||
current_hour: int,
|
||||
round_num: int
|
||||
) -> List:
|
||||
"""根据时间和配置决定本轮激活哪些Agent"""
|
||||
time_config = config.get("time_config", {})
|
||||
agent_configs = config.get("agent_configs", [])
|
||||
|
||||
base_min = time_config.get("agents_per_hour_min", 5)
|
||||
base_max = time_config.get("agents_per_hour_max", 20)
|
||||
|
||||
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
|
||||
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
|
||||
|
||||
if current_hour in peak_hours:
|
||||
multiplier = time_config.get("peak_activity_multiplier", 1.5)
|
||||
elif current_hour in off_peak_hours:
|
||||
multiplier = time_config.get("off_peak_activity_multiplier", 0.3)
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
target_count = int(random.uniform(base_min, base_max) * multiplier)
|
||||
|
||||
candidates = []
|
||||
for cfg in agent_configs:
|
||||
agent_id = cfg.get("agent_id", 0)
|
||||
active_hours = cfg.get("active_hours", list(range(8, 23)))
|
||||
activity_level = cfg.get("activity_level", 0.5)
|
||||
|
||||
if current_hour not in active_hours:
|
||||
continue
|
||||
|
||||
if random.random() < activity_level:
|
||||
candidates.append(agent_id)
|
||||
|
||||
selected_ids = random.sample(
|
||||
candidates,
|
||||
min(target_count, len(candidates))
|
||||
) if candidates else []
|
||||
|
||||
active_agents = []
|
||||
for agent_id in selected_ids:
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
active_agents.append((agent_id, agent))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return active_agents
|
||||
|
||||
|
||||
async def run_twitter_simulation(
|
||||
config: Dict[str, Any],
|
||||
simulation_dir: str,
|
||||
action_logger: Optional[ActionLogger] = None
|
||||
):
|
||||
"""运行Twitter模拟"""
|
||||
print("[Twitter] 初始化...")
|
||||
|
||||
model = create_model(config)
|
||||
|
||||
# OASIS Twitter使用CSV格式
|
||||
profile_path = os.path.join(simulation_dir, "twitter_profiles.csv")
|
||||
if not os.path.exists(profile_path):
|
||||
print(f"[Twitter] 错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
|
||||
agent_graph = await generate_twitter_agent_graph(
|
||||
profile_path=profile_path,
|
||||
model=model,
|
||||
available_actions=TWITTER_ACTIONS,
|
||||
)
|
||||
|
||||
# 获取Agent名称映射
|
||||
agent_names = {}
|
||||
for agent_id, agent in agent_graph.get_agents():
|
||||
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
||||
|
||||
db_path = os.path.join(simulation_dir, "twitter_simulation.db")
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
|
||||
env = oasis.make(
|
||||
agent_graph=agent_graph,
|
||||
platform=oasis.DefaultPlatformType.TWITTER,
|
||||
database_path=db_path,
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
print("[Twitter] 环境已启动")
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_simulation_start("twitter", config)
|
||||
|
||||
total_actions = 0
|
||||
|
||||
# 执行初始事件
|
||||
event_config = config.get("event_config", {})
|
||||
initial_posts = event_config.get("initial_posts", [])
|
||||
|
||||
if initial_posts:
|
||||
initial_actions = {}
|
||||
for post in initial_posts:
|
||||
agent_id = post.get("poster_agent_id", 0)
|
||||
content = post.get("content", "")
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
initial_actions[agent] = ManualAction(
|
||||
action_type=ActionType.CREATE_POST,
|
||||
action_args={"content": content}
|
||||
)
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_action(
|
||||
round_num=0,
|
||||
platform="twitter",
|
||||
agent_id=agent_id,
|
||||
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
||||
action_type="CREATE_POST",
|
||||
action_args={"content": content[:100] + "..." if len(content) > 100 else content}
|
||||
)
|
||||
total_actions += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
print(f"[Twitter] 已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 主模拟循环
|
||||
time_config = config.get("time_config", {})
|
||||
total_hours = time_config.get("total_simulation_hours", 72)
|
||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||
total_rounds = (total_hours * 60) // minutes_per_round
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
for round_num in range(total_rounds):
|
||||
simulated_minutes = round_num * minutes_per_round
|
||||
simulated_hour = (simulated_minutes // 60) % 24
|
||||
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||
|
||||
active_agents = get_active_agents_for_round(
|
||||
env, config, simulated_hour, round_num
|
||||
)
|
||||
|
||||
if not active_agents:
|
||||
continue
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_round_start(round_num + 1, simulated_hour, "twitter")
|
||||
|
||||
actions = {agent: LLMAction() for _, agent in active_agents}
|
||||
await env.step(actions)
|
||||
|
||||
# 记录动作
|
||||
for agent_id, agent in active_agents:
|
||||
if action_logger:
|
||||
action_logger.log_action(
|
||||
round_num=round_num + 1,
|
||||
platform="twitter",
|
||||
agent_id=agent_id,
|
||||
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
||||
action_type="LLM_ACTION",
|
||||
action_args={}
|
||||
)
|
||||
total_actions += 1
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_round_end(round_num + 1, len(active_agents), "twitter")
|
||||
|
||||
if (round_num + 1) % 20 == 0:
|
||||
progress = (round_num + 1) / total_rounds * 100
|
||||
print(f"[Twitter] Day {simulated_day}, {simulated_hour:02d}:00 "
|
||||
f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
||||
|
||||
await env.close()
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_simulation_end("twitter", total_rounds, total_actions)
|
||||
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
print(f"[Twitter] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
|
||||
|
||||
async def run_reddit_simulation(
|
||||
config: Dict[str, Any],
|
||||
simulation_dir: str,
|
||||
action_logger: Optional[ActionLogger] = None
|
||||
):
|
||||
"""运行Reddit模拟"""
|
||||
print("[Reddit] 初始化...")
|
||||
|
||||
model = create_model(config)
|
||||
|
||||
profile_path = os.path.join(simulation_dir, "reddit_profiles.json")
|
||||
if not os.path.exists(profile_path):
|
||||
print(f"[Reddit] 错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
|
||||
agent_graph = await generate_reddit_agent_graph(
|
||||
profile_path=profile_path,
|
||||
model=model,
|
||||
available_actions=REDDIT_ACTIONS,
|
||||
)
|
||||
|
||||
# 获取Agent名称映射
|
||||
agent_names = {}
|
||||
for agent_id, agent in agent_graph.get_agents():
|
||||
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
||||
|
||||
db_path = os.path.join(simulation_dir, "reddit_simulation.db")
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
|
||||
env = oasis.make(
|
||||
agent_graph=agent_graph,
|
||||
platform=oasis.DefaultPlatformType.REDDIT,
|
||||
database_path=db_path,
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
print("[Reddit] 环境已启动")
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_simulation_start("reddit", config)
|
||||
|
||||
total_actions = 0
|
||||
|
||||
# 执行初始事件
|
||||
event_config = config.get("event_config", {})
|
||||
initial_posts = event_config.get("initial_posts", [])
|
||||
|
||||
if initial_posts:
|
||||
initial_actions = {}
|
||||
for post in initial_posts:
|
||||
agent_id = post.get("poster_agent_id", 0)
|
||||
content = post.get("content", "")
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
if agent in initial_actions:
|
||||
if not isinstance(initial_actions[agent], list):
|
||||
initial_actions[agent] = [initial_actions[agent]]
|
||||
initial_actions[agent].append(ManualAction(
|
||||
action_type=ActionType.CREATE_POST,
|
||||
action_args={"content": content}
|
||||
))
|
||||
else:
|
||||
initial_actions[agent] = ManualAction(
|
||||
action_type=ActionType.CREATE_POST,
|
||||
action_args={"content": content}
|
||||
)
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_action(
|
||||
round_num=0,
|
||||
platform="reddit",
|
||||
agent_id=agent_id,
|
||||
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
||||
action_type="CREATE_POST",
|
||||
action_args={"content": content[:100] + "..." if len(content) > 100 else content}
|
||||
)
|
||||
total_actions += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
print(f"[Reddit] 已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 主模拟循环
|
||||
time_config = config.get("time_config", {})
|
||||
total_hours = time_config.get("total_simulation_hours", 72)
|
||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||
total_rounds = (total_hours * 60) // minutes_per_round
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
for round_num in range(total_rounds):
|
||||
simulated_minutes = round_num * minutes_per_round
|
||||
simulated_hour = (simulated_minutes // 60) % 24
|
||||
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||
|
||||
active_agents = get_active_agents_for_round(
|
||||
env, config, simulated_hour, round_num
|
||||
)
|
||||
|
||||
if not active_agents:
|
||||
continue
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_round_start(round_num + 1, simulated_hour, "reddit")
|
||||
|
||||
actions = {agent: LLMAction() for _, agent in active_agents}
|
||||
await env.step(actions)
|
||||
|
||||
# 记录动作
|
||||
for agent_id, agent in active_agents:
|
||||
if action_logger:
|
||||
action_logger.log_action(
|
||||
round_num=round_num + 1,
|
||||
platform="reddit",
|
||||
agent_id=agent_id,
|
||||
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
||||
action_type="LLM_ACTION",
|
||||
action_args={}
|
||||
)
|
||||
total_actions += 1
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_round_end(round_num + 1, len(active_agents), "reddit")
|
||||
|
||||
if (round_num + 1) % 20 == 0:
|
||||
progress = (round_num + 1) / total_rounds * 100
|
||||
print(f"[Reddit] Day {simulated_day}, {simulated_hour:02d}:00 "
|
||||
f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
||||
|
||||
await env.close()
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_simulation_end("reddit", total_rounds, total_actions)
|
||||
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
print(f"[Reddit] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description='OASIS双平台并行模拟')
|
||||
parser.add_argument(
|
||||
'--config',
|
||||
type=str,
|
||||
required=True,
|
||||
help='配置文件路径 (simulation_config.json)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--twitter-only',
|
||||
action='store_true',
|
||||
help='只运行Twitter模拟'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--reddit-only',
|
||||
action='store_true',
|
||||
help='只运行Reddit模拟'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--action-log',
|
||||
type=str,
|
||||
default='actions.jsonl',
|
||||
help='动作日志文件路径 (默认: actions.jsonl)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.config):
|
||||
print(f"错误: 配置文件不存在: {args.config}")
|
||||
sys.exit(1)
|
||||
|
||||
config = load_config(args.config)
|
||||
simulation_dir = os.path.dirname(args.config) or "."
|
||||
|
||||
# 创建动作日志记录器
|
||||
action_log_path = os.path.join(simulation_dir, args.action_log)
|
||||
action_logger = ActionLogger(action_log_path)
|
||||
|
||||
print("=" * 60)
|
||||
print("OASIS 双平台并行模拟")
|
||||
print(f"配置文件: {args.config}")
|
||||
print(f"模拟ID: {config.get('simulation_id', 'unknown')}")
|
||||
print(f"动作日志: {action_log_path}")
|
||||
print("=" * 60)
|
||||
|
||||
time_config = config.get("time_config", {})
|
||||
print(f"\n模拟参数:")
|
||||
print(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时")
|
||||
print(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟")
|
||||
print(f" - Agent数量: {len(config.get('agent_configs', []))}")
|
||||
|
||||
# LLM推理说明
|
||||
reasoning = config.get("generation_reasoning", "")
|
||||
if reasoning:
|
||||
print(f"\nLLM配置推理:")
|
||||
print(f" {reasoning[:500]}..." if len(reasoning) > 500 else f" {reasoning}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
if args.twitter_only:
|
||||
await run_twitter_simulation(config, simulation_dir, action_logger)
|
||||
elif args.reddit_only:
|
||||
await run_reddit_simulation(config, simulation_dir, action_logger)
|
||||
else:
|
||||
# 并行运行(共享同一个action_logger)
|
||||
await asyncio.gather(
|
||||
run_twitter_simulation(config, simulation_dir, action_logger),
|
||||
run_reddit_simulation(config, simulation_dir, action_logger),
|
||||
)
|
||||
|
||||
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||
print("\n" + "=" * 60)
|
||||
print(f"全部模拟完成! 总耗时: {total_elapsed:.1f}秒")
|
||||
print(f"动作日志已保存到: {action_log_path}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
298
backend/scripts/run_reddit_simulation.py
Normal file
298
backend/scripts/run_reddit_simulation.py
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
"""
|
||||
OASIS Reddit模拟预设脚本
|
||||
此脚本读取配置文件中的参数来执行模拟,实现全程自动化
|
||||
|
||||
使用方式:
|
||||
python run_reddit_simulation.py --config /path/to/simulation_config.json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
|
||||
# 添加项目路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
try:
|
||||
from camel.models import ModelFactory
|
||||
from camel.types import ModelPlatformType
|
||||
import oasis
|
||||
from oasis import (
|
||||
ActionType,
|
||||
LLMAction,
|
||||
ManualAction,
|
||||
generate_reddit_agent_graph
|
||||
)
|
||||
except ImportError as e:
|
||||
print(f"错误: 缺少依赖 {e}")
|
||||
print("请先安装: pip install oasis-ai camel-ai")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class RedditSimulationRunner:
|
||||
"""Reddit模拟运行器"""
|
||||
|
||||
# Reddit可用动作
|
||||
AVAILABLE_ACTIONS = [
|
||||
ActionType.LIKE_POST,
|
||||
ActionType.DISLIKE_POST,
|
||||
ActionType.CREATE_POST,
|
||||
ActionType.CREATE_COMMENT,
|
||||
ActionType.LIKE_COMMENT,
|
||||
ActionType.DISLIKE_COMMENT,
|
||||
ActionType.SEARCH_POSTS,
|
||||
ActionType.SEARCH_USER,
|
||||
ActionType.TREND,
|
||||
ActionType.REFRESH,
|
||||
ActionType.DO_NOTHING,
|
||||
ActionType.FOLLOW,
|
||||
ActionType.MUTE,
|
||||
]
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
"""
|
||||
初始化模拟运行器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径 (simulation_config.json)
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.config = self._load_config()
|
||||
self.simulation_dir = os.path.dirname(config_path)
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def _get_profile_path(self) -> str:
|
||||
"""获取Profile文件路径"""
|
||||
return os.path.join(self.simulation_dir, "reddit_profiles.json")
|
||||
|
||||
def _get_db_path(self) -> str:
|
||||
"""获取数据库路径"""
|
||||
return os.path.join(self.simulation_dir, "reddit_simulation.db")
|
||||
|
||||
def _create_model(self):
|
||||
"""
|
||||
创建LLM模型
|
||||
|
||||
OASIS使用camel-ai的ModelFactory,配置方式:
|
||||
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量
|
||||
- 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量
|
||||
"""
|
||||
import os
|
||||
|
||||
llm_model = self.config.get("llm_model", "gpt-4o-mini")
|
||||
llm_base_url = self.config.get("llm_base_url", "")
|
||||
|
||||
# 如果配置了base_url,设置环境变量
|
||||
if llm_base_url:
|
||||
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
|
||||
|
||||
return ModelFactory.create(
|
||||
model_platform=ModelPlatformType.OPENAI,
|
||||
model_type=llm_model,
|
||||
)
|
||||
|
||||
def _get_active_agents_for_round(
|
||||
self,
|
||||
env,
|
||||
current_hour: int,
|
||||
round_num: int
|
||||
) -> List:
|
||||
"""
|
||||
根据时间和配置决定本轮激活哪些Agent
|
||||
"""
|
||||
time_config = self.config.get("time_config", {})
|
||||
agent_configs = self.config.get("agent_configs", [])
|
||||
|
||||
base_min = time_config.get("agents_per_hour_min", 5)
|
||||
base_max = time_config.get("agents_per_hour_max", 20)
|
||||
|
||||
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
|
||||
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
|
||||
|
||||
if current_hour in peak_hours:
|
||||
multiplier = time_config.get("peak_activity_multiplier", 1.5)
|
||||
elif current_hour in off_peak_hours:
|
||||
multiplier = time_config.get("off_peak_activity_multiplier", 0.3)
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
target_count = int(random.uniform(base_min, base_max) * multiplier)
|
||||
|
||||
candidates = []
|
||||
for cfg in agent_configs:
|
||||
agent_id = cfg.get("agent_id", 0)
|
||||
active_hours = cfg.get("active_hours", list(range(8, 23)))
|
||||
activity_level = cfg.get("activity_level", 0.5)
|
||||
|
||||
if current_hour not in active_hours:
|
||||
continue
|
||||
|
||||
if random.random() < activity_level:
|
||||
candidates.append(agent_id)
|
||||
|
||||
selected_ids = random.sample(
|
||||
candidates,
|
||||
min(target_count, len(candidates))
|
||||
) if candidates else []
|
||||
|
||||
active_agents = []
|
||||
for agent_id in selected_ids:
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
active_agents.append((agent_id, agent))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return active_agents
|
||||
|
||||
async def run(self):
|
||||
"""运行Reddit模拟"""
|
||||
print("=" * 60)
|
||||
print("OASIS Reddit模拟")
|
||||
print(f"配置文件: {self.config_path}")
|
||||
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
|
||||
print("=" * 60)
|
||||
|
||||
time_config = self.config.get("time_config", {})
|
||||
total_hours = time_config.get("total_simulation_hours", 72)
|
||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||
total_rounds = (total_hours * 60) // minutes_per_round
|
||||
|
||||
print(f"\n模拟参数:")
|
||||
print(f" - 总模拟时长: {total_hours}小时")
|
||||
print(f" - 每轮时间: {minutes_per_round}分钟")
|
||||
print(f" - 总轮数: {total_rounds}")
|
||||
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
|
||||
|
||||
print("\n初始化LLM模型...")
|
||||
model = self._create_model()
|
||||
|
||||
print("加载Agent Profile...")
|
||||
profile_path = self._get_profile_path()
|
||||
if not os.path.exists(profile_path):
|
||||
print(f"错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
|
||||
agent_graph = await generate_reddit_agent_graph(
|
||||
profile_path=profile_path,
|
||||
model=model,
|
||||
available_actions=self.AVAILABLE_ACTIONS,
|
||||
)
|
||||
|
||||
db_path = self._get_db_path()
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
print(f"已删除旧数据库: {db_path}")
|
||||
|
||||
print("创建OASIS环境...")
|
||||
env = oasis.make(
|
||||
agent_graph=agent_graph,
|
||||
platform=oasis.DefaultPlatformType.REDDIT,
|
||||
database_path=db_path,
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
print("环境初始化完成\n")
|
||||
|
||||
# 执行初始事件
|
||||
event_config = self.config.get("event_config", {})
|
||||
initial_posts = event_config.get("initial_posts", [])
|
||||
|
||||
if initial_posts:
|
||||
print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...")
|
||||
initial_actions = {}
|
||||
for post in initial_posts:
|
||||
agent_id = post.get("poster_agent_id", 0)
|
||||
content = post.get("content", "")
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
if agent in initial_actions:
|
||||
if not isinstance(initial_actions[agent], list):
|
||||
initial_actions[agent] = [initial_actions[agent]]
|
||||
initial_actions[agent].append(ManualAction(
|
||||
action_type=ActionType.CREATE_POST,
|
||||
action_args={"content": content}
|
||||
))
|
||||
else:
|
||||
initial_actions[agent] = ManualAction(
|
||||
action_type=ActionType.CREATE_POST,
|
||||
action_args={"content": content}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}")
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
print(f" 已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 主模拟循环
|
||||
print("\n开始模拟循环...")
|
||||
start_time = datetime.now()
|
||||
|
||||
for round_num in range(total_rounds):
|
||||
simulated_minutes = round_num * minutes_per_round
|
||||
simulated_hour = (simulated_minutes // 60) % 24
|
||||
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||
|
||||
active_agents = self._get_active_agents_for_round(
|
||||
env, simulated_hour, round_num
|
||||
)
|
||||
|
||||
if not active_agents:
|
||||
continue
|
||||
|
||||
actions = {
|
||||
agent: LLMAction()
|
||||
for _, agent in active_agents
|
||||
}
|
||||
|
||||
await env.step(actions)
|
||||
|
||||
if (round_num + 1) % 10 == 0 or round_num == 0:
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
progress = (round_num + 1) / total_rounds * 100
|
||||
print(f" [Day {simulated_day}, {simulated_hour:02d}:00] "
|
||||
f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) "
|
||||
f"- {len(active_agents)} agents active "
|
||||
f"- elapsed: {elapsed:.1f}s")
|
||||
|
||||
await env.close()
|
||||
|
||||
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||
print(f"\n模拟完成!")
|
||||
print(f" - 总耗时: {total_elapsed:.1f}秒")
|
||||
print(f" - 数据库: {db_path}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description='OASIS Reddit模拟')
|
||||
parser.add_argument(
|
||||
'--config',
|
||||
type=str,
|
||||
required=True,
|
||||
help='配置文件路径 (simulation_config.json)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.config):
|
||||
print(f"错误: 配置文件不存在: {args.config}")
|
||||
sys.exit(1)
|
||||
|
||||
runner = RedditSimulationRunner(args.config)
|
||||
await runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
313
backend/scripts/run_twitter_simulation.py
Normal file
313
backend/scripts/run_twitter_simulation.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
"""
|
||||
OASIS Twitter模拟预设脚本
|
||||
此脚本读取配置文件中的参数来执行模拟,实现全程自动化
|
||||
|
||||
使用方式:
|
||||
python run_twitter_simulation.py --config /path/to/simulation_config.json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
|
||||
# 添加项目路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
try:
|
||||
from camel.models import ModelFactory
|
||||
from camel.types import ModelPlatformType
|
||||
import oasis
|
||||
from oasis import (
|
||||
ActionType,
|
||||
LLMAction,
|
||||
ManualAction,
|
||||
generate_twitter_agent_graph
|
||||
)
|
||||
except ImportError as e:
|
||||
print(f"错误: 缺少依赖 {e}")
|
||||
print("请先安装: pip install oasis-ai camel-ai")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class TwitterSimulationRunner:
|
||||
"""Twitter模拟运行器"""
|
||||
|
||||
# Twitter可用动作
|
||||
AVAILABLE_ACTIONS = [
|
||||
ActionType.CREATE_POST,
|
||||
ActionType.LIKE_POST,
|
||||
ActionType.REPOST,
|
||||
ActionType.FOLLOW,
|
||||
ActionType.DO_NOTHING,
|
||||
ActionType.QUOTE_POST,
|
||||
]
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
"""
|
||||
初始化模拟运行器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径 (simulation_config.json)
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.config = self._load_config()
|
||||
self.simulation_dir = os.path.dirname(config_path)
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def _get_profile_path(self) -> str:
|
||||
"""获取Profile文件路径(OASIS Twitter使用CSV格式)"""
|
||||
return os.path.join(self.simulation_dir, "twitter_profiles.csv")
|
||||
|
||||
def _get_db_path(self) -> str:
|
||||
"""获取数据库路径"""
|
||||
return os.path.join(self.simulation_dir, "twitter_simulation.db")
|
||||
|
||||
def _create_model(self):
|
||||
"""
|
||||
创建LLM模型
|
||||
|
||||
OASIS使用camel-ai的ModelFactory,配置方式:
|
||||
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量
|
||||
- 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量
|
||||
|
||||
配置文件中的 llm_model 对应 model_type
|
||||
"""
|
||||
import os
|
||||
|
||||
llm_model = self.config.get("llm_model", "gpt-4o-mini")
|
||||
llm_base_url = self.config.get("llm_base_url", "")
|
||||
|
||||
# 如果配置了base_url,设置环境变量(OASIS通过环境变量读取)
|
||||
if llm_base_url:
|
||||
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
|
||||
|
||||
return ModelFactory.create(
|
||||
model_platform=ModelPlatformType.OPENAI,
|
||||
model_type=llm_model,
|
||||
)
|
||||
|
||||
def _get_active_agents_for_round(
|
||||
self,
|
||||
env,
|
||||
current_hour: int,
|
||||
round_num: int
|
||||
) -> List:
|
||||
"""
|
||||
根据时间和配置决定本轮激活哪些Agent
|
||||
|
||||
Args:
|
||||
env: OASIS环境
|
||||
current_hour: 当前模拟小时(0-23)
|
||||
round_num: 当前轮数
|
||||
|
||||
Returns:
|
||||
激活的Agent列表
|
||||
"""
|
||||
time_config = self.config.get("time_config", {})
|
||||
agent_configs = self.config.get("agent_configs", [])
|
||||
|
||||
# 基础激活数量
|
||||
base_min = time_config.get("agents_per_hour_min", 5)
|
||||
base_max = time_config.get("agents_per_hour_max", 20)
|
||||
|
||||
# 根据时段调整
|
||||
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
|
||||
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
|
||||
|
||||
if current_hour in peak_hours:
|
||||
multiplier = time_config.get("peak_activity_multiplier", 1.5)
|
||||
elif current_hour in off_peak_hours:
|
||||
multiplier = time_config.get("off_peak_activity_multiplier", 0.3)
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
target_count = int(random.uniform(base_min, base_max) * multiplier)
|
||||
|
||||
# 根据每个Agent的配置计算激活概率
|
||||
candidates = []
|
||||
for cfg in agent_configs:
|
||||
agent_id = cfg.get("agent_id", 0)
|
||||
active_hours = cfg.get("active_hours", list(range(8, 23)))
|
||||
activity_level = cfg.get("activity_level", 0.5)
|
||||
|
||||
# 检查是否在活跃时间
|
||||
if current_hour not in active_hours:
|
||||
continue
|
||||
|
||||
# 根据活跃度计算概率
|
||||
if random.random() < activity_level:
|
||||
candidates.append(agent_id)
|
||||
|
||||
# 随机选择
|
||||
selected_ids = random.sample(
|
||||
candidates,
|
||||
min(target_count, len(candidates))
|
||||
) if candidates else []
|
||||
|
||||
# 转换为Agent对象
|
||||
active_agents = []
|
||||
for agent_id in selected_ids:
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
active_agents.append((agent_id, agent))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return active_agents
|
||||
|
||||
async def run(self):
|
||||
"""运行Twitter模拟"""
|
||||
print("=" * 60)
|
||||
print("OASIS Twitter模拟")
|
||||
print(f"配置文件: {self.config_path}")
|
||||
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
|
||||
print("=" * 60)
|
||||
|
||||
# 加载时间配置
|
||||
time_config = self.config.get("time_config", {})
|
||||
total_hours = time_config.get("total_simulation_hours", 72)
|
||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||
|
||||
# 计算总轮数
|
||||
total_rounds = (total_hours * 60) // minutes_per_round
|
||||
|
||||
print(f"\n模拟参数:")
|
||||
print(f" - 总模拟时长: {total_hours}小时")
|
||||
print(f" - 每轮时间: {minutes_per_round}分钟")
|
||||
print(f" - 总轮数: {total_rounds}")
|
||||
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
|
||||
|
||||
# 创建模型
|
||||
print("\n初始化LLM模型...")
|
||||
model = self._create_model()
|
||||
|
||||
# 加载Agent图
|
||||
print("加载Agent Profile...")
|
||||
profile_path = self._get_profile_path()
|
||||
if not os.path.exists(profile_path):
|
||||
print(f"错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
|
||||
agent_graph = await generate_twitter_agent_graph(
|
||||
profile_path=profile_path,
|
||||
model=model,
|
||||
available_actions=self.AVAILABLE_ACTIONS,
|
||||
)
|
||||
|
||||
# 数据库路径
|
||||
db_path = self._get_db_path()
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
print(f"已删除旧数据库: {db_path}")
|
||||
|
||||
# 创建环境
|
||||
print("创建OASIS环境...")
|
||||
env = oasis.make(
|
||||
agent_graph=agent_graph,
|
||||
platform=oasis.DefaultPlatformType.TWITTER,
|
||||
database_path=db_path,
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
print("环境初始化完成\n")
|
||||
|
||||
# 执行初始事件
|
||||
event_config = self.config.get("event_config", {})
|
||||
initial_posts = event_config.get("initial_posts", [])
|
||||
|
||||
if initial_posts:
|
||||
print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...")
|
||||
initial_actions = {}
|
||||
for post in initial_posts:
|
||||
agent_id = post.get("poster_agent_id", 0)
|
||||
content = post.get("content", "")
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
initial_actions[agent] = ManualAction(
|
||||
action_type=ActionType.CREATE_POST,
|
||||
action_args={"content": content}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}")
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
print(f" 已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 主模拟循环
|
||||
print("\n开始模拟循环...")
|
||||
start_time = datetime.now()
|
||||
|
||||
for round_num in range(total_rounds):
|
||||
# 计算当前模拟时间
|
||||
simulated_minutes = round_num * minutes_per_round
|
||||
simulated_hour = (simulated_minutes // 60) % 24
|
||||
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||
|
||||
# 获取本轮激活的Agent
|
||||
active_agents = self._get_active_agents_for_round(
|
||||
env, simulated_hour, round_num
|
||||
)
|
||||
|
||||
if not active_agents:
|
||||
continue
|
||||
|
||||
# 构建动作
|
||||
actions = {
|
||||
agent: LLMAction()
|
||||
for _, agent in active_agents
|
||||
}
|
||||
|
||||
# 执行动作
|
||||
await env.step(actions)
|
||||
|
||||
# 打印进度
|
||||
if (round_num + 1) % 10 == 0 or round_num == 0:
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
progress = (round_num + 1) / total_rounds * 100
|
||||
print(f" [Day {simulated_day}, {simulated_hour:02d}:00] "
|
||||
f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) "
|
||||
f"- {len(active_agents)} agents active "
|
||||
f"- elapsed: {elapsed:.1f}s")
|
||||
|
||||
# 关闭环境
|
||||
await env.close()
|
||||
|
||||
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||
print(f"\n模拟完成!")
|
||||
print(f" - 总耗时: {total_elapsed:.1f}秒")
|
||||
print(f" - 数据库: {db_path}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description='OASIS Twitter模拟')
|
||||
parser.add_argument(
|
||||
'--config',
|
||||
type=str,
|
||||
required=True,
|
||||
help='配置文件路径 (simulation_config.json)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.config):
|
||||
print(f"错误: 配置文件不存在: {args.config}")
|
||||
sys.exit(1)
|
||||
|
||||
runner = TwitterSimulationRunner(args.config)
|
||||
await runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
166
backend/scripts/test_profile_format.py
Normal file
166
backend/scripts/test_profile_format.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
"""
|
||||
测试Profile格式生成是否符合OASIS要求
|
||||
验证:
|
||||
1. Twitter Profile生成CSV格式
|
||||
2. Reddit Profile生成JSON详细格式
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import csv
|
||||
import tempfile
|
||||
|
||||
# 添加项目路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from app.services.oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
|
||||
|
||||
|
||||
def test_profile_formats():
|
||||
"""测试Profile格式"""
|
||||
print("=" * 60)
|
||||
print("OASIS Profile格式测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建测试Profile数据
|
||||
test_profiles = [
|
||||
OasisAgentProfile(
|
||||
user_id=0,
|
||||
user_name="test_user_123",
|
||||
name="Test User",
|
||||
bio="A test user for validation",
|
||||
persona="Test User is an enthusiastic participant in social discussions.",
|
||||
karma=1500,
|
||||
friend_count=100,
|
||||
follower_count=200,
|
||||
statuses_count=500,
|
||||
age=25,
|
||||
gender="male",
|
||||
mbti="INTJ",
|
||||
country="China",
|
||||
profession="Student",
|
||||
interested_topics=["Technology", "Education"],
|
||||
source_entity_uuid="test-uuid-123",
|
||||
source_entity_type="Student",
|
||||
),
|
||||
OasisAgentProfile(
|
||||
user_id=1,
|
||||
user_name="org_official_456",
|
||||
name="Official Organization",
|
||||
bio="Official account for Organization",
|
||||
persona="This is an official institutional account that communicates official positions.",
|
||||
karma=5000,
|
||||
friend_count=50,
|
||||
follower_count=10000,
|
||||
statuses_count=200,
|
||||
profession="Organization",
|
||||
interested_topics=["Public Policy", "Announcements"],
|
||||
source_entity_uuid="test-uuid-456",
|
||||
source_entity_type="University",
|
||||
),
|
||||
]
|
||||
|
||||
generator = OasisProfileGenerator.__new__(OasisProfileGenerator)
|
||||
|
||||
# 使用临时目录
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
twitter_path = os.path.join(temp_dir, "twitter_profiles.csv")
|
||||
reddit_path = os.path.join(temp_dir, "reddit_profiles.json")
|
||||
|
||||
# 测试Twitter CSV格式
|
||||
print("\n1. 测试Twitter Profile (CSV格式)")
|
||||
print("-" * 40)
|
||||
generator._save_twitter_csv(test_profiles, twitter_path)
|
||||
|
||||
# 读取并验证CSV
|
||||
with open(twitter_path, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
rows = list(reader)
|
||||
|
||||
print(f" 文件: {twitter_path}")
|
||||
print(f" 行数: {len(rows)}")
|
||||
print(f" 表头: {list(rows[0].keys())}")
|
||||
print(f"\n 示例数据 (第1行):")
|
||||
for key, value in rows[0].items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# 验证必需字段
|
||||
required_twitter_fields = ['user_id', 'user_name', 'name', 'bio',
|
||||
'friend_count', 'follower_count', 'statuses_count', 'created_at']
|
||||
missing = set(required_twitter_fields) - set(rows[0].keys())
|
||||
if missing:
|
||||
print(f"\n [错误] 缺少字段: {missing}")
|
||||
else:
|
||||
print(f"\n [通过] 所有必需字段都存在")
|
||||
|
||||
# 测试Reddit JSON格式
|
||||
print("\n2. 测试Reddit Profile (JSON详细格式)")
|
||||
print("-" * 40)
|
||||
generator._save_reddit_json(test_profiles, reddit_path)
|
||||
|
||||
# 读取并验证JSON
|
||||
with open(reddit_path, 'r', encoding='utf-8') as f:
|
||||
reddit_data = json.load(f)
|
||||
|
||||
print(f" 文件: {reddit_path}")
|
||||
print(f" 条目数: {len(reddit_data)}")
|
||||
print(f" 字段: {list(reddit_data[0].keys())}")
|
||||
print(f"\n 示例数据 (第1条):")
|
||||
print(json.dumps(reddit_data[0], ensure_ascii=False, indent=4))
|
||||
|
||||
# 验证详细格式字段
|
||||
required_reddit_fields = ['realname', 'username', 'bio', 'persona']
|
||||
optional_reddit_fields = ['age', 'gender', 'mbti', 'country', 'profession', 'interested_topics']
|
||||
|
||||
missing = set(required_reddit_fields) - set(reddit_data[0].keys())
|
||||
if missing:
|
||||
print(f"\n [错误] 缺少必需字段: {missing}")
|
||||
else:
|
||||
print(f"\n [通过] 所有必需字段都存在")
|
||||
|
||||
present_optional = set(optional_reddit_fields) & set(reddit_data[0].keys())
|
||||
print(f" [信息] 可选字段: {present_optional}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def show_expected_formats():
|
||||
"""显示OASIS期望的格式"""
|
||||
print("\n" + "=" * 60)
|
||||
print("OASIS 期望的Profile格式参考")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n1. Twitter Profile (CSV格式)")
|
||||
print("-" * 40)
|
||||
twitter_example = """user_id,user_name,name,bio,friend_count,follower_count,statuses_count,created_at
|
||||
0,user0,User Zero,I am user zero with interests in technology.,100,150,500,2023-01-01
|
||||
1,user1,User One,Tech enthusiast and coffee lover.,200,250,1000,2023-01-02"""
|
||||
print(twitter_example)
|
||||
|
||||
print("\n2. Reddit Profile (JSON详细格式)")
|
||||
print("-" * 40)
|
||||
reddit_example = [
|
||||
{
|
||||
"realname": "James Miller",
|
||||
"username": "millerhospitality",
|
||||
"bio": "Passionate about hospitality & tourism.",
|
||||
"persona": "James is a seasoned professional in the Hospitality & Tourism industry...",
|
||||
"age": 40,
|
||||
"gender": "male",
|
||||
"mbti": "ESTJ",
|
||||
"country": "UK",
|
||||
"profession": "Hospitality & Tourism",
|
||||
"interested_topics": ["Economics", "Business"]
|
||||
}
|
||||
]
|
||||
print(json.dumps(reddit_example, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_profile_formats()
|
||||
show_expected_formats()
|
||||
|
||||
|
||||
Loading…
Reference in a new issue