feat(graph): implement pagination for fetching nodes and edges; add utility functions for streamlined data retrieval

This commit is contained in:
666ghj 2026-02-27 15:53:29 +08:00
parent d30a0a23ef
commit da6548e96f
4 changed files with 192 additions and 58 deletions

View file

@ -15,6 +15,7 @@ from zep_cloud import EpisodeData, EntityEdgeSourceTarget
from ..config import Config
from ..models.task import TaskManager, TaskStatus
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
from .text_processor import TextProcessor
@ -395,12 +396,12 @@ class GraphBuilderService:
def _get_graph_info(self, graph_id: str) -> GraphInfo:
"""获取图谱信息"""
# 获取节点
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
# 获取边
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
# 获取节点(分页)
nodes = fetch_all_nodes(self.client, graph_id)
# 获取边(分页)
edges = fetch_all_edges(self.client, graph_id)
# 统计实体类型
entity_types = set()
for node in nodes:
@ -408,7 +409,7 @@ class GraphBuilderService:
for label in node.labels:
if label not in ["Entity", "Node"]:
entity_types.add(label)
return GraphInfo(
graph_id=graph_id,
node_count=len(nodes),
@ -426,9 +427,9 @@ class GraphBuilderService:
Returns:
包含nodes和edges的字典包括时间信息属性等详细数据
"""
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
nodes = fetch_all_nodes(self.client, graph_id)
edges = fetch_all_edges(self.client, graph_id)
# 创建节点映射用于获取节点名称
node_map = {}
for node in nodes:

View file

@ -11,6 +11,7 @@ from zep_cloud.client import Zep
from ..config import Config
from ..utils.logger import get_logger
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
logger = get_logger('mirofish.zep_entity_reader')
@ -125,22 +126,18 @@ class ZepEntityReader:
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
"""
获取图谱的所有节点带重试机制
获取图谱的所有节点分页获取
Args:
graph_id: 图谱ID
Returns:
节点列表
"""
logger.info(f"获取图谱 {graph_id} 的所有节点...")
# 使用重试机制调用Zep API
nodes = self._call_with_retry(
func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id),
operation_name=f"获取节点(graph={graph_id})"
)
nodes = fetch_all_nodes(self.client, graph_id)
nodes_data = []
for node in nodes:
nodes_data.append({
@ -150,28 +147,24 @@ class ZepEntityReader:
"summary": node.summary or "",
"attributes": node.attributes or {},
})
logger.info(f"共获取 {len(nodes_data)} 个节点")
return nodes_data
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
"""
获取图谱的所有边带重试机制
获取图谱的所有边分页获取
Args:
graph_id: 图谱ID
Returns:
边列表
"""
logger.info(f"获取图谱 {graph_id} 的所有边...")
# 使用重试机制调用Zep API
edges = self._call_with_retry(
func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id),
operation_name=f"获取边(graph={graph_id})"
)
edges = fetch_all_edges(self.client, graph_id)
edges_data = []
for edge in edges:
edges_data.append({
@ -182,7 +175,7 @@ class ZepEntityReader:
"target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {},
})
logger.info(f"共获取 {len(edges_data)} 条边")
return edges_data

View file

@ -18,6 +18,7 @@ from zep_cloud.client import Zep
from ..config import Config
from ..utils.logger import get_logger
from ..utils.llm_client import LLMClient
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
logger = get_logger('mirofish.zep_tools')
@ -648,71 +649,67 @@ class ZepToolsService:
def get_all_nodes(self, graph_id: str) -> List[NodeInfo]:
"""
获取图谱的所有节点
获取图谱的所有节点分页获取
Args:
graph_id: 图谱ID
Returns:
节点列表
"""
logger.info(f"获取图谱 {graph_id} 的所有节点...")
nodes = self._call_with_retry(
func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id),
operation_name=f"获取节点(graph={graph_id})"
)
nodes = fetch_all_nodes(self.client, graph_id)
result = []
for node in nodes:
node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or ""
result.append(NodeInfo(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
uuid=str(node_uuid) if node_uuid else "",
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {}
))
logger.info(f"获取到 {len(result)} 个节点")
return result
def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[EdgeInfo]:
"""
获取图谱的所有边包含时间信息
获取图谱的所有边分页获取包含时间信息
Args:
graph_id: 图谱ID
include_temporal: 是否包含时间信息默认True
Returns:
边列表包含created_at, valid_at, invalid_at, expired_at
"""
logger.info(f"获取图谱 {graph_id} 的所有边...")
edges = self._call_with_retry(
func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id),
operation_name=f"获取边(graph={graph_id})"
)
edges = fetch_all_edges(self.client, graph_id)
result = []
for edge in edges:
edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or ""
edge_info = EdgeInfo(
uuid=getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
uuid=str(edge_uuid) if edge_uuid else "",
name=edge.name or "",
fact=edge.fact or "",
source_node_uuid=edge.source_node_uuid or "",
target_node_uuid=edge.target_node_uuid or ""
)
# 添加时间信息
if include_temporal:
edge_info.created_at = getattr(edge, 'created_at', None)
edge_info.valid_at = getattr(edge, 'valid_at', None)
edge_info.invalid_at = getattr(edge, 'invalid_at', None)
edge_info.expired_at = getattr(edge, 'expired_at', None)
result.append(edge_info)
logger.info(f"获取到 {len(result)} 条边")
return result

View file

@ -0,0 +1,143 @@
"""Zep Graph 分页读取工具。
Zep node/edge 列表接口使用 UUID cursor 分页
本模块封装自动翻页逻辑含单页重试对调用方透明地返回完整列表
"""
from __future__ import annotations
import time
from collections.abc import Callable
from typing import Any
from zep_cloud import InternalServerError
from zep_cloud.client import Zep
from .logger import get_logger
logger = get_logger('mirofish.zep_paging')
_DEFAULT_PAGE_SIZE = 100
_MAX_NODES = 2000
_DEFAULT_MAX_RETRIES = 3
_DEFAULT_RETRY_DELAY = 2.0 # seconds, doubles each retry
def _fetch_page_with_retry(
api_call: Callable[..., list[Any]],
*args: Any,
max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: float = _DEFAULT_RETRY_DELAY,
page_description: str = "page",
**kwargs: Any,
) -> list[Any]:
"""单页请求,失败时指数退避重试。仅重试网络/IO类瞬态错误。"""
if max_retries < 1:
raise ValueError("max_retries must be >= 1")
last_exception: Exception | None = None
delay = retry_delay
for attempt in range(max_retries):
try:
return api_call(*args, **kwargs)
except (ConnectionError, TimeoutError, OSError, InternalServerError) as e:
last_exception = e
if attempt < max_retries - 1:
logger.warning(
f"Zep {page_description} attempt {attempt + 1} failed: {str(e)[:100]}, retrying in {delay:.1f}s..."
)
time.sleep(delay)
delay *= 2
else:
logger.error(f"Zep {page_description} failed after {max_retries} attempts: {str(e)}")
assert last_exception is not None
raise last_exception
def fetch_all_nodes(
client: Zep,
graph_id: str,
page_size: int = _DEFAULT_PAGE_SIZE,
max_items: int = _MAX_NODES,
max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: float = _DEFAULT_RETRY_DELAY,
) -> list[Any]:
"""分页获取图谱节点,最多返回 max_items 条(默认 2000。每页请求自带重试。"""
all_nodes: list[Any] = []
cursor: str | None = None
page_num = 0
while True:
kwargs: dict[str, Any] = {"limit": page_size}
if cursor is not None:
kwargs["uuid_cursor"] = cursor
page_num += 1
batch = _fetch_page_with_retry(
client.graph.node.get_by_graph_id,
graph_id,
max_retries=max_retries,
retry_delay=retry_delay,
page_description=f"fetch nodes page {page_num} (graph={graph_id})",
**kwargs,
)
if not batch:
break
all_nodes.extend(batch)
if len(all_nodes) >= max_items:
all_nodes = all_nodes[:max_items]
logger.warning(f"Node count reached limit ({max_items}), stopping pagination for graph {graph_id}")
break
if len(batch) < page_size:
break
cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None)
if cursor is None:
logger.warning(f"Node missing uuid field, stopping pagination at {len(all_nodes)} nodes")
break
return all_nodes
def fetch_all_edges(
client: Zep,
graph_id: str,
page_size: int = _DEFAULT_PAGE_SIZE,
max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: float = _DEFAULT_RETRY_DELAY,
) -> list[Any]:
"""分页获取图谱所有边,返回完整列表。每页请求自带重试。"""
all_edges: list[Any] = []
cursor: str | None = None
page_num = 0
while True:
kwargs: dict[str, Any] = {"limit": page_size}
if cursor is not None:
kwargs["uuid_cursor"] = cursor
page_num += 1
batch = _fetch_page_with_retry(
client.graph.edge.get_by_graph_id,
graph_id,
max_retries=max_retries,
retry_delay=retry_delay,
page_description=f"fetch edges page {page_num} (graph={graph_id})",
**kwargs,
)
if not batch:
break
all_edges.extend(batch)
if len(batch) < page_size:
break
cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None)
if cursor is None:
logger.warning(f"Edge missing uuid field, stopping pagination at {len(all_edges)} edges")
break
return all_edges