628 lines
20 KiB
Python
628 lines
20 KiB
Python
"""
|
||
图谱相关API路由
|
||
采用项目上下文机制,服务端持久化状态
|
||
"""
|
||
|
||
import os
|
||
import traceback
|
||
import threading
|
||
from flask import request, jsonify
|
||
|
||
from . import graph_bp
|
||
from ..config import Config
|
||
from ..services.ontology_generator import OntologyGenerator
|
||
from ..services.graph_builder import GraphBuilderService
|
||
from ..services.llm_graph_builder import LLMGraphBuilderService
|
||
from ..services.text_processor import TextProcessor
|
||
from ..utils.file_parser import FileParser
|
||
from ..utils.logger import get_logger
|
||
from ..models.task import TaskManager, TaskStatus
|
||
from ..models.project import ProjectManager, ProjectStatus
|
||
|
||
# 获取日志器
|
||
logger = get_logger('mirofish.api')
|
||
|
||
|
||
def allowed_file(filename: str) -> bool:
|
||
"""检查文件扩展名是否允许"""
|
||
if not filename or '.' not in filename:
|
||
return False
|
||
ext = os.path.splitext(filename)[1].lower().lstrip('.')
|
||
return ext in Config.ALLOWED_EXTENSIONS
|
||
|
||
|
||
# ============== 项目管理接口 ==============
|
||
|
||
@graph_bp.route('/project/<project_id>', methods=['GET'])
|
||
def get_project(project_id: str):
|
||
"""
|
||
获取项目详情
|
||
"""
|
||
project = ProjectManager.get_project(project_id)
|
||
|
||
if not project:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": f"Project not found: {project_id}"
|
||
}), 404
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": project.to_dict()
|
||
})
|
||
|
||
|
||
@graph_bp.route('/project/list', methods=['GET'])
|
||
def list_projects():
|
||
"""
|
||
列出所有项目
|
||
"""
|
||
limit = request.args.get('limit', 50, type=int)
|
||
projects = ProjectManager.list_projects(limit=limit)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": [p.to_dict() for p in projects],
|
||
"count": len(projects)
|
||
})
|
||
|
||
|
||
@graph_bp.route('/project/<project_id>', methods=['DELETE'])
|
||
def delete_project(project_id: str):
|
||
"""
|
||
删除项目
|
||
"""
|
||
success = ProjectManager.delete_project(project_id)
|
||
|
||
if not success:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": f"Project not found or delete failed: {project_id}"
|
||
}), 404
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"message": f"Project deleted: {project_id}"
|
||
})
|
||
|
||
|
||
@graph_bp.route('/project/<project_id>/reset', methods=['POST'])
|
||
def reset_project(project_id: str):
|
||
"""
|
||
重置项目状态(用于重新构建图谱)
|
||
"""
|
||
project = ProjectManager.get_project(project_id)
|
||
|
||
if not project:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": f"Project not found: {project_id}"
|
||
}), 404
|
||
|
||
# 重置到本体已生成状态
|
||
if project.ontology:
|
||
project.status = ProjectStatus.ONTOLOGY_GENERATED
|
||
else:
|
||
project.status = ProjectStatus.CREATED
|
||
|
||
project.graph_id = None
|
||
project.graph_build_task_id = None
|
||
project.error = None
|
||
ProjectManager.save_project(project)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"message": f"Project reset: {project_id}",
|
||
"data": project.to_dict()
|
||
})
|
||
|
||
|
||
# ============== 接口1:上传文件并生成本体 ==============
|
||
|
||
@graph_bp.route('/ontology/generate', methods=['POST'])
|
||
def generate_ontology():
|
||
"""
|
||
接口1:上传文件,分析生成本体定义
|
||
|
||
请求方式:multipart/form-data
|
||
|
||
参数:
|
||
files: 上传的文件(PDF/MD/TXT),可多个
|
||
simulation_requirement: 模拟需求描述(必填)
|
||
project_name: 项目名称(可选)
|
||
additional_context: 额外说明(可选)
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"project_id": "proj_xxxx",
|
||
"ontology": {
|
||
"entity_types": [...],
|
||
"edge_types": [...],
|
||
"analysis_summary": "..."
|
||
},
|
||
"files": [...],
|
||
"total_text_length": 12345
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
logger.info("=== 开始生成本体定义 ===")
|
||
|
||
# 获取参数
|
||
simulation_requirement = request.form.get('simulation_requirement', '')
|
||
project_name = request.form.get('project_name', 'Unnamed Project')
|
||
additional_context = request.form.get('additional_context', '')
|
||
|
||
logger.debug(f"项目名称: {project_name}")
|
||
logger.debug(f"模拟需求: {simulation_requirement[:100]}...")
|
||
|
||
if not simulation_requirement:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": "Please provide simulation_requirement"
|
||
}), 400
|
||
|
||
# 获取上传的文件
|
||
uploaded_files = request.files.getlist('files')
|
||
if not uploaded_files or all(not f.filename for f in uploaded_files):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": "Please upload at least one document file"
|
||
}), 400
|
||
|
||
# 创建项目
|
||
project = ProjectManager.create_project(name=project_name)
|
||
project.simulation_requirement = simulation_requirement
|
||
logger.info(f"创建项目: {project.project_id}")
|
||
|
||
# 保存文件并提取文本
|
||
document_texts = []
|
||
all_text = ""
|
||
|
||
for file in uploaded_files:
|
||
if file and file.filename and allowed_file(file.filename):
|
||
# 保存文件到项目目录
|
||
file_info = ProjectManager.save_file_to_project(
|
||
project.project_id,
|
||
file,
|
||
file.filename
|
||
)
|
||
project.files.append({
|
||
"filename": file_info["original_filename"],
|
||
"size": file_info["size"]
|
||
})
|
||
|
||
# 提取文本
|
||
text = FileParser.extract_text(file_info["path"])
|
||
text = TextProcessor.preprocess_text(text)
|
||
document_texts.append(text)
|
||
all_text += f"\n\n=== {file_info['original_filename']} ===\n{text}"
|
||
|
||
if not document_texts:
|
||
ProjectManager.delete_project(project.project_id)
|
||
return jsonify({
|
||
"success": False,
|
||
"error": "No documents were processed successfully, please check file format"
|
||
}), 400
|
||
|
||
# 保存提取的文本
|
||
project.total_text_length = len(all_text)
|
||
ProjectManager.save_extracted_text(project.project_id, all_text)
|
||
logger.info(f"文本提取完成,共 {len(all_text)} 字符")
|
||
|
||
# 生成本体
|
||
logger.info("调用 LLM 生成本体定义...")
|
||
generator = OntologyGenerator()
|
||
ontology = generator.generate(
|
||
document_texts=document_texts,
|
||
simulation_requirement=simulation_requirement,
|
||
additional_context=additional_context if additional_context else None
|
||
)
|
||
|
||
# 保存本体到项目
|
||
entity_count = len(ontology.get("entity_types", []))
|
||
edge_count = len(ontology.get("edge_types", []))
|
||
logger.info(f"本体生成完成: {entity_count} 个实体类型, {edge_count} 个关系类型")
|
||
|
||
project.ontology = {
|
||
"entity_types": ontology.get("entity_types", []),
|
||
"edge_types": ontology.get("edge_types", [])
|
||
}
|
||
project.analysis_summary = ontology.get("analysis_summary", "")
|
||
project.status = ProjectStatus.ONTOLOGY_GENERATED
|
||
ProjectManager.save_project(project)
|
||
logger.info(f"=== 本体生成完成 === 项目ID: {project.project_id}")
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"project_id": project.project_id,
|
||
"project_name": project.name,
|
||
"ontology": project.ontology,
|
||
"analysis_summary": project.analysis_summary,
|
||
"files": project.files,
|
||
"total_text_length": project.total_text_length
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
# ============== 接口2:构建图谱 ==============
|
||
|
||
@graph_bp.route('/build', methods=['POST'])
|
||
def build_graph():
|
||
"""
|
||
接口2:根据project_id构建图谱
|
||
|
||
请求(JSON):
|
||
{
|
||
"project_id": "proj_xxxx", // 必填,来自接口1
|
||
"graph_name": "图谱名称", // 可选
|
||
"chunk_size": 500, // 可选,默认500
|
||
"chunk_overlap": 50 // 可选,默认50
|
||
}
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"project_id": "proj_xxxx",
|
||
"task_id": "task_xxxx",
|
||
"message": "图谱构建任务已启动"
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
logger.info("=== 开始构建图谱 ===")
|
||
|
||
# 解析请求
|
||
data = request.get_json() or {}
|
||
project_id = data.get('project_id')
|
||
logger.debug(f"请求参数: project_id={project_id}")
|
||
|
||
if not project_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": "Please provide project_id"
|
||
}), 400
|
||
|
||
# 获取项目
|
||
project = ProjectManager.get_project(project_id)
|
||
if not project:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": f"Project not found: {project_id}"
|
||
}), 404
|
||
|
||
# 检查项目状态
|
||
force = data.get('force', False) # 强制重新构建
|
||
|
||
if project.status == ProjectStatus.CREATED:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": "Ontology not yet generated for this project. Please call /ontology/generate first"
|
||
}), 400
|
||
|
||
if project.status == ProjectStatus.GRAPH_BUILDING and not force:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": "Graph is currently being built. To force rebuild, add force: true",
|
||
"task_id": project.graph_build_task_id
|
||
}), 400
|
||
|
||
# 如果强制重建,重置状态
|
||
if force and project.status in [ProjectStatus.GRAPH_BUILDING, ProjectStatus.FAILED, ProjectStatus.GRAPH_COMPLETED]:
|
||
project.status = ProjectStatus.ONTOLOGY_GENERATED
|
||
project.graph_id = None
|
||
project.graph_build_task_id = None
|
||
project.error = None
|
||
|
||
# 获取配置
|
||
graph_name = data.get('graph_name', project.name or 'MiroFish Graph')
|
||
chunk_size = data.get('chunk_size', project.chunk_size or Config.DEFAULT_CHUNK_SIZE)
|
||
chunk_overlap = data.get('chunk_overlap', project.chunk_overlap or Config.DEFAULT_CHUNK_OVERLAP)
|
||
|
||
# 更新项目配置
|
||
project.chunk_size = chunk_size
|
||
project.chunk_overlap = chunk_overlap
|
||
|
||
# 获取提取的文本
|
||
text = ProjectManager.get_extracted_text(project_id)
|
||
if not text:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": "Extracted text content not found"
|
||
}), 400
|
||
|
||
# 获取本体
|
||
ontology = project.ontology
|
||
if not ontology:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": "Ontology definition not found"
|
||
}), 400
|
||
|
||
# 创建异步任务
|
||
task_manager = TaskManager()
|
||
task_id = task_manager.create_task(f"Build graph: {graph_name}")
|
||
logger.info(f"创建图谱构建任务: task_id={task_id}, project_id={project_id}")
|
||
|
||
# 更新项目状态
|
||
project.status = ProjectStatus.GRAPH_BUILDING
|
||
project.graph_build_task_id = task_id
|
||
ProjectManager.save_project(project)
|
||
|
||
# 启动后台任务
|
||
def build_task():
|
||
build_logger = get_logger('mirofish.build')
|
||
try:
|
||
build_logger.info(f"[{task_id}] 开始构建图谱 (LLM mode)...")
|
||
task_manager.update_task(
|
||
task_id,
|
||
status=TaskStatus.PROCESSING,
|
||
message="Initializing LLM graph build service..."
|
||
)
|
||
|
||
# 创建 LLM 图谱构建服务(不需要 Zep)
|
||
from ..services.llm_graph_builder import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP
|
||
builder = LLMGraphBuilderService()
|
||
|
||
# Use larger chunks for better context
|
||
entity_chunk_size = max(chunk_size, DEFAULT_CHUNK_SIZE)
|
||
entity_chunk_overlap = max(chunk_overlap, DEFAULT_CHUNK_OVERLAP)
|
||
|
||
# 分块
|
||
task_manager.update_task(
|
||
task_id,
|
||
message="Splitting text into chunks...",
|
||
progress=5
|
||
)
|
||
chunks = TextProcessor.split_text(
|
||
text,
|
||
chunk_size=entity_chunk_size,
|
||
overlap=entity_chunk_overlap
|
||
)
|
||
total_chunks = len(chunks)
|
||
|
||
# 创建图谱
|
||
task_manager.update_task(
|
||
task_id,
|
||
message="Creating graph...",
|
||
progress=10
|
||
)
|
||
graph_id = builder.create_graph(name=graph_name)
|
||
|
||
# 更新项目的graph_id
|
||
project.graph_id = graph_id
|
||
ProjectManager.save_project(project)
|
||
|
||
# 设置本体
|
||
builder.set_ontology(graph_id, ontology)
|
||
|
||
# Pass 1: Entity extraction
|
||
def entity_progress_callback(msg, progress_ratio):
|
||
progress = 15 + int(progress_ratio * 40) # 15% - 55%
|
||
task_manager.update_task(
|
||
task_id,
|
||
message=msg,
|
||
progress=progress
|
||
)
|
||
|
||
task_manager.update_task(
|
||
task_id,
|
||
message=f"[Pass 1] Extracting entities from {total_chunks} chunks...",
|
||
progress=15
|
||
)
|
||
|
||
builder.extract_entities(
|
||
graph_id,
|
||
chunks,
|
||
progress_callback=entity_progress_callback
|
||
)
|
||
|
||
# Pass 2: Relationship discovery
|
||
def rel_progress_callback(msg, progress_ratio):
|
||
progress = 55 + int(progress_ratio * 35) # 55% - 90%
|
||
task_manager.update_task(
|
||
task_id,
|
||
message=msg,
|
||
progress=progress
|
||
)
|
||
|
||
task_manager.update_task(
|
||
task_id,
|
||
message="[Pass 2] Discovering relationships between entities...",
|
||
progress=55
|
||
)
|
||
|
||
builder.discover_relationships(
|
||
graph_id,
|
||
text,
|
||
progress_callback=rel_progress_callback
|
||
)
|
||
|
||
# 获取图谱数据
|
||
task_manager.update_task(
|
||
task_id,
|
||
message="Retrieving graph data...",
|
||
progress=95
|
||
)
|
||
graph_data = builder.get_graph_data(graph_id)
|
||
|
||
# Persist graph data to disk
|
||
project_dir = ProjectManager._get_project_dir(project_id)
|
||
builder.save_graph_data(graph_id, project_dir)
|
||
|
||
# 更新项目状态
|
||
project.status = ProjectStatus.GRAPH_COMPLETED
|
||
ProjectManager.save_project(project)
|
||
|
||
node_count = graph_data.get("node_count", 0)
|
||
edge_count = graph_data.get("edge_count", 0)
|
||
build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}")
|
||
|
||
# 完成
|
||
task_manager.update_task(
|
||
task_id,
|
||
status=TaskStatus.COMPLETED,
|
||
message="Graph construction complete",
|
||
progress=100,
|
||
result={
|
||
"project_id": project_id,
|
||
"graph_id": graph_id,
|
||
"node_count": node_count,
|
||
"edge_count": edge_count,
|
||
"chunk_count": total_chunks
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
# 更新项目状态为失败
|
||
build_logger.error(f"[{task_id}] 图谱构建失败: {str(e)}")
|
||
build_logger.debug(traceback.format_exc())
|
||
|
||
project.status = ProjectStatus.FAILED
|
||
project.error = str(e)
|
||
ProjectManager.save_project(project)
|
||
|
||
task_manager.update_task(
|
||
task_id,
|
||
status=TaskStatus.FAILED,
|
||
message=f"Build failed: {str(e)}",
|
||
error=traceback.format_exc()
|
||
)
|
||
|
||
# 启动后台线程
|
||
thread = threading.Thread(target=build_task, daemon=True)
|
||
thread.start()
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"project_id": project_id,
|
||
"task_id": task_id,
|
||
"message": "Graph build task started. Query progress via /task/{task_id}"
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
# ============== 任务查询接口 ==============
|
||
|
||
@graph_bp.route('/task/<task_id>', methods=['GET'])
|
||
def get_task(task_id: str):
|
||
"""
|
||
查询任务状态
|
||
"""
|
||
task = TaskManager().get_task(task_id)
|
||
|
||
if not task:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": f"Task not found: {task_id}"
|
||
}), 404
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": task.to_dict()
|
||
})
|
||
|
||
|
||
@graph_bp.route('/tasks', methods=['GET'])
|
||
def list_tasks():
|
||
"""
|
||
列出所有任务
|
||
"""
|
||
tasks = TaskManager().list_tasks()
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": [t.to_dict() for t in tasks],
|
||
"count": len(tasks)
|
||
})
|
||
|
||
|
||
# ============== 图谱数据接口 ==============
|
||
|
||
@graph_bp.route('/data/<graph_id>', methods=['GET'])
|
||
def get_graph_data(graph_id: str):
|
||
"""
|
||
获取图谱数据(节点和边)
|
||
First tries disk (LLM builder), falls back to Zep if available.
|
||
"""
|
||
try:
|
||
# Find which project owns this graph_id
|
||
all_projects = ProjectManager.list_projects()
|
||
for proj in all_projects:
|
||
if proj.graph_id == graph_id:
|
||
project_dir = ProjectManager._get_project_dir(proj.project_id)
|
||
graph_data = LLMGraphBuilderService.load_graph_data(project_dir)
|
||
if graph_data:
|
||
return jsonify({
|
||
"success": True,
|
||
"data": graph_data
|
||
})
|
||
break
|
||
|
||
# Fallback to Zep if graph data not on disk
|
||
if Config.ZEP_API_KEY:
|
||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||
graph_data = builder.get_graph_data(graph_id)
|
||
return jsonify({
|
||
"success": True,
|
||
"data": graph_data
|
||
})
|
||
|
||
return jsonify({
|
||
"success": False,
|
||
"error": f"Graph data not found for {graph_id}"
|
||
}), 404
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@graph_bp.route('/delete/<graph_id>', methods=['DELETE'])
|
||
def delete_graph(graph_id: str):
|
||
"""
|
||
删除Zep图谱
|
||
"""
|
||
try:
|
||
if not Config.ZEP_API_KEY:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": "ZEP_API_KEY is not configured"
|
||
}), 500
|
||
|
||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||
builder.delete_graph(graph_id)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"message": f"Graph deleted: {graph_id}"
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|