MiroFish/backend/app/api/graph.py

628 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图谱相关API路由
采用项目上下文机制,服务端持久化状态
"""
import os
import traceback
import threading
from flask import request, jsonify
from . import graph_bp
from ..config import Config
from ..services.ontology_generator import OntologyGenerator
from ..services.graph_builder import GraphBuilderService
from ..services.llm_graph_builder import LLMGraphBuilderService
from ..services.text_processor import TextProcessor
from ..utils.file_parser import FileParser
from ..utils.logger import get_logger
from ..models.task import TaskManager, TaskStatus
from ..models.project import ProjectManager, ProjectStatus
# 获取日志器
logger = get_logger('mirofish.api')
def allowed_file(filename: str) -> bool:
"""检查文件扩展名是否允许"""
if not filename or '.' not in filename:
return False
ext = os.path.splitext(filename)[1].lower().lstrip('.')
return ext in Config.ALLOWED_EXTENSIONS
# ============== 项目管理接口 ==============
@graph_bp.route('/project/<project_id>', methods=['GET'])
def get_project(project_id: str):
"""
获取项目详情
"""
project = ProjectManager.get_project(project_id)
if not project:
return jsonify({
"success": False,
"error": f"Project not found: {project_id}"
}), 404
return jsonify({
"success": True,
"data": project.to_dict()
})
@graph_bp.route('/project/list', methods=['GET'])
def list_projects():
"""
列出所有项目
"""
limit = request.args.get('limit', 50, type=int)
projects = ProjectManager.list_projects(limit=limit)
return jsonify({
"success": True,
"data": [p.to_dict() for p in projects],
"count": len(projects)
})
@graph_bp.route('/project/<project_id>', methods=['DELETE'])
def delete_project(project_id: str):
"""
删除项目
"""
success = ProjectManager.delete_project(project_id)
if not success:
return jsonify({
"success": False,
"error": f"Project not found or delete failed: {project_id}"
}), 404
return jsonify({
"success": True,
"message": f"Project deleted: {project_id}"
})
@graph_bp.route('/project/<project_id>/reset', methods=['POST'])
def reset_project(project_id: str):
"""
重置项目状态(用于重新构建图谱)
"""
project = ProjectManager.get_project(project_id)
if not project:
return jsonify({
"success": False,
"error": f"Project not found: {project_id}"
}), 404
# 重置到本体已生成状态
if project.ontology:
project.status = ProjectStatus.ONTOLOGY_GENERATED
else:
project.status = ProjectStatus.CREATED
project.graph_id = None
project.graph_build_task_id = None
project.error = None
ProjectManager.save_project(project)
return jsonify({
"success": True,
"message": f"Project reset: {project_id}",
"data": project.to_dict()
})
# ============== 接口1上传文件并生成本体 ==============
@graph_bp.route('/ontology/generate', methods=['POST'])
def generate_ontology():
"""
接口1上传文件分析生成本体定义
请求方式multipart/form-data
参数:
files: 上传的文件PDF/MD/TXT可多个
simulation_requirement: 模拟需求描述(必填)
project_name: 项目名称(可选)
additional_context: 额外说明(可选)
返回:
{
"success": true,
"data": {
"project_id": "proj_xxxx",
"ontology": {
"entity_types": [...],
"edge_types": [...],
"analysis_summary": "..."
},
"files": [...],
"total_text_length": 12345
}
}
"""
try:
logger.info("=== 开始生成本体定义 ===")
# 获取参数
simulation_requirement = request.form.get('simulation_requirement', '')
project_name = request.form.get('project_name', 'Unnamed Project')
additional_context = request.form.get('additional_context', '')
logger.debug(f"项目名称: {project_name}")
logger.debug(f"模拟需求: {simulation_requirement[:100]}...")
if not simulation_requirement:
return jsonify({
"success": False,
"error": "Please provide simulation_requirement"
}), 400
# 获取上传的文件
uploaded_files = request.files.getlist('files')
if not uploaded_files or all(not f.filename for f in uploaded_files):
return jsonify({
"success": False,
"error": "Please upload at least one document file"
}), 400
# 创建项目
project = ProjectManager.create_project(name=project_name)
project.simulation_requirement = simulation_requirement
logger.info(f"创建项目: {project.project_id}")
# 保存文件并提取文本
document_texts = []
all_text = ""
for file in uploaded_files:
if file and file.filename and allowed_file(file.filename):
# 保存文件到项目目录
file_info = ProjectManager.save_file_to_project(
project.project_id,
file,
file.filename
)
project.files.append({
"filename": file_info["original_filename"],
"size": file_info["size"]
})
# 提取文本
text = FileParser.extract_text(file_info["path"])
text = TextProcessor.preprocess_text(text)
document_texts.append(text)
all_text += f"\n\n=== {file_info['original_filename']} ===\n{text}"
if not document_texts:
ProjectManager.delete_project(project.project_id)
return jsonify({
"success": False,
"error": "No documents were processed successfully, please check file format"
}), 400
# 保存提取的文本
project.total_text_length = len(all_text)
ProjectManager.save_extracted_text(project.project_id, all_text)
logger.info(f"文本提取完成,共 {len(all_text)} 字符")
# 生成本体
logger.info("调用 LLM 生成本体定义...")
generator = OntologyGenerator()
ontology = generator.generate(
document_texts=document_texts,
simulation_requirement=simulation_requirement,
additional_context=additional_context if additional_context else None
)
# 保存本体到项目
entity_count = len(ontology.get("entity_types", []))
edge_count = len(ontology.get("edge_types", []))
logger.info(f"本体生成完成: {entity_count} 个实体类型, {edge_count} 个关系类型")
project.ontology = {
"entity_types": ontology.get("entity_types", []),
"edge_types": ontology.get("edge_types", [])
}
project.analysis_summary = ontology.get("analysis_summary", "")
project.status = ProjectStatus.ONTOLOGY_GENERATED
ProjectManager.save_project(project)
logger.info(f"=== 本体生成完成 === 项目ID: {project.project_id}")
return jsonify({
"success": True,
"data": {
"project_id": project.project_id,
"project_name": project.name,
"ontology": project.ontology,
"analysis_summary": project.analysis_summary,
"files": project.files,
"total_text_length": project.total_text_length
}
})
except Exception as e:
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
# ============== 接口2构建图谱 ==============
@graph_bp.route('/build', methods=['POST'])
def build_graph():
"""
接口2根据project_id构建图谱
请求JSON
{
"project_id": "proj_xxxx", // 必填来自接口1
"graph_name": "图谱名称", // 可选
"chunk_size": 500, // 可选默认500
"chunk_overlap": 50 // 可选默认50
}
返回:
{
"success": true,
"data": {
"project_id": "proj_xxxx",
"task_id": "task_xxxx",
"message": "图谱构建任务已启动"
}
}
"""
try:
logger.info("=== 开始构建图谱 ===")
# 解析请求
data = request.get_json() or {}
project_id = data.get('project_id')
logger.debug(f"请求参数: project_id={project_id}")
if not project_id:
return jsonify({
"success": False,
"error": "Please provide project_id"
}), 400
# 获取项目
project = ProjectManager.get_project(project_id)
if not project:
return jsonify({
"success": False,
"error": f"Project not found: {project_id}"
}), 404
# 检查项目状态
force = data.get('force', False) # 强制重新构建
if project.status == ProjectStatus.CREATED:
return jsonify({
"success": False,
"error": "Ontology not yet generated for this project. Please call /ontology/generate first"
}), 400
if project.status == ProjectStatus.GRAPH_BUILDING and not force:
return jsonify({
"success": False,
"error": "Graph is currently being built. To force rebuild, add force: true",
"task_id": project.graph_build_task_id
}), 400
# 如果强制重建,重置状态
if force and project.status in [ProjectStatus.GRAPH_BUILDING, ProjectStatus.FAILED, ProjectStatus.GRAPH_COMPLETED]:
project.status = ProjectStatus.ONTOLOGY_GENERATED
project.graph_id = None
project.graph_build_task_id = None
project.error = None
# 获取配置
graph_name = data.get('graph_name', project.name or 'MiroFish Graph')
chunk_size = data.get('chunk_size', project.chunk_size or Config.DEFAULT_CHUNK_SIZE)
chunk_overlap = data.get('chunk_overlap', project.chunk_overlap or Config.DEFAULT_CHUNK_OVERLAP)
# 更新项目配置
project.chunk_size = chunk_size
project.chunk_overlap = chunk_overlap
# 获取提取的文本
text = ProjectManager.get_extracted_text(project_id)
if not text:
return jsonify({
"success": False,
"error": "Extracted text content not found"
}), 400
# 获取本体
ontology = project.ontology
if not ontology:
return jsonify({
"success": False,
"error": "Ontology definition not found"
}), 400
# 创建异步任务
task_manager = TaskManager()
task_id = task_manager.create_task(f"Build graph: {graph_name}")
logger.info(f"创建图谱构建任务: task_id={task_id}, project_id={project_id}")
# 更新项目状态
project.status = ProjectStatus.GRAPH_BUILDING
project.graph_build_task_id = task_id
ProjectManager.save_project(project)
# 启动后台任务
def build_task():
build_logger = get_logger('mirofish.build')
try:
build_logger.info(f"[{task_id}] 开始构建图谱 (LLM mode)...")
task_manager.update_task(
task_id,
status=TaskStatus.PROCESSING,
message="Initializing LLM graph build service..."
)
# 创建 LLM 图谱构建服务(不需要 Zep
from ..services.llm_graph_builder import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP
builder = LLMGraphBuilderService()
# Use larger chunks for better context
entity_chunk_size = max(chunk_size, DEFAULT_CHUNK_SIZE)
entity_chunk_overlap = max(chunk_overlap, DEFAULT_CHUNK_OVERLAP)
# 分块
task_manager.update_task(
task_id,
message="Splitting text into chunks...",
progress=5
)
chunks = TextProcessor.split_text(
text,
chunk_size=entity_chunk_size,
overlap=entity_chunk_overlap
)
total_chunks = len(chunks)
# 创建图谱
task_manager.update_task(
task_id,
message="Creating graph...",
progress=10
)
graph_id = builder.create_graph(name=graph_name)
# 更新项目的graph_id
project.graph_id = graph_id
ProjectManager.save_project(project)
# 设置本体
builder.set_ontology(graph_id, ontology)
# Pass 1: Entity extraction
def entity_progress_callback(msg, progress_ratio):
progress = 15 + int(progress_ratio * 40) # 15% - 55%
task_manager.update_task(
task_id,
message=msg,
progress=progress
)
task_manager.update_task(
task_id,
message=f"[Pass 1] Extracting entities from {total_chunks} chunks...",
progress=15
)
builder.extract_entities(
graph_id,
chunks,
progress_callback=entity_progress_callback
)
# Pass 2: Relationship discovery
def rel_progress_callback(msg, progress_ratio):
progress = 55 + int(progress_ratio * 35) # 55% - 90%
task_manager.update_task(
task_id,
message=msg,
progress=progress
)
task_manager.update_task(
task_id,
message="[Pass 2] Discovering relationships between entities...",
progress=55
)
builder.discover_relationships(
graph_id,
text,
progress_callback=rel_progress_callback
)
# 获取图谱数据
task_manager.update_task(
task_id,
message="Retrieving graph data...",
progress=95
)
graph_data = builder.get_graph_data(graph_id)
# Persist graph data to disk
project_dir = ProjectManager._get_project_dir(project_id)
builder.save_graph_data(graph_id, project_dir)
# 更新项目状态
project.status = ProjectStatus.GRAPH_COMPLETED
ProjectManager.save_project(project)
node_count = graph_data.get("node_count", 0)
edge_count = graph_data.get("edge_count", 0)
build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}")
# 完成
task_manager.update_task(
task_id,
status=TaskStatus.COMPLETED,
message="Graph construction complete",
progress=100,
result={
"project_id": project_id,
"graph_id": graph_id,
"node_count": node_count,
"edge_count": edge_count,
"chunk_count": total_chunks
}
)
except Exception as e:
# 更新项目状态为失败
build_logger.error(f"[{task_id}] 图谱构建失败: {str(e)}")
build_logger.debug(traceback.format_exc())
project.status = ProjectStatus.FAILED
project.error = str(e)
ProjectManager.save_project(project)
task_manager.update_task(
task_id,
status=TaskStatus.FAILED,
message=f"Build failed: {str(e)}",
error=traceback.format_exc()
)
# 启动后台线程
thread = threading.Thread(target=build_task, daemon=True)
thread.start()
return jsonify({
"success": True,
"data": {
"project_id": project_id,
"task_id": task_id,
"message": "Graph build task started. Query progress via /task/{task_id}"
}
})
except Exception as e:
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
# ============== 任务查询接口 ==============
@graph_bp.route('/task/<task_id>', methods=['GET'])
def get_task(task_id: str):
"""
查询任务状态
"""
task = TaskManager().get_task(task_id)
if not task:
return jsonify({
"success": False,
"error": f"Task not found: {task_id}"
}), 404
return jsonify({
"success": True,
"data": task.to_dict()
})
@graph_bp.route('/tasks', methods=['GET'])
def list_tasks():
"""
列出所有任务
"""
tasks = TaskManager().list_tasks()
return jsonify({
"success": True,
"data": [t.to_dict() for t in tasks],
"count": len(tasks)
})
# ============== 图谱数据接口 ==============
@graph_bp.route('/data/<graph_id>', methods=['GET'])
def get_graph_data(graph_id: str):
"""
获取图谱数据(节点和边)
First tries disk (LLM builder), falls back to Zep if available.
"""
try:
# Find which project owns this graph_id
all_projects = ProjectManager.list_projects()
for proj in all_projects:
if proj.graph_id == graph_id:
project_dir = ProjectManager._get_project_dir(proj.project_id)
graph_data = LLMGraphBuilderService.load_graph_data(project_dir)
if graph_data:
return jsonify({
"success": True,
"data": graph_data
})
break
# Fallback to Zep if graph data not on disk
if Config.ZEP_API_KEY:
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
graph_data = builder.get_graph_data(graph_id)
return jsonify({
"success": True,
"data": graph_data
})
return jsonify({
"success": False,
"error": f"Graph data not found for {graph_id}"
}), 404
except Exception as e:
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
@graph_bp.route('/delete/<graph_id>', methods=['DELETE'])
def delete_graph(graph_id: str):
"""
删除Zep图谱
"""
try:
if not Config.ZEP_API_KEY:
return jsonify({
"success": False,
"error": "ZEP_API_KEY is not configured"
}), 500
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
builder.delete_graph(graph_id)
return jsonify({
"success": True,
"message": f"Graph deleted: {graph_id}"
})
except Exception as e:
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500