""" Zep检索工具服务 封装图谱搜索、节点读取、边查询等工具,供Report Agent使用 """ import time from typing import Dict, Any, List, Optional from dataclasses import dataclass from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger logger = get_logger('mirofish.zep_tools') @dataclass class SearchResult: """搜索结果""" facts: List[str] edges: List[Dict[str, Any]] nodes: List[Dict[str, Any]] query: str total_count: int def to_dict(self) -> Dict[str, Any]: return { "facts": self.facts, "edges": self.edges, "nodes": self.nodes, "query": self.query, "total_count": self.total_count } def to_text(self) -> str: """转换为文本格式,供LLM理解""" text_parts = [f"搜索查询: {self.query}", f"找到 {self.total_count} 条相关信息"] if self.facts: text_parts.append("\n### 相关事实:") for i, fact in enumerate(self.facts, 1): text_parts.append(f"{i}. {fact}") return "\n".join(text_parts) @dataclass class NodeInfo: """节点信息""" uuid: str name: str labels: List[str] summary: str attributes: Dict[str, Any] def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, "name": self.name, "labels": self.labels, "summary": self.summary, "attributes": self.attributes } def to_text(self) -> str: """转换为文本格式""" entity_type = next((l for l in self.labels if l not in ["Entity", "Node"]), "未知类型") return f"实体: {self.name} (类型: {entity_type})\n摘要: {self.summary}" @dataclass class EdgeInfo: """边信息""" uuid: str name: str fact: str source_node_uuid: str target_node_uuid: str source_node_name: Optional[str] = None target_node_name: Optional[str] = None def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, "name": self.name, "fact": self.fact, "source_node_uuid": self.source_node_uuid, "target_node_uuid": self.target_node_uuid, "source_node_name": self.source_node_name, "target_node_name": self.target_node_name } def to_text(self) -> str: """转换为文本格式""" source = self.source_node_name or self.source_node_uuid[:8] target = self.target_node_name or self.target_node_uuid[:8] return f"关系: {source} --[{self.name}]--> {target}\n事实: {self.fact}" class ZepToolsService: """ Zep检索工具服务 提供多种图谱检索工具,可以被Report Agent调用: 1. search_graph - 图谱语义搜索 2. get_all_nodes - 获取图谱所有节点 3. get_all_edges - 获取图谱所有边 4. get_node_detail - 获取节点详细信息 5. get_node_edges - 获取节点相关的边 6. get_entities_by_type - 按类型获取实体 7. get_entity_summary - 获取实体的关系摘要 """ # 重试配置 MAX_RETRIES = 3 RETRY_DELAY = 2.0 def __init__(self, api_key: Optional[str] = None): self.api_key = api_key or Config.ZEP_API_KEY if not self.api_key: raise ValueError("ZEP_API_KEY 未配置") self.client = Zep(api_key=self.api_key) logger.info("ZepToolsService 初始化完成") def _call_with_retry(self, func, operation_name: str, max_retries: int = None): """带重试机制的API调用""" max_retries = max_retries or self.MAX_RETRIES last_exception = None delay = self.RETRY_DELAY for attempt in range(max_retries): try: return func() except Exception as e: last_exception = e if attempt < max_retries - 1: logger.warning( f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, " f"{delay:.1f}秒后重试..." ) time.sleep(delay) delay *= 2 else: logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}") raise last_exception def search_graph( self, graph_id: str, query: str, limit: int = 10, scope: str = "edges" ) -> SearchResult: """ 图谱语义搜索 使用混合搜索(语义+BM25)在图谱中搜索相关信息。 如果Zep Cloud的search API不可用,则降级为本地关键词匹配。 Args: graph_id: 图谱ID (Standalone Graph) query: 搜索查询 limit: 返回结果数量 scope: 搜索范围,"edges" 或 "nodes" Returns: SearchResult: 搜索结果 """ logger.info(f"图谱搜索: graph_id={graph_id}, query={query[:50]}...") # 尝试使用Zep Cloud Search API try: search_results = self._call_with_retry( func=lambda: self.client.graph.search( graph_id=graph_id, query=query, limit=limit, scope=scope, reranker="cross_encoder" ), operation_name=f"图谱搜索(graph={graph_id})" ) facts = [] edges = [] nodes = [] # 解析边搜索结果 if hasattr(search_results, 'edges') and search_results.edges: for edge in search_results.edges: if hasattr(edge, 'fact') and edge.fact: facts.append(edge.fact) edges.append({ "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), "name": getattr(edge, 'name', ''), "fact": getattr(edge, 'fact', ''), "source_node_uuid": getattr(edge, 'source_node_uuid', ''), "target_node_uuid": getattr(edge, 'target_node_uuid', ''), }) # 解析节点搜索结果 if hasattr(search_results, 'nodes') and search_results.nodes: for node in search_results.nodes: nodes.append({ "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), "name": getattr(node, 'name', ''), "labels": getattr(node, 'labels', []), "summary": getattr(node, 'summary', ''), }) # 节点摘要也算作事实 if hasattr(node, 'summary') and node.summary: facts.append(f"[{node.name}]: {node.summary}") logger.info(f"搜索完成: 找到 {len(facts)} 条相关事实") return SearchResult( facts=facts, edges=edges, nodes=nodes, query=query, total_count=len(facts) ) except Exception as e: logger.warning(f"Zep Search API失败,降级为本地搜索: {str(e)}") # 降级:使用本地关键词匹配搜索 return self._local_search(graph_id, query, limit, scope) def _local_search( self, graph_id: str, query: str, limit: int = 10, scope: str = "edges" ) -> SearchResult: """ 本地关键词匹配搜索(作为Zep Search API的降级方案) 获取所有边/节点,然后在本地进行关键词匹配 Args: graph_id: 图谱ID query: 搜索查询 limit: 返回结果数量 scope: 搜索范围 Returns: SearchResult: 搜索结果 """ logger.info(f"使用本地搜索: query={query[:30]}...") facts = [] edges_result = [] nodes_result = [] # 提取查询关键词(简单分词) query_lower = query.lower() keywords = [w.strip() for w in query_lower.replace(',', ' ').replace(',', ' ').split() if len(w.strip()) > 1] def match_score(text: str) -> int: """计算文本与查询的匹配分数""" if not text: return 0 text_lower = text.lower() # 完全匹配查询 if query_lower in text_lower: return 100 # 关键词匹配 score = 0 for keyword in keywords: if keyword in text_lower: score += 10 return score try: if scope in ["edges", "both"]: # 获取所有边并匹配 all_edges = self.get_all_edges(graph_id) scored_edges = [] for edge in all_edges: score = match_score(edge.fact) + match_score(edge.name) if score > 0: scored_edges.append((score, edge)) # 按分数排序 scored_edges.sort(key=lambda x: x[0], reverse=True) for score, edge in scored_edges[:limit]: if edge.fact: facts.append(edge.fact) edges_result.append({ "uuid": edge.uuid, "name": edge.name, "fact": edge.fact, "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, }) if scope in ["nodes", "both"]: # 获取所有节点并匹配 all_nodes = self.get_all_nodes(graph_id) scored_nodes = [] for node in all_nodes: score = match_score(node.name) + match_score(node.summary) if score > 0: scored_nodes.append((score, node)) scored_nodes.sort(key=lambda x: x[0], reverse=True) for score, node in scored_nodes[:limit]: nodes_result.append({ "uuid": node.uuid, "name": node.name, "labels": node.labels, "summary": node.summary, }) if node.summary: facts.append(f"[{node.name}]: {node.summary}") logger.info(f"本地搜索完成: 找到 {len(facts)} 条相关事实") except Exception as e: logger.error(f"本地搜索失败: {str(e)}") return SearchResult( facts=facts, edges=edges_result, nodes=nodes_result, query=query, total_count=len(facts) ) def get_all_nodes(self, graph_id: str) -> List[NodeInfo]: """ 获取图谱的所有节点 Args: graph_id: 图谱ID Returns: 节点列表 """ logger.info(f"获取图谱 {graph_id} 的所有节点...") nodes = self._call_with_retry( func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id), operation_name=f"获取节点(graph={graph_id})" ) result = [] for node in nodes: result.append(NodeInfo( uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), name=node.name or "", labels=node.labels or [], summary=node.summary or "", attributes=node.attributes or {} )) logger.info(f"获取到 {len(result)} 个节点") return result def get_all_edges(self, graph_id: str) -> List[EdgeInfo]: """ 获取图谱的所有边 Args: graph_id: 图谱ID Returns: 边列表 """ logger.info(f"获取图谱 {graph_id} 的所有边...") edges = self._call_with_retry( func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id), operation_name=f"获取边(graph={graph_id})" ) result = [] for edge in edges: result.append(EdgeInfo( uuid=getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), name=edge.name or "", fact=edge.fact or "", source_node_uuid=edge.source_node_uuid or "", target_node_uuid=edge.target_node_uuid or "" )) logger.info(f"获取到 {len(result)} 条边") return result def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]: """ 获取单个节点的详细信息 Args: node_uuid: 节点UUID Returns: 节点信息或None """ logger.info(f"获取节点详情: {node_uuid[:8]}...") try: node = self._call_with_retry( func=lambda: self.client.graph.node.get(uuid_=node_uuid), operation_name=f"获取节点详情(uuid={node_uuid[:8]}...)" ) if not node: return None return NodeInfo( uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), name=node.name or "", labels=node.labels or [], summary=node.summary or "", attributes=node.attributes or {} ) except Exception as e: logger.error(f"获取节点详情失败: {str(e)}") return None def get_node_edges(self, graph_id: str, node_uuid: str) -> List[EdgeInfo]: """ 获取节点相关的所有边 通过获取图谱所有边,然后过滤出与指定节点相关的边 Args: graph_id: 图谱ID node_uuid: 节点UUID Returns: 边列表 """ logger.info(f"获取节点 {node_uuid[:8]}... 的相关边") try: # 获取图谱所有边,然后过滤 all_edges = self.get_all_edges(graph_id) result = [] for edge in all_edges: # 检查边是否与指定节点相关(作为源或目标) if edge.source_node_uuid == node_uuid or edge.target_node_uuid == node_uuid: result.append(edge) logger.info(f"找到 {len(result)} 条与节点相关的边") return result except Exception as e: logger.warning(f"获取节点边失败: {str(e)}") return [] def get_entities_by_type( self, graph_id: str, entity_type: str ) -> List[NodeInfo]: """ 按类型获取实体 Args: graph_id: 图谱ID entity_type: 实体类型(如 Student, PublicFigure 等) Returns: 符合类型的实体列表 """ logger.info(f"获取类型为 {entity_type} 的实体...") all_nodes = self.get_all_nodes(graph_id) filtered = [] for node in all_nodes: # 检查labels是否包含指定类型 if entity_type in node.labels: filtered.append(node) logger.info(f"找到 {len(filtered)} 个 {entity_type} 类型的实体") return filtered def get_entity_summary( self, graph_id: str, entity_name: str ) -> Dict[str, Any]: """ 获取指定实体的关系摘要 搜索与该实体相关的所有信息,并生成摘要 Args: graph_id: 图谱ID entity_name: 实体名称 Returns: 实体摘要信息 """ logger.info(f"获取实体 {entity_name} 的关系摘要...") # 先搜索该实体相关的信息 search_result = self.search_graph( graph_id=graph_id, query=entity_name, limit=20 ) # 尝试在所有节点中找到该实体 all_nodes = self.get_all_nodes(graph_id) entity_node = None for node in all_nodes: if node.name.lower() == entity_name.lower(): entity_node = node break related_edges = [] if entity_node: # 传入graph_id参数 related_edges = self.get_node_edges(graph_id, entity_node.uuid) return { "entity_name": entity_name, "entity_info": entity_node.to_dict() if entity_node else None, "related_facts": search_result.facts, "related_edges": [e.to_dict() for e in related_edges], "total_relations": len(related_edges) } def get_graph_statistics(self, graph_id: str) -> Dict[str, Any]: """ 获取图谱的统计信息 Args: graph_id: 图谱ID Returns: 统计信息 """ logger.info(f"获取图谱 {graph_id} 的统计信息...") nodes = self.get_all_nodes(graph_id) edges = self.get_all_edges(graph_id) # 统计实体类型分布 entity_types = {} for node in nodes: for label in node.labels: if label not in ["Entity", "Node"]: entity_types[label] = entity_types.get(label, 0) + 1 # 统计关系类型分布 relation_types = {} for edge in edges: relation_types[edge.name] = relation_types.get(edge.name, 0) + 1 return { "graph_id": graph_id, "total_nodes": len(nodes), "total_edges": len(edges), "entity_types": entity_types, "relation_types": relation_types } def get_simulation_context( self, graph_id: str, simulation_requirement: str, limit: int = 30 ) -> Dict[str, Any]: """ 获取模拟相关的上下文信息 综合搜索与模拟需求相关的所有信息 Args: graph_id: 图谱ID simulation_requirement: 模拟需求描述 limit: 每类信息的数量限制 Returns: 模拟上下文信息 """ logger.info(f"获取模拟上下文: {simulation_requirement[:50]}...") # 搜索与模拟需求相关的信息 search_result = self.search_graph( graph_id=graph_id, query=simulation_requirement, limit=limit ) # 获取图谱统计 stats = self.get_graph_statistics(graph_id) # 获取所有实体节点 all_nodes = self.get_all_nodes(graph_id) # 筛选有实际类型的实体(非纯Entity节点) entities = [] for node in all_nodes: custom_labels = [l for l in node.labels if l not in ["Entity", "Node"]] if custom_labels: entities.append({ "name": node.name, "type": custom_labels[0], "summary": node.summary }) return { "simulation_requirement": simulation_requirement, "related_facts": search_result.facts, "graph_statistics": stats, "entities": entities[:limit], # 限制数量 "total_entities": len(entities) }