Two-pass graph extraction: entities then relationships with larger chunks

This commit is contained in:
_Yusaki 2026-03-13 19:30:30 +07:00
parent e806898018
commit 10a85e76d6
2 changed files with 175 additions and 67 deletions

View file

@ -372,8 +372,13 @@ def build_graph():
) )
# 创建 LLM 图谱构建服务(不需要 Zep # 创建 LLM 图谱构建服务(不需要 Zep
from ..services.llm_graph_builder import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP
builder = LLMGraphBuilderService() builder = LLMGraphBuilderService()
# Use larger chunks for better context
entity_chunk_size = max(chunk_size, DEFAULT_CHUNK_SIZE)
entity_chunk_overlap = max(chunk_overlap, DEFAULT_CHUNK_OVERLAP)
# 分块 # 分块
task_manager.update_task( task_manager.update_task(
task_id, task_id,
@ -382,8 +387,8 @@ def build_graph():
) )
chunks = TextProcessor.split_text( chunks = TextProcessor.split_text(
text, text,
chunk_size=chunk_size, chunk_size=entity_chunk_size,
overlap=chunk_overlap overlap=entity_chunk_overlap
) )
total_chunks = len(chunks) total_chunks = len(chunks)
@ -402,9 +407,9 @@ def build_graph():
# 设置本体 # 设置本体
builder.set_ontology(graph_id, ontology) builder.set_ontology(graph_id, ontology)
# LLM extraction from chunks # Pass 1: Entity extraction
def extract_progress_callback(msg, progress_ratio): def entity_progress_callback(msg, progress_ratio):
progress = 15 + int(progress_ratio * 75) # 15% - 90% progress = 15 + int(progress_ratio * 40) # 15% - 55%
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message=msg, message=msg,
@ -413,14 +418,35 @@ def build_graph():
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message=f"Extracting entities from {total_chunks} chunks via LLM...", message=f"[Pass 1] Extracting entities from {total_chunks} chunks...",
progress=15 progress=15
) )
builder.extract_from_chunks( builder.extract_entities(
graph_id, graph_id,
chunks, chunks,
progress_callback=extract_progress_callback progress_callback=entity_progress_callback
)
# Pass 2: Relationship discovery
def rel_progress_callback(msg, progress_ratio):
progress = 55 + int(progress_ratio * 35) # 55% - 90%
task_manager.update_task(
task_id,
message=msg,
progress=progress
)
task_manager.update_task(
task_id,
message="[Pass 2] Discovering relationships between entities...",
progress=55
)
builder.discover_relationships(
graph_id,
text,
progress_callback=rel_progress_callback
) )
# 获取图谱数据 # 获取图谱数据

View file

