""" OASIS Agent Profile生成器 将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式 优化改进: 1. 调用Zep检索功能二次丰富节点信息 2. 优化提示词生成非常详细的人设 3. 区分个人实体和抽象群体实体 """ import json import random import time from typing import Dict, Any, List, Optional from dataclasses import dataclass, field from datetime import datetime from openai import OpenAI from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger from .zep_entity_reader import EntityNode, ZepEntityReader logger = get_logger('mirofish.oasis_profile') @dataclass class OasisAgentProfile: """OASIS Agent Profile数据结构""" # 通用字段 user_id: int user_name: str name: str bio: str persona: str # 可选字段 - Reddit风格 karma: int = 1000 # 可选字段 - Twitter风格 friend_count: int = 100 follower_count: int = 150 statuses_count: int = 500 # 额外人设信息 age: Optional[int] = None gender: Optional[str] = None mbti: Optional[str] = None country: Optional[str] = None profession: Optional[str] = None interested_topics: List[str] = field(default_factory=list) # 来源实体信息 source_entity_uuid: Optional[str] = None source_entity_type: Optional[str] = None created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d")) def to_reddit_format(self) -> Dict[str, Any]: """转换为Reddit平台格式""" profile = { "user_id": self.user_id, "username": self.user_name, # OASIS 库要求字段名为 username(无下划线) "name": self.name, "bio": self.bio, "persona": self.persona, "karma": self.karma, "created_at": self.created_at, } # 添加额外人设信息(如果有) if self.age: profile["age"] = self.age if self.gender: profile["gender"] = self.gender if self.mbti: profile["mbti"] = self.mbti if self.country: profile["country"] = self.country if self.profession: profile["profession"] = self.profession if self.interested_topics: profile["interested_topics"] = self.interested_topics return profile def to_twitter_format(self) -> Dict[str, Any]: """转换为Twitter平台格式""" profile = { "user_id": self.user_id, "username": self.user_name, # OASIS 库要求字段名为 username(无下划线) "name": self.name, "bio": self.bio, "persona": self.persona, "friend_count": self.friend_count, "follower_count": self.follower_count, "statuses_count": self.statuses_count, "created_at": self.created_at, } # 添加额外人设信息 if self.age: profile["age"] = self.age if self.gender: profile["gender"] = self.gender if self.mbti: profile["mbti"] = self.mbti if self.country: profile["country"] = self.country if self.profession: profile["profession"] = self.profession if self.interested_topics: profile["interested_topics"] = self.interested_topics return profile def to_dict(self) -> Dict[str, Any]: """转换为完整字典格式""" return { "user_id": self.user_id, "user_name": self.user_name, "name": self.name, "bio": self.bio, "persona": self.persona, "karma": self.karma, "friend_count": self.friend_count, "follower_count": self.follower_count, "statuses_count": self.statuses_count, "age": self.age, "gender": self.gender, "mbti": self.mbti, "country": self.country, "profession": self.profession, "interested_topics": self.interested_topics, "source_entity_uuid": self.source_entity_uuid, "source_entity_type": self.source_entity_type, "created_at": self.created_at, } class OasisProfileGenerator: """ OASIS Profile生成器 将Zep图谱中的实体转换为OASIS模拟所需的Agent Profile 优化特性: 1. 调用Zep图谱检索功能获取更丰富的上下文 2. 生成非常详细的人设(包括基本信息、职业经历、性格特征、社交媒体行为等) 3. 区分个人实体和抽象群体实体 """ # MBTI类型列表 MBTI_TYPES = [ "INTJ", "INTP", "ENTJ", "ENTP", "INFJ", "INFP", "ENFJ", "ENFP", "ISTJ", "ISFJ", "ESTJ", "ESFJ", "ISTP", "ISFP", "ESTP", "ESFP" ] # 常见国家列表 COUNTRIES = [ "China", "US", "UK", "Japan", "Germany", "France", "Canada", "Australia", "Brazil", "India", "South Korea" ] # 个人类型实体(需要生成具体人设) INDIVIDUAL_ENTITY_TYPES = [ "student", "alumni", "professor", "person", "publicfigure", "expert", "faculty", "official", "journalist", "activist" ] # 群体/机构类型实体(需要生成群体代表人设) GROUP_ENTITY_TYPES = [ "university", "governmentagency", "organization", "ngo", "mediaoutlet", "company", "institution", "group", "community" ] def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model_name: Optional[str] = None, zep_api_key: Optional[str] = None, graph_id: Optional[str] = None ): self.api_key = api_key or Config.LLM_API_KEY self.base_url = base_url or Config.LLM_BASE_URL self.model_name = model_name or Config.LLM_MODEL_NAME if not self.api_key: raise ValueError("LLM_API_KEY is not configured") self.client = OpenAI( api_key=self.api_key, base_url=self.base_url ) # Zep客户端用于检索丰富上下文 self.zep_api_key = zep_api_key or Config.ZEP_API_KEY self.zep_client = None self.graph_id = graph_id if self.zep_api_key: try: self.zep_client = Zep(api_key=self.zep_api_key) except Exception as e: logger.warning(f"Zep客户端初始化失败: {e}") def generate_profile_from_entity( self, entity: EntityNode, user_id: int, use_llm: bool = True ) -> OasisAgentProfile: """ 从Zep实体生成OASIS Agent Profile Args: entity: Zep实体节点 user_id: 用户ID(用于OASIS) use_llm: 是否使用LLM生成详细人设 Returns: OasisAgentProfile """ entity_type = entity.get_entity_type() or "Entity" # 基础信息 name = entity.name user_name = self._generate_username(name) # 构建上下文信息 context = self._build_entity_context(entity) if use_llm: # 使用LLM生成详细人设 profile_data = self._generate_profile_with_llm( entity_name=name, entity_type=entity_type, entity_summary=entity.summary, entity_attributes=entity.attributes, context=context ) else: # 使用规则生成基础人设 profile_data = self._generate_profile_rule_based( entity_name=name, entity_type=entity_type, entity_summary=entity.summary, entity_attributes=entity.attributes ) return OasisAgentProfile( user_id=user_id, user_name=user_name, name=name, bio=profile_data.get("bio", f"{entity_type}: {name}"), persona=profile_data.get("persona", entity.summary or f"A {entity_type} named {name}."), karma=profile_data.get("karma", random.randint(500, 5000)), friend_count=profile_data.get("friend_count", random.randint(50, 500)), follower_count=profile_data.get("follower_count", random.randint(100, 1000)), statuses_count=profile_data.get("statuses_count", random.randint(100, 2000)), age=profile_data.get("age"), gender=profile_data.get("gender"), mbti=profile_data.get("mbti"), country=profile_data.get("country"), profession=profile_data.get("profession"), interested_topics=profile_data.get("interested_topics", []), source_entity_uuid=entity.uuid, source_entity_type=entity_type, ) def _generate_username(self, name: str) -> str: """生成用户名""" # 移除特殊字符,转换为小写 username = name.lower().replace(" ", "_") username = ''.join(c for c in username if c.isalnum() or c == '_') # 添加随机后缀避免重复 suffix = random.randint(100, 999) return f"{username}_{suffix}" def _search_zep_for_entity(self, entity: EntityNode) -> Dict[str, Any]: """ 使用Zep图谱混合搜索功能获取实体相关的丰富信息 Zep没有内置混合搜索接口,需要分别搜索edges和nodes然后合并结果。 使用并行请求同时搜索,提高效率。 Args: entity: 实体节点对象 Returns: 包含facts, node_summaries, context的字典 """ import concurrent.futures if not self.zep_client: return {"facts": [], "node_summaries": [], "context": ""} entity_name = entity.name results = { "facts": [], "node_summaries": [], "context": "" } # 必须有graph_id才能进行搜索 if not self.graph_id: logger.debug(f"跳过Zep检索:未设置graph_id") return results comprehensive_query = f"All information, activities, events, relationships, and background about {entity_name}" def search_edges(): """搜索边(事实/关系)- 带重试机制""" max_retries = 3 last_exception = None delay = 2.0 for attempt in range(max_retries): try: return self.zep_client.graph.search( query=comprehensive_query, graph_id=self.graph_id, limit=30, scope="edges", reranker="rrf" ) except Exception as e: last_exception = e if attempt < max_retries - 1: logger.debug(f"Zep边搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...") time.sleep(delay) delay *= 2 else: logger.debug(f"Zep边搜索在 {max_retries} 次尝试后仍失败: {e}") return None def search_nodes(): """搜索节点(实体摘要)- 带重试机制""" max_retries = 3 last_exception = None delay = 2.0 for attempt in range(max_retries): try: return self.zep_client.graph.search( query=comprehensive_query, graph_id=self.graph_id, limit=20, scope="nodes", reranker="rrf" ) except Exception as e: last_exception = e if attempt < max_retries - 1: logger.debug(f"Zep节点搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...") time.sleep(delay) delay *= 2 else: logger.debug(f"Zep节点搜索在 {max_retries} 次尝试后仍失败: {e}") return None try: # 并行执行edges和nodes搜索 with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: edge_future = executor.submit(search_edges) node_future = executor.submit(search_nodes) # 获取结果 edge_result = edge_future.result(timeout=30) node_result = node_future.result(timeout=30) # 处理边搜索结果 all_facts = set() if edge_result and hasattr(edge_result, 'edges') and edge_result.edges: for edge in edge_result.edges: if hasattr(edge, 'fact') and edge.fact: all_facts.add(edge.fact) results["facts"] = list(all_facts) # 处理节点搜索结果 all_summaries = set() if node_result and hasattr(node_result, 'nodes') and node_result.nodes: for node in node_result.nodes: if hasattr(node, 'summary') and node.summary: all_summaries.add(node.summary) if hasattr(node, 'name') and node.name and node.name != entity_name: all_summaries.add(f"Related entity: {node.name}") results["node_summaries"] = list(all_summaries) # 构建综合上下文 context_parts = [] if results["facts"]: context_parts.append("Factual information:\n" + "\n".join(f"- {f}" for f in results["facts"][:20])) if results["node_summaries"]: context_parts.append("Related entities:\n" + "\n".join(f"- {s}" for s in results["node_summaries"][:10])) results["context"] = "\n\n".join(context_parts) logger.info(f"Zep混合检索完成: {entity_name}, 获取 {len(results['facts'])} 条事实, {len(results['node_summaries'])} 个相关节点") except concurrent.futures.TimeoutError: logger.warning(f"Zep检索超时 ({entity_name})") except Exception as e: logger.warning(f"Zep检索失败 ({entity_name}): {e}") return results def _build_entity_context(self, entity: EntityNode) -> str: """ 构建实体的完整上下文信息 包括: 1. 实体本身的边信息(事实) 2. 关联节点的详细信息 3. Zep混合检索到的丰富信息 """ context_parts = [] # 1. 添加实体属性信息 if entity.attributes: attrs = [] for key, value in entity.attributes.items(): if value and str(value).strip(): attrs.append(f"- {key}: {value}") if attrs: context_parts.append("### Entity Attributes\n" + "\n".join(attrs)) # 2. 添加相关边信息(事实/关系) existing_facts = set() if entity.related_edges: relationships = [] for edge in entity.related_edges: # 不限制数量 fact = edge.get("fact", "") edge_name = edge.get("edge_name", "") direction = edge.get("direction", "") if fact: relationships.append(f"- {fact}") existing_facts.add(fact) elif edge_name: if direction == "outgoing": relationships.append(f"- {entity.name} --[{edge_name}]--> (related entity)") else: relationships.append(f"- (related entity) --[{edge_name}]--> {entity.name}") if relationships: context_parts.append("### Related Facts and Relationships\n" + "\n".join(relationships)) # 3. 添加关联节点的详细信息 if entity.related_nodes: related_info = [] for node in entity.related_nodes: # 不限制数量 node_name = node.get("name", "") node_labels = node.get("labels", []) node_summary = node.get("summary", "") # 过滤掉默认标签 custom_labels = [l for l in node_labels if l not in ["Entity", "Node"]] label_str = f" ({', '.join(custom_labels)})" if custom_labels else "" if node_summary: related_info.append(f"- **{node_name}**{label_str}: {node_summary}") else: related_info.append(f"- **{node_name}**{label_str}") if related_info: context_parts.append("### Related Entity Information\n" + "\n".join(related_info)) # 4. 使用Zep混合检索获取更丰富的信息 zep_results = self._search_zep_for_entity(entity) if zep_results.get("facts"): # 去重:排除已存在的事实 new_facts = [f for f in zep_results["facts"] if f not in existing_facts] if new_facts: context_parts.append("### Facts Retrieved from Zep\n" + "\n".join(f"- {f}" for f in new_facts[:15])) if zep_results.get("node_summaries"): context_parts.append("### Related Nodes Retrieved from Zep\n" + "\n".join(f"- {s}" for s in zep_results["node_summaries"][:10])) return "\n\n".join(context_parts) def _is_individual_entity(self, entity_type: str) -> bool: """判断是否是个人类型实体""" return entity_type.lower() in self.INDIVIDUAL_ENTITY_TYPES def _is_group_entity(self, entity_type: str) -> bool: """判断是否是群体/机构类型实体""" return entity_type.lower() in self.GROUP_ENTITY_TYPES def _generate_profile_with_llm( self, entity_name: str, entity_type: str, entity_summary: str, entity_attributes: Dict[str, Any], context: str ) -> Dict[str, Any]: """ 使用LLM生成非常详细的人设 根据实体类型区分: - 个人实体:生成具体的人物设定 - 群体/机构实体:生成代表性账号设定 """ is_individual = self._is_individual_entity(entity_type) if is_individual: prompt = self._build_individual_persona_prompt( entity_name, entity_type, entity_summary, entity_attributes, context ) else: prompt = self._build_group_persona_prompt( entity_name, entity_type, entity_summary, entity_attributes, context ) # 尝试多次生成,直到成功或达到最大重试次数 max_attempts = 3 last_error = None for attempt in range(max_attempts): try: response = self.client.chat.completions.create( model=self.model_name, messages=[ {"role": "system", "content": self._get_system_prompt(is_individual)}, {"role": "user", "content": prompt} ], response_format={"type": "json_object"}, temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 # 不设置max_tokens,让LLM自由发挥 ) content = response.choices[0].message.content # 检查是否被截断(finish_reason不是'stop') finish_reason = response.choices[0].finish_reason if finish_reason == 'length': logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...") content = self._fix_truncated_json(content) # 尝试解析JSON try: result = json.loads(content) # 验证必需字段 if "bio" not in result or not result["bio"]: result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}" if "persona" not in result or not result["persona"]: result["persona"] = entity_summary or f"{entity_name} is a {entity_type}." return result except json.JSONDecodeError as je: logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}") # 尝试修复JSON result = self._try_fix_json(content, entity_name, entity_type, entity_summary) if result.get("_fixed"): del result["_fixed"] return result last_error = je except Exception as e: logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}") last_error = e import time time.sleep(1 * (attempt + 1)) # 指数退避 logger.warning(f"LLM生成人设失败({max_attempts}次尝试): {last_error}, 使用规则生成") return self._generate_profile_rule_based( entity_name, entity_type, entity_summary, entity_attributes ) def _fix_truncated_json(self, content: str) -> str: """修复被截断的JSON(输出被max_tokens限制截断)""" import re # 如果JSON被截断,尝试闭合它 content = content.strip() # 计算未闭合的括号 open_braces = content.count('{') - content.count('}') open_brackets = content.count('[') - content.count(']') # 检查是否有未闭合的字符串 # 简单检查:如果最后一个引号后没有逗号或闭合括号,可能是字符串被截断 if content and content[-1] not in '",}]': # 尝试闭合字符串 content += '"' # 闭合括号 content += ']' * open_brackets content += '}' * open_braces return content def _try_fix_json(self, content: str, entity_name: str, entity_type: str, entity_summary: str = "") -> Dict[str, Any]: """尝试修复损坏的JSON""" import re # 1. 首先尝试修复被截断的情况 content = self._fix_truncated_json(content) # 2. 尝试提取JSON部分 json_match = re.search(r'\{[\s\S]*\}', content) if json_match: json_str = json_match.group() # 3. 处理字符串中的换行符问题 # 找到所有字符串值并替换其中的换行符 def fix_string_newlines(match): s = match.group(0) # 替换字符串内的实际换行符为空格 s = s.replace('\n', ' ').replace('\r', ' ') # 替换多余空格 s = re.sub(r'\s+', ' ', s) return s # 匹配JSON字符串值 json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str) # 4. 尝试解析 try: result = json.loads(json_str) result["_fixed"] = True return result except json.JSONDecodeError as e: # 5. 如果还是失败,尝试更激进的修复 try: # 移除所有控制字符 json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str) # 替换所有连续空白 json_str = re.sub(r'\s+', ' ', json_str) result = json.loads(json_str) result["_fixed"] = True return result except: pass # 6. 尝试从内容中提取部分信息 bio_match = re.search(r'"bio"\s*:\s*"([^"]*)"', content) persona_match = re.search(r'"persona"\s*:\s*"([^"]*)', content) # 可能被截断 bio = bio_match.group(1) if bio_match else (entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}") persona = persona_match.group(1) if persona_match else (entity_summary or f"{entity_name}是一个{entity_type}。") # 如果提取到了有意义的内容,标记为已修复 if bio_match or persona_match: logger.info(f"从损坏的JSON中提取了部分信息") return { "bio": bio, "persona": persona, "_fixed": True } # 7. 完全失败,返回基础结构 logger.warning(f"JSON修复失败,返回基础结构") return { "bio": entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}", "persona": entity_summary or f"{entity_name}是一个{entity_type}。" } def _get_system_prompt(self, is_individual: bool) -> str: """获取系统提示词""" base_prompt = "You are a social media user profile generation expert. Generate detailed, realistic personas for public opinion simulation, maximally faithful to existing real-world information. You must return valid JSON format, all string values must not contain unescaped newlines." return base_prompt def _build_individual_persona_prompt( self, entity_name: str, entity_type: str, entity_summary: str, entity_attributes: Dict[str, Any], context: str ) -> str: """构建个人实体的详细人设提示词""" attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "None" context_str = context[:3000] if context else "No additional context" return f"""Generate a detailed social media user persona for this entity, maximally faithful to existing real-world information. Entity Name: {entity_name} Entity Type: {entity_type} Entity Summary: {entity_summary} Entity Attributes: {attrs_str} Context Information: {context_str} Generate JSON with the following fields: 1. bio: Social media biography, 200 words 2. persona: Detailed persona description (2000 words of plain text), must include: - Basic information (age, profession, educational background, location) - Background (important experiences, connection to events, social relationships) - Personality traits (MBTI type, core personality, emotional expression style) - Social media behavior (posting frequency, content preferences, interaction style, language characteristics) - Stances and opinions (attitudes toward topics, content that may provoke or move them) - Unique traits (catchphrases, special experiences, personal hobbies) - Personal memories (important part of the persona, describe this individual's connection to events, and their existing actions and reactions in those events) 3. age: Age as a number (must be an integer) 4. gender: Must be in English: "male" or "female" 5. mbti: MBTI type (e.g., INTJ, ENFP) 6. country: Country name 7. profession: Profession 8. interested_topics: Array of topics of interest Important: - All field values must be strings or numbers, do not use newline characters - persona must be a coherent text description - Write in English (gender field must be "male" or "female") - Content must be consistent with entity information - age must be a valid integer, gender must be "male" or "female" """ def _build_group_persona_prompt( self, entity_name: str, entity_type: str, entity_summary: str, entity_attributes: Dict[str, Any], context: str ) -> str: """构建群体/机构实体的详细人设提示词""" attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "None" context_str = context[:3000] if context else "No additional context" return f"""Generate a detailed social media account profile for this institution/group entity, maximally faithful to existing real-world information. Entity Name: {entity_name} Entity Type: {entity_type} Entity Summary: {entity_summary} Entity Attributes: {attrs_str} Context Information: {context_str} Generate JSON with the following fields: 1. bio: Official account biography, 200 words, professional and appropriate 2. persona: Detailed account profile description (2000 words of plain text), must include: - Institutional basic information (official name, nature, founding background, primary functions) - Account positioning (account type, target audience, core functions) - Communication style (language characteristics, common expressions, taboo topics) - Content characteristics (content types, posting frequency, active time periods) - Stances and attitudes (official positions on core topics, approach to handling controversies) - Special notes (group portrait represented, operational habits) - Institutional memory (important part of the persona, describe this institution's connection to events, and its existing actions and reactions in those events) 3. age: Fixed value 30 (virtual age for institutional accounts) 4. gender: Fixed value "other" (institutional accounts use "other" to indicate non-personal) 5. mbti: MBTI type to describe account style, e.g., ISTJ for rigorous and conservative 6. country: Country name 7. profession: Institutional function description 8. interested_topics: Array of focus areas Important: - All field values must be strings or numbers, null values are not allowed - persona must be a coherent text description, do not use newline characters - Write in English (gender field must be "other") - age must be integer 30, gender must be string "other" - Institutional account communication must align with its identity positioning""" def _generate_profile_rule_based( self, entity_name: str, entity_type: str, entity_summary: str, entity_attributes: Dict[str, Any] ) -> Dict[str, Any]: """使用规则生成基础人设""" # 根据实体类型生成不同的人设 entity_type_lower = entity_type.lower() if entity_type_lower in ["student", "alumni"]: return { "bio": f"{entity_type} with interests in academics and social issues.", "persona": f"{entity_name} is a {entity_type.lower()} who is actively engaged in academic and social discussions. They enjoy sharing perspectives and connecting with peers.", "age": random.randint(18, 30), "gender": random.choice(["male", "female"]), "mbti": random.choice(self.MBTI_TYPES), "country": random.choice(self.COUNTRIES), "profession": "Student", "interested_topics": ["Education", "Social Issues", "Technology"], } elif entity_type_lower in ["publicfigure", "expert", "faculty"]: return { "bio": f"Expert and thought leader in their field.", "persona": f"{entity_name} is a recognized {entity_type.lower()} who shares insights and opinions on important matters. They are known for their expertise and influence in public discourse.", "age": random.randint(35, 60), "gender": random.choice(["male", "female"]), "mbti": random.choice(["ENTJ", "INTJ", "ENTP", "INTP"]), "country": random.choice(self.COUNTRIES), "profession": entity_attributes.get("occupation", "Expert"), "interested_topics": ["Politics", "Economics", "Culture & Society"], } elif entity_type_lower in ["mediaoutlet", "socialmediaplatform"]: return { "bio": f"Official account for {entity_name}. News and updates.", "persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.", "age": 30, # 机构虚拟年龄 "gender": "other", # 机构使用other "mbti": "ISTJ", # 机构风格:严谨保守 "country": "China", "profession": "Media", "interested_topics": ["General News", "Current Events", "Public Affairs"], } elif entity_type_lower in ["university", "governmentagency", "ngo", "organization"]: return { "bio": f"Official account of {entity_name}.", "persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.", "age": 30, # 机构虚拟年龄 "gender": "other", # 机构使用other "mbti": "ISTJ", # 机构风格:严谨保守 "country": "China", "profession": entity_type, "interested_topics": ["Public Policy", "Community", "Official Announcements"], } else: # 默认人设 return { "bio": entity_summary[:150] if entity_summary else f"{entity_type}: {entity_name}", "persona": entity_summary or f"{entity_name} is a {entity_type.lower()} participating in social discussions.", "age": random.randint(25, 50), "gender": random.choice(["male", "female"]), "mbti": random.choice(self.MBTI_TYPES), "country": random.choice(self.COUNTRIES), "profession": entity_type, "interested_topics": ["General", "Social Issues"], } def set_graph_id(self, graph_id: str): """设置图谱ID用于Zep检索""" self.graph_id = graph_id def generate_profiles_from_entities( self, entities: List[EntityNode], use_llm: bool = True, progress_callback: Optional[callable] = None, graph_id: Optional[str] = None, parallel_count: int = 5, realtime_output_path: Optional[str] = None, output_platform: str = "reddit" ) -> List[OasisAgentProfile]: """ 批量从实体生成Agent Profile(支持并行生成) Args: entities: 实体列表 use_llm: 是否使用LLM生成详细人设 progress_callback: 进度回调函数 (current, total, message) graph_id: 图谱ID,用于Zep检索获取更丰富上下文 parallel_count: 并行生成数量,默认5 realtime_output_path: 实时写入的文件路径(如果提供,每生成一个就写入一次) output_platform: 输出平台格式 ("reddit" 或 "twitter") Returns: Agent Profile列表 """ import concurrent.futures from threading import Lock # 设置graph_id用于Zep检索 if graph_id: self.graph_id = graph_id total = len(entities) profiles = [None] * total # 预分配列表保持顺序 completed_count = [0] # 使用列表以便在闭包中修改 lock = Lock() # 实时写入文件的辅助函数 def save_profiles_realtime(): """实时保存已生成的 profiles 到文件""" if not realtime_output_path: return with lock: # 过滤出已生成的 profiles existing_profiles = [p for p in profiles if p is not None] if not existing_profiles: return try: if output_platform == "reddit": # Reddit JSON 格式 profiles_data = [p.to_reddit_format() for p in existing_profiles] with open(realtime_output_path, 'w', encoding='utf-8') as f: json.dump(profiles_data, f, ensure_ascii=False, indent=2) else: # Twitter CSV 格式 import csv profiles_data = [p.to_twitter_format() for p in existing_profiles] if profiles_data: fieldnames = list(profiles_data[0].keys()) with open(realtime_output_path, 'w', encoding='utf-8', newline='') as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(profiles_data) except Exception as e: logger.warning(f"实时保存 profiles 失败: {e}") def generate_single_profile(idx: int, entity: EntityNode) -> tuple: """生成单个profile的工作函数""" entity_type = entity.get_entity_type() or "Entity" try: profile = self.generate_profile_from_entity( entity=entity, user_id=idx, use_llm=use_llm ) # 实时输出生成的人设到控制台和日志 self._print_generated_profile(entity.name, entity_type, profile) return idx, profile, None except Exception as e: logger.error(f"生成实体 {entity.name} 的人设失败: {str(e)}") # 创建一个基础profile fallback_profile = OasisAgentProfile( user_id=idx, user_name=self._generate_username(entity.name), name=entity.name, bio=f"{entity_type}: {entity.name}", persona=entity.summary or f"A participant in social discussions.", source_entity_uuid=entity.uuid, source_entity_type=entity_type, ) return idx, fallback_profile, str(e) logger.info(f"开始并行生成 {total} 个Agent人设(并行数: {parallel_count})...") print(f"\n{'='*60}") print(f"Starting agent persona generation - {total} entities, parallel count: {parallel_count}") print(f"{'='*60}\n") # 使用线程池并行执行 with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_count) as executor: # 提交所有任务 future_to_entity = { executor.submit(generate_single_profile, idx, entity): (idx, entity) for idx, entity in enumerate(entities) } # 收集结果 for future in concurrent.futures.as_completed(future_to_entity): idx, entity = future_to_entity[future] entity_type = entity.get_entity_type() or "Entity" try: result_idx, profile, error = future.result() profiles[result_idx] = profile with lock: completed_count[0] += 1 current = completed_count[0] # 实时写入文件 save_profiles_realtime() if progress_callback: progress_callback( current, total, f"Completed {current}/{total}: {entity.name} ({entity_type})" ) if error: logger.warning(f"[{current}/{total}] {entity.name} 使用备用人设: {error}") else: logger.info(f"[{current}/{total}] 成功生成人设: {entity.name} ({entity_type})") except Exception as e: logger.error(f"处理实体 {entity.name} 时发生异常: {str(e)}") with lock: completed_count[0] += 1 profiles[idx] = OasisAgentProfile( user_id=idx, user_name=self._generate_username(entity.name), name=entity.name, bio=f"{entity_type}: {entity.name}", persona=entity.summary or "A participant in social discussions.", source_entity_uuid=entity.uuid, source_entity_type=entity_type, ) # 实时写入文件(即使是备用人设) save_profiles_realtime() print(f"\n{'='*60}") print(f"Persona generation complete! Generated {len([p for p in profiles if p])} agents") print(f"{'='*60}\n") return profiles def _print_generated_profile(self, entity_name: str, entity_type: str, profile: OasisAgentProfile): """实时输出生成的人设到控制台(完整内容,不截断)""" separator = "-" * 70 # 构建完整输出内容(不截断) topics_str = ', '.join(profile.interested_topics) if profile.interested_topics else 'None' output_lines = [ f"\n{separator}", f"[Generated] {entity_name} ({entity_type})", f"{separator}", f"Username: {profile.user_name}", f"", f"[Bio]", f"{profile.bio}", f"", f"[Detailed Persona]", f"{profile.persona}", f"", f"[Basic Attributes]", f"Age: {profile.age} | Gender: {profile.gender} | MBTI: {profile.mbti}", f"Profession: {profile.profession} | Country: {profile.country}", f"Interests: {topics_str}", separator ] output = "\n".join(output_lines) # 只输出到控制台(避免重复,logger不再输出完整内容) print(output) def save_profiles( self, profiles: List[OasisAgentProfile], file_path: str, platform: str = "reddit" ): """ 保存Profile到文件(根据平台选择正确格式) OASIS平台格式要求: - Twitter: CSV格式 - Reddit: JSON格式 Args: profiles: Profile列表 file_path: 文件路径 platform: 平台类型 ("reddit" 或 "twitter") """ if platform == "twitter": self._save_twitter_csv(profiles, file_path) else: self._save_reddit_json(profiles, file_path) def _save_twitter_csv(self, profiles: List[OasisAgentProfile], file_path: str): """ 保存Twitter Profile为CSV格式(符合OASIS官方要求) OASIS Twitter要求的CSV字段: - user_id: 用户ID(根据CSV顺序从0开始) - name: 用户真实姓名 - username: 系统中的用户名 - user_char: 详细人设描述(注入到LLM系统提示中,指导Agent行为) - description: 简短的公开简介(显示在用户资料页面) user_char vs description 区别: - user_char: 内部使用,LLM系统提示,决定Agent如何思考和行动 - description: 外部显示,其他用户可见的简介 """ import csv # 确保文件扩展名是.csv if not file_path.endswith('.csv'): file_path = file_path.replace('.json', '.csv') with open(file_path, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f) # 写入OASIS要求的表头 headers = ['user_id', 'name', 'username', 'user_char', 'description'] writer.writerow(headers) # 写入数据行 for idx, profile in enumerate(profiles): # user_char: 完整人设(bio + persona),用于LLM系统提示 user_char = profile.bio if profile.persona and profile.persona != profile.bio: user_char = f"{profile.bio} {profile.persona}" # 处理换行符(CSV中用空格替代) user_char = user_char.replace('\n', ' ').replace('\r', ' ') # description: 简短简介,用于外部显示 description = profile.bio.replace('\n', ' ').replace('\r', ' ') row = [ idx, # user_id: 从0开始的顺序ID profile.name, # name: 真实姓名 profile.user_name, # username: 用户名 user_char, # user_char: 完整人设(内部LLM使用) description # description: 简短简介(外部显示) ] writer.writerow(row) logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (OASIS CSV格式)") def _normalize_gender(self, gender: Optional[str]) -> str: """ 标准化gender字段为OASIS要求的英文格式 OASIS要求: male, female, other """ if not gender: return "other" gender_lower = gender.lower().strip() # 中文映射 gender_map = { "男": "male", "女": "female", "机构": "other", "其他": "other", # 英文已有 "male": "male", "female": "female", "other": "other", } return gender_map.get(gender_lower, "other") def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str): """ 保存Reddit Profile为JSON格式 使用与 to_reddit_format() 一致的格式,确保 OASIS 能正确读取。 必须包含 user_id 字段,这是 OASIS agent_graph.get_agent() 匹配的关键! 必需字段: - user_id: 用户ID(整数,用于匹配 initial_posts 中的 poster_agent_id) - username: 用户名 - name: 显示名称 - bio: 简介 - persona: 详细人设 - age: 年龄(整数) - gender: "male", "female", 或 "other" - mbti: MBTI类型 - country: 国家 """ data = [] for idx, profile in enumerate(profiles): # 使用与 to_reddit_format() 一致的格式 item = { "user_id": profile.user_id if profile.user_id is not None else idx, # 关键:必须包含 user_id "username": profile.user_name, "name": profile.name, "bio": profile.bio[:150] if profile.bio else f"{profile.name}", "persona": profile.persona or f"{profile.name} is a participant in social discussions.", "karma": profile.karma if profile.karma else 1000, "created_at": profile.created_at, # OASIS必需字段 - 确保都有默认值 "age": profile.age if profile.age else 30, "gender": self._normalize_gender(profile.gender), "mbti": profile.mbti if profile.mbti else "ISTJ", "country": profile.country if profile.country else "中国", } # 可选字段 if profile.profession: item["profession"] = profile.profession if profile.interested_topics: item["interested_topics"] = profile.interested_topics data.append(item) with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON格式,包含user_id字段)") # 保留旧方法名作为别名,保持向后兼容 def save_profiles_to_json( self, profiles: List[OasisAgentProfile], file_path: str, platform: str = "reddit" ): """[已废弃] 请使用 save_profiles() 方法""" logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法") self.save_profiles(profiles, file_path, platform)