MiroFish/backend/app/services/zep_tools.py
666ghj 5ece3f670b Implement Report Agent for automated report generation and interaction
- Introduced the Report Agent module to facilitate the automatic generation of simulation analysis reports using LangChain and Zep, following the ReACT model.
- Added functionality for report outline planning, segmented content generation, and user interaction through a dialogue interface.
- Implemented new API endpoints for report generation, status checking, and retrieval, enhancing the overall reporting capabilities.
- Updated README.md to include detailed instructions on the new report generation features and API usage.
- Enhanced the project structure to accommodate the new report management functionalities, including report storage and retrieval mechanisms.
2025-12-09 15:10:55 +08:00

621 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Zep检索工具服务
封装图谱搜索、节点读取、边查询等工具供Report Agent使用
"""
import time
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from zep_cloud.client import Zep
from ..config import Config
from ..utils.logger import get_logger
logger = get_logger('mirofish.zep_tools')
@dataclass
class SearchResult:
"""搜索结果"""
facts: List[str]
edges: List[Dict[str, Any]]
nodes: List[Dict[str, Any]]
query: str
total_count: int
def to_dict(self) -> Dict[str, Any]:
return {
"facts": self.facts,
"edges": self.edges,
"nodes": self.nodes,
"query": self.query,
"total_count": self.total_count
}
def to_text(self) -> str:
"""转换为文本格式供LLM理解"""
text_parts = [f"搜索查询: {self.query}", f"找到 {self.total_count} 条相关信息"]
if self.facts:
text_parts.append("\n### 相关事实:")
for i, fact in enumerate(self.facts, 1):
text_parts.append(f"{i}. {fact}")
return "\n".join(text_parts)
@dataclass
class NodeInfo:
"""节点信息"""
uuid: str
name: str
labels: List[str]
summary: str
attributes: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
return {
"uuid": self.uuid,
"name": self.name,
"labels": self.labels,
"summary": self.summary,
"attributes": self.attributes
}
def to_text(self) -> str:
"""转换为文本格式"""
entity_type = next((l for l in self.labels if l not in ["Entity", "Node"]), "未知类型")
return f"实体: {self.name} (类型: {entity_type})\n摘要: {self.summary}"
@dataclass
class EdgeInfo:
"""边信息"""
uuid: str
name: str
fact: str
source_node_uuid: str
target_node_uuid: str
source_node_name: Optional[str] = None
target_node_name: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
"uuid": self.uuid,
"name": self.name,
"fact": self.fact,
"source_node_uuid": self.source_node_uuid,
"target_node_uuid": self.target_node_uuid,
"source_node_name": self.source_node_name,
"target_node_name": self.target_node_name
}
def to_text(self) -> str:
"""转换为文本格式"""
source = self.source_node_name or self.source_node_uuid[:8]
target = self.target_node_name or self.target_node_uuid[:8]
return f"关系: {source} --[{self.name}]--> {target}\n事实: {self.fact}"
class ZepToolsService:
"""
Zep检索工具服务
提供多种图谱检索工具可以被Report Agent调用
1. search_graph - 图谱语义搜索
2. get_all_nodes - 获取图谱所有节点
3. get_all_edges - 获取图谱所有边
4. get_node_detail - 获取节点详细信息
5. get_node_edges - 获取节点相关的边
6. get_entities_by_type - 按类型获取实体
7. get_entity_summary - 获取实体的关系摘要
"""
# 重试配置
MAX_RETRIES = 3
RETRY_DELAY = 2.0
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY 未配置")
self.client = Zep(api_key=self.api_key)
logger.info("ZepToolsService 初始化完成")
def _call_with_retry(self, func, operation_name: str, max_retries: int = None):
"""带重试机制的API调用"""
max_retries = max_retries or self.MAX_RETRIES
last_exception = None
delay = self.RETRY_DELAY
for attempt in range(max_retries):
try:
return func()
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.warning(
f"Zep {operation_name}{attempt + 1} 次尝试失败: {str(e)[:100]}, "
f"{delay:.1f}秒后重试..."
)
time.sleep(delay)
delay *= 2
else:
logger.error(f"Zep {operation_name}{max_retries} 次尝试后仍失败: {str(e)}")
raise last_exception
def search_graph(
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges"
) -> SearchResult:
"""
图谱语义搜索
使用混合搜索(语义+BM25在图谱中搜索相关信息。
如果Zep Cloud的search API不可用则降级为本地关键词匹配。
Args:
graph_id: 图谱ID (Standalone Graph)
query: 搜索查询
limit: 返回结果数量
scope: 搜索范围,"edges""nodes"
Returns:
SearchResult: 搜索结果
"""
logger.info(f"图谱搜索: graph_id={graph_id}, query={query[:50]}...")
# 尝试使用Zep Cloud Search API
try:
search_results = self._call_with_retry(
func=lambda: self.client.graph.search(
graph_id=graph_id,
query=query,
limit=limit,
scope=scope,
reranker="cross_encoder"
),
operation_name=f"图谱搜索(graph={graph_id})"
)
facts = []
edges = []
nodes = []
# 解析边搜索结果
if hasattr(search_results, 'edges') and search_results.edges:
for edge in search_results.edges:
if hasattr(edge, 'fact') and edge.fact:
facts.append(edge.fact)
edges.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": getattr(edge, 'name', ''),
"fact": getattr(edge, 'fact', ''),
"source_node_uuid": getattr(edge, 'source_node_uuid', ''),
"target_node_uuid": getattr(edge, 'target_node_uuid', ''),
})
# 解析节点搜索结果
if hasattr(search_results, 'nodes') and search_results.nodes:
for node in search_results.nodes:
nodes.append({
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
"name": getattr(node, 'name', ''),
"labels": getattr(node, 'labels', []),
"summary": getattr(node, 'summary', ''),
})
# 节点摘要也算作事实
if hasattr(node, 'summary') and node.summary:
facts.append(f"[{node.name}]: {node.summary}")
logger.info(f"搜索完成: 找到 {len(facts)} 条相关事实")
return SearchResult(
facts=facts,
edges=edges,
nodes=nodes,
query=query,
total_count=len(facts)
)
except Exception as e:
logger.warning(f"Zep Search API失败降级为本地搜索: {str(e)}")
# 降级:使用本地关键词匹配搜索
return self._local_search(graph_id, query, limit, scope)
def _local_search(
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges"
) -> SearchResult:
"""
本地关键词匹配搜索作为Zep Search API的降级方案
获取所有边/节点,然后在本地进行关键词匹配
Args:
graph_id: 图谱ID
query: 搜索查询
limit: 返回结果数量
scope: 搜索范围
Returns:
SearchResult: 搜索结果
"""
logger.info(f"使用本地搜索: query={query[:30]}...")
facts = []
edges_result = []
nodes_result = []
# 提取查询关键词(简单分词)
query_lower = query.lower()
keywords = [w.strip() for w in query_lower.replace(',', ' ').replace('', ' ').split() if len(w.strip()) > 1]
def match_score(text: str) -> int:
"""计算文本与查询的匹配分数"""
if not text:
return 0
text_lower = text.lower()
# 完全匹配查询
if query_lower in text_lower:
return 100
# 关键词匹配
score = 0
for keyword in keywords:
if keyword in text_lower:
score += 10
return score
try:
if scope in ["edges", "both"]:
# 获取所有边并匹配
all_edges = self.get_all_edges(graph_id)
scored_edges = []
for edge in all_edges:
score = match_score(edge.fact) + match_score(edge.name)
if score > 0:
scored_edges.append((score, edge))
# 按分数排序
scored_edges.sort(key=lambda x: x[0], reverse=True)
for score, edge in scored_edges[:limit]:
if edge.fact:
facts.append(edge.fact)
edges_result.append({
"uuid": edge.uuid,
"name": edge.name,
"fact": edge.fact,
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
})
if scope in ["nodes", "both"]:
# 获取所有节点并匹配
all_nodes = self.get_all_nodes(graph_id)
scored_nodes = []
for node in all_nodes:
score = match_score(node.name) + match_score(node.summary)
if score > 0:
scored_nodes.append((score, node))
scored_nodes.sort(key=lambda x: x[0], reverse=True)
for score, node in scored_nodes[:limit]:
nodes_result.append({
"uuid": node.uuid,
"name": node.name,
"labels": node.labels,
"summary": node.summary,
})
if node.summary:
facts.append(f"[{node.name}]: {node.summary}")
logger.info(f"本地搜索完成: 找到 {len(facts)} 条相关事实")
except Exception as e:
logger.error(f"本地搜索失败: {str(e)}")
return SearchResult(
facts=facts,
edges=edges_result,
nodes=nodes_result,
query=query,
total_count=len(facts)
)
def get_all_nodes(self, graph_id: str) -> List[NodeInfo]:
"""
获取图谱的所有节点
Args:
graph_id: 图谱ID
Returns:
节点列表
"""
logger.info(f"获取图谱 {graph_id} 的所有节点...")
nodes = self._call_with_retry(
func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id),
operation_name=f"获取节点(graph={graph_id})"
)
result = []
for node in nodes:
result.append(NodeInfo(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {}
))
logger.info(f"获取到 {len(result)} 个节点")
return result
def get_all_edges(self, graph_id: str) -> List[EdgeInfo]:
"""
获取图谱的所有边
Args:
graph_id: 图谱ID
Returns:
边列表
"""
logger.info(f"获取图谱 {graph_id} 的所有边...")
edges = self._call_with_retry(
func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id),
operation_name=f"获取边(graph={graph_id})"
)
result = []
for edge in edges:
result.append(EdgeInfo(
uuid=getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
name=edge.name or "",
fact=edge.fact or "",
source_node_uuid=edge.source_node_uuid or "",
target_node_uuid=edge.target_node_uuid or ""
))
logger.info(f"获取到 {len(result)} 条边")
return result
def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]:
"""
获取单个节点的详细信息
Args:
node_uuid: 节点UUID
Returns:
节点信息或None
"""
logger.info(f"获取节点详情: {node_uuid[:8]}...")
try:
node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=node_uuid),
operation_name=f"获取节点详情(uuid={node_uuid[:8]}...)"
)
if not node:
return None
return NodeInfo(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {}
)
except Exception as e:
logger.error(f"获取节点详情失败: {str(e)}")
return None
def get_node_edges(self, graph_id: str, node_uuid: str) -> List[EdgeInfo]:
"""
获取节点相关的所有边
通过获取图谱所有边,然后过滤出与指定节点相关的边
Args:
graph_id: 图谱ID
node_uuid: 节点UUID
Returns:
边列表
"""
logger.info(f"获取节点 {node_uuid[:8]}... 的相关边")
try:
# 获取图谱所有边,然后过滤
all_edges = self.get_all_edges(graph_id)
result = []
for edge in all_edges:
# 检查边是否与指定节点相关(作为源或目标)
if edge.source_node_uuid == node_uuid or edge.target_node_uuid == node_uuid:
result.append(edge)
logger.info(f"找到 {len(result)} 条与节点相关的边")
return result
except Exception as e:
logger.warning(f"获取节点边失败: {str(e)}")
return []
def get_entities_by_type(
self,
graph_id: str,
entity_type: str
) -> List[NodeInfo]:
"""
按类型获取实体
Args:
graph_id: 图谱ID
entity_type: 实体类型(如 Student, PublicFigure 等)
Returns:
符合类型的实体列表
"""
logger.info(f"获取类型为 {entity_type} 的实体...")
all_nodes = self.get_all_nodes(graph_id)
filtered = []
for node in all_nodes:
# 检查labels是否包含指定类型
if entity_type in node.labels:
filtered.append(node)
logger.info(f"找到 {len(filtered)}{entity_type} 类型的实体")
return filtered
def get_entity_summary(
self,
graph_id: str,
entity_name: str
) -> Dict[str, Any]:
"""
获取指定实体的关系摘要
搜索与该实体相关的所有信息,并生成摘要
Args:
graph_id: 图谱ID
entity_name: 实体名称
Returns:
实体摘要信息
"""
logger.info(f"获取实体 {entity_name} 的关系摘要...")
# 先搜索该实体相关的信息
search_result = self.search_graph(
graph_id=graph_id,
query=entity_name,
limit=20
)
# 尝试在所有节点中找到该实体
all_nodes = self.get_all_nodes(graph_id)
entity_node = None
for node in all_nodes:
if node.name.lower() == entity_name.lower():
entity_node = node
break
related_edges = []
if entity_node:
# 传入graph_id参数
related_edges = self.get_node_edges(graph_id, entity_node.uuid)
return {
"entity_name": entity_name,
"entity_info": entity_node.to_dict() if entity_node else None,
"related_facts": search_result.facts,
"related_edges": [e.to_dict() for e in related_edges],
"total_relations": len(related_edges)
}
def get_graph_statistics(self, graph_id: str) -> Dict[str, Any]:
"""
获取图谱的统计信息
Args:
graph_id: 图谱ID
Returns:
统计信息
"""
logger.info(f"获取图谱 {graph_id} 的统计信息...")
nodes = self.get_all_nodes(graph_id)
edges = self.get_all_edges(graph_id)
# 统计实体类型分布
entity_types = {}
for node in nodes:
for label in node.labels:
if label not in ["Entity", "Node"]:
entity_types[label] = entity_types.get(label, 0) + 1
# 统计关系类型分布
relation_types = {}
for edge in edges:
relation_types[edge.name] = relation_types.get(edge.name, 0) + 1
return {
"graph_id": graph_id,
"total_nodes": len(nodes),
"total_edges": len(edges),
"entity_types": entity_types,
"relation_types": relation_types
}
def get_simulation_context(
self,
graph_id: str,
simulation_requirement: str,
limit: int = 30
) -> Dict[str, Any]:
"""
获取模拟相关的上下文信息
综合搜索与模拟需求相关的所有信息
Args:
graph_id: 图谱ID
simulation_requirement: 模拟需求描述
limit: 每类信息的数量限制
Returns:
模拟上下文信息
"""
logger.info(f"获取模拟上下文: {simulation_requirement[:50]}...")
# 搜索与模拟需求相关的信息
search_result = self.search_graph(
graph_id=graph_id,
query=simulation_requirement,
limit=limit
)
# 获取图谱统计
stats = self.get_graph_statistics(graph_id)
# 获取所有实体节点
all_nodes = self.get_all_nodes(graph_id)
# 筛选有实际类型的实体非纯Entity节点
entities = []
for node in all_nodes:
custom_labels = [l for l in node.labels if l not in ["Entity", "Node"]]
if custom_labels:
entities.append({
"name": node.name,
"type": custom_labels[0],
"summary": node.summary
})
return {
"simulation_requirement": simulation_requirement,
"related_facts": search_result.facts,
"graph_statistics": stats,
"entities": entities[:limit], # 限制数量
"total_entities": len(entities)
}