diff --git a/backend/app/api/graph.py b/backend/app/api/graph.py index 39dba18..2eabc0e 100644 --- a/backend/app/api/graph.py +++ b/backend/app/api/graph.py @@ -12,6 +12,7 @@ from . import graph_bp from ..config import Config from ..services.ontology_generator import OntologyGenerator from ..services.graph_builder import GraphBuilderService +from ..services.llm_graph_builder import LLMGraphBuilderService from ..services.text_processor import TextProcessor from ..utils.file_parser import FileParser from ..utils.logger import get_logger @@ -282,17 +283,6 @@ def build_graph(): try: logger.info("=== 开始构建图谱 ===") - # 检查配置 - errors = [] - if not Config.ZEP_API_KEY: - errors.append("ZEP_API_KEY is not configured") - if errors: - logger.error(f"配置错误: {errors}") - return jsonify({ - "success": False, - "error": "Configuration error: " + "; ".join(errors) - }), 500 - # 解析请求 data = request.get_json() or {} project_id = data.get('project_id') @@ -374,16 +364,16 @@ def build_graph(): def build_task(): build_logger = get_logger('mirofish.build') try: - build_logger.info(f"[{task_id}] 开始构建图谱...") + build_logger.info(f"[{task_id}] 开始构建图谱 (LLM mode)...") task_manager.update_task( - task_id, + task_id, status=TaskStatus.PROCESSING, - message="Initializing graph build service..." + message="Initializing LLM graph build service..." ) - - # 创建图谱构建服务 - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) - + + # 创建 LLM 图谱构建服务(不需要 Zep) + builder = LLMGraphBuilderService() + # 分块 task_manager.update_task( task_id, @@ -391,71 +381,48 @@ def build_graph(): progress=5 ) chunks = TextProcessor.split_text( - text, - chunk_size=chunk_size, + text, + chunk_size=chunk_size, overlap=chunk_overlap ) total_chunks = len(chunks) - + # 创建图谱 task_manager.update_task( task_id, - message="Creating Zep graph...", + message="Creating graph...", progress=10 ) graph_id = builder.create_graph(name=graph_name) - + # 更新项目的graph_id project.graph_id = graph_id ProjectManager.save_project(project) - + # 设置本体 - task_manager.update_task( - task_id, - message="Setting ontology definition...", - progress=15 - ) builder.set_ontology(graph_id, ontology) - - # 添加文本(progress_callback 签名是 (msg, progress_ratio)) - def add_progress_callback(msg, progress_ratio): - progress = 15 + int(progress_ratio * 40) # 15% - 55% + + # LLM extraction from chunks + def extract_progress_callback(msg, progress_ratio): + progress = 15 + int(progress_ratio * 75) # 15% - 90% task_manager.update_task( task_id, message=msg, progress=progress ) - + task_manager.update_task( task_id, - message=f"Adding {total_chunks} text chunks...", + message=f"Extracting entities from {total_chunks} chunks via LLM...", progress=15 ) - - episode_uuids = builder.add_text_batches( - graph_id, + + builder.extract_from_chunks( + graph_id, chunks, - batch_size=3, - progress_callback=add_progress_callback + progress_callback=extract_progress_callback ) - - # 等待Zep处理完成(查询每个episode的processed状态) - task_manager.update_task( - task_id, - message="Waiting for Zep to process data...", - progress=55 - ) - - def wait_progress_callback(msg, progress_ratio): - progress = 55 + int(progress_ratio * 35) # 55% - 90% - task_manager.update_task( - task_id, - message=msg, - progress=progress - ) - - builder._wait_for_episodes(episode_uuids, wait_progress_callback) - + # 获取图谱数据 task_manager.update_task( task_id, @@ -463,15 +430,19 @@ def build_graph(): progress=95 ) graph_data = builder.get_graph_data(graph_id) - + + # Persist graph data to disk + project_dir = ProjectManager._get_project_dir(project_id) + builder.save_graph_data(graph_id, project_dir) + # 更新项目状态 project.status = ProjectStatus.GRAPH_COMPLETED ProjectManager.save_project(project) - + node_count = graph_data.get("node_count", 0) edge_count = graph_data.get("edge_count", 0) build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}") - + # 完成 task_manager.update_task( task_id, @@ -486,16 +457,16 @@ def build_graph(): "chunk_count": total_chunks } ) - + except Exception as e: # 更新项目状态为失败 build_logger.error(f"[{task_id}] 图谱构建失败: {str(e)}") build_logger.debug(traceback.format_exc()) - + project.status = ProjectStatus.FAILED project.error = str(e) ProjectManager.save_project(project) - + task_manager.update_task( task_id, status=TaskStatus.FAILED, @@ -565,21 +536,36 @@ def list_tasks(): def get_graph_data(graph_id: str): """ 获取图谱数据(节点和边) + First tries disk (LLM builder), falls back to Zep if available. """ try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": "ZEP_API_KEY is not configured" - }), 500 + # Find which project owns this graph_id + all_projects = ProjectManager.list_projects() + for proj_summary in all_projects: + proj = ProjectManager.get_project(proj_summary["project_id"]) + if proj and proj.graph_id == graph_id: + project_dir = ProjectManager._get_project_dir(proj.project_id) + graph_data = LLMGraphBuilderService.load_graph_data(project_dir) + if graph_data: + return jsonify({ + "success": True, + "data": graph_data + }) + break - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) - graph_data = builder.get_graph_data(graph_id) + # Fallback to Zep if graph data not on disk + if Config.ZEP_API_KEY: + builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + graph_data = builder.get_graph_data(graph_id) + return jsonify({ + "success": True, + "data": graph_data + }) return jsonify({ - "success": True, - "data": graph_data - }) + "success": False, + "error": f"Graph data not found for {graph_id}" + }), 404 except Exception as e: return jsonify({ diff --git a/backend/app/services/llm_graph_builder.py b/backend/app/services/llm_graph_builder.py new file mode 100644 index 0000000..e87ed6c --- /dev/null +++ b/backend/app/services/llm_graph_builder.py @@ -0,0 +1,254 @@ +""" +LLM-based graph builder service +Replaces Zep with direct LLM calls for entity/relationship extraction +""" + +import os +import uuid +import json +from typing import Dict, Any, List, Optional, Callable + +from ..utils.llm_client import LLMClient +from ..models.task import TaskManager, TaskStatus +from .text_processor import TextProcessor + + +EXTRACT_SYSTEM_PROMPT = """You are a knowledge graph extraction engine. Given a text chunk and an ontology schema, extract all entities and relationships. + +ONTOLOGY SCHEMA: +{ontology_json} + +RULES: +1. Extract entities that match the entity_types defined in the schema. Each entity needs: name, type (matching an entity_type name), summary (1-2 sentences), and any attributes defined for that type. +2. Extract relationships between entities that match the edge_types defined in the schema. Each relationship needs: name (the edge type name), source (entity name), target (entity name), and a fact (short description of the relationship). +3. Only extract entities and relationships that are explicitly mentioned or strongly implied in the text. +4. Use consistent entity names across extractions (e.g., always "Mira" not sometimes "Mira" and sometimes "Mira the Socializer"). +5. If no entities or relationships are found, return empty arrays. + +Return JSON in this exact format: +{ + "entities": [ + { + "name": "EntityName", + "type": "EntityTypeName", + "summary": "Brief description", + "attributes": {"attr_name": "attr_value"} + } + ], + "relationships": [ + { + "name": "EDGE_TYPE_NAME", + "source": "SourceEntityName", + "target": "TargetEntityName", + "fact": "Description of this relationship" + } + ] +}""" + + +class LLMGraphBuilderService: + """ + Graph builder that uses direct LLM calls instead of Zep. + Same interface as GraphBuilderService for drop-in replacement. + """ + + def __init__(self, llm_client: Optional[LLMClient] = None): + self.llm = llm_client or LLMClient() + self.task_manager = TaskManager() + # In-memory graph storage (keyed by graph_id) + self._graphs: Dict[str, Dict[str, Any]] = {} + + def create_graph(self, name: str) -> str: + graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" + self._graphs[graph_id] = { + "name": name, + "ontology": None, + "nodes": {}, # keyed by normalized name + "edges": [], + } + return graph_id + + def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): + if graph_id in self._graphs: + self._graphs[graph_id]["ontology"] = ontology + + def extract_from_chunks( + self, + graph_id: str, + chunks: List[str], + progress_callback: Optional[Callable] = None + ): + """Extract entities and relationships from text chunks using LLM.""" + graph = self._graphs[graph_id] + ontology = graph["ontology"] + ontology_json = json.dumps(ontology, indent=2, ensure_ascii=False) + + total = len(chunks) + for i, chunk in enumerate(chunks): + if progress_callback: + progress_callback( + f"Extracting from chunk {i+1}/{total}...", + (i + 1) / total + ) + + try: + result = self.llm.chat_json( + messages=[ + { + "role": "system", + "content": EXTRACT_SYSTEM_PROMPT.format(ontology_json=ontology_json) + }, + { + "role": "user", + "content": f"Extract entities and relationships from this text:\n\n{chunk}" + } + ], + temperature=0.1, + max_tokens=4096 + ) + self._merge_extraction(graph_id, result) + except Exception as e: + if progress_callback: + progress_callback(f"Chunk {i+1} extraction error: {e}", (i + 1) / total) + + def _merge_extraction(self, graph_id: str, result: Dict[str, Any]): + """Merge extracted entities/relationships into the graph, deduplicating by name.""" + graph = self._graphs[graph_id] + nodes = graph["nodes"] + edges = graph["edges"] + + # Valid entity type names from ontology + valid_entity_types = set() + if graph["ontology"]: + for et in graph["ontology"].get("entity_types", []): + valid_entity_types.add(et["name"]) + + # Valid edge type names + valid_edge_types = set() + if graph["ontology"]: + for et in graph["ontology"].get("edge_types", []): + valid_edge_types.add(et["name"]) + + # Merge entities + for entity in result.get("entities", []): + name = entity.get("name", "").strip() + if not name: + continue + + etype = entity.get("type", "Entity") + key = name.lower() + + if key in nodes: + # Update summary if new one is longer + existing = nodes[key] + new_summary = entity.get("summary", "") + if new_summary and len(new_summary) > len(existing.get("summary", "")): + existing["summary"] = new_summary + # Merge attributes + for k, v in entity.get("attributes", {}).items(): + if v and not existing["attributes"].get(k): + existing["attributes"][k] = v + else: + labels = [etype] if etype in valid_entity_types else ["Entity"] + nodes[key] = { + "uuid": str(uuid.uuid4()), + "name": name, + "labels": labels, + "summary": entity.get("summary", ""), + "attributes": entity.get("attributes", {}) or {}, + "created_at": None, + } + + # Merge relationships (deduplicate by source+target+name) + existing_edges = set() + for e in edges: + existing_edges.add((e["source_node_name"].lower(), e["target_node_name"].lower(), e["name"])) + + for rel in result.get("relationships", []): + rel_name = rel.get("name", "").strip() + source = rel.get("source", "").strip() + target = rel.get("target", "").strip() + if not rel_name or not source or not target: + continue + + edge_key = (source.lower(), target.lower(), rel_name) + if edge_key in existing_edges: + continue + existing_edges.add(edge_key) + + # Resolve node UUIDs + source_node = nodes.get(source.lower()) + target_node = nodes.get(target.lower()) + source_uuid = source_node["uuid"] if source_node else str(uuid.uuid4()) + target_uuid = target_node["uuid"] if target_node else str(uuid.uuid4()) + + # Create placeholder nodes if they don't exist + if not source_node: + nodes[source.lower()] = { + "uuid": source_uuid, + "name": source, + "labels": ["Entity"], + "summary": "", + "attributes": {}, + "created_at": None, + } + if not target_node: + nodes[target.lower()] = { + "uuid": target_uuid, + "name": target, + "labels": ["Entity"], + "summary": "", + "attributes": {}, + "created_at": None, + } + + edges.append({ + "uuid": str(uuid.uuid4()), + "name": rel_name, + "fact": rel.get("fact", ""), + "fact_type": rel_name, + "source_node_uuid": source_uuid, + "target_node_uuid": target_uuid, + "source_node_name": source, + "target_node_name": target, + "attributes": {}, + "created_at": None, + "valid_at": None, + "invalid_at": None, + "expired_at": None, + "episodes": [], + }) + + def get_graph_data(self, graph_id: str) -> Dict[str, Any]: + """Return graph data in the same format as the Zep-based builder.""" + graph = self._graphs.get(graph_id, {"nodes": {}, "edges": []}) + nodes_list = list(graph["nodes"].values()) + edges_list = graph["edges"] + + return { + "graph_id": graph_id, + "nodes": nodes_list, + "edges": edges_list, + "node_count": len(nodes_list), + "edge_count": len(edges_list), + } + + def save_graph_data(self, graph_id: str, project_dir: str) -> str: + """Persist graph data to a JSON file in the project directory.""" + data = self.get_graph_data(graph_id) + path = os.path.join(project_dir, "graph_data.json") + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + return path + + @staticmethod + def load_graph_data(project_dir: str) -> Optional[Dict[str, Any]]: + """Load persisted graph data from disk.""" + path = os.path.join(project_dir, "graph_data.json") + if os.path.exists(path): + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + return None + + def delete_graph(self, graph_id: str): + self._graphs.pop(graph_id, None)