Enhance backend functionality with OASIS simulation features

- Updated README.md to include new simulation scripts and configuration details for OASIS, including API retry mechanisms and environment variable settings.
- Added simulation management and configuration generation services to streamline the simulation process across Twitter and Reddit platforms.
- Introduced new API routes for simulation-related operations, including entity retrieval and simulation status management.
- Implemented a robust retry mechanism for external API calls to improve system stability.
- Enhanced task management model to include detailed progress tracking.
- Added logging capabilities for action tracking during simulations.
- Included new scripts for running parallel simulations and testing profile formats.
This commit is contained in:
666ghj 2025-12-01 15:03:44 +08:00
parent c60e6e1089
commit 5f159f6d88
19 changed files with 7202 additions and 49 deletions

File diff suppressed because it is too large Load diff

View file

@ -46,8 +46,9 @@ def create_app(config_class=Config):
return response
# 注册蓝图
from .api import graph_bp
from .api import graph_bp, simulation_bp
app.register_blueprint(graph_bp, url_prefix='/api/graph')
app.register_blueprint(simulation_bp, url_prefix='/api/simulation')
# 健康检查
@app.route('/health')

View file

@ -5,6 +5,8 @@ API路由模块
from flask import Blueprint
graph_bp = Blueprint('graph', __name__)
simulation_bp = Blueprint('simulation', __name__)
from . import graph # noqa: E402, F401
from . import simulation # noqa: E402, F401

File diff suppressed because it is too large Load diff

View file

@ -41,6 +41,20 @@ class Config:
DEFAULT_CHUNK_SIZE = 500 # 默认切块大小
DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小
# OASIS模拟配置
OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10'))
OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations')
# OASIS平台可用动作配置
OASIS_TWITTER_ACTIONS = [
'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST'
]
OASIS_REDDIT_ACTIONS = [
'LIKE_POST', 'DISLIKE_POST', 'CREATE_POST', 'CREATE_COMMENT',
'LIKE_COMMENT', 'DISLIKE_COMMENT', 'SEARCH_POSTS', 'SEARCH_USER',
'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE'
]
@classmethod
def validate(cls):
"""验证必要配置"""

View file

@ -27,11 +27,12 @@ class Task:
status: TaskStatus
created_at: datetime
updated_at: datetime
progress: int = 0 # 进度百分比 0-100
progress: int = 0 # 进度百分比 0-100
message: str = "" # 状态消息
result: Optional[Dict] = None # 任务结果
error: Optional[str] = None # 错误信息
metadata: Dict = field(default_factory=dict) # 额外元数据
progress_detail: Dict = field(default_factory=dict) # 详细进度信息
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
@ -43,6 +44,7 @@ class Task:
"updated_at": self.updated_at.isoformat(),
"progress": self.progress,
"message": self.message,
"progress_detail": self.progress_detail,
"result": self.result,
"error": self.error,
"metadata": self.metadata,
@ -108,7 +110,8 @@ class TaskManager:
progress: Optional[int] = None,
message: Optional[str] = None,
result: Optional[Dict] = None,
error: Optional[str] = None
error: Optional[str] = None,
progress_detail: Optional[Dict] = None
):
"""
更新任务状态
@ -120,6 +123,7 @@ class TaskManager:
message: 消息
result: 结果
error: 错误信息
progress_detail: 详细进度信息
"""
with self._task_lock:
task = self._tasks.get(task_id)
@ -135,6 +139,8 @@ class TaskManager:
task.result = result
if error is not None:
task.error = error
if progress_detail is not None:
task.progress_detail = progress_detail
def complete_task(self, task_id: str, result: Dict):
"""标记任务完成"""

View file

@ -5,6 +5,47 @@
from .ontology_generator import OntologyGenerator
from .graph_builder import GraphBuilderService
from .text_processor import TextProcessor
from .zep_entity_reader import ZepEntityReader, EntityNode, FilteredEntities
from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
from .simulation_manager import SimulationManager, SimulationState, SimulationStatus
from .simulation_config_generator import (
SimulationConfigGenerator,
SimulationParameters,
AgentActivityConfig,
TimeSimulationConfig,
EventConfig,
PlatformConfig
)
from .simulation_runner import (
SimulationRunner,
SimulationRunState,
RunnerStatus,
AgentAction,
RoundSummary
)
__all__ = ['OntologyGenerator', 'GraphBuilderService', 'TextProcessor']
__all__ = [
'OntologyGenerator',
'GraphBuilderService',
'TextProcessor',
'ZepEntityReader',
'EntityNode',
'FilteredEntities',
'OasisProfileGenerator',
'OasisAgentProfile',
'SimulationManager',
'SimulationState',
'SimulationStatus',
'SimulationConfigGenerator',
'SimulationParameters',
'AgentActivityConfig',
'TimeSimulationConfig',
'EventConfig',
'PlatformConfig',
'SimulationRunner',
'SimulationRunState',
'RunnerStatus',
'AgentAction',
'RoundSummary',
]

View file

@ -0,0 +1,561 @@
"""
OASIS Agent Profile生成器
将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式
"""
import json
import random
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
from openai import OpenAI
from ..config import Config
from ..utils.logger import get_logger
from .zep_entity_reader import EntityNode, ZepEntityReader
logger = get_logger('mirofish.oasis_profile')
@dataclass
class OasisAgentProfile:
"""OASIS Agent Profile数据结构"""
# 通用字段
user_id: int
user_name: str
name: str
bio: str
persona: str
# 可选字段 - Reddit风格
karma: int = 1000
# 可选字段 - Twitter风格
friend_count: int = 100
follower_count: int = 150
statuses_count: int = 500
# 额外人设信息
age: Optional[int] = None
gender: Optional[str] = None
mbti: Optional[str] = None
country: Optional[str] = None
profession: Optional[str] = None
interested_topics: List[str] = field(default_factory=list)
# 来源实体信息
source_entity_uuid: Optional[str] = None
source_entity_type: Optional[str] = None
created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d"))
def to_reddit_format(self) -> Dict[str, Any]:
"""转换为Reddit平台格式"""
profile = {
"user_id": self.user_id,
"user_name": self.user_name,
"name": self.name,
"bio": self.bio,
"persona": self.persona,
"karma": self.karma,
"created_at": self.created_at,
}
# 添加额外人设信息(如果有)
if self.age:
profile["age"] = self.age
if self.gender:
profile["gender"] = self.gender
if self.mbti:
profile["mbti"] = self.mbti
if self.country:
profile["country"] = self.country
if self.profession:
profile["profession"] = self.profession
if self.interested_topics:
profile["interested_topics"] = self.interested_topics
return profile
def to_twitter_format(self) -> Dict[str, Any]:
"""转换为Twitter平台格式"""
profile = {
"user_id": self.user_id,
"user_name": self.user_name,
"name": self.name,
"bio": self.bio,
"persona": self.persona,
"friend_count": self.friend_count,
"follower_count": self.follower_count,
"statuses_count": self.statuses_count,
"created_at": self.created_at,
}
# 添加额外人设信息
if self.age:
profile["age"] = self.age
if self.gender:
profile["gender"] = self.gender
if self.mbti:
profile["mbti"] = self.mbti
if self.country:
profile["country"] = self.country
if self.profession:
profile["profession"] = self.profession
if self.interested_topics:
profile["interested_topics"] = self.interested_topics
return profile
def to_dict(self) -> Dict[str, Any]:
"""转换为完整字典格式"""
return {
"user_id": self.user_id,
"user_name": self.user_name,
"name": self.name,
"bio": self.bio,
"persona": self.persona,
"karma": self.karma,
"friend_count": self.friend_count,
"follower_count": self.follower_count,
"statuses_count": self.statuses_count,
"age": self.age,
"gender": self.gender,
"mbti": self.mbti,
"country": self.country,
"profession": self.profession,
"interested_topics": self.interested_topics,
"source_entity_uuid": self.source_entity_uuid,
"source_entity_type": self.source_entity_type,
"created_at": self.created_at,
}
class OasisProfileGenerator:
"""
OASIS Profile生成器
将Zep图谱中的实体转换为OASIS模拟所需的Agent Profile
"""
# MBTI类型列表
MBTI_TYPES = [
"INTJ", "INTP", "ENTJ", "ENTP",
"INFJ", "INFP", "ENFJ", "ENFP",
"ISTJ", "ISFJ", "ESTJ", "ESFJ",
"ISTP", "ISFP", "ESTP", "ESFP"
]
# 常见国家列表
COUNTRIES = [
"China", "US", "UK", "Japan", "Germany", "France",
"Canada", "Australia", "Brazil", "India", "South Korea"
]
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model_name: Optional[str] = None
):
self.api_key = api_key or Config.LLM_API_KEY
self.base_url = base_url or Config.LLM_BASE_URL
self.model_name = model_name or Config.LLM_MODEL_NAME
if not self.api_key:
raise ValueError("LLM_API_KEY 未配置")
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
def generate_profile_from_entity(
self,
entity: EntityNode,
user_id: int,
use_llm: bool = True
) -> OasisAgentProfile:
"""
从Zep实体生成OASIS Agent Profile
Args:
entity: Zep实体节点
user_id: 用户ID用于OASIS
use_llm: 是否使用LLM生成详细人设
Returns:
OasisAgentProfile
"""
entity_type = entity.get_entity_type() or "Entity"
# 基础信息
name = entity.name
user_name = self._generate_username(name)
# 构建上下文信息
context = self._build_entity_context(entity)
if use_llm:
# 使用LLM生成详细人设
profile_data = self._generate_profile_with_llm(
entity_name=name,
entity_type=entity_type,
entity_summary=entity.summary,
entity_attributes=entity.attributes,
context=context
)
else:
# 使用规则生成基础人设
profile_data = self._generate_profile_rule_based(
entity_name=name,
entity_type=entity_type,
entity_summary=entity.summary,
entity_attributes=entity.attributes
)
return OasisAgentProfile(
user_id=user_id,
user_name=user_name,
name=name,
bio=profile_data.get("bio", f"{entity_type}: {name}"),
persona=profile_data.get("persona", entity.summary or f"A {entity_type} named {name}."),
karma=profile_data.get("karma", random.randint(500, 5000)),
friend_count=profile_data.get("friend_count", random.randint(50, 500)),
follower_count=profile_data.get("follower_count", random.randint(100, 1000)),
statuses_count=profile_data.get("statuses_count", random.randint(100, 2000)),
age=profile_data.get("age"),
gender=profile_data.get("gender"),
mbti=profile_data.get("mbti"),
country=profile_data.get("country"),
profession=profile_data.get("profession"),
interested_topics=profile_data.get("interested_topics", []),
source_entity_uuid=entity.uuid,
source_entity_type=entity_type,
)
def _generate_username(self, name: str) -> str:
"""生成用户名"""
# 移除特殊字符,转换为小写
username = name.lower().replace(" ", "_")
username = ''.join(c for c in username if c.isalnum() or c == '_')
# 添加随机后缀避免重复
suffix = random.randint(100, 999)
return f"{username}_{suffix}"
def _build_entity_context(self, entity: EntityNode) -> str:
"""构建实体的上下文信息"""
context_parts = []
# 添加相关边信息
if entity.related_edges:
relationships = []
for edge in entity.related_edges[:10]: # 最多取10条
if edge.get("fact"):
relationships.append(edge["fact"])
if relationships:
context_parts.append("Related facts:\n" + "\n".join(f"- {r}" for r in relationships))
# 添加关联节点信息
if entity.related_nodes:
related_names = [n["name"] for n in entity.related_nodes[:5]]
if related_names:
context_parts.append(f"Related to: {', '.join(related_names)}")
return "\n\n".join(context_parts)
def _generate_profile_with_llm(
self,
entity_name: str,
entity_type: str,
entity_summary: str,
entity_attributes: Dict[str, Any],
context: str
) -> Dict[str, Any]:
"""使用LLM生成详细人设"""
prompt = f"""Based on the following entity information, generate a detailed social media user profile for simulation purposes.
Entity Information:
- Name: {entity_name}
- Type: {entity_type}
- Summary: {entity_summary}
- Attributes: {json.dumps(entity_attributes, ensure_ascii=False)}
Context:
{context}
Generate a JSON object with the following fields:
{{
"bio": "A short bio (max 150 chars) suitable for social media",
"persona": "A detailed persona description (2-3 sentences) describing personality, interests, and behavior patterns",
"age": <integer between 18-65, or null if not applicable>,
"gender": "<male/female/other, or null if not applicable>",
"mbti": "<MBTI type like INTJ, ENFP, etc., or null>",
"country": "<country name, or null>",
"profession": "<profession/occupation, or null>",
"interested_topics": ["topic1", "topic2", ...]
}}
Important:
- The profile should be consistent with the entity type and context
- Make the persona feel realistic and suitable for social media simulation
- If the entity is an organization, institution, or non-person, adapt the profile accordingly (e.g., as an official account)
- Return ONLY the JSON object, no additional text"""
try:
# 使用重试机制调用LLM API
from ..utils.retry import RetryableAPIClient
retry_client = RetryableAPIClient(max_retries=3, initial_delay=1.0)
def call_llm():
return self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": "You are a profile generator for social media simulation. Generate realistic user profiles based on entity information."},
{"role": "user", "content": prompt}
],
response_format={"type": "json_object"},
temperature=0.7
)
response = retry_client.call_with_retry(call_llm)
result = json.loads(response.choices[0].message.content)
return result
except Exception as e:
logger.warning(f"LLM生成人设失败已重试: {str(e)}, 使用规则生成")
return self._generate_profile_rule_based(
entity_name, entity_type, entity_summary, entity_attributes
)
def _generate_profile_rule_based(
self,
entity_name: str,
entity_type: str,
entity_summary: str,
entity_attributes: Dict[str, Any]
) -> Dict[str, Any]:
"""使用规则生成基础人设"""
# 根据实体类型生成不同的人设
entity_type_lower = entity_type.lower()
if entity_type_lower in ["student", "alumni"]:
return {
"bio": f"{entity_type} with interests in academics and social issues.",
"persona": f"{entity_name} is a {entity_type.lower()} who is actively engaged in academic and social discussions. They enjoy sharing perspectives and connecting with peers.",
"age": random.randint(18, 30),
"gender": random.choice(["male", "female"]),
"mbti": random.choice(self.MBTI_TYPES),
"country": random.choice(self.COUNTRIES),
"profession": "Student",
"interested_topics": ["Education", "Social Issues", "Technology"],
}
elif entity_type_lower in ["publicfigure", "expert", "faculty"]:
return {
"bio": f"Expert and thought leader in their field.",
"persona": f"{entity_name} is a recognized {entity_type.lower()} who shares insights and opinions on important matters. They are known for their expertise and influence in public discourse.",
"age": random.randint(35, 60),
"gender": random.choice(["male", "female"]),
"mbti": random.choice(["ENTJ", "INTJ", "ENTP", "INTP"]),
"country": random.choice(self.COUNTRIES),
"profession": entity_attributes.get("occupation", "Expert"),
"interested_topics": ["Politics", "Economics", "Culture & Society"],
}
elif entity_type_lower in ["mediaoutlet", "socialmediaplatform"]:
return {
"bio": f"Official account for {entity_name}. News and updates.",
"persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.",
"profession": "Media",
"interested_topics": ["General News", "Current Events", "Public Affairs"],
}
elif entity_type_lower in ["university", "governmentagency", "ngo", "organization"]:
return {
"bio": f"Official account of {entity_name}.",
"persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.",
"profession": entity_type,
"interested_topics": ["Public Policy", "Community", "Official Announcements"],
}
else:
# 默认人设
return {
"bio": entity_summary[:150] if entity_summary else f"{entity_type}: {entity_name}",
"persona": entity_summary or f"{entity_name} is a {entity_type.lower()} participating in social discussions.",
"age": random.randint(25, 50),
"gender": random.choice(["male", "female"]),
"mbti": random.choice(self.MBTI_TYPES),
"country": random.choice(self.COUNTRIES),
"profession": entity_type,
"interested_topics": ["General", "Social Issues"],
}
def generate_profiles_from_entities(
self,
entities: List[EntityNode],
use_llm: bool = True,
progress_callback: Optional[callable] = None
) -> List[OasisAgentProfile]:
"""
批量从实体生成Agent Profile
Args:
entities: 实体列表
use_llm: 是否使用LLM生成详细人设
progress_callback: 进度回调函数 (current, total, message)
Returns:
Agent Profile列表
"""
profiles = []
total = len(entities)
for idx, entity in enumerate(entities):
if progress_callback:
progress_callback(idx + 1, total, f"生成 {entity.name} 的人设...")
try:
profile = self.generate_profile_from_entity(
entity=entity,
user_id=idx,
use_llm=use_llm
)
profiles.append(profile)
except Exception as e:
logger.error(f"生成实体 {entity.name} 的人设失败: {str(e)}")
# 创建一个基础profile
profiles.append(OasisAgentProfile(
user_id=idx,
user_name=self._generate_username(entity.name),
name=entity.name,
bio=f"{entity.get_entity_type() or 'Entity'}: {entity.name}",
persona=entity.summary or f"A participant in social discussions.",
source_entity_uuid=entity.uuid,
source_entity_type=entity.get_entity_type(),
))
return profiles
def save_profiles(
self,
profiles: List[OasisAgentProfile],
file_path: str,
platform: str = "reddit"
):
"""
保存Profile到文件根据平台选择正确格式
OASIS平台格式要求
- Twitter: CSV格式
- Reddit: JSON格式
Args:
profiles: Profile列表
file_path: 文件路径
platform: 平台类型 ("reddit" "twitter")
"""
if platform == "twitter":
self._save_twitter_csv(profiles, file_path)
else:
self._save_reddit_json(profiles, file_path)
def _save_twitter_csv(self, profiles: List[OasisAgentProfile], file_path: str):
"""
保存Twitter Profile为CSV格式
OASIS Twitter要求的CSV字段
user_id, user_name, name, bio, friend_count, follower_count, statuses_count, created_at
"""
import csv
# 确保文件扩展名是.csv
if not file_path.endswith('.csv'):
file_path = file_path.replace('.json', '.csv')
with open(file_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
# 写入表头
headers = ['user_id', 'user_name', 'name', 'bio', 'friend_count',
'follower_count', 'statuses_count', 'created_at']
writer.writerow(headers)
# 写入数据行
for profile in profiles:
# bio需要处理换行符和逗号
bio = profile.bio.replace('\n', ' ').replace('\r', ' ')
row = [
profile.user_id,
profile.user_name,
profile.name,
bio,
profile.friend_count,
profile.follower_count,
profile.statuses_count,
profile.created_at
]
writer.writerow(row)
logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (CSV格式)")
def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str):
"""
保存Reddit Profile为JSON格式
OASIS Reddit支持两种JSON格式
1. 基础格式: user_id, user_name, name, bio, karma, created_at
2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics
我们使用详细格式与用户示例数据(36个简单人设.json)保持一致
"""
data = []
for profile in profiles:
# 使用详细格式(与用户示例兼容)
item = {
"realname": profile.name,
"username": profile.user_name,
"bio": profile.bio[:150] if profile.bio else "", # OASIS bio限制150字符
"persona": profile.persona or f"{profile.name} is a participant in social discussions.",
}
# 添加人设详情字段
if profile.age:
item["age"] = profile.age
if profile.gender:
item["gender"] = profile.gender
if profile.mbti:
item["mbti"] = profile.mbti
if profile.country:
item["country"] = profile.country
if profile.profession:
item["profession"] = profile.profession
if profile.interested_topics:
item["interested_topics"] = profile.interested_topics
data.append(item)
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式)")
# 保留旧方法名作为别名,保持向后兼容
def save_profiles_to_json(
self,
profiles: List[OasisAgentProfile],
file_path: str,
platform: str = "reddit"
):
"""[已废弃] 请使用 save_profiles() 方法"""
logger.warning("save_profiles_to_json已废弃请使用save_profiles方法")
self.save_profiles(profiles, file_path, platform)

