MiroFish/backend/app/services/oasis_profile_generator.py

1200 lines
49 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
OASIS Agent Profile生成器
将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式
优化改进:
1. 调用Zep检索功能二次丰富节点信息
2. 优化提示词生成非常详细的人设
3. 区分个人实体和抽象群体实体
"""
import json
import random
import time
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
from openai import OpenAI
from zep_cloud.client import Zep
from ..config import Config
from ..utils.logger import get_logger
from .zep_entity_reader import EntityNode, ZepEntityReader
logger = get_logger('mirofish.oasis_profile')
@dataclass
class OasisAgentProfile:
"""OASIS Agent Profile数据结构"""
# 通用字段
user_id: int
user_name: str
name: str
bio: str
persona: str
# 可选字段 - Reddit风格
karma: int = 1000
# 可选字段 - Twitter风格
friend_count: int = 100
follower_count: int = 150
statuses_count: int = 500
# 额外人设信息
age: Optional[int] = None
gender: Optional[str] = None
mbti: Optional[str] = None
country: Optional[str] = None
profession: Optional[str] = None
interested_topics: List[str] = field(default_factory=list)
# 来源实体信息
source_entity_uuid: Optional[str] = None
source_entity_type: Optional[str] = None
created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d"))
def to_reddit_format(self) -> Dict[str, Any]:
"""转换为Reddit平台格式"""
profile = {
"user_id": self.user_id,
"username": self.user_name, # OASIS 库要求字段名为 username无下划线
"name": self.name,
"bio": self.bio,
"persona": self.persona,
"karma": self.karma,
"created_at": self.created_at,
}
# 添加额外人设信息(如果有)
if self.age:
profile["age"] = self.age
if self.gender:
profile["gender"] = self.gender
if self.mbti:
profile["mbti"] = self.mbti
if self.country:
profile["country"] = self.country
if self.profession:
profile["profession"] = self.profession
if self.interested_topics:
profile["interested_topics"] = self.interested_topics
return profile
def to_twitter_format(self) -> Dict[str, Any]:
"""转换为Twitter平台格式"""
profile = {
"user_id": self.user_id,
"username": self.user_name, # OASIS 库要求字段名为 username无下划线
"name": self.name,
"bio": self.bio,
"persona": self.persona,
"friend_count": self.friend_count,
"follower_count": self.follower_count,
"statuses_count": self.statuses_count,
"created_at": self.created_at,
}
# 添加额外人设信息
if self.age:
profile["age"] = self.age
if self.gender:
profile["gender"] = self.gender
if self.mbti:
profile["mbti"] = self.mbti
if self.country:
profile["country"] = self.country
if self.profession:
profile["profession"] = self.profession
if self.interested_topics:
profile["interested_topics"] = self.interested_topics
return profile
def to_dict(self) -> Dict[str, Any]:
"""转换为完整字典格式"""
return {
"user_id": self.user_id,
"user_name": self.user_name,
"name": self.name,
"bio": self.bio,
"persona": self.persona,
"karma": self.karma,
"friend_count": self.friend_count,
"follower_count": self.follower_count,
"statuses_count": self.statuses_count,
"age": self.age,
"gender": self.gender,
"mbti": self.mbti,
"country": self.country,
"profession": self.profession,
"interested_topics": self.interested_topics,
"source_entity_uuid": self.source_entity_uuid,
"source_entity_type": self.source_entity_type,
"created_at": self.created_at,
}
class OasisProfileGenerator:
"""
OASIS Profile生成器
将Zep图谱中的实体转换为OASIS模拟所需的Agent Profile
优化特性:
1. 调用Zep图谱检索功能获取更丰富的上下文
2. 生成非常详细的人设(包括基本信息、职业经历、性格特征、社交媒体行为等)
3. 区分个人实体和抽象群体实体
"""
# MBTI类型列表
MBTI_TYPES = [
"INTJ", "INTP", "ENTJ", "ENTP",
"INFJ", "INFP", "ENFJ", "ENFP",
"ISTJ", "ISFJ", "ESTJ", "ESFJ",
"ISTP", "ISFP", "ESTP", "ESFP"
]
# 常见国家列表
COUNTRIES = [
"China", "US", "UK", "Japan", "Germany", "France",
"Canada", "Australia", "Brazil", "India", "South Korea"
]
# 个人类型实体(需要生成具体人设)
INDIVIDUAL_ENTITY_TYPES = [
"student", "alumni", "professor", "person", "publicfigure",
"expert", "faculty", "official", "journalist", "activist"
]
# 群体/机构类型实体(需要生成群体代表人设)
GROUP_ENTITY_TYPES = [
"university", "governmentagency", "organization", "ngo",
"mediaoutlet", "company", "institution", "group", "community"
]
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model_name: Optional[str] = None,
zep_api_key: Optional[str] = None,
graph_id: Optional[str] = None
):
self.api_key = api_key or Config.LLM_API_KEY
self.base_url = base_url or Config.LLM_BASE_URL
self.model_name = model_name or Config.LLM_MODEL_NAME
if not self.api_key:
raise ValueError("LLM_API_KEY is not configured")
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
# Zep客户端用于检索丰富上下文
self.zep_api_key = zep_api_key or Config.ZEP_API_KEY
self.zep_client = None
self.graph_id = graph_id
if self.zep_api_key:
try:
self.zep_client = Zep(api_key=self.zep_api_key)
except Exception as e:
logger.warning(f"Zep客户端初始化失败: {e}")
def generate_profile_from_entity(
self,
entity: EntityNode,
user_id: int,
use_llm: bool = True
) -> OasisAgentProfile:
"""
从Zep实体生成OASIS Agent Profile
Args:
entity: Zep实体节点
user_id: 用户ID用于OASIS
use_llm: 是否使用LLM生成详细人设
Returns:
OasisAgentProfile
"""
entity_type = entity.get_entity_type() or "Entity"
# 基础信息
name = entity.name
user_name = self._generate_username(name)
# 构建上下文信息
context = self._build_entity_context(entity)
if use_llm:
# 使用LLM生成详细人设
profile_data = self._generate_profile_with_llm(
entity_name=name,
entity_type=entity_type,
entity_summary=entity.summary,
entity_attributes=entity.attributes,
context=context
)
else:
# 使用规则生成基础人设
profile_data = self._generate_profile_rule_based(
entity_name=name,
entity_type=entity_type,
entity_summary=entity.summary,
entity_attributes=entity.attributes
)
return OasisAgentProfile(
user_id=user_id,
user_name=user_name,
name=name,
bio=profile_data.get("bio", f"{entity_type}: {name}"),
persona=profile_data.get("persona", entity.summary or f"A {entity_type} named {name}."),
karma=profile_data.get("karma", random.randint(500, 5000)),
friend_count=profile_data.get("friend_count", random.randint(50, 500)),
follower_count=profile_data.get("follower_count", random.randint(100, 1000)),
statuses_count=profile_data.get("statuses_count", random.randint(100, 2000)),
age=profile_data.get("age"),
gender=profile_data.get("gender"),
mbti=profile_data.get("mbti"),
country=profile_data.get("country"),
profession=profile_data.get("profession"),
interested_topics=profile_data.get("interested_topics", []),
source_entity_uuid=entity.uuid,
source_entity_type=entity_type,
)
def _generate_username(self, name: str) -> str:
"""生成用户名"""
# 移除特殊字符,转换为小写
username = name.lower().replace(" ", "_")
username = ''.join(c for c in username if c.isalnum() or c == '_')
# 添加随机后缀避免重复
suffix = random.randint(100, 999)
return f"{username}_{suffix}"
def _search_zep_for_entity(self, entity: EntityNode) -> Dict[str, Any]:
"""
使用Zep图谱混合搜索功能获取实体相关的丰富信息
Zep没有内置混合搜索接口需要分别搜索edges和nodes然后合并结果。
使用并行请求同时搜索,提高效率。
Args:
entity: 实体节点对象
Returns:
包含facts, node_summaries, context的字典
"""
import concurrent.futures
if not self.zep_client:
return {"facts": [], "node_summaries": [], "context": ""}
entity_name = entity.name
results = {
"facts": [],
"node_summaries": [],
"context": ""
}
# 必须有graph_id才能进行搜索
if not self.graph_id:
logger.debug(f"跳过Zep检索未设置graph_id")
return results
comprehensive_query = f"All information, activities, events, relationships, and background about {entity_name}"
def search_edges():
"""搜索边(事实/关系)- 带重试机制"""
max_retries = 3
last_exception = None
delay = 2.0
for attempt in range(max_retries):
try:
return self.zep_client.graph.search(
query=comprehensive_query,
graph_id=self.graph_id,
limit=30,
scope="edges",
reranker="rrf"
)
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.debug(f"Zep边搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
time.sleep(delay)
delay *= 2
else:
logger.debug(f"Zep边搜索在 {max_retries} 次尝试后仍失败: {e}")
return None
def search_nodes():
"""搜索节点(实体摘要)- 带重试机制"""
max_retries = 3
last_exception = None
delay = 2.0
for attempt in range(max_retries):
try:
return self.zep_client.graph.search(
query=comprehensive_query,
graph_id=self.graph_id,
limit=20,
scope="nodes",
reranker="rrf"
)
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.debug(f"Zep节点搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
time.sleep(delay)
delay *= 2
else:
logger.debug(f"Zep节点搜索在 {max_retries} 次尝试后仍失败: {e}")
return None
try:
# 并行执行edges和nodes搜索
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
edge_future = executor.submit(search_edges)
node_future = executor.submit(search_nodes)
# 获取结果
edge_result = edge_future.result(timeout=30)
node_result = node_future.result(timeout=30)
# 处理边搜索结果
all_facts = set()
if edge_result and hasattr(edge_result, 'edges') and edge_result.edges:
for edge in edge_result.edges:
if hasattr(edge, 'fact') and edge.fact:
all_facts.add(edge.fact)
results["facts"] = list(all_facts)
# 处理节点搜索结果
all_summaries = set()
if node_result and hasattr(node_result, 'nodes') and node_result.nodes:
for node in node_result.nodes:
if hasattr(node, 'summary') and node.summary:
all_summaries.add(node.summary)
if hasattr(node, 'name') and node.name and node.name != entity_name:
all_summaries.add(f"Related entity: {node.name}")
results["node_summaries"] = list(all_summaries)
# 构建综合上下文
context_parts = []
if results["facts"]:
context_parts.append("Factual information:\n" + "\n".join(f"- {f}" for f in results["facts"][:20]))
if results["node_summaries"]:
context_parts.append("Related entities:\n" + "\n".join(f"- {s}" for s in results["node_summaries"][:10]))
results["context"] = "\n\n".join(context_parts)
logger.info(f"Zep混合检索完成: {entity_name}, 获取 {len(results['facts'])} 条事实, {len(results['node_summaries'])} 个相关节点")
except concurrent.futures.TimeoutError:
logger.warning(f"Zep检索超时 ({entity_name})")
except Exception as e:
logger.warning(f"Zep检索失败 ({entity_name}): {e}")
return results
def _build_entity_context(self, entity: EntityNode) -> str:
"""
构建实体的完整上下文信息
包括:
1. 实体本身的边信息(事实)
2. 关联节点的详细信息
3. Zep混合检索到的丰富信息
"""
context_parts = []
# 1. 添加实体属性信息
if entity.attributes:
attrs = []
for key, value in entity.attributes.items():
if value and str(value).strip():
attrs.append(f"- {key}: {value}")
if attrs:
context_parts.append("### Entity Attributes\n" + "\n".join(attrs))
# 2. 添加相关边信息(事实/关系)
existing_facts = set()
if entity.related_edges:
relationships = []
for edge in entity.related_edges: # 不限制数量
fact = edge.get("fact", "")
edge_name = edge.get("edge_name", "")
direction = edge.get("direction", "")
if fact:
relationships.append(f"- {fact}")
existing_facts.add(fact)
elif edge_name:
if direction == "outgoing":
relationships.append(f"- {entity.name} --[{edge_name}]--> (related entity)")
else:
relationships.append(f"- (related entity) --[{edge_name}]--> {entity.name}")
if relationships:
context_parts.append("### Related Facts and Relationships\n" + "\n".join(relationships))
# 3. 添加关联节点的详细信息
if entity.related_nodes:
related_info = []
for node in entity.related_nodes: # 不限制数量
node_name = node.get("name", "")
node_labels = node.get("labels", [])
node_summary = node.get("summary", "")
# 过滤掉默认标签
custom_labels = [l for l in node_labels if l not in ["Entity", "Node"]]
label_str = f" ({', '.join(custom_labels)})" if custom_labels else ""
if node_summary:
related_info.append(f"- **{node_name}**{label_str}: {node_summary}")
else:
related_info.append(f"- **{node_name}**{label_str}")
if related_info:
context_parts.append("### Related Entity Information\n" + "\n".join(related_info))
# 4. 使用Zep混合检索获取更丰富的信息
zep_results = self._search_zep_for_entity(entity)
if zep_results.get("facts"):
# 去重:排除已存在的事实
new_facts = [f for f in zep_results["facts"] if f not in existing_facts]
if new_facts:
context_parts.append("### Facts Retrieved from Zep\n" + "\n".join(f"- {f}" for f in new_facts[:15]))
if zep_results.get("node_summaries"):
context_parts.append("### Related Nodes Retrieved from Zep\n" + "\n".join(f"- {s}" for s in zep_results["node_summaries"][:10]))
return "\n\n".join(context_parts)
def _is_individual_entity(self, entity_type: str) -> bool:
"""判断是否是个人类型实体"""
return entity_type.lower() in self.INDIVIDUAL_ENTITY_TYPES
def _is_group_entity(self, entity_type: str) -> bool:
"""判断是否是群体/机构类型实体"""
return entity_type.lower() in self.GROUP_ENTITY_TYPES
def _generate_profile_with_llm(
self,
entity_name: str,
entity_type: str,
entity_summary: str,
entity_attributes: Dict[str, Any],
context: str
) -> Dict[str, Any]:
"""
使用LLM生成非常详细的人设
根据实体类型区分:
- 个人实体:生成具体的人物设定
- 群体/机构实体:生成代表性账号设定
"""
is_individual = self._is_individual_entity(entity_type)
if is_individual:
prompt = self._build_individual_persona_prompt(
entity_name, entity_type, entity_summary, entity_attributes, context
)
else:
prompt = self._build_group_persona_prompt(
entity_name, entity_type, entity_summary, entity_attributes, context
)
# 尝试多次生成,直到成功或达到最大重试次数
max_attempts = 3
last_error = None
for attempt in range(max_attempts):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": self._get_system_prompt(is_individual)},
{"role": "user", "content": prompt}
],
response_format={"type": "json_object"},
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
# 不设置max_tokens让LLM自由发挥
)
content = response.choices[0].message.content
# 检查是否被截断finish_reason不是'stop'
finish_reason = response.choices[0].finish_reason
if finish_reason == 'length':
logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...")
content = self._fix_truncated_json(content)
# 尝试解析JSON
try:
result = json.loads(content)
# 验证必需字段
if "bio" not in result or not result["bio"]:
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
if "persona" not in result or not result["persona"]:
result["persona"] = entity_summary or f"{entity_name} is a {entity_type}."
return result
except json.JSONDecodeError as je:
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}")
# 尝试修复JSON
result = self._try_fix_json(content, entity_name, entity_type, entity_summary)
if result.get("_fixed"):
del result["_fixed"]
return result
last_error = je
except Exception as e:
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
last_error = e
import time
time.sleep(1 * (attempt + 1)) # 指数退避
logger.warning(f"LLM生成人设失败{max_attempts}次尝试): {last_error}, 使用规则生成")
return self._generate_profile_rule_based(
entity_name, entity_type, entity_summary, entity_attributes
)
def _fix_truncated_json(self, content: str) -> str:
"""修复被截断的JSON输出被max_tokens限制截断"""
import re
# 如果JSON被截断尝试闭合它
content = content.strip()
# 计算未闭合的括号
open_braces = content.count('{') - content.count('}')
open_brackets = content.count('[') - content.count(']')
# 检查是否有未闭合的字符串
# 简单检查:如果最后一个引号后没有逗号或闭合括号,可能是字符串被截断
if content and content[-1] not in '",}]':
# 尝试闭合字符串
content += '"'
# 闭合括号
content += ']' * open_brackets
content += '}' * open_braces
return content
def _try_fix_json(self, content: str, entity_name: str, entity_type: str, entity_summary: str = "") -> Dict[str, Any]:
"""尝试修复损坏的JSON"""
import re
# 1. 首先尝试修复被截断的情况
content = self._fix_truncated_json(content)
# 2. 尝试提取JSON部分
json_match = re.search(r'\{[\s\S]*\}', content)
if json_match:
json_str = json_match.group()
# 3. 处理字符串中的换行符问题
# 找到所有字符串值并替换其中的换行符
def fix_string_newlines(match):
s = match.group(0)
# 替换字符串内的实际换行符为空格
s = s.replace('\n', ' ').replace('\r', ' ')
# 替换多余空格
s = re.sub(r'\s+', ' ', s)
return s
# 匹配JSON字符串值
json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str)
# 4. 尝试解析
try:
result = json.loads(json_str)
result["_fixed"] = True
return result
except json.JSONDecodeError as e:
# 5. 如果还是失败,尝试更激进的修复
try:
# 移除所有控制字符
json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str)
# 替换所有连续空白
json_str = re.sub(r'\s+', ' ', json_str)
result = json.loads(json_str)
result["_fixed"] = True
return result
except:
pass
# 6. 尝试从内容中提取部分信息
bio_match = re.search(r'"bio"\s*:\s*"([^"]*)"', content)
persona_match = re.search(r'"persona"\s*:\s*"([^"]*)', content) # 可能被截断
bio = bio_match.group(1) if bio_match else (entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}")
persona = persona_match.group(1) if persona_match else (entity_summary or f"{entity_name}是一个{entity_type}")
# 如果提取到了有意义的内容,标记为已修复
if bio_match or persona_match:
logger.info(f"从损坏的JSON中提取了部分信息")
return {
"bio": bio,
"persona": persona,
"_fixed": True
}
# 7. 完全失败,返回基础结构
logger.warning(f"JSON修复失败返回基础结构")
return {
"bio": entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}",
"persona": entity_summary or f"{entity_name}是一个{entity_type}"
}
def _get_system_prompt(self, is_individual: bool) -> str:
"""获取系统提示词"""
base_prompt = "You are a social media user profile generation expert. Generate detailed, realistic personas for public opinion simulation, maximally faithful to existing real-world information. You must return valid JSON format, all string values must not contain unescaped newlines."
return base_prompt
def _build_individual_persona_prompt(
self,
entity_name: str,
entity_type: str,
entity_summary: str,
entity_attributes: Dict[str, Any],
context: str
) -> str:
"""构建个人实体的详细人设提示词"""
attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "None"
context_str = context[:3000] if context else "No additional context"
return f"""Generate a detailed social media user persona for this entity, maximally faithful to existing real-world information.
Entity Name: {entity_name}
Entity Type: {entity_type}
Entity Summary: {entity_summary}
Entity Attributes: {attrs_str}
Context Information:
{context_str}
Generate JSON with the following fields:
1. bio: Social media biography, 200 words
2. persona: Detailed persona description (2000 words of plain text), must include:
- Basic information (age, profession, educational background, location)
- Background (important experiences, connection to events, social relationships)
- Personality traits (MBTI type, core personality, emotional expression style)
- Social media behavior (posting frequency, content preferences, interaction style, language characteristics)
- Stances and opinions (attitudes toward topics, content that may provoke or move them)
- Unique traits (catchphrases, special experiences, personal hobbies)
- Personal memories (important part of the persona, describe this individual's connection to events, and their existing actions and reactions in those events)
3. age: Age as a number (must be an integer)
4. gender: Must be in English: "male" or "female"
5. mbti: MBTI type (e.g., INTJ, ENFP)
6. country: Country name
7. profession: Profession
8. interested_topics: Array of topics of interest
Important:
- All field values must be strings or numbers, do not use newline characters
- persona must be a coherent text description
- Write in English (gender field must be "male" or "female")
- Content must be consistent with entity information
- age must be a valid integer, gender must be "male" or "female"
"""
def _build_group_persona_prompt(
self,
entity_name: str,
entity_type: str,
entity_summary: str,
entity_attributes: Dict[str, Any],
context: str
) -> str:
"""构建群体/机构实体的详细人设提示词"""
attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "None"
context_str = context[:3000] if context else "No additional context"
return f"""Generate a detailed social media account profile for this institution/group entity, maximally faithful to existing real-world information.
Entity Name: {entity_name}
Entity Type: {entity_type}
Entity Summary: {entity_summary}
Entity Attributes: {attrs_str}
Context Information:
{context_str}
Generate JSON with the following fields:
1. bio: Official account biography, 200 words, professional and appropriate
2. persona: Detailed account profile description (2000 words of plain text), must include:
- Institutional basic information (official name, nature, founding background, primary functions)
- Account positioning (account type, target audience, core functions)
- Communication style (language characteristics, common expressions, taboo topics)
- Content characteristics (content types, posting frequency, active time periods)
- Stances and attitudes (official positions on core topics, approach to handling controversies)
- Special notes (group portrait represented, operational habits)
- Institutional memory (important part of the persona, describe this institution's connection to events, and its existing actions and reactions in those events)
3. age: Fixed value 30 (virtual age for institutional accounts)
4. gender: Fixed value "other" (institutional accounts use "other" to indicate non-personal)
5. mbti: MBTI type to describe account style, e.g., ISTJ for rigorous and conservative
6. country: Country name
7. profession: Institutional function description
8. interested_topics: Array of focus areas
Important:
- All field values must be strings or numbers, null values are not allowed
- persona must be a coherent text description, do not use newline characters
- Write in English (gender field must be "other")
- age must be integer 30, gender must be string "other"
- Institutional account communication must align with its identity positioning"""
def _generate_profile_rule_based(
self,
entity_name: str,
entity_type: str,
entity_summary: str,
entity_attributes: Dict[str, Any]
) -> Dict[str, Any]:
"""使用规则生成基础人设"""
# 根据实体类型生成不同的人设
entity_type_lower = entity_type.lower()
if entity_type_lower in ["student", "alumni"]:
return {
"bio": f"{entity_type} with interests in academics and social issues.",
"persona": f"{entity_name} is a {entity_type.lower()} who is actively engaged in academic and social discussions. They enjoy sharing perspectives and connecting with peers.",
"age": random.randint(18, 30),
"gender": random.choice(["male", "female"]),
"mbti": random.choice(self.MBTI_TYPES),
"country": random.choice(self.COUNTRIES),
"profession": "Student",
"interested_topics": ["Education", "Social Issues", "Technology"],
}
elif entity_type_lower in ["publicfigure", "expert", "faculty"]:
return {
"bio": f"Expert and thought leader in their field.",
"persona": f"{entity_name} is a recognized {entity_type.lower()} who shares insights and opinions on important matters. They are known for their expertise and influence in public discourse.",
"age": random.randint(35, 60),
"gender": random.choice(["male", "female"]),
"mbti": random.choice(["ENTJ", "INTJ", "ENTP", "INTP"]),
"country": random.choice(self.COUNTRIES),
"profession": entity_attributes.get("occupation", "Expert"),
"interested_topics": ["Politics", "Economics", "Culture & Society"],
}
elif entity_type_lower in ["mediaoutlet", "socialmediaplatform"]:
return {
"bio": f"Official account for {entity_name}. News and updates.",
"persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.",
"age": 30, # 机构虚拟年龄
"gender": "other", # 机构使用other
"mbti": "ISTJ", # 机构风格:严谨保守
"country": "China",
"profession": "Media",
"interested_topics": ["General News", "Current Events", "Public Affairs"],
}
elif entity_type_lower in ["university", "governmentagency", "ngo", "organization"]:
return {
"bio": f"Official account of {entity_name}.",
"persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.",
"age": 30, # 机构虚拟年龄
"gender": "other", # 机构使用other
"mbti": "ISTJ", # 机构风格:严谨保守
"country": "China",
"profession": entity_type,
"interested_topics": ["Public Policy", "Community", "Official Announcements"],
}
else:
# 默认人设
return {
"bio": entity_summary[:150] if entity_summary else f"{entity_type}: {entity_name}",
"persona": entity_summary or f"{entity_name} is a {entity_type.lower()} participating in social discussions.",
"age": random.randint(25, 50),
"gender": random.choice(["male", "female"]),
"mbti": random.choice(self.MBTI_TYPES),
"country": random.choice(self.COUNTRIES),
"profession": entity_type,
"interested_topics": ["General", "Social Issues"],
}
def set_graph_id(self, graph_id: str):
"""设置图谱ID用于Zep检索"""
self.graph_id = graph_id
def generate_profiles_from_entities(
self,
entities: List[EntityNode],
use_llm: bool = True,
progress_callback: Optional[callable] = None,
graph_id: Optional[str] = None,
parallel_count: int = 5,
realtime_output_path: Optional[str] = None,
output_platform: str = "reddit"
) -> List[OasisAgentProfile]:
"""
批量从实体生成Agent Profile支持并行生成
Args:
entities: 实体列表
use_llm: 是否使用LLM生成详细人设
progress_callback: 进度回调函数 (current, total, message)
graph_id: 图谱ID用于Zep检索获取更丰富上下文
parallel_count: 并行生成数量默认5
realtime_output_path: 实时写入的文件路径(如果提供,每生成一个就写入一次)
output_platform: 输出平台格式 ("reddit""twitter")
Returns:
Agent Profile列表
"""
import concurrent.futures
from threading import Lock
# 设置graph_id用于Zep检索
if graph_id:
self.graph_id = graph_id
total = len(entities)
profiles = [None] * total # 预分配列表保持顺序
completed_count = [0] # 使用列表以便在闭包中修改
lock = Lock()
# 实时写入文件的辅助函数
def save_profiles_realtime():
"""实时保存已生成的 profiles 到文件"""
if not realtime_output_path:
return
with lock:
# 过滤出已生成的 profiles
existing_profiles = [p for p in profiles if p is not None]
if not existing_profiles:
return
try:
if output_platform == "reddit":
# Reddit JSON 格式
profiles_data = [p.to_reddit_format() for p in existing_profiles]
with open(realtime_output_path, 'w', encoding='utf-8') as f:
json.dump(profiles_data, f, ensure_ascii=False, indent=2)
else:
# Twitter CSV 格式
import csv
profiles_data = [p.to_twitter_format() for p in existing_profiles]
if profiles_data:
fieldnames = list(profiles_data[0].keys())
with open(realtime_output_path, 'w', encoding='utf-8', newline='') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(profiles_data)
except Exception as e:
logger.warning(f"实时保存 profiles 失败: {e}")
def generate_single_profile(idx: int, entity: EntityNode) -> tuple:
"""生成单个profile的工作函数"""
entity_type = entity.get_entity_type() or "Entity"
try:
profile = self.generate_profile_from_entity(
entity=entity,
user_id=idx,
use_llm=use_llm
)
# 实时输出生成的人设到控制台和日志
self._print_generated_profile(entity.name, entity_type, profile)
return idx, profile, None
except Exception as e:
logger.error(f"生成实体 {entity.name} 的人设失败: {str(e)}")
# 创建一个基础profile
fallback_profile = OasisAgentProfile(
user_id=idx,
user_name=self._generate_username(entity.name),
name=entity.name,
bio=f"{entity_type}: {entity.name}",
persona=entity.summary or f"A participant in social discussions.",
source_entity_uuid=entity.uuid,
source_entity_type=entity_type,
)
return idx, fallback_profile, str(e)
logger.info(f"开始并行生成 {total} 个Agent人设并行数: {parallel_count}...")
print(f"\n{'='*60}")
print(f"Starting agent persona generation - {total} entities, parallel count: {parallel_count}")
print(f"{'='*60}\n")
# 使用线程池并行执行
with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_count) as executor:
# 提交所有任务
future_to_entity = {
executor.submit(generate_single_profile, idx, entity): (idx, entity)
for idx, entity in enumerate(entities)
}
# 收集结果
for future in concurrent.futures.as_completed(future_to_entity):
idx, entity = future_to_entity[future]
entity_type = entity.get_entity_type() or "Entity"
try:
result_idx, profile, error = future.result()
profiles[result_idx] = profile
with lock:
completed_count[0] += 1
current = completed_count[0]
# 实时写入文件
save_profiles_realtime()
if progress_callback:
progress_callback(
current,
total,
f"Completed {current}/{total}: {entity.name} ({entity_type})"
)
if error:
logger.warning(f"[{current}/{total}] {entity.name} 使用备用人设: {error}")
else:
logger.info(f"[{current}/{total}] 成功生成人设: {entity.name} ({entity_type})")
except Exception as e:
logger.error(f"处理实体 {entity.name} 时发生异常: {str(e)}")
with lock:
completed_count[0] += 1
profiles[idx] = OasisAgentProfile(
user_id=idx,
user_name=self._generate_username(entity.name),
name=entity.name,
bio=f"{entity_type}: {entity.name}",
persona=entity.summary or "A participant in social discussions.",
source_entity_uuid=entity.uuid,
source_entity_type=entity_type,
)
# 实时写入文件(即使是备用人设)
save_profiles_realtime()
print(f"\n{'='*60}")
print(f"Persona generation complete! Generated {len([p for p in profiles if p])} agents")
print(f"{'='*60}\n")
return profiles
def _print_generated_profile(self, entity_name: str, entity_type: str, profile: OasisAgentProfile):
"""实时输出生成的人设到控制台(完整内容,不截断)"""
separator = "-" * 70
# 构建完整输出内容(不截断)
topics_str = ', '.join(profile.interested_topics) if profile.interested_topics else 'None'
output_lines = [
f"\n{separator}",
f"[Generated] {entity_name} ({entity_type})",
f"{separator}",
f"Username: {profile.user_name}",
f"",
f"[Bio]",
f"{profile.bio}",
f"",
f"[Detailed Persona]",
f"{profile.persona}",
f"",
f"[Basic Attributes]",
f"Age: {profile.age} | Gender: {profile.gender} | MBTI: {profile.mbti}",
f"Profession: {profile.profession} | Country: {profile.country}",
f"Interests: {topics_str}",
separator
]
output = "\n".join(output_lines)
# 只输出到控制台避免重复logger不再输出完整内容
print(output)
def save_profiles(
self,
profiles: List[OasisAgentProfile],
file_path: str,
platform: str = "reddit"
):
"""
保存Profile到文件根据平台选择正确格式
OASIS平台格式要求
- Twitter: CSV格式
- Reddit: JSON格式
Args:
profiles: Profile列表
file_path: 文件路径
platform: 平台类型 ("reddit""twitter")
"""
if platform == "twitter":
self._save_twitter_csv(profiles, file_path)
else:
self._save_reddit_json(profiles, file_path)
def _save_twitter_csv(self, profiles: List[OasisAgentProfile], file_path: str):
"""
保存Twitter Profile为CSV格式符合OASIS官方要求
OASIS Twitter要求的CSV字段
- user_id: 用户ID根据CSV顺序从0开始
- name: 用户真实姓名
- username: 系统中的用户名
- user_char: 详细人设描述注入到LLM系统提示中指导Agent行为
- description: 简短的公开简介(显示在用户资料页面)
user_char vs description 区别:
- user_char: 内部使用LLM系统提示决定Agent如何思考和行动
- description: 外部显示,其他用户可见的简介
"""
import csv
# 确保文件扩展名是.csv
if not file_path.endswith('.csv'):
file_path = file_path.replace('.json', '.csv')
with open(file_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
# 写入OASIS要求的表头
headers = ['user_id', 'name', 'username', 'user_char', 'description']
writer.writerow(headers)
# 写入数据行
for idx, profile in enumerate(profiles):
# user_char: 完整人设bio + persona用于LLM系统提示
user_char = profile.bio
if profile.persona and profile.persona != profile.bio:
user_char = f"{profile.bio} {profile.persona}"
# 处理换行符CSV中用空格替代
user_char = user_char.replace('\n', ' ').replace('\r', ' ')
# description: 简短简介,用于外部显示
description = profile.bio.replace('\n', ' ').replace('\r', ' ')
row = [
idx, # user_id: 从0开始的顺序ID
profile.name, # name: 真实姓名
profile.user_name, # username: 用户名
user_char, # user_char: 完整人设内部LLM使用
description # description: 简短简介(外部显示)
]
writer.writerow(row)
logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (OASIS CSV格式)")
def _normalize_gender(self, gender: Optional[str]) -> str:
"""
标准化gender字段为OASIS要求的英文格式
OASIS要求: male, female, other
"""
if not gender:
return "other"
gender_lower = gender.lower().strip()
# 中文映射
gender_map = {
"": "male",
"": "female",
"机构": "other",
"其他": "other",
# 英文已有
"male": "male",
"female": "female",
"other": "other",
}
return gender_map.get(gender_lower, "other")
def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str):
"""
保存Reddit Profile为JSON格式
使用与 to_reddit_format() 一致的格式,确保 OASIS 能正确读取。
必须包含 user_id 字段,这是 OASIS agent_graph.get_agent() 匹配的关键!
必需字段:
- user_id: 用户ID整数用于匹配 initial_posts 中的 poster_agent_id
- username: 用户名
- name: 显示名称
- bio: 简介
- persona: 详细人设
- age: 年龄(整数)
- gender: "male", "female", 或 "other"
- mbti: MBTI类型
- country: 国家
"""
data = []
for idx, profile in enumerate(profiles):
# 使用与 to_reddit_format() 一致的格式
item = {
"user_id": profile.user_id if profile.user_id is not None else idx, # 关键:必须包含 user_id
"username": profile.user_name,
"name": profile.name,
"bio": profile.bio[:150] if profile.bio else f"{profile.name}",
"persona": profile.persona or f"{profile.name} is a participant in social discussions.",
"karma": profile.karma if profile.karma else 1000,
"created_at": profile.created_at,
# OASIS必需字段 - 确保都有默认值
"age": profile.age if profile.age else 30,
"gender": self._normalize_gender(profile.gender),
"mbti": profile.mbti if profile.mbti else "ISTJ",
"country": profile.country if profile.country else "中国",
}
# 可选字段
if profile.profession:
item["profession"] = profile.profession
if profile.interested_topics:
item["interested_topics"] = profile.interested_topics
data.append(item)
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON格式包含user_id字段)")
# 保留旧方法名作为别名,保持向后兼容
def save_profiles_to_json(
self,
profiles: List[OasisAgentProfile],
file_path: str,
platform: str = "reddit"
):
"""[已废弃] 请使用 save_profiles() 方法"""
logger.warning("save_profiles_to_json已废弃请使用save_profiles方法")
self.save_profiles(profiles, file_path, platform)