- Updated the `get_graph_data` method in `graph_builder.py` to include additional attributes such as creation time, validity periods, and episodes for nodes and edges, improving the richness of the graph data. - Modified the detail panel in `Process.vue` to present new attributes, including properties, episodes, and timestamps, enhancing user interaction and data visibility. - Improved styling for the detail panel to better organize and present the comprehensive information for selected nodes and edges.
499 lines
17 KiB
Python
499 lines
17 KiB
Python
"""
|
||
图谱构建服务
|
||
接口2:使用Zep API构建Standalone Graph
|
||
"""
|
||
|
||
import os
|
||
import uuid
|
||
import time
|
||
import threading
|
||
from typing import Dict, Any, List, Optional, Callable
|
||
from dataclasses import dataclass
|
||
|
||
from zep_cloud.client import Zep
|
||
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
|
||
|
||
from ..config import Config
|
||
from ..models.task import TaskManager, TaskStatus
|
||
from .text_processor import TextProcessor
|
||
|
||
|
||
@dataclass
|
||
class GraphInfo:
|
||
"""图谱信息"""
|
||
graph_id: str
|
||
node_count: int
|
||
edge_count: int
|
||
entity_types: List[str]
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"graph_id": self.graph_id,
|
||
"node_count": self.node_count,
|
||
"edge_count": self.edge_count,
|
||
"entity_types": self.entity_types,
|
||
}
|
||
|
||
|
||
class GraphBuilderService:
|
||
"""
|
||
图谱构建服务
|
||
负责调用Zep API构建知识图谱
|
||
"""
|
||
|
||
def __init__(self, api_key: Optional[str] = None):
|
||
self.api_key = api_key or Config.ZEP_API_KEY
|
||
if not self.api_key:
|
||
raise ValueError("ZEP_API_KEY 未配置")
|
||
|
||
self.client = Zep(api_key=self.api_key)
|
||
self.task_manager = TaskManager()
|
||
|
||
def build_graph_async(
|
||
self,
|
||
text: str,
|
||
ontology: Dict[str, Any],
|
||
graph_name: str = "MiroFish Graph",
|
||
chunk_size: int = 500,
|
||
chunk_overlap: int = 50,
|
||
batch_size: int = 3
|
||
) -> str:
|
||
"""
|
||
异步构建图谱
|
||
|
||
Args:
|
||
text: 输入文本
|
||
ontology: 本体定义(来自接口1的输出)
|
||
graph_name: 图谱名称
|
||
chunk_size: 文本块大小
|
||
chunk_overlap: 块重叠大小
|
||
batch_size: 每批发送的块数量
|
||
|
||
Returns:
|
||
任务ID
|
||
"""
|
||
# 创建任务
|
||
task_id = self.task_manager.create_task(
|
||
task_type="graph_build",
|
||
metadata={
|
||
"graph_name": graph_name,
|
||
"chunk_size": chunk_size,
|
||
"text_length": len(text),
|
||
}
|
||
)
|
||
|
||
# 在后台线程中执行构建
|
||
thread = threading.Thread(
|
||
target=self._build_graph_worker,
|
||
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size)
|
||
)
|
||
thread.daemon = True
|
||
thread.start()
|
||
|
||
return task_id
|
||
|
||
def _build_graph_worker(
|
||
self,
|
||
task_id: str,
|
||
text: str,
|
||
ontology: Dict[str, Any],
|
||
graph_name: str,
|
||
chunk_size: int,
|
||
chunk_overlap: int,
|
||
batch_size: int
|
||
):
|
||
"""图谱构建工作线程"""
|
||
try:
|
||
self.task_manager.update_task(
|
||
task_id,
|
||
status=TaskStatus.PROCESSING,
|
||
progress=5,
|
||
message="开始构建图谱..."
|
||
)
|
||
|
||
# 1. 创建图谱
|
||
graph_id = self.create_graph(graph_name)
|
||
self.task_manager.update_task(
|
||
task_id,
|
||
progress=10,
|
||
message=f"图谱已创建: {graph_id}"
|
||
)
|
||
|
||
# 2. 设置本体
|
||
self.set_ontology(graph_id, ontology)
|
||
self.task_manager.update_task(
|
||
task_id,
|
||
progress=15,
|
||
message="本体已设置"
|
||
)
|
||
|
||
# 3. 文本分块
|
||
chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap)
|
||
total_chunks = len(chunks)
|
||
self.task_manager.update_task(
|
||
task_id,
|
||
progress=20,
|
||
message=f"文本已分割为 {total_chunks} 个块"
|
||
)
|
||
|
||
# 4. 分批发送数据
|
||
episode_uuids = self.add_text_batches(
|
||
graph_id, chunks, batch_size,
|
||
lambda msg, prog: self.task_manager.update_task(
|
||
task_id,
|
||
progress=20 + int(prog * 0.4), # 20-60%
|
||
message=msg
|
||
)
|
||
)
|
||
|
||
# 5. 等待Zep处理完成
|
||
self.task_manager.update_task(
|
||
task_id,
|
||
progress=60,
|
||
message="等待Zep处理数据..."
|
||
)
|
||
|
||
self._wait_for_episodes(
|
||
episode_uuids,
|
||
lambda msg, prog: self.task_manager.update_task(
|
||
task_id,
|
||
progress=60 + int(prog * 0.3), # 60-90%
|
||
message=msg
|
||
)
|
||
)
|
||
|
||
# 6. 获取图谱信息
|
||
self.task_manager.update_task(
|
||
task_id,
|
||
progress=90,
|
||
message="获取图谱信息..."
|
||
)
|
||
|
||
graph_info = self._get_graph_info(graph_id)
|
||
|
||
# 完成
|
||
self.task_manager.complete_task(task_id, {
|
||
"graph_id": graph_id,
|
||
"graph_info": graph_info.to_dict(),
|
||
"chunks_processed": total_chunks,
|
||
})
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_msg = f"{str(e)}\n{traceback.format_exc()}"
|
||
self.task_manager.fail_task(task_id, error_msg)
|
||
|
||
def create_graph(self, name: str) -> str:
|
||
"""创建Zep图谱(公开方法)"""
|
||
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
|
||
|
||
self.client.graph.create(
|
||
graph_id=graph_id,
|
||
name=name,
|
||
description="MiroFish Social Simulation Graph"
|
||
)
|
||
|
||
return graph_id
|
||
|
||
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
|
||
"""设置图谱本体(公开方法)"""
|
||
import warnings
|
||
from typing import Optional
|
||
from pydantic import Field
|
||
from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel
|
||
|
||
# 抑制 Pydantic v2 关于 Field(default=None) 的警告
|
||
# 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略
|
||
warnings.filterwarnings('ignore', category=UserWarning, module='pydantic')
|
||
|
||
# Zep 保留名称,不能作为属性名
|
||
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
|
||
|
||
def safe_attr_name(attr_name: str) -> str:
|
||
"""将保留名称转换为安全名称"""
|
||
if attr_name.lower() in RESERVED_NAMES:
|
||
return f"entity_{attr_name}"
|
||
return attr_name
|
||
|
||
# 动态创建实体类型
|
||
entity_types = {}
|
||
for entity_def in ontology.get("entity_types", []):
|
||
name = entity_def["name"]
|
||
description = entity_def.get("description", f"A {name} entity.")
|
||
|
||
# 创建属性字典和类型注解(Pydantic v2 需要)
|
||
attrs = {"__doc__": description}
|
||
annotations = {}
|
||
|
||
for attr_def in entity_def.get("attributes", []):
|
||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
||
attr_desc = attr_def.get("description", attr_name)
|
||
# Zep API 需要 Field 的 description,这是必需的
|
||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||
annotations[attr_name] = Optional[EntityText] # 类型注解
|
||
|
||
attrs["__annotations__"] = annotations
|
||
|
||
# 动态创建类
|
||
entity_class = type(name, (EntityModel,), attrs)
|
||
entity_class.__doc__ = description
|
||
entity_types[name] = entity_class
|
||
|
||
# 动态创建边类型
|
||
edge_definitions = {}
|
||
for edge_def in ontology.get("edge_types", []):
|
||
name = edge_def["name"]
|
||
description = edge_def.get("description", f"A {name} relationship.")
|
||
|
||
# 创建属性字典和类型注解
|
||
attrs = {"__doc__": description}
|
||
annotations = {}
|
||
|
||
for attr_def in edge_def.get("attributes", []):
|
||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
||
attr_desc = attr_def.get("description", attr_name)
|
||
# Zep API 需要 Field 的 description,这是必需的
|
||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||
annotations[attr_name] = Optional[str] # 边属性用str类型
|
||
|
||
attrs["__annotations__"] = annotations
|
||
|
||
# 动态创建类
|
||
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
||
edge_class = type(class_name, (EdgeModel,), attrs)
|
||
edge_class.__doc__ = description
|
||
|
||
# 构建source_targets
|
||
source_targets = []
|
||
for st in edge_def.get("source_targets", []):
|
||
source_targets.append(
|
||
EntityEdgeSourceTarget(
|
||
source=st.get("source", "Entity"),
|
||
target=st.get("target", "Entity")
|
||
)
|
||
)
|
||
|
||
if source_targets:
|
||
edge_definitions[name] = (edge_class, source_targets)
|
||
|
||
# 调用Zep API设置本体
|
||
if entity_types or edge_definitions:
|
||
self.client.graph.set_ontology(
|
||
graph_ids=[graph_id],
|
||
entities=entity_types if entity_types else None,
|
||
edges=edge_definitions if edge_definitions else None,
|
||
)
|
||
|
||
def add_text_batches(
|
||
self,
|
||
graph_id: str,
|
||
chunks: List[str],
|
||
batch_size: int = 3,
|
||
progress_callback: Optional[Callable] = None
|
||
) -> List[str]:
|
||
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表"""
|
||
episode_uuids = []
|
||
total_chunks = len(chunks)
|
||
|
||
for i in range(0, total_chunks, batch_size):
|
||
batch_chunks = chunks[i:i + batch_size]
|
||
batch_num = i // batch_size + 1
|
||
total_batches = (total_chunks + batch_size - 1) // batch_size
|
||
|
||
if progress_callback:
|
||
progress = (i + len(batch_chunks)) / total_chunks
|
||
progress_callback(
|
||
f"发送第 {batch_num}/{total_batches} 批数据 ({len(batch_chunks)} 块)...",
|
||
progress
|
||
)
|
||
|
||
# 构建episode数据
|
||
episodes = [
|
||
EpisodeData(data=chunk, type="text")
|
||
for chunk in batch_chunks
|
||
]
|
||
|
||
# 发送到Zep
|
||
try:
|
||
batch_result = self.client.graph.add_batch(
|
||
graph_id=graph_id,
|
||
episodes=episodes
|
||
)
|
||
|
||
# 收集返回的 episode uuid
|
||
if batch_result and isinstance(batch_result, list):
|
||
for ep in batch_result:
|
||
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
|
||
if ep_uuid:
|
||
episode_uuids.append(ep_uuid)
|
||
|
||
# 避免请求过快
|
||
time.sleep(1)
|
||
|
||
except Exception as e:
|
||
if progress_callback:
|
||
progress_callback(f"批次 {batch_num} 发送失败: {str(e)}", 0)
|
||
raise
|
||
|
||
return episode_uuids
|
||
|
||
def _wait_for_episodes(
|
||
self,
|
||
episode_uuids: List[str],
|
||
progress_callback: Optional[Callable] = None,
|
||
timeout: int = 600
|
||
):
|
||
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)"""
|
||
if not episode_uuids:
|
||
if progress_callback:
|
||
progress_callback("无需等待(没有 episode)", 1.0)
|
||
return
|
||
|
||
start_time = time.time()
|
||
pending_episodes = set(episode_uuids)
|
||
completed_count = 0
|
||
total_episodes = len(episode_uuids)
|
||
|
||
if progress_callback:
|
||
progress_callback(f"开始等待 {total_episodes} 个文本块处理...", 0)
|
||
|
||
while pending_episodes:
|
||
if time.time() - start_time > timeout:
|
||
if progress_callback:
|
||
progress_callback(
|
||
f"部分文本块超时,已完成 {completed_count}/{total_episodes}",
|
||
completed_count / total_episodes
|
||
)
|
||
break
|
||
|
||
# 检查每个 episode 的处理状态
|
||
for ep_uuid in list(pending_episodes):
|
||
try:
|
||
episode = self.client.graph.episode.get(uuid_=ep_uuid)
|
||
is_processed = getattr(episode, 'processed', False)
|
||
|
||
if is_processed:
|
||
pending_episodes.remove(ep_uuid)
|
||
completed_count += 1
|
||
|
||
except Exception as e:
|
||
# 忽略单个查询错误,继续
|
||
pass
|
||
|
||
elapsed = int(time.time() - start_time)
|
||
if progress_callback:
|
||
progress_callback(
|
||
f"Zep处理中... {completed_count}/{total_episodes} 完成, {len(pending_episodes)} 待处理 ({elapsed}秒)",
|
||
completed_count / total_episodes if total_episodes > 0 else 0
|
||
)
|
||
|
||
if pending_episodes:
|
||
time.sleep(3) # 每3秒检查一次
|
||
|
||
if progress_callback:
|
||
progress_callback(f"处理完成: {completed_count}/{total_episodes}", 1.0)
|
||
|
||
def _get_graph_info(self, graph_id: str) -> GraphInfo:
|
||
"""获取图谱信息"""
|
||
# 获取节点
|
||
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
||
|
||
# 获取边
|
||
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
||
|
||
# 统计实体类型
|
||
entity_types = set()
|
||
for node in nodes:
|
||
if node.labels:
|
||
for label in node.labels:
|
||
if label not in ["Entity", "Node"]:
|
||
entity_types.add(label)
|
||
|
||
return GraphInfo(
|
||
graph_id=graph_id,
|
||
node_count=len(nodes),
|
||
edge_count=len(edges),
|
||
entity_types=list(entity_types)
|
||
)
|
||
|
||
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
|
||
"""
|
||
获取完整图谱数据(包含详细信息)
|
||
|
||
Args:
|
||
graph_id: 图谱ID
|
||
|
||
Returns:
|
||
包含nodes和edges的字典,包括时间信息、属性等详细数据
|
||
"""
|
||
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
||
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
||
|
||
# 创建节点映射用于获取节点名称
|
||
node_map = {}
|
||
for node in nodes:
|
||
node_map[node.uuid_] = node.name or ""
|
||
|
||
nodes_data = []
|
||
for node in nodes:
|
||
# 获取创建时间
|
||
created_at = getattr(node, 'created_at', None)
|
||
if created_at:
|
||
created_at = str(created_at)
|
||
|
||
nodes_data.append({
|
||
"uuid": node.uuid_,
|
||
"name": node.name,
|
||
"labels": node.labels or [],
|
||
"summary": node.summary or "",
|
||
"attributes": node.attributes or {},
|
||
"created_at": created_at,
|
||
})
|
||
|
||
edges_data = []
|
||
for edge in edges:
|
||
# 获取时间信息
|
||
created_at = getattr(edge, 'created_at', None)
|
||
valid_at = getattr(edge, 'valid_at', None)
|
||
invalid_at = getattr(edge, 'invalid_at', None)
|
||
expired_at = getattr(edge, 'expired_at', None)
|
||
|
||
# 获取 episodes
|
||
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
|
||
if episodes and not isinstance(episodes, list):
|
||
episodes = [str(episodes)]
|
||
elif episodes:
|
||
episodes = [str(e) for e in episodes]
|
||
|
||
# 获取 fact_type
|
||
fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
|
||
|
||
edges_data.append({
|
||
"uuid": edge.uuid_,
|
||
"name": edge.name or "",
|
||
"fact": edge.fact or "",
|
||
"fact_type": fact_type,
|
||
"source_node_uuid": edge.source_node_uuid,
|
||
"target_node_uuid": edge.target_node_uuid,
|
||
"source_node_name": node_map.get(edge.source_node_uuid, ""),
|
||
"target_node_name": node_map.get(edge.target_node_uuid, ""),
|
||
"attributes": edge.attributes or {},
|
||
"created_at": str(created_at) if created_at else None,
|
||
"valid_at": str(valid_at) if valid_at else None,
|
||
"invalid_at": str(invalid_at) if invalid_at else None,
|
||
"expired_at": str(expired_at) if expired_at else None,
|
||
"episodes": episodes or [],
|
||
})
|
||
|
||
return {
|
||
"graph_id": graph_id,
|
||
"nodes": nodes_data,
|
||
"edges": edges_data,
|
||
"node_count": len(nodes_data),
|
||
"edge_count": len(edges_data),
|
||
}
|
||
|
||
def delete_graph(self, graph_id: str):
|
||
"""删除图谱"""
|
||
self.client.graph.delete(graph_id=graph_id)
|
||
|