""" txt2graph 可视化界面 基于Streamlit和PyVis实现知识图谱可视化 """ import os import tempfile import streamlit as st from pathlib import Path from pyvis.network import Network import streamlit.components.v1 as components from dotenv import load_dotenv load_dotenv() from text_extractor import extract_text, split_text_into_chunks from graph_builder import ZepGraphBuilder, GraphData # 页面配置 st.set_page_config( page_title="txt2graph - 知识图谱生成器", page_icon="🕸️", layout="wide", initial_sidebar_state="expanded" ) # 自定义CSS样式 st.markdown(""" """, unsafe_allow_html=True) # 实体类型对应的颜色 ENTITY_COLORS = { "Person": "#ff6b6b", "Company": "#4ecdc4", "Organization": "#45b7d1", "Location": "#96ceb4", "Product": "#ffeead", "Event": "#dcc6e0", "Media": "#ffb74d", } def create_pyvis_graph(graph_data: GraphData) -> str: """ 创建PyVis图并返回HTML """ # 创建网络图 net = Network( height="700px", width="100%", bgcolor="#0e1117", font_color="white", directed=True, select_menu=True, filter_menu=True, ) # 配置物理引擎 net.set_options(""" { "nodes": { "font": { "size": 14, "face": "Noto Sans SC, Arial" }, "borderWidth": 2, "shadow": true }, "edges": { "color": { "inherit": false, "color": "#555555", "highlight": "#667eea" }, "arrows": { "to": { "enabled": true, "scaleFactor": 0.5 } }, "smooth": { "type": "continuous", "roundness": 0.2 }, "font": { "size": 10, "color": "#888888", "face": "Noto Sans SC, Arial" } }, "physics": { "enabled": true, "barnesHut": { "gravitationalConstant": -5000, "centralGravity": 0.3, "springLength": 150, "springConstant": 0.04, "damping": 0.09 }, "stabilization": { "enabled": true, "iterations": 200 } }, "interaction": { "hover": true, "tooltipDelay": 100, "navigationButtons": true, "keyboard": true } } """) # 构建节点UUID到名称的映射 node_map = {node.uuid: node for node in graph_data.nodes} # 添加节点 for node in graph_data.nodes: # 确定节点类型和颜色 node_type = node.labels[0] if node.labels else "Unknown" color = ENTITY_COLORS.get(node_type, "#888888") # 构建工具提示 title = f"{node.name}
" title += f"类型: {node_type}

