- Updated README.md to include new simulation scripts and configuration details for OASIS, including API retry mechanisms and environment variable settings. - Added simulation management and configuration generation services to streamline the simulation process across Twitter and Reddit platforms. - Introduced new API routes for simulation-related operations, including entity retrieval and simulation status management. - Implemented a robust retry mechanism for external API calls to improve system stability. - Enhanced task management model to include detailed progress tracking. - Added logging capabilities for action tracking during simulations. - Included new scripts for running parallel simulations and testing profile formats.
561 lines
21 KiB
Python
561 lines
21 KiB
Python
"""
|
||
OASIS Agent Profile生成器
|
||
将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式
|
||
"""
|
||
|
||
import json
|
||
import random
|
||
from typing import Dict, Any, List, Optional
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
|
||
from openai import OpenAI
|
||
|
||
from ..config import Config
|
||
from ..utils.logger import get_logger
|
||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||
|
||
logger = get_logger('mirofish.oasis_profile')
|
||
|
||
|
||
@dataclass
|
||
class OasisAgentProfile:
|
||
"""OASIS Agent Profile数据结构"""
|
||
# 通用字段
|
||
user_id: int
|
||
user_name: str
|
||
name: str
|
||
bio: str
|
||
persona: str
|
||
|
||
# 可选字段 - Reddit风格
|
||
karma: int = 1000
|
||
|
||
# 可选字段 - Twitter风格
|
||
friend_count: int = 100
|
||
follower_count: int = 150
|
||
statuses_count: int = 500
|
||
|
||
# 额外人设信息
|
||
age: Optional[int] = None
|
||
gender: Optional[str] = None
|
||
mbti: Optional[str] = None
|
||
country: Optional[str] = None
|
||
profession: Optional[str] = None
|
||
interested_topics: List[str] = field(default_factory=list)
|
||
|
||
# 来源实体信息
|
||
source_entity_uuid: Optional[str] = None
|
||
source_entity_type: Optional[str] = None
|
||
|
||
created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d"))
|
||
|
||
def to_reddit_format(self) -> Dict[str, Any]:
|
||
"""转换为Reddit平台格式"""
|
||
profile = {
|
||
"user_id": self.user_id,
|
||
"user_name": self.user_name,
|
||
"name": self.name,
|
||
"bio": self.bio,
|
||
"persona": self.persona,
|
||
"karma": self.karma,
|
||
"created_at": self.created_at,
|
||
}
|
||
|
||
# 添加额外人设信息(如果有)
|
||
if self.age:
|
||
profile["age"] = self.age
|
||
if self.gender:
|
||
profile["gender"] = self.gender
|
||
if self.mbti:
|
||
profile["mbti"] = self.mbti
|
||
if self.country:
|
||
profile["country"] = self.country
|
||
if self.profession:
|
||
profile["profession"] = self.profession
|
||
if self.interested_topics:
|
||
profile["interested_topics"] = self.interested_topics
|
||
|
||
return profile
|
||
|
||
def to_twitter_format(self) -> Dict[str, Any]:
|
||
"""转换为Twitter平台格式"""
|
||
profile = {
|
||
"user_id": self.user_id,
|
||
"user_name": self.user_name,
|
||
"name": self.name,
|
||
"bio": self.bio,
|
||
"persona": self.persona,
|
||
"friend_count": self.friend_count,
|
||
"follower_count": self.follower_count,
|
||
"statuses_count": self.statuses_count,
|
||
"created_at": self.created_at,
|
||
}
|
||
|
||
# 添加额外人设信息
|
||
if self.age:
|
||
profile["age"] = self.age
|
||
if self.gender:
|
||
profile["gender"] = self.gender
|
||
if self.mbti:
|
||
profile["mbti"] = self.mbti
|
||
if self.country:
|
||
profile["country"] = self.country
|
||
if self.profession:
|
||
profile["profession"] = self.profession
|
||
if self.interested_topics:
|
||
profile["interested_topics"] = self.interested_topics
|
||
|
||
return profile
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为完整字典格式"""
|
||
return {
|
||
"user_id": self.user_id,
|
||
"user_name": self.user_name,
|
||
"name": self.name,
|
||
"bio": self.bio,
|
||
"persona": self.persona,
|
||
"karma": self.karma,
|
||
"friend_count": self.friend_count,
|
||
"follower_count": self.follower_count,
|
||
"statuses_count": self.statuses_count,
|
||
"age": self.age,
|
||
"gender": self.gender,
|
||
"mbti": self.mbti,
|
||
"country": self.country,
|
||
"profession": self.profession,
|
||
"interested_topics": self.interested_topics,
|
||
"source_entity_uuid": self.source_entity_uuid,
|
||
"source_entity_type": self.source_entity_type,
|
||
"created_at": self.created_at,
|
||
}
|
||
|
||
|
||
class OasisProfileGenerator:
|
||
"""
|
||
OASIS Profile生成器
|
||
|
||
将Zep图谱中的实体转换为OASIS模拟所需的Agent Profile
|
||
"""
|
||
|
||
# MBTI类型列表
|
||
MBTI_TYPES = [
|
||
"INTJ", "INTP", "ENTJ", "ENTP",
|
||
"INFJ", "INFP", "ENFJ", "ENFP",
|
||
"ISTJ", "ISFJ", "ESTJ", "ESFJ",
|
||
"ISTP", "ISFP", "ESTP", "ESFP"
|
||
]
|
||
|
||
# 常见国家列表
|
||
COUNTRIES = [
|
||
"China", "US", "UK", "Japan", "Germany", "France",
|
||
"Canada", "Australia", "Brazil", "India", "South Korea"
|
||
]
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: Optional[str] = None,
|
||
base_url: Optional[str] = None,
|
||
model_name: Optional[str] = None
|
||
):
|
||
self.api_key = api_key or Config.LLM_API_KEY
|
||
self.base_url = base_url or Config.LLM_BASE_URL
|
||
self.model_name = model_name or Config.LLM_MODEL_NAME
|
||
|
||
if not self.api_key:
|
||
raise ValueError("LLM_API_KEY 未配置")
|
||
|
||
self.client = OpenAI(
|
||
api_key=self.api_key,
|
||
base_url=self.base_url
|
||
)
|
||
|
||
def generate_profile_from_entity(
|
||
self,
|
||
entity: EntityNode,
|
||
user_id: int,
|
||
use_llm: bool = True
|
||
) -> OasisAgentProfile:
|
||
"""
|
||
从Zep实体生成OASIS Agent Profile
|
||
|
||
Args:
|
||
entity: Zep实体节点
|
||
user_id: 用户ID(用于OASIS)
|
||
use_llm: 是否使用LLM生成详细人设
|
||
|
||
Returns:
|
||
OasisAgentProfile
|
||
"""
|
||
entity_type = entity.get_entity_type() or "Entity"
|
||
|
||
# 基础信息
|
||
name = entity.name
|
||
user_name = self._generate_username(name)
|
||
|
||
# 构建上下文信息
|
||
context = self._build_entity_context(entity)
|
||
|
||
if use_llm:
|
||
# 使用LLM生成详细人设
|
||
profile_data = self._generate_profile_with_llm(
|
||
entity_name=name,
|
||
entity_type=entity_type,
|
||
entity_summary=entity.summary,
|
||
entity_attributes=entity.attributes,
|
||
context=context
|
||
)
|
||
else:
|
||
# 使用规则生成基础人设
|
||
profile_data = self._generate_profile_rule_based(
|
||
entity_name=name,
|
||
entity_type=entity_type,
|
||
entity_summary=entity.summary,
|
||
entity_attributes=entity.attributes
|
||
)
|
||
|
||
return OasisAgentProfile(
|
||
user_id=user_id,
|
||
user_name=user_name,
|
||
name=name,
|
||
bio=profile_data.get("bio", f"{entity_type}: {name}"),
|
||
persona=profile_data.get("persona", entity.summary or f"A {entity_type} named {name}."),
|
||
karma=profile_data.get("karma", random.randint(500, 5000)),
|
||
friend_count=profile_data.get("friend_count", random.randint(50, 500)),
|
||
follower_count=profile_data.get("follower_count", random.randint(100, 1000)),
|
||
statuses_count=profile_data.get("statuses_count", random.randint(100, 2000)),
|
||
age=profile_data.get("age"),
|
||
gender=profile_data.get("gender"),
|
||
mbti=profile_data.get("mbti"),
|
||
country=profile_data.get("country"),
|
||
profession=profile_data.get("profession"),
|
||
interested_topics=profile_data.get("interested_topics", []),
|
||
source_entity_uuid=entity.uuid,
|
||
source_entity_type=entity_type,
|
||
)
|
||
|
||
def _generate_username(self, name: str) -> str:
|
||
"""生成用户名"""
|
||
# 移除特殊字符,转换为小写
|
||
username = name.lower().replace(" ", "_")
|
||
username = ''.join(c for c in username if c.isalnum() or c == '_')
|
||
|
||
# 添加随机后缀避免重复
|
||
suffix = random.randint(100, 999)
|
||
return f"{username}_{suffix}"
|
||
|
||
def _build_entity_context(self, entity: EntityNode) -> str:
|
||
"""构建实体的上下文信息"""
|
||
context_parts = []
|
||
|
||
# 添加相关边信息
|
||
if entity.related_edges:
|
||
relationships = []
|
||
for edge in entity.related_edges[:10]: # 最多取10条
|
||
if edge.get("fact"):
|
||
relationships.append(edge["fact"])
|
||
|
||
if relationships:
|
||
context_parts.append("Related facts:\n" + "\n".join(f"- {r}" for r in relationships))
|
||
|
||
# 添加关联节点信息
|
||
if entity.related_nodes:
|
||
related_names = [n["name"] for n in entity.related_nodes[:5]]
|
||
if related_names:
|
||
context_parts.append(f"Related to: {', '.join(related_names)}")
|
||
|
||
return "\n\n".join(context_parts)
|
||
|
||
def _generate_profile_with_llm(
|
||
self,
|
||
entity_name: str,
|
||
entity_type: str,
|
||
entity_summary: str,
|
||
entity_attributes: Dict[str, Any],
|
||
context: str
|
||
) -> Dict[str, Any]:
|
||
"""使用LLM生成详细人设"""
|
||
|
||
prompt = f"""Based on the following entity information, generate a detailed social media user profile for simulation purposes.
|
||
|
||
Entity Information:
|
||
- Name: {entity_name}
|
||
- Type: {entity_type}
|
||
- Summary: {entity_summary}
|
||
- Attributes: {json.dumps(entity_attributes, ensure_ascii=False)}
|
||
|
||
Context:
|
||
{context}
|
||
|
||
Generate a JSON object with the following fields:
|
||
{{
|
||
"bio": "A short bio (max 150 chars) suitable for social media",
|
||
"persona": "A detailed persona description (2-3 sentences) describing personality, interests, and behavior patterns",
|
||
"age": <integer between 18-65, or null if not applicable>,
|
||
"gender": "<male/female/other, or null if not applicable>",
|
||
"mbti": "<MBTI type like INTJ, ENFP, etc., or null>",
|
||
"country": "<country name, or null>",
|
||
"profession": "<profession/occupation, or null>",
|
||
"interested_topics": ["topic1", "topic2", ...]
|
||
}}
|
||
|
||
Important:
|
||
- The profile should be consistent with the entity type and context
|
||
- Make the persona feel realistic and suitable for social media simulation
|
||
- If the entity is an organization, institution, or non-person, adapt the profile accordingly (e.g., as an official account)
|
||
- Return ONLY the JSON object, no additional text"""
|
||
|
||
try:
|
||
# 使用重试机制调用LLM API
|
||
from ..utils.retry import RetryableAPIClient
|
||
|
||
retry_client = RetryableAPIClient(max_retries=3, initial_delay=1.0)
|
||
|
||
def call_llm():
|
||
return self.client.chat.completions.create(
|
||
model=self.model_name,
|
||
messages=[
|
||
{"role": "system", "content": "You are a profile generator for social media simulation. Generate realistic user profiles based on entity information."},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.7
|
||
)
|
||
|
||
response = retry_client.call_with_retry(call_llm)
|
||
result = json.loads(response.choices[0].message.content)
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.warning(f"LLM生成人设失败(已重试): {str(e)}, 使用规则生成")
|
||
return self._generate_profile_rule_based(
|
||
entity_name, entity_type, entity_summary, entity_attributes
|
||
)
|
||
|
||
def _generate_profile_rule_based(
|
||
self,
|
||
entity_name: str,
|
||
entity_type: str,
|
||
entity_summary: str,
|
||
entity_attributes: Dict[str, Any]
|
||
) -> Dict[str, Any]:
|
||
"""使用规则生成基础人设"""
|
||
|
||
# 根据实体类型生成不同的人设
|
||
entity_type_lower = entity_type.lower()
|
||
|
||
if entity_type_lower in ["student", "alumni"]:
|
||
return {
|
||
"bio": f"{entity_type} with interests in academics and social issues.",
|
||
"persona": f"{entity_name} is a {entity_type.lower()} who is actively engaged in academic and social discussions. They enjoy sharing perspectives and connecting with peers.",
|
||
"age": random.randint(18, 30),
|
||
"gender": random.choice(["male", "female"]),
|
||
"mbti": random.choice(self.MBTI_TYPES),
|
||
"country": random.choice(self.COUNTRIES),
|
||
"profession": "Student",
|
||
"interested_topics": ["Education", "Social Issues", "Technology"],
|
||
}
|
||
|
||
elif entity_type_lower in ["publicfigure", "expert", "faculty"]:
|
||
return {
|
||
"bio": f"Expert and thought leader in their field.",
|
||
"persona": f"{entity_name} is a recognized {entity_type.lower()} who shares insights and opinions on important matters. They are known for their expertise and influence in public discourse.",
|
||
"age": random.randint(35, 60),
|
||
"gender": random.choice(["male", "female"]),
|
||
"mbti": random.choice(["ENTJ", "INTJ", "ENTP", "INTP"]),
|
||
"country": random.choice(self.COUNTRIES),
|
||
"profession": entity_attributes.get("occupation", "Expert"),
|
||
"interested_topics": ["Politics", "Economics", "Culture & Society"],
|
||
}
|
||
|
||
elif entity_type_lower in ["mediaoutlet", "socialmediaplatform"]:
|
||
return {
|
||
"bio": f"Official account for {entity_name}. News and updates.",
|
||
"persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.",
|
||
"profession": "Media",
|
||
"interested_topics": ["General News", "Current Events", "Public Affairs"],
|
||
}
|
||
|
||
elif entity_type_lower in ["university", "governmentagency", "ngo", "organization"]:
|
||
return {
|
||
"bio": f"Official account of {entity_name}.",
|
||
"persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.",
|
||
"profession": entity_type,
|
||
"interested_topics": ["Public Policy", "Community", "Official Announcements"],
|
||
}
|
||
|
||
else:
|
||
# 默认人设
|
||
return {
|
||
"bio": entity_summary[:150] if entity_summary else f"{entity_type}: {entity_name}",
|
||
"persona": entity_summary or f"{entity_name} is a {entity_type.lower()} participating in social discussions.",
|
||
"age": random.randint(25, 50),
|
||
"gender": random.choice(["male", "female"]),
|
||
"mbti": random.choice(self.MBTI_TYPES),
|
||
"country": random.choice(self.COUNTRIES),
|
||
"profession": entity_type,
|
||
"interested_topics": ["General", "Social Issues"],
|
||
}
|
||
|
||
def generate_profiles_from_entities(
|
||
self,
|
||
entities: List[EntityNode],
|
||
use_llm: bool = True,
|
||
progress_callback: Optional[callable] = None
|
||
) -> List[OasisAgentProfile]:
|
||
"""
|
||
批量从实体生成Agent Profile
|
||
|
||
Args:
|
||
entities: 实体列表
|
||
use_llm: 是否使用LLM生成详细人设
|
||
progress_callback: 进度回调函数 (current, total, message)
|
||
|
||
Returns:
|
||
Agent Profile列表
|
||
"""
|
||
profiles = []
|
||
total = len(entities)
|
||
|
||
for idx, entity in enumerate(entities):
|
||
if progress_callback:
|
||
progress_callback(idx + 1, total, f"生成 {entity.name} 的人设...")
|
||
|
||
try:
|
||
profile = self.generate_profile_from_entity(
|
||
entity=entity,
|
||
user_id=idx,
|
||
use_llm=use_llm
|
||
)
|
||
profiles.append(profile)
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成实体 {entity.name} 的人设失败: {str(e)}")
|
||
# 创建一个基础profile
|
||
profiles.append(OasisAgentProfile(
|
||
user_id=idx,
|
||
user_name=self._generate_username(entity.name),
|
||
name=entity.name,
|
||
bio=f"{entity.get_entity_type() or 'Entity'}: {entity.name}",
|
||
persona=entity.summary or f"A participant in social discussions.",
|
||
source_entity_uuid=entity.uuid,
|
||
source_entity_type=entity.get_entity_type(),
|
||
))
|
||
|
||
return profiles
|
||
|
||
def save_profiles(
|
||
self,
|
||
profiles: List[OasisAgentProfile],
|
||
file_path: str,
|
||
platform: str = "reddit"
|
||
):
|
||
"""
|
||
保存Profile到文件(根据平台选择正确格式)
|
||
|
||
OASIS平台格式要求:
|
||
- Twitter: CSV格式
|
||
- Reddit: JSON格式
|
||
|
||
Args:
|
||
profiles: Profile列表
|
||
file_path: 文件路径
|
||
platform: 平台类型 ("reddit" 或 "twitter")
|
||
"""
|
||
if platform == "twitter":
|
||
self._save_twitter_csv(profiles, file_path)
|
||
else:
|
||
self._save_reddit_json(profiles, file_path)
|
||
|
||
def _save_twitter_csv(self, profiles: List[OasisAgentProfile], file_path: str):
|
||
"""
|
||
保存Twitter Profile为CSV格式
|
||
|
||
OASIS Twitter要求的CSV字段:
|
||
user_id, user_name, name, bio, friend_count, follower_count, statuses_count, created_at
|
||
"""
|
||
import csv
|
||
|
||
# 确保文件扩展名是.csv
|
||
if not file_path.endswith('.csv'):
|
||
file_path = file_path.replace('.json', '.csv')
|
||
|
||
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
||
writer = csv.writer(f)
|
||
|
||
# 写入表头
|
||
headers = ['user_id', 'user_name', 'name', 'bio', 'friend_count',
|
||
'follower_count', 'statuses_count', 'created_at']
|
||
writer.writerow(headers)
|
||
|
||
# 写入数据行
|
||
for profile in profiles:
|
||
# bio需要处理换行符和逗号
|
||
bio = profile.bio.replace('\n', ' ').replace('\r', ' ')
|
||
row = [
|
||
profile.user_id,
|
||
profile.user_name,
|
||
profile.name,
|
||
bio,
|
||
profile.friend_count,
|
||
profile.follower_count,
|
||
profile.statuses_count,
|
||
profile.created_at
|
||
]
|
||
writer.writerow(row)
|
||
|
||
logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (CSV格式)")
|
||
|
||
def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str):
|
||
"""
|
||
保存Reddit Profile为JSON格式
|
||
|
||
OASIS Reddit支持两种JSON格式:
|
||
1. 基础格式: user_id, user_name, name, bio, karma, created_at
|
||
2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics
|
||
|
||
我们使用详细格式,与用户示例数据(36个简单人设.json)保持一致
|
||
"""
|
||
data = []
|
||
for profile in profiles:
|
||
# 使用详细格式(与用户示例兼容)
|
||
item = {
|
||
"realname": profile.name,
|
||
"username": profile.user_name,
|
||
"bio": profile.bio[:150] if profile.bio else "", # OASIS bio限制150字符
|
||
"persona": profile.persona or f"{profile.name} is a participant in social discussions.",
|
||
}
|
||
|
||
# 添加人设详情字段
|
||
if profile.age:
|
||
item["age"] = profile.age
|
||
if profile.gender:
|
||
item["gender"] = profile.gender
|
||
if profile.mbti:
|
||
item["mbti"] = profile.mbti
|
||
if profile.country:
|
||
item["country"] = profile.country
|
||
if profile.profession:
|
||
item["profession"] = profile.profession
|
||
if profile.interested_topics:
|
||
item["interested_topics"] = profile.interested_topics
|
||
|
||
data.append(item)
|
||
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
|
||
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式)")
|
||
|
||
# 保留旧方法名作为别名,保持向后兼容
|
||
def save_profiles_to_json(
|
||
self,
|
||
profiles: List[OasisAgentProfile],
|
||
file_path: str,
|
||
platform: str = "reddit"
|
||
):
|
||
"""[已废弃] 请使用 save_profiles() 方法"""
|
||
logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法")
|
||
self.save_profiles(profiles, file_path, platform)
|
||
|