Enhance backend functionality with OASIS simulation features
- Updated README.md to include new simulation scripts and configuration details for OASIS, including API retry mechanisms and environment variable settings. - Added simulation management and configuration generation services to streamline the simulation process across Twitter and Reddit platforms. - Introduced new API routes for simulation-related operations, including entity retrieval and simulation status management. - Implemented a robust retry mechanism for external API calls to improve system stability. - Enhanced task management model to include detailed progress tracking. - Added logging capabilities for action tracking during simulations. - Included new scripts for running parallel simulations and testing profile formats.
This commit is contained in:
parent
c60e6e1089
commit
5f159f6d88
19 changed files with 7202 additions and 49 deletions
1442
backend/README.md
1442
backend/README.md
File diff suppressed because it is too large
Load diff
|
|
@ -46,8 +46,9 @@ def create_app(config_class=Config):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# 注册蓝图
|
# 注册蓝图
|
||||||
from .api import graph_bp
|
from .api import graph_bp, simulation_bp
|
||||||
app.register_blueprint(graph_bp, url_prefix='/api/graph')
|
app.register_blueprint(graph_bp, url_prefix='/api/graph')
|
||||||
|
app.register_blueprint(simulation_bp, url_prefix='/api/simulation')
|
||||||
|
|
||||||
# 健康检查
|
# 健康检查
|
||||||
@app.route('/health')
|
@app.route('/health')
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ API路由模块
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
|
|
||||||
graph_bp = Blueprint('graph', __name__)
|
graph_bp = Blueprint('graph', __name__)
|
||||||
|
simulation_bp = Blueprint('simulation', __name__)
|
||||||
|
|
||||||
from . import graph # noqa: E402, F401
|
from . import graph # noqa: E402, F401
|
||||||
|
from . import simulation # noqa: E402, F401
|
||||||
|
|
||||||
|
|
|
||||||
1330
backend/app/api/simulation.py
Normal file
1330
backend/app/api/simulation.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -41,6 +41,20 @@ class Config:
|
||||||
DEFAULT_CHUNK_SIZE = 500 # 默认切块大小
|
DEFAULT_CHUNK_SIZE = 500 # 默认切块大小
|
||||||
DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小
|
DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小
|
||||||
|
|
||||||
|
# OASIS模拟配置
|
||||||
|
OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10'))
|
||||||
|
OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations')
|
||||||
|
|
||||||
|
# OASIS平台可用动作配置
|
||||||
|
OASIS_TWITTER_ACTIONS = [
|
||||||
|
'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST'
|
||||||
|
]
|
||||||
|
OASIS_REDDIT_ACTIONS = [
|
||||||
|
'LIKE_POST', 'DISLIKE_POST', 'CREATE_POST', 'CREATE_COMMENT',
|
||||||
|
'LIKE_COMMENT', 'DISLIKE_COMMENT', 'SEARCH_POSTS', 'SEARCH_USER',
|
||||||
|
'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE'
|
||||||
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls):
|
def validate(cls):
|
||||||
"""验证必要配置"""
|
"""验证必要配置"""
|
||||||
|
|
|
||||||
|
|
@ -27,11 +27,12 @@ class Task:
|
||||||
status: TaskStatus
|
status: TaskStatus
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
progress: int = 0 # 进度百分比 0-100
|
progress: int = 0 # 总进度百分比 0-100
|
||||||
message: str = "" # 状态消息
|
message: str = "" # 状态消息
|
||||||
result: Optional[Dict] = None # 任务结果
|
result: Optional[Dict] = None # 任务结果
|
||||||
error: Optional[str] = None # 错误信息
|
error: Optional[str] = None # 错误信息
|
||||||
metadata: Dict = field(default_factory=dict) # 额外元数据
|
metadata: Dict = field(default_factory=dict) # 额外元数据
|
||||||
|
progress_detail: Dict = field(default_factory=dict) # 详细进度信息
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""转换为字典"""
|
"""转换为字典"""
|
||||||
|
|
@ -43,6 +44,7 @@ class Task:
|
||||||
"updated_at": self.updated_at.isoformat(),
|
"updated_at": self.updated_at.isoformat(),
|
||||||
"progress": self.progress,
|
"progress": self.progress,
|
||||||
"message": self.message,
|
"message": self.message,
|
||||||
|
"progress_detail": self.progress_detail,
|
||||||
"result": self.result,
|
"result": self.result,
|
||||||
"error": self.error,
|
"error": self.error,
|
||||||
"metadata": self.metadata,
|
"metadata": self.metadata,
|
||||||
|
|
@ -108,7 +110,8 @@ class TaskManager:
|
||||||
progress: Optional[int] = None,
|
progress: Optional[int] = None,
|
||||||
message: Optional[str] = None,
|
message: Optional[str] = None,
|
||||||
result: Optional[Dict] = None,
|
result: Optional[Dict] = None,
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None,
|
||||||
|
progress_detail: Optional[Dict] = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
更新任务状态
|
更新任务状态
|
||||||
|
|
@ -120,6 +123,7 @@ class TaskManager:
|
||||||
message: 消息
|
message: 消息
|
||||||
result: 结果
|
result: 结果
|
||||||
error: 错误信息
|
error: 错误信息
|
||||||
|
progress_detail: 详细进度信息
|
||||||
"""
|
"""
|
||||||
with self._task_lock:
|
with self._task_lock:
|
||||||
task = self._tasks.get(task_id)
|
task = self._tasks.get(task_id)
|
||||||
|
|
@ -135,6 +139,8 @@ class TaskManager:
|
||||||
task.result = result
|
task.result = result
|
||||||
if error is not None:
|
if error is not None:
|
||||||
task.error = error
|
task.error = error
|
||||||
|
if progress_detail is not None:
|
||||||
|
task.progress_detail = progress_detail
|
||||||
|
|
||||||
def complete_task(self, task_id: str, result: Dict):
|
def complete_task(self, task_id: str, result: Dict):
|
||||||
"""标记任务完成"""
|
"""标记任务完成"""
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,47 @@
|
||||||
from .ontology_generator import OntologyGenerator
|
from .ontology_generator import OntologyGenerator
|
||||||
from .graph_builder import GraphBuilderService
|
from .graph_builder import GraphBuilderService
|
||||||
from .text_processor import TextProcessor
|
from .text_processor import TextProcessor
|
||||||
|
from .zep_entity_reader import ZepEntityReader, EntityNode, FilteredEntities
|
||||||
|
from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
|
||||||
|
from .simulation_manager import SimulationManager, SimulationState, SimulationStatus
|
||||||
|
from .simulation_config_generator import (
|
||||||
|
SimulationConfigGenerator,
|
||||||
|
SimulationParameters,
|
||||||
|
AgentActivityConfig,
|
||||||
|
TimeSimulationConfig,
|
||||||
|
EventConfig,
|
||||||
|
PlatformConfig
|
||||||
|
)
|
||||||
|
from .simulation_runner import (
|
||||||
|
SimulationRunner,
|
||||||
|
SimulationRunState,
|
||||||
|
RunnerStatus,
|
||||||
|
AgentAction,
|
||||||
|
RoundSummary
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ['OntologyGenerator', 'GraphBuilderService', 'TextProcessor']
|
__all__ = [
|
||||||
|
'OntologyGenerator',
|
||||||
|
'GraphBuilderService',
|
||||||
|
'TextProcessor',
|
||||||
|
'ZepEntityReader',
|
||||||
|
'EntityNode',
|
||||||
|
'FilteredEntities',
|
||||||
|
'OasisProfileGenerator',
|
||||||
|
'OasisAgentProfile',
|
||||||
|
'SimulationManager',
|
||||||
|
'SimulationState',
|
||||||
|
'SimulationStatus',
|
||||||
|
'SimulationConfigGenerator',
|
||||||
|
'SimulationParameters',
|
||||||
|
'AgentActivityConfig',
|
||||||
|
'TimeSimulationConfig',
|
||||||
|
'EventConfig',
|
||||||
|
'PlatformConfig',
|
||||||
|
'SimulationRunner',
|
||||||
|
'SimulationRunState',
|
||||||
|
'RunnerStatus',
|
||||||
|
'AgentAction',
|
||||||
|
'RoundSummary',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
|
||||||
561
backend/app/services/oasis_profile_generator.py
Normal file
561
backend/app/services/oasis_profile_generator.py
Normal file
|
|
@ -0,0 +1,561 @@
|
||||||
|
"""
|
||||||
|
OASIS Agent Profile生成器
|
||||||
|
将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from ..config import Config
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||||
|
|
||||||
|
logger = get_logger('mirofish.oasis_profile')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OasisAgentProfile:
|
||||||
|
"""OASIS Agent Profile数据结构"""
|
||||||
|
# 通用字段
|
||||||
|
user_id: int
|
||||||
|
user_name: str
|
||||||
|
name: str
|
||||||
|
bio: str
|
||||||
|
persona: str
|
||||||
|
|
||||||
|
# 可选字段 - Reddit风格
|
||||||
|
karma: int = 1000
|
||||||
|
|
||||||
|
# 可选字段 - Twitter风格
|
||||||
|
friend_count: int = 100
|
||||||
|
follower_count: int = 150
|
||||||
|
statuses_count: int = 500
|
||||||
|
|
||||||
|
# 额外人设信息
|
||||||
|
age: Optional[int] = None
|
||||||
|
gender: Optional[str] = None
|
||||||
|
mbti: Optional[str] = None
|
||||||
|
country: Optional[str] = None
|
||||||
|
profession: Optional[str] = None
|
||||||
|
interested_topics: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# 来源实体信息
|
||||||
|
source_entity_uuid: Optional[str] = None
|
||||||
|
source_entity_type: Optional[str] = None
|
||||||
|
|
||||||
|
created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d"))
|
||||||
|
|
||||||
|
def to_reddit_format(self) -> Dict[str, Any]:
|
||||||
|
"""转换为Reddit平台格式"""
|
||||||
|
profile = {
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"user_name": self.user_name,
|
||||||
|
"name": self.name,
|
||||||
|
"bio": self.bio,
|
||||||
|
"persona": self.persona,
|
||||||
|
"karma": self.karma,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加额外人设信息(如果有)
|
||||||
|
if self.age:
|
||||||
|
profile["age"] = self.age
|
||||||
|
if self.gender:
|
||||||
|
profile["gender"] = self.gender
|
||||||
|
if self.mbti:
|
||||||
|
profile["mbti"] = self.mbti
|
||||||
|
if self.country:
|
||||||
|
profile["country"] = self.country
|
||||||
|
if self.profession:
|
||||||
|
profile["profession"] = self.profession
|
||||||
|
if self.interested_topics:
|
||||||
|
profile["interested_topics"] = self.interested_topics
|
||||||
|
|
||||||
|
return profile
|
||||||
|
|
||||||
|
def to_twitter_format(self) -> Dict[str, Any]:
|
||||||
|
"""转换为Twitter平台格式"""
|
||||||
|
profile = {
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"user_name": self.user_name,
|
||||||
|
"name": self.name,
|
||||||
|
"bio": self.bio,
|
||||||
|
"persona": self.persona,
|
||||||
|
"friend_count": self.friend_count,
|
||||||
|
"follower_count": self.follower_count,
|
||||||
|
"statuses_count": self.statuses_count,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加额外人设信息
|
||||||
|
if self.age:
|
||||||
|
profile["age"] = self.age
|
||||||
|
if self.gender:
|
||||||
|
profile["gender"] = self.gender
|
||||||
|
if self.mbti:
|
||||||
|
profile["mbti"] = self.mbti
|
||||||
|
if self.country:
|
||||||
|
profile["country"] = self.country
|
||||||
|
if self.profession:
|
||||||
|
profile["profession"] = self.profession
|
||||||
|
if self.interested_topics:
|
||||||
|
profile["interested_topics"] = self.interested_topics
|
||||||
|
|
||||||
|
return profile
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为完整字典格式"""
|
||||||
|
return {
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"user_name": self.user_name,
|
||||||
|
"name": self.name,
|
||||||
|
"bio": self.bio,
|
||||||
|
"persona": self.persona,
|
||||||
|
"karma": self.karma,
|
||||||
|
"friend_count": self.friend_count,
|
||||||
|
"follower_count": self.follower_count,
|
||||||
|
"statuses_count": self.statuses_count,
|
||||||
|
"age": self.age,
|
||||||
|
"gender": self.gender,
|
||||||
|
"mbti": self.mbti,
|
||||||
|
"country": self.country,
|
||||||
|
"profession": self.profession,
|
||||||
|
"interested_topics": self.interested_topics,
|
||||||
|
"source_entity_uuid": self.source_entity_uuid,
|
||||||
|
"source_entity_type": self.source_entity_type,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OasisProfileGenerator:
|
||||||
|
"""
|
||||||
|
OASIS Profile生成器
|
||||||
|
|
||||||
|
将Zep图谱中的实体转换为OASIS模拟所需的Agent Profile
|
||||||
|
"""
|
||||||
|
|
||||||
|
# MBTI类型列表
|
||||||
|
MBTI_TYPES = [
|
||||||
|
"INTJ", "INTP", "ENTJ", "ENTP",
|
||||||
|
"INFJ", "INFP", "ENFJ", "ENFP",
|
||||||
|
"ISTJ", "ISFJ", "ESTJ", "ESFJ",
|
||||||
|
"ISTP", "ISFP", "ESTP", "ESFP"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 常见国家列表
|
||||||
|
COUNTRIES = [
|
||||||
|
"China", "US", "UK", "Japan", "Germany", "France",
|
||||||
|
"Canada", "Australia", "Brazil", "India", "South Korea"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None
|
||||||
|
):
|
||||||
|
self.api_key = api_key or Config.LLM_API_KEY
|
||||||
|
self.base_url = base_url or Config.LLM_BASE_URL
|
||||||
|
self.model_name = model_name or Config.LLM_MODEL_NAME
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("LLM_API_KEY 未配置")
|
||||||
|
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=self.api_key,
|
||||||
|
base_url=self.base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_profile_from_entity(
|
||||||
|
self,
|
||||||
|
entity: EntityNode,
|
||||||
|
user_id: int,
|
||||||
|
use_llm: bool = True
|
||||||
|
) -> OasisAgentProfile:
|
||||||
|
"""
|
||||||
|
从Zep实体生成OASIS Agent Profile
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity: Zep实体节点
|
||||||
|
user_id: 用户ID(用于OASIS)
|
||||||
|
use_llm: 是否使用LLM生成详细人设
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OasisAgentProfile
|
||||||
|
"""
|
||||||
|
entity_type = entity.get_entity_type() or "Entity"
|
||||||
|
|
||||||
|
# 基础信息
|
||||||
|
name = entity.name
|
||||||
|
user_name = self._generate_username(name)
|
||||||
|
|
||||||
|
# 构建上下文信息
|
||||||
|
context = self._build_entity_context(entity)
|
||||||
|
|
||||||
|
if use_llm:
|
||||||
|
# 使用LLM生成详细人设
|
||||||
|
profile_data = self._generate_profile_with_llm(
|
||||||
|
entity_name=name,
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_summary=entity.summary,
|
||||||
|
entity_attributes=entity.attributes,
|
||||||
|
context=context
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 使用规则生成基础人设
|
||||||
|
profile_data = self._generate_profile_rule_based(
|
||||||
|
entity_name=name,
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_summary=entity.summary,
|
||||||
|
entity_attributes=entity.attributes
|
||||||
|
)
|
||||||
|
|
||||||
|
return OasisAgentProfile(
|
||||||
|
user_id=user_id,
|
||||||
|
user_name=user_name,
|
||||||
|
name=name,
|
||||||
|
bio=profile_data.get("bio", f"{entity_type}: {name}"),
|
||||||
|
persona=profile_data.get("persona", entity.summary or f"A {entity_type} named {name}."),
|
||||||
|
karma=profile_data.get("karma", random.randint(500, 5000)),
|
||||||
|
friend_count=profile_data.get("friend_count", random.randint(50, 500)),
|
||||||
|
follower_count=profile_data.get("follower_count", random.randint(100, 1000)),
|
||||||
|
statuses_count=profile_data.get("statuses_count", random.randint(100, 2000)),
|
||||||
|
age=profile_data.get("age"),
|
||||||
|
gender=profile_data.get("gender"),
|
||||||
|
mbti=profile_data.get("mbti"),
|
||||||
|
country=profile_data.get("country"),
|
||||||
|
profession=profile_data.get("profession"),
|
||||||
|
interested_topics=profile_data.get("interested_topics", []),
|
||||||
|
source_entity_uuid=entity.uuid,
|
||||||
|
source_entity_type=entity_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_username(self, name: str) -> str:
|
||||||
|
"""生成用户名"""
|
||||||
|
# 移除特殊字符,转换为小写
|
||||||
|
username = name.lower().replace(" ", "_")
|
||||||
|
username = ''.join(c for c in username if c.isalnum() or c == '_')
|
||||||
|
|
||||||
|
# 添加随机后缀避免重复
|
||||||
|
suffix = random.randint(100, 999)
|
||||||
|
return f"{username}_{suffix}"
|
||||||
|
|
||||||
|
def _build_entity_context(self, entity: EntityNode) -> str:
|
||||||
|
"""构建实体的上下文信息"""
|
||||||
|
context_parts = []
|
||||||
|
|
||||||
|
# 添加相关边信息
|
||||||
|
if entity.related_edges:
|
||||||
|
relationships = []
|
||||||
|
for edge in entity.related_edges[:10]: # 最多取10条
|
||||||
|
if edge.get("fact"):
|
||||||
|
relationships.append(edge["fact"])
|
||||||
|
|
||||||
|
if relationships:
|
||||||
|
context_parts.append("Related facts:\n" + "\n".join(f"- {r}" for r in relationships))
|
||||||
|
|
||||||
|
# 添加关联节点信息
|
||||||
|
if entity.related_nodes:
|
||||||
|
related_names = [n["name"] for n in entity.related_nodes[:5]]
|
||||||
|
if related_names:
|
||||||
|
context_parts.append(f"Related to: {', '.join(related_names)}")
|
||||||
|
|
||||||
|
return "\n\n".join(context_parts)
|
||||||
|
|
||||||
|
def _generate_profile_with_llm(
|
||||||
|
self,
|
||||||
|
entity_name: str,
|
||||||
|
entity_type: str,
|
||||||
|
entity_summary: str,
|
||||||
|
entity_attributes: Dict[str, Any],
|
||||||
|
context: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""使用LLM生成详细人设"""
|
||||||
|
|
||||||
|
prompt = f"""Based on the following entity information, generate a detailed social media user profile for simulation purposes.
|
||||||
|
|
||||||
|
Entity Information:
|
||||||
|
- Name: {entity_name}
|
||||||
|
- Type: {entity_type}
|
||||||
|
- Summary: {entity_summary}
|
||||||
|
- Attributes: {json.dumps(entity_attributes, ensure_ascii=False)}
|
||||||
|
|
||||||
|
Context:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
Generate a JSON object with the following fields:
|
||||||
|
{{
|
||||||
|
"bio": "A short bio (max 150 chars) suitable for social media",
|
||||||
|
"persona": "A detailed persona description (2-3 sentences) describing personality, interests, and behavior patterns",
|
||||||
|
"age": <integer between 18-65, or null if not applicable>,
|
||||||
|
"gender": "<male/female/other, or null if not applicable>",
|
||||||
|
"mbti": "<MBTI type like INTJ, ENFP, etc., or null>",
|
||||||
|
"country": "<country name, or null>",
|
||||||
|
"profession": "<profession/occupation, or null>",
|
||||||
|
"interested_topics": ["topic1", "topic2", ...]
|
||||||
|
}}
|
||||||
|
|
||||||
|
Important:
|
||||||
|
- The profile should be consistent with the entity type and context
|
||||||
|
- Make the persona feel realistic and suitable for social media simulation
|
||||||
|
- If the entity is an organization, institution, or non-person, adapt the profile accordingly (e.g., as an official account)
|
||||||
|
- Return ONLY the JSON object, no additional text"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用重试机制调用LLM API
|
||||||
|
from ..utils.retry import RetryableAPIClient
|
||||||
|
|
||||||
|
retry_client = RetryableAPIClient(max_retries=3, initial_delay=1.0)
|
||||||
|
|
||||||
|
def call_llm():
|
||||||
|
return self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a profile generator for social media simulation. Generate realistic user profiles based on entity information."},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
response = retry_client.call_with_retry(call_llm)
|
||||||
|
result = json.loads(response.choices[0].message.content)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"LLM生成人设失败(已重试): {str(e)}, 使用规则生成")
|
||||||
|
return self._generate_profile_rule_based(
|
||||||
|
entity_name, entity_type, entity_summary, entity_attributes
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_profile_rule_based(
|
||||||
|
self,
|
||||||
|
entity_name: str,
|
||||||
|
entity_type: str,
|
||||||
|
entity_summary: str,
|
||||||
|
entity_attributes: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""使用规则生成基础人设"""
|
||||||
|
|
||||||
|
# 根据实体类型生成不同的人设
|
||||||
|
entity_type_lower = entity_type.lower()
|
||||||
|
|
||||||
|
if entity_type_lower in ["student", "alumni"]:
|
||||||
|
return {
|
||||||
|
"bio": f"{entity_type} with interests in academics and social issues.",
|
||||||
|
"persona": f"{entity_name} is a {entity_type.lower()} who is actively engaged in academic and social discussions. They enjoy sharing perspectives and connecting with peers.",
|
||||||
|
"age": random.randint(18, 30),
|
||||||
|
"gender": random.choice(["male", "female"]),
|
||||||
|
"mbti": random.choice(self.MBTI_TYPES),
|
||||||
|
"country": random.choice(self.COUNTRIES),
|
||||||
|
"profession": "Student",
|
||||||
|
"interested_topics": ["Education", "Social Issues", "Technology"],
|
||||||
|
}
|
||||||
|
|
||||||
|
elif entity_type_lower in ["publicfigure", "expert", "faculty"]:
|
||||||
|
return {
|
||||||
|
"bio": f"Expert and thought leader in their field.",
|
||||||
|
"persona": f"{entity_name} is a recognized {entity_type.lower()} who shares insights and opinions on important matters. They are known for their expertise and influence in public discourse.",
|
||||||
|
"age": random.randint(35, 60),
|
||||||
|
"gender": random.choice(["male", "female"]),
|
||||||
|
"mbti": random.choice(["ENTJ", "INTJ", "ENTP", "INTP"]),
|
||||||
|
"country": random.choice(self.COUNTRIES),
|
||||||
|
"profession": entity_attributes.get("occupation", "Expert"),
|
||||||
|
"interested_topics": ["Politics", "Economics", "Culture & Society"],
|
||||||
|
}
|
||||||
|
|
||||||
|
elif entity_type_lower in ["mediaoutlet", "socialmediaplatform"]:
|
||||||
|
return {
|
||||||
|
"bio": f"Official account for {entity_name}. News and updates.",
|
||||||
|
"persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.",
|
||||||
|
"profession": "Media",
|
||||||
|
"interested_topics": ["General News", "Current Events", "Public Affairs"],
|
||||||
|
}
|
||||||
|
|
||||||
|
elif entity_type_lower in ["university", "governmentagency", "ngo", "organization"]:
|
||||||
|
return {
|
||||||
|
"bio": f"Official account of {entity_name}.",
|
||||||
|
"persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.",
|
||||||
|
"profession": entity_type,
|
||||||
|
"interested_topics": ["Public Policy", "Community", "Official Announcements"],
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 默认人设
|
||||||
|
return {
|
||||||
|
"bio": entity_summary[:150] if entity_summary else f"{entity_type}: {entity_name}",
|
||||||
|
"persona": entity_summary or f"{entity_name} is a {entity_type.lower()} participating in social discussions.",
|
||||||
|
"age": random.randint(25, 50),
|
||||||
|
"gender": random.choice(["male", "female"]),
|
||||||
|
"mbti": random.choice(self.MBTI_TYPES),
|
||||||
|
"country": random.choice(self.COUNTRIES),
|
||||||
|
"profession": entity_type,
|
||||||
|
"interested_topics": ["General", "Social Issues"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate_profiles_from_entities(
|
||||||
|
self,
|
||||||
|
entities: List[EntityNode],
|
||||||
|
use_llm: bool = True,
|
||||||
|
progress_callback: Optional[callable] = None
|
||||||
|
) -> List[OasisAgentProfile]:
|
||||||
|
"""
|
||||||
|
批量从实体生成Agent Profile
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entities: 实体列表
|
||||||
|
use_llm: 是否使用LLM生成详细人设
|
||||||
|
progress_callback: 进度回调函数 (current, total, message)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent Profile列表
|
||||||
|
"""
|
||||||
|
profiles = []
|
||||||
|
total = len(entities)
|
||||||
|
|
||||||
|
for idx, entity in enumerate(entities):
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(idx + 1, total, f"生成 {entity.name} 的人设...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
profile = self.generate_profile_from_entity(
|
||||||
|
entity=entity,
|
||||||
|
user_id=idx,
|
||||||
|
use_llm=use_llm
|
||||||
|
)
|
||||||
|
profiles.append(profile)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成实体 {entity.name} 的人设失败: {str(e)}")
|
||||||
|
# 创建一个基础profile
|
||||||
|
profiles.append(OasisAgentProfile(
|
||||||
|
user_id=idx,
|
||||||
|
user_name=self._generate_username(entity.name),
|
||||||
|
name=entity.name,
|
||||||
|
bio=f"{entity.get_entity_type() or 'Entity'}: {entity.name}",
|
||||||
|
persona=entity.summary or f"A participant in social discussions.",
|
||||||
|
source_entity_uuid=entity.uuid,
|
||||||
|
source_entity_type=entity.get_entity_type(),
|
||||||
|
))
|
||||||
|
|
||||||
|
return profiles
|
||||||
|
|
||||||
|
def save_profiles(
|
||||||
|
self,
|
||||||
|
profiles: List[OasisAgentProfile],
|
||||||
|
file_path: str,
|
||||||
|
platform: str = "reddit"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
保存Profile到文件(根据平台选择正确格式)
|
||||||
|
|
||||||
|
OASIS平台格式要求:
|
||||||
|
- Twitter: CSV格式
|
||||||
|
- Reddit: JSON格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profiles: Profile列表
|
||||||
|
file_path: 文件路径
|
||||||
|
platform: 平台类型 ("reddit" 或 "twitter")
|
||||||
|
"""
|
||||||
|
if platform == "twitter":
|
||||||
|
self._save_twitter_csv(profiles, file_path)
|
||||||
|
else:
|
||||||
|
self._save_reddit_json(profiles, file_path)
|
||||||
|
|
||||||
|
def _save_twitter_csv(self, profiles: List[OasisAgentProfile], file_path: str):
|
||||||
|
"""
|
||||||
|
保存Twitter Profile为CSV格式
|
||||||
|
|
||||||
|
OASIS Twitter要求的CSV字段:
|
||||||
|
user_id, user_name, name, bio, friend_count, follower_count, statuses_count, created_at
|
||||||
|
"""
|
||||||
|
import csv
|
||||||
|
|
||||||
|
# 确保文件扩展名是.csv
|
||||||
|
if not file_path.endswith('.csv'):
|
||||||
|
file_path = file_path.replace('.json', '.csv')
|
||||||
|
|
||||||
|
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
|
||||||
|
# 写入表头
|
||||||
|
headers = ['user_id', 'user_name', 'name', 'bio', 'friend_count',
|
||||||
|
'follower_count', 'statuses_count', 'created_at']
|
||||||
|
writer.writerow(headers)
|
||||||
|
|
||||||
|
# 写入数据行
|
||||||
|
for profile in profiles:
|
||||||
|
# bio需要处理换行符和逗号
|
||||||
|
bio = profile.bio.replace('\n', ' ').replace('\r', ' ')
|
||||||
|
row = [
|
||||||
|
profile.user_id,
|
||||||
|
profile.user_name,
|
||||||
|
profile.name,
|
||||||
|
bio,
|
||||||
|
profile.friend_count,
|
||||||
|
profile.follower_count,
|
||||||
|
profile.statuses_count,
|
||||||
|
profile.created_at
|
||||||
|
]
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (CSV格式)")
|
||||||
|
|
||||||
|
def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str):
|
||||||
|
"""
|
||||||
|
保存Reddit Profile为JSON格式
|
||||||
|
|
||||||
|
OASIS Reddit支持两种JSON格式:
|
||||||
|
1. 基础格式: user_id, user_name, name, bio, karma, created_at
|
||||||
|
2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics
|
||||||
|
|
||||||
|
我们使用详细格式,与用户示例数据(36个简单人设.json)保持一致
|
||||||
|
"""
|
||||||
|
data = []
|
||||||
|
for profile in profiles:
|
||||||
|
# 使用详细格式(与用户示例兼容)
|
||||||
|
item = {
|
||||||
|
"realname": profile.name,
|
||||||
|
"username": profile.user_name,
|
||||||
|
"bio": profile.bio[:150] if profile.bio else "", # OASIS bio限制150字符
|
||||||
|
"persona": profile.persona or f"{profile.name} is a participant in social discussions.",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加人设详情字段
|
||||||
|
if profile.age:
|
||||||
|
item["age"] = profile.age
|
||||||
|
if profile.gender:
|
||||||
|
item["gender"] = profile.gender
|
||||||
|
if profile.mbti:
|
||||||
|
item["mbti"] = profile.mbti
|
||||||
|
if profile.country:
|
||||||
|
item["country"] = profile.country
|
||||||
|
if profile.profession:
|
||||||
|
item["profession"] = profile.profession
|
||||||
|
if profile.interested_topics:
|
||||||
|
item["interested_topics"] = profile.interested_topics
|
||||||
|
|
||||||
|
data.append(item)
|
||||||
|
|
||||||
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式)")
|
||||||
|
|
||||||
|
# 保留旧方法名作为别名,保持向后兼容
|
||||||
|
def save_profiles_to_json(
|
||||||
|
self,
|
||||||
|
profiles: List[OasisAgentProfile],
|
||||||
|
file_path: str,
|
||||||
|
platform: str = "reddit"
|
||||||
|
):
|
||||||
|
"""[已废弃] 请使用 save_profiles() 方法"""
|
||||||
|
logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法")
|
||||||
|
self.save_profiles(profiles, file_path, platform)
|
||||||
|
|
||||||
584
backend/app/services/simulation_config_generator.py
Normal file
584
backend/app/services/simulation_config_generator.py
Normal file
|
|
@ -0,0 +1,584 @@
|
||||||
|
"""
|
||||||
|
模拟配置智能生成器
|
||||||
|
使用LLM根据模拟需求、文档内容、图谱信息自动生成细致的模拟参数
|
||||||
|
实现全程自动化,无需人工设置参数
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from ..config import Config
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||||
|
|
||||||
|
logger = get_logger('mirofish.simulation_config')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentActivityConfig:
|
||||||
|
"""单个Agent的活动配置"""
|
||||||
|
agent_id: int
|
||||||
|
entity_uuid: str
|
||||||
|
entity_name: str
|
||||||
|
entity_type: str
|
||||||
|
|
||||||
|
# 活跃度配置 (0.0-1.0)
|
||||||
|
activity_level: float = 0.5 # 整体活跃度
|
||||||
|
|
||||||
|
# 发言频率(每小时预期发言次数)
|
||||||
|
posts_per_hour: float = 1.0
|
||||||
|
comments_per_hour: float = 2.0
|
||||||
|
|
||||||
|
# 活跃时间段(24小时制,0-23)
|
||||||
|
active_hours: List[int] = field(default_factory=lambda: list(range(8, 23)))
|
||||||
|
|
||||||
|
# 响应速度(对热点事件的反应延迟,单位:模拟分钟)
|
||||||
|
response_delay_min: int = 5
|
||||||
|
response_delay_max: int = 60
|
||||||
|
|
||||||
|
# 情感倾向 (-1.0到1.0,负面到正面)
|
||||||
|
sentiment_bias: float = 0.0
|
||||||
|
|
||||||
|
# 立场(对特定话题的态度)
|
||||||
|
stance: str = "neutral" # supportive, opposing, neutral, observer
|
||||||
|
|
||||||
|
# 影响力权重(决定其发言被其他Agent看到的概率)
|
||||||
|
influence_weight: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TimeSimulationConfig:
|
||||||
|
"""时间模拟配置"""
|
||||||
|
# 模拟总时长(模拟小时数)
|
||||||
|
total_simulation_hours: int = 72 # 默认模拟72小时(3天)
|
||||||
|
|
||||||
|
# 每轮代表的时间(模拟分钟)
|
||||||
|
minutes_per_round: int = 30
|
||||||
|
|
||||||
|
# 每小时激活的Agent数量范围
|
||||||
|
agents_per_hour_min: int = 5
|
||||||
|
agents_per_hour_max: int = 20
|
||||||
|
|
||||||
|
# 高峰时段(活跃度提升)
|
||||||
|
peak_hours: List[int] = field(default_factory=lambda: [9, 10, 11, 14, 15, 20, 21, 22])
|
||||||
|
peak_activity_multiplier: float = 1.5
|
||||||
|
|
||||||
|
# 低谷时段(活跃度降低)
|
||||||
|
off_peak_hours: List[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5, 6])
|
||||||
|
off_peak_activity_multiplier: float = 0.3
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EventConfig:
|
||||||
|
"""事件配置"""
|
||||||
|
# 初始事件(模拟开始时的触发事件)
|
||||||
|
initial_posts: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
# 定时事件(在特定时间触发的事件)
|
||||||
|
scheduled_events: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
# 热点话题关键词
|
||||||
|
hot_topics: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# 舆论引导方向
|
||||||
|
narrative_direction: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PlatformConfig:
|
||||||
|
"""平台特定配置"""
|
||||||
|
platform: str # twitter or reddit
|
||||||
|
|
||||||
|
# 推荐算法权重
|
||||||
|
recency_weight: float = 0.4 # 时间新鲜度
|
||||||
|
popularity_weight: float = 0.3 # 热度
|
||||||
|
relevance_weight: float = 0.3 # 相关性
|
||||||
|
|
||||||
|
# 病毒传播阈值(达到多少互动后触发扩散)
|
||||||
|
viral_threshold: int = 10
|
||||||
|
|
||||||
|
# 回声室效应强度(相似观点聚集程度)
|
||||||
|
echo_chamber_strength: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SimulationParameters:
|
||||||
|
"""完整的模拟参数配置"""
|
||||||
|
# 基础信息
|
||||||
|
simulation_id: str
|
||||||
|
project_id: str
|
||||||
|
graph_id: str
|
||||||
|
simulation_requirement: str
|
||||||
|
|
||||||
|
# 时间配置
|
||||||
|
time_config: TimeSimulationConfig = field(default_factory=TimeSimulationConfig)
|
||||||
|
|
||||||
|
# Agent配置列表
|
||||||
|
agent_configs: List[AgentActivityConfig] = field(default_factory=list)
|
||||||
|
|
||||||
|
# 事件配置
|
||||||
|
event_config: EventConfig = field(default_factory=EventConfig)
|
||||||
|
|
||||||
|
# 平台配置
|
||||||
|
twitter_config: Optional[PlatformConfig] = None
|
||||||
|
reddit_config: Optional[PlatformConfig] = None
|
||||||
|
|
||||||
|
# LLM配置
|
||||||
|
llm_model: str = ""
|
||||||
|
llm_base_url: str = ""
|
||||||
|
|
||||||
|
# 生成元数据
|
||||||
|
generated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
generation_reasoning: str = "" # LLM的推理说明
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典"""
|
||||||
|
return {
|
||||||
|
"simulation_id": self.simulation_id,
|
||||||
|
"project_id": self.project_id,
|
||||||
|
"graph_id": self.graph_id,
|
||||||
|
"simulation_requirement": self.simulation_requirement,
|
||||||
|
"time_config": asdict(self.time_config),
|
||||||
|
"agent_configs": [asdict(a) for a in self.agent_configs],
|
||||||
|
"event_config": asdict(self.event_config),
|
||||||
|
"twitter_config": asdict(self.twitter_config) if self.twitter_config else None,
|
||||||
|
"reddit_config": asdict(self.reddit_config) if self.reddit_config else None,
|
||||||
|
"llm_model": self.llm_model,
|
||||||
|
"llm_base_url": self.llm_base_url,
|
||||||
|
"generated_at": self.generated_at,
|
||||||
|
"generation_reasoning": self.generation_reasoning,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_json(self, indent: int = 2) -> str:
|
||||||
|
"""转换为JSON字符串"""
|
||||||
|
return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent)
|
||||||
|
|
||||||
|
|
||||||
|
class SimulationConfigGenerator:
|
||||||
|
"""
|
||||||
|
模拟配置智能生成器
|
||||||
|
|
||||||
|
使用LLM分析模拟需求、文档内容、图谱实体信息,
|
||||||
|
自动生成最佳的模拟参数配置
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 上下文最大字符数
|
||||||
|
MAX_CONTEXT_LENGTH = 50000
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None
|
||||||
|
):
|
||||||
|
self.api_key = api_key or Config.LLM_API_KEY
|
||||||
|
self.base_url = base_url or Config.LLM_BASE_URL
|
||||||
|
self.model_name = model_name or Config.LLM_MODEL_NAME
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("LLM_API_KEY 未配置")
|
||||||
|
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=self.api_key,
|
||||||
|
base_url=self.base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_config(
|
||||||
|
self,
|
||||||
|
simulation_id: str,
|
||||||
|
project_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
simulation_requirement: str,
|
||||||
|
document_text: str,
|
||||||
|
entities: List[EntityNode],
|
||||||
|
enable_twitter: bool = True,
|
||||||
|
enable_reddit: bool = True,
|
||||||
|
) -> SimulationParameters:
|
||||||
|
"""
|
||||||
|
智能生成完整的模拟配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
simulation_id: 模拟ID
|
||||||
|
project_id: 项目ID
|
||||||
|
graph_id: 图谱ID
|
||||||
|
simulation_requirement: 模拟需求描述
|
||||||
|
document_text: 原始文档内容
|
||||||
|
entities: 过滤后的实体列表
|
||||||
|
enable_twitter: 是否启用Twitter
|
||||||
|
enable_reddit: 是否启用Reddit
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SimulationParameters: 完整的模拟参数
|
||||||
|
"""
|
||||||
|
logger.info(f"开始智能生成模拟配置: simulation_id={simulation_id}")
|
||||||
|
|
||||||
|
# 1. 构建上下文信息(截断到50000字符)
|
||||||
|
context = self._build_context(
|
||||||
|
simulation_requirement=simulation_requirement,
|
||||||
|
document_text=document_text,
|
||||||
|
entities=entities
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 调用LLM生成配置
|
||||||
|
llm_result = self._generate_config_with_llm(
|
||||||
|
context=context,
|
||||||
|
entities=entities,
|
||||||
|
enable_twitter=enable_twitter,
|
||||||
|
enable_reddit=enable_reddit
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 构建SimulationParameters对象
|
||||||
|
params = self._build_parameters(
|
||||||
|
simulation_id=simulation_id,
|
||||||
|
project_id=project_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
simulation_requirement=simulation_requirement,
|
||||||
|
entities=entities,
|
||||||
|
llm_result=llm_result,
|
||||||
|
enable_twitter=enable_twitter,
|
||||||
|
enable_reddit=enable_reddit
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"模拟配置生成完成: {len(params.agent_configs)} 个Agent配置")
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _build_context(
|
||||||
|
self,
|
||||||
|
simulation_requirement: str,
|
||||||
|
document_text: str,
|
||||||
|
entities: List[EntityNode]
|
||||||
|
) -> str:
|
||||||
|
"""构建LLM上下文,截断到最大长度"""
|
||||||
|
|
||||||
|
# 实体摘要
|
||||||
|
entity_summary = self._summarize_entities(entities)
|
||||||
|
|
||||||
|
# 构建上下文
|
||||||
|
context_parts = [
|
||||||
|
f"## 模拟需求\n{simulation_requirement}",
|
||||||
|
f"\n## 实体信息 ({len(entities)}个)\n{entity_summary}",
|
||||||
|
]
|
||||||
|
|
||||||
|
current_length = sum(len(p) for p in context_parts)
|
||||||
|
remaining_length = self.MAX_CONTEXT_LENGTH - current_length - 500 # 留500字符余量
|
||||||
|
|
||||||
|
if remaining_length > 0 and document_text:
|
||||||
|
doc_text = document_text[:remaining_length]
|
||||||
|
if len(document_text) > remaining_length:
|
||||||
|
doc_text += "\n...(文档已截断)"
|
||||||
|
context_parts.append(f"\n## 原始文档内容\n{doc_text}")
|
||||||
|
|
||||||
|
return "\n".join(context_parts)
|
||||||
|
|
||||||
|
def _summarize_entities(self, entities: List[EntityNode]) -> str:
|
||||||
|
"""生成实体摘要"""
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
# 按类型分组
|
||||||
|
by_type: Dict[str, List[EntityNode]] = {}
|
||||||
|
for e in entities:
|
||||||
|
t = e.get_entity_type() or "Unknown"
|
||||||
|
if t not in by_type:
|
||||||
|
by_type[t] = []
|
||||||
|
by_type[t].append(e)
|
||||||
|
|
||||||
|
for entity_type, type_entities in by_type.items():
|
||||||
|
lines.append(f"\n### {entity_type} ({len(type_entities)}个)")
|
||||||
|
for e in type_entities[:10]: # 每类最多显示10个
|
||||||
|
summary_preview = (e.summary[:100] + "...") if len(e.summary) > 100 else e.summary
|
||||||
|
lines.append(f"- {e.name}: {summary_preview}")
|
||||||
|
if len(type_entities) > 10:
|
||||||
|
lines.append(f" ... 还有 {len(type_entities) - 10} 个")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _generate_config_with_llm(
|
||||||
|
self,
|
||||||
|
context: str,
|
||||||
|
entities: List[EntityNode],
|
||||||
|
enable_twitter: bool,
|
||||||
|
enable_reddit: bool
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""调用LLM生成配置"""
|
||||||
|
|
||||||
|
# 构建实体列表用于Agent配置
|
||||||
|
entity_list = []
|
||||||
|
for i, e in enumerate(entities):
|
||||||
|
entity_list.append({
|
||||||
|
"agent_id": i,
|
||||||
|
"entity_uuid": e.uuid,
|
||||||
|
"entity_name": e.name,
|
||||||
|
"entity_type": e.get_entity_type() or "Unknown",
|
||||||
|
"summary": e.summary[:200] if e.summary else ""
|
||||||
|
})
|
||||||
|
|
||||||
|
prompt = f"""你是一个社交媒体舆论模拟专家。请根据以下信息,生成详细的模拟参数配置。
|
||||||
|
|
||||||
|
{context}
|
||||||
|
|
||||||
|
## 实体列表(需要为每个实体生成活动配置)
|
||||||
|
```json
|
||||||
|
{json.dumps(entity_list, ensure_ascii=False, indent=2)}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 任务
|
||||||
|
请生成一个JSON配置,包含以下部分:
|
||||||
|
|
||||||
|
1. **time_config** - 时间模拟配置
|
||||||
|
- total_simulation_hours: 模拟总时长(小时),根据事件性质决定(短期热点24-72小时,长期舆论168-336小时)
|
||||||
|
- minutes_per_round: 每轮代表的时间(分钟),建议15-60
|
||||||
|
- agents_per_hour_min/max: 每小时激活的Agent数量范围
|
||||||
|
- peak_hours: 高峰时段列表(0-23)
|
||||||
|
- off_peak_hours: 低谷时段列表
|
||||||
|
|
||||||
|
2. **agent_configs** - 每个Agent的活动配置(必须为每个实体生成)
|
||||||
|
对于每个agent_id,设置:
|
||||||
|
- activity_level: 活跃度(0.0-1.0),官方机构通常0.1-0.3,媒体0.3-0.5,个人0.5-0.9
|
||||||
|
- posts_per_hour: 每小时发帖频率,官方机构0.05-0.2,媒体0.5-2,个人0.1-1
|
||||||
|
- comments_per_hour: 每小时评论频率
|
||||||
|
- active_hours: 活跃时间段列表,官方通常工作时间,个人更分散
|
||||||
|
- response_delay_min/max: 响应延迟(模拟分钟),官方较慢(30-180),个人较快(1-30)
|
||||||
|
- sentiment_bias: 情感倾向(-1到1),根据实体立场设置
|
||||||
|
- stance: 立场(supportive/opposing/neutral/observer)
|
||||||
|
- influence_weight: 影响力权重,知名人物和媒体较高
|
||||||
|
|
||||||
|
3. **event_config** - 事件配置
|
||||||
|
- initial_posts: 初始帖子列表,包含content和poster_agent_id
|
||||||
|
- hot_topics: 热点话题关键词列表
|
||||||
|
- narrative_direction: 舆论发展方向描述
|
||||||
|
|
||||||
|
4. **platform_configs** - 平台配置(如果启用)
|
||||||
|
- viral_threshold: 病毒传播阈值
|
||||||
|
- echo_chamber_strength: 回声室效应强度(0-1)
|
||||||
|
|
||||||
|
5. **reasoning** - 你的推理说明,解释为什么这样设置参数
|
||||||
|
|
||||||
|
## 重要原则
|
||||||
|
- 官方机构(University、GovernmentAgency)发言频率低但影响力大
|
||||||
|
- 媒体(MediaOutlet)发言频率中等,传播速度快
|
||||||
|
- 个人(Student、PublicFigure)发言频率高但影响力分散
|
||||||
|
- 根据模拟需求判断各实体的立场和情感倾向
|
||||||
|
- 时间配置要符合真实社交媒体的使用规律
|
||||||
|
|
||||||
|
请返回JSON格式,不要包含markdown代码块标记。"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用重试机制调用LLM API
|
||||||
|
from ..utils.retry import RetryableAPIClient
|
||||||
|
|
||||||
|
retry_client = RetryableAPIClient(max_retries=3, initial_delay=2.0, max_delay=60.0)
|
||||||
|
|
||||||
|
def call_llm():
|
||||||
|
return self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "你是社交媒体舆论模拟专家,擅长设计真实的模拟参数。返回纯JSON格式,不要markdown。"
|
||||||
|
},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=8000
|
||||||
|
)
|
||||||
|
|
||||||
|
response = retry_client.call_with_retry(call_llm)
|
||||||
|
result = json.loads(response.choices[0].message.content)
|
||||||
|
logger.info(f"LLM配置生成成功")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM配置生成失败(已重试): {str(e)}")
|
||||||
|
# 返回默认配置
|
||||||
|
return self._generate_default_config(entities)
|
||||||
|
|
||||||
|
def _generate_default_config(self, entities: List[EntityNode]) -> Dict[str, Any]:
|
||||||
|
"""生成默认配置(LLM失败时的fallback)"""
|
||||||
|
agent_configs = []
|
||||||
|
|
||||||
|
for i, e in enumerate(entities):
|
||||||
|
entity_type = (e.get_entity_type() or "Unknown").lower()
|
||||||
|
|
||||||
|
# 根据实体类型设置默认参数
|
||||||
|
if entity_type in ["university", "governmentagency", "ngo"]:
|
||||||
|
config = {
|
||||||
|
"agent_id": i,
|
||||||
|
"activity_level": 0.2,
|
||||||
|
"posts_per_hour": 0.1,
|
||||||
|
"comments_per_hour": 0.05,
|
||||||
|
"active_hours": list(range(9, 18)),
|
||||||
|
"response_delay_min": 60,
|
||||||
|
"response_delay_max": 240,
|
||||||
|
"sentiment_bias": 0.0,
|
||||||
|
"stance": "neutral",
|
||||||
|
"influence_weight": 3.0
|
||||||
|
}
|
||||||
|
elif entity_type in ["mediaoutlet"]:
|
||||||
|
config = {
|
||||||
|
"agent_id": i,
|
||||||
|
"activity_level": 0.6,
|
||||||
|
"posts_per_hour": 1.0,
|
||||||
|
"comments_per_hour": 0.5,
|
||||||
|
"active_hours": list(range(6, 24)),
|
||||||
|
"response_delay_min": 5,
|
||||||
|
"response_delay_max": 30,
|
||||||
|
"sentiment_bias": 0.0,
|
||||||
|
"stance": "observer",
|
||||||
|
"influence_weight": 2.5
|
||||||
|
}
|
||||||
|
elif entity_type in ["publicfigure", "expert"]:
|
||||||
|
config = {
|
||||||
|
"agent_id": i,
|
||||||
|
"activity_level": 0.5,
|
||||||
|
"posts_per_hour": 0.3,
|
||||||
|
"comments_per_hour": 0.5,
|
||||||
|
"active_hours": list(range(8, 23)),
|
||||||
|
"response_delay_min": 10,
|
||||||
|
"response_delay_max": 60,
|
||||||
|
"sentiment_bias": 0.0,
|
||||||
|
"stance": "neutral",
|
||||||
|
"influence_weight": 2.0
|
||||||
|
}
|
||||||
|
else: # Student, Person, etc.
|
||||||
|
config = {
|
||||||
|
"agent_id": i,
|
||||||
|
"activity_level": 0.7,
|
||||||
|
"posts_per_hour": 0.5,
|
||||||
|
"comments_per_hour": 1.0,
|
||||||
|
"active_hours": list(range(7, 24)),
|
||||||
|
"response_delay_min": 1,
|
||||||
|
"response_delay_max": 20,
|
||||||
|
"sentiment_bias": 0.0,
|
||||||
|
"stance": "neutral",
|
||||||
|
"influence_weight": 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
agent_configs.append(config)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"time_config": {
|
||||||
|
"total_simulation_hours": 72,
|
||||||
|
"minutes_per_round": 30,
|
||||||
|
"agents_per_hour_min": max(1, len(entities) // 10),
|
||||||
|
"agents_per_hour_max": max(5, len(entities) // 3),
|
||||||
|
"peak_hours": [9, 10, 11, 14, 15, 20, 21, 22],
|
||||||
|
"off_peak_hours": [0, 1, 2, 3, 4, 5]
|
||||||
|
},
|
||||||
|
"agent_configs": agent_configs,
|
||||||
|
"event_config": {
|
||||||
|
"initial_posts": [],
|
||||||
|
"hot_topics": [],
|
||||||
|
"narrative_direction": ""
|
||||||
|
},
|
||||||
|
"reasoning": "使用默认配置(LLM生成失败)"
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_parameters(
|
||||||
|
self,
|
||||||
|
simulation_id: str,
|
||||||
|
project_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
simulation_requirement: str,
|
||||||
|
entities: List[EntityNode],
|
||||||
|
llm_result: Dict[str, Any],
|
||||||
|
enable_twitter: bool,
|
||||||
|
enable_reddit: bool
|
||||||
|
) -> SimulationParameters:
|
||||||
|
"""根据LLM结果构建SimulationParameters对象"""
|
||||||
|
|
||||||
|
# 时间配置
|
||||||
|
time_cfg = llm_result.get("time_config", {})
|
||||||
|
time_config = TimeSimulationConfig(
|
||||||
|
total_simulation_hours=time_cfg.get("total_simulation_hours", 72),
|
||||||
|
minutes_per_round=time_cfg.get("minutes_per_round", 30),
|
||||||
|
agents_per_hour_min=time_cfg.get("agents_per_hour_min", 5),
|
||||||
|
agents_per_hour_max=time_cfg.get("agents_per_hour_max", 20),
|
||||||
|
peak_hours=time_cfg.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]),
|
||||||
|
off_peak_hours=time_cfg.get("off_peak_hours", [0, 1, 2, 3, 4, 5]),
|
||||||
|
peak_activity_multiplier=time_cfg.get("peak_activity_multiplier", 1.5),
|
||||||
|
off_peak_activity_multiplier=time_cfg.get("off_peak_activity_multiplier", 0.3)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent配置
|
||||||
|
agent_configs = []
|
||||||
|
llm_agent_configs = {cfg["agent_id"]: cfg for cfg in llm_result.get("agent_configs", [])}
|
||||||
|
|
||||||
|
for i, entity in enumerate(entities):
|
||||||
|
cfg = llm_agent_configs.get(i, {})
|
||||||
|
|
||||||
|
agent_config = AgentActivityConfig(
|
||||||
|
agent_id=i,
|
||||||
|
entity_uuid=entity.uuid,
|
||||||
|
entity_name=entity.name,
|
||||||
|
entity_type=entity.get_entity_type() or "Unknown",
|
||||||
|
activity_level=cfg.get("activity_level", 0.5),
|
||||||
|
posts_per_hour=cfg.get("posts_per_hour", 0.5),
|
||||||
|
comments_per_hour=cfg.get("comments_per_hour", 1.0),
|
||||||
|
active_hours=cfg.get("active_hours", list(range(8, 23))),
|
||||||
|
response_delay_min=cfg.get("response_delay_min", 5),
|
||||||
|
response_delay_max=cfg.get("response_delay_max", 60),
|
||||||
|
sentiment_bias=cfg.get("sentiment_bias", 0.0),
|
||||||
|
stance=cfg.get("stance", "neutral"),
|
||||||
|
influence_weight=cfg.get("influence_weight", 1.0)
|
||||||
|
)
|
||||||
|
agent_configs.append(agent_config)
|
||||||
|
|
||||||
|
# 事件配置
|
||||||
|
event_cfg = llm_result.get("event_config", {})
|
||||||
|
event_config = EventConfig(
|
||||||
|
initial_posts=event_cfg.get("initial_posts", []),
|
||||||
|
scheduled_events=event_cfg.get("scheduled_events", []),
|
||||||
|
hot_topics=event_cfg.get("hot_topics", []),
|
||||||
|
narrative_direction=event_cfg.get("narrative_direction", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 平台配置
|
||||||
|
twitter_config = None
|
||||||
|
reddit_config = None
|
||||||
|
|
||||||
|
platform_cfgs = llm_result.get("platform_configs", {})
|
||||||
|
|
||||||
|
if enable_twitter:
|
||||||
|
tw_cfg = platform_cfgs.get("twitter", {})
|
||||||
|
twitter_config = PlatformConfig(
|
||||||
|
platform="twitter",
|
||||||
|
recency_weight=tw_cfg.get("recency_weight", 0.4),
|
||||||
|
popularity_weight=tw_cfg.get("popularity_weight", 0.3),
|
||||||
|
relevance_weight=tw_cfg.get("relevance_weight", 0.3),
|
||||||
|
viral_threshold=tw_cfg.get("viral_threshold", 10),
|
||||||
|
echo_chamber_strength=tw_cfg.get("echo_chamber_strength", 0.5)
|
||||||
|
)
|
||||||
|
|
||||||
|
if enable_reddit:
|
||||||
|
rd_cfg = platform_cfgs.get("reddit", {})
|
||||||
|
reddit_config = PlatformConfig(
|
||||||
|
platform="reddit",
|
||||||
|
recency_weight=rd_cfg.get("recency_weight", 0.3),
|
||||||
|
popularity_weight=rd_cfg.get("popularity_weight", 0.4),
|
||||||
|
relevance_weight=rd_cfg.get("relevance_weight", 0.3),
|
||||||
|
viral_threshold=rd_cfg.get("viral_threshold", 15),
|
||||||
|
echo_chamber_strength=rd_cfg.get("echo_chamber_strength", 0.6)
|
||||||
|
)
|
||||||
|
|
||||||
|
return SimulationParameters(
|
||||||
|
simulation_id=simulation_id,
|
||||||
|
project_id=project_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
simulation_requirement=simulation_requirement,
|
||||||
|
time_config=time_config,
|
||||||
|
agent_configs=agent_configs,
|
||||||
|
event_config=event_config,
|
||||||
|
twitter_config=twitter_config,
|
||||||
|
reddit_config=reddit_config,
|
||||||
|
llm_model=self.model_name,
|
||||||
|
llm_base_url=self.base_url,
|
||||||
|
generation_reasoning=llm_result.get("reasoning", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
546
backend/app/services/simulation_manager.py
Normal file
546
backend/app/services/simulation_manager.py
Normal file
|
|
@ -0,0 +1,546 @@
|
||||||
|
"""
|
||||||
|
OASIS模拟管理器
|
||||||
|
管理Twitter和Reddit双平台并行模拟
|
||||||
|
使用预设脚本 + LLM智能生成配置参数
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from ..config import Config
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
from .zep_entity_reader import ZepEntityReader, FilteredEntities
|
||||||
|
from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
|
||||||
|
from .simulation_config_generator import SimulationConfigGenerator, SimulationParameters
|
||||||
|
|
||||||
|
logger = get_logger('mirofish.simulation')
|
||||||
|
|
||||||
|
|
||||||
|
class SimulationStatus(str, Enum):
|
||||||
|
"""模拟状态"""
|
||||||
|
CREATED = "created"
|
||||||
|
PREPARING = "preparing"
|
||||||
|
READY = "ready"
|
||||||
|
RUNNING = "running"
|
||||||
|
PAUSED = "paused"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class PlatformType(str, Enum):
|
||||||
|
"""平台类型"""
|
||||||
|
TWITTER = "twitter"
|
||||||
|
REDDIT = "reddit"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SimulationState:
|
||||||
|
"""模拟状态"""
|
||||||
|
simulation_id: str
|
||||||
|
project_id: str
|
||||||
|
graph_id: str
|
||||||
|
|
||||||
|
# 平台启用状态
|
||||||
|
enable_twitter: bool = True
|
||||||
|
enable_reddit: bool = True
|
||||||
|
|
||||||
|
# 状态
|
||||||
|
status: SimulationStatus = SimulationStatus.CREATED
|
||||||
|
|
||||||
|
# 准备阶段数据
|
||||||
|
entities_count: int = 0
|
||||||
|
profiles_count: int = 0
|
||||||
|
entity_types: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# 配置生成信息
|
||||||
|
config_generated: bool = False
|
||||||
|
config_reasoning: str = ""
|
||||||
|
|
||||||
|
# 运行时数据
|
||||||
|
current_round: int = 0
|
||||||
|
twitter_status: str = "not_started"
|
||||||
|
reddit_status: str = "not_started"
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
|
||||||
|
# 错误信息
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""完整状态字典(内部使用)"""
|
||||||
|
return {
|
||||||
|
"simulation_id": self.simulation_id,
|
||||||
|
"project_id": self.project_id,
|
||||||
|
"graph_id": self.graph_id,
|
||||||
|
"enable_twitter": self.enable_twitter,
|
||||||
|
"enable_reddit": self.enable_reddit,
|
||||||
|
"status": self.status.value,
|
||||||
|
"entities_count": self.entities_count,
|
||||||
|
"profiles_count": self.profiles_count,
|
||||||
|
"entity_types": self.entity_types,
|
||||||
|
"config_generated": self.config_generated,
|
||||||
|
"config_reasoning": self.config_reasoning,
|
||||||
|
"current_round": self.current_round,
|
||||||
|
"twitter_status": self.twitter_status,
|
||||||
|
"reddit_status": self.reddit_status,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
"updated_at": self.updated_at,
|
||||||
|
"error": self.error,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_simple_dict(self) -> Dict[str, Any]:
|
||||||
|
"""简化状态字典(API返回使用)"""
|
||||||
|
return {
|
||||||
|
"simulation_id": self.simulation_id,
|
||||||
|
"project_id": self.project_id,
|
||||||
|
"graph_id": self.graph_id,
|
||||||
|
"status": self.status.value,
|
||||||
|
"entities_count": self.entities_count,
|
||||||
|
"profiles_count": self.profiles_count,
|
||||||
|
"entity_types": self.entity_types,
|
||||||
|
"config_generated": self.config_generated,
|
||||||
|
"error": self.error,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SimulationManager:
|
||||||
|
"""
|
||||||
|
模拟管理器
|
||||||
|
|
||||||
|
核心功能:
|
||||||
|
1. 从Zep图谱读取实体并过滤
|
||||||
|
2. 生成OASIS Agent Profile
|
||||||
|
3. 使用LLM智能生成模拟配置参数
|
||||||
|
4. 准备预设脚本所需的所有文件
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 模拟数据存储目录
|
||||||
|
SIMULATION_DATA_DIR = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
'../../uploads/simulations'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预设脚本目录
|
||||||
|
SCRIPTS_DIR = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
'../../scripts'
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 确保目录存在
|
||||||
|
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# 内存中的模拟状态缓存
|
||||||
|
self._simulations: Dict[str, SimulationState] = {}
|
||||||
|
|
||||||
|
def _get_simulation_dir(self, simulation_id: str) -> str:
|
||||||
|
"""获取模拟数据目录"""
|
||||||
|
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
|
||||||
|
os.makedirs(sim_dir, exist_ok=True)
|
||||||
|
return sim_dir
|
||||||
|
|
||||||
|
def _save_simulation_state(self, state: SimulationState):
|
||||||
|
"""保存模拟状态到文件"""
|
||||||
|
sim_dir = self._get_simulation_dir(state.simulation_id)
|
||||||
|
state_file = os.path.join(sim_dir, "state.json")
|
||||||
|
|
||||||
|
state.updated_at = datetime.now().isoformat()
|
||||||
|
|
||||||
|
with open(state_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(state.to_dict(), f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
self._simulations[state.simulation_id] = state
|
||||||
|
|
||||||
|
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
|
||||||
|
"""从文件加载模拟状态"""
|
||||||
|
if simulation_id in self._simulations:
|
||||||
|
return self._simulations[simulation_id]
|
||||||
|
|
||||||
|
sim_dir = self._get_simulation_dir(simulation_id)
|
||||||
|
state_file = os.path.join(sim_dir, "state.json")
|
||||||
|
|
||||||
|
if not os.path.exists(state_file):
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(state_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
state = SimulationState(
|
||||||
|
simulation_id=simulation_id,
|
||||||
|
project_id=data.get("project_id", ""),
|
||||||
|
graph_id=data.get("graph_id", ""),
|
||||||
|
enable_twitter=data.get("enable_twitter", True),
|
||||||
|
enable_reddit=data.get("enable_reddit", True),
|
||||||
|
status=SimulationStatus(data.get("status", "created")),
|
||||||
|
entities_count=data.get("entities_count", 0),
|
||||||
|
profiles_count=data.get("profiles_count", 0),
|
||||||
|
entity_types=data.get("entity_types", []),
|
||||||
|
config_generated=data.get("config_generated", False),
|
||||||
|
config_reasoning=data.get("config_reasoning", ""),
|
||||||
|
current_round=data.get("current_round", 0),
|
||||||
|
twitter_status=data.get("twitter_status", "not_started"),
|
||||||
|
reddit_status=data.get("reddit_status", "not_started"),
|
||||||
|
created_at=data.get("created_at", datetime.now().isoformat()),
|
||||||
|
updated_at=data.get("updated_at", datetime.now().isoformat()),
|
||||||
|
error=data.get("error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._simulations[simulation_id] = state
|
||||||
|
return state
|
||||||
|
|
||||||
|
def create_simulation(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
enable_twitter: bool = True,
|
||||||
|
enable_reddit: bool = True,
|
||||||
|
) -> SimulationState:
|
||||||
|
"""
|
||||||
|
创建新的模拟
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: 项目ID
|
||||||
|
graph_id: Zep图谱ID
|
||||||
|
enable_twitter: 是否启用Twitter模拟
|
||||||
|
enable_reddit: 是否启用Reddit模拟
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SimulationState
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
simulation_id = f"sim_{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
|
state = SimulationState(
|
||||||
|
simulation_id=simulation_id,
|
||||||
|
project_id=project_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
enable_twitter=enable_twitter,
|
||||||
|
enable_reddit=enable_reddit,
|
||||||
|
status=SimulationStatus.CREATED,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._save_simulation_state(state)
|
||||||
|
logger.info(f"创建模拟: {simulation_id}, project={project_id}, graph={graph_id}")
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def prepare_simulation(
|
||||||
|
self,
|
||||||
|
simulation_id: str,
|
||||||
|
simulation_requirement: str,
|
||||||
|
document_text: str,
|
||||||
|
defined_entity_types: Optional[List[str]] = None,
|
||||||
|
use_llm_for_profiles: bool = True,
|
||||||
|
progress_callback: Optional[callable] = None
|
||||||
|
) -> SimulationState:
|
||||||
|
"""
|
||||||
|
准备模拟环境(全程自动化)
|
||||||
|
|
||||||
|
步骤:
|
||||||
|
1. 从Zep图谱读取并过滤实体
|
||||||
|
2. 为每个实体生成OASIS Agent Profile(可选LLM增强)
|
||||||
|
3. 使用LLM智能生成模拟配置参数(时间、活跃度、发言频率等)
|
||||||
|
4. 保存配置文件和Profile文件
|
||||||
|
5. 复制预设脚本到模拟目录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
simulation_id: 模拟ID
|
||||||
|
simulation_requirement: 模拟需求描述(用于LLM生成配置)
|
||||||
|
document_text: 原始文档内容(用于LLM理解背景)
|
||||||
|
defined_entity_types: 预定义的实体类型(可选)
|
||||||
|
use_llm_for_profiles: 是否使用LLM生成详细人设
|
||||||
|
progress_callback: 进度回调函数 (stage, progress, message)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SimulationState
|
||||||
|
"""
|
||||||
|
state = self._load_simulation_state(simulation_id)
|
||||||
|
if not state:
|
||||||
|
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
state.status = SimulationStatus.PREPARING
|
||||||
|
self._save_simulation_state(state)
|
||||||
|
|
||||||
|
sim_dir = self._get_simulation_dir(simulation_id)
|
||||||
|
|
||||||
|
# ========== 阶段1: 读取并过滤实体 ==========
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback("reading", 0, "正在连接Zep图谱...")
|
||||||
|
|
||||||
|
reader = ZepEntityReader()
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback("reading", 30, "正在读取节点数据...")
|
||||||
|
|
||||||
|
filtered = reader.filter_defined_entities(
|
||||||
|
graph_id=state.graph_id,
|
||||||
|
defined_entity_types=defined_entity_types,
|
||||||
|
enrich_with_edges=True
|
||||||
|
)
|
||||||
|
|
||||||
|
state.entities_count = filtered.filtered_count
|
||||||
|
state.entity_types = list(filtered.entity_types)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
"reading", 100,
|
||||||
|
f"完成,共 {filtered.filtered_count} 个实体",
|
||||||
|
current=filtered.filtered_count,
|
||||||
|
total=filtered.filtered_count
|
||||||
|
)
|
||||||
|
|
||||||
|
if filtered.filtered_count == 0:
|
||||||
|
state.status = SimulationStatus.FAILED
|
||||||
|
state.error = "没有找到符合条件的实体,请检查图谱是否正确构建"
|
||||||
|
self._save_simulation_state(state)
|
||||||
|
return state
|
||||||
|
|
||||||
|
# ========== 阶段2: 生成Agent Profile ==========
|
||||||
|
total_entities = len(filtered.entities)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
"generating_profiles", 0,
|
||||||
|
"开始生成...",
|
||||||
|
current=0,
|
||||||
|
total=total_entities
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = OasisProfileGenerator()
|
||||||
|
|
||||||
|
def profile_progress(current, total, msg):
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
"generating_profiles",
|
||||||
|
int(current / total * 100),
|
||||||
|
msg,
|
||||||
|
current=current,
|
||||||
|
total=total,
|
||||||
|
item_name=msg
|
||||||
|
)
|
||||||
|
|
||||||
|
profiles = generator.generate_profiles_from_entities(
|
||||||
|
entities=filtered.entities,
|
||||||
|
use_llm=use_llm_for_profiles,
|
||||||
|
progress_callback=profile_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
state.profiles_count = len(profiles)
|
||||||
|
|
||||||
|
# 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式)
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
"generating_profiles", 95,
|
||||||
|
"保存Profile文件...",
|
||||||
|
current=total_entities,
|
||||||
|
total=total_entities
|
||||||
|
)
|
||||||
|
|
||||||
|
if state.enable_reddit:
|
||||||
|
generator.save_profiles(
|
||||||
|
profiles=profiles,
|
||||||
|
file_path=os.path.join(sim_dir, "reddit_profiles.json"),
|
||||||
|
platform="reddit"
|
||||||
|
)
|
||||||
|
|
||||||
|
if state.enable_twitter:
|
||||||
|
# Twitter使用CSV格式!这是OASIS的要求
|
||||||
|
generator.save_profiles(
|
||||||
|
profiles=profiles,
|
||||||
|
file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
|
||||||
|
platform="twitter"
|
||||||
|
)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
"generating_profiles", 100,
|
||||||
|
f"完成,共 {len(profiles)} 个Profile",
|
||||||
|
current=len(profiles),
|
||||||
|
total=len(profiles)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========== 阶段3: LLM智能生成模拟配置 ==========
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
"generating_config", 0,
|
||||||
|
"正在分析模拟需求...",
|
||||||
|
current=0,
|
||||||
|
total=3
|
||||||
|
)
|
||||||
|
|
||||||
|
config_generator = SimulationConfigGenerator()
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
"generating_config", 30,
|
||||||
|
"正在调用LLM生成配置...",
|
||||||
|
current=1,
|
||||||
|
total=3
|
||||||
|
)
|
||||||
|
|
||||||
|
sim_params = config_generator.generate_config(
|
||||||
|
simulation_id=simulation_id,
|
||||||
|
project_id=state.project_id,
|
||||||
|
graph_id=state.graph_id,
|
||||||
|
simulation_requirement=simulation_requirement,
|
||||||
|
document_text=document_text,
|
||||||
|
entities=filtered.entities,
|
||||||
|
enable_twitter=state.enable_twitter,
|
||||||
|
enable_reddit=state.enable_reddit
|
||||||
|
)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
"generating_config", 70,
|
||||||
|
"正在保存配置文件...",
|
||||||
|
current=2,
|
||||||
|
total=3
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存配置文件
|
||||||
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||||
|
with open(config_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(sim_params.to_json())
|
||||||
|
|
||||||
|
state.config_generated = True
|
||||||
|
state.config_reasoning = sim_params.generation_reasoning
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
"generating_config", 100,
|
||||||
|
"配置生成完成",
|
||||||
|
current=3,
|
||||||
|
total=3
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========== 阶段4: 复制预设脚本 ==========
|
||||||
|
script_files = ["run_twitter_simulation.py", "run_reddit_simulation.py",
|
||||||
|
"run_parallel_simulation.py", "action_logger.py"]
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
"copying_scripts", 0,
|
||||||
|
"开始准备脚本...",
|
||||||
|
current=0,
|
||||||
|
total=len(script_files)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._copy_preset_scripts(sim_dir)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
"copying_scripts", 100,
|
||||||
|
f"完成,共 {len(script_files)} 个脚本",
|
||||||
|
current=len(script_files),
|
||||||
|
total=len(script_files)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新状态
|
||||||
|
state.status = SimulationStatus.READY
|
||||||
|
self._save_simulation_state(state)
|
||||||
|
|
||||||
|
logger.info(f"模拟准备完成: {simulation_id}, "
|
||||||
|
f"entities={state.entities_count}, profiles={state.profiles_count}")
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模拟准备失败: {simulation_id}, error={str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
state.status = SimulationStatus.FAILED
|
||||||
|
state.error = str(e)
|
||||||
|
self._save_simulation_state(state)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _copy_preset_scripts(self, sim_dir: str):
|
||||||
|
"""复制预设脚本到模拟目录"""
|
||||||
|
scripts = [
|
||||||
|
"run_twitter_simulation.py",
|
||||||
|
"run_reddit_simulation.py",
|
||||||
|
"run_parallel_simulation.py"
|
||||||
|
]
|
||||||
|
|
||||||
|
for script in scripts:
|
||||||
|
src = os.path.join(self.SCRIPTS_DIR, script)
|
||||||
|
dst = os.path.join(sim_dir, script)
|
||||||
|
|
||||||
|
if os.path.exists(src):
|
||||||
|
shutil.copy2(src, dst)
|
||||||
|
logger.debug(f"复制脚本: {script}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"预设脚本不存在: {src}")
|
||||||
|
|
||||||
|
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
||||||
|
"""获取模拟状态"""
|
||||||
|
return self._load_simulation_state(simulation_id)
|
||||||
|
|
||||||
|
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
|
||||||
|
"""列出所有模拟"""
|
||||||
|
simulations = []
|
||||||
|
|
||||||
|
if os.path.exists(self.SIMULATION_DATA_DIR):
|
||||||
|
for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
|
||||||
|
state = self._load_simulation_state(sim_id)
|
||||||
|
if state:
|
||||||
|
if project_id is None or state.project_id == project_id:
|
||||||
|
simulations.append(state)
|
||||||
|
|
||||||
|
return simulations
|
||||||
|
|
||||||
|
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
|
||||||
|
"""获取模拟的Agent Profile"""
|
||||||
|
state = self._load_simulation_state(simulation_id)
|
||||||
|
if not state:
|
||||||
|
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||||
|
|
||||||
|
sim_dir = self._get_simulation_dir(simulation_id)
|
||||||
|
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
|
||||||
|
|
||||||
|
if not os.path.exists(profile_path):
|
||||||
|
return []
|
||||||
|
|
||||||
|
with open(profile_path, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""获取模拟配置"""
|
||||||
|
sim_dir = self._get_simulation_dir(simulation_id)
|
||||||
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||||
|
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
|
||||||
|
"""获取运行说明"""
|
||||||
|
sim_dir = self._get_simulation_dir(simulation_id)
|
||||||
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"simulation_dir": sim_dir,
|
||||||
|
"config_file": config_path,
|
||||||
|
"commands": {
|
||||||
|
"twitter": f"python run_twitter_simulation.py --config simulation_config.json",
|
||||||
|
"reddit": f"python run_reddit_simulation.py --config simulation_config.json",
|
||||||
|
"parallel": f"python run_parallel_simulation.py --config simulation_config.json",
|
||||||
|
},
|
||||||
|
"instructions": (
|
||||||
|
f"1. 进入模拟目录: cd {sim_dir}\n"
|
||||||
|
f"2. 激活conda环境: conda activate MiroFish\n"
|
||||||
|
f"3. 运行模拟:\n"
|
||||||
|
f" - 单独运行Twitter: python run_twitter_simulation.py --config simulation_config.json\n"
|
||||||
|
f" - 单独运行Reddit: python run_reddit_simulation.py --config simulation_config.json\n"
|
||||||
|
f" - 并行运行双平台: python run_parallel_simulation.py --config simulation_config.json"
|
||||||
|
)
|
||||||
|
}
|
||||||
670
backend/app/services/simulation_runner.py
Normal file
670
backend/app/services/simulation_runner.py
Normal file
|
|
@ -0,0 +1,670 @@
|
||||||
|
"""
|
||||||
|
OASIS模拟运行器
|
||||||
|
在后台运行模拟并记录每个Agent的动作,支持实时状态监控
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import subprocess
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
|
from ..config import Config
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger('mirofish.simulation_runner')
|
||||||
|
|
||||||
|
|
||||||
|
class RunnerStatus(str, Enum):
|
||||||
|
"""运行器状态"""
|
||||||
|
IDLE = "idle"
|
||||||
|
STARTING = "starting"
|
||||||
|
RUNNING = "running"
|
||||||
|
PAUSED = "paused"
|
||||||
|
STOPPING = "stopping"
|
||||||
|
STOPPED = "stopped"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentAction:
|
||||||
|
"""Agent动作记录"""
|
||||||
|
round_num: int
|
||||||
|
timestamp: str
|
||||||
|
platform: str # twitter / reddit
|
||||||
|
agent_id: int
|
||||||
|
agent_name: str
|
||||||
|
action_type: str # CREATE_POST, LIKE_POST, etc.
|
||||||
|
action_args: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
result: Optional[str] = None
|
||||||
|
success: bool = True
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"round_num": self.round_num,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
"platform": self.platform,
|
||||||
|
"agent_id": self.agent_id,
|
||||||
|
"agent_name": self.agent_name,
|
||||||
|
"action_type": self.action_type,
|
||||||
|
"action_args": self.action_args,
|
||||||
|
"result": self.result,
|
||||||
|
"success": self.success,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RoundSummary:
|
||||||
|
"""每轮摘要"""
|
||||||
|
round_num: int
|
||||||
|
start_time: str
|
||||||
|
end_time: Optional[str] = None
|
||||||
|
simulated_hour: int = 0
|
||||||
|
twitter_actions: int = 0
|
||||||
|
reddit_actions: int = 0
|
||||||
|
active_agents: List[int] = field(default_factory=list)
|
||||||
|
actions: List[AgentAction] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"round_num": self.round_num,
|
||||||
|
"start_time": self.start_time,
|
||||||
|
"end_time": self.end_time,
|
||||||
|
"simulated_hour": self.simulated_hour,
|
||||||
|
"twitter_actions": self.twitter_actions,
|
||||||
|
"reddit_actions": self.reddit_actions,
|
||||||
|
"active_agents": self.active_agents,
|
||||||
|
"actions_count": len(self.actions),
|
||||||
|
"actions": [a.to_dict() for a in self.actions],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SimulationRunState:
|
||||||
|
"""模拟运行状态(实时)"""
|
||||||
|
simulation_id: str
|
||||||
|
runner_status: RunnerStatus = RunnerStatus.IDLE
|
||||||
|
|
||||||
|
# 进度信息
|
||||||
|
current_round: int = 0
|
||||||
|
total_rounds: int = 0
|
||||||
|
simulated_hours: int = 0
|
||||||
|
total_simulation_hours: int = 0
|
||||||
|
|
||||||
|
# 平台状态
|
||||||
|
twitter_running: bool = False
|
||||||
|
reddit_running: bool = False
|
||||||
|
twitter_actions_count: int = 0
|
||||||
|
reddit_actions_count: int = 0
|
||||||
|
|
||||||
|
# 每轮摘要
|
||||||
|
rounds: List[RoundSummary] = field(default_factory=list)
|
||||||
|
|
||||||
|
# 最近动作(用于前端实时展示)
|
||||||
|
recent_actions: List[AgentAction] = field(default_factory=list)
|
||||||
|
max_recent_actions: int = 50
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
started_at: Optional[str] = None
|
||||||
|
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
completed_at: Optional[str] = None
|
||||||
|
|
||||||
|
# 错误信息
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
# 进程ID(用于停止)
|
||||||
|
process_pid: Optional[int] = None
|
||||||
|
|
||||||
|
def add_action(self, action: AgentAction):
|
||||||
|
"""添加动作到最近动作列表"""
|
||||||
|
self.recent_actions.insert(0, action)
|
||||||
|
if len(self.recent_actions) > self.max_recent_actions:
|
||||||
|
self.recent_actions = self.recent_actions[:self.max_recent_actions]
|
||||||
|
|
||||||
|
if action.platform == "twitter":
|
||||||
|
self.twitter_actions_count += 1
|
||||||
|
else:
|
||||||
|
self.reddit_actions_count += 1
|
||||||
|
|
||||||
|
self.updated_at = datetime.now().isoformat()
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"simulation_id": self.simulation_id,
|
||||||
|
"runner_status": self.runner_status.value,
|
||||||
|
"current_round": self.current_round,
|
||||||
|
"total_rounds": self.total_rounds,
|
||||||
|
"simulated_hours": self.simulated_hours,
|
||||||
|
"total_simulation_hours": self.total_simulation_hours,
|
||||||
|
"progress_percent": round(self.current_round / max(self.total_rounds, 1) * 100, 1),
|
||||||
|
"twitter_running": self.twitter_running,
|
||||||
|
"reddit_running": self.reddit_running,
|
||||||
|
"twitter_actions_count": self.twitter_actions_count,
|
||||||
|
"reddit_actions_count": self.reddit_actions_count,
|
||||||
|
"total_actions_count": self.twitter_actions_count + self.reddit_actions_count,
|
||||||
|
"started_at": self.started_at,
|
||||||
|
"updated_at": self.updated_at,
|
||||||
|
"completed_at": self.completed_at,
|
||||||
|
"error": self.error,
|
||||||
|
"process_pid": self.process_pid,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_detail_dict(self) -> Dict[str, Any]:
|
||||||
|
"""包含最近动作的详细信息"""
|
||||||
|
result = self.to_dict()
|
||||||
|
result["recent_actions"] = [a.to_dict() for a in self.recent_actions]
|
||||||
|
result["rounds_count"] = len(self.rounds)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class SimulationRunner:
|
||||||
|
"""
|
||||||
|
模拟运行器
|
||||||
|
|
||||||
|
负责:
|
||||||
|
1. 在后台进程中运行OASIS模拟
|
||||||
|
2. 解析运行日志,记录每个Agent的动作
|
||||||
|
3. 提供实时状态查询接口
|
||||||
|
4. 支持暂停/停止/恢复操作
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 运行状态存储目录
|
||||||
|
RUN_STATE_DIR = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
'../../uploads/simulations'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 内存中的运行状态
|
||||||
|
_run_states: Dict[str, SimulationRunState] = {}
|
||||||
|
_processes: Dict[str, subprocess.Popen] = {}
|
||||||
|
_action_queues: Dict[str, Queue] = {}
|
||||||
|
_monitor_threads: Dict[str, threading.Thread] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
|
||||||
|
"""获取运行状态"""
|
||||||
|
if simulation_id in cls._run_states:
|
||||||
|
return cls._run_states[simulation_id]
|
||||||
|
|
||||||
|
# 尝试从文件加载
|
||||||
|
state = cls._load_run_state(simulation_id)
|
||||||
|
if state:
|
||||||
|
cls._run_states[simulation_id] = state
|
||||||
|
return state
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
|
||||||
|
"""从文件加载运行状态"""
|
||||||
|
state_file = os.path.join(cls.RUN_STATE_DIR, simulation_id, "run_state.json")
|
||||||
|
if not os.path.exists(state_file):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(state_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
state = SimulationRunState(
|
||||||
|
simulation_id=simulation_id,
|
||||||
|
runner_status=RunnerStatus(data.get("runner_status", "idle")),
|
||||||
|
current_round=data.get("current_round", 0),
|
||||||
|
total_rounds=data.get("total_rounds", 0),
|
||||||
|
simulated_hours=data.get("simulated_hours", 0),
|
||||||
|
total_simulation_hours=data.get("total_simulation_hours", 0),
|
||||||
|
twitter_running=data.get("twitter_running", False),
|
||||||
|
reddit_running=data.get("reddit_running", False),
|
||||||
|
twitter_actions_count=data.get("twitter_actions_count", 0),
|
||||||
|
reddit_actions_count=data.get("reddit_actions_count", 0),
|
||||||
|
started_at=data.get("started_at"),
|
||||||
|
updated_at=data.get("updated_at", datetime.now().isoformat()),
|
||||||
|
completed_at=data.get("completed_at"),
|
||||||
|
error=data.get("error"),
|
||||||
|
process_pid=data.get("process_pid"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载最近动作
|
||||||
|
actions_data = data.get("recent_actions", [])
|
||||||
|
for a in actions_data:
|
||||||
|
state.recent_actions.append(AgentAction(
|
||||||
|
round_num=a.get("round_num", 0),
|
||||||
|
timestamp=a.get("timestamp", ""),
|
||||||
|
platform=a.get("platform", ""),
|
||||||
|
agent_id=a.get("agent_id", 0),
|
||||||
|
agent_name=a.get("agent_name", ""),
|
||||||
|
action_type=a.get("action_type", ""),
|
||||||
|
action_args=a.get("action_args", {}),
|
||||||
|
result=a.get("result"),
|
||||||
|
success=a.get("success", True),
|
||||||
|
))
|
||||||
|
|
||||||
|
return state
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载运行状态失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _save_run_state(cls, state: SimulationRunState):
|
||||||
|
"""保存运行状态到文件"""
|
||||||
|
sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id)
|
||||||
|
os.makedirs(sim_dir, exist_ok=True)
|
||||||
|
state_file = os.path.join(sim_dir, "run_state.json")
|
||||||
|
|
||||||
|
data = state.to_detail_dict()
|
||||||
|
|
||||||
|
with open(state_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
cls._run_states[state.simulation_id] = state
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_simulation(
|
||||||
|
cls,
|
||||||
|
simulation_id: str,
|
||||||
|
platform: str = "parallel" # twitter / reddit / parallel
|
||||||
|
) -> SimulationRunState:
|
||||||
|
"""
|
||||||
|
启动模拟
|
||||||
|
|
||||||
|
Args:
|
||||||
|
simulation_id: 模拟ID
|
||||||
|
platform: 运行平台 (twitter/reddit/parallel)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SimulationRunState
|
||||||
|
"""
|
||||||
|
# 检查是否已在运行
|
||||||
|
existing = cls.get_run_state(simulation_id)
|
||||||
|
if existing and existing.runner_status in [RunnerStatus.RUNNING, RunnerStatus.STARTING]:
|
||||||
|
raise ValueError(f"模拟已在运行中: {simulation_id}")
|
||||||
|
|
||||||
|
# 加载模拟配置
|
||||||
|
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||||
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||||
|
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
raise ValueError(f"模拟配置不存在,请先调用 /prepare 接口")
|
||||||
|
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
# 初始化运行状态
|
||||||
|
time_config = config.get("time_config", {})
|
||||||
|
total_hours = time_config.get("total_simulation_hours", 72)
|
||||||
|
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||||
|
total_rounds = int(total_hours * 60 / minutes_per_round)
|
||||||
|
|
||||||
|
state = SimulationRunState(
|
||||||
|
simulation_id=simulation_id,
|
||||||
|
runner_status=RunnerStatus.STARTING,
|
||||||
|
total_rounds=total_rounds,
|
||||||
|
total_simulation_hours=total_hours,
|
||||||
|
started_at=datetime.now().isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
cls._save_run_state(state)
|
||||||
|
|
||||||
|
# 确定运行哪个脚本
|
||||||
|
if platform == "twitter":
|
||||||
|
script_name = "run_twitter_simulation.py"
|
||||||
|
state.twitter_running = True
|
||||||
|
elif platform == "reddit":
|
||||||
|
script_name = "run_reddit_simulation.py"
|
||||||
|
state.reddit_running = True
|
||||||
|
else:
|
||||||
|
script_name = "run_parallel_simulation.py"
|
||||||
|
state.twitter_running = True
|
||||||
|
state.reddit_running = True
|
||||||
|
|
||||||
|
script_path = os.path.join(sim_dir, script_name)
|
||||||
|
|
||||||
|
if not os.path.exists(script_path):
|
||||||
|
raise ValueError(f"脚本不存在: {script_path}")
|
||||||
|
|
||||||
|
# 创建动作队列
|
||||||
|
action_queue = Queue()
|
||||||
|
cls._action_queues[simulation_id] = action_queue
|
||||||
|
|
||||||
|
# 启动模拟进程
|
||||||
|
try:
|
||||||
|
# 构建运行命令
|
||||||
|
cmd = [
|
||||||
|
sys.executable, # Python解释器
|
||||||
|
script_path,
|
||||||
|
"--config", "simulation_config.json",
|
||||||
|
"--action-log", "actions.jsonl", # 动作日志文件
|
||||||
|
]
|
||||||
|
|
||||||
|
# 设置工作目录为模拟目录
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
cwd=sim_dir,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
state.process_pid = process.pid
|
||||||
|
state.runner_status = RunnerStatus.RUNNING
|
||||||
|
cls._processes[simulation_id] = process
|
||||||
|
cls._save_run_state(state)
|
||||||
|
|
||||||
|
# 启动监控线程
|
||||||
|
monitor_thread = threading.Thread(
|
||||||
|
target=cls._monitor_simulation,
|
||||||
|
args=(simulation_id,),
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
monitor_thread.start()
|
||||||
|
cls._monitor_threads[simulation_id] = monitor_thread
|
||||||
|
|
||||||
|
logger.info(f"模拟启动成功: {simulation_id}, pid={process.pid}, platform={platform}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
state.runner_status = RunnerStatus.FAILED
|
||||||
|
state.error = str(e)
|
||||||
|
cls._save_run_state(state)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _monitor_simulation(cls, simulation_id: str):
|
||||||
|
"""监控模拟进程,解析动作日志"""
|
||||||
|
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||||
|
actions_log = os.path.join(sim_dir, "actions.jsonl")
|
||||||
|
|
||||||
|
process = cls._processes.get(simulation_id)
|
||||||
|
state = cls.get_run_state(simulation_id)
|
||||||
|
|
||||||
|
if not process or not state:
|
||||||
|
return
|
||||||
|
|
||||||
|
last_position = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while process.poll() is None: # 进程仍在运行
|
||||||
|
# 读取动作日志
|
||||||
|
if os.path.exists(actions_log):
|
||||||
|
with open(actions_log, 'r', encoding='utf-8') as f:
|
||||||
|
f.seek(last_position)
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
try:
|
||||||
|
action_data = json.loads(line)
|
||||||
|
action = AgentAction(
|
||||||
|
round_num=action_data.get("round", 0),
|
||||||
|
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
|
||||||
|
platform=action_data.get("platform", "unknown"),
|
||||||
|
agent_id=action_data.get("agent_id", 0),
|
||||||
|
agent_name=action_data.get("agent_name", ""),
|
||||||
|
action_type=action_data.get("action_type", ""),
|
||||||
|
action_args=action_data.get("action_args", {}),
|
||||||
|
result=action_data.get("result"),
|
||||||
|
success=action_data.get("success", True),
|
||||||
|
)
|
||||||
|
state.add_action(action)
|
||||||
|
|
||||||
|
# 更新轮次
|
||||||
|
if action.round_num > state.current_round:
|
||||||
|
state.current_round = action.round_num
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
last_position = f.tell()
|
||||||
|
|
||||||
|
# 定期保存状态
|
||||||
|
cls._save_run_state(state)
|
||||||
|
time.sleep(1) # 每秒检查一次
|
||||||
|
|
||||||
|
# 进程结束
|
||||||
|
exit_code = process.returncode
|
||||||
|
|
||||||
|
if exit_code == 0:
|
||||||
|
state.runner_status = RunnerStatus.COMPLETED
|
||||||
|
state.completed_at = datetime.now().isoformat()
|
||||||
|
logger.info(f"模拟完成: {simulation_id}")
|
||||||
|
else:
|
||||||
|
state.runner_status = RunnerStatus.FAILED
|
||||||
|
stderr = process.stderr.read() if process.stderr else ""
|
||||||
|
state.error = f"进程退出码: {exit_code}, 错误: {stderr[:500]}"
|
||||||
|
logger.error(f"模拟失败: {simulation_id}, error={state.error}")
|
||||||
|
|
||||||
|
state.twitter_running = False
|
||||||
|
state.reddit_running = False
|
||||||
|
cls._save_run_state(state)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"监控线程异常: {simulation_id}, error={str(e)}")
|
||||||
|
state.runner_status = RunnerStatus.FAILED
|
||||||
|
state.error = str(e)
|
||||||
|
cls._save_run_state(state)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理
|
||||||
|
cls._processes.pop(simulation_id, None)
|
||||||
|
cls._action_queues.pop(simulation_id, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
|
||||||
|
"""停止模拟"""
|
||||||
|
state = cls.get_run_state(simulation_id)
|
||||||
|
if not state:
|
||||||
|
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||||
|
|
||||||
|
if state.runner_status not in [RunnerStatus.RUNNING, RunnerStatus.PAUSED]:
|
||||||
|
raise ValueError(f"模拟未在运行: {simulation_id}, status={state.runner_status}")
|
||||||
|
|
||||||
|
state.runner_status = RunnerStatus.STOPPING
|
||||||
|
cls._save_run_state(state)
|
||||||
|
|
||||||
|
# 终止进程
|
||||||
|
process = cls._processes.get(simulation_id)
|
||||||
|
if process:
|
||||||
|
process.terminate()
|
||||||
|
try:
|
||||||
|
process.wait(timeout=10)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
process.kill()
|
||||||
|
|
||||||
|
state.runner_status = RunnerStatus.STOPPED
|
||||||
|
state.twitter_running = False
|
||||||
|
state.reddit_running = False
|
||||||
|
state.completed_at = datetime.now().isoformat()
|
||||||
|
cls._save_run_state(state)
|
||||||
|
|
||||||
|
logger.info(f"模拟已停止: {simulation_id}")
|
||||||
|
return state
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_actions(
|
||||||
|
cls,
|
||||||
|
simulation_id: str,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
platform: Optional[str] = None,
|
||||||
|
agent_id: Optional[int] = None,
|
||||||
|
round_num: Optional[int] = None
|
||||||
|
) -> List[AgentAction]:
|
||||||
|
"""
|
||||||
|
获取动作历史
|
||||||
|
|
||||||
|
Args:
|
||||||
|
simulation_id: 模拟ID
|
||||||
|
limit: 返回数量限制
|
||||||
|
offset: 偏移量
|
||||||
|
platform: 过滤平台
|
||||||
|
agent_id: 过滤Agent
|
||||||
|
round_num: 过滤轮次
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
动作列表
|
||||||
|
"""
|
||||||
|
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||||
|
actions_log = os.path.join(sim_dir, "actions.jsonl")
|
||||||
|
|
||||||
|
if not os.path.exists(actions_log):
|
||||||
|
return []
|
||||||
|
|
||||||
|
actions = []
|
||||||
|
|
||||||
|
with open(actions_log, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
|
||||||
|
# 过滤
|
||||||
|
if platform and data.get("platform") != platform:
|
||||||
|
continue
|
||||||
|
if agent_id is not None and data.get("agent_id") != agent_id:
|
||||||
|
continue
|
||||||
|
if round_num is not None and data.get("round") != round_num:
|
||||||
|
continue
|
||||||
|
|
||||||
|
actions.append(AgentAction(
|
||||||
|
round_num=data.get("round", 0),
|
||||||
|
timestamp=data.get("timestamp", ""),
|
||||||
|
platform=data.get("platform", ""),
|
||||||
|
agent_id=data.get("agent_id", 0),
|
||||||
|
agent_name=data.get("agent_name", ""),
|
||||||
|
action_type=data.get("action_type", ""),
|
||||||
|
action_args=data.get("action_args", {}),
|
||||||
|
result=data.get("result"),
|
||||||
|
success=data.get("success", True),
|
||||||
|
))
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 按时间倒序排列
|
||||||
|
actions.reverse()
|
||||||
|
|
||||||
|
# 分页
|
||||||
|
return actions[offset:offset + limit]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_timeline(
|
||||||
|
cls,
|
||||||
|
simulation_id: str,
|
||||||
|
start_round: int = 0,
|
||||||
|
end_round: Optional[int] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取模拟时间线(按轮次汇总)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
simulation_id: 模拟ID
|
||||||
|
start_round: 起始轮次
|
||||||
|
end_round: 结束轮次
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
每轮的汇总信息
|
||||||
|
"""
|
||||||
|
actions = cls.get_actions(simulation_id, limit=10000)
|
||||||
|
|
||||||
|
# 按轮次分组
|
||||||
|
rounds: Dict[int, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
for action in actions:
|
||||||
|
round_num = action.round_num
|
||||||
|
|
||||||
|
if round_num < start_round:
|
||||||
|
continue
|
||||||
|
if end_round is not None and round_num > end_round:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if round_num not in rounds:
|
||||||
|
rounds[round_num] = {
|
||||||
|
"round_num": round_num,
|
||||||
|
"twitter_actions": 0,
|
||||||
|
"reddit_actions": 0,
|
||||||
|
"active_agents": set(),
|
||||||
|
"action_types": {},
|
||||||
|
"first_action_time": action.timestamp,
|
||||||
|
"last_action_time": action.timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
r = rounds[round_num]
|
||||||
|
|
||||||
|
if action.platform == "twitter":
|
||||||
|
r["twitter_actions"] += 1
|
||||||
|
else:
|
||||||
|
r["reddit_actions"] += 1
|
||||||
|
|
||||||
|
r["active_agents"].add(action.agent_id)
|
||||||
|
r["action_types"][action.action_type] = r["action_types"].get(action.action_type, 0) + 1
|
||||||
|
r["last_action_time"] = action.timestamp
|
||||||
|
|
||||||
|
# 转换为列表
|
||||||
|
result = []
|
||||||
|
for round_num in sorted(rounds.keys()):
|
||||||
|
r = rounds[round_num]
|
||||||
|
result.append({
|
||||||
|
"round_num": round_num,
|
||||||
|
"twitter_actions": r["twitter_actions"],
|
||||||
|
"reddit_actions": r["reddit_actions"],
|
||||||
|
"total_actions": r["twitter_actions"] + r["reddit_actions"],
|
||||||
|
"active_agents_count": len(r["active_agents"]),
|
||||||
|
"active_agents": list(r["active_agents"]),
|
||||||
|
"action_types": r["action_types"],
|
||||||
|
"first_action_time": r["first_action_time"],
|
||||||
|
"last_action_time": r["last_action_time"],
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取每个Agent的统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent统计列表
|
||||||
|
"""
|
||||||
|
actions = cls.get_actions(simulation_id, limit=10000)
|
||||||
|
|
||||||
|
agent_stats: Dict[int, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
for action in actions:
|
||||||
|
agent_id = action.agent_id
|
||||||
|
|
||||||
|
if agent_id not in agent_stats:
|
||||||
|
agent_stats[agent_id] = {
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"agent_name": action.agent_name,
|
||||||
|
"total_actions": 0,
|
||||||
|
"twitter_actions": 0,
|
||||||
|
"reddit_actions": 0,
|
||||||
|
"action_types": {},
|
||||||
|
"first_action_time": action.timestamp,
|
||||||
|
"last_action_time": action.timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
stats = agent_stats[agent_id]
|
||||||
|
stats["total_actions"] += 1
|
||||||
|
|
||||||
|
if action.platform == "twitter":
|
||||||
|
stats["twitter_actions"] += 1
|
||||||
|
else:
|
||||||
|
stats["reddit_actions"] += 1
|
||||||
|
|
||||||
|
stats["action_types"][action.action_type] = stats["action_types"].get(action.action_type, 0) + 1
|
||||||
|
stats["last_action_time"] = action.timestamp
|
||||||
|
|
||||||
|
# 按总动作数排序
|
||||||
|
result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
386
backend/app/services/zep_entity_reader.py
Normal file
386
backend/app/services/zep_entity_reader.py
Normal file
|
|
@ -0,0 +1,386 @@
|
||||||
|
"""
|
||||||
|
Zep实体读取与过滤服务
|
||||||
|
从Zep图谱中读取节点,筛选出符合预定义实体类型的节点
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, List, Optional, Set
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from zep_cloud.client import Zep
|
||||||
|
|
||||||
|
from ..config import Config
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger('mirofish.zep_entity_reader')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EntityNode:
|
||||||
|
"""实体节点数据结构"""
|
||||||
|
uuid: str
|
||||||
|
name: str
|
||||||
|
labels: List[str]
|
||||||
|
summary: str
|
||||||
|
attributes: Dict[str, Any]
|
||||||
|
# 相关的边信息
|
||||||
|
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
# 相关的其他节点信息
|
||||||
|
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"uuid": self.uuid,
|
||||||
|
"name": self.name,
|
||||||
|
"labels": self.labels,
|
||||||
|
"summary": self.summary,
|
||||||
|
"attributes": self.attributes,
|
||||||
|
"related_edges": self.related_edges,
|
||||||
|
"related_nodes": self.related_nodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_entity_type(self) -> Optional[str]:
|
||||||
|
"""获取实体类型(排除默认的Entity标签)"""
|
||||||
|
for label in self.labels:
|
||||||
|
if label not in ["Entity", "Node"]:
|
||||||
|
return label
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FilteredEntities:
|
||||||
|
"""过滤后的实体集合"""
|
||||||
|
entities: List[EntityNode]
|
||||||
|
entity_types: Set[str]
|
||||||
|
total_count: int
|
||||||
|
filtered_count: int
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"entities": [e.to_dict() for e in self.entities],
|
||||||
|
"entity_types": list(self.entity_types),
|
||||||
|
"total_count": self.total_count,
|
||||||
|
"filtered_count": self.filtered_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ZepEntityReader:
|
||||||
|
"""
|
||||||
|
Zep实体读取与过滤服务
|
||||||
|
|
||||||
|
主要功能:
|
||||||
|
1. 从Zep图谱读取所有节点
|
||||||
|
2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点)
|
||||||
|
3. 获取每个实体的相关边和关联节点信息
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: Optional[str] = None):
|
||||||
|
self.api_key = api_key or Config.ZEP_API_KEY
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("ZEP_API_KEY 未配置")
|
||||||
|
|
||||||
|
self.client = Zep(api_key=self.api_key)
|
||||||
|
|
||||||
|
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取图谱的所有节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph_id: 图谱ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
节点列表
|
||||||
|
"""
|
||||||
|
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||||
|
|
||||||
|
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
||||||
|
|
||||||
|
nodes_data = []
|
||||||
|
for node in nodes:
|
||||||
|
nodes_data.append({
|
||||||
|
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||||
|
"name": node.name or "",
|
||||||
|
"labels": node.labels or [],
|
||||||
|
"summary": node.summary or "",
|
||||||
|
"attributes": node.attributes or {},
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"共获取 {len(nodes_data)} 个节点")
|
||||||
|
return nodes_data
|
||||||
|
|
||||||
|
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取图谱的所有边
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph_id: 图谱ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
边列表
|
||||||
|
"""
|
||||||
|
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||||
|
|
||||||
|
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
||||||
|
|
||||||
|
edges_data = []
|
||||||
|
for edge in edges:
|
||||||
|
edges_data.append({
|
||||||
|
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
||||||
|
"name": edge.name or "",
|
||||||
|
"fact": edge.fact or "",
|
||||||
|
"source_node_uuid": edge.source_node_uuid,
|
||||||
|
"target_node_uuid": edge.target_node_uuid,
|
||||||
|
"attributes": edge.attributes or {},
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"共获取 {len(edges_data)} 条边")
|
||||||
|
return edges_data
|
||||||
|
|
||||||
|
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取指定节点的所有相关边
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_uuid: 节点UUID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
边列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid)
|
||||||
|
|
||||||
|
edges_data = []
|
||||||
|
for edge in edges:
|
||||||
|
edges_data.append({
|
||||||
|
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
||||||
|
"name": edge.name or "",
|
||||||
|
"fact": edge.fact or "",
|
||||||
|
"source_node_uuid": edge.source_node_uuid,
|
||||||
|
"target_node_uuid": edge.target_node_uuid,
|
||||||
|
"attributes": edge.attributes or {},
|
||||||
|
})
|
||||||
|
|
||||||
|
return edges_data
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def filter_defined_entities(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
defined_entity_types: Optional[List[str]] = None,
|
||||||
|
enrich_with_edges: bool = True
|
||||||
|
) -> FilteredEntities:
|
||||||
|
"""
|
||||||
|
筛选出符合预定义实体类型的节点
|
||||||
|
|
||||||
|
筛选逻辑:
|
||||||
|
- 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过
|
||||||
|
- 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph_id: 图谱ID
|
||||||
|
defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型)
|
||||||
|
enrich_with_edges: 是否获取每个实体的相关边信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FilteredEntities: 过滤后的实体集合
|
||||||
|
"""
|
||||||
|
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
|
||||||
|
|
||||||
|
# 获取所有节点
|
||||||
|
all_nodes = self.get_all_nodes(graph_id)
|
||||||
|
total_count = len(all_nodes)
|
||||||
|
|
||||||
|
# 获取所有边(用于后续关联查找)
|
||||||
|
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
||||||
|
|
||||||
|
# 构建节点UUID到节点数据的映射
|
||||||
|
node_map = {n["uuid"]: n for n in all_nodes}
|
||||||
|
|
||||||
|
# 筛选符合条件的实体
|
||||||
|
filtered_entities = []
|
||||||
|
entity_types_found = set()
|
||||||
|
|
||||||
|
for node in all_nodes:
|
||||||
|
labels = node.get("labels", [])
|
||||||
|
|
||||||
|
# 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签
|
||||||
|
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
|
||||||
|
|
||||||
|
if not custom_labels:
|
||||||
|
# 只有默认标签,跳过
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果指定了预定义类型,检查是否匹配
|
||||||
|
if defined_entity_types:
|
||||||
|
matching_labels = [l for l in custom_labels if l in defined_entity_types]
|
||||||
|
if not matching_labels:
|
||||||
|
continue
|
||||||
|
entity_type = matching_labels[0]
|
||||||
|
else:
|
||||||
|
entity_type = custom_labels[0]
|
||||||
|
|
||||||
|
entity_types_found.add(entity_type)
|
||||||
|
|
||||||
|
# 创建实体节点对象
|
||||||
|
entity = EntityNode(
|
||||||
|
uuid=node["uuid"],
|
||||||
|
name=node["name"],
|
||||||
|
labels=labels,
|
||||||
|
summary=node["summary"],
|
||||||
|
attributes=node["attributes"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取相关边和节点
|
||||||
|
if enrich_with_edges:
|
||||||
|
related_edges = []
|
||||||
|
related_node_uuids = set()
|
||||||
|
|
||||||
|
for edge in all_edges:
|
||||||
|
if edge["source_node_uuid"] == node["uuid"]:
|
||||||
|
related_edges.append({
|
||||||
|
"direction": "outgoing",
|
||||||
|
"edge_name": edge["name"],
|
||||||
|
"fact": edge["fact"],
|
||||||
|
"target_node_uuid": edge["target_node_uuid"],
|
||||||
|
})
|
||||||
|
related_node_uuids.add(edge["target_node_uuid"])
|
||||||
|
elif edge["target_node_uuid"] == node["uuid"]:
|
||||||
|
related_edges.append({
|
||||||
|
"direction": "incoming",
|
||||||
|
"edge_name": edge["name"],
|
||||||
|
"fact": edge["fact"],
|
||||||
|
"source_node_uuid": edge["source_node_uuid"],
|
||||||
|
})
|
||||||
|
related_node_uuids.add(edge["source_node_uuid"])
|
||||||
|
|
||||||
|
entity.related_edges = related_edges
|
||||||
|
|
||||||
|
# 获取关联节点的基本信息
|
||||||
|
related_nodes = []
|
||||||
|
for related_uuid in related_node_uuids:
|
||||||
|
if related_uuid in node_map:
|
||||||
|
related_node = node_map[related_uuid]
|
||||||
|
related_nodes.append({
|
||||||
|
"uuid": related_node["uuid"],
|
||||||
|
"name": related_node["name"],
|
||||||
|
"labels": related_node["labels"],
|
||||||
|
"summary": related_node.get("summary", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
entity.related_nodes = related_nodes
|
||||||
|
|
||||||
|
filtered_entities.append(entity)
|
||||||
|
|
||||||
|
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
|
||||||
|
f"实体类型: {entity_types_found}")
|
||||||
|
|
||||||
|
return FilteredEntities(
|
||||||
|
entities=filtered_entities,
|
||||||
|
entity_types=entity_types_found,
|
||||||
|
total_count=total_count,
|
||||||
|
filtered_count=len(filtered_entities),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_entity_with_context(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
entity_uuid: str
|
||||||
|
) -> Optional[EntityNode]:
|
||||||
|
"""
|
||||||
|
获取单个实体及其完整上下文(边和关联节点)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph_id: 图谱ID
|
||||||
|
entity_uuid: 实体UUID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EntityNode或None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取节点
|
||||||
|
node = self.client.graph.node.get(uuid_=entity_uuid)
|
||||||
|
|
||||||
|
if not node:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取节点的边
|
||||||
|
edges = self.get_node_edges(entity_uuid)
|
||||||
|
|
||||||
|
# 获取所有节点用于关联查找
|
||||||
|
all_nodes = self.get_all_nodes(graph_id)
|
||||||
|
node_map = {n["uuid"]: n for n in all_nodes}
|
||||||
|
|
||||||
|
# 处理相关边和节点
|
||||||
|
related_edges = []
|
||||||
|
related_node_uuids = set()
|
||||||
|
|
||||||
|
for edge in edges:
|
||||||
|
if edge["source_node_uuid"] == entity_uuid:
|
||||||
|
related_edges.append({
|
||||||
|
"direction": "outgoing",
|
||||||
|
"edge_name": edge["name"],
|
||||||
|
"fact": edge["fact"],
|
||||||
|
"target_node_uuid": edge["target_node_uuid"],
|
||||||
|
})
|
||||||
|
related_node_uuids.add(edge["target_node_uuid"])
|
||||||
|
else:
|
||||||
|
related_edges.append({
|
||||||
|
"direction": "incoming",
|
||||||
|
"edge_name": edge["name"],
|
||||||
|
"fact": edge["fact"],
|
||||||
|
"source_node_uuid": edge["source_node_uuid"],
|
||||||
|
})
|
||||||
|
related_node_uuids.add(edge["source_node_uuid"])
|
||||||
|
|
||||||
|
# 获取关联节点信息
|
||||||
|
related_nodes = []
|
||||||
|
for related_uuid in related_node_uuids:
|
||||||
|
if related_uuid in node_map:
|
||||||
|
related_node = node_map[related_uuid]
|
||||||
|
related_nodes.append({
|
||||||
|
"uuid": related_node["uuid"],
|
||||||
|
"name": related_node["name"],
|
||||||
|
"labels": related_node["labels"],
|
||||||
|
"summary": related_node.get("summary", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
return EntityNode(
|
||||||
|
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||||
|
name=node.name or "",
|
||||||
|
labels=node.labels or [],
|
||||||
|
summary=node.summary or "",
|
||||||
|
attributes=node.attributes or {},
|
||||||
|
related_edges=related_edges,
|
||||||
|
related_nodes=related_nodes,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_entities_by_type(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
entity_type: str,
|
||||||
|
enrich_with_edges: bool = True
|
||||||
|
) -> List[EntityNode]:
|
||||||
|
"""
|
||||||
|
获取指定类型的所有实体
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph_id: 图谱ID
|
||||||
|
entity_type: 实体类型(如 "Student", "PublicFigure" 等)
|
||||||
|
enrich_with_edges: 是否获取相关边信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
实体列表
|
||||||
|
"""
|
||||||
|
result = self.filter_defined_entities(
|
||||||
|
graph_id=graph_id,
|
||||||
|
defined_entity_types=[entity_type],
|
||||||
|
enrich_with_edges=enrich_with_edges
|
||||||
|
)
|
||||||
|
return result.entities
|
||||||
|
|
||||||
|
|
||||||
238
backend/app/utils/retry.py
Normal file
238
backend/app/utils/retry.py
Normal file
|
|
@ -0,0 +1,238 @@
|
||||||
|
"""
|
||||||
|
API调用重试机制
|
||||||
|
用于处理LLM等外部API调用的重试逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import functools
|
||||||
|
from typing import Callable, Any, Optional, Type, Tuple
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger('mirofish.retry')
|
||||||
|
|
||||||
|
|
||||||
|
def retry_with_backoff(
|
||||||
|
max_retries: int = 3,
|
||||||
|
initial_delay: float = 1.0,
|
||||||
|
max_delay: float = 30.0,
|
||||||
|
backoff_factor: float = 2.0,
|
||||||
|
jitter: bool = True,
|
||||||
|
exceptions: Tuple[Type[Exception], ...] = (Exception,),
|
||||||
|
on_retry: Optional[Callable[[Exception, int], None]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
带指数退避的重试装饰器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_retries: 最大重试次数
|
||||||
|
initial_delay: 初始延迟(秒)
|
||||||
|
max_delay: 最大延迟(秒)
|
||||||
|
backoff_factor: 退避因子
|
||||||
|
jitter: 是否添加随机抖动
|
||||||
|
exceptions: 需要重试的异常类型
|
||||||
|
on_retry: 重试时的回调函数 (exception, retry_count)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@retry_with_backoff(max_retries=3)
|
||||||
|
def call_llm_api():
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
last_exception = None
|
||||||
|
delay = initial_delay
|
||||||
|
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
except exceptions as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
if attempt == max_retries:
|
||||||
|
logger.error(f"函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 计算延迟
|
||||||
|
current_delay = min(delay, max_delay)
|
||||||
|
if jitter:
|
||||||
|
current_delay = current_delay * (0.5 + random.random())
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, "
|
||||||
|
f"{current_delay:.1f}秒后重试..."
|
||||||
|
)
|
||||||
|
|
||||||
|
if on_retry:
|
||||||
|
on_retry(e, attempt + 1)
|
||||||
|
|
||||||
|
time.sleep(current_delay)
|
||||||
|
delay *= backoff_factor
|
||||||
|
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def retry_with_backoff_async(
|
||||||
|
max_retries: int = 3,
|
||||||
|
initial_delay: float = 1.0,
|
||||||
|
max_delay: float = 30.0,
|
||||||
|
backoff_factor: float = 2.0,
|
||||||
|
jitter: bool = True,
|
||||||
|
exceptions: Tuple[Type[Exception], ...] = (Exception,),
|
||||||
|
on_retry: Optional[Callable[[Exception, int], None]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
异步版本的重试装饰器
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs) -> Any:
|
||||||
|
last_exception = None
|
||||||
|
delay = initial_delay
|
||||||
|
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
except exceptions as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
if attempt == max_retries:
|
||||||
|
logger.error(f"异步函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
current_delay = min(delay, max_delay)
|
||||||
|
if jitter:
|
||||||
|
current_delay = current_delay * (0.5 + random.random())
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"异步函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, "
|
||||||
|
f"{current_delay:.1f}秒后重试..."
|
||||||
|
)
|
||||||
|
|
||||||
|
if on_retry:
|
||||||
|
on_retry(e, attempt + 1)
|
||||||
|
|
||||||
|
await asyncio.sleep(current_delay)
|
||||||
|
delay *= backoff_factor
|
||||||
|
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class RetryableAPIClient:
|
||||||
|
"""
|
||||||
|
可重试的API客户端封装
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_retries: int = 3,
|
||||||
|
initial_delay: float = 1.0,
|
||||||
|
max_delay: float = 30.0,
|
||||||
|
backoff_factor: float = 2.0
|
||||||
|
):
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.initial_delay = initial_delay
|
||||||
|
self.max_delay = max_delay
|
||||||
|
self.backoff_factor = backoff_factor
|
||||||
|
|
||||||
|
def call_with_retry(
|
||||||
|
self,
|
||||||
|
func: Callable,
|
||||||
|
*args,
|
||||||
|
exceptions: Tuple[Type[Exception], ...] = (Exception,),
|
||||||
|
**kwargs
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
执行函数调用并在失败时重试
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: 要调用的函数
|
||||||
|
*args: 函数参数
|
||||||
|
exceptions: 需要重试的异常类型
|
||||||
|
**kwargs: 函数关键字参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
函数返回值
|
||||||
|
"""
|
||||||
|
last_exception = None
|
||||||
|
delay = self.initial_delay
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
except exceptions as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
if attempt == self.max_retries:
|
||||||
|
logger.error(f"API调用在 {self.max_retries} 次重试后仍失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
current_delay = min(delay, self.max_delay)
|
||||||
|
current_delay = current_delay * (0.5 + random.random())
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"API调用第 {attempt + 1} 次尝试失败: {str(e)}, "
|
||||||
|
f"{current_delay:.1f}秒后重试..."
|
||||||
|
)
|
||||||
|
|
||||||
|
time.sleep(current_delay)
|
||||||
|
delay *= self.backoff_factor
|
||||||
|
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
def call_batch_with_retry(
|
||||||
|
self,
|
||||||
|
items: list,
|
||||||
|
process_func: Callable,
|
||||||
|
exceptions: Tuple[Type[Exception], ...] = (Exception,),
|
||||||
|
continue_on_failure: bool = True
|
||||||
|
) -> Tuple[list, list]:
|
||||||
|
"""
|
||||||
|
批量调用并对每个失败项单独重试
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items: 要处理的项目列表
|
||||||
|
process_func: 处理函数,接收单个item作为参数
|
||||||
|
exceptions: 需要重试的异常类型
|
||||||
|
continue_on_failure: 单项失败后是否继续处理其他项
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(成功结果列表, 失败项列表)
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
failures = []
|
||||||
|
|
||||||
|
for idx, item in enumerate(items):
|
||||||
|
try:
|
||||||
|
result = self.call_with_retry(
|
||||||
|
process_func,
|
||||||
|
item,
|
||||||
|
exceptions=exceptions
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理第 {idx + 1} 项失败: {str(e)}")
|
||||||
|
failures.append({
|
||||||
|
"index": idx,
|
||||||
|
"item": item,
|
||||||
|
"error": str(e)
|
||||||
|
})
|
||||||
|
|
||||||
|
if not continue_on_failure:
|
||||||
|
raise
|
||||||
|
|
||||||
|
return results, failures
|
||||||
|
|
||||||
|
|
@ -20,3 +20,7 @@ pydantic>=2.0.0
|
||||||
# 文件处理
|
# 文件处理
|
||||||
werkzeug>=3.0.0
|
werkzeug>=3.0.0
|
||||||
|
|
||||||
|
# OASIS社交媒体模拟框架
|
||||||
|
oasis-ai>=0.1.0
|
||||||
|
camel-ai>=0.2.0
|
||||||
|
|
||||||
|
|
|
||||||
138
backend/scripts/action_logger.py
Normal file
138
backend/scripts/action_logger.py
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
"""
|
||||||
|
动作日志记录器
|
||||||
|
用于记录OASIS模拟中每个Agent的动作,供后端监控使用
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ActionLogger:
|
||||||
|
"""动作日志记录器"""
|
||||||
|
|
||||||
|
def __init__(self, log_path: str):
|
||||||
|
"""
|
||||||
|
初始化日志记录器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_path: 日志文件路径(.jsonl格式)
|
||||||
|
"""
|
||||||
|
self.log_path = log_path
|
||||||
|
self._ensure_dir()
|
||||||
|
|
||||||
|
def _ensure_dir(self):
|
||||||
|
"""确保目录存在"""
|
||||||
|
log_dir = os.path.dirname(self.log_path)
|
||||||
|
if log_dir:
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def log_action(
|
||||||
|
self,
|
||||||
|
round_num: int,
|
||||||
|
platform: str,
|
||||||
|
agent_id: int,
|
||||||
|
agent_name: str,
|
||||||
|
action_type: str,
|
||||||
|
action_args: Optional[Dict[str, Any]] = None,
|
||||||
|
result: Optional[str] = None,
|
||||||
|
success: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
记录一个动作
|
||||||
|
|
||||||
|
Args:
|
||||||
|
round_num: 轮次
|
||||||
|
platform: 平台 (twitter/reddit)
|
||||||
|
agent_id: Agent ID
|
||||||
|
agent_name: Agent名称
|
||||||
|
action_type: 动作类型
|
||||||
|
action_args: 动作参数
|
||||||
|
result: 执行结果
|
||||||
|
success: 是否成功
|
||||||
|
"""
|
||||||
|
entry = {
|
||||||
|
"round": round_num,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"platform": platform,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"action_type": action_type,
|
||||||
|
"action_args": action_args or {},
|
||||||
|
"result": result,
|
||||||
|
"success": success,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||||
|
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
def log_round_start(self, round_num: int, simulated_hour: int, platform: str):
|
||||||
|
"""记录轮次开始"""
|
||||||
|
entry = {
|
||||||
|
"round": round_num,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"platform": platform,
|
||||||
|
"event_type": "round_start",
|
||||||
|
"simulated_hour": simulated_hour,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||||
|
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
def log_round_end(self, round_num: int, actions_count: int, platform: str):
|
||||||
|
"""记录轮次结束"""
|
||||||
|
entry = {
|
||||||
|
"round": round_num,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"platform": platform,
|
||||||
|
"event_type": "round_end",
|
||||||
|
"actions_count": actions_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||||
|
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
def log_simulation_start(self, platform: str, config: Dict[str, Any]):
|
||||||
|
"""记录模拟开始"""
|
||||||
|
entry = {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"platform": platform,
|
||||||
|
"event_type": "simulation_start",
|
||||||
|
"total_rounds": config.get("time_config", {}).get("total_simulation_hours", 72) * 2,
|
||||||
|
"agents_count": len(config.get("agent_configs", [])),
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||||
|
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
def log_simulation_end(self, platform: str, total_rounds: int, total_actions: int):
|
||||||
|
"""记录模拟结束"""
|
||||||
|
entry = {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"platform": platform,
|
||||||
|
"event_type": "simulation_end",
|
||||||
|
"total_rounds": total_rounds,
|
||||||
|
"total_actions": total_actions,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||||
|
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
# 全局日志实例(可选)
|
||||||
|
_global_logger: Optional[ActionLogger] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(log_path: Optional[str] = None) -> ActionLogger:
|
||||||
|
"""获取全局日志实例"""
|
||||||
|
global _global_logger
|
||||||
|
|
||||||
|
if log_path:
|
||||||
|
_global_logger = ActionLogger(log_path)
|
||||||
|
|
||||||
|
if _global_logger is None:
|
||||||
|
_global_logger = ActionLogger("actions.jsonl")
|
||||||
|
|
||||||
|
return _global_logger
|
||||||
|
|
||||||
503
backend/scripts/run_parallel_simulation.py
Normal file
503
backend/scripts/run_parallel_simulation.py
Normal file
|
|
@ -0,0 +1,503 @@
|
||||||
|
"""
|
||||||
|
OASIS 双平台并行模拟预设脚本
|
||||||
|
同时运行Twitter和Reddit模拟,读取相同的配置文件
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
python run_parallel_simulation.py --config simulation_config.json [--action-log actions.jsonl]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from action_logger import ActionLogger
|
||||||
|
|
||||||
|
try:
|
||||||
|
from camel.models import ModelFactory
|
||||||
|
from camel.types import ModelPlatformType
|
||||||
|
import oasis
|
||||||
|
from oasis import (
|
||||||
|
ActionType,
|
||||||
|
LLMAction,
|
||||||
|
ManualAction,
|
||||||
|
generate_twitter_agent_graph,
|
||||||
|
generate_reddit_agent_graph
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"错误: 缺少依赖 {e}")
|
||||||
|
print("请先安装: pip install oasis-ai camel-ai")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# Twitter可用动作
|
||||||
|
TWITTER_ACTIONS = [
|
||||||
|
ActionType.CREATE_POST,
|
||||||
|
ActionType.LIKE_POST,
|
||||||
|
ActionType.REPOST,
|
||||||
|
ActionType.FOLLOW,
|
||||||
|
ActionType.DO_NOTHING,
|
||||||
|
ActionType.QUOTE_POST,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Reddit可用动作
|
||||||
|
REDDIT_ACTIONS = [
|
||||||
|
ActionType.LIKE_POST,
|
||||||
|
ActionType.DISLIKE_POST,
|
||||||
|
ActionType.CREATE_POST,
|
||||||
|
ActionType.CREATE_COMMENT,
|
||||||
|
ActionType.LIKE_COMMENT,
|
||||||
|
ActionType.DISLIKE_COMMENT,
|
||||||
|
ActionType.SEARCH_POSTS,
|
||||||
|
ActionType.SEARCH_USER,
|
||||||
|
ActionType.TREND,
|
||||||
|
ActionType.REFRESH,
|
||||||
|
ActionType.DO_NOTHING,
|
||||||
|
ActionType.FOLLOW,
|
||||||
|
ActionType.MUTE,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_path: str) -> Dict[str, Any]:
|
||||||
|
"""加载配置文件"""
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(config: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
创建LLM模型
|
||||||
|
|
||||||
|
OASIS使用camel-ai的ModelFactory,配置方式:
|
||||||
|
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量
|
||||||
|
- 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量
|
||||||
|
"""
|
||||||
|
llm_model = config.get("llm_model", "gpt-4o-mini")
|
||||||
|
llm_base_url = config.get("llm_base_url", "")
|
||||||
|
|
||||||
|
# 如果配置了base_url,设置环境变量
|
||||||
|
if llm_base_url:
|
||||||
|
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
|
||||||
|
|
||||||
|
return ModelFactory.create(
|
||||||
|
model_platform=ModelPlatformType.OPENAI,
|
||||||
|
model_type=llm_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_agents_for_round(
|
||||||
|
env,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
current_hour: int,
|
||||||
|
round_num: int
|
||||||
|
) -> List:
|
||||||
|
"""根据时间和配置决定本轮激活哪些Agent"""
|
||||||
|
time_config = config.get("time_config", {})
|
||||||
|
agent_configs = config.get("agent_configs", [])
|
||||||
|
|
||||||
|
base_min = time_config.get("agents_per_hour_min", 5)
|
||||||
|
base_max = time_config.get("agents_per_hour_max", 20)
|
||||||
|
|
||||||
|
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
|
||||||
|
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
|
||||||
|
|
||||||
|
if current_hour in peak_hours:
|
||||||
|
multiplier = time_config.get("peak_activity_multiplier", 1.5)
|
||||||
|
elif current_hour in off_peak_hours:
|
||||||
|
multiplier = time_config.get("off_peak_activity_multiplier", 0.3)
|
||||||
|
else:
|
||||||
|
multiplier = 1.0
|
||||||
|
|
||||||
|
target_count = int(random.uniform(base_min, base_max) * multiplier)
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
for cfg in agent_configs:
|
||||||
|
agent_id = cfg.get("agent_id", 0)
|
||||||
|
active_hours = cfg.get("active_hours", list(range(8, 23)))
|
||||||
|
activity_level = cfg.get("activity_level", 0.5)
|
||||||
|
|
||||||
|
if current_hour not in active_hours:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if random.random() < activity_level:
|
||||||
|
candidates.append(agent_id)
|
||||||
|
|
||||||
|
selected_ids = random.sample(
|
||||||
|
candidates,
|
||||||
|
min(target_count, len(candidates))
|
||||||
|
) if candidates else []
|
||||||
|
|
||||||
|
active_agents = []
|
||||||
|
for agent_id in selected_ids:
|
||||||
|
try:
|
||||||
|
agent = env.agent_graph.get_agent(agent_id)
|
||||||
|
active_agents.append((agent_id, agent))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return active_agents
|
||||||
|
|
||||||
|
|
||||||
|
async def run_twitter_simulation(
|
||||||
|
config: Dict[str, Any],
|
||||||
|
simulation_dir: str,
|
||||||
|
action_logger: Optional[ActionLogger] = None
|
||||||
|
):
|
||||||
|
"""运行Twitter模拟"""
|
||||||
|
print("[Twitter] 初始化...")
|
||||||
|
|
||||||
|
model = create_model(config)
|
||||||
|
|
||||||
|
# OASIS Twitter使用CSV格式
|
||||||
|
profile_path = os.path.join(simulation_dir, "twitter_profiles.csv")
|
||||||
|
if not os.path.exists(profile_path):
|
||||||
|
print(f"[Twitter] 错误: Profile文件不存在: {profile_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
agent_graph = await generate_twitter_agent_graph(
|
||||||
|
profile_path=profile_path,
|
||||||
|
model=model,
|
||||||
|
available_actions=TWITTER_ACTIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取Agent名称映射
|
||||||
|
agent_names = {}
|
||||||
|
for agent_id, agent in agent_graph.get_agents():
|
||||||
|
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
||||||
|
|
||||||
|
db_path = os.path.join(simulation_dir, "twitter_simulation.db")
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.remove(db_path)
|
||||||
|
|
||||||
|
env = oasis.make(
|
||||||
|
agent_graph=agent_graph,
|
||||||
|
platform=oasis.DefaultPlatformType.TWITTER,
|
||||||
|
database_path=db_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
await env.reset()
|
||||||
|
print("[Twitter] 环境已启动")
|
||||||
|
|
||||||
|
if action_logger:
|
||||||
|
action_logger.log_simulation_start("twitter", config)
|
||||||
|
|
||||||
|
total_actions = 0
|
||||||
|
|
||||||
|
# 执行初始事件
|
||||||
|
event_config = config.get("event_config", {})
|
||||||
|
initial_posts = event_config.get("initial_posts", [])
|
||||||
|
|
||||||
|
if initial_posts:
|
||||||
|
initial_actions = {}
|
||||||
|
for post in initial_posts:
|
||||||
|
agent_id = post.get("poster_agent_id", 0)
|
||||||
|
content = post.get("content", "")
|
||||||
|
try:
|
||||||
|
agent = env.agent_graph.get_agent(agent_id)
|
||||||
|
initial_actions[agent] = ManualAction(
|
||||||
|
action_type=ActionType.CREATE_POST,
|
||||||
|
action_args={"content": content}
|
||||||
|
)
|
||||||
|
|
||||||
|
if action_logger:
|
||||||
|
action_logger.log_action(
|
||||||
|
round_num=0,
|
||||||
|
platform="twitter",
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
||||||
|
action_type="CREATE_POST",
|
||||||
|
action_args={"content": content[:100] + "..." if len(content) > 100 else content}
|
||||||
|
)
|
||||||
|
total_actions += 1
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if initial_actions:
|
||||||
|
await env.step(initial_actions)
|
||||||
|
print(f"[Twitter] 已发布 {len(initial_actions)} 条初始帖子")
|
||||||
|
|
||||||
|
# 主模拟循环
|
||||||
|
time_config = config.get("time_config", {})
|
||||||
|
total_hours = time_config.get("total_simulation_hours", 72)
|
||||||
|
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||||
|
total_rounds = (total_hours * 60) // minutes_per_round
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
for round_num in range(total_rounds):
|
||||||
|
simulated_minutes = round_num * minutes_per_round
|
||||||
|
simulated_hour = (simulated_minutes // 60) % 24
|
||||||
|
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||||
|
|
||||||
|
active_agents = get_active_agents_for_round(
|
||||||
|
env, config, simulated_hour, round_num
|
||||||
|
)
|
||||||
|
|
||||||
|
if not active_agents:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if action_logger:
|
||||||
|
action_logger.log_round_start(round_num + 1, simulated_hour, "twitter")
|
||||||
|
|
||||||
|
actions = {agent: LLMAction() for _, agent in active_agents}
|
||||||
|
await env.step(actions)
|
||||||
|
|
||||||
|
# 记录动作
|
||||||
|
for agent_id, agent in active_agents:
|
||||||
|
if action_logger:
|
||||||
|
action_logger.log_action(
|
||||||
|
round_num=round_num + 1,
|
||||||
|
platform="twitter",
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
||||||
|
action_type="LLM_ACTION",
|
||||||
|
action_args={}
|
||||||
|
)
|
||||||
|
total_actions += 1
|
||||||
|
|
||||||
|
if action_logger:
|
||||||
|
action_logger.log_round_end(round_num + 1, len(active_agents), "twitter")
|
||||||
|
|
||||||
|
if (round_num + 1) % 20 == 0:
|
||||||
|
progress = (round_num + 1) / total_rounds * 100
|
||||||
|
print(f"[Twitter] Day {simulated_day}, {simulated_hour:02d}:00 "
|
||||||
|
f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
||||||
|
|
||||||
|
await env.close()
|
||||||
|
|
||||||
|
if action_logger:
|
||||||
|
action_logger.log_simulation_end("twitter", total_rounds, total_actions)
|
||||||
|
|
||||||
|
elapsed = (datetime.now() - start_time).total_seconds()
|
||||||
|
print(f"[Twitter] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||||
|
|
||||||
|
|
||||||
|
async def run_reddit_simulation(
|
||||||
|
config: Dict[str, Any],
|
||||||
|
simulation_dir: str,
|
||||||
|
action_logger: Optional[ActionLogger] = None
|
||||||
|
):
|
||||||
|
"""运行Reddit模拟"""
|
||||||
|
print("[Reddit] 初始化...")
|
||||||
|
|
||||||
|
model = create_model(config)
|
||||||
|
|
||||||
|
profile_path = os.path.join(simulation_dir, "reddit_profiles.json")
|
||||||
|
if not os.path.exists(profile_path):
|
||||||
|
print(f"[Reddit] 错误: Profile文件不存在: {profile_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
agent_graph = await generate_reddit_agent_graph(
|
||||||
|
profile_path=profile_path,
|
||||||
|
model=model,
|
||||||
|
available_actions=REDDIT_ACTIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取Agent名称映射
|
||||||
|
agent_names = {}
|
||||||
|
for agent_id, agent in agent_graph.get_agents():
|
||||||
|
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
||||||
|
|
||||||
|
db_path = os.path.join(simulation_dir, "reddit_simulation.db")
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.remove(db_path)
|
||||||
|
|
||||||
|
env = oasis.make(
|
||||||
|
agent_graph=agent_graph,
|
||||||
|
platform=oasis.DefaultPlatformType.REDDIT,
|
||||||
|
database_path=db_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
await env.reset()
|
||||||
|
print("[Reddit] 环境已启动")
|
||||||
|
|
||||||
|
if action_logger:
|
||||||
|
action_logger.log_simulation_start("reddit", config)
|
||||||
|
|
||||||
|
total_actions = 0
|
||||||
|
|
||||||
|
# 执行初始事件
|
||||||
|
event_config = config.get("event_config", {})
|
||||||
|
initial_posts = event_config.get("initial_posts", [])
|
||||||
|
|
||||||
|
if initial_posts:
|
||||||
|
initial_actions = {}
|
||||||
|
for post in initial_posts:
|
||||||
|
agent_id = post.get("poster_agent_id", 0)
|
||||||
|
content = post.get("content", "")
|
||||||
|
try:
|
||||||
|
agent = env.agent_graph.get_agent(agent_id)
|
||||||
|
if agent in initial_actions:
|
||||||
|
if not isinstance(initial_actions[agent], list):
|
||||||
|
initial_actions[agent] = [initial_actions[agent]]
|
||||||
|
initial_actions[agent].append(ManualAction(
|
||||||
|
action_type=ActionType.CREATE_POST,
|
||||||
|
action_args={"content": content}
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
initial_actions[agent] = ManualAction(
|
||||||
|
action_type=ActionType.CREATE_POST,
|
||||||
|
action_args={"content": content}
|
||||||
|
)
|
||||||
|
|
||||||
|
if action_logger:
|
||||||
|
action_logger.log_action(
|
||||||
|
round_num=0,
|
||||||
|
platform="reddit",
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
||||||
|
action_type="CREATE_POST",
|
||||||
|
action_args={"content": content[:100] + "..." if len(content) > 100 else content}
|
||||||
|
)
|
||||||
|
total_actions += 1
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if initial_actions:
|
||||||
|
await env.step(initial_actions)
|
||||||
|
print(f"[Reddit] 已发布 {len(initial_actions)} 条初始帖子")
|
||||||
|
|
||||||
|
# 主模拟循环
|
||||||
|
time_config = config.get("time_config", {})
|
||||||
|
total_hours = time_config.get("total_simulation_hours", 72)
|
||||||
|
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||||
|
total_rounds = (total_hours * 60) // minutes_per_round
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
for round_num in range(total_rounds):
|
||||||
|
simulated_minutes = round_num * minutes_per_round
|
||||||
|
simulated_hour = (simulated_minutes // 60) % 24
|
||||||
|
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||||
|
|
||||||
|
active_agents = get_active_agents_for_round(
|
||||||
|
env, config, simulated_hour, round_num
|
||||||
|
)
|
||||||
|
|
||||||
|
if not active_agents:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if action_logger:
|
||||||
|
action_logger.log_round_start(round_num + 1, simulated_hour, "reddit")
|
||||||
|
|
||||||
|
actions = {agent: LLMAction() for _, agent in active_agents}
|
||||||
|
await env.step(actions)
|
||||||
|
|
||||||
|
# 记录动作
|
||||||
|
for agent_id, agent in active_agents:
|
||||||
|
if action_logger:
|
||||||
|
action_logger.log_action(
|
||||||
|
round_num=round_num + 1,
|
||||||
|
platform="reddit",
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
||||||
|
action_type="LLM_ACTION",
|
||||||
|
action_args={}
|
||||||
|
)
|
||||||
|
total_actions += 1
|
||||||
|
|
||||||
|
if action_logger:
|
||||||
|
action_logger.log_round_end(round_num + 1, len(active_agents), "reddit")
|
||||||
|
|
||||||
|
if (round_num + 1) % 20 == 0:
|
||||||
|
progress = (round_num + 1) / total_rounds * 100
|
||||||
|
print(f"[Reddit] Day {simulated_day}, {simulated_hour:02d}:00 "
|
||||||
|
f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
||||||
|
|
||||||
|
await env.close()
|
||||||
|
|
||||||
|
if action_logger:
|
||||||
|
action_logger.log_simulation_end("reddit", total_rounds, total_actions)
|
||||||
|
|
||||||
|
elapsed = (datetime.now() - start_time).total_seconds()
|
||||||
|
print(f"[Reddit] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
parser = argparse.ArgumentParser(description='OASIS双平台并行模拟')
|
||||||
|
parser.add_argument(
|
||||||
|
'--config',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help='配置文件路径 (simulation_config.json)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--twitter-only',
|
||||||
|
action='store_true',
|
||||||
|
help='只运行Twitter模拟'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--reddit-only',
|
||||||
|
action='store_true',
|
||||||
|
help='只运行Reddit模拟'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--action-log',
|
||||||
|
type=str,
|
||||||
|
default='actions.jsonl',
|
||||||
|
help='动作日志文件路径 (默认: actions.jsonl)'
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.config):
|
||||||
|
print(f"错误: 配置文件不存在: {args.config}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
config = load_config(args.config)
|
||||||
|
simulation_dir = os.path.dirname(args.config) or "."
|
||||||
|
|
||||||
|
# 创建动作日志记录器
|
||||||
|
action_log_path = os.path.join(simulation_dir, args.action_log)
|
||||||
|
action_logger = ActionLogger(action_log_path)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("OASIS 双平台并行模拟")
|
||||||
|
print(f"配置文件: {args.config}")
|
||||||
|
print(f"模拟ID: {config.get('simulation_id', 'unknown')}")
|
||||||
|
print(f"动作日志: {action_log_path}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
time_config = config.get("time_config", {})
|
||||||
|
print(f"\n模拟参数:")
|
||||||
|
print(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时")
|
||||||
|
print(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟")
|
||||||
|
print(f" - Agent数量: {len(config.get('agent_configs', []))}")
|
||||||
|
|
||||||
|
# LLM推理说明
|
||||||
|
reasoning = config.get("generation_reasoning", "")
|
||||||
|
if reasoning:
|
||||||
|
print(f"\nLLM配置推理:")
|
||||||
|
print(f" {reasoning[:500]}..." if len(reasoning) > 500 else f" {reasoning}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
if args.twitter_only:
|
||||||
|
await run_twitter_simulation(config, simulation_dir, action_logger)
|
||||||
|
elif args.reddit_only:
|
||||||
|
await run_reddit_simulation(config, simulation_dir, action_logger)
|
||||||
|
else:
|
||||||
|
# 并行运行(共享同一个action_logger)
|
||||||
|
await asyncio.gather(
|
||||||
|
run_twitter_simulation(config, simulation_dir, action_logger),
|
||||||
|
run_reddit_simulation(config, simulation_dir, action_logger),
|
||||||
|
)
|
||||||
|
|
||||||
|
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(f"全部模拟完成! 总耗时: {total_elapsed:.1f}秒")
|
||||||
|
print(f"动作日志已保存到: {action_log_path}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
298
backend/scripts/run_reddit_simulation.py
Normal file
298
backend/scripts/run_reddit_simulation.py
Normal file
|
|
@ -0,0 +1,298 @@
|
||||||
|
"""
|
||||||
|
OASIS Reddit模拟预设脚本
|
||||||
|
此脚本读取配置文件中的参数来执行模拟,实现全程自动化
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
python run_reddit_simulation.py --config /path/to/simulation_config.json
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
# 添加项目路径
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from camel.models import ModelFactory
|
||||||
|
from camel.types import ModelPlatformType
|
||||||
|
import oasis
|
||||||
|
from oasis import (
|
||||||
|
ActionType,
|
||||||
|
LLMAction,
|
||||||
|
ManualAction,
|
||||||
|
generate_reddit_agent_graph
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"错误: 缺少依赖 {e}")
|
||||||
|
print("请先安装: pip install oasis-ai camel-ai")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
class RedditSimulationRunner:
|
||||||
|
"""Reddit模拟运行器"""
|
||||||
|
|
||||||
|
# Reddit可用动作
|
||||||
|
AVAILABLE_ACTIONS = [
|
||||||
|
ActionType.LIKE_POST,
|
||||||
|
ActionType.DISLIKE_POST,
|
||||||
|
ActionType.CREATE_POST,
|
||||||
|
ActionType.CREATE_COMMENT,
|
||||||
|
ActionType.LIKE_COMMENT,
|
||||||
|
ActionType.DISLIKE_COMMENT,
|
||||||
|
ActionType.SEARCH_POSTS,
|
||||||
|
ActionType.SEARCH_USER,
|
||||||
|
ActionType.TREND,
|
||||||
|
ActionType.REFRESH,
|
||||||
|
ActionType.DO_NOTHING,
|
||||||
|
ActionType.FOLLOW,
|
||||||
|
ActionType.MUTE,
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, config_path: str):
|
||||||
|
"""
|
||||||
|
初始化模拟运行器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: 配置文件路径 (simulation_config.json)
|
||||||
|
"""
|
||||||
|
self.config_path = config_path
|
||||||
|
self.config = self._load_config()
|
||||||
|
self.simulation_dir = os.path.dirname(config_path)
|
||||||
|
|
||||||
|
def _load_config(self) -> Dict[str, Any]:
|
||||||
|
"""加载配置文件"""
|
||||||
|
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def _get_profile_path(self) -> str:
|
||||||
|
"""获取Profile文件路径"""
|
||||||
|
return os.path.join(self.simulation_dir, "reddit_profiles.json")
|
||||||
|
|
||||||
|
def _get_db_path(self) -> str:
|
||||||
|
"""获取数据库路径"""
|
||||||
|
return os.path.join(self.simulation_dir, "reddit_simulation.db")
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
"""
|
||||||
|
创建LLM模型
|
||||||
|
|
||||||
|
OASIS使用camel-ai的ModelFactory,配置方式:
|
||||||
|
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量
|
||||||
|
- 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
llm_model = self.config.get("llm_model", "gpt-4o-mini")
|
||||||
|
llm_base_url = self.config.get("llm_base_url", "")
|
||||||
|
|
||||||
|
# 如果配置了base_url,设置环境变量
|
||||||
|
if llm_base_url:
|
||||||
|
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
|
||||||
|
|
||||||
|
return ModelFactory.create(
|
||||||
|
model_platform=ModelPlatformType.OPENAI,
|
||||||
|
model_type=llm_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_active_agents_for_round(
|
||||||
|
self,
|
||||||
|
env,
|
||||||
|
current_hour: int,
|
||||||
|
round_num: int
|
||||||
|
) -> List:
|
||||||
|
"""
|
||||||
|
根据时间和配置决定本轮激活哪些Agent
|
||||||
|
"""
|
||||||
|
time_config = self.config.get("time_config", {})
|
||||||
|
agent_configs = self.config.get("agent_configs", [])
|
||||||
|
|
||||||
|
base_min = time_config.get("agents_per_hour_min", 5)
|
||||||
|
base_max = time_config.get("agents_per_hour_max", 20)
|
||||||
|
|
||||||
|
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
|
||||||
|
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
|
||||||
|
|
||||||
|
if current_hour in peak_hours:
|
||||||
|
multiplier = time_config.get("peak_activity_multiplier", 1.5)
|
||||||
|
elif current_hour in off_peak_hours:
|
||||||
|
multiplier = time_config.get("off_peak_activity_multiplier", 0.3)
|
||||||
|
else:
|
||||||
|
multiplier = 1.0
|
||||||
|
|
||||||
|
target_count = int(random.uniform(base_min, base_max) * multiplier)
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
for cfg in agent_configs:
|
||||||
|
agent_id = cfg.get("agent_id", 0)
|
||||||
|
active_hours = cfg.get("active_hours", list(range(8, 23)))
|
||||||
|
activity_level = cfg.get("activity_level", 0.5)
|
||||||
|
|
||||||
|
if current_hour not in active_hours:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if random.random() < activity_level:
|
||||||
|
candidates.append(agent_id)
|
||||||
|
|
||||||
|
selected_ids = random.sample(
|
||||||
|
candidates,
|
||||||
|
min(target_count, len(candidates))
|
||||||
|
) if candidates else []
|
||||||
|
|
||||||
|
active_agents = []
|
||||||
|
for agent_id in selected_ids:
|
||||||
|
try:
|
||||||
|
agent = env.agent_graph.get_agent(agent_id)
|
||||||
|
active_agents.append((agent_id, agent))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return active_agents
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""运行Reddit模拟"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("OASIS Reddit模拟")
|
||||||
|
print(f"配置文件: {self.config_path}")
|
||||||
|
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
time_config = self.config.get("time_config", {})
|
||||||
|
total_hours = time_config.get("total_simulation_hours", 72)
|
||||||
|
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||||
|
total_rounds = (total_hours * 60) // minutes_per_round
|
||||||
|
|
||||||
|
print(f"\n模拟参数:")
|
||||||
|
print(f" - 总模拟时长: {total_hours}小时")
|
||||||
|
print(f" - 每轮时间: {minutes_per_round}分钟")
|
||||||
|
print(f" - 总轮数: {total_rounds}")
|
||||||
|
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
|
||||||
|
|
||||||
|
print("\n初始化LLM模型...")
|
||||||
|
model = self._create_model()
|
||||||
|
|
||||||
|
print("加载Agent Profile...")
|
||||||
|
profile_path = self._get_profile_path()
|
||||||
|
if not os.path.exists(profile_path):
|
||||||
|
print(f"错误: Profile文件不存在: {profile_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
agent_graph = await generate_reddit_agent_graph(
|
||||||
|
profile_path=profile_path,
|
||||||
|
model=model,
|
||||||
|
available_actions=self.AVAILABLE_ACTIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_path = self._get_db_path()
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.remove(db_path)
|
||||||
|
print(f"已删除旧数据库: {db_path}")
|
||||||
|
|
||||||
|
print("创建OASIS环境...")
|
||||||
|
env = oasis.make(
|
||||||
|
agent_graph=agent_graph,
|
||||||
|
platform=oasis.DefaultPlatformType.REDDIT,
|
||||||
|
database_path=db_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
await env.reset()
|
||||||
|
print("环境初始化完成\n")
|
||||||
|
|
||||||
|
# 执行初始事件
|
||||||
|
event_config = self.config.get("event_config", {})
|
||||||
|
initial_posts = event_config.get("initial_posts", [])
|
||||||
|
|
||||||
|
if initial_posts:
|
||||||
|
print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...")
|
||||||
|
initial_actions = {}
|
||||||
|
for post in initial_posts:
|
||||||
|
agent_id = post.get("poster_agent_id", 0)
|
||||||
|
content = post.get("content", "")
|
||||||
|
try:
|
||||||
|
agent = env.agent_graph.get_agent(agent_id)
|
||||||
|
if agent in initial_actions:
|
||||||
|
if not isinstance(initial_actions[agent], list):
|
||||||
|
initial_actions[agent] = [initial_actions[agent]]
|
||||||
|
initial_actions[agent].append(ManualAction(
|
||||||
|
action_type=ActionType.CREATE_POST,
|
||||||
|
action_args={"content": content}
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
initial_actions[agent] = ManualAction(
|
||||||
|
action_type=ActionType.CREATE_POST,
|
||||||
|
action_args={"content": content}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}")
|
||||||
|
|
||||||
|
if initial_actions:
|
||||||
|
await env.step(initial_actions)
|
||||||
|
print(f" 已发布 {len(initial_actions)} 条初始帖子")
|
||||||
|
|
||||||
|
# 主模拟循环
|
||||||
|
print("\n开始模拟循环...")
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
for round_num in range(total_rounds):
|
||||||
|
simulated_minutes = round_num * minutes_per_round
|
||||||
|
simulated_hour = (simulated_minutes // 60) % 24
|
||||||
|
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||||
|
|
||||||
|
active_agents = self._get_active_agents_for_round(
|
||||||
|
env, simulated_hour, round_num
|
||||||
|
)
|
||||||
|
|
||||||
|
if not active_agents:
|
||||||
|
continue
|
||||||
|
|
||||||
|
actions = {
|
||||||
|
agent: LLMAction()
|
||||||
|
for _, agent in active_agents
|
||||||
|
}
|
||||||
|
|
||||||
|
await env.step(actions)
|
||||||
|
|
||||||
|
if (round_num + 1) % 10 == 0 or round_num == 0:
|
||||||
|
elapsed = (datetime.now() - start_time).total_seconds()
|
||||||
|
progress = (round_num + 1) / total_rounds * 100
|
||||||
|
print(f" [Day {simulated_day}, {simulated_hour:02d}:00] "
|
||||||
|
f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) "
|
||||||
|
f"- {len(active_agents)} agents active "
|
||||||
|
f"- elapsed: {elapsed:.1f}s")
|
||||||
|
|
||||||
|
await env.close()
|
||||||
|
|
||||||
|
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||||
|
print(f"\n模拟完成!")
|
||||||
|
print(f" - 总耗时: {total_elapsed:.1f}秒")
|
||||||
|
print(f" - 数据库: {db_path}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
parser = argparse.ArgumentParser(description='OASIS Reddit模拟')
|
||||||
|
parser.add_argument(
|
||||||
|
'--config',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help='配置文件路径 (simulation_config.json)'
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.config):
|
||||||
|
print(f"错误: 配置文件不存在: {args.config}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
runner = RedditSimulationRunner(args.config)
|
||||||
|
await runner.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
313
backend/scripts/run_twitter_simulation.py
Normal file
313
backend/scripts/run_twitter_simulation.py
Normal file
|
|
@ -0,0 +1,313 @@
|
||||||
|
"""
|
||||||
|
OASIS Twitter模拟预设脚本
|
||||||
|
此脚本读取配置文件中的参数来执行模拟,实现全程自动化
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
python run_twitter_simulation.py --config /path/to/simulation_config.json
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
# 添加项目路径
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from camel.models import ModelFactory
|
||||||
|
from camel.types import ModelPlatformType
|
||||||
|
import oasis
|
||||||
|
from oasis import (
|
||||||
|
ActionType,
|
||||||
|
LLMAction,
|
||||||
|
ManualAction,
|
||||||
|
generate_twitter_agent_graph
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"错误: 缺少依赖 {e}")
|
||||||
|
print("请先安装: pip install oasis-ai camel-ai")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
class TwitterSimulationRunner:
|
||||||
|
"""Twitter模拟运行器"""
|
||||||
|
|
||||||
|
# Twitter可用动作
|
||||||
|
AVAILABLE_ACTIONS = [
|
||||||
|
ActionType.CREATE_POST,
|
||||||
|
ActionType.LIKE_POST,
|
||||||
|
ActionType.REPOST,
|
||||||
|
ActionType.FOLLOW,
|
||||||
|
ActionType.DO_NOTHING,
|
||||||
|
ActionType.QUOTE_POST,
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, config_path: str):
|
||||||
|
"""
|
||||||
|
初始化模拟运行器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: 配置文件路径 (simulation_config.json)
|
||||||
|
"""
|
||||||
|
self.config_path = config_path
|
||||||
|
self.config = self._load_config()
|
||||||
|
self.simulation_dir = os.path.dirname(config_path)
|
||||||
|
|
||||||
|
def _load_config(self) -> Dict[str, Any]:
|
||||||
|
"""加载配置文件"""
|
||||||
|
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def _get_profile_path(self) -> str:
|
||||||
|
"""获取Profile文件路径(OASIS Twitter使用CSV格式)"""
|
||||||
|
return os.path.join(self.simulation_dir, "twitter_profiles.csv")
|
||||||
|
|
||||||
|
def _get_db_path(self) -> str:
|
||||||
|
"""获取数据库路径"""
|
||||||
|
return os.path.join(self.simulation_dir, "twitter_simulation.db")
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
"""
|
||||||
|
创建LLM模型
|
||||||
|
|
||||||
|
OASIS使用camel-ai的ModelFactory,配置方式:
|
||||||
|
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量
|
||||||
|
- 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量
|
||||||
|
|
||||||
|
配置文件中的 llm_model 对应 model_type
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
llm_model = self.config.get("llm_model", "gpt-4o-mini")
|
||||||
|
llm_base_url = self.config.get("llm_base_url", "")
|
||||||
|
|
||||||
|
# 如果配置了base_url,设置环境变量(OASIS通过环境变量读取)
|
||||||
|
if llm_base_url:
|
||||||
|
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
|
||||||
|
|
||||||
|
return ModelFactory.create(
|
||||||
|
model_platform=ModelPlatformType.OPENAI,
|
||||||
|
model_type=llm_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_active_agents_for_round(
|
||||||
|
self,
|
||||||
|
env,
|
||||||
|
current_hour: int,
|
||||||
|
round_num: int
|
||||||
|
) -> List:
|
||||||
|
"""
|
||||||
|
根据时间和配置决定本轮激活哪些Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: OASIS环境
|
||||||
|
current_hour: 当前模拟小时(0-23)
|
||||||
|
round_num: 当前轮数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
激活的Agent列表
|
||||||
|
"""
|
||||||
|
time_config = self.config.get("time_config", {})
|
||||||
|
agent_configs = self.config.get("agent_configs", [])
|
||||||
|
|
||||||
|
# 基础激活数量
|
||||||
|
base_min = time_config.get("agents_per_hour_min", 5)
|
||||||
|
base_max = time_config.get("agents_per_hour_max", 20)
|
||||||
|
|
||||||
|
# 根据时段调整
|
||||||
|
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
|
||||||
|
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
|
||||||
|
|
||||||
|
if current_hour in peak_hours:
|
||||||
|
multiplier = time_config.get("peak_activity_multiplier", 1.5)
|
||||||
|
elif current_hour in off_peak_hours:
|
||||||
|
multiplier = time_config.get("off_peak_activity_multiplier", 0.3)
|
||||||
|
else:
|
||||||
|
multiplier = 1.0
|
||||||
|
|
||||||
|
target_count = int(random.uniform(base_min, base_max) * multiplier)
|
||||||
|
|
||||||
|
# 根据每个Agent的配置计算激活概率
|
||||||
|
candidates = []
|
||||||
|
for cfg in agent_configs:
|
||||||
|
agent_id = cfg.get("agent_id", 0)
|
||||||
|
active_hours = cfg.get("active_hours", list(range(8, 23)))
|
||||||
|
activity_level = cfg.get("activity_level", 0.5)
|
||||||
|
|
||||||
|
# 检查是否在活跃时间
|
||||||
|
if current_hour not in active_hours:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 根据活跃度计算概率
|
||||||
|
if random.random() < activity_level:
|
||||||
|
candidates.append(agent_id)
|
||||||
|
|
||||||
|
# 随机选择
|
||||||
|
selected_ids = random.sample(
|
||||||
|
candidates,
|
||||||
|
min(target_count, len(candidates))
|
||||||
|
) if candidates else []
|
||||||
|
|
||||||
|
# 转换为Agent对象
|
||||||
|
active_agents = []
|
||||||
|
for agent_id in selected_ids:
|
||||||
|
try:
|
||||||
|
agent = env.agent_graph.get_agent(agent_id)
|
||||||
|
active_agents.append((agent_id, agent))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return active_agents
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""运行Twitter模拟"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("OASIS Twitter模拟")
|
||||||
|
print(f"配置文件: {self.config_path}")
|
||||||
|
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 加载时间配置
|
||||||
|
time_config = self.config.get("time_config", {})
|
||||||
|
total_hours = time_config.get("total_simulation_hours", 72)
|
||||||
|
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||||
|
|
||||||
|
# 计算总轮数
|
||||||
|
total_rounds = (total_hours * 60) // minutes_per_round
|
||||||
|
|
||||||
|
print(f"\n模拟参数:")
|
||||||
|
print(f" - 总模拟时长: {total_hours}小时")
|
||||||
|
print(f" - 每轮时间: {minutes_per_round}分钟")
|
||||||
|
print(f" - 总轮数: {total_rounds}")
|
||||||
|
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
print("\n初始化LLM模型...")
|
||||||
|
model = self._create_model()
|
||||||
|
|
||||||
|
# 加载Agent图
|
||||||
|
print("加载Agent Profile...")
|
||||||
|
profile_path = self._get_profile_path()
|
||||||
|
if not os.path.exists(profile_path):
|
||||||
|
print(f"错误: Profile文件不存在: {profile_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
agent_graph = await generate_twitter_agent_graph(
|
||||||
|
profile_path=profile_path,
|
||||||
|
model=model,
|
||||||
|
available_actions=self.AVAILABLE_ACTIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 数据库路径
|
||||||
|
db_path = self._get_db_path()
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.remove(db_path)
|
||||||
|
print(f"已删除旧数据库: {db_path}")
|
||||||
|
|
||||||
|
# 创建环境
|
||||||
|
print("创建OASIS环境...")
|
||||||
|
env = oasis.make(
|
||||||
|
agent_graph=agent_graph,
|
||||||
|
platform=oasis.DefaultPlatformType.TWITTER,
|
||||||
|
database_path=db_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
await env.reset()
|
||||||
|
print("环境初始化完成\n")
|
||||||
|
|
||||||
|
# 执行初始事件
|
||||||
|
event_config = self.config.get("event_config", {})
|
||||||
|
initial_posts = event_config.get("initial_posts", [])
|
||||||
|
|
||||||
|
if initial_posts:
|
||||||
|
print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...")
|
||||||
|
initial_actions = {}
|
||||||
|
for post in initial_posts:
|
||||||
|
agent_id = post.get("poster_agent_id", 0)
|
||||||
|
content = post.get("content", "")
|
||||||
|
try:
|
||||||
|
agent = env.agent_graph.get_agent(agent_id)
|
||||||
|
initial_actions[agent] = ManualAction(
|
||||||
|
action_type=ActionType.CREATE_POST,
|
||||||
|
action_args={"content": content}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}")
|
||||||
|
|
||||||
|
if initial_actions:
|
||||||
|
await env.step(initial_actions)
|
||||||
|
print(f" 已发布 {len(initial_actions)} 条初始帖子")
|
||||||
|
|
||||||
|
# 主模拟循环
|
||||||
|
print("\n开始模拟循环...")
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
for round_num in range(total_rounds):
|
||||||
|
# 计算当前模拟时间
|
||||||
|
simulated_minutes = round_num * minutes_per_round
|
||||||
|
simulated_hour = (simulated_minutes // 60) % 24
|
||||||
|
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||||
|
|
||||||
|
# 获取本轮激活的Agent
|
||||||
|
active_agents = self._get_active_agents_for_round(
|
||||||
|
env, simulated_hour, round_num
|
||||||
|
)
|
||||||
|
|
||||||
|
if not active_agents:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 构建动作
|
||||||
|
actions = {
|
||||||
|
agent: LLMAction()
|
||||||
|
for _, agent in active_agents
|
||||||
|
}
|
||||||
|
|
||||||
|
# 执行动作
|
||||||
|
await env.step(actions)
|
||||||
|
|
||||||
|
# 打印进度
|
||||||
|
if (round_num + 1) % 10 == 0 or round_num == 0:
|
||||||
|
elapsed = (datetime.now() - start_time).total_seconds()
|
||||||
|
progress = (round_num + 1) / total_rounds * 100
|
||||||
|
print(f" [Day {simulated_day}, {simulated_hour:02d}:00] "
|
||||||
|
f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) "
|
||||||
|
f"- {len(active_agents)} agents active "
|
||||||
|
f"- elapsed: {elapsed:.1f}s")
|
||||||
|
|
||||||
|
# 关闭环境
|
||||||
|
await env.close()
|
||||||
|
|
||||||
|
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||||
|
print(f"\n模拟完成!")
|
||||||
|
print(f" - 总耗时: {total_elapsed:.1f}秒")
|
||||||
|
print(f" - 数据库: {db_path}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
parser = argparse.ArgumentParser(description='OASIS Twitter模拟')
|
||||||
|
parser.add_argument(
|
||||||
|
'--config',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help='配置文件路径 (simulation_config.json)'
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.config):
|
||||||
|
print(f"错误: 配置文件不存在: {args.config}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
runner = TwitterSimulationRunner(args.config)
|
||||||
|
await runner.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
166
backend/scripts/test_profile_format.py
Normal file
166
backend/scripts/test_profile_format.py
Normal file
|
|
@ -0,0 +1,166 @@
|
||||||
|
"""
|
||||||
|
测试Profile格式生成是否符合OASIS要求
|
||||||
|
验证:
|
||||||
|
1. Twitter Profile生成CSV格式
|
||||||
|
2. Reddit Profile生成JSON详细格式
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import csv
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
# 添加项目路径
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from app.services.oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
|
||||||
|
|
||||||
|
|
||||||
|
def test_profile_formats():
|
||||||
|
"""测试Profile格式"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("OASIS Profile格式测试")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 创建测试Profile数据
|
||||||
|
test_profiles = [
|
||||||
|
OasisAgentProfile(
|
||||||
|
user_id=0,
|
||||||
|
user_name="test_user_123",
|
||||||
|
name="Test User",
|
||||||
|
bio="A test user for validation",
|
||||||
|
persona="Test User is an enthusiastic participant in social discussions.",
|
||||||
|
karma=1500,
|
||||||
|
friend_count=100,
|
||||||
|
follower_count=200,
|
||||||
|
statuses_count=500,
|
||||||
|
age=25,
|
||||||
|
gender="male",
|
||||||
|
mbti="INTJ",
|
||||||
|
country="China",
|
||||||
|
profession="Student",
|
||||||
|
interested_topics=["Technology", "Education"],
|
||||||
|
source_entity_uuid="test-uuid-123",
|
||||||
|
source_entity_type="Student",
|
||||||
|
),
|
||||||
|
OasisAgentProfile(
|
||||||
|
user_id=1,
|
||||||
|
user_name="org_official_456",
|
||||||
|
name="Official Organization",
|
||||||
|
bio="Official account for Organization",
|
||||||
|
persona="This is an official institutional account that communicates official positions.",
|
||||||
|
karma=5000,
|
||||||
|
friend_count=50,
|
||||||
|
follower_count=10000,
|
||||||
|
statuses_count=200,
|
||||||
|
profession="Organization",
|
||||||
|
interested_topics=["Public Policy", "Announcements"],
|
||||||
|
source_entity_uuid="test-uuid-456",
|
||||||
|
source_entity_type="University",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
generator = OasisProfileGenerator.__new__(OasisProfileGenerator)
|
||||||
|
|
||||||
|
# 使用临时目录
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
twitter_path = os.path.join(temp_dir, "twitter_profiles.csv")
|
||||||
|
reddit_path = os.path.join(temp_dir, "reddit_profiles.json")
|
||||||
|
|
||||||
|
# 测试Twitter CSV格式
|
||||||
|
print("\n1. 测试Twitter Profile (CSV格式)")
|
||||||
|
print("-" * 40)
|
||||||
|
generator._save_twitter_csv(test_profiles, twitter_path)
|
||||||
|
|
||||||
|
# 读取并验证CSV
|
||||||
|
with open(twitter_path, 'r', encoding='utf-8') as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
rows = list(reader)
|
||||||
|
|
||||||
|
print(f" 文件: {twitter_path}")
|
||||||
|
print(f" 行数: {len(rows)}")
|
||||||
|
print(f" 表头: {list(rows[0].keys())}")
|
||||||
|
print(f"\n 示例数据 (第1行):")
|
||||||
|
for key, value in rows[0].items():
|
||||||
|
print(f" {key}: {value}")
|
||||||
|
|
||||||
|
# 验证必需字段
|
||||||
|
required_twitter_fields = ['user_id', 'user_name', 'name', 'bio',
|
||||||
|
'friend_count', 'follower_count', 'statuses_count', 'created_at']
|
||||||
|
missing = set(required_twitter_fields) - set(rows[0].keys())
|
||||||
|
if missing:
|
||||||
|
print(f"\n [错误] 缺少字段: {missing}")
|
||||||
|
else:
|
||||||
|
print(f"\n [通过] 所有必需字段都存在")
|
||||||
|
|
||||||
|
# 测试Reddit JSON格式
|
||||||
|
print("\n2. 测试Reddit Profile (JSON详细格式)")
|
||||||
|
print("-" * 40)
|
||||||
|
generator._save_reddit_json(test_profiles, reddit_path)
|
||||||
|
|
||||||
|
# 读取并验证JSON
|
||||||
|
with open(reddit_path, 'r', encoding='utf-8') as f:
|
||||||
|
reddit_data = json.load(f)
|
||||||
|
|
||||||
|
print(f" 文件: {reddit_path}")
|
||||||
|
print(f" 条目数: {len(reddit_data)}")
|
||||||
|
print(f" 字段: {list(reddit_data[0].keys())}")
|
||||||
|
print(f"\n 示例数据 (第1条):")
|
||||||
|
print(json.dumps(reddit_data[0], ensure_ascii=False, indent=4))
|
||||||
|
|
||||||
|
# 验证详细格式字段
|
||||||
|
required_reddit_fields = ['realname', 'username', 'bio', 'persona']
|
||||||
|
optional_reddit_fields = ['age', 'gender', 'mbti', 'country', 'profession', 'interested_topics']
|
||||||
|
|
||||||
|
missing = set(required_reddit_fields) - set(reddit_data[0].keys())
|
||||||
|
if missing:
|
||||||
|
print(f"\n [错误] 缺少必需字段: {missing}")
|
||||||
|
else:
|
||||||
|
print(f"\n [通过] 所有必需字段都存在")
|
||||||
|
|
||||||
|
present_optional = set(optional_reddit_fields) & set(reddit_data[0].keys())
|
||||||
|
print(f" [信息] 可选字段: {present_optional}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试完成!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
def show_expected_formats():
|
||||||
|
"""显示OASIS期望的格式"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("OASIS 期望的Profile格式参考")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print("\n1. Twitter Profile (CSV格式)")
|
||||||
|
print("-" * 40)
|
||||||
|
twitter_example = """user_id,user_name,name,bio,friend_count,follower_count,statuses_count,created_at
|
||||||
|
0,user0,User Zero,I am user zero with interests in technology.,100,150,500,2023-01-01
|
||||||
|
1,user1,User One,Tech enthusiast and coffee lover.,200,250,1000,2023-01-02"""
|
||||||
|
print(twitter_example)
|
||||||
|
|
||||||
|
print("\n2. Reddit Profile (JSON详细格式)")
|
||||||
|
print("-" * 40)
|
||||||
|
reddit_example = [
|
||||||
|
{
|
||||||
|
"realname": "James Miller",
|
||||||
|
"username": "millerhospitality",
|
||||||
|
"bio": "Passionate about hospitality & tourism.",
|
||||||
|
"persona": "James is a seasoned professional in the Hospitality & Tourism industry...",
|
||||||
|
"age": 40,
|
||||||
|
"gender": "male",
|
||||||
|
"mbti": "ESTJ",
|
||||||
|
"country": "UK",
|
||||||
|
"profession": "Hospitality & Tourism",
|
||||||
|
"interested_topics": ["Economics", "Business"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
print(json.dumps(reddit_example, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_profile_formats()
|
||||||
|
show_expected_formats()
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in a new issue