" if node.summary: title += f"{node.summary[:200]}{'...' if len(node.summary) > 200 else ''}" # 根据节点类型调整大小 size = 25 if node_type == "Person" else 30 if node_type in ["Company", "Organization"] else 20 net.add_node( node.uuid, label=node.name, title=title, color=color, size=size, shape="dot", ) # 添加边 for edge in graph_data.edges: if edge.source_node_uuid in node_map and edge.target_node_uuid in node_map: # 构建边的工具提示 title = edge.fact if edge.fact else edge.name net.add_edge( edge.source_node_uuid, edge.target_node_uuid, title=title, label=edge.name[:20] if edge.name else "", ) # 生成HTML with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False, encoding='utf-8') as f: net.save_graph(f.name) with open(f.name, 'r', encoding='utf-8') as html_file: html_content = html_file.read() os.unlink(f.name) return html_content def display_stats(graph_data: GraphData): """显示图谱统计信息""" col1, col2, col3 = st.columns(3) with col1: st.markdown(f"""
{len(graph_data.nodes)}
实体节点
""", unsafe_allow_html=True) with col2: st.markdown(f"""
{len(graph_data.edges)}
关系边
""", unsafe_allow_html=True) # 统计实体类型分布 type_counts = {} for node in graph_data.nodes: node_type = node.labels[0] if node.labels else "Unknown" type_counts[node_type] = type_counts.get(node_type, 0) + 1 with col3: st.markdown(f"""
{len(type_counts)}
实体类型
""", unsafe_allow_html=True) def display_entity_list(graph_data: GraphData): """显示实体列表""" st.subheader("实体列表") # 按类型分组 entities_by_type = {} for node in graph_data.nodes: node_type = node.labels[0] if node.labels else "Unknown" if node_type not in entities_by_type: entities_by_type[node_type] = [] entities_by_type[node_type].append(node) # 创建标签页 if entities_by_type: tabs = st.tabs(list(entities_by_type.keys())) for tab, (entity_type, entities) in zip(tabs, entities_by_type.items()): with tab: for entity in entities: with st.expander(f"{entity.name}", expanded=False): if entity.summary: st.write(entity.summary) if entity.attributes: st.json(entity.attributes) def main(): # 标题 st.title("txt2graph") st.markdown("*将文本转化为知识图谱*") # 侧边栏 with st.sidebar: st.header("配置") # API Key api_key = st.text_input( "Zep API Key", type="password", value=os.environ.get("ZEP_API_KEY", ""), help="从 https://app.getzep.com 获取API Key" ) if api_key: os.environ["ZEP_API_KEY"] = api_key st.divider() # 文件上传 st.header("上传文件") uploaded_file = st.file_uploader( "支持 .txt, .md, .pdf 文件", type=["txt", "md", "pdf"], help="上传要转换为知识图谱的文本文件" ) # 或者直接输入文本 st.divider() st.header("或直接输入文本") text_input = st.text_area( "输入文本内容", height=150, placeholder="在此输入或粘贴文本..." ) st.divider() # 高级设置 with st.expander("高级设置"): chunk_size = st.slider( "文本分块大小", min_value=500, max_value=4000, value=2000, step=500, help="较小的块处理更稳定,较大的块包含更多上下文" ) graph_name = st.text_input( "图谱名称", value="Knowledge Graph", help="为生成的图谱命名" ) st.divider() # 生成按钮 generate_btn = st.button("生成知识图谱", type="primary", use_container_width=True) # 主内容区 if "graph_data" not in st.session_state: st.session_state.graph_data = None if generate_btn: if not api_key: st.error("请先配置 Zep API Key") return # 获取文本内容 text_content = None if uploaded_file: with st.spinner("正在提取文本..."): # 保存上传的文件到临时位置 with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp: tmp.write(uploaded_file.getvalue()) tmp_path = tmp.name try: text_content = extract_text(tmp_path) finally: os.unlink(tmp_path) elif text_input: text_content = text_input else: st.warning("请上传文件或输入文本") return if text_content: st.info(f"提取了 {len(text_content)} 个字符的文本") # 进度显示 progress_bar = st.progress(0) status_text = st.empty() try: # 创建图谱构建器 builder = ZepGraphBuilder(api_key=api_key) # 创建图谱 status_text.text("创建图谱...") progress_bar.progress(10) graph_id = builder.create_graph(name=graph_name) # 设置本体 status_text.text("配置实体类型...") progress_bar.progress(20) builder.set_ontology(graph_id) # 分块 status_text.text("分割文本...") progress_bar.progress(30) chunks = split_text_into_chunks(text_content, max_chunk_size=chunk_size) st.info(f"文本已分为 {len(chunks)} 个块") # 添加到图谱 status_text.text("正在发送数据到Zep...") progress_bar.progress(40) def update_progress(msg): status_text.text(msg) # 分批发送数据 task_ids = builder.add_text_to_graph( graph_id=graph_id, text_chunks=chunks, batch_size=3, progress_callback=update_progress ) # 等待处理完成 progress_bar.progress(60) status_text.text("等待Zep处理数据...") if task_ids: builder.wait_for_tasks( task_ids, timeout=600, progress_callback=update_progress ) # 获取图数据 status_text.text("获取图谱数据...") progress_bar.progress(90) st.session_state.graph_data = builder.get_graph_data(graph_id) st.session_state.graph_id = graph_id progress_bar.progress(100) status_text.text("完成!") st.success(f"知识图谱生成成功! Graph ID: {graph_id}") except Exception as e: st.error(f"生成图谱时出错: {str(e)}") import traceback st.code(traceback.format_exc()) # 显示图谱 if st.session_state.graph_data: graph_data = st.session_state.graph_data # 统计信息 display_stats(graph_data) st.divider() # 图谱可视化 st.subheader("知识图谱可视化") if graph_data.nodes: with st.spinner("渲染图谱..."): html_content = create_pyvis_graph(graph_data) components.html(html_content, height=750, scrolling=True) else: st.warning("图谱中没有节点") st.divider() # 实体列表 col1, col2 = st.columns([1, 1]) with col1: display_entity_list(graph_data) with col2: st.subheader("关系列表") if graph_data.edges: for edge in graph_data.edges[:50]: # 只显示前50条 st.markdown(f"- **{edge.fact}**" if edge.fact else f"- {edge.name}") if len(graph_data.edges) > 50: st.caption(f"...还有 {len(graph_data.edges) - 50} 条关系") else: st.info("暂无关系数据") if __name__ == "__main__": main()