""" Zep实体读取与过滤服务 从Zep图谱中读取节点,筛选出符合预定义实体类型的节点 """ import time from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from dataclasses import dataclass, field from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges logger = get_logger('mirofish.zep_entity_reader') # 用于泛型返回类型 T = TypeVar('T') @dataclass class EntityNode: """实体节点数据结构""" uuid: str name: str labels: List[str] summary: str attributes: Dict[str, Any] # 相关的边信息 related_edges: List[Dict[str, Any]] = field(default_factory=list) # 相关的其他节点信息 related_nodes: List[Dict[str, Any]] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, "name": self.name, "labels": self.labels, "summary": self.summary, "attributes": self.attributes, "related_edges": self.related_edges, "related_nodes": self.related_nodes, } def get_entity_type(self) -> Optional[str]: """获取实体类型(排除默认的Entity标签)""" for label in self.labels: if label not in ["Entity", "Node"]: return label return None @dataclass class FilteredEntities: """过滤后的实体集合""" entities: List[EntityNode] entity_types: Set[str] total_count: int filtered_count: int def to_dict(self) -> Dict[str, Any]: return { "entities": [e.to_dict() for e in self.entities], "entity_types": list(self.entity_types), "total_count": self.total_count, "filtered_count": self.filtered_count, } class ZepEntityReader: """ Zep实体读取与过滤服务 主要功能: 1. 从Zep图谱读取所有节点 2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点) 3. 获取每个实体的相关边和关联节点信息 """ def __init__(self, api_key: Optional[str] = None): self.api_key = api_key or Config.ZEP_API_KEY if not self.api_key: raise ValueError("ZEP_API_KEY is not configured") self.client = Zep(api_key=self.api_key) def _call_with_retry( self, func: Callable[[], T], operation_name: str, max_retries: int = 3, initial_delay: float = 2.0 ) -> T: """ 带重试机制的Zep API调用 Args: func: 要执行的函数(无参数的lambda或callable) operation_name: 操作名称,用于日志 max_retries: 最大重试次数(默认3次,即最多尝试3次) initial_delay: 初始延迟秒数 Returns: API调用结果 """ last_exception = None delay = initial_delay for attempt in range(max_retries): try: return func() except Exception as e: last_exception = e if attempt < max_retries - 1: logger.warning( f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, " f"{delay:.1f}秒后重试..." ) time.sleep(delay) delay *= 2 # 指数退避 else: logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}") raise last_exception def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: """ 获取图谱的所有节点(分页获取) Args: graph_id: 图谱ID Returns: 节点列表 """ logger.info(f"获取图谱 {graph_id} 的所有节点...") nodes = fetch_all_nodes(self.client, graph_id) nodes_data = [] for node in nodes: nodes_data.append({ "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), "name": node.name or "", "labels": node.labels or [], "summary": node.summary or "", "attributes": node.attributes or {}, }) logger.info(f"共获取 {len(nodes_data)} 个节点") return nodes_data def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: """ 获取图谱的所有边(分页获取) Args: graph_id: 图谱ID Returns: 边列表 """ logger.info(f"获取图谱 {graph_id} 的所有边...") edges = fetch_all_edges(self.client, graph_id) edges_data = [] for edge in edges: edges_data.append({ "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), "name": edge.name or "", "fact": edge.fact or "", "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, "attributes": edge.attributes or {}, }) logger.info(f"共获取 {len(edges_data)} 条边") return edges_data def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: """ 获取指定节点的所有相关边(带重试机制) Args: node_uuid: 节点UUID Returns: 边列表 """ try: # 使用重试机制调用Zep API edges = self._call_with_retry( func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), operation_name=f"获取节点边(node={node_uuid[:8]}...)" ) edges_data = [] for edge in edges: edges_data.append({ "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), "name": edge.name or "", "fact": edge.fact or "", "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, "attributes": edge.attributes or {}, }) return edges_data except Exception as e: logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}") return [] def filter_defined_entities( self, graph_id: str, defined_entity_types: Optional[List[str]] = None, enrich_with_edges: bool = True ) -> FilteredEntities: """ 筛选出符合预定义实体类型的节点 筛选逻辑: - 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过 - 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留 Args: graph_id: 图谱ID defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型) enrich_with_edges: 是否获取每个实体的相关边信息 Returns: FilteredEntities: 过滤后的实体集合 """ logger.info(f"开始筛选图谱 {graph_id} 的实体...") # 获取所有节点 all_nodes = self.get_all_nodes(graph_id) total_count = len(all_nodes) # 获取所有边(用于后续关联查找) all_edges = self.get_all_edges(graph_id) if enrich_with_edges else [] # 构建节点UUID到节点数据的映射 node_map = {n["uuid"]: n for n in all_nodes} # 筛选符合条件的实体 filtered_entities = [] entity_types_found = set() for node in all_nodes: labels = node.get("labels", []) # 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签 custom_labels = [l for l in labels if l not in ["Entity", "Node"]] if not custom_labels: # 只有默认标签,跳过 continue # 如果指定了预定义类型,检查是否匹配 if defined_entity_types: matching_labels = [l for l in custom_labels if l in defined_entity_types] if not matching_labels: continue entity_type = matching_labels[0] else: entity_type = custom_labels[0] entity_types_found.add(entity_type) # 创建实体节点对象 entity = EntityNode( uuid=node["uuid"], name=node["name"], labels=labels, summary=node["summary"], attributes=node["attributes"], ) # 获取相关边和节点 if enrich_with_edges: related_edges = [] related_node_uuids = set() for edge in all_edges: if edge["source_node_uuid"] == node["uuid"]: related_edges.append({ "direction": "outgoing", "edge_name": edge["name"], "fact": edge["fact"], "target_node_uuid": edge["target_node_uuid"], }) related_node_uuids.add(edge["target_node_uuid"]) elif edge["target_node_uuid"] == node["uuid"]: related_edges.append({ "direction": "incoming", "edge_name": edge["name"], "fact": edge["fact"], "source_node_uuid": edge["source_node_uuid"], }) related_node_uuids.add(edge["source_node_uuid"]) entity.related_edges = related_edges # 获取关联节点的基本信息 related_nodes = [] for related_uuid in related_node_uuids: if related_uuid in node_map: related_node = node_map[related_uuid] related_nodes.append({ "uuid": related_node["uuid"], "name": related_node["name"], "labels": related_node["labels"], "summary": related_node.get("summary", ""), }) entity.related_nodes = related_nodes filtered_entities.append(entity) logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, " f"实体类型: {entity_types_found}") return FilteredEntities( entities=filtered_entities, entity_types=entity_types_found, total_count=total_count, filtered_count=len(filtered_entities), ) def get_entity_with_context( self, graph_id: str, entity_uuid: str ) -> Optional[EntityNode]: """ 获取单个实体及其完整上下文(边和关联节点,带重试机制) Args: graph_id: 图谱ID entity_uuid: 实体UUID Returns: EntityNode或None """ try: # 使用重试机制获取节点 node = self._call_with_retry( func=lambda: self.client.graph.node.get(uuid_=entity_uuid), operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" ) if not node: return None # 获取节点的边 edges = self.get_node_edges(entity_uuid) # 获取所有节点用于关联查找 all_nodes = self.get_all_nodes(graph_id) node_map = {n["uuid"]: n for n in all_nodes} # 处理相关边和节点 related_edges = [] related_node_uuids = set() for edge in edges: if edge["source_node_uuid"] == entity_uuid: related_edges.append({ "direction": "outgoing", "edge_name": edge["name"], "fact": edge["fact"], "target_node_uuid": edge["target_node_uuid"], }) related_node_uuids.add(edge["target_node_uuid"]) else: related_edges.append({ "direction": "incoming", "edge_name": edge["name"], "fact": edge["fact"], "source_node_uuid": edge["source_node_uuid"], }) related_node_uuids.add(edge["source_node_uuid"]) # 获取关联节点信息 related_nodes = [] for related_uuid in related_node_uuids: if related_uuid in node_map: related_node = node_map[related_uuid] related_nodes.append({ "uuid": related_node["uuid"], "name": related_node["name"], "labels": related_node["labels"], "summary": related_node.get("summary", ""), }) return EntityNode( uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), name=node.name or "", labels=node.labels or [], summary=node.summary or "", attributes=node.attributes or {}, related_edges=related_edges, related_nodes=related_nodes, ) except Exception as e: logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}") return None def get_entities_by_type( self, graph_id: str, entity_type: str, enrich_with_edges: bool = True ) -> List[EntityNode]: """ 获取指定类型的所有实体 Args: graph_id: 图谱ID entity_type: 实体类型(如 "Student", "PublicFigure" 等) enrich_with_edges: 是否获取相关边信息 Returns: 实体列表 """ result = self.filter_defined_entities( graph_id=graph_id, defined_entity_types=[entity_type], enrich_with_edges=enrich_with_edges ) return result.entities