MiroFish/backend/app/services/zep_entity_reader.py
666ghj 5f159f6d88 Enhance backend functionality with OASIS simulation features
- Updated README.md to include new simulation scripts and configuration details for OASIS, including API retry mechanisms and environment variable settings.
- Added simulation management and configuration generation services to streamline the simulation process across Twitter and Reddit platforms.
- Introduced new API routes for simulation-related operations, including entity retrieval and simulation status management.
- Implemented a robust retry mechanism for external API calls to improve system stability.
- Enhanced task management model to include detailed progress tracking.
- Added logging capabilities for action tracking during simulations.
- Included new scripts for running parallel simulations and testing profile formats.
2025-12-01 15:03:44 +08:00

386 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Zep实体读取与过滤服务
从Zep图谱中读取节点筛选出符合预定义实体类型的节点
"""
from typing import Dict, Any, List, Optional, Set
from dataclasses import dataclass, field
from zep_cloud.client import Zep
from ..config import Config
from ..utils.logger import get_logger
logger = get_logger('mirofish.zep_entity_reader')
@dataclass
class EntityNode:
"""实体节点数据结构"""
uuid: str
name: str
labels: List[str]
summary: str
attributes: Dict[str, Any]
# 相关的边信息
related_edges: List[Dict[str, Any]] = field(default_factory=list)
# 相关的其他节点信息
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"uuid": self.uuid,
"name": self.name,
"labels": self.labels,
"summary": self.summary,
"attributes": self.attributes,
"related_edges": self.related_edges,
"related_nodes": self.related_nodes,
}
def get_entity_type(self) -> Optional[str]:
"""获取实体类型排除默认的Entity标签"""
for label in self.labels:
if label not in ["Entity", "Node"]:
return label
return None
@dataclass
class FilteredEntities:
"""过滤后的实体集合"""
entities: List[EntityNode]
entity_types: Set[str]
total_count: int
filtered_count: int
def to_dict(self) -> Dict[str, Any]:
return {
"entities": [e.to_dict() for e in self.entities],
"entity_types": list(self.entity_types),
"total_count": self.total_count,
"filtered_count": self.filtered_count,
}
class ZepEntityReader:
"""
Zep实体读取与过滤服务
主要功能:
1. 从Zep图谱读取所有节点
2. 筛选出符合预定义实体类型的节点Labels不只是Entity的节点
3. 获取每个实体的相关边和关联节点信息
"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY 未配置")
self.client = Zep(api_key=self.api_key)
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
"""
获取图谱的所有节点
Args:
graph_id: 图谱ID
Returns:
节点列表
"""
logger.info(f"获取图谱 {graph_id} 的所有节点...")
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
nodes_data = []
for node in nodes:
nodes_data.append({
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
"name": node.name or "",
"labels": node.labels or [],
"summary": node.summary or "",
"attributes": node.attributes or {},
})
logger.info(f"共获取 {len(nodes_data)} 个节点")
return nodes_data
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
"""
获取图谱的所有边
Args:
graph_id: 图谱ID
Returns:
边列表
"""
logger.info(f"获取图谱 {graph_id} 的所有边...")
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
edges_data = []
for edge in edges:
edges_data.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": edge.name or "",
"fact": edge.fact or "",
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {},
})
logger.info(f"共获取 {len(edges_data)} 条边")
return edges_data
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
"""
获取指定节点的所有相关边
Args:
node_uuid: 节点UUID
Returns:
边列表
"""
try:
edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid)
edges_data = []
for edge in edges:
edges_data.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": edge.name or "",
"fact": edge.fact or "",
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {},
})
return edges_data
except Exception as e:
logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}")
return []
def filter_defined_entities(
self,
graph_id: str,
defined_entity_types: Optional[List[str]] = None,
enrich_with_edges: bool = True
) -> FilteredEntities:
"""
筛选出符合预定义实体类型的节点
筛选逻辑:
- 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过
- 如果节点的Labels包含除"Entity""Node"之外的标签,说明符合预定义类型,保留
Args:
graph_id: 图谱ID
defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型)
enrich_with_edges: 是否获取每个实体的相关边信息
Returns:
FilteredEntities: 过滤后的实体集合
"""
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
# 获取所有节点
all_nodes = self.get_all_nodes(graph_id)
total_count = len(all_nodes)
# 获取所有边(用于后续关联查找)
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
# 构建节点UUID到节点数据的映射
node_map = {n["uuid"]: n for n in all_nodes}
# 筛选符合条件的实体
filtered_entities = []
entity_types_found = set()
for node in all_nodes:
labels = node.get("labels", [])
# 筛选逻辑Labels必须包含除"Entity"和"Node"之外的标签
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
if not custom_labels:
# 只有默认标签,跳过
continue
# 如果指定了预定义类型,检查是否匹配
if defined_entity_types:
matching_labels = [l for l in custom_labels if l in defined_entity_types]
if not matching_labels:
continue
entity_type = matching_labels[0]
else:
entity_type = custom_labels[0]
entity_types_found.add(entity_type)
# 创建实体节点对象
entity = EntityNode(
uuid=node["uuid"],
name=node["name"],
labels=labels,
summary=node["summary"],
attributes=node["attributes"],
)
# 获取相关边和节点
if enrich_with_edges:
related_edges = []
related_node_uuids = set()
for edge in all_edges:
if edge["source_node_uuid"] == node["uuid"]:
related_edges.append({
"direction": "outgoing",
"edge_name": edge["name"],
"fact": edge["fact"],
"target_node_uuid": edge["target_node_uuid"],
})
related_node_uuids.add(edge["target_node_uuid"])
elif edge["target_node_uuid"] == node["uuid"]:
related_edges.append({
"direction": "incoming",
"edge_name": edge["name"],
"fact": edge["fact"],
"source_node_uuid": edge["source_node_uuid"],
})
related_node_uuids.add(edge["source_node_uuid"])
entity.related_edges = related_edges
# 获取关联节点的基本信息
related_nodes = []
for related_uuid in related_node_uuids:
if related_uuid in node_map:
related_node = node_map[related_uuid]
related_nodes.append({
"uuid": related_node["uuid"],
"name": related_node["name"],
"labels": related_node["labels"],
"summary": related_node.get("summary", ""),
})
entity.related_nodes = related_nodes
filtered_entities.append(entity)
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
f"实体类型: {entity_types_found}")
return FilteredEntities(
entities=filtered_entities,
entity_types=entity_types_found,
total_count=total_count,
filtered_count=len(filtered_entities),
)
def get_entity_with_context(
self,
graph_id: str,
entity_uuid: str
) -> Optional[EntityNode]:
"""
获取单个实体及其完整上下文(边和关联节点)
Args:
graph_id: 图谱ID
entity_uuid: 实体UUID
Returns:
EntityNode或None
"""
try:
# 获取节点
node = self.client.graph.node.get(uuid_=entity_uuid)
if not node:
return None
# 获取节点的边
edges = self.get_node_edges(entity_uuid)
# 获取所有节点用于关联查找
all_nodes = self.get_all_nodes(graph_id)
node_map = {n["uuid"]: n for n in all_nodes}
# 处理相关边和节点
related_edges = []
related_node_uuids = set()
for edge in edges:
if edge["source_node_uuid"] == entity_uuid:
related_edges.append({
"direction": "outgoing",
"edge_name": edge["name"],
"fact": edge["fact"],
"target_node_uuid": edge["target_node_uuid"],
})
related_node_uuids.add(edge["target_node_uuid"])
else:
related_edges.append({
"direction": "incoming",
"edge_name": edge["name"],
"fact": edge["fact"],
"source_node_uuid": edge["source_node_uuid"],
})
related_node_uuids.add(edge["source_node_uuid"])
# 获取关联节点信息
related_nodes = []
for related_uuid in related_node_uuids:
if related_uuid in node_map:
related_node = node_map[related_uuid]
related_nodes.append({
"uuid": related_node["uuid"],
"name": related_node["name"],
"labels": related_node["labels"],
"summary": related_node.get("summary", ""),
})
return EntityNode(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {},
related_edges=related_edges,
related_nodes=related_nodes,
)
except Exception as e:
logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}")
return None
def get_entities_by_type(
self,
graph_id: str,
entity_type: str,
enrich_with_edges: bool = True
) -> List[EntityNode]:
"""
获取指定类型的所有实体
Args:
graph_id: 图谱ID
entity_type: 实体类型(如 "Student", "PublicFigure" 等)
enrich_with_edges: 是否获取相关边信息
Returns:
实体列表
"""
result = self.filter_defined_entities(
graph_id=graph_id,
defined_entity_types=[entity_type],
enrich_with_edges=enrich_with_edges
)
return result.entities