View file

@ -0,0 +1,584 @@
"""
模拟配置智能生成器
使用LLM根据模拟需求文档内容图谱信息自动生成细致的模拟参数
实现全程自动化无需人工设置参数
"""
import json
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field, asdict
from datetime import datetime
from openai import OpenAI
from ..config import Config
from ..utils.logger import get_logger
from .zep_entity_reader import EntityNode, ZepEntityReader
logger = get_logger('mirofish.simulation_config')
@dataclass
class AgentActivityConfig:
"""单个Agent的活动配置"""
agent_id: int
entity_uuid: str
entity_name: str
entity_type: str
# 活跃度配置 (0.0-1.0)
activity_level: float = 0.5 # 整体活跃度
# 发言频率(每小时预期发言次数)
posts_per_hour: float = 1.0
comments_per_hour: float = 2.0
# 活跃时间段24小时制0-23
active_hours: List[int] = field(default_factory=lambda: list(range(8, 23)))
# 响应速度(对热点事件的反应延迟,单位:模拟分钟)
response_delay_min: int = 5
response_delay_max: int = 60
# 情感倾向 (-1.0到1.0,负面到正面)
sentiment_bias: float = 0.0
# 立场(对特定话题的态度)
stance: str = "neutral" # supportive, opposing, neutral, observer
# 影响力权重决定其发言被其他Agent看到的概率
influence_weight: float = 1.0
@dataclass
class TimeSimulationConfig:
"""时间模拟配置"""
# 模拟总时长(模拟小时数)
total_simulation_hours: int = 72 # 默认模拟72小时3天
# 每轮代表的时间(模拟分钟)
minutes_per_round: int = 30
# 每小时激活的Agent数量范围
agents_per_hour_min: int = 5
agents_per_hour_max: int = 20
# 高峰时段(活跃度提升)
peak_hours: List[int] = field(default_factory=lambda: [9, 10, 11, 14, 15, 20, 21, 22])
peak_activity_multiplier: float = 1.5
# 低谷时段(活跃度降低)
off_peak_hours: List[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5, 6])
off_peak_activity_multiplier: float = 0.3
@dataclass
class EventConfig:
"""事件配置"""
# 初始事件(模拟开始时的触发事件)
initial_posts: List[Dict[str, Any]] = field(default_factory=list)
# 定时事件(在特定时间触发的事件)
scheduled_events: List[Dict[str, Any]] = field(default_factory=list)
# 热点话题关键词
hot_topics: List[str] = field(default_factory=list)
# 舆论引导方向
narrative_direction: str = ""
@dataclass
class PlatformConfig:
"""平台特定配置"""
platform: str # twitter or reddit
# 推荐算法权重
recency_weight: float = 0.4 # 时间新鲜度
popularity_weight: float = 0.3 # 热度
relevance_weight: float = 0.3 # 相关性
# 病毒传播阈值(达到多少互动后触发扩散)
viral_threshold: int = 10
# 回声室效应强度(相似观点聚集程度)
echo_chamber_strength: float = 0.5
@dataclass
class SimulationParameters:
"""完整的模拟参数配置"""
# 基础信息
simulation_id: str
project_id: str
graph_id: str
simulation_requirement: str
# 时间配置
time_config: TimeSimulationConfig = field(default_factory=TimeSimulationConfig)
# Agent配置列表
agent_configs: List[AgentActivityConfig] = field(default_factory=list)
# 事件配置
event_config: EventConfig = field(default_factory=EventConfig)
# 平台配置
twitter_config: Optional[PlatformConfig] = None
reddit_config: Optional[PlatformConfig] = None
# LLM配置
llm_model: str = ""
llm_base_url: str = ""
# 生成元数据
generated_at: str = field(default_factory=lambda: datetime.now().isoformat())
generation_reasoning: str = "" # LLM的推理说明
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"simulation_id": self.simulation_id,
"project_id": self.project_id,
"graph_id": self.graph_id,
"simulation_requirement": self.simulation_requirement,
"time_config": asdict(self.time_config),
"agent_configs": [asdict(a) for a in self.agent_configs],
"event_config": asdict(self.event_config),
"twitter_config": asdict(self.twitter_config) if self.twitter_config else None,
"reddit_config": asdict(self.reddit_config) if self.reddit_config else None,
"llm_model": self.llm_model,
"llm_base_url": self.llm_base_url,
"generated_at": self.generated_at,
"generation_reasoning": self.generation_reasoning,
}
def to_json(self, indent: int = 2) -> str:
"""转换为JSON字符串"""
return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent)
class SimulationConfigGenerator:
"""
模拟配置智能生成器
使用LLM分析模拟需求文档内容图谱实体信息
自动生成最佳的模拟参数配置
"""
# 上下文最大字符数
MAX_CONTEXT_LENGTH = 50000
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model_name: Optional[str] = None
):
self.api_key = api_key or Config.LLM_API_KEY
self.base_url = base_url or Config.LLM_BASE_URL
self.model_name = model_name or Config.LLM_MODEL_NAME
if not self.api_key:
raise ValueError("LLM_API_KEY 未配置")
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
def generate_config(
self,
simulation_id: str,
project_id: str,
graph_id: str,
simulation_requirement: str,
document_text: str,
entities: List[EntityNode],
enable_twitter: bool = True,
enable_reddit: bool = True,
) -> SimulationParameters:
"""
智能生成完整的模拟配置
Args:
simulation_id: 模拟ID
project_id: 项目ID
graph_id: 图谱ID
simulation_requirement: 模拟需求描述
document_text: 原始文档内容
entities: 过滤后的实体列表
enable_twitter: 是否启用Twitter
enable_reddit: 是否启用Reddit
Returns:
SimulationParameters: 完整的模拟参数
"""
logger.info(f"开始智能生成模拟配置: simulation_id={simulation_id}")
# 1. 构建上下文信息截断到50000字符
context = self._build_context(
simulation_requirement=simulation_requirement,
document_text=document_text,
entities=entities
)
# 2. 调用LLM生成配置
llm_result = self._generate_config_with_llm(
context=context,
entities=entities,
enable_twitter=enable_twitter,
enable_reddit=enable_reddit
)
# 3. 构建SimulationParameters对象
params = self._build_parameters(
simulation_id=simulation_id,
project_id=project_id,
graph_id=graph_id,
simulation_requirement=simulation_requirement,
entities=entities,
llm_result=llm_result,
enable_twitter=enable_twitter,
enable_reddit=enable_reddit
)
logger.info(f"模拟配置生成完成: {len(params.agent_configs)} 个Agent配置")
return params
def _build_context(
self,
simulation_requirement: str,
document_text: str,
entities: List[EntityNode]
) -> str:
"""构建LLM上下文截断到最大长度"""
# 实体摘要
entity_summary = self._summarize_entities(entities)
# 构建上下文
context_parts = [
f"## 模拟需求\n{simulation_requirement}",
f"\n## 实体信息 ({len(entities)}个)\n{entity_summary}",
]
current_length = sum(len(p) for p in context_parts)
remaining_length = self.MAX_CONTEXT_LENGTH - current_length - 500 # 留500字符余量
if remaining_length > 0 and document_text:
doc_text = document_text[:remaining_length]
if len(document_text) > remaining_length:
doc_text += "\n...(文档已截断)"
context_parts.append(f"\n## 原始文档内容\n{doc_text}")
return "\n".join(context_parts)
def _summarize_entities(self, entities: List[EntityNode]) -> str:
"""生成实体摘要"""
lines = []
# 按类型分组
by_type: Dict[str, List[EntityNode]] = {}
for e in entities:
t = e.get_entity_type() or "Unknown"
if t not in by_type:
by_type[t] = []
by_type[t].append(e)
for entity_type, type_entities in by_type.items():
lines.append(f"\n### {entity_type} ({len(type_entities)}个)")
for e in type_entities[:10]: # 每类最多显示10个
summary_preview = (e.summary[:100] + "...") if len(e.summary) > 100 else e.summary
lines.append(f"- {e.name}: {summary_preview}")
if len(type_entities) > 10:
lines.append(f" ... 还有 {len(type_entities) - 10}")
return "\n".join(lines)
def _generate_config_with_llm(
self,
context: str,
entities: List[EntityNode],
enable_twitter: bool,
enable_reddit: bool
) -> Dict[str, Any]:
"""调用LLM生成配置"""
# 构建实体列表用于Agent配置
entity_list = []
for i, e in enumerate(entities):
entity_list.append({
"agent_id": i,
"entity_uuid": e.uuid,
"entity_name": e.name,
"entity_type": e.get_entity_type() or "Unknown",
"summary": e.summary[:200] if e.summary else ""
})
prompt = f"""你是一个社交媒体舆论模拟专家。请根据以下信息,生成详细的模拟参数配置。
{context}
## 实体列表(需要为每个实体生成活动配置)
```json
{json.dumps(entity_list, ensure_ascii=False, indent=2)}
```
## 任务
请生成一个JSON配置包含以下部分
1. **time_config** - 时间模拟配置
- total_simulation_hours: 模拟总时长小时根据事件性质决定短期热点24-72小时长期舆论168-336小时
- minutes_per_round: 每轮代表的时间分钟建议15-60
- agents_per_hour_min/max: 每小时激活的Agent数量范围
- peak_hours: 高峰时段列表0-23
- off_peak_hours: 低谷时段列表
2. **agent_configs** - 每个Agent的活动配置必须为每个实体生成
对于每个agent_id设置
- activity_level: 活跃度(0.0-1.0)官方机构通常0.1-0.3媒体0.3-0.5个人0.5-0.9
- posts_per_hour: 每小时发帖频率官方机构0.05-0.2媒体0.5-2个人0.1-1
- comments_per_hour: 每小时评论频率
- active_hours: 活跃时间段列表官方通常工作时间个人更分散
- response_delay_min/max: 响应延迟模拟分钟官方较慢(30-180)个人较快(1-30)
- sentiment_bias: 情感倾向(-1到1)根据实体立场设置
- stance: 立场(supportive/opposing/neutral/observer)
- influence_weight: 影响力权重知名人物和媒体较高
3. **event_config** - 事件配置
- initial_posts: 初始帖子列表包含content和poster_agent_id
- hot_topics: 热点话题关键词列表
- narrative_direction: 舆论发展方向描述
4. **platform_configs** - 平台配置如果启用
- viral_threshold: 病毒传播阈值
- echo_chamber_strength: 回声室效应强度(0-1)
5. **reasoning** - 你的推理说明解释为什么这样设置参数
## 重要原则
- 官方机构UniversityGovernmentAgency发言频率低但影响力大
- 媒体MediaOutlet发言频率中等传播速度快
- 个人StudentPublicFigure发言频率高但影响力分散
- 根据模拟需求判断各实体的立场和情感倾向
- 时间配置要符合真实社交媒体的使用规律
请返回JSON格式不要包含markdown代码块标记"""
try:
# 使用重试机制调用LLM API
from ..utils.retry import RetryableAPIClient
retry_client = RetryableAPIClient(max_retries=3, initial_delay=2.0, max_delay=60.0)
def call_llm():
return self.client.chat.completions.create(
model=self.model_name,
messages=[
{
"role": "system",
"content": "你是社交媒体舆论模拟专家擅长设计真实的模拟参数。返回纯JSON格式不要markdown。"
},
{"role": "user", "content": prompt}
],
response_format={"type": "json_object"},
temperature=0.7,
max_tokens=8000
)
response = retry_client.call_with_retry(call_llm)
result = json.loads(response.choices[0].message.content)
logger.info(f"LLM配置生成成功")
return result
except Exception as e:
logger.error(f"LLM配置生成失败已重试: {str(e)}")
# 返回默认配置
return self._generate_default_config(entities)
def _generate_default_config(self, entities: List[EntityNode]) -> Dict[str, Any]:
"""生成默认配置LLM失败时的fallback"""
agent_configs = []
for i, e in enumerate(entities):
entity_type = (e.get_entity_type() or "Unknown").lower()
# 根据实体类型设置默认参数
if entity_type in ["university", "governmentagency", "ngo"]:
config = {
"agent_id": i,
"activity_level": 0.2,
"posts_per_hour": 0.1,
"comments_per_hour": 0.05,
"active_hours": list(range(9, 18)),
"response_delay_min": 60,
"response_delay_max": 240,
"sentiment_bias": 0.0,
"stance": "neutral",
"influence_weight": 3.0
}
elif entity_type in ["mediaoutlet"]:
config = {
"agent_id": i,
"activity_level": 0.6,
"posts_per_hour": 1.0,
"comments_per_hour": 0.5,
"active_hours": list(range(6, 24)),
"response_delay_min": 5,
"response_delay_max": 30,
"sentiment_bias": 0.0,
"stance": "observer",
"influence_weight": 2.5
}
elif entity_type in ["publicfigure", "expert"]:
config = {
"agent_id": i,
"activity_level": 0.5,
"posts_per_hour": 0.3,
"comments_per_hour": 0.5,
"active_hours": list(range(8, 23)),
"response_delay_min": 10,
"response_delay_max": 60,
"sentiment_bias": 0.0,
"stance": "neutral",
"influence_weight": 2.0
}
else: # Student, Person, etc.
config = {
"agent_id": i,
"activity_level": 0.7,
"posts_per_hour": 0.5,
"comments_per_hour": 1.0,
"active_hours": list(range(7, 24)),
"response_delay_min": 1,
"response_delay_max": 20,
"sentiment_bias": 0.0,
"stance": "neutral",
"influence_weight": 1.0
}
agent_configs.append(config)
return {
"time_config": {
"total_simulation_hours": 72,
"minutes_per_round": 30,
"agents_per_hour_min": max(1, len(entities) // 10),
"agents_per_hour_max": max(5, len(entities) // 3),
"peak_hours": [9, 10, 11, 14, 15, 20, 21, 22],
"off_peak_hours": [0, 1, 2, 3, 4, 5]
},
"agent_configs": agent_configs,
"event_config": {
"initial_posts": [],
"hot_topics": [],
"narrative_direction": ""
},
"reasoning": "使用默认配置LLM生成失败"
}
def _build_parameters(
self,
simulation_id: str,
project_id: str,
graph_id: str,
simulation_requirement: str,
entities: List[EntityNode],
llm_result: Dict[str, Any],
enable_twitter: bool,
enable_reddit: bool
) -> SimulationParameters:
"""根据LLM结果构建SimulationParameters对象"""
# 时间配置
time_cfg = llm_result.get("time_config", {})
time_config = TimeSimulationConfig(
total_simulation_hours=time_cfg.get("total_simulation_hours", 72),
minutes_per_round=time_cfg.get("minutes_per_round", 30),
agents_per_hour_min=time_cfg.get("agents_per_hour_min", 5),
agents_per_hour_max=time_cfg.get("agents_per_hour_max", 20),
peak_hours=time_cfg.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]),
off_peak_hours=time_cfg.get("off_peak_hours", [0, 1, 2, 3, 4, 5]),
peak_activity_multiplier=time_cfg.get("peak_activity_multiplier", 1.5),
off_peak_activity_multiplier=time_cfg.get("off_peak_activity_multiplier", 0.3)
)
# Agent配置
agent_configs = []
llm_agent_configs = {cfg["agent_id"]: cfg for cfg in llm_result.get("agent_configs", [])}
for i, entity in enumerate(entities):
cfg = llm_agent_configs.get(i, {})
agent_config = AgentActivityConfig(
agent_id=i,
entity_uuid=entity.uuid,
entity_name=entity.name,
entity_type=entity.get_entity_type() or "Unknown",
activity_level=cfg.get("activity_level", 0.5),
posts_per_hour=cfg.get("posts_per_hour", 0.5),
comments_per_hour=cfg.get("comments_per_hour", 1.0),
active_hours=cfg.get("active_hours", list(range(8, 23))),
response_delay_min=cfg.get("response_delay_min", 5),
response_delay_max=cfg.get("response_delay_max", 60),
sentiment_bias=cfg.get("sentiment_bias", 0.0),
stance=cfg.get("stance", "neutral"),
influence_weight=cfg.get("influence_weight", 1.0)
)
agent_configs.append(agent_config)
# 事件配置
event_cfg = llm_result.get("event_config", {})
event_config = EventConfig(
initial_posts=event_cfg.get("initial_posts", []),
scheduled_events=event_cfg.get("scheduled_events", []),
hot_topics=event_cfg.get("hot_topics", []),
narrative_direction=event_cfg.get("narrative_direction", "")
)
# 平台配置
twitter_config = None
reddit_config = None
platform_cfgs = llm_result.get("platform_configs", {})
if enable_twitter:
tw_cfg = platform_cfgs.get("twitter", {})
twitter_config = PlatformConfig(
platform="twitter",
recency_weight=tw_cfg.get("recency_weight", 0.4),
popularity_weight=tw_cfg.get("popularity_weight", 0.3),
relevance_weight=tw_cfg.get("relevance_weight", 0.3),
viral_threshold=tw_cfg.get("viral_threshold", 10),
echo_chamber_strength=tw_cfg.get("echo_chamber_strength", 0.5)
)
if enable_reddit:
rd_cfg = platform_cfgs.get("reddit", {})
reddit_config = PlatformConfig(
platform="reddit",
recency_weight=rd_cfg.get("recency_weight", 0.3),
popularity_weight=rd_cfg.get("popularity_weight", 0.4),
relevance_weight=rd_cfg.get("relevance_weight", 0.3),
viral_threshold=rd_cfg.get("viral_threshold", 15),
echo_chamber_strength=rd_cfg.get("echo_chamber_strength", 0.6)
)
return SimulationParameters(
simulation_id=simulation_id,
project_id=project_id,
graph_id=graph_id,
simulation_requirement=simulation_requirement,
time_config=time_config,
agent_configs=agent_configs,
event_config=event_config,
twitter_config=twitter_config,
reddit_config=reddit_config,
llm_model=self.model_name,
llm_base_url=self.base_url,
generation_reasoning=llm_result.get("reasoning", "")
)

