- Created a new Streamlit application for visualizing knowledge graphs. - Implemented text extraction from PDF, Markdown, and TXT files. - Developed graph building logic using Zep Cloud API. - Added support for custom entity types and relationships. - Included interactive HTML visualization for generated graphs. - Updated .gitignore to include new directories and files. - Added example environment configuration file (.env.example) for API key setup. - Created README.md with installation and usage instructions. - Introduced various utility scripts and styles for enhanced functionality.
415 lines
13 KiB
Python
415 lines
13 KiB
Python
"""
|
||
Zep图谱构建模块
|
||
负责与Zep云服务交互,构建知识图谱
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import uuid
|
||
from typing import Optional, Callable
|
||
from dataclasses import dataclass
|
||
|
||
from zep_cloud.client import Zep
|
||
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
|
||
|
||
from ontology import ENTITY_TYPES, EDGE_TYPES
|
||
|
||
|
||
@dataclass
|
||
class GraphNode:
|
||
"""图节点数据结构"""
|
||
uuid: str
|
||
name: str
|
||
summary: str
|
||
labels: list[str]
|
||
attributes: dict
|
||
|
||
|
||
@dataclass
|
||
class GraphEdge:
|
||
"""图边数据结构"""
|
||
uuid: str
|
||
name: str
|
||
fact: str
|
||
source_node_uuid: str
|
||
target_node_uuid: str
|
||
attributes: dict
|
||
|
||
|
||
@dataclass
|
||
class GraphData:
|
||
"""完整图数据"""
|
||
graph_id: str
|
||
nodes: list[GraphNode]
|
||
edges: list[GraphEdge]
|
||
|
||
|
||
class ZepGraphBuilder:
|
||
"""Zep知识图谱构建器"""
|
||
|
||
def __init__(self, api_key: Optional[str] = None):
|
||
"""
|
||
初始化图谱构建器
|
||
|
||
Args:
|
||
api_key: Zep API密钥,如果不提供则从环境变量ZEP_API_KEY读取
|
||
"""
|
||
self.api_key = api_key or os.environ.get("ZEP_API_KEY")
|
||
if not self.api_key:
|
||
raise ValueError("需要提供ZEP_API_KEY,可以通过参数传入或设置环境变量")
|
||
|
||
self.client = Zep(api_key=self.api_key)
|
||
|
||
def create_graph(self, graph_id: Optional[str] = None, name: str = "Knowledge Graph") -> str:
|
||
"""
|
||
创建新的图谱
|
||
|
||
Args:
|
||
graph_id: 图谱ID,如果不提供则自动生成
|
||
name: 图谱名称
|
||
|
||
Returns:
|
||
图谱ID
|
||
"""
|
||
if graph_id is None:
|
||
graph_id = f"graph_{uuid.uuid4().hex[:16]}"
|
||
|
||
self.client.graph.create(
|
||
graph_id=graph_id,
|
||
name=name,
|
||
description="Knowledge graph generated by txt2graph"
|
||
)
|
||
|
||
return graph_id
|
||
|
||
def set_ontology(self, graph_id: str):
|
||
"""
|
||
为图谱设置自定义本体(实体和边类型)
|
||
|
||
Args:
|
||
graph_id: 图谱ID
|
||
"""
|
||
# 构建边类型的源目标映射
|
||
edge_definitions = {}
|
||
|
||
# WORKS_FOR: Person -> Organization/Company
|
||
edge_definitions["WORKS_FOR"] = (
|
||
EDGE_TYPES["WORKS_FOR"],
|
||
[
|
||
EntityEdgeSourceTarget(source="Person", target="Organization"),
|
||
EntityEdgeSourceTarget(source="Person", target="Company"),
|
||
]
|
||
)
|
||
|
||
# LOCATED_IN: 多种实体 -> Location
|
||
edge_definitions["LOCATED_IN"] = (
|
||
EDGE_TYPES["LOCATED_IN"],
|
||
[
|
||
EntityEdgeSourceTarget(source="Person", target="Location"),
|
||
EntityEdgeSourceTarget(source="Organization", target="Location"),
|
||
EntityEdgeSourceTarget(source="Company", target="Location"),
|
||
EntityEdgeSourceTarget(source="Event", target="Location"),
|
||
]
|
||
)
|
||
|
||
# PART_OF: Organization -> Organization, Company -> Company
|
||
edge_definitions["PART_OF"] = (
|
||
EDGE_TYPES["PART_OF"],
|
||
[
|
||
EntityEdgeSourceTarget(source="Organization", target="Organization"),
|
||
EntityEdgeSourceTarget(source="Company", target="Company"),
|
||
]
|
||
)
|
||
|
||
# PRODUCES: Company -> Product
|
||
edge_definitions["PRODUCES"] = (
|
||
EDGE_TYPES["PRODUCES"],
|
||
[
|
||
EntityEdgeSourceTarget(source="Company", target="Product"),
|
||
EntityEdgeSourceTarget(source="Organization", target="Product"),
|
||
]
|
||
)
|
||
|
||
# PARTICIPATES_IN: Person/Organization/Company -> Event
|
||
edge_definitions["PARTICIPATES_IN"] = (
|
||
EDGE_TYPES["PARTICIPATES_IN"],
|
||
[
|
||
EntityEdgeSourceTarget(source="Person", target="Event"),
|
||
EntityEdgeSourceTarget(source="Organization", target="Event"),
|
||
EntityEdgeSourceTarget(source="Company", target="Event"),
|
||
]
|
||
)
|
||
|
||
# COLLABORATES: 各种实体之间的合作
|
||
edge_definitions["COLLABORATES"] = (
|
||
EDGE_TYPES["COLLABORATES"],
|
||
[
|
||
EntityEdgeSourceTarget(source="Person", target="Person"),
|
||
EntityEdgeSourceTarget(source="Company", target="Company"),
|
||
EntityEdgeSourceTarget(source="Organization", target="Organization"),
|
||
EntityEdgeSourceTarget(source="Company", target="Organization"),
|
||
]
|
||
)
|
||
|
||
# COMPETES: 公司之间的竞争
|
||
edge_definitions["COMPETES"] = (
|
||
EDGE_TYPES["COMPETES"],
|
||
[
|
||
EntityEdgeSourceTarget(source="Company", target="Company"),
|
||
]
|
||
)
|
||
|
||
# REPORTS: Media报道相关实体
|
||
edge_definitions["REPORTS"] = (
|
||
EDGE_TYPES["REPORTS"],
|
||
[
|
||
EntityEdgeSourceTarget(source="Media", target="Person"),
|
||
EntityEdgeSourceTarget(source="Media", target="Company"),
|
||
EntityEdgeSourceTarget(source="Media", target="Organization"),
|
||
EntityEdgeSourceTarget(source="Media", target="Event"),
|
||
]
|
||
)
|
||
|
||
# 设置本体
|
||
self.client.graph.set_ontology(
|
||
graph_ids=[graph_id],
|
||
entities=ENTITY_TYPES,
|
||
edges=edge_definitions,
|
||
)
|
||
|
||
def add_text_to_graph(
|
||
self,
|
||
graph_id: str,
|
||
text_chunks: list[str],
|
||
batch_size: int = 3,
|
||
progress_callback: Optional[Callable] = None
|
||
) -> list[str]:
|
||
"""
|
||
将文本块分批添加到图谱中
|
||
|
||
Args:
|
||
graph_id: 图谱ID
|
||
text_chunks: 文本块列表
|
||
batch_size: 每批发送的块数量
|
||
progress_callback: 进度回调函数
|
||
|
||
Returns:
|
||
任务ID列表
|
||
"""
|
||
task_ids = []
|
||
total_chunks = len(text_chunks)
|
||
|
||
# 分批处理
|
||
for i in range(0, total_chunks, batch_size):
|
||
batch_chunks = text_chunks[i:i + batch_size]
|
||
batch_num = i // batch_size + 1
|
||
total_batches = (total_chunks + batch_size - 1) // batch_size
|
||
|
||
if progress_callback:
|
||
progress_callback(f"发送第 {batch_num}/{total_batches} 批数据 ({len(batch_chunks)} 块)...")
|
||
|
||
# 构建episode数据
|
||
episodes = [
|
||
EpisodeData(data=chunk, type="text")
|
||
for chunk in batch_chunks
|
||
]
|
||
|
||
try:
|
||
# 批量添加
|
||
batch_result = self.client.graph.add_batch(
|
||
graph_id=graph_id,
|
||
episodes=episodes
|
||
)
|
||
|
||
if batch_result and batch_result[0].task_id:
|
||
task_ids.append(batch_result[0].task_id)
|
||
|
||
# 短暂等待,避免请求过快
|
||
time.sleep(1)
|
||
|
||
except Exception as e:
|
||
if progress_callback:
|
||
progress_callback(f"批次 {batch_num} 发送失败: {str(e)}")
|
||
raise
|
||
|
||
return task_ids
|
||
|
||
def wait_for_tasks(
|
||
self,
|
||
task_ids: list[str],
|
||
timeout: int = 600,
|
||
progress_callback: Optional[Callable] = None
|
||
):
|
||
"""
|
||
等待所有任务完成
|
||
|
||
Args:
|
||
task_ids: 任务ID列表
|
||
timeout: 超时时间(秒)
|
||
progress_callback: 进度回调
|
||
"""
|
||
if not task_ids:
|
||
return
|
||
|
||
start_time = time.time()
|
||
pending_tasks = set(task_ids)
|
||
completed_tasks = set()
|
||
|
||
while pending_tasks:
|
||
if time.time() - start_time > timeout:
|
||
if progress_callback:
|
||
progress_callback(f"警告: 部分任务超时,已完成 {len(completed_tasks)}/{len(task_ids)}")
|
||
break
|
||
|
||
for task_id in list(pending_tasks):
|
||
try:
|
||
task = self.client.task.get(task_id=task_id)
|
||
|
||
if task.status == "completed":
|
||
pending_tasks.remove(task_id)
|
||
completed_tasks.add(task_id)
|
||
elif task.status == "failed":
|
||
pending_tasks.remove(task_id)
|
||
if progress_callback:
|
||
progress_callback(f"任务失败: {task.error}")
|
||
|
||
except Exception as e:
|
||
if progress_callback:
|
||
progress_callback(f"检查任务状态出错: {str(e)}")
|
||
|
||
if pending_tasks:
|
||
if progress_callback:
|
||
elapsed = int(time.time() - start_time)
|
||
progress_callback(f"等待处理中... 已完成 {len(completed_tasks)}/{len(task_ids)} ({elapsed}秒)")
|
||
time.sleep(3)
|
||
|
||
if progress_callback:
|
||
progress_callback(f"所有任务处理完成: {len(completed_tasks)}/{len(task_ids)}")
|
||
|
||
def get_graph_data(self, graph_id: str) -> GraphData:
|
||
"""
|
||
获取图谱的完整数据
|
||
|
||
Args:
|
||
graph_id: 图谱ID
|
||
|
||
Returns:
|
||
GraphData对象,包含所有节点和边
|
||
"""
|
||
# 获取所有节点
|
||
raw_nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
||
nodes = [
|
||
GraphNode(
|
||
uuid=node.uuid_,
|
||
name=node.name,
|
||
summary=node.summary or "",
|
||
labels=node.labels or [],
|
||
attributes=node.attributes or {}
|
||
)
|
||
for node in raw_nodes
|
||
]
|
||
|
||
# 获取所有边
|
||
raw_edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
||
edges = [
|
||
GraphEdge(
|
||
uuid=edge.uuid_,
|
||
name=edge.name or "",
|
||
fact=edge.fact or "",
|
||
source_node_uuid=edge.source_node_uuid,
|
||
target_node_uuid=edge.target_node_uuid,
|
||
attributes=edge.attributes or {}
|
||
)
|
||
for edge in raw_edges
|
||
]
|
||
|
||
return GraphData(
|
||
graph_id=graph_id,
|
||
nodes=nodes,
|
||
edges=edges
|
||
)
|
||
|
||
def delete_graph(self, graph_id: str):
|
||
"""删除图谱"""
|
||
self.client.graph.delete(graph_id=graph_id)
|
||
|
||
|
||
def build_graph_from_text(
|
||
text: str,
|
||
graph_name: str = "Knowledge Graph",
|
||
api_key: Optional[str] = None,
|
||
chunk_size: int = 2000,
|
||
progress_callback: Optional[Callable] = None
|
||
) -> GraphData:
|
||
"""
|
||
便捷函数:从文本构建知识图谱
|
||
|
||
Args:
|
||
text: 输入文本
|
||
graph_name: 图谱名称
|
||
api_key: Zep API密钥
|
||
chunk_size: 文本分块大小(默认2000字符)
|
||
progress_callback: 进度回调
|
||
|
||
Returns:
|
||
GraphData对象
|
||
"""
|
||
from text_extractor import split_text_into_chunks
|
||
|
||
builder = ZepGraphBuilder(api_key=api_key)
|
||
|
||
# 创建图谱
|
||
graph_id = builder.create_graph(name=graph_name)
|
||
if progress_callback:
|
||
progress_callback(f"创建图谱: {graph_id}")
|
||
|
||
# 设置本体
|
||
builder.set_ontology(graph_id)
|
||
if progress_callback:
|
||
progress_callback("设置实体类型...")
|
||
|
||
# 分块处理文本
|
||
chunks = split_text_into_chunks(text, max_chunk_size=chunk_size)
|
||
if progress_callback:
|
||
progress_callback(f"文本分为 {len(chunks)} 个块")
|
||
|
||
# 分批添加到图谱
|
||
task_ids = builder.add_text_to_graph(
|
||
graph_id=graph_id,
|
||
text_chunks=chunks,
|
||
batch_size=3,
|
||
progress_callback=progress_callback
|
||
)
|
||
|
||
# 等待所有任务完成
|
||
if task_ids:
|
||
builder.wait_for_tasks(task_ids, progress_callback=progress_callback)
|
||
|
||
# 获取并返回图数据
|
||
return builder.get_graph_data(graph_id)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
|
||
test_text = """
|
||
武汉大学是中国著名的高等学府,位于湖北省武汉市。
|
||
该校的樱花季每年吸引大量游客。
|
||
马化腾是腾讯公司的创始人,腾讯总部位于深圳。
|
||
"""
|
||
|
||
result = build_graph_from_text(
|
||
text=test_text,
|
||
graph_name="测试图谱",
|
||
progress_callback=print
|
||
)
|
||
|
||
print(f"\n节点数: {len(result.nodes)}")
|
||
for node in result.nodes:
|
||
print(f" - {node.name} ({node.labels})")
|
||
|
||
print(f"\n边数: {len(result.edges)}")
|
||
for edge in result.edges:
|
||
print(f" - {edge.fact}")
|