MiroFish/txt2graph/app.py
666ghj 9657061b26 Add initial implementation of txt2graph tool for knowledge graph generation
- Created a new Streamlit application for visualizing knowledge graphs.
- Implemented text extraction from PDF, Markdown, and TXT files.
- Developed graph building logic using Zep Cloud API.
- Added support for custom entity types and relationships.
- Included interactive HTML visualization for generated graphs.
- Updated .gitignore to include new directories and files.
- Added example environment configuration file (.env.example) for API key setup.
- Created README.md with installation and usage instructions.
- Introduced various utility scripts and styles for enhanced functionality.
2025-11-28 14:07:42 +08:00

497 lines
15 KiB
Python

"""
txt2graph 可视化界面
基于Streamlit和PyVis实现知识图谱可视化
"""
import os
import tempfile
import streamlit as st
from pathlib import Path
from pyvis.network import Network
import streamlit.components.v1 as components
from dotenv import load_dotenv
load_dotenv()
from text_extractor import extract_text, split_text_into_chunks
from graph_builder import ZepGraphBuilder, GraphData
# 页面配置
st.set_page_config(
page_title="txt2graph - 知识图谱生成器",
page_icon="🕸️",
layout="wide",
initial_sidebar_state="expanded"
)
# 自定义CSS样式
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+SC:wght@400;500;700&family=JetBrains+Mono&display=swap');
.main {
font-family: 'Noto Sans SC', sans-serif;
}
.stTitle {
font-weight: 700 !important;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.stats-card {
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
border-radius: 12px;
padding: 20px;
margin: 10px 0;
border: 1px solid rgba(102, 126, 234, 0.3);
}
.stats-number {
font-size: 2.5rem;
font-weight: 700;
color: #667eea;
font-family: 'JetBrains Mono', monospace;
}
.stats-label {
font-size: 0.9rem;
color: #a0a0a0;
text-transform: uppercase;
letter-spacing: 1px;
}
.entity-tag {
display: inline-block;
padding: 4px 12px;
border-radius: 20px;
font-size: 0.8rem;
margin: 2px;
font-weight: 500;
}
.entity-Person { background: rgba(255, 107, 107, 0.2); color: #ff6b6b; border: 1px solid #ff6b6b; }
.entity-Company { background: rgba(78, 205, 196, 0.2); color: #4ecdc4; border: 1px solid #4ecdc4; }
.entity-Organization { background: rgba(69, 183, 209, 0.2); color: #45b7d1; border: 1px solid #45b7d1; }
.entity-Location { background: rgba(150, 206, 180, 0.2); color: #96ceb4; border: 1px solid #96ceb4; }
.entity-Product { background: rgba(255, 238, 173, 0.2); color: #ffeead; border: 1px solid #ffeead; }
.entity-Event { background: rgba(220, 198, 224, 0.2); color: #dcc6e0; border: 1px solid #dcc6e0; }
.entity-Media { background: rgba(255, 183, 77, 0.2); color: #ffb74d; border: 1px solid #ffb74d; }
.sidebar .stButton > button {
width: 100%;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 12px 24px;
border-radius: 8px;
font-weight: 600;
transition: all 0.3s ease;
}
.sidebar .stButton > button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 20px rgba(102, 126, 234, 0.4);
}
</style>
""", unsafe_allow_html=True)
# 实体类型对应的颜色
ENTITY_COLORS = {
"Person": "#ff6b6b",
"Company": "#4ecdc4",
"Organization": "#45b7d1",
"Location": "#96ceb4",
"Product": "#ffeead",
"Event": "#dcc6e0",
"Media": "#ffb74d",
}
def create_pyvis_graph(graph_data: GraphData) -> str:
"""
创建PyVis图并返回HTML
"""
# 创建网络图
net = Network(
height="700px",
width="100%",
bgcolor="#0e1117",
font_color="white",
directed=True,
select_menu=True,
filter_menu=True,
)
# 配置物理引擎
net.set_options("""
{
"nodes": {
"font": {
"size": 14,
"face": "Noto Sans SC, Arial"
},
"borderWidth": 2,
"shadow": true
},
"edges": {
"color": {
"inherit": false,
"color": "#555555",
"highlight": "#667eea"
},
"arrows": {
"to": {
"enabled": true,
"scaleFactor": 0.5
}
},
"smooth": {
"type": "continuous",
"roundness": 0.2
},
"font": {
"size": 10,
"color": "#888888",
"face": "Noto Sans SC, Arial"
}
},
"physics": {
"enabled": true,
"barnesHut": {
"gravitationalConstant": -5000,
"centralGravity": 0.3,
"springLength": 150,
"springConstant": 0.04,
"damping": 0.09
},
"stabilization": {
"enabled": true,
"iterations": 200
}
},
"interaction": {
"hover": true,
"tooltipDelay": 100,
"navigationButtons": true,
"keyboard": true
}
}
""")
# 构建节点UUID到名称的映射
node_map = {node.uuid: node for node in graph_data.nodes}
# 添加节点
for node in graph_data.nodes:
# 确定节点类型和颜色
node_type = node.labels[0] if node.labels else "Unknown"
color = ENTITY_COLORS.get(node_type, "#888888")
# 构建工具提示
title = f"<b>{node.name}</b><br>"
title += f"<i>类型: {node_type}</i><br><br>"
if node.summary:
title += f"{node.summary[:200]}{'...' if len(node.summary) > 200 else ''}"
# 根据节点类型调整大小
size = 25 if node_type == "Person" else 30 if node_type in ["Company", "Organization"] else 20
net.add_node(
node.uuid,
label=node.name,
title=title,
color=color,
size=size,
shape="dot",
)
# 添加边
for edge in graph_data.edges:
if edge.source_node_uuid in node_map and edge.target_node_uuid in node_map:
# 构建边的工具提示
title = edge.fact if edge.fact else edge.name
net.add_edge(
edge.source_node_uuid,
edge.target_node_uuid,
title=title,
label=edge.name[:20] if edge.name else "",
)
# 生成HTML
with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False, encoding='utf-8') as f:
net.save_graph(f.name)
with open(f.name, 'r', encoding='utf-8') as html_file:
html_content = html_file.read()
os.unlink(f.name)
return html_content
def display_stats(graph_data: GraphData):
"""显示图谱统计信息"""
col1, col2, col3 = st.columns(3)
with col1:
st.markdown(f"""
<div class="stats-card">
<div class="stats-number">{len(graph_data.nodes)}</div>
<div class="stats-label">实体节点</div>
</div>
""", unsafe_allow_html=True)
with col2:
st.markdown(f"""
<div class="stats-card">
<div class="stats-number">{len(graph_data.edges)}</div>
<div class="stats-label">关系边</div>
</div>
""", unsafe_allow_html=True)
# 统计实体类型分布
type_counts = {}
for node in graph_data.nodes:
node_type = node.labels[0] if node.labels else "Unknown"
type_counts[node_type] = type_counts.get(node_type, 0) + 1
with col3:
st.markdown(f"""
<div class="stats-card">
<div class="stats-number">{len(type_counts)}</div>
<div class="stats-label">实体类型</div>
</div>
""", unsafe_allow_html=True)
def display_entity_list(graph_data: GraphData):
"""显示实体列表"""
st.subheader("实体列表")
# 按类型分组
entities_by_type = {}
for node in graph_data.nodes:
node_type = node.labels[0] if node.labels else "Unknown"
if node_type not in entities_by_type:
entities_by_type[node_type] = []
entities_by_type[node_type].append(node)
# 创建标签页
if entities_by_type:
tabs = st.tabs(list(entities_by_type.keys()))
for tab, (entity_type, entities) in zip(tabs, entities_by_type.items()):
with tab:
for entity in entities:
with st.expander(f"{entity.name}", expanded=False):
if entity.summary:
st.write(entity.summary)
if entity.attributes:
st.json(entity.attributes)
def main():
# 标题
st.title("txt2graph")
st.markdown("*将文本转化为知识图谱*")
# 侧边栏
with st.sidebar:
st.header("配置")
# API Key
api_key = st.text_input(
"Zep API Key",
type="password",
value=os.environ.get("ZEP_API_KEY", ""),
help="从 https://app.getzep.com 获取API Key"
)
if api_key:
os.environ["ZEP_API_KEY"] = api_key
st.divider()
# 文件上传
st.header("上传文件")
uploaded_file = st.file_uploader(
"支持 .txt, .md, .pdf 文件",
type=["txt", "md", "pdf"],
help="上传要转换为知识图谱的文本文件"
)
# 或者直接输入文本
st.divider()
st.header("或直接输入文本")
text_input = st.text_area(
"输入文本内容",
height=150,
placeholder="在此输入或粘贴文本..."
)
st.divider()
# 高级设置
with st.expander("高级设置"):
chunk_size = st.slider(
"文本分块大小",
min_value=500,
max_value=4000,
value=2000,
step=500,
help="较小的块处理更稳定,较大的块包含更多上下文"
)
graph_name = st.text_input(
"图谱名称",
value="Knowledge Graph",
help="为生成的图谱命名"
)
st.divider()
# 生成按钮
generate_btn = st.button("生成知识图谱", type="primary", use_container_width=True)
# 主内容区
if "graph_data" not in st.session_state:
st.session_state.graph_data = None
if generate_btn:
if not api_key:
st.error("请先配置 Zep API Key")
return
# 获取文本内容
text_content = None
if uploaded_file:
with st.spinner("正在提取文本..."):
# 保存上传的文件到临时位置
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp:
tmp.write(uploaded_file.getvalue())
tmp_path = tmp.name
try:
text_content = extract_text(tmp_path)
finally:
os.unlink(tmp_path)
elif text_input:
text_content = text_input
else:
st.warning("请上传文件或输入文本")
return
if text_content:
st.info(f"提取了 {len(text_content)} 个字符的文本")
# 进度显示
progress_bar = st.progress(0)
status_text = st.empty()
try:
# 创建图谱构建器
builder = ZepGraphBuilder(api_key=api_key)
# 创建图谱
status_text.text("创建图谱...")
progress_bar.progress(10)
graph_id = builder.create_graph(name=graph_name)
# 设置本体
status_text.text("配置实体类型...")
progress_bar.progress(20)
builder.set_ontology(graph_id)
# 分块
status_text.text("分割文本...")
progress_bar.progress(30)
chunks = split_text_into_chunks(text_content, max_chunk_size=chunk_size)
st.info(f"文本已分为 {len(chunks)} 个块")
# 添加到图谱
status_text.text("正在发送数据到Zep...")
progress_bar.progress(40)
def update_progress(msg):
status_text.text(msg)
# 分批发送数据
task_ids = builder.add_text_to_graph(
graph_id=graph_id,
text_chunks=chunks,
batch_size=3,
progress_callback=update_progress
)
# 等待处理完成
progress_bar.progress(60)
status_text.text("等待Zep处理数据...")
if task_ids:
builder.wait_for_tasks(
task_ids,
timeout=600,
progress_callback=update_progress
)
# 获取图数据
status_text.text("获取图谱数据...")
progress_bar.progress(90)
st.session_state.graph_data = builder.get_graph_data(graph_id)
st.session_state.graph_id = graph_id
progress_bar.progress(100)
status_text.text("完成!")
st.success(f"知识图谱生成成功! Graph ID: {graph_id}")
except Exception as e:
st.error(f"生成图谱时出错: {str(e)}")
import traceback
st.code(traceback.format_exc())
# 显示图谱
if st.session_state.graph_data:
graph_data = st.session_state.graph_data
# 统计信息
display_stats(graph_data)
st.divider()
# 图谱可视化
st.subheader("知识图谱可视化")
if graph_data.nodes:
with st.spinner("渲染图谱..."):
html_content = create_pyvis_graph(graph_data)
components.html(html_content, height=750, scrolling=True)
else:
st.warning("图谱中没有节点")
st.divider()
# 实体列表
col1, col2 = st.columns([1, 1])
with col1:
display_entity_list(graph_data)
with col2:
st.subheader("关系列表")
if graph_data.edges:
for edge in graph_data.edges[:50]: # 只显示前50条
st.markdown(f"- **{edge.fact}**" if edge.fact else f"- {edge.name}")
if len(graph_data.edges) > 50:
st.caption(f"...还有 {len(graph_data.edges) - 50} 条关系")
else:
st.info("暂无关系数据")
if __name__ == "__main__":
main()