MiroFish/txt2graph/graph_builder.py
666ghj 9657061b26 Add initial implementation of txt2graph tool for knowledge graph generation
- Created a new Streamlit application for visualizing knowledge graphs.
- Implemented text extraction from PDF, Markdown, and TXT files.
- Developed graph building logic using Zep Cloud API.
- Added support for custom entity types and relationships.
- Included interactive HTML visualization for generated graphs.
- Updated .gitignore to include new directories and files.
- Added example environment configuration file (.env.example) for API key setup.
- Created README.md with installation and usage instructions.
- Introduced various utility scripts and styles for enhanced functionality.
2025-11-28 14:07:42 +08:00

415 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Zep图谱构建模块
负责与Zep云服务交互构建知识图谱
"""
import os
import time
import uuid
from typing import Optional, Callable
from dataclasses import dataclass
from zep_cloud.client import Zep
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
from ontology import ENTITY_TYPES, EDGE_TYPES
@dataclass
class GraphNode:
"""图节点数据结构"""
uuid: str
name: str
summary: str
labels: list[str]
attributes: dict
@dataclass
class GraphEdge:
"""图边数据结构"""
uuid: str
name: str
fact: str
source_node_uuid: str
target_node_uuid: str
attributes: dict
@dataclass
class GraphData:
"""完整图数据"""
graph_id: str
nodes: list[GraphNode]
edges: list[GraphEdge]
class ZepGraphBuilder:
"""Zep知识图谱构建器"""
def __init__(self, api_key: Optional[str] = None):
"""
初始化图谱构建器
Args:
api_key: Zep API密钥如果不提供则从环境变量ZEP_API_KEY读取
"""
self.api_key = api_key or os.environ.get("ZEP_API_KEY")
if not self.api_key:
raise ValueError("需要提供ZEP_API_KEY可以通过参数传入或设置环境变量")
self.client = Zep(api_key=self.api_key)
def create_graph(self, graph_id: Optional[str] = None, name: str = "Knowledge Graph") -> str:
"""
创建新的图谱
Args:
graph_id: 图谱ID如果不提供则自动生成
name: 图谱名称
Returns:
图谱ID
"""
if graph_id is None:
graph_id = f"graph_{uuid.uuid4().hex[:16]}"
self.client.graph.create(
graph_id=graph_id,
name=name,
description="Knowledge graph generated by txt2graph"
)
return graph_id
def set_ontology(self, graph_id: str):
"""
为图谱设置自定义本体(实体和边类型)
Args:
graph_id: 图谱ID
"""
# 构建边类型的源目标映射
edge_definitions = {}
# WORKS_FOR: Person -> Organization/Company
edge_definitions["WORKS_FOR"] = (
EDGE_TYPES["WORKS_FOR"],
[
EntityEdgeSourceTarget(source="Person", target="Organization"),
EntityEdgeSourceTarget(source="Person", target="Company"),
]
)
# LOCATED_IN: 多种实体 -> Location
edge_definitions["LOCATED_IN"] = (
EDGE_TYPES["LOCATED_IN"],
[
EntityEdgeSourceTarget(source="Person", target="Location"),
EntityEdgeSourceTarget(source="Organization", target="Location"),
EntityEdgeSourceTarget(source="Company", target="Location"),
EntityEdgeSourceTarget(source="Event", target="Location"),
]
)
# PART_OF: Organization -> Organization, Company -> Company
edge_definitions["PART_OF"] = (
EDGE_TYPES["PART_OF"],
[
EntityEdgeSourceTarget(source="Organization", target="Organization"),
EntityEdgeSourceTarget(source="Company", target="Company"),
]
)
# PRODUCES: Company -> Product
edge_definitions["PRODUCES"] = (
EDGE_TYPES["PRODUCES"],
[
EntityEdgeSourceTarget(source="Company", target="Product"),
EntityEdgeSourceTarget(source="Organization", target="Product"),
]
)
# PARTICIPATES_IN: Person/Organization/Company -> Event
edge_definitions["PARTICIPATES_IN"] = (
EDGE_TYPES["PARTICIPATES_IN"],
[
EntityEdgeSourceTarget(source="Person", target="Event"),
EntityEdgeSourceTarget(source="Organization", target="Event"),
EntityEdgeSourceTarget(source="Company", target="Event"),
]
)
# COLLABORATES: 各种实体之间的合作
edge_definitions["COLLABORATES"] = (
EDGE_TYPES["COLLABORATES"],
[
EntityEdgeSourceTarget(source="Person", target="Person"),
EntityEdgeSourceTarget(source="Company", target="Company"),
EntityEdgeSourceTarget(source="Organization", target="Organization"),
EntityEdgeSourceTarget(source="Company", target="Organization"),
]
)
# COMPETES: 公司之间的竞争
edge_definitions["COMPETES"] = (
EDGE_TYPES["COMPETES"],
[
EntityEdgeSourceTarget(source="Company", target="Company"),
]
)
# REPORTS: Media报道相关实体
edge_definitions["REPORTS"] = (
EDGE_TYPES["REPORTS"],
[
EntityEdgeSourceTarget(source="Media", target="Person"),
EntityEdgeSourceTarget(source="Media", target="Company"),
EntityEdgeSourceTarget(source="Media", target="Organization"),
EntityEdgeSourceTarget(source="Media", target="Event"),
]
)
# 设置本体
self.client.graph.set_ontology(
graph_ids=[graph_id],
entities=ENTITY_TYPES,
edges=edge_definitions,
)
def add_text_to_graph(
self,
graph_id: str,
text_chunks: list[str],
batch_size: int = 3,
progress_callback: Optional[Callable] = None
) -> list[str]:
"""
将文本块分批添加到图谱中
Args:
graph_id: 图谱ID
text_chunks: 文本块列表
batch_size: 每批发送的块数量
progress_callback: 进度回调函数
Returns:
任务ID列表
"""
task_ids = []
total_chunks = len(text_chunks)
# 分批处理
for i in range(0, total_chunks, batch_size):
batch_chunks = text_chunks[i:i + batch_size]
batch_num = i // batch_size + 1
total_batches = (total_chunks + batch_size - 1) // batch_size
if progress_callback:
progress_callback(f"发送第 {batch_num}/{total_batches} 批数据 ({len(batch_chunks)} 块)...")
# 构建episode数据
episodes = [
EpisodeData(data=chunk, type="text")
for chunk in batch_chunks
]
try:
# 批量添加
batch_result = self.client.graph.add_batch(
graph_id=graph_id,
episodes=episodes
)
if batch_result and batch_result[0].task_id:
task_ids.append(batch_result[0].task_id)
# 短暂等待,避免请求过快
time.sleep(1)
except Exception as e:
if progress_callback:
progress_callback(f"批次 {batch_num} 发送失败: {str(e)}")
raise
return task_ids
def wait_for_tasks(
self,
task_ids: list[str],
timeout: int = 600,
progress_callback: Optional[Callable] = None
):
"""
等待所有任务完成
Args:
task_ids: 任务ID列表
timeout: 超时时间(秒)
progress_callback: 进度回调
"""
if not task_ids:
return
start_time = time.time()
pending_tasks = set(task_ids)
completed_tasks = set()
while pending_tasks:
if time.time() - start_time > timeout:
if progress_callback:
progress_callback(f"警告: 部分任务超时,已完成 {len(completed_tasks)}/{len(task_ids)}")
break
for task_id in list(pending_tasks):
try:
task = self.client.task.get(task_id=task_id)
if task.status == "completed":
pending_tasks.remove(task_id)
completed_tasks.add(task_id)
elif task.status == "failed":
pending_tasks.remove(task_id)
if progress_callback:
progress_callback(f"任务失败: {task.error}")
except Exception as e:
if progress_callback:
progress_callback(f"检查任务状态出错: {str(e)}")
if pending_tasks:
if progress_callback:
elapsed = int(time.time() - start_time)
progress_callback(f"等待处理中... 已完成 {len(completed_tasks)}/{len(task_ids)} ({elapsed}秒)")
time.sleep(3)
if progress_callback:
progress_callback(f"所有任务处理完成: {len(completed_tasks)}/{len(task_ids)}")
def get_graph_data(self, graph_id: str) -> GraphData:
"""
获取图谱的完整数据
Args:
graph_id: 图谱ID
Returns:
GraphData对象包含所有节点和边
"""
# 获取所有节点
raw_nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
nodes = [
GraphNode(
uuid=node.uuid_,
name=node.name,
summary=node.summary or "",
labels=node.labels or [],
attributes=node.attributes or {}
)
for node in raw_nodes
]
# 获取所有边
raw_edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
edges = [
GraphEdge(
uuid=edge.uuid_,
name=edge.name or "",
fact=edge.fact or "",
source_node_uuid=edge.source_node_uuid,
target_node_uuid=edge.target_node_uuid,
attributes=edge.attributes or {}
)
for edge in raw_edges
]
return GraphData(
graph_id=graph_id,
nodes=nodes,
edges=edges
)
def delete_graph(self, graph_id: str):
"""删除图谱"""
self.client.graph.delete(graph_id=graph_id)
def build_graph_from_text(
text: str,
graph_name: str = "Knowledge Graph",
api_key: Optional[str] = None,
chunk_size: int = 2000,
progress_callback: Optional[Callable] = None
) -> GraphData:
"""
便捷函数:从文本构建知识图谱
Args:
text: 输入文本
graph_name: 图谱名称
api_key: Zep API密钥
chunk_size: 文本分块大小默认2000字符
progress_callback: 进度回调
Returns:
GraphData对象
"""
from text_extractor import split_text_into_chunks
builder = ZepGraphBuilder(api_key=api_key)
# 创建图谱
graph_id = builder.create_graph(name=graph_name)
if progress_callback:
progress_callback(f"创建图谱: {graph_id}")
# 设置本体
builder.set_ontology(graph_id)
if progress_callback:
progress_callback("设置实体类型...")
# 分块处理文本
chunks = split_text_into_chunks(text, max_chunk_size=chunk_size)
if progress_callback:
progress_callback(f"文本分为 {len(chunks)} 个块")
# 分批添加到图谱
task_ids = builder.add_text_to_graph(
graph_id=graph_id,
text_chunks=chunks,
batch_size=3,
progress_callback=progress_callback
)
# 等待所有任务完成
if task_ids:
builder.wait_for_tasks(task_ids, progress_callback=progress_callback)
# 获取并返回图数据
return builder.get_graph_data(graph_id)
if __name__ == "__main__":
# 测试
from dotenv import load_dotenv
load_dotenv()
test_text = """
武汉大学是中国著名的高等学府,位于湖北省武汉市。
该校的樱花季每年吸引大量游客。
马化腾是腾讯公司的创始人,腾讯总部位于深圳。
"""
result = build_graph_from_text(
text=test_text,
graph_name="测试图谱",
progress_callback=print
)
print(f"\n节点数: {len(result.nodes)}")
for node in result.nodes:
print(f" - {node.name} ({node.labels})")
print(f"\n边数: {len(result.edges)}")
for edge in result.edges:
print(f" - {edge.fact}")