View file

@ -0,0 +1,546 @@
"""
OASIS模拟管理器
管理Twitter和Reddit双平台并行模拟
使用预设脚本 + LLM智能生成配置参数
"""
import os
import json
import shutil
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from ..config import Config
from ..utils.logger import get_logger
from .zep_entity_reader import ZepEntityReader, FilteredEntities
from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
from .simulation_config_generator import SimulationConfigGenerator, SimulationParameters
logger = get_logger('mirofish.simulation')
class SimulationStatus(str, Enum):
"""模拟状态"""
CREATED = "created"
PREPARING = "preparing"
READY = "ready"
RUNNING = "running"
PAUSED = "paused"
COMPLETED = "completed"
FAILED = "failed"
class PlatformType(str, Enum):
"""平台类型"""
TWITTER = "twitter"
REDDIT = "reddit"
@dataclass
class SimulationState:
"""模拟状态"""
simulation_id: str
project_id: str
graph_id: str
# 平台启用状态
enable_twitter: bool = True
enable_reddit: bool = True
# 状态
status: SimulationStatus = SimulationStatus.CREATED
# 准备阶段数据
entities_count: int = 0
profiles_count: int = 0
entity_types: List[str] = field(default_factory=list)
# 配置生成信息
config_generated: bool = False
config_reasoning: str = ""
# 运行时数据
current_round: int = 0
twitter_status: str = "not_started"
reddit_status: str = "not_started"
# 时间戳
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
# 错误信息
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""完整状态字典(内部使用)"""
return {
"simulation_id": self.simulation_id,
"project_id": self.project_id,
"graph_id": self.graph_id,
"enable_twitter": self.enable_twitter,
"enable_reddit": self.enable_reddit,
"status": self.status.value,
"entities_count": self.entities_count,
"profiles_count": self.profiles_count,
"entity_types": self.entity_types,
"config_generated": self.config_generated,
"config_reasoning": self.config_reasoning,
"current_round": self.current_round,
"twitter_status": self.twitter_status,
"reddit_status": self.reddit_status,
"created_at": self.created_at,
"updated_at": self.updated_at,
"error": self.error,
}
def to_simple_dict(self) -> Dict[str, Any]:
"""简化状态字典API返回使用"""
return {
"simulation_id": self.simulation_id,
"project_id": self.project_id,
"graph_id": self.graph_id,
"status": self.status.value,
"entities_count": self.entities_count,
"profiles_count": self.profiles_count,
"entity_types": self.entity_types,
"config_generated": self.config_generated,
"error": self.error,
}
class SimulationManager:
"""
模拟管理器
核心功能
1. 从Zep图谱读取实体并过滤
2. 生成OASIS Agent Profile
3. 使用LLM智能生成模拟配置参数
4. 准备预设脚本所需的所有文件
"""
# 模拟数据存储目录
SIMULATION_DATA_DIR = os.path.join(
os.path.dirname(__file__),
'../../uploads/simulations'
)
# 预设脚本目录
SCRIPTS_DIR = os.path.join(
os.path.dirname(__file__),
'../../scripts'
)
def __init__(self):
# 确保目录存在
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
# 内存中的模拟状态缓存
self._simulations: Dict[str, SimulationState] = {}
def _get_simulation_dir(self, simulation_id: str) -> str:
"""获取模拟数据目录"""
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
os.makedirs(sim_dir, exist_ok=True)
return sim_dir
def _save_simulation_state(self, state: SimulationState):
"""保存模拟状态到文件"""
sim_dir = self._get_simulation_dir(state.simulation_id)
state_file = os.path.join(sim_dir, "state.json")
state.updated_at = datetime.now().isoformat()
with open(state_file, 'w', encoding='utf-8') as f:
json.dump(state.to_dict(), f, ensure_ascii=False, indent=2)
self._simulations[state.simulation_id] = state
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
"""从文件加载模拟状态"""
if simulation_id in self._simulations:
return self._simulations[simulation_id]
sim_dir = self._get_simulation_dir(simulation_id)
state_file = os.path.join(sim_dir, "state.json")
if not os.path.exists(state_file):
return None
with open(state_file, 'r', encoding='utf-8') as f:
data = json.load(f)
state = SimulationState(
simulation_id=simulation_id,
project_id=data.get("project_id", ""),
graph_id=data.get("graph_id", ""),
enable_twitter=data.get("enable_twitter", True),
enable_reddit=data.get("enable_reddit", True),
status=SimulationStatus(data.get("status", "created")),
entities_count=data.get("entities_count", 0),
profiles_count=data.get("profiles_count", 0),
entity_types=data.get("entity_types", []),
config_generated=data.get("config_generated", False),
config_reasoning=data.get("config_reasoning", ""),
current_round=data.get("current_round", 0),
twitter_status=data.get("twitter_status", "not_started"),
reddit_status=data.get("reddit_status", "not_started"),
created_at=data.get("created_at", datetime.now().isoformat()),
updated_at=data.get("updated_at", datetime.now().isoformat()),
error=data.get("error"),
)
self._simulations[simulation_id] = state
return state
def create_simulation(
self,
project_id: str,
graph_id: str,
enable_twitter: bool = True,
enable_reddit: bool = True,
) -> SimulationState:
"""
创建新的模拟
Args:
project_id: 项目ID
graph_id: Zep图谱ID
enable_twitter: 是否启用Twitter模拟
enable_reddit: 是否启用Reddit模拟
Returns:
SimulationState
"""
import uuid
simulation_id = f"sim_{uuid.uuid4().hex[:12]}"
state = SimulationState(
simulation_id=simulation_id,
project_id=project_id,
graph_id=graph_id,
enable_twitter=enable_twitter,
enable_reddit=enable_reddit,
status=SimulationStatus.CREATED,
)
self._save_simulation_state(state)
logger.info(f"创建模拟: {simulation_id}, project={project_id}, graph={graph_id}")
return state
def prepare_simulation(
self,
simulation_id: str,
simulation_requirement: str,
document_text: str,
defined_entity_types: Optional[List[str]] = None,
use_llm_for_profiles: bool = True,
progress_callback: Optional[callable] = None
) -> SimulationState:
"""
准备模拟环境全程自动化
步骤
1. 从Zep图谱读取并过滤实体
2. 为每个实体生成OASIS Agent Profile可选LLM增强
3. 使用LLM智能生成模拟配置参数时间活跃度发言频率等
4. 保存配置文件和Profile文件
5. 复制预设脚本到模拟目录
Args:
simulation_id: 模拟ID
simulation_requirement: 模拟需求描述用于LLM生成配置
document_text: 原始文档内容用于LLM理解背景
defined_entity_types: 预定义的实体类型可选
use_llm_for_profiles: 是否使用LLM生成详细人设
progress_callback: 进度回调函数 (stage, progress, message)
Returns:
SimulationState
"""
state = self._load_simulation_state(simulation_id)
if not state:
raise ValueError(f"模拟不存在: {simulation_id}")
try:
state.status = SimulationStatus.PREPARING
self._save_simulation_state(state)
sim_dir = self._get_simulation_dir(simulation_id)
# ========== 阶段1: 读取并过滤实体 ==========
if progress_callback:
progress_callback("reading", 0, "正在连接Zep图谱...")
reader = ZepEntityReader()
if progress_callback:
progress_callback("reading", 30, "正在读取节点数据...")
filtered = reader.filter_defined_entities(
graph_id=state.graph_id,
defined_entity_types=defined_entity_types,
enrich_with_edges=True
)
state.entities_count = filtered.filtered_count
state.entity_types = list(filtered.entity_types)
if progress_callback:
progress_callback(
"reading", 100,
f"完成,共 {filtered.filtered_count} 个实体",
current=filtered.filtered_count,
total=filtered.filtered_count
)
if filtered.filtered_count == 0:
state.status = SimulationStatus.FAILED
state.error = "没有找到符合条件的实体,请检查图谱是否正确构建"
self._save_simulation_state(state)
return state
# ========== 阶段2: 生成Agent Profile ==========
total_entities = len(filtered.entities)
if progress_callback:
progress_callback(
"generating_profiles", 0,
"开始生成...",
current=0,
total=total_entities
)
generator = OasisProfileGenerator()
def profile_progress(current, total, msg):
if progress_callback:
progress_callback(
"generating_profiles",
int(current / total * 100),
msg,
current=current,
total=total,
item_name=msg
)
profiles = generator.generate_profiles_from_entities(
entities=filtered.entities,
use_llm=use_llm_for_profiles,
progress_callback=profile_progress
)
state.profiles_count = len(profiles)
# 保存Profile文件注意Twitter使用CSV格式Reddit使用JSON格式
if progress_callback:
progress_callback(
"generating_profiles", 95,
"保存Profile文件...",
current=total_entities,
total=total_entities
)
if state.enable_reddit:
generator.save_profiles(
profiles=profiles,
file_path=os.path.join(sim_dir, "reddit_profiles.json"),
platform="reddit"
)
if state.enable_twitter:
# Twitter使用CSV格式这是OASIS的要求
generator.save_profiles(
profiles=profiles,
file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
platform="twitter"
)
if progress_callback:
progress_callback(
"generating_profiles", 100,
f"完成,共 {len(profiles)} 个Profile",
current=len(profiles),
total=len(profiles)
)
# ========== 阶段3: LLM智能生成模拟配置 ==========
if progress_callback:
progress_callback(
"generating_config", 0,
"正在分析模拟需求...",
current=0,
total=3
)
config_generator = SimulationConfigGenerator()
if progress_callback:
progress_callback(
"generating_config", 30,
"正在调用LLM生成配置...",
current=1,
total=3
)
sim_params = config_generator.generate_config(
simulation_id=simulation_id,
project_id=state.project_id,
graph_id=state.graph_id,
simulation_requirement=simulation_requirement,
document_text=document_text,
entities=filtered.entities,
enable_twitter=state.enable_twitter,
enable_reddit=state.enable_reddit
)
if progress_callback:
progress_callback(
"generating_config", 70,
"正在保存配置文件...",
current=2,
total=3
)
# 保存配置文件
config_path = os.path.join(sim_dir, "simulation_config.json")
with open(config_path, 'w', encoding='utf-8') as f:
f.write(sim_params.to_json())
state.config_generated = True
state.config_reasoning = sim_params.generation_reasoning
if progress_callback:
progress_callback(
"generating_config", 100,
"配置生成完成",
current=3,
total=3
)
# ========== 阶段4: 复制预设脚本 ==========
script_files = ["run_twitter_simulation.py", "run_reddit_simulation.py",
"run_parallel_simulation.py", "action_logger.py"]
if progress_callback:
progress_callback(
"copying_scripts", 0,
"开始准备脚本...",
current=0,
total=len(script_files)
)
self._copy_preset_scripts(sim_dir)
if progress_callback:
progress_callback(
"copying_scripts", 100,
f"完成,共 {len(script_files)} 个脚本",
current=len(script_files),
total=len(script_files)
)
# 更新状态
state.status = SimulationStatus.READY
self._save_simulation_state(state)
logger.info(f"模拟准备完成: {simulation_id}, "
f"entities={state.entities_count}, profiles={state.profiles_count}")
return state
except Exception as e:
logger.error(f"模拟准备失败: {simulation_id}, error={str(e)}")
import traceback
logger.error(traceback.format_exc())
state.status = SimulationStatus.FAILED
state.error = str(e)
self._save_simulation_state(state)
raise
def _copy_preset_scripts(self, sim_dir: str):
"""复制预设脚本到模拟目录"""
scripts = [
"run_twitter_simulation.py",
"run_reddit_simulation.py",
"run_parallel_simulation.py"
]
for script in scripts:
src = os.path.join(self.SCRIPTS_DIR, script)
dst = os.path.join(sim_dir, script)
if os.path.exists(src):
shutil.copy2(src, dst)
logger.debug(f"复制脚本: {script}")
else:
logger.warning(f"预设脚本不存在: {src}")
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
"""获取模拟状态"""
return self._load_simulation_state(simulation_id)
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
"""列出所有模拟"""
simulations = []
if os.path.exists(self.SIMULATION_DATA_DIR):
for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
state = self._load_simulation_state(sim_id)
if state:
if project_id is None or state.project_id == project_id:
simulations.append(state)
return simulations
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
"""获取模拟的Agent Profile"""
state = self._load_simulation_state(simulation_id)
if not state:
raise ValueError(f"模拟不存在: {simulation_id}")
sim_dir = self._get_simulation_dir(simulation_id)
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
if not os.path.exists(profile_path):
return []
with open(profile_path, 'r', encoding='utf-8') as f:
return json.load(f)
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
"""获取模拟配置"""
sim_dir = self._get_simulation_dir(simulation_id)
config_path = os.path.join(sim_dir, "simulation_config.json")
if not os.path.exists(config_path):
return None
with open(config_path, 'r', encoding='utf-8') as f:
return json.load(f)
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
"""获取运行说明"""
sim_dir = self._get_simulation_dir(simulation_id)
config_path = os.path.join(sim_dir, "simulation_config.json")
return {
"simulation_dir": sim_dir,
"config_file": config_path,
"commands": {
"twitter": f"python run_twitter_simulation.py --config simulation_config.json",
"reddit": f"python run_reddit_simulation.py --config simulation_config.json",
"parallel": f"python run_parallel_simulation.py --config simulation_config.json",
},
"instructions": (
f"1. 进入模拟目录: cd {sim_dir}\n"
f"2. 激活conda环境: conda activate MiroFish\n"
f"3. 运行模拟:\n"
f" - 单独运行Twitter: python run_twitter_simulation.py --config simulation_config.json\n"
f" - 单独运行Reddit: python run_reddit_simulation.py --config simulation_config.json\n"
f" - 并行运行双平台: python run_parallel_simulation.py --config simulation_config.json"
)
}

View file

@ -0,0 +1,670 @@
"""
OASIS模拟运行器
在后台运行模拟并记录每个Agent的动作支持实时状态监控
"""
import os
import sys
import json
import time
import asyncio
import threading
import subprocess
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from queue import Queue
from ..config import Config
from ..utils.logger import get_logger
logger = get_logger('mirofish.simulation_runner')
class RunnerStatus(str, Enum):
"""运行器状态"""
IDLE = "idle"
STARTING = "starting"
RUNNING = "running"
PAUSED = "paused"
STOPPING = "stopping"
STOPPED = "stopped"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class AgentAction:
"""Agent动作记录"""
round_num: int
timestamp: str
platform: str # twitter / reddit
agent_id: int
agent_name: str
action_type: str # CREATE_POST, LIKE_POST, etc.
action_args: Dict[str, Any] = field(default_factory=dict)
result: Optional[str] = None
success: bool = True
def to_dict(self) -> Dict[str, Any]:
return {
"round_num": self.round_num,
"timestamp": self.timestamp,
"platform": self.platform,
"agent_id": self.agent_id,
"agent_name": self.agent_name,
"action_type": self.action_type,
"action_args": self.action_args,
"result": self.result,
"success": self.success,
}
@dataclass
class RoundSummary:
"""每轮摘要"""
round_num: int
start_time: str
end_time: Optional[str] = None
simulated_hour: int = 0
twitter_actions: int = 0
reddit_actions: int = 0
active_agents: List[int] = field(default_factory=list)
actions: List[AgentAction] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"round_num": self.round_num,
"start_time": self.start_time,
"end_time": self.end_time,
"simulated_hour": self.simulated_hour,
"twitter_actions": self.twitter_actions,
"reddit_actions": self.reddit_actions,
"active_agents": self.active_agents,
"actions_count": len(self.actions),
"actions": [a.to_dict() for a in self.actions],
}
@dataclass
class SimulationRunState:
"""模拟运行状态(实时)"""
simulation_id: str
runner_status: RunnerStatus = RunnerStatus.IDLE
# 进度信息
current_round: int = 0
total_rounds: int = 0
simulated_hours: int = 0
total_simulation_hours: int = 0
# 平台状态
twitter_running: bool = False
reddit_running: bool = False
twitter_actions_count: int = 0
reddit_actions_count: int = 0
# 每轮摘要
rounds: List[RoundSummary] = field(default_factory=list)
# 最近动作(用于前端实时展示)
recent_actions: List[AgentAction] = field(default_factory=list)
max_recent_actions: int = 50
# 时间戳
started_at: Optional[str] = None
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
completed_at: Optional[str] = None
# 错误信息
error: Optional[str] = None
# 进程ID用于停止
process_pid: Optional[int] = None
def add_action(self, action: AgentAction):
"""添加动作到最近动作列表"""
self.recent_actions.insert(0, action)
if len(self.recent_actions) > self.max_recent_actions:
self.recent_actions = self.recent_actions[:self.max_recent_actions]
if action.platform == "twitter":
self.twitter_actions_count += 1
else:
self.reddit_actions_count += 1
self.updated_at = datetime.now().isoformat()
def to_dict(self) -> Dict[str, Any]:
return {
"simulation_id": self.simulation_id,
"runner_status": self.runner_status.value,
"current_round": self.current_round,
"total_rounds": self.total_rounds,
"simulated_hours": self.simulated_hours,
"total_simulation_hours": self.total_simulation_hours,
"progress_percent": round(self.current_round / max(self.total_rounds, 1) * 100, 1),
"twitter_running": self.twitter_running,
"reddit_running": self.reddit_running,
"twitter_actions_count": self.twitter_actions_count,
"reddit_actions_count": self.reddit_actions_count,
"total_actions_count": self.twitter_actions_count + self.reddit_actions_count,
"started_at": self.started_at,
"updated_at": self.updated_at,
"completed_at": self.completed_at,
"error": self.error,
"process_pid": self.process_pid,
}
def to_detail_dict(self) -> Dict[str, Any]:
"""包含最近动作的详细信息"""
result = self.to_dict()
result["recent_actions"] = [a.to_dict() for a in self.recent_actions]
result["rounds_count"] = len(self.rounds)
return result
class SimulationRunner:
"""
模拟运行器
负责
1. 在后台进程中运行OASIS模拟
2. 解析运行日志记录每个Agent的动作
3. 提供实时状态查询接口
4. 支持暂停/停止/恢复操作
"""
# 运行状态存储目录
RUN_STATE_DIR = os.path.join(
os.path.dirname(__file__),
'../../uploads/simulations'
)
# 内存中的运行状态
_run_states: Dict[str, SimulationRunState] = {}
_processes: Dict[str, subprocess.Popen] = {}
_action_queues: Dict[str, Queue] = {}
_monitor_threads: Dict[str, threading.Thread] = {}
@classmethod
def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
"""获取运行状态"""
if simulation_id in cls._run_states:
return cls._run_states[simulation_id]
# 尝试从文件加载
state = cls._load_run_state(simulation_id)
if state:
cls._run_states[simulation_id] = state
return state
@classmethod
def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
"""从文件加载运行状态"""
state_file = os.path.join(cls.RUN_STATE_DIR, simulation_id, "run_state.json")
if not os.path.exists(state_file):
return None
try:
with open(state_file, 'r', encoding='utf-8') as f:
data = json.load(f)
state = SimulationRunState(
simulation_id=simulation_id,
runner_status=RunnerStatus(data.get("runner_status", "idle")),
current_round=data.get("current_round", 0),
total_rounds=data.get("total_rounds", 0),
simulated_hours=data.get("simulated_hours", 0),
total_simulation_hours=data.get("total_simulation_hours", 0),
twitter_running=data.get("twitter_running", False),
reddit_running=data.get("reddit_running", False),
twitter_actions_count=data.get("twitter_actions_count", 0),
reddit_actions_count=data.get("reddit_actions_count", 0),
started_at=data.get("started_at"),
updated_at=data.get("updated_at", datetime.now().isoformat()),
completed_at=data.get("completed_at"),
error=data.get("error"),
process_pid=data.get("process_pid"),
)
# 加载最近动作
actions_data = data.get("recent_actions", [])
for a in actions_data:
state.recent_actions.append(AgentAction(
round_num=a.get("round_num", 0),
timestamp=a.get("timestamp", ""),
platform=a.get("platform", ""),
agent_id=a.get("agent_id", 0),
agent_name=a.get("agent_name", ""),
action_type=a.get("action_type", ""),
action_args=a.get("action_args", {}),
result=a.get("result"),
success=a.get("success", True),
))
return state
except Exception as e:
logger.error(f"加载运行状态失败: {str(e)}")
return None
@classmethod
def _save_run_state(cls, state: SimulationRunState):
"""保存运行状态到文件"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id)
os.makedirs(sim_dir, exist_ok=True)
state_file = os.path.join(sim_dir, "run_state.json")
data = state.to_detail_dict()
with open(state_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
cls._run_states[state.simulation_id] = state
@classmethod
def start_simulation(
cls,
simulation_id: str,
platform: str = "parallel" # twitter / reddit / parallel
) -> SimulationRunState:
"""
启动模拟
Args:
simulation_id: 模拟ID
platform: 运行平台 (twitter/reddit/parallel)
Returns:
SimulationRunState
"""
# 检查是否已在运行
existing = cls.get_run_state(simulation_id)
if existing and existing.runner_status in [RunnerStatus.RUNNING, RunnerStatus.STARTING]:
raise ValueError(f"模拟已在运行中: {simulation_id}")
# 加载模拟配置
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
config_path = os.path.join(sim_dir, "simulation_config.json")
if not os.path.exists(config_path):
raise ValueError(f"模拟配置不存在,请先调用 /prepare 接口")
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# 初始化运行状态
time_config = config.get("time_config", {})
total_hours = time_config.get("total_simulation_hours", 72)
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = int(total_hours * 60 / minutes_per_round)
state = SimulationRunState(
simulation_id=simulation_id,
runner_status=RunnerStatus.STARTING,
total_rounds=total_rounds,
total_simulation_hours=total_hours,
started_at=datetime.now().isoformat(),
)
cls._save_run_state(state)
# 确定运行哪个脚本
if platform == "twitter":
script_name = "run_twitter_simulation.py"
state.twitter_running = True
elif platform == "reddit":
script_name = "run_reddit_simulation.py"
state.reddit_running = True
else:
script_name = "run_parallel_simulation.py"
state.twitter_running = True
state.reddit_running = True
script_path = os.path.join(sim_dir, script_name)
if not os.path.exists(script_path):
raise ValueError(f"脚本不存在: {script_path}")
# 创建动作队列
action_queue = Queue()
cls._action_queues[simulation_id] = action_queue
# 启动模拟进程
try:
# 构建运行命令
cmd = [
sys.executable, # Python解释器
script_path,
"--config", "simulation_config.json",
"--action-log", "actions.jsonl", # 动作日志文件
]
# 设置工作目录为模拟目录
process = subprocess.Popen(
cmd,
cwd=sim_dir,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
)
state.process_pid = process.pid
state.runner_status = RunnerStatus.RUNNING
cls._processes[simulation_id] = process
cls._save_run_state(state)
# 启动监控线程
monitor_thread = threading.Thread(
target=cls._monitor_simulation,
args=(simulation_id,),
daemon=True
)
monitor_thread.start()
cls._monitor_threads[simulation_id] = monitor_thread
logger.info(f"模拟启动成功: {simulation_id}, pid={process.pid}, platform={platform}")
except Exception as e:
state.runner_status = RunnerStatus.FAILED
state.error = str(e)
cls._save_run_state(state)
raise
return state
@classmethod
def _monitor_simulation(cls, simulation_id: str):
"""监控模拟进程,解析动作日志"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
actions_log = os.path.join(sim_dir, "actions.jsonl")
process = cls._processes.get(simulation_id)
state = cls.get_run_state(simulation_id)
if not process or not state:
return
last_position = 0
try:
while process.poll() is None: # 进程仍在运行
# 读取动作日志
if os.path.exists(actions_log):
with open(actions_log, 'r', encoding='utf-8') as f:
f.seek(last_position)
for line in f:
line = line.strip()
if line:
try:
action_data = json.loads(line)
action = AgentAction(
round_num=action_data.get("round", 0),
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
platform=action_data.get("platform", "unknown"),
agent_id=action_data.get("agent_id", 0),
agent_name=action_data.get("agent_name", ""),
action_type=action_data.get("action_type", ""),
action_args=action_data.get("action_args", {}),
result=action_data.get("result"),
success=action_data.get("success", True),
)
state.add_action(action)
# 更新轮次
if action.round_num > state.current_round:
state.current_round = action.round_num
except json.JSONDecodeError:
pass
last_position = f.tell()
# 定期保存状态
cls._save_run_state(state)
time.sleep(1) # 每秒检查一次
# 进程结束
exit_code = process.returncode
if exit_code == 0:
state.runner_status = RunnerStatus.COMPLETED
state.completed_at = datetime.now().isoformat()
logger.info(f"模拟完成: {simulation_id}")
else:
state.runner_status = RunnerStatus.FAILED
stderr = process.stderr.read() if process.stderr else ""
state.error = f"进程退出码: {exit_code}, 错误: {stderr[:500]}"
logger.error(f"模拟失败: {simulation_id}, error={state.error}")
state.twitter_running = False
state.reddit_running = False
cls._save_run_state(state)
except Exception as e:
logger.error(f"监控线程异常: {simulation_id}, error={str(e)}")
state.runner_status = RunnerStatus.FAILED
state.error = str(e)
cls._save_run_state(state)
finally:
# 清理
cls._processes.pop(simulation_id, None)
cls._action_queues.pop(simulation_id, None)
@classmethod
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
"""停止模拟"""
state = cls.get_run_state(simulation_id)
if not state:
raise ValueError(f"模拟不存在: {simulation_id}")
if state.runner_status not in [RunnerStatus.RUNNING, RunnerStatus.PAUSED]:
raise ValueError(f"模拟未在运行: {simulation_id}, status={state.runner_status}")
state.runner_status = RunnerStatus.STOPPING
cls._save_run_state(state)
# 终止进程
process = cls._processes.get(simulation_id)
if process:
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
process.kill()
state.runner_status = RunnerStatus.STOPPED
state.twitter_running = False
state.reddit_running = False
state.completed_at = datetime.now().isoformat()
cls._save_run_state(state)
logger.info(f"模拟已停止: {simulation_id}")
return state
@classmethod
def get_actions(
cls,
simulation_id: str,
limit: int = 100,
offset: int = 0,
platform: Optional[str] = None,
agent_id: Optional[int] = None,
round_num: Optional[int] = None
) -> List[AgentAction]:
"""
获取动作历史
Args:
simulation_id: 模拟ID
limit: 返回数量限制
offset: 偏移量
platform: 过滤平台
agent_id: 过滤Agent
round_num: 过滤轮次
Returns:
动作列表
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
actions_log = os.path.join(sim_dir, "actions.jsonl")
if not os.path.exists(actions_log):
return []
actions = []
with open(actions_log, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
# 过滤
if platform and data.get("platform") != platform:
continue
if agent_id is not None and data.get("agent_id") != agent_id:
continue
if round_num is not None and data.get("round") != round_num:
continue
actions.append(AgentAction(
round_num=data.get("round", 0),
timestamp=data.get("timestamp", ""),
platform=data.get("platform", ""),
agent_id=data.get("agent_id", 0),
agent_name=data.get("agent_name", ""),
action_type=data.get("action_type", ""),
action_args=data.get("action_args", {}),
result=data.get("result"),
success=data.get("success", True),
))
except json.JSONDecodeError:
continue
# 按时间倒序排列
actions.reverse()
# 分页
return actions[offset:offset + limit]
@classmethod
def get_timeline(
cls,
simulation_id: str,
start_round: int = 0,
end_round: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
获取模拟时间线按轮次汇总
Args:
simulation_id: 模拟ID
start_round: 起始轮次
end_round: 结束轮次
Returns:
每轮的汇总信息
"""
actions = cls.get_actions(simulation_id, limit=10000)
# 按轮次分组
rounds: Dict[int, Dict[str, Any]] = {}
for action in actions:
round_num = action.round_num
if round_num < start_round:
continue
if end_round is not None and round_num > end_round:
continue
if round_num not in rounds:
rounds[round_num] = {
"round_num": round_num,
"twitter_actions": 0,
"reddit_actions": 0,
"active_agents": set(),
"action_types": {},
"first_action_time": action.timestamp,
"last_action_time": action.timestamp,
}
r = rounds[round_num]
if action.platform == "twitter":
r["twitter_actions"] += 1
else:
r["reddit_actions"] += 1
r["active_agents"].add(action.agent_id)
r["action_types"][action.action_type] = r["action_types"].get(action.action_type, 0) + 1
r["last_action_time"] = action.timestamp
# 转换为列表
result = []
for round_num in sorted(rounds.keys()):
r = rounds[round_num]
result.append({
"round_num": round_num,
"twitter_actions": r["twitter_actions"],
"reddit_actions": r["reddit_actions"],
"total_actions": r["twitter_actions"] + r["reddit_actions"],
"active_agents_count": len(r["active_agents"]),
"active_agents": list(r["active_agents"]),
"action_types": r["action_types"],
"first_action_time": r["first_action_time"],
"last_action_time": r["last_action_time"],
})
return result
@classmethod
def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]:
"""
获取每个Agent的统计信息
Returns:
Agent统计列表
"""
actions = cls.get_actions(simulation_id, limit=10000)
agent_stats: Dict[int, Dict[str, Any]] = {}
for action in actions:
agent_id = action.agent_id
if agent_id not in agent_stats:
agent_stats[agent_id] = {
"agent_id": agent_id,
"agent_name": action.agent_name,
"total_actions": 0,
"twitter_actions": 0,
"reddit_actions": 0,
"action_types": {},
"first_action_time": action.timestamp,
"last_action_time": action.timestamp,
}
stats = agent_stats[agent_id]
stats["total_actions"] += 1
if action.platform == "twitter":
stats["twitter_actions"] += 1
else:
stats["reddit_actions"] += 1
stats["action_types"][action.action_type] = stats["action_types"].get(action.action_type, 0) + 1
stats["last_action_time"] = action.timestamp
# 按总动作数排序
result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True)
return result

View file

@ -0,0 +1,386 @@
"""
Zep实体读取与过滤服务
从Zep图谱中读取节点筛选出符合预定义实体类型的节点
"""
from typing import Dict, Any, List, Optional, Set
from dataclasses import dataclass, field
from zep_cloud.client import Zep
from ..config import Config
from ..utils.logger import get_logger
logger = get_logger('mirofish.zep_entity_reader')
@dataclass
class EntityNode:
"""实体节点数据结构"""
uuid: str
name: str
labels: List[str]
summary: str
attributes: Dict[str, Any]
# 相关的边信息
related_edges: List[Dict[str, Any]] = field(default_factory=list)
# 相关的其他节点信息
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"uuid": self.uuid,
"name": self.name,
"labels": self.labels,
"summary": self.summary,
"attributes": self.attributes,
"related_edges": self.related_edges,
"related_nodes": self.related_nodes,
}
def get_entity_type(self) -> Optional[str]:
"""获取实体类型排除默认的Entity标签"""
for label in self.labels:
if label not in ["Entity", "Node"]:
return label
return None
@dataclass
class FilteredEntities:
"""过滤后的实体集合"""
entities: List[EntityNode]
entity_types: Set[str]
total_count: int
filtered_count: int
def to_dict(self) -> Dict[str, Any]:
return {
"entities": [e.to_dict() for e in self.entities],
"entity_types": list(self.entity_types),
"total_count": self.total_count,
"filtered_count": self.filtered_count,
}
class ZepEntityReader:
"""
Zep实体读取与过滤服务
主要功能
1. 从Zep图谱读取所有节点
2. 筛选出符合预定义实体类型的节点Labels不只是Entity的节点
3. 获取每个实体的相关边和关联节点信息
"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY 未配置")
self.client = Zep(api_key=self.api_key)
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
"""
获取图谱的所有节点
Args:
graph_id: 图谱ID
Returns:
节点列表
"""
logger.info(f"获取图谱 {graph_id} 的所有节点...")
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
nodes_data = []
for node in nodes:
nodes_data.append({
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
"name": node.name or "",
"labels": node.labels or [],
"summary": node.summary or "",
"attributes": node.attributes or {},
})
logger.info(f"共获取 {len(nodes_data)} 个节点")
return nodes_data
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
"""
获取图谱的所有边
Args:
graph_id: 图谱ID
Returns:
边列表
"""
logger.info(f"获取图谱 {graph_id} 的所有边...")
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
edges_data = []
for edge in edges:
edges_data.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": edge.name or "",
"fact": edge.fact or "",
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {},
})
logger.info(f"共获取 {len(edges_data)} 条边")
return edges_data
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
"""
获取指定节点的所有相关边
Args:
node_uuid: 节点UUID
Returns:
边列表
"""
try:
edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid)
edges_data = []
for edge in edges:
edges_data.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": edge.name or "",
"fact": edge.fact or "",
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {},
})
return edges_data
except Exception as e:
logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}")
return []
def filter_defined_entities(
self,
graph_id: str,
defined_entity_types: Optional[List[str]] = None,
enrich_with_edges: bool = True
) -> FilteredEntities:
"""
筛选出符合预定义实体类型的节点
筛选逻辑
- 如果节点的Labels只有一个"Entity"说明这个实体不符合我们预定义的类型跳过
- 如果节点的Labels包含除"Entity""Node"之外的标签说明符合预定义类型保留
Args:
graph_id: 图谱ID
defined_entity_types: 预定义的实体类型列表可选如果提供则只保留这些类型
enrich_with_edges: 是否获取每个实体的相关边信息
Returns:
FilteredEntities: 过滤后的实体集合
"""
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
# 获取所有节点
all_nodes = self.get_all_nodes(graph_id)
total_count = len(all_nodes)
# 获取所有边(用于后续关联查找)
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
# 构建节点UUID到节点数据的映射
node_map = {n["uuid"]: n for n in all_nodes}
# 筛选符合条件的实体
filtered_entities = []
entity_types_found = set()
for node in all_nodes:
labels = node.get("labels", [])
# 筛选逻辑Labels必须包含除"Entity"和"Node"之外的标签
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
if not custom_labels:
# 只有默认标签,跳过
continue
# 如果指定了预定义类型,检查是否匹配
if defined_entity_types:
matching_labels = [l for l in custom_labels if l in defined_entity_types]
if not matching_labels:
continue
entity_type = matching_labels[0]
else:
entity_type = custom_labels[0]
entity_types_found.add(entity_type)
# 创建实体节点对象
entity = EntityNode(
uuid=node["uuid"],
name=node["name"],
labels=labels,
summary=node["summary"],
attributes=node["attributes"],
)
# 获取相关边和节点
if enrich_with_edges:
related_edges = []
related_node_uuids = set()
for edge in all_edges:
if edge["source_node_uuid"] == node["uuid"]:
related_edges.append({
"direction": "outgoing",
"edge_name": edge["name"],
"fact": edge["fact"],
"target_node_uuid": edge["target_node_uuid"],
})
related_node_uuids.add(edge["target_node_uuid"])
elif edge["target_node_uuid"] == node["uuid"]:
related_edges.append({
"direction": "incoming",
"edge_name": edge["name"],
"fact": edge["fact"],
"source_node_uuid": edge["source_node_uuid"],
})
related_node_uuids.add(edge["source_node_uuid"])
entity.related_edges = related_edges
# 获取关联节点的基本信息
related_nodes = []
for related_uuid in related_node_uuids:
if related_uuid in node_map:
related_node = node_map[related_uuid]
related_nodes.append({
"uuid": related_node["uuid"],
"name": related_node["name"],
"labels": related_node["labels"],
"summary": related_node.get("summary", ""),
})
entity.related_nodes = related_nodes
filtered_entities.append(entity)
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
f"实体类型: {entity_types_found}")
return FilteredEntities(
entities=filtered_entities,
entity_types=entity_types_found,
total_count=total_count,
filtered_count=len(filtered_entities),
)
def get_entity_with_context(
self,
graph_id: str,
entity_uuid: str
) -> Optional[EntityNode]:
"""
获取单个实体及其完整上下文边和关联节点
Args:
graph_id: 图谱ID
entity_uuid: 实体UUID
Returns:
EntityNode或None
"""
try:
# 获取节点
node = self.client.graph.node.get(uuid_=entity_uuid)
if not node:
return None
# 获取节点的边
edges = self.get_node_edges(entity_uuid)
# 获取所有节点用于关联查找
all_nodes = self.get_all_nodes(graph_id)
node_map = {n["uuid"]: n for n in all_nodes}
# 处理相关边和节点
related_edges = []
related_node_uuids = set()
for edge in edges:
if edge["source_node_uuid"] == entity_uuid:
related_edges.append({
"direction": "outgoing",
"edge_name": edge["name"],
"fact": edge["fact"],
"target_node_uuid": edge["target_node_uuid"],
})
related_node_uuids.add(edge["target_node_uuid"])
else:
related_edges.append({
"direction": "incoming",
"edge_name": edge["name"],
"fact": edge["fact"],
"source_node_uuid": edge["source_node_uuid"],
})
related_node_uuids.add(edge["source_node_uuid"])
# 获取关联节点信息
related_nodes = []
for related_uuid in related_node_uuids:
if related_uuid in node_map:
related_node = node_map[related_uuid]
related_nodes.append({
"uuid": related_node["uuid"],
"name": related_node["name"],
"labels": related_node["labels"],
"summary": related_node.get("summary", ""),
})
return EntityNode(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {},
related_edges=related_edges,
related_nodes=related_nodes,
)
except Exception as e:
logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}")
return None
def get_entities_by_type(
self,
graph_id: str,
entity_type: str,
enrich_with_edges: bool = True
) -> List[EntityNode]:
"""
获取指定类型的所有实体
Args:
graph_id: 图谱ID
entity_type: 实体类型 "Student", "PublicFigure"
enrich_with_edges: 是否获取相关边信息
Returns:
实体列表
"""
result = self.filter_defined_entities(
graph_id=graph_id,
defined_entity_types=[entity_type],
enrich_with_edges=enrich_with_edges
)
return result.entities