@ -1,12 +1,14 @@
""" """
LLM-based graph builder service LLM-based graph builder service
Replaces Zep with direct LLM calls for entity/relationship extraction Replaces Zep with direct LLM calls for entity/relationship extraction.
Two-pass approach: (1) extract entities, (2) discover relationships.
""" """
import os import os
import uuid import uuid
import json import json
import logging import logging
import traceback
from typing import Dict, Any, List, Optional, Callable from typing import Dict, Any, List, Optional, Callable
from ..utils.llm_client import LLMClient from ..utils.llm_client import LLMClient
@ -15,35 +17,50 @@ from .text_processor import TextProcessor
logger = logging.getLogger('mirofish.llm_graph_builder') logger = logging.getLogger('mirofish.llm_graph_builder')
# Default chunk size — larger than Zep's default to capture more context per call
DEFAULT_CHUNK_SIZE = 2500
DEFAULT_CHUNK_OVERLAP = 200
EXTRACT_SYSTEM_PROMPT_TEMPLATE = ( ENTITY_EXTRACT_PROMPT = (
"You are a knowledge graph extraction engine. Given a text chunk and an ontology schema, " "You are a knowledge graph entity extraction engine. Given a text chunk and an ontology schema, "
"extract all entities and relationships.\n\n" "extract all entities mentioned in the text.\n\n"
"ONTOLOGY SCHEMA:\n%s\n\n" "ONTOLOGY SCHEMA (entity types only):\n%s\n\n"
"RULES:\n" "RULES:\n"
"1. Extract entities that match the entity_types defined in the schema. Each entity needs: " "1. Extract entities that match the entity_types defined above.\n"
"name, type (matching an entity_type name), summary (1-2 sentences), and any attributes defined for that type.\n" "2. Each entity needs: name (the canonical name used in the text), type (must match an entity_type name from the schema), "
"2. Extract relationships between entities that match the edge_types defined in the schema. " "summary (1-2 sentences describing the entity based on the text), and attributes (fill in any attributes defined for that type).\n"
"Each relationship needs: name (the edge type name), source (entity name), target (entity name), " "3. Only extract entities explicitly mentioned or strongly implied in the text.\n"
"and a fact (short description of the relationship).\n" "4. Use the exact name as it appears in the text (e.g. 'Mira' not 'Mira the Socializer').\n"
"3. Only extract entities and relationships that are explicitly mentioned or strongly implied in the text.\n" "5. If no entities are found, return an empty array.\n\n"
"4. Use consistent entity names across extractions.\n" 'Return JSON: a single key "entities" with an array of objects, each having keys: name, type, summary, attributes.'
"5. If no entities or relationships are found, return empty arrays.\n\n" )
'Return JSON with keys "entities" (array of objects with name, type, summary, attributes) '
'and "relationships" (array of objects with name, source, target, fact).' RELATIONSHIP_EXTRACT_PROMPT = (
"You are a knowledge graph relationship extraction engine. Given a text section, a list of known entities, "
"and an ontology schema, extract all relationships between the entities.\n\n"
"ONTOLOGY SCHEMA (edge types):\n%s\n\n"
"KNOWN ENTITIES:\n%s\n\n"
"RULES:\n"
"1. Find relationships between the known entities that match the edge_types defined above.\n"
"2. Each relationship needs: name (must match an edge_type name), source (entity name), target (entity name), "
"fact (1 sentence describing the specific relationship found in the text).\n"
"3. Both source and target MUST be from the known entities list.\n"
"4. Only extract relationships explicitly stated or strongly implied in the text.\n"
"5. Extract ALL relationships you can find — be thorough.\n"
"6. If no relationships are found, return an empty array.\n\n"
'Return JSON: a single key "relationships" with an array of objects, each having keys: name, source, target, fact.'
) )
class LLMGraphBuilderService: class LLMGraphBuilderService:
""" """
Graph builder that uses direct LLM calls instead of Zep. Graph builder using direct LLM calls instead of Zep.
Same interface as GraphBuilderService for drop-in replacement. Two-pass extraction: entities first, then relationships.
""" """
def __init__(self, llm_client: Optional[LLMClient] = None): def __init__(self, llm_client: Optional[LLMClient] = None):
self.llm = llm_client or LLMClient() self.llm = llm_client or LLMClient()
self.task_manager = TaskManager() self.task_manager = TaskManager()
# In-memory graph storage (keyed by graph_id)
self._graphs: Dict[str, Dict[str, Any]] = {} self._graphs: Dict[str, Dict[str, Any]] = {}
def create_graph(self, name: str) -> str: def create_graph(self, name: str) -> str:
@ -51,7 +68,7 @@ class LLMGraphBuilderService:
self._graphs[graph_id] = { self._graphs[graph_id] = {
"name": name, "name": name,
"ontology": None, "ontology": None,
"nodes": {}, # keyed by normalized name "nodes": {},
"edges": [], "edges": [],
} }
return graph_id return graph_id
@ -60,16 +77,22 @@ class LLMGraphBuilderService:
if graph_id in self._graphs: if graph_id in self._graphs:
self._graphs[graph_id]["ontology"] = ontology self._graphs[graph_id]["ontology"] = ontology
def extract_from_chunks( # ── Pass 1: Entity Extraction ──
def extract_entities(
self, self,
graph_id: str, graph_id: str,
chunks: List[str], chunks: List[str],
progress_callback: Optional[Callable] = None progress_callback: Optional[Callable] = None
): ):
"""Extract entities and relationships from text chunks using LLM.""" """Pass 1: Extract entities from each chunk."""
graph = self._graphs[graph_id] graph = self._graphs[graph_id]
ontology = graph["ontology"] ontology = graph["ontology"]
ontology_json = json.dumps(ontology, indent=2, ensure_ascii=False)
# Build entity-types-only schema for the prompt
entity_types_json = json.dumps(
ontology.get("entity_types", []), indent=2, ensure_ascii=False
)
total = len(chunks) total = len(chunks)
success_count = 0 success_count = 0
@ -79,65 +102,125 @@ class LLMGraphBuilderService:
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
if progress_callback: if progress_callback:
progress_callback( progress_callback(
f"Extracting from chunk {i+1}/{total} (ok={success_count}, fail={fail_count})...", f"[Pass 1] Extracting entities from chunk {i+1}/{total} "
f"(ok={success_count}, fail={fail_count})...",
(i + 1) / total (i + 1) / total
) )
try: try:
logger.info(f"Extracting chunk {i+1}/{total} ({len(chunk)} chars)") logger.info(f"[Pass 1] Entity extraction chunk {i+1}/{total} ({len(chunk)} chars)")
result = self.llm.chat_json( result = self.llm.chat_json(
messages=[ messages=[
{ {
"role": "system", "role": "system",
"content": EXTRACT_SYSTEM_PROMPT_TEMPLATE % ontology_json "content": ENTITY_EXTRACT_PROMPT % entity_types_json
}, },
{ {
"role": "user", "role": "user",
"content": f"Extract entities and relationships from this text:\n\n{chunk}" "content": f"Extract all entities from this text:\n\n{chunk}"
} }
], ],
temperature=0.1, temperature=0.1,
max_tokens=4096 max_tokens=4096
) )
entities = result.get("entities", []) entities = result.get("entities", [])
rels = result.get("relationships", []) logger.info(f"[Pass 1] Chunk {i+1}: {len(entities)} entities")
logger.info(f"Chunk {i+1}: extracted {len(entities)} entities, {len(rels)} relationships") self._merge_entities(graph_id, entities)
self._merge_extraction(graph_id, result)
success_count += 1 success_count += 1
except Exception as e: except Exception as e:
fail_count += 1 fail_count += 1
last_error = e last_error = e
import traceback logger.error(f"[Pass 1] Chunk {i+1} failed: {type(e).__name__}: {e}")
logger.error(f"Chunk {i+1} extraction failed: {type(e).__name__}: {e}") logger.debug(traceback.format_exc())
logger.error(f"Chunk {i+1} traceback: {traceback.format_exc()}")
if progress_callback:
progress_callback(f"Chunk {i+1} error: {e}", (i + 1) / total)
logger.info(f"Extraction complete: {success_count}/{total} succeeded, {fail_count} failed") logger.info(f"[Pass 1] Complete: {success_count}/{total} succeeded, "
f"{len(graph['nodes'])} unique entities found")
if success_count == 0 and total > 0: if success_count == 0 and total > 0:
raise RuntimeError(f"All {total} chunks failed extraction. Last error: {last_error}") raise RuntimeError(f"All {total} entity extraction calls failed. Last error: {last_error}")
def _merge_extraction(self, graph_id: str, result: Dict[str, Any]): # ── Pass 2: Relationship Discovery ──
"""Merge extracted entities/relationships into the graph, deduplicating by name."""
def discover_relationships(
self,
graph_id: str,
full_text: str,
progress_callback: Optional[Callable] = None
):
"""Pass 2: Find relationships between known entities using larger text windows."""
graph = self._graphs[graph_id]
ontology = graph["ontology"]
nodes = graph["nodes"]
if not nodes:
logger.warning("[Pass 2] No entities to find relationships for")
return
# Build edge types schema
edge_types_json = json.dumps(
ontology.get("edge_types", []), indent=2, ensure_ascii=False
)
# Build entity list string
entity_names = sorted(set(n["name"] for n in nodes.values()))
entity_list = "\n".join(f"- {name} ({nodes[name.lower()]['labels'][0]})"
for name in entity_names if name.lower() in nodes)
# Use larger chunks for relationship discovery (5000 chars, 500 overlap)
rel_chunks = TextProcessor.split_text(full_text, chunk_size=5000, overlap=500)
total = len(rel_chunks)
success_count = 0
fail_count = 0
for i, chunk in enumerate(rel_chunks):
if progress_callback:
progress_callback(
f"[Pass 2] Finding relationships in section {i+1}/{total} "
f"(edges so far: {len(graph['edges'])})...",
(i + 1) / total
)
try:
logger.info(f"[Pass 2] Relationship discovery section {i+1}/{total} ({len(chunk)} chars)")
result = self.llm.chat_json(
messages=[
{
"role": "system",
"content": RELATIONSHIP_EXTRACT_PROMPT % (edge_types_json, entity_list)
},
{
"role": "user",
"content": f"Find all relationships between the known entities in this text:\n\n{chunk}"
}
],
temperature=0.1,
max_tokens=4096
)
rels = result.get("relationships", [])
logger.info(f"[Pass 2] Section {i+1}: {len(rels)} relationships")
self._merge_relationships(graph_id, rels)
success_count += 1
except Exception as e:
fail_count += 1
logger.error(f"[Pass 2] Section {i+1} failed: {type(e).__name__}: {e}")
logger.debug(traceback.format_exc())
logger.info(f"[Pass 2] Complete: {success_count}/{total} succeeded, "
f"{len(graph['edges'])} total edges")
# ── Merge Helpers ──
def _merge_entities(self, graph_id: str, entities: List[Dict[str, Any]]):
"""Merge extracted entities into the graph."""
graph = self._graphs[graph_id] graph = self._graphs[graph_id]
nodes = graph["nodes"] nodes = graph["nodes"]
edges = graph["edges"]
# Valid entity type names from ontology
valid_entity_types = set() valid_entity_types = set()
if graph["ontology"]: if graph["ontology"]:
for et in graph["ontology"].get("entity_types", []): for et in graph["ontology"].get("entity_types", []):
valid_entity_types.add(et["name"]) valid_entity_types.add(et["name"])
# Valid edge type names for entity in entities:
valid_edge_types = set()
if graph["ontology"]:
for et in graph["ontology"].get("edge_types", []):
valid_edge_types.add(et["name"])
# Merge entities
for entity in result.get("entities", []):
name = entity.get("name", "").strip() name = entity.get("name", "").strip()
if not name: if not name:
continue continue
@ -146,13 +229,11 @@ class LLMGraphBuilderService:
key = name.lower() key = name.lower()
if key in nodes: if key in nodes:
# Update summary if new one is longer
existing = nodes[key] existing = nodes[key]
new_summary = entity.get("summary", "") new_summary = entity.get("summary", "")
if new_summary and len(new_summary) > len(existing.get("summary", "")): if new_summary and len(new_summary) > len(existing.get("summary", "")):
existing["summary"] = new_summary existing["summary"] = new_summary
# Merge attributes for k, v in (entity.get("attributes", {}) or {}).items():
for k, v in entity.get("attributes", {}).items():
if v and not existing["attributes"].get(k): if v and not existing["attributes"].get(k):
existing["attributes"][k] = v existing["attributes"][k] = v
else: else:
@ -166,12 +247,17 @@ class LLMGraphBuilderService:
"created_at": None, "created_at": None,
} }
# Merge relationships (deduplicate by source+target+name) def _merge_relationships(self, graph_id: str, relationships: List[Dict[str, Any]]):
"""Merge extracted relationships into the graph."""
graph = self._graphs[graph_id]
nodes = graph["nodes"]
edges = graph["edges"]
existing_edges = set() existing_edges = set()
for e in edges: for e in edges:
existing_edges.add((e["source_node_name"].lower(), e["target_node_name"].lower(), e["name"])) existing_edges.add((e["source_node_name"].lower(), e["target_node_name"].lower(), e["name"]))
for rel in result.get("relationships", []): for rel in relationships:
rel_name = rel.get("name", "").strip() rel_name = rel.get("name", "").strip()
source = rel.get("source", "").strip() source = rel.get("source", "").strip()
target = rel.get("target", "").strip() target = rel.get("target", "").strip()
@ -183,13 +269,11 @@ class LLMGraphBuilderService:
continue continue
existing_edges.add(edge_key) existing_edges.add(edge_key)
# Resolve node UUIDs
source_node = nodes.get(source.lower()) source_node = nodes.get(source.lower())
target_node = nodes.get(target.lower()) target_node = nodes.get(target.lower())
source_uuid = source_node["uuid"] if source_node else str(uuid.uuid4()) source_uuid = source_node["uuid"] if source_node else str(uuid.uuid4())
target_uuid = target_node["uuid"] if target_node else str(uuid.uuid4()) target_uuid = target_node["uuid"] if target_node else str(uuid.uuid4())
# Create placeholder nodes if they don't exist
if not source_node: if not source_node:
nodes[source.lower()] = { nodes[source.lower()] = {
"uuid": source_uuid, "uuid": source_uuid,
@ -226,12 +310,12 @@ class LLMGraphBuilderService:
"episodes": [], "episodes": [],
}) })
# ── Data Access ──
def get_graph_data(self, graph_id: str) -> Dict[str, Any]: def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
"""Return graph data in the same format as the Zep-based builder."""
graph = self._graphs.get(graph_id, {"nodes": {}, "edges": []}) graph = self._graphs.get(graph_id, {"nodes": {}, "edges": []})
nodes_list = list(graph["nodes"].values()) nodes_list = list(graph["nodes"].values())
edges_list = graph["edges"] edges_list = graph["edges"]
return { return {
"graph_id": graph_id, "graph_id": graph_id,
"nodes": nodes_list, "nodes": nodes_list,
@ -241,7 +325,6 @@ class LLMGraphBuilderService:
} }
def save_graph_data(self, graph_id: str, project_dir: str) -> str: def save_graph_data(self, graph_id: str, project_dir: str) -> str:
"""Persist graph data to a JSON file in the project directory."""
data = self.get_graph_data(graph_id) data = self.get_graph_data(graph_id)
path = os.path.join(project_dir, "graph_data.json") path = os.path.join(project_dir, "graph_data.json")
with open(path, "w", encoding="utf-8") as f: with open(path, "w", encoding="utf-8") as f:
@ -250,7 +333,6 @@ class LLMGraphBuilderService:
@staticmethod @staticmethod
def load_graph_data(project_dir: str) -> Optional[Dict[str, Any]]: def load_graph_data(project_dir: str) -> Optional[Dict[str, Any]]:
"""Load persisted graph data from disk."""
path = os.path.join(project_dir, "graph_data.json") path = os.path.join(project_dir, "graph_data.json")
if os.path.exists(path): if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f: