""" Zep图谱构建模块 负责与Zep云服务交互,构建知识图谱 """ import os import time import uuid from typing import Optional, Callable from dataclasses import dataclass from zep_cloud.client import Zep from zep_cloud import EpisodeData, EntityEdgeSourceTarget from ontology import ENTITY_TYPES, EDGE_TYPES @dataclass class GraphNode: """图节点数据结构""" uuid: str name: str summary: str labels: list[str] attributes: dict @dataclass class GraphEdge: """图边数据结构""" uuid: str name: str fact: str source_node_uuid: str target_node_uuid: str attributes: dict @dataclass class GraphData: """完整图数据""" graph_id: str nodes: list[GraphNode] edges: list[GraphEdge] class ZepGraphBuilder: """Zep知识图谱构建器""" def __init__(self, api_key: Optional[str] = None): """ 初始化图谱构建器 Args: api_key: Zep API密钥,如果不提供则从环境变量ZEP_API_KEY读取 """ self.api_key = api_key or os.environ.get("ZEP_API_KEY") if not self.api_key: raise ValueError("需要提供ZEP_API_KEY,可以通过参数传入或设置环境变量") self.client = Zep(api_key=self.api_key) def create_graph(self, graph_id: Optional[str] = None, name: str = "Knowledge Graph") -> str: """ 创建新的图谱 Args: graph_id: 图谱ID,如果不提供则自动生成 name: 图谱名称 Returns: 图谱ID """ if graph_id is None: graph_id = f"graph_{uuid.uuid4().hex[:16]}" self.client.graph.create( graph_id=graph_id, name=name, description="Knowledge graph generated by txt2graph" ) return graph_id def set_ontology(self, graph_id: str): """ 为图谱设置自定义本体(实体和边类型) Args: graph_id: 图谱ID """ # 构建边类型的源目标映射 edge_definitions = {} # WORKS_FOR: Person -> Organization/Company edge_definitions["WORKS_FOR"] = ( EDGE_TYPES["WORKS_FOR"], [ EntityEdgeSourceTarget(source="Person", target="Organization"), EntityEdgeSourceTarget(source="Person", target="Company"), ] ) # LOCATED_IN: 多种实体 -> Location edge_definitions["LOCATED_IN"] = ( EDGE_TYPES["LOCATED_IN"], [ EntityEdgeSourceTarget(source="Person", target="Location"), EntityEdgeSourceTarget(source="Organization", target="Location"), EntityEdgeSourceTarget(source="Company", target="Location"), EntityEdgeSourceTarget(source="Event", target="Location"), ] ) # PART_OF: Organization -> Organization, Company -> Company edge_definitions["PART_OF"] = ( EDGE_TYPES["PART_OF"], [ EntityEdgeSourceTarget(source="Organization", target="Organization"), EntityEdgeSourceTarget(source="Company", target="Company"), ] ) # PRODUCES: Company -> Product edge_definitions["PRODUCES"] = ( EDGE_TYPES["PRODUCES"], [ EntityEdgeSourceTarget(source="Company", target="Product"), EntityEdgeSourceTarget(source="Organization", target="Product"), ] ) # PARTICIPATES_IN: Person/Organization/Company -> Event edge_definitions["PARTICIPATES_IN"] = ( EDGE_TYPES["PARTICIPATES_IN"], [ EntityEdgeSourceTarget(source="Person", target="Event"), EntityEdgeSourceTarget(source="Organization", target="Event"), EntityEdgeSourceTarget(source="Company", target="Event"), ] ) # COLLABORATES: 各种实体之间的合作 edge_definitions["COLLABORATES"] = ( EDGE_TYPES["COLLABORATES"], [ EntityEdgeSourceTarget(source="Person", target="Person"), EntityEdgeSourceTarget(source="Company", target="Company"), EntityEdgeSourceTarget(source="Organization", target="Organization"), EntityEdgeSourceTarget(source="Company", target="Organization"), ] ) # COMPETES: 公司之间的竞争 edge_definitions["COMPETES"] = ( EDGE_TYPES["COMPETES"], [ EntityEdgeSourceTarget(source="Company", target="Company"), ] ) # REPORTS: Media报道相关实体 edge_definitions["REPORTS"] = ( EDGE_TYPES["REPORTS"], [ EntityEdgeSourceTarget(source="Media", target="Person"), EntityEdgeSourceTarget(source="Media", target="Company"), EntityEdgeSourceTarget(source="Media", target="Organization"), EntityEdgeSourceTarget(source="Media", target="Event"), ] ) # 设置本体 self.client.graph.set_ontology( graph_ids=[graph_id], entities=ENTITY_TYPES, edges=edge_definitions, ) def add_text_to_graph( self, graph_id: str, text_chunks: list[str], batch_size: int = 3, progress_callback: Optional[Callable] = None ) -> list[str]: """ 将文本块分批添加到图谱中 Args: graph_id: 图谱ID text_chunks: 文本块列表 batch_size: 每批发送的块数量 progress_callback: 进度回调函数 Returns: 任务ID列表 """ task_ids = [] total_chunks = len(text_chunks) # 分批处理 for i in range(0, total_chunks, batch_size): batch_chunks = text_chunks[i:i + batch_size] batch_num = i // batch_size + 1 total_batches = (total_chunks + batch_size - 1) // batch_size if progress_callback: progress_callback(f"发送第 {batch_num}/{total_batches} 批数据 ({len(batch_chunks)} 块)...") # 构建episode数据 episodes = [ EpisodeData(data=chunk, type="text") for chunk in batch_chunks ] try: # 批量添加 batch_result = self.client.graph.add_batch( graph_id=graph_id, episodes=episodes ) if batch_result and batch_result[0].task_id: task_ids.append(batch_result[0].task_id) # 短暂等待,避免请求过快 time.sleep(1) except Exception as e: if progress_callback: progress_callback(f"批次 {batch_num} 发送失败: {str(e)}") raise return task_ids def wait_for_tasks( self, task_ids: list[str], timeout: int = 600, progress_callback: Optional[Callable] = None ): """ 等待所有任务完成 Args: task_ids: 任务ID列表 timeout: 超时时间(秒) progress_callback: 进度回调 """ if not task_ids: return start_time = time.time() pending_tasks = set(task_ids) completed_tasks = set() while pending_tasks: if time.time() - start_time > timeout: if progress_callback: progress_callback(f"警告: 部分任务超时,已完成 {len(completed_tasks)}/{len(task_ids)}") break for task_id in list(pending_tasks): try: task = self.client.task.get(task_id=task_id) if task.status == "completed": pending_tasks.remove(task_id) completed_tasks.add(task_id) elif task.status == "failed": pending_tasks.remove(task_id) if progress_callback: progress_callback(f"任务失败: {task.error}") except Exception as e: if progress_callback: progress_callback(f"检查任务状态出错: {str(e)}") if pending_tasks: if progress_callback: elapsed = int(time.time() - start_time) progress_callback(f"等待处理中... 已完成 {len(completed_tasks)}/{len(task_ids)} ({elapsed}秒)") time.sleep(3) if progress_callback: progress_callback(f"所有任务处理完成: {len(completed_tasks)}/{len(task_ids)}") def get_graph_data(self, graph_id: str) -> GraphData: """ 获取图谱的完整数据 Args: graph_id: 图谱ID Returns: GraphData对象,包含所有节点和边 """ # 获取所有节点 raw_nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id) nodes = [ GraphNode( uuid=node.uuid_, name=node.name, summary=node.summary or "", labels=node.labels or [], attributes=node.attributes or {} ) for node in raw_nodes ] # 获取所有边 raw_edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id) edges = [ GraphEdge( uuid=edge.uuid_, name=edge.name or "", fact=edge.fact or "", source_node_uuid=edge.source_node_uuid, target_node_uuid=edge.target_node_uuid, attributes=edge.attributes or {} ) for edge in raw_edges ] return GraphData( graph_id=graph_id, nodes=nodes, edges=edges ) def delete_graph(self, graph_id: str): """删除图谱""" self.client.graph.delete(graph_id=graph_id) def build_graph_from_text( text: str, graph_name: str = "Knowledge Graph", api_key: Optional[str] = None, chunk_size: int = 2000, progress_callback: Optional[Callable] = None ) -> GraphData: """ 便捷函数:从文本构建知识图谱 Args: text: 输入文本 graph_name: 图谱名称 api_key: Zep API密钥 chunk_size: 文本分块大小(默认2000字符) progress_callback: 进度回调 Returns: GraphData对象 """ from text_extractor import split_text_into_chunks builder = ZepGraphBuilder(api_key=api_key) # 创建图谱 graph_id = builder.create_graph(name=graph_name) if progress_callback: progress_callback(f"创建图谱: {graph_id}") # 设置本体 builder.set_ontology(graph_id) if progress_callback: progress_callback("设置实体类型...") # 分块处理文本 chunks = split_text_into_chunks(text, max_chunk_size=chunk_size) if progress_callback: progress_callback(f"文本分为 {len(chunks)} 个块") # 分批添加到图谱 task_ids = builder.add_text_to_graph( graph_id=graph_id, text_chunks=chunks, batch_size=3, progress_callback=progress_callback ) # 等待所有任务完成 if task_ids: builder.wait_for_tasks(task_ids, progress_callback=progress_callback) # 获取并返回图数据 return builder.get_graph_data(graph_id) if __name__ == "__main__": # 测试 from dotenv import load_dotenv load_dotenv() test_text = """ 武汉大学是中国著名的高等学府,位于湖北省武汉市。 该校的樱花季每年吸引大量游客。 马化腾是腾讯公司的创始人,腾讯总部位于深圳。 """ result = build_graph_from_text( text=test_text, graph_name="测试图谱", progress_callback=print ) print(f"\n节点数: {len(result.nodes)}") for node in result.nodes: print(f" - {node.name} ({node.labels})") print(f"\n边数: {len(result.edges)}") for edge in result.edges: print(f" - {edge.fact}")