238
backend/app/utils/retry.py Normal file
View file

@ -0,0 +1,238 @@
"""
API调用重试机制
用于处理LLM等外部API调用的重试逻辑
"""
import time
import random
import functools
from typing import Callable, Any, Optional, Type, Tuple
from ..utils.logger import get_logger
logger = get_logger('mirofish.retry')
def retry_with_backoff(
max_retries: int = 3,
initial_delay: float = 1.0,
max_delay: float = 30.0,
backoff_factor: float = 2.0,
jitter: bool = True,
exceptions: Tuple[Type[Exception], ...] = (Exception,),
on_retry: Optional[Callable[[Exception, int], None]] = None
):
"""
带指数退避的重试装饰器
Args:
max_retries: 最大重试次数
initial_delay: 初始延迟
max_delay: 最大延迟
backoff_factor: 退避因子
jitter: 是否添加随机抖动
exceptions: 需要重试的异常类型
on_retry: 重试时的回调函数 (exception, retry_count)
Usage:
@retry_with_backoff(max_retries=3)
def call_llm_api():
...
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
last_exception = None
delay = initial_delay
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt == max_retries:
logger.error(f"函数 {func.__name__}{max_retries} 次重试后仍失败: {str(e)}")
raise
# 计算延迟
current_delay = min(delay, max_delay)
if jitter:
current_delay = current_delay * (0.5 + random.random())
logger.warning(
f"函数 {func.__name__}{attempt + 1} 次尝试失败: {str(e)}, "
f"{current_delay:.1f}秒后重试..."
)
if on_retry:
on_retry(e, attempt + 1)
time.sleep(current_delay)
delay *= backoff_factor
raise last_exception
return wrapper
return decorator
def retry_with_backoff_async(
max_retries: int = 3,
initial_delay: float = 1.0,
max_delay: float = 30.0,
backoff_factor: float = 2.0,
jitter: bool = True,
exceptions: Tuple[Type[Exception], ...] = (Exception,),
on_retry: Optional[Callable[[Exception, int], None]] = None
):
"""
异步版本的重试装饰器
"""
import asyncio
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs) -> Any:
last_exception = None
delay = initial_delay
for attempt in range(max_retries + 1):
try:
return await func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt == max_retries:
logger.error(f"异步函数 {func.__name__}{max_retries} 次重试后仍失败: {str(e)}")
raise
current_delay = min(delay, max_delay)
if jitter:
current_delay = current_delay * (0.5 + random.random())
logger.warning(
f"异步函数 {func.__name__}{attempt + 1} 次尝试失败: {str(e)}, "
f"{current_delay:.1f}秒后重试..."
)
if on_retry:
on_retry(e, attempt + 1)
await asyncio.sleep(current_delay)
delay *= backoff_factor
raise last_exception
return wrapper
return decorator
class RetryableAPIClient:
"""
可重试的API客户端封装
"""
def __init__(
self,
max_retries: int = 3,
initial_delay: float = 1.0,
max_delay: float = 30.0,
backoff_factor: float = 2.0
):
self.max_retries = max_retries
self.initial_delay = initial_delay
self.max_delay = max_delay
self.backoff_factor = backoff_factor
def call_with_retry(
self,
func: Callable,
*args,
exceptions: Tuple[Type[Exception], ...] = (Exception,),
**kwargs
) -> Any:
"""
执行函数调用并在失败时重试
Args:
func: 要调用的函数
*args: 函数参数
exceptions: 需要重试的异常类型
**kwargs: 函数关键字参数
Returns:
函数返回值
"""
last_exception = None
delay = self.initial_delay
for attempt in range(self.max_retries + 1):
try:
return func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt == self.max_retries:
logger.error(f"API调用在 {self.max_retries} 次重试后仍失败: {str(e)}")
raise
current_delay = min(delay, self.max_delay)
current_delay = current_delay * (0.5 + random.random())
logger.warning(
f"API调用第 {attempt + 1} 次尝试失败: {str(e)}, "
f"{current_delay:.1f}秒后重试..."
)
time.sleep(current_delay)
delay *= self.backoff_factor
raise last_exception
def call_batch_with_retry(
self,
items: list,
process_func: Callable,
exceptions: Tuple[Type[Exception], ...] = (Exception,),
continue_on_failure: bool = True
) -> Tuple[list, list]:
"""
批量调用并对每个失败项单独重试
Args:
items: 要处理的项目列表
process_func: 处理函数接收单个item作为参数
exceptions: 需要重试的异常类型
continue_on_failure: 单项失败后是否继续处理其他项
Returns:
(成功结果列表, 失败项列表)
"""
results = []
failures = []
for idx, item in enumerate(items):
try:
result = self.call_with_retry(
process_func,
item,
exceptions=exceptions
)
results.append(result)
except Exception as e:
logger.error(f"处理第 {idx + 1} 项失败: {str(e)}")
failures.append({
"index": idx,
"item": item,
"error": str(e)
})
if not continue_on_failure:
raise
return results, failures

View file

@ -20,3 +20,7 @@ pydantic>=2.0.0
# 文件处理
werkzeug>=3.0.0
# OASIS社交媒体模拟框架
oasis-ai>=0.1.0
camel-ai>=0.2.0

View file

@ -0,0 +1,138 @@
"""
动作日志记录器
用于记录OASIS模拟中每个Agent的动作供后端监控使用
"""
import json
import os
from datetime import datetime
from typing import Dict, Any, Optional
class ActionLogger:
"""动作日志记录器"""
def __init__(self, log_path: str):
"""
初始化日志记录器
Args:
log_path: 日志文件路径.jsonl格式
"""
self.log_path = log_path
self._ensure_dir()
def _ensure_dir(self):
"""确保目录存在"""
log_dir = os.path.dirname(self.log_path)
if log_dir:
os.makedirs(log_dir, exist_ok=True)
def log_action(
self,
round_num: int,
platform: str,
agent_id: int,
agent_name: str,
action_type: str,
action_args: Optional[Dict[str, Any]] = None,
result: Optional[str] = None,
success: bool = True
):
"""
记录一个动作
Args:
round_num: 轮次
platform: 平台 (twitter/reddit)
agent_id: Agent ID
agent_name: Agent名称
action_type: 动作类型
action_args: 动作参数
result: 执行结果
success: 是否成功
"""
entry = {
"round": round_num,
"timestamp": datetime.now().isoformat(),
"platform": platform,
"agent_id": agent_id,
"agent_name": agent_name,
"action_type": action_type,
"action_args": action_args or {},
"result": result,
"success": success,
}
with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_round_start(self, round_num: int, simulated_hour: int, platform: str):
"""记录轮次开始"""
entry = {
"round": round_num,
"timestamp": datetime.now().isoformat(),
"platform": platform,
"event_type": "round_start",
"simulated_hour": simulated_hour,
}
with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_round_end(self, round_num: int, actions_count: int, platform: str):
"""记录轮次结束"""
entry = {
"round": round_num,
"timestamp": datetime.now().isoformat(),
"platform": platform,
"event_type": "round_end",
"actions_count": actions_count,
}
with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_simulation_start(self, platform: str, config: Dict[str, Any]):
"""记录模拟开始"""
entry = {
"timestamp": datetime.now().isoformat(),
"platform": platform,
"event_type": "simulation_start",
"total_rounds": config.get("time_config", {}).get("total_simulation_hours", 72) * 2,
"agents_count": len(config.get("agent_configs", [])),
}
with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_simulation_end(self, platform: str, total_rounds: int, total_actions: int):
"""记录模拟结束"""
entry = {
"timestamp": datetime.now().isoformat(),
"platform": platform,
"event_type": "simulation_end",
"total_rounds": total_rounds,
"total_actions": total_actions,
}
with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
# 全局日志实例(可选)
_global_logger: Optional[ActionLogger] = None
def get_logger(log_path: Optional[str] = None) -> ActionLogger:
"""获取全局日志实例"""
global _global_logger
if log_path:
_global_logger = ActionLogger(log_path)
if _global_logger is None:
_global_logger = ActionLogger("actions.jsonl")
return _global_logger

View file

@ -0,0 +1,503 @@
"""
OASIS 双平台并行模拟预设脚本
同时运行Twitter和Reddit模拟读取相同的配置文件
使用方式:
python run_parallel_simulation.py --config simulation_config.json [--action-log actions.jsonl]
"""
import argparse
import asyncio
import json
import os
import random
import sys
from datetime import datetime
from typing import Dict, Any, List, Optional
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from action_logger import ActionLogger
try:
from camel.models import ModelFactory
from camel.types import ModelPlatformType
import oasis
from oasis import (
ActionType,
LLMAction,
ManualAction,
generate_twitter_agent_graph,
generate_reddit_agent_graph
)
except ImportError as e:
print(f"错误: 缺少依赖 {e}")
print("请先安装: pip install oasis-ai camel-ai")
sys.exit(1)
# Twitter可用动作
TWITTER_ACTIONS = [
ActionType.CREATE_POST,
ActionType.LIKE_POST,
ActionType.REPOST,
ActionType.FOLLOW,
ActionType.DO_NOTHING,
ActionType.QUOTE_POST,
]
# Reddit可用动作
REDDIT_ACTIONS = [
ActionType.LIKE_POST,
ActionType.DISLIKE_POST,
ActionType.CREATE_POST,
ActionType.CREATE_COMMENT,
ActionType.LIKE_COMMENT,
ActionType.DISLIKE_COMMENT,
ActionType.SEARCH_POSTS,
ActionType.SEARCH_USER,
ActionType.TREND,
ActionType.REFRESH,
ActionType.DO_NOTHING,
ActionType.FOLLOW,
ActionType.MUTE,
]
def load_config(config_path: str) -> Dict[str, Any]:
"""加载配置文件"""
with open(config_path, 'r', encoding='utf-8') as f:
return json.load(f)
def create_model(config: Dict[str, Any]):
"""
创建LLM模型
OASIS使用camel-ai的ModelFactory配置方式
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量
- 自定义API: 设置 OPENAI_API_KEY OPENAI_API_BASE_URL 环境变量
"""
llm_model = config.get("llm_model", "gpt-4o-mini")
llm_base_url = config.get("llm_base_url", "")
# 如果配置了base_url设置环境变量
if llm_base_url:
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
return ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=llm_model,
)
def get_active_agents_for_round(
env,
config: Dict[str, Any],
current_hour: int,
round_num: int
) -> List:
"""根据时间和配置决定本轮激活哪些Agent"""
time_config = config.get("time_config", {})
agent_configs = config.get("agent_configs", [])
base_min = time_config.get("agents_per_hour_min", 5)
base_max = time_config.get("agents_per_hour_max", 20)
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
if current_hour in peak_hours:
multiplier = time_config.get("peak_activity_multiplier", 1.5)
elif current_hour in off_peak_hours:
multiplier = time_config.get("off_peak_activity_multiplier", 0.3)
else:
multiplier = 1.0
target_count = int(random.uniform(base_min, base_max) * multiplier)
candidates = []
for cfg in agent_configs:
agent_id = cfg.get("agent_id", 0)
active_hours = cfg.get("active_hours", list(range(8, 23)))
activity_level = cfg.get("activity_level", 0.5)
if current_hour not in active_hours:
continue
if random.random() < activity_level:
candidates.append(agent_id)
selected_ids = random.sample(
candidates,
min(target_count, len(candidates))
) if candidates else []
active_agents = []
for agent_id in selected_ids:
try:
agent = env.agent_graph.get_agent(agent_id)
active_agents.append((agent_id, agent))
except Exception:
pass
return active_agents
async def run_twitter_simulation(
config: Dict[str, Any],
simulation_dir: str,
action_logger: Optional[ActionLogger] = None
):
"""运行Twitter模拟"""
print("[Twitter] 初始化...")
model = create_model(config)
# OASIS Twitter使用CSV格式
profile_path = os.path.join(simulation_dir, "twitter_profiles.csv")
if not os.path.exists(profile_path):
print(f"[Twitter] 错误: Profile文件不存在: {profile_path}")
return
agent_graph = await generate_twitter_agent_graph(
profile_path=profile_path,
model=model,
available_actions=TWITTER_ACTIONS,
)
# 获取Agent名称映射
agent_names = {}
for agent_id, agent in agent_graph.get_agents():
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
db_path = os.path.join(simulation_dir, "twitter_simulation.db")
if os.path.exists(db_path):
os.remove(db_path)
env = oasis.make(
agent_graph=agent_graph,
platform=oasis.DefaultPlatformType.TWITTER,
database_path=db_path,
)
await env.reset()
print("[Twitter] 环境已启动")
if action_logger:
action_logger.log_simulation_start("twitter", config)
total_actions = 0
# 执行初始事件
event_config = config.get("event_config", {})
initial_posts = event_config.get("initial_posts", [])
if initial_posts:
initial_actions = {}
for post in initial_posts:
agent_id = post.get("poster_agent_id", 0)
content = post.get("content", "")
try:
agent = env.agent_graph.get_agent(agent_id)
initial_actions[agent] = ManualAction(
action_type=ActionType.CREATE_POST,
action_args={"content": content}
)
if action_logger:
action_logger.log_action(
round_num=0,
platform="twitter",
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="CREATE_POST",
action_args={"content": content[:100] + "..." if len(content) > 100 else content}
)
total_actions += 1
except Exception:
pass
if initial_actions:
await env.step(initial_actions)
print(f"[Twitter] 已发布 {len(initial_actions)} 条初始帖子")
# 主模拟循环
time_config = config.get("time_config", {})
total_hours = time_config.get("total_simulation_hours", 72)
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round
start_time = datetime.now()
for round_num in range(total_rounds):
simulated_minutes = round_num * minutes_per_round
simulated_hour = (simulated_minutes // 60) % 24
simulated_day = simulated_minutes // (60 * 24) + 1
active_agents = get_active_agents_for_round(
env, config, simulated_hour, round_num
)
if not active_agents:
continue
if action_logger:
action_logger.log_round_start(round_num + 1, simulated_hour, "twitter")
actions = {agent: LLMAction() for _, agent in active_agents}
await env.step(actions)
# 记录动作
for agent_id, agent in active_agents:
if action_logger:
action_logger.log_action(
round_num=round_num + 1,
platform="twitter",
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="LLM_ACTION",
action_args={}
)
total_actions += 1
if action_logger:
action_logger.log_round_end(round_num + 1, len(active_agents), "twitter")
if (round_num + 1) % 20 == 0:
progress = (round_num + 1) / total_rounds * 100
print(f"[Twitter] Day {simulated_day}, {simulated_hour:02d}:00 "
f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
await env.close()
if action_logger:
action_logger.log_simulation_end("twitter", total_rounds, total_actions)
elapsed = (datetime.now() - start_time).total_seconds()
print(f"[Twitter] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
async def run_reddit_simulation(
config: Dict[str, Any],
simulation_dir: str,
action_logger: Optional[ActionLogger] = None
):
"""运行Reddit模拟"""
print("[Reddit] 初始化...")
model = create_model(config)
profile_path = os.path.join(simulation_dir, "reddit_profiles.json")
if not os.path.exists(profile_path):
print(f"[Reddit] 错误: Profile文件不存在: {profile_path}")
return
agent_graph = await generate_reddit_agent_graph(
profile_path=profile_path,
model=model,
available_actions=REDDIT_ACTIONS,
)
# 获取Agent名称映射
agent_names = {}
for agent_id, agent in agent_graph.get_agents():
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
db_path = os.path.join(simulation_dir, "reddit_simulation.db")
if os.path.exists(db_path):
os.remove(db_path)
env = oasis.make(
agent_graph=agent_graph,
platform=oasis.DefaultPlatformType.REDDIT,
database_path=db_path,
)
await env.reset()
print("[Reddit] 环境已启动")
if action_logger:
action_logger.log_simulation_start("reddit", config)
total_actions = 0
# 执行初始事件
event_config = config.get("event_config", {})
initial_posts = event_config.get("initial_posts", [])
if initial_posts:
initial_actions = {}
for post in initial_posts:
agent_id = post.get("poster_agent_id", 0)
content = post.get("content", "")
try:
agent = env.agent_graph.get_agent(agent_id)
if agent in initial_actions:
if not isinstance(initial_actions[agent], list):
initial_actions[agent] = [initial_actions[agent]]
initial_actions[agent].append(ManualAction(
action_type=ActionType.CREATE_POST,
action_args={"content": content}
))
else:
initial_actions[agent] = ManualAction(
action_type=ActionType.CREATE_POST,
action_args={"content": content}
)
if action_logger:
action_logger.log_action(
round_num=0,
platform="reddit",
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="CREATE_POST",
action_args={"content": content[:100] + "..." if len(content) > 100 else content}
)
total_actions += 1
except Exception:
pass
if initial_actions:
await env.step(initial_actions)
print(f"[Reddit] 已发布 {len(initial_actions)} 条初始帖子")
# 主模拟循环
time_config = config.get("time_config", {})
total_hours = time_config.get("total_simulation_hours", 72)
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round
start_time = datetime.now()
for round_num in range(total_rounds):
simulated_minutes = round_num * minutes_per_round
simulated_hour = (simulated_minutes // 60) % 24
simulated_day = simulated_minutes // (60 * 24) + 1
active_agents = get_active_agents_for_round(
env, config, simulated_hour, round_num
)
if not active_agents:
continue
if action_logger:
action_logger.log_round_start(round_num + 1, simulated_hour, "reddit")
actions = {agent: LLMAction() for _, agent in active_agents}
await env.step(actions)
# 记录动作
for agent_id, agent in active_agents:
if action_logger:
action_logger.log_action(
round_num=round_num + 1,
platform="reddit",
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="LLM_ACTION",
action_args={}
)
total_actions += 1
if action_logger:
action_logger.log_round_end(round_num + 1, len(active_agents), "reddit")
if (round_num + 1) % 20 == 0:
progress = (round_num + 1) / total_rounds * 100
print(f"[Reddit] Day {simulated_day}, {simulated_hour:02d}:00 "
f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
await env.close()
if action_logger:
action_logger.log_simulation_end("reddit", total_rounds, total_actions)
elapsed = (datetime.now() - start_time).total_seconds()
print(f"[Reddit] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
async def main():
parser = argparse.ArgumentParser(description='OASIS双平台并行模拟')
parser.add_argument(
'--config',
type=str,
required=True,
help='配置文件路径 (simulation_config.json)'
)
parser.add_argument(
'--twitter-only',
action='store_true',
help='只运行Twitter模拟'
)
parser.add_argument(
'--reddit-only',
action='store_true',
help='只运行Reddit模拟'
)
parser.add_argument(
'--action-log',
type=str,
default='actions.jsonl',
help='动作日志文件路径 (默认: actions.jsonl)'
)
args = parser.parse_args()
if not os.path.exists(args.config):
print(f"错误: 配置文件不存在: {args.config}")
sys.exit(1)
config = load_config(args.config)
simulation_dir = os.path.dirname(args.config) or "."
# 创建动作日志记录器
action_log_path = os.path.join(simulation_dir, args.action_log)
action_logger = ActionLogger(action_log_path)
print("=" * 60)
print("OASIS 双平台并行模拟")
print(f"配置文件: {args.config}")
print(f"模拟ID: {config.get('simulation_id', 'unknown')}")
print(f"动作日志: {action_log_path}")
print("=" * 60)
time_config = config.get("time_config", {})
print(f"\n模拟参数:")
print(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时")
print(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟")
print(f" - Agent数量: {len(config.get('agent_configs', []))}")
# LLM推理说明
reasoning = config.get("generation_reasoning", "")
if reasoning:
print(f"\nLLM配置推理:")
print(f" {reasoning[:500]}..." if len(reasoning) > 500 else f" {reasoning}")
print("\n" + "=" * 60)
start_time = datetime.now()
if args.twitter_only:
await run_twitter_simulation(config, simulation_dir, action_logger)
elif args.reddit_only:
await run_reddit_simulation(config, simulation_dir, action_logger)
else:
# 并行运行共享同一个action_logger
await asyncio.gather(
run_twitter_simulation(config, simulation_dir, action_logger),
run_reddit_simulation(config, simulation_dir, action_logger),
)
total_elapsed = (datetime.now() - start_time).total_seconds()
print("\n" + "=" * 60)
print(f"全部模拟完成! 总耗时: {total_elapsed:.1f}")
print(f"动作日志已保存到: {action_log_path}")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,298 @@
"""
OASIS Reddit模拟预设脚本
此脚本读取配置文件中的参数来执行模拟实现全程自动化
使用方式:
python run_reddit_simulation.py --config /path/to/simulation_config.json
"""
import argparse
import asyncio
import json
import os
import random
import sys
from datetime import datetime
from typing import Dict, Any, List
# 添加项目路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
from camel.models import ModelFactory
from camel.types import ModelPlatformType
import oasis
from oasis import (
ActionType,
LLMAction,
ManualAction,
generate_reddit_agent_graph
)
except ImportError as e:
print(f"错误: 缺少依赖 {e}")
print("请先安装: pip install oasis-ai camel-ai")
sys.exit(1)
class RedditSimulationRunner:
"""Reddit模拟运行器"""
# Reddit可用动作
AVAILABLE_ACTIONS = [
ActionType.LIKE_POST,
ActionType.DISLIKE_POST,
ActionType.CREATE_POST,
ActionType.CREATE_COMMENT,
ActionType.LIKE_COMMENT,
ActionType.DISLIKE_COMMENT,
ActionType.SEARCH_POSTS,
ActionType.SEARCH_USER,
ActionType.TREND,
ActionType.REFRESH,
ActionType.DO_NOTHING,
ActionType.FOLLOW,
ActionType.MUTE,
]
def __init__(self, config_path: str):
"""
初始化模拟运行器
Args:
config_path: 配置文件路径 (simulation_config.json)
"""
self.config_path = config_path
self.config = self._load_config()
self.simulation_dir = os.path.dirname(config_path)
def _load_config(self) -> Dict[str, Any]:
"""加载配置文件"""
with open(self.config_path, 'r', encoding='utf-8') as f:
return json.load(f)
def _get_profile_path(self) -> str:
"""获取Profile文件路径"""
return os.path.join(self.simulation_dir, "reddit_profiles.json")
def _get_db_path(self) -> str:
"""获取数据库路径"""
return os.path.join(self.simulation_dir, "reddit_simulation.db")
def _create_model(self):
"""
创建LLM模型
OASIS使用camel-ai的ModelFactory配置方式
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量
- 自定义API: 设置 OPENAI_API_KEY OPENAI_API_BASE_URL 环境变量
"""
import os
llm_model = self.config.get("llm_model", "gpt-4o-mini")
llm_base_url = self.config.get("llm_base_url", "")
# 如果配置了base_url设置环境变量
if llm_base_url:
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
return ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=llm_model,
)
def _get_active_agents_for_round(
self,
env,
current_hour: int,
round_num: int
) -> List:
"""
根据时间和配置决定本轮激活哪些Agent
"""
time_config = self.config.get("time_config", {})
agent_configs = self.config.get("agent_configs", [])
base_min = time_config.get("agents_per_hour_min", 5)
base_max = time_config.get("agents_per_hour_max", 20)
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
if current_hour in peak_hours:
multiplier = time_config.get("peak_activity_multiplier", 1.5)
elif current_hour in off_peak_hours:
multiplier = time_config.get("off_peak_activity_multiplier", 0.3)
else:
multiplier = 1.0
target_count = int(random.uniform(base_min, base_max) * multiplier)
candidates = []
for cfg in agent_configs:
agent_id = cfg.get("agent_id", 0)
active_hours = cfg.get("active_hours", list(range(8, 23)))
activity_level = cfg.get("activity_level", 0.5)
if current_hour not in active_hours:
continue
if random.random() < activity_level:
candidates.append(agent_id)
selected_ids = random.sample(
candidates,
min(target_count, len(candidates))
) if candidates else []
active_agents = []
for agent_id in selected_ids:
try:
agent = env.agent_graph.get_agent(agent_id)
active_agents.append((agent_id, agent))
except Exception:
pass
return active_agents
async def run(self):
"""运行Reddit模拟"""
print("=" * 60)
print("OASIS Reddit模拟")
print(f"配置文件: {self.config_path}")
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
print("=" * 60)
time_config = self.config.get("time_config", {})
total_hours = time_config.get("total_simulation_hours", 72)
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round
print(f"\n模拟参数:")
print(f" - 总模拟时长: {total_hours}小时")
print(f" - 每轮时间: {minutes_per_round}分钟")
print(f" - 总轮数: {total_rounds}")
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
print("\n初始化LLM模型...")
model = self._create_model()
print("加载Agent Profile...")
profile_path = self._get_profile_path()
if not os.path.exists(profile_path):
print(f"错误: Profile文件不存在: {profile_path}")
return
agent_graph = await generate_reddit_agent_graph(
profile_path=profile_path,
model=model,
available_actions=self.AVAILABLE_ACTIONS,
)
db_path = self._get_db_path()
if os.path.exists(db_path):
os.remove(db_path)
print(f"已删除旧数据库: {db_path}")
print("创建OASIS环境...")
env = oasis.make(
agent_graph=agent_graph,
platform=oasis.DefaultPlatformType.REDDIT,
database_path=db_path,
)
await env.reset()
print("环境初始化完成\n")
# 执行初始事件
event_config = self.config.get("event_config", {})
initial_posts = event_config.get("initial_posts", [])
if initial_posts:
print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...")
initial_actions = {}
for post in initial_posts:
agent_id = post.get("poster_agent_id", 0)
content = post.get("content", "")
try:
agent = env.agent_graph.get_agent(agent_id)
if agent in initial_actions:
if not isinstance(initial_actions[agent], list):
initial_actions[agent] = [initial_actions[agent]]
initial_actions[agent].append(ManualAction(
action_type=ActionType.CREATE_POST,
action_args={"content": content}
))
else:
initial_actions[agent] = ManualAction(
action_type=ActionType.CREATE_POST,
action_args={"content": content}
)
except Exception as e:
print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}")
if initial_actions:
await env.step(initial_actions)
print(f" 已发布 {len(initial_actions)} 条初始帖子")
# 主模拟循环
print("\n开始模拟循环...")
start_time = datetime.now()
for round_num in range(total_rounds):
simulated_minutes = round_num * minutes_per_round
simulated_hour = (simulated_minutes // 60) % 24
simulated_day = simulated_minutes // (60 * 24) + 1
active_agents = self._get_active_agents_for_round(
env, simulated_hour, round_num
)
if not active_agents:
continue
actions = {
agent: LLMAction()
for _, agent in active_agents
}
await env.step(actions)
if (round_num + 1) % 10 == 0 or round_num == 0:
elapsed = (datetime.now() - start_time).total_seconds()
progress = (round_num + 1) / total_rounds * 100
print(f" [Day {simulated_day}, {simulated_hour:02d}:00] "
f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) "
f"- {len(active_agents)} agents active "
f"- elapsed: {elapsed:.1f}s")
await env.close()
total_elapsed = (datetime.now() - start_time).total_seconds()
print(f"\n模拟完成!")
print(f" - 总耗时: {total_elapsed:.1f}")
print(f" - 数据库: {db_path}")
print("=" * 60)
async def main():
parser = argparse.ArgumentParser(description='OASIS Reddit模拟')
parser.add_argument(
'--config',
type=str,
required=True,
help='配置文件路径 (simulation_config.json)'
)
args = parser.parse_args()
if not os.path.exists(args.config):
print(f"错误: 配置文件不存在: {args.config}")
sys.exit(1)
runner = RedditSimulationRunner(args.config)
await runner.run()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,313 @@
"""
OASIS Twitter模拟预设脚本
此脚本读取配置文件中的参数来执行模拟实现全程自动化
使用方式:
python run_twitter_simulation.py --config /path/to/simulation_config.json
"""
import argparse
import asyncio
import json
import os
import random
import sys
from datetime import datetime
from typing import Dict, Any, List
# 添加项目路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
from camel.models import ModelFactory
from camel.types import ModelPlatformType
import oasis
from oasis import (
ActionType,
LLMAction,
ManualAction,
generate_twitter_agent_graph
)
except ImportError as e:
print(f"错误: 缺少依赖 {e}")
print("请先安装: pip install oasis-ai camel-ai")
sys.exit(1)
class TwitterSimulationRunner:
"""Twitter模拟运行器"""
# Twitter可用动作
AVAILABLE_ACTIONS = [
ActionType.CREATE_POST,
ActionType.LIKE_POST,
ActionType.REPOST,
ActionType.FOLLOW,
ActionType.DO_NOTHING,
ActionType.QUOTE_POST,
]
def __init__(self, config_path: str):
"""
初始化模拟运行器
Args:
config_path: 配置文件路径 (simulation_config.json)
"""
self.config_path = config_path
self.config = self._load_config()
self.simulation_dir = os.path.dirname(config_path)
def _load_config(self) -> Dict[str, Any]:
"""加载配置文件"""
with open(self.config_path, 'r', encoding='utf-8') as f:
return json.load(f)
def _get_profile_path(self) -> str:
"""获取Profile文件路径OASIS Twitter使用CSV格式"""
return os.path.join(self.simulation_dir, "twitter_profiles.csv")
def _get_db_path(self) -> str:
"""获取数据库路径"""
return os.path.join(self.simulation_dir, "twitter_simulation.db")
def _create_model(self):
"""
创建LLM模型
OASIS使用camel-ai的ModelFactory配置方式
- 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量
- 自定义API: 设置 OPENAI_API_KEY OPENAI_API_BASE_URL 环境变量
配置文件中的 llm_model 对应 model_type
"""
import os
llm_model = self.config.get("llm_model", "gpt-4o-mini")
llm_base_url = self.config.get("llm_base_url", "")
# 如果配置了base_url设置环境变量OASIS通过环境变量读取
if llm_base_url:
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
return ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=llm_model,
)
def _get_active_agents_for_round(
self,
env,
current_hour: int,
round_num: int
) -> List:
"""
根据时间和配置决定本轮激活哪些Agent
Args:
env: OASIS环境
current_hour: 当前模拟小时0-23
round_num: 当前轮数
Returns:
激活的Agent列表
"""
time_config = self.config.get("time_config", {})
agent_configs = self.config.get("agent_configs", [])
# 基础激活数量
base_min = time_config.get("agents_per_hour_min", 5)
base_max = time_config.get("agents_per_hour_max", 20)
# 根据时段调整
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
if current_hour in peak_hours:
multiplier = time_config.get("peak_activity_multiplier", 1.5)
elif current_hour in off_peak_hours:
multiplier = time_config.get("off_peak_activity_multiplier", 0.3)
else:
multiplier = 1.0
target_count = int(random.uniform(base_min, base_max) * multiplier)
# 根据每个Agent的配置计算激活概率
candidates = []
for cfg in agent_configs:
agent_id = cfg.get("agent_id", 0)
active_hours = cfg.get("active_hours", list(range(8, 23)))
activity_level = cfg.get("activity_level", 0.5)
# 检查是否在活跃时间
if current_hour not in active_hours:
continue
# 根据活跃度计算概率
if random.random() < activity_level:
candidates.append(agent_id)
# 随机选择
selected_ids = random.sample(
candidates,
min(target_count, len(candidates))
) if candidates else []
# 转换为Agent对象
active_agents = []
for agent_id in selected_ids:
try:
agent = env.agent_graph.get_agent(agent_id)
active_agents.append((agent_id, agent))
except Exception:
pass
return active_agents
async def run(self):
"""运行Twitter模拟"""
print("=" * 60)
print("OASIS Twitter模拟")
print(f"配置文件: {self.config_path}")
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
print("=" * 60)
# 加载时间配置
time_config = self.config.get("time_config", {})
total_hours = time_config.get("total_simulation_hours", 72)
minutes_per_round = time_config.get("minutes_per_round", 30)
# 计算总轮数
total_rounds = (total_hours * 60) // minutes_per_round
print(f"\n模拟参数:")
print(f" - 总模拟时长: {total_hours}小时")
print(f" - 每轮时间: {minutes_per_round}分钟")
print(f" - 总轮数: {total_rounds}")
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
# 创建模型
print("\n初始化LLM模型...")
model = self._create_model()
# 加载Agent图
print("加载Agent Profile...")
profile_path = self._get_profile_path()
if not os.path.exists(profile_path):
print(f"错误: Profile文件不存在: {profile_path}")
return
agent_graph = await generate_twitter_agent_graph(
profile_path=profile_path,
model=model,
available_actions=self.AVAILABLE_ACTIONS,
)
# 数据库路径
db_path = self._get_db_path()
if os.path.exists(db_path):
os.remove(db_path)
print(f"已删除旧数据库: {db_path}")
# 创建环境
print("创建OASIS环境...")
env = oasis.make(
agent_graph=agent_graph,
platform=oasis.DefaultPlatformType.TWITTER,
database_path=db_path,
)
await env.reset()
print("环境初始化完成\n")
# 执行初始事件
event_config = self.config.get("event_config", {})
initial_posts = event_config.get("initial_posts", [])
if initial_posts:
print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...")
initial_actions = {}
for post in initial_posts:
agent_id = post.get("poster_agent_id", 0)
content = post.get("content", "")
try:
agent = env.agent_graph.get_agent(agent_id)
initial_actions[agent] = ManualAction(
action_type=ActionType.CREATE_POST,
action_args={"content": content}
)
except Exception as e:
print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}")
if initial_actions:
await env.step(initial_actions)
print(f" 已发布 {len(initial_actions)} 条初始帖子")
# 主模拟循环
print("\n开始模拟循环...")
start_time = datetime.now()
for round_num in range(total_rounds):
# 计算当前模拟时间
simulated_minutes = round_num * minutes_per_round
simulated_hour = (simulated_minutes // 60) % 24
simulated_day = simulated_minutes // (60 * 24) + 1
# 获取本轮激活的Agent
active_agents = self._get_active_agents_for_round(
env, simulated_hour, round_num
)
if not active_agents:
continue
# 构建动作
actions = {
agent: LLMAction()
for _, agent in active_agents
}
# 执行动作
await env.step(actions)
# 打印进度
if (round_num + 1) % 10 == 0 or round_num == 0:
elapsed = (datetime.now() - start_time).total_seconds()
progress = (round_num + 1) / total_rounds * 100
print(f" [Day {simulated_day}, {simulated_hour:02d}:00] "
f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) "
f"- {len(active_agents)} agents active "
f"- elapsed: {elapsed:.1f}s")
# 关闭环境
await env.close()
total_elapsed = (datetime.now() - start_time).total_seconds()
print(f"\n模拟完成!")
print(f" - 总耗时: {total_elapsed:.1f}")
print(f" - 数据库: {db_path}")
print("=" * 60)
async def main():
parser = argparse.ArgumentParser(description='OASIS Twitter模拟')
parser.add_argument(
'--config',
type=str,
required=True,
help='配置文件路径 (simulation_config.json)'
)
args = parser.parse_args()
if not os.path.exists(args.config):
print(f"错误: 配置文件不存在: {args.config}")
sys.exit(1)
runner = TwitterSimulationRunner(args.config)
await runner.run()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,166 @@
"""
测试Profile格式生成是否符合OASIS要求
验证
1. Twitter Profile生成CSV格式
2. Reddit Profile生成JSON详细格式
"""
import os
import sys
import json
import csv
import tempfile
# 添加项目路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.services.oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
def test_profile_formats():
"""测试Profile格式"""
print("=" * 60)
print("OASIS Profile格式测试")
print("=" * 60)
# 创建测试Profile数据
test_profiles = [
OasisAgentProfile(
user_id=0,
user_name="test_user_123",
name="Test User",
bio="A test user for validation",
persona="Test User is an enthusiastic participant in social discussions.",
karma=1500,
friend_count=100,
follower_count=200,
statuses_count=500,
age=25,
gender="male",
mbti="INTJ",
country="China",
profession="Student",
interested_topics=["Technology", "Education"],
source_entity_uuid="test-uuid-123",
source_entity_type="Student",
),
OasisAgentProfile(
user_id=1,
user_name="org_official_456",
name="Official Organization",
bio="Official account for Organization",
persona="This is an official institutional account that communicates official positions.",
karma=5000,
friend_count=50,
follower_count=10000,
statuses_count=200,
profession="Organization",
interested_topics=["Public Policy", "Announcements"],
source_entity_uuid="test-uuid-456",
source_entity_type="University",
),
]
generator = OasisProfileGenerator.__new__(OasisProfileGenerator)
# 使用临时目录
with tempfile.TemporaryDirectory() as temp_dir:
twitter_path = os.path.join(temp_dir, "twitter_profiles.csv")
reddit_path = os.path.join(temp_dir, "reddit_profiles.json")
# 测试Twitter CSV格式
print("\n1. 测试Twitter Profile (CSV格式)")
print("-" * 40)
generator._save_twitter_csv(test_profiles, twitter_path)
# 读取并验证CSV
with open(twitter_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
rows = list(reader)
print(f" 文件: {twitter_path}")
print(f" 行数: {len(rows)}")
print(f" 表头: {list(rows[0].keys())}")
print(f"\n 示例数据 (第1行):")
for key, value in rows[0].items():
print(f" {key}: {value}")
# 验证必需字段
required_twitter_fields = ['user_id', 'user_name', 'name', 'bio',
'friend_count', 'follower_count', 'statuses_count', 'created_at']
missing = set(required_twitter_fields) - set(rows[0].keys())
if missing:
print(f"\n [错误] 缺少字段: {missing}")
else:
print(f"\n [通过] 所有必需字段都存在")
# 测试Reddit JSON格式
print("\n2. 测试Reddit Profile (JSON详细格式)")
print("-" * 40)
generator._save_reddit_json(test_profiles, reddit_path)
# 读取并验证JSON
with open(reddit_path, 'r', encoding='utf-8') as f:
reddit_data = json.load(f)
print(f" 文件: {reddit_path}")
print(f" 条目数: {len(reddit_data)}")
print(f" 字段: {list(reddit_data[0].keys())}")
print(f"\n 示例数据 (第1条):")
print(json.dumps(reddit_data[0], ensure_ascii=False, indent=4))
# 验证详细格式字段
required_reddit_fields = ['realname', 'username', 'bio', 'persona']
optional_reddit_fields = ['age', 'gender', 'mbti', 'country', 'profession', 'interested_topics']
missing = set(required_reddit_fields) - set(reddit_data[0].keys())
if missing:
print(f"\n [错误] 缺少必需字段: {missing}")
else:
print(f"\n [通过] 所有必需字段都存在")
present_optional = set(optional_reddit_fields) & set(reddit_data[0].keys())
print(f" [信息] 可选字段: {present_optional}")
print("\n" + "=" * 60)
print("测试完成!")
print("=" * 60)
def show_expected_formats():
"""显示OASIS期望的格式"""
print("\n" + "=" * 60)
print("OASIS 期望的Profile格式参考")
print("=" * 60)
print("\n1. Twitter Profile (CSV格式)")
print("-" * 40)
twitter_example = """user_id,user_name,name,bio,friend_count,follower_count,statuses_count,created_at
0,user0,User Zero,I am user zero with interests in technology.,100,150,500,2023-01-01
1,user1,User One,Tech enthusiast and coffee lover.,200,250,1000,2023-01-02"""
print(twitter_example)
print("\n2. Reddit Profile (JSON详细格式)")
print("-" * 40)
reddit_example = [
{
"realname": "James Miller",
"username": "millerhospitality",
"bio": "Passionate about hospitality & tourism.",
"persona": "James is a seasoned professional in the Hospitality & Tourism industry...",
"age": 40,
"gender": "male",
"mbti": "ESTJ",
"country": "UK",
"profession": "Hospitality & Tourism",
"interested_topics": ["Economics", "Business"]
}
]
print(json.dumps(reddit_example, ensure_ascii=False, indent=2))
if __name__ == "__main__":
test_profile_formats()
show_expected_formats()