From 10a85e76d6c84d77f8507a8393e0576f3d7596d1 Mon Sep 17 00:00:00 2001 From: _Yusaki Date: Fri, 13 Mar 2026 19:30:30 +0700 Subject: [PATCH] Two-pass graph extraction: entities then relationships with larger chunks --- backend/app/api/graph.py | 42 ++++- backend/app/services/llm_graph_builder.py | 200 +++++++++++++++------- 2 files changed, 175 insertions(+), 67 deletions(-) diff --git a/backend/app/api/graph.py b/backend/app/api/graph.py index 74a902d..ae8da01 100644 --- a/backend/app/api/graph.py +++ b/backend/app/api/graph.py @@ -372,8 +372,13 @@ def build_graph(): ) # 创建 LLM 图谱构建服务(不需要 Zep) + from ..services.llm_graph_builder import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP builder = LLMGraphBuilderService() + # Use larger chunks for better context + entity_chunk_size = max(chunk_size, DEFAULT_CHUNK_SIZE) + entity_chunk_overlap = max(chunk_overlap, DEFAULT_CHUNK_OVERLAP) + # 分块 task_manager.update_task( task_id, @@ -382,8 +387,8 @@ def build_graph(): ) chunks = TextProcessor.split_text( text, - chunk_size=chunk_size, - overlap=chunk_overlap + chunk_size=entity_chunk_size, + overlap=entity_chunk_overlap ) total_chunks = len(chunks) @@ -402,9 +407,9 @@ def build_graph(): # 设置本体 builder.set_ontology(graph_id, ontology) - # LLM extraction from chunks - def extract_progress_callback(msg, progress_ratio): - progress = 15 + int(progress_ratio * 75) # 15% - 90% + # Pass 1: Entity extraction + def entity_progress_callback(msg, progress_ratio): + progress = 15 + int(progress_ratio * 40) # 15% - 55% task_manager.update_task( task_id, message=msg, @@ -413,14 +418,35 @@ def build_graph(): task_manager.update_task( task_id, - message=f"Extracting entities from {total_chunks} chunks via LLM...", + message=f"[Pass 1] Extracting entities from {total_chunks} chunks...", progress=15 ) - builder.extract_from_chunks( + builder.extract_entities( graph_id, chunks, - progress_callback=extract_progress_callback + progress_callback=entity_progress_callback + ) + + # Pass 2: Relationship discovery + def rel_progress_callback(msg, progress_ratio): + progress = 55 + int(progress_ratio * 35) # 55% - 90% + task_manager.update_task( + task_id, + message=msg, + progress=progress + ) + + task_manager.update_task( + task_id, + message="[Pass 2] Discovering relationships between entities...", + progress=55 + ) + + builder.discover_relationships( + graph_id, + text, + progress_callback=rel_progress_callback ) # 获取图谱数据 diff --git a/backend/app/services/llm_graph_builder.py b/backend/app/services/llm_graph_builder.py index 021503d..d331175 100644 --- a/backend/app/services/llm_graph_builder.py +++ b/backend/app/services/llm_graph_builder.py @@ -1,12 +1,14 @@ """ LLM-based graph builder service -Replaces Zep with direct LLM calls for entity/relationship extraction +Replaces Zep with direct LLM calls for entity/relationship extraction. +Two-pass approach: (1) extract entities, (2) discover relationships. """ import os import uuid import json import logging +import traceback from typing import Dict, Any, List, Optional, Callable from ..utils.llm_client import LLMClient @@ -15,35 +17,50 @@ from .text_processor import TextProcessor logger = logging.getLogger('mirofish.llm_graph_builder') +# Default chunk size — larger than Zep's default to capture more context per call +DEFAULT_CHUNK_SIZE = 2500 +DEFAULT_CHUNK_OVERLAP = 200 -EXTRACT_SYSTEM_PROMPT_TEMPLATE = ( - "You are a knowledge graph extraction engine. Given a text chunk and an ontology schema, " - "extract all entities and relationships.\n\n" - "ONTOLOGY SCHEMA:\n%s\n\n" +ENTITY_EXTRACT_PROMPT = ( + "You are a knowledge graph entity extraction engine. Given a text chunk and an ontology schema, " + "extract all entities mentioned in the text.\n\n" + "ONTOLOGY SCHEMA (entity types only):\n%s\n\n" "RULES:\n" - "1. Extract entities that match the entity_types defined in the schema. Each entity needs: " - "name, type (matching an entity_type name), summary (1-2 sentences), and any attributes defined for that type.\n" - "2. Extract relationships between entities that match the edge_types defined in the schema. " - "Each relationship needs: name (the edge type name), source (entity name), target (entity name), " - "and a fact (short description of the relationship).\n" - "3. Only extract entities and relationships that are explicitly mentioned or strongly implied in the text.\n" - "4. Use consistent entity names across extractions.\n" - "5. If no entities or relationships are found, return empty arrays.\n\n" - 'Return JSON with keys "entities" (array of objects with name, type, summary, attributes) ' - 'and "relationships" (array of objects with name, source, target, fact).' + "1. Extract entities that match the entity_types defined above.\n" + "2. Each entity needs: name (the canonical name used in the text), type (must match an entity_type name from the schema), " + "summary (1-2 sentences describing the entity based on the text), and attributes (fill in any attributes defined for that type).\n" + "3. Only extract entities explicitly mentioned or strongly implied in the text.\n" + "4. Use the exact name as it appears in the text (e.g. 'Mira' not 'Mira the Socializer').\n" + "5. If no entities are found, return an empty array.\n\n" + 'Return JSON: a single key "entities" with an array of objects, each having keys: name, type, summary, attributes.' +) + +RELATIONSHIP_EXTRACT_PROMPT = ( + "You are a knowledge graph relationship extraction engine. Given a text section, a list of known entities, " + "and an ontology schema, extract all relationships between the entities.\n\n" + "ONTOLOGY SCHEMA (edge types):\n%s\n\n" + "KNOWN ENTITIES:\n%s\n\n" + "RULES:\n" + "1. Find relationships between the known entities that match the edge_types defined above.\n" + "2. Each relationship needs: name (must match an edge_type name), source (entity name), target (entity name), " + "fact (1 sentence describing the specific relationship found in the text).\n" + "3. Both source and target MUST be from the known entities list.\n" + "4. Only extract relationships explicitly stated or strongly implied in the text.\n" + "5. Extract ALL relationships you can find — be thorough.\n" + "6. If no relationships are found, return an empty array.\n\n" + 'Return JSON: a single key "relationships" with an array of objects, each having keys: name, source, target, fact.' ) class LLMGraphBuilderService: """ - Graph builder that uses direct LLM calls instead of Zep. - Same interface as GraphBuilderService for drop-in replacement. + Graph builder using direct LLM calls instead of Zep. + Two-pass extraction: entities first, then relationships. """ def __init__(self, llm_client: Optional[LLMClient] = None): self.llm = llm_client or LLMClient() self.task_manager = TaskManager() - # In-memory graph storage (keyed by graph_id) self._graphs: Dict[str, Dict[str, Any]] = {} def create_graph(self, name: str) -> str: @@ -51,7 +68,7 @@ class LLMGraphBuilderService: self._graphs[graph_id] = { "name": name, "ontology": None, - "nodes": {}, # keyed by normalized name + "nodes": {}, "edges": [], } return graph_id @@ -60,16 +77,22 @@ class LLMGraphBuilderService: if graph_id in self._graphs: self._graphs[graph_id]["ontology"] = ontology - def extract_from_chunks( + # ── Pass 1: Entity Extraction ── + + def extract_entities( self, graph_id: str, chunks: List[str], progress_callback: Optional[Callable] = None ): - """Extract entities and relationships from text chunks using LLM.""" + """Pass 1: Extract entities from each chunk.""" graph = self._graphs[graph_id] ontology = graph["ontology"] - ontology_json = json.dumps(ontology, indent=2, ensure_ascii=False) + + # Build entity-types-only schema for the prompt + entity_types_json = json.dumps( + ontology.get("entity_types", []), indent=2, ensure_ascii=False + ) total = len(chunks) success_count = 0 @@ -79,65 +102,125 @@ class LLMGraphBuilderService: for i, chunk in enumerate(chunks): if progress_callback: progress_callback( - f"Extracting from chunk {i+1}/{total} (ok={success_count}, fail={fail_count})...", + f"[Pass 1] Extracting entities from chunk {i+1}/{total} " + f"(ok={success_count}, fail={fail_count})...", (i + 1) / total ) try: - logger.info(f"Extracting chunk {i+1}/{total} ({len(chunk)} chars)") + logger.info(f"[Pass 1] Entity extraction chunk {i+1}/{total} ({len(chunk)} chars)") result = self.llm.chat_json( messages=[ { "role": "system", - "content": EXTRACT_SYSTEM_PROMPT_TEMPLATE % ontology_json + "content": ENTITY_EXTRACT_PROMPT % entity_types_json }, { "role": "user", - "content": f"Extract entities and relationships from this text:\n\n{chunk}" + "content": f"Extract all entities from this text:\n\n{chunk}" } ], temperature=0.1, max_tokens=4096 ) entities = result.get("entities", []) - rels = result.get("relationships", []) - logger.info(f"Chunk {i+1}: extracted {len(entities)} entities, {len(rels)} relationships") - self._merge_extraction(graph_id, result) + logger.info(f"[Pass 1] Chunk {i+1}: {len(entities)} entities") + self._merge_entities(graph_id, entities) success_count += 1 except Exception as e: fail_count += 1 last_error = e - import traceback - logger.error(f"Chunk {i+1} extraction failed: {type(e).__name__}: {e}") - logger.error(f"Chunk {i+1} traceback: {traceback.format_exc()}") - if progress_callback: - progress_callback(f"Chunk {i+1} error: {e}", (i + 1) / total) + logger.error(f"[Pass 1] Chunk {i+1} failed: {type(e).__name__}: {e}") + logger.debug(traceback.format_exc()) - logger.info(f"Extraction complete: {success_count}/{total} succeeded, {fail_count} failed") + logger.info(f"[Pass 1] Complete: {success_count}/{total} succeeded, " + f"{len(graph['nodes'])} unique entities found") if success_count == 0 and total > 0: - raise RuntimeError(f"All {total} chunks failed extraction. Last error: {last_error}") + raise RuntimeError(f"All {total} entity extraction calls failed. Last error: {last_error}") - def _merge_extraction(self, graph_id: str, result: Dict[str, Any]): - """Merge extracted entities/relationships into the graph, deduplicating by name.""" + # ── Pass 2: Relationship Discovery ── + + def discover_relationships( + self, + graph_id: str, + full_text: str, + progress_callback: Optional[Callable] = None + ): + """Pass 2: Find relationships between known entities using larger text windows.""" + graph = self._graphs[graph_id] + ontology = graph["ontology"] + nodes = graph["nodes"] + + if not nodes: + logger.warning("[Pass 2] No entities to find relationships for") + return + + # Build edge types schema + edge_types_json = json.dumps( + ontology.get("edge_types", []), indent=2, ensure_ascii=False + ) + + # Build entity list string + entity_names = sorted(set(n["name"] for n in nodes.values())) + entity_list = "\n".join(f"- {name} ({nodes[name.lower()]['labels'][0]})" + for name in entity_names if name.lower() in nodes) + + # Use larger chunks for relationship discovery (5000 chars, 500 overlap) + rel_chunks = TextProcessor.split_text(full_text, chunk_size=5000, overlap=500) + total = len(rel_chunks) + success_count = 0 + fail_count = 0 + + for i, chunk in enumerate(rel_chunks): + if progress_callback: + progress_callback( + f"[Pass 2] Finding relationships in section {i+1}/{total} " + f"(edges so far: {len(graph['edges'])})...", + (i + 1) / total + ) + + try: + logger.info(f"[Pass 2] Relationship discovery section {i+1}/{total} ({len(chunk)} chars)") + result = self.llm.chat_json( + messages=[ + { + "role": "system", + "content": RELATIONSHIP_EXTRACT_PROMPT % (edge_types_json, entity_list) + }, + { + "role": "user", + "content": f"Find all relationships between the known entities in this text:\n\n{chunk}" + } + ], + temperature=0.1, + max_tokens=4096 + ) + rels = result.get("relationships", []) + logger.info(f"[Pass 2] Section {i+1}: {len(rels)} relationships") + self._merge_relationships(graph_id, rels) + success_count += 1 + except Exception as e: + fail_count += 1 + logger.error(f"[Pass 2] Section {i+1} failed: {type(e).__name__}: {e}") + logger.debug(traceback.format_exc()) + + logger.info(f"[Pass 2] Complete: {success_count}/{total} succeeded, " + f"{len(graph['edges'])} total edges") + + # ── Merge Helpers ── + + def _merge_entities(self, graph_id: str, entities: List[Dict[str, Any]]): + """Merge extracted entities into the graph.""" graph = self._graphs[graph_id] nodes = graph["nodes"] - edges = graph["edges"] - # Valid entity type names from ontology valid_entity_types = set() if graph["ontology"]: for et in graph["ontology"].get("entity_types", []): valid_entity_types.add(et["name"]) - # Valid edge type names - valid_edge_types = set() - if graph["ontology"]: - for et in graph["ontology"].get("edge_types", []): - valid_edge_types.add(et["name"]) - - # Merge entities - for entity in result.get("entities", []): + for entity in entities: name = entity.get("name", "").strip() if not name: continue @@ -146,13 +229,11 @@ class LLMGraphBuilderService: key = name.lower() if key in nodes: - # Update summary if new one is longer existing = nodes[key] new_summary = entity.get("summary", "") if new_summary and len(new_summary) > len(existing.get("summary", "")): existing["summary"] = new_summary - # Merge attributes - for k, v in entity.get("attributes", {}).items(): + for k, v in (entity.get("attributes", {}) or {}).items(): if v and not existing["attributes"].get(k): existing["attributes"][k] = v else: @@ -166,12 +247,17 @@ class LLMGraphBuilderService: "created_at": None, } - # Merge relationships (deduplicate by source+target+name) + def _merge_relationships(self, graph_id: str, relationships: List[Dict[str, Any]]): + """Merge extracted relationships into the graph.""" + graph = self._graphs[graph_id] + nodes = graph["nodes"] + edges = graph["edges"] + existing_edges = set() for e in edges: existing_edges.add((e["source_node_name"].lower(), e["target_node_name"].lower(), e["name"])) - for rel in result.get("relationships", []): + for rel in relationships: rel_name = rel.get("name", "").strip() source = rel.get("source", "").strip() target = rel.get("target", "").strip() @@ -183,13 +269,11 @@ class LLMGraphBuilderService: continue existing_edges.add(edge_key) - # Resolve node UUIDs source_node = nodes.get(source.lower()) target_node = nodes.get(target.lower()) source_uuid = source_node["uuid"] if source_node else str(uuid.uuid4()) target_uuid = target_node["uuid"] if target_node else str(uuid.uuid4()) - # Create placeholder nodes if they don't exist if not source_node: nodes[source.lower()] = { "uuid": source_uuid, @@ -226,12 +310,12 @@ class LLMGraphBuilderService: "episodes": [], }) + # ── Data Access ── + def get_graph_data(self, graph_id: str) -> Dict[str, Any]: - """Return graph data in the same format as the Zep-based builder.""" graph = self._graphs.get(graph_id, {"nodes": {}, "edges": []}) nodes_list = list(graph["nodes"].values()) edges_list = graph["edges"] - return { "graph_id": graph_id, "nodes": nodes_list, @@ -241,7 +325,6 @@ class LLMGraphBuilderService: } def save_graph_data(self, graph_id: str, project_dir: str) -> str: - """Persist graph data to a JSON file in the project directory.""" data = self.get_graph_data(graph_id) path = os.path.join(project_dir, "graph_data.json") with open(path, "w", encoding="utf-8") as f: @@ -250,7 +333,6 @@ class LLMGraphBuilderService: @staticmethod def load_graph_data(project_dir: str) -> Optional[Dict[str, Any]]: - """Load persisted graph data from disk.""" path = os.path.join(project_dir, "graph_data.json") if os.path.exists(path): with open(path, "r", encoding="utf-8") as f: