feat(graph): implement pagination for fetching nodes and edges; add utility functions for streamlined data retrieval
This commit is contained in:
parent
d30a0a23ef
commit
da6548e96f
4 changed files with 192 additions and 58 deletions
|
|
@ -15,6 +15,7 @@ from zep_cloud import EpisodeData, EntityEdgeSourceTarget
|
|||
|
||||
from ..config import Config
|
||||
from ..models.task import TaskManager, TaskStatus
|
||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||
from .text_processor import TextProcessor
|
||||
|
||||
|
||||
|
|
@ -395,12 +396,12 @@ class GraphBuilderService:
|
|||
|
||||
def _get_graph_info(self, graph_id: str) -> GraphInfo:
|
||||
"""获取图谱信息"""
|
||||
# 获取节点
|
||||
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
||||
|
||||
# 获取边
|
||||
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
||||
|
||||
# 获取节点(分页)
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
|
||||
# 获取边(分页)
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
# 统计实体类型
|
||||
entity_types = set()
|
||||
for node in nodes:
|
||||
|
|
@ -408,7 +409,7 @@ class GraphBuilderService:
|
|||
for label in node.labels:
|
||||
if label not in ["Entity", "Node"]:
|
||||
entity_types.add(label)
|
||||
|
||||
|
||||
return GraphInfo(
|
||||
graph_id=graph_id,
|
||||
node_count=len(nodes),
|
||||
|
|
@ -426,9 +427,9 @@ class GraphBuilderService:
|
|||
Returns:
|
||||
包含nodes和edges的字典,包括时间信息、属性等详细数据
|
||||
"""
|
||||
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
||||
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
||||
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
# 创建节点映射用于获取节点名称
|
||||
node_map = {}
|
||||
for node in nodes:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from zep_cloud.client import Zep
|
|||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||
|
||||
logger = get_logger('mirofish.zep_entity_reader')
|
||||
|
||||
|
|
@ -125,22 +126,18 @@ class ZepEntityReader:
|
|||
|
||||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有节点(带重试机制)
|
||||
|
||||
获取图谱的所有节点(分页获取)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
|
||||
Returns:
|
||||
节点列表
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||
|
||||
# 使用重试机制调用Zep API
|
||||
nodes = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id),
|
||||
operation_name=f"获取节点(graph={graph_id})"
|
||||
)
|
||||
|
||||
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
|
||||
nodes_data = []
|
||||
for node in nodes:
|
||||
nodes_data.append({
|
||||
|
|
@ -150,28 +147,24 @@ class ZepEntityReader:
|
|||
"summary": node.summary or "",
|
||||
"attributes": node.attributes or {},
|
||||
})
|
||||
|
||||
|
||||
logger.info(f"共获取 {len(nodes_data)} 个节点")
|
||||
return nodes_data
|
||||
|
||||
|
||||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有边(带重试机制)
|
||||
|
||||
获取图谱的所有边(分页获取)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
|
||||
Returns:
|
||||
边列表
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||
|
||||
# 使用重试机制调用Zep API
|
||||
edges = self._call_with_retry(
|
||||
func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id),
|
||||
operation_name=f"获取边(graph={graph_id})"
|
||||
)
|
||||
|
||||
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
edges_data.append({
|
||||
|
|
@ -182,7 +175,7 @@ class ZepEntityReader:
|
|||
"target_node_uuid": edge.target_node_uuid,
|
||||
"attributes": edge.attributes or {},
|
||||
})
|
||||
|
||||
|
||||
logger.info(f"共获取 {len(edges_data)} 条边")
|
||||
return edges_data
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from zep_cloud.client import Zep
|
|||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.llm_client import LLMClient
|
||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||
|
||||
logger = get_logger('mirofish.zep_tools')
|
||||
|
||||
|
|
@ -648,71 +649,67 @@ class ZepToolsService:
|
|||
|
||||
def get_all_nodes(self, graph_id: str) -> List[NodeInfo]:
|
||||
"""
|
||||
获取图谱的所有节点
|
||||
|
||||
获取图谱的所有节点(分页获取)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
|
||||
Returns:
|
||||
节点列表
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||
|
||||
nodes = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id),
|
||||
operation_name=f"获取节点(graph={graph_id})"
|
||||
)
|
||||
|
||||
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
|
||||
result = []
|
||||
for node in nodes:
|
||||
node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or ""
|
||||
result.append(NodeInfo(
|
||||
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||
uuid=str(node_uuid) if node_uuid else "",
|
||||
name=node.name or "",
|
||||
labels=node.labels or [],
|
||||
summary=node.summary or "",
|
||||
attributes=node.attributes or {}
|
||||
))
|
||||
|
||||
|
||||
logger.info(f"获取到 {len(result)} 个节点")
|
||||
return result
|
||||
|
||||
|
||||
def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[EdgeInfo]:
|
||||
"""
|
||||
获取图谱的所有边(包含时间信息)
|
||||
|
||||
获取图谱的所有边(分页获取,包含时间信息)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
include_temporal: 是否包含时间信息(默认True)
|
||||
|
||||
|
||||
Returns:
|
||||
边列表(包含created_at, valid_at, invalid_at, expired_at)
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||
|
||||
edges = self._call_with_retry(
|
||||
func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id),
|
||||
operation_name=f"获取边(graph={graph_id})"
|
||||
)
|
||||
|
||||
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
result = []
|
||||
for edge in edges:
|
||||
edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or ""
|
||||
edge_info = EdgeInfo(
|
||||
uuid=getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
||||
uuid=str(edge_uuid) if edge_uuid else "",
|
||||
name=edge.name or "",
|
||||
fact=edge.fact or "",
|
||||
source_node_uuid=edge.source_node_uuid or "",
|
||||
target_node_uuid=edge.target_node_uuid or ""
|
||||
)
|
||||
|
||||
|
||||
# 添加时间信息
|
||||
if include_temporal:
|
||||
edge_info.created_at = getattr(edge, 'created_at', None)
|
||||
edge_info.valid_at = getattr(edge, 'valid_at', None)
|
||||
edge_info.invalid_at = getattr(edge, 'invalid_at', None)
|
||||
edge_info.expired_at = getattr(edge, 'expired_at', None)
|
||||
|
||||
|
||||
result.append(edge_info)
|
||||
|
||||
|
||||
logger.info(f"获取到 {len(result)} 条边")
|
||||
return result
|
||||
|
||||
|
|
|
|||
143
backend/app/utils/zep_paging.py
Normal file
143
backend/app/utils/zep_paging.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""Zep Graph 分页读取工具。
|
||||
|
||||
Zep 的 node/edge 列表接口使用 UUID cursor 分页,
|
||||
本模块封装自动翻页逻辑(含单页重试),对调用方透明地返回完整列表。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from zep_cloud import InternalServerError
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.zep_paging')
|
||||
|
||||
_DEFAULT_PAGE_SIZE = 100
|
||||
_MAX_NODES = 2000
|
||||
_DEFAULT_MAX_RETRIES = 3
|
||||
_DEFAULT_RETRY_DELAY = 2.0 # seconds, doubles each retry
|
||||
|
||||
|
||||
def _fetch_page_with_retry(
|
||||
api_call: Callable[..., list[Any]],
|
||||
*args: Any,
|
||||
max_retries: int = _DEFAULT_MAX_RETRIES,
|
||||
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
||||
page_description: str = "page",
|
||||
**kwargs: Any,
|
||||
) -> list[Any]:
|
||||
"""单页请求,失败时指数退避重试。仅重试网络/IO类瞬态错误。"""
|
||||
if max_retries < 1:
|
||||
raise ValueError("max_retries must be >= 1")
|
||||
|
||||
last_exception: Exception | None = None
|
||||
delay = retry_delay
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return api_call(*args, **kwargs)
|
||||
except (ConnectionError, TimeoutError, OSError, InternalServerError) as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
f"Zep {page_description} attempt {attempt + 1} failed: {str(e)[:100]}, retrying in {delay:.1f}s..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
else:
|
||||
logger.error(f"Zep {page_description} failed after {max_retries} attempts: {str(e)}")
|
||||
|
||||
assert last_exception is not None
|
||||
raise last_exception
|
||||
|
||||
|
||||
def fetch_all_nodes(
|
||||
client: Zep,
|
||||
graph_id: str,
|
||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||
max_items: int = _MAX_NODES,
|
||||
max_retries: int = _DEFAULT_MAX_RETRIES,
|
||||
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
||||
) -> list[Any]:
|
||||
"""分页获取图谱节点,最多返回 max_items 条(默认 2000)。每页请求自带重试。"""
|
||||
all_nodes: list[Any] = []
|
||||
cursor: str | None = None
|
||||
page_num = 0
|
||||
|
||||
while True:
|
||||
kwargs: dict[str, Any] = {"limit": page_size}
|
||||
if cursor is not None:
|
||||
kwargs["uuid_cursor"] = cursor
|
||||
|
||||
page_num += 1
|
||||
batch = _fetch_page_with_retry(
|
||||
client.graph.node.get_by_graph_id,
|
||||
graph_id,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
page_description=f"fetch nodes page {page_num} (graph={graph_id})",
|
||||
**kwargs,
|
||||
)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
all_nodes.extend(batch)
|
||||
if len(all_nodes) >= max_items:
|
||||
all_nodes = all_nodes[:max_items]
|
||||
logger.warning(f"Node count reached limit ({max_items}), stopping pagination for graph {graph_id}")
|
||||
break
|
||||
if len(batch) < page_size:
|
||||
break
|
||||
|
||||
cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None)
|
||||
if cursor is None:
|
||||
logger.warning(f"Node missing uuid field, stopping pagination at {len(all_nodes)} nodes")
|
||||
break
|
||||
|
||||
return all_nodes
|
||||
|
||||
|
||||
def fetch_all_edges(
|
||||
client: Zep,
|
||||
graph_id: str,
|
||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||
max_retries: int = _DEFAULT_MAX_RETRIES,
|
||||
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
||||
) -> list[Any]:
|
||||
"""分页获取图谱所有边,返回完整列表。每页请求自带重试。"""
|
||||
all_edges: list[Any] = []
|
||||
cursor: str | None = None
|
||||
page_num = 0
|
||||
|
||||
while True:
|
||||
kwargs: dict[str, Any] = {"limit": page_size}
|
||||
if cursor is not None:
|
||||
kwargs["uuid_cursor"] = cursor
|
||||
|
||||
page_num += 1
|
||||
batch = _fetch_page_with_retry(
|
||||
client.graph.edge.get_by_graph_id,
|
||||
graph_id,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
page_description=f"fetch edges page {page_num} (graph={graph_id})",
|
||||
**kwargs,
|
||||
)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
all_edges.extend(batch)
|
||||
if len(batch) < page_size:
|
||||
break
|
||||
|
||||
cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None)
|
||||
if cursor is None:
|
||||
logger.warning(f"Edge missing uuid field, stopping pagination at {len(all_edges)} edges")
|
||||
break
|
||||
|
||||
return all_edges
|
||||
Loading…
Reference in a new issue