"""
txt2graph 可视化界面
基于Streamlit和PyVis实现知识图谱可视化
"""
import os
import tempfile
import streamlit as st
from pathlib import Path
from pyvis.network import Network
import streamlit.components.v1 as components
from dotenv import load_dotenv
load_dotenv()
from text_extractor import extract_text, split_text_into_chunks
from graph_builder import ZepGraphBuilder, GraphData
# 页面配置
st.set_page_config(
page_title="txt2graph - 知识图谱生成器",
page_icon="🕸️",
layout="wide",
initial_sidebar_state="expanded"
)
# 自定义CSS样式
st.markdown("""
""", unsafe_allow_html=True)
# 实体类型对应的颜色
ENTITY_COLORS = {
"Person": "#ff6b6b",
"Company": "#4ecdc4",
"Organization": "#45b7d1",
"Location": "#96ceb4",
"Product": "#ffeead",
"Event": "#dcc6e0",
"Media": "#ffb74d",
}
def create_pyvis_graph(graph_data: GraphData) -> str:
"""
创建PyVis图并返回HTML
"""
# 创建网络图
net = Network(
height="700px",
width="100%",
bgcolor="#0e1117",
font_color="white",
directed=True,
select_menu=True,
filter_menu=True,
)
# 配置物理引擎
net.set_options("""
{
"nodes": {
"font": {
"size": 14,
"face": "Noto Sans SC, Arial"
},
"borderWidth": 2,
"shadow": true
},
"edges": {
"color": {
"inherit": false,
"color": "#555555",
"highlight": "#667eea"
},
"arrows": {
"to": {
"enabled": true,
"scaleFactor": 0.5
}
},
"smooth": {
"type": "continuous",
"roundness": 0.2
},
"font": {
"size": 10,
"color": "#888888",
"face": "Noto Sans SC, Arial"
}
},
"physics": {
"enabled": true,
"barnesHut": {
"gravitationalConstant": -5000,
"centralGravity": 0.3,
"springLength": 150,
"springConstant": 0.04,
"damping": 0.09
},
"stabilization": {
"enabled": true,
"iterations": 200
}
},
"interaction": {
"hover": true,
"tooltipDelay": 100,
"navigationButtons": true,
"keyboard": true
}
}
""")
# 构建节点UUID到名称的映射
node_map = {node.uuid: node for node in graph_data.nodes}
# 添加节点
for node in graph_data.nodes:
# 确定节点类型和颜色
node_type = node.labels[0] if node.labels else "Unknown"
color = ENTITY_COLORS.get(node_type, "#888888")
# 构建工具提示
title = f"{node.name}
"
title += f"类型: {node_type}
"
if node.summary:
title += f"{node.summary[:200]}{'...' if len(node.summary) > 200 else ''}"
# 根据节点类型调整大小
size = 25 if node_type == "Person" else 30 if node_type in ["Company", "Organization"] else 20
net.add_node(
node.uuid,
label=node.name,
title=title,
color=color,
size=size,
shape="dot",
)
# 添加边
for edge in graph_data.edges:
if edge.source_node_uuid in node_map and edge.target_node_uuid in node_map:
# 构建边的工具提示
title = edge.fact if edge.fact else edge.name
net.add_edge(
edge.source_node_uuid,
edge.target_node_uuid,
title=title,
label=edge.name[:20] if edge.name else "",
)
# 生成HTML
with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False, encoding='utf-8') as f:
net.save_graph(f.name)
with open(f.name, 'r', encoding='utf-8') as html_file:
html_content = html_file.read()
os.unlink(f.name)
return html_content
def display_stats(graph_data: GraphData):
"""显示图谱统计信息"""
col1, col2, col3 = st.columns(3)
with col1:
st.markdown(f"""