Two-pass graph extraction: entities then relationships with larger chunks

This commit is contained in:
_Yusaki 2026-03-13 19:30:30 +07:00
parent e806898018
commit 10a85e76d6
2 changed files with 175 additions and 67 deletions

View file

@ -372,8 +372,13 @@ def build_graph():
)
# 创建 LLM 图谱构建服务(不需要 Zep
from ..services.llm_graph_builder import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP
builder = LLMGraphBuilderService()
# Use larger chunks for better context
entity_chunk_size = max(chunk_size, DEFAULT_CHUNK_SIZE)
entity_chunk_overlap = max(chunk_overlap, DEFAULT_CHUNK_OVERLAP)
# 分块
task_manager.update_task(
task_id,
@ -382,8 +387,8 @@ def build_graph():
)
chunks = TextProcessor.split_text(
text,
chunk_size=chunk_size,
overlap=chunk_overlap
chunk_size=entity_chunk_size,
overlap=entity_chunk_overlap
)
total_chunks = len(chunks)
@ -402,9 +407,9 @@ def build_graph():
# 设置本体
builder.set_ontology(graph_id, ontology)
# LLM extraction from chunks
def extract_progress_callback(msg, progress_ratio):
progress = 15 + int(progress_ratio * 75) # 15% - 90%
# Pass 1: Entity extraction
def entity_progress_callback(msg, progress_ratio):
progress = 15 + int(progress_ratio * 40) # 15% - 55%
task_manager.update_task(
task_id,
message=msg,
@ -413,14 +418,35 @@ def build_graph():
task_manager.update_task(
task_id,
message=f"Extracting entities from {total_chunks} chunks via LLM...",
message=f"[Pass 1] Extracting entities from {total_chunks} chunks...",
progress=15
)
builder.extract_from_chunks(
builder.extract_entities(
graph_id,
chunks,
progress_callback=extract_progress_callback
progress_callback=entity_progress_callback
)
# Pass 2: Relationship discovery
def rel_progress_callback(msg, progress_ratio):
progress = 55 + int(progress_ratio * 35) # 55% - 90%
task_manager.update_task(
task_id,
message=msg,
progress=progress
)
task_manager.update_task(
task_id,
message="[Pass 2] Discovering relationships between entities...",
progress=55
)
builder.discover_relationships(
graph_id,
text,
progress_callback=rel_progress_callback
)
# 获取图谱数据

View file

@ -1,12 +1,14 @@
"""
LLM-based graph builder service
Replaces Zep with direct LLM calls for entity/relationship extraction
Replaces Zep with direct LLM calls for entity/relationship extraction.
Two-pass approach: (1) extract entities, (2) discover relationships.
"""
import os
import uuid
import json
import logging
import traceback
from typing import Dict, Any, List, Optional, Callable
from ..utils.llm_client import LLMClient
@ -15,35 +17,50 @@ from .text_processor import TextProcessor
logger = logging.getLogger('mirofish.llm_graph_builder')
# Default chunk size — larger than Zep's default to capture more context per call
DEFAULT_CHUNK_SIZE = 2500
DEFAULT_CHUNK_OVERLAP = 200
EXTRACT_SYSTEM_PROMPT_TEMPLATE = (
"You are a knowledge graph extraction engine. Given a text chunk and an ontology schema, "
"extract all entities and relationships.\n\n"
"ONTOLOGY SCHEMA:\n%s\n\n"
ENTITY_EXTRACT_PROMPT = (
"You are a knowledge graph entity extraction engine. Given a text chunk and an ontology schema, "
"extract all entities mentioned in the text.\n\n"
"ONTOLOGY SCHEMA (entity types only):\n%s\n\n"
"RULES:\n"
"1. Extract entities that match the entity_types defined in the schema. Each entity needs: "
"name, type (matching an entity_type name), summary (1-2 sentences), and any attributes defined for that type.\n"
"2. Extract relationships between entities that match the edge_types defined in the schema. "
"Each relationship needs: name (the edge type name), source (entity name), target (entity name), "
"and a fact (short description of the relationship).\n"
"3. Only extract entities and relationships that are explicitly mentioned or strongly implied in the text.\n"
"4. Use consistent entity names across extractions.\n"
"5. If no entities or relationships are found, return empty arrays.\n\n"
'Return JSON with keys "entities" (array of objects with name, type, summary, attributes) '
'and "relationships" (array of objects with name, source, target, fact).'
"1. Extract entities that match the entity_types defined above.\n"
"2. Each entity needs: name (the canonical name used in the text), type (must match an entity_type name from the schema), "
"summary (1-2 sentences describing the entity based on the text), and attributes (fill in any attributes defined for that type).\n"
"3. Only extract entities explicitly mentioned or strongly implied in the text.\n"
"4. Use the exact name as it appears in the text (e.g. 'Mira' not 'Mira the Socializer').\n"
"5. If no entities are found, return an empty array.\n\n"
'Return JSON: a single key "entities" with an array of objects, each having keys: name, type, summary, attributes.'
)
RELATIONSHIP_EXTRACT_PROMPT = (
"You are a knowledge graph relationship extraction engine. Given a text section, a list of known entities, "
"and an ontology schema, extract all relationships between the entities.\n\n"
"ONTOLOGY SCHEMA (edge types):\n%s\n\n"
"KNOWN ENTITIES:\n%s\n\n"
"RULES:\n"
"1. Find relationships between the known entities that match the edge_types defined above.\n"
"2. Each relationship needs: name (must match an edge_type name), source (entity name), target (entity name), "
"fact (1 sentence describing the specific relationship found in the text).\n"
"3. Both source and target MUST be from the known entities list.\n"
"4. Only extract relationships explicitly stated or strongly implied in the text.\n"
"5. Extract ALL relationships you can find — be thorough.\n"
"6. If no relationships are found, return an empty array.\n\n"
'Return JSON: a single key "relationships" with an array of objects, each having keys: name, source, target, fact.'
)
class LLMGraphBuilderService:
"""
Graph builder that uses direct LLM calls instead of Zep.
Same interface as GraphBuilderService for drop-in replacement.
Graph builder using direct LLM calls instead of Zep.
Two-pass extraction: entities first, then relationships.
"""
def __init__(self, llm_client: Optional[LLMClient] = None):
self.llm = llm_client or LLMClient()
self.task_manager = TaskManager()
# In-memory graph storage (keyed by graph_id)
self._graphs: Dict[str, Dict[str, Any]] = {}
def create_graph(self, name: str) -> str:
@ -51,7 +68,7 @@ class LLMGraphBuilderService:
self._graphs[graph_id] = {
"name": name,
"ontology": None,
"nodes": {}, # keyed by normalized name
"nodes": {},
"edges": [],
}
return graph_id
@ -60,16 +77,22 @@ class LLMGraphBuilderService:
if graph_id in self._graphs:
self._graphs[graph_id]["ontology"] = ontology
def extract_from_chunks(
# ── Pass 1: Entity Extraction ──
def extract_entities(
self,
graph_id: str,
chunks: List[str],
progress_callback: Optional[Callable] = None
):
"""Extract entities and relationships from text chunks using LLM."""
"""Pass 1: Extract entities from each chunk."""
graph = self._graphs[graph_id]
ontology = graph["ontology"]
ontology_json = json.dumps(ontology, indent=2, ensure_ascii=False)
# Build entity-types-only schema for the prompt
entity_types_json = json.dumps(
ontology.get("entity_types", []), indent=2, ensure_ascii=False
)
total = len(chunks)
success_count = 0
@ -79,65 +102,125 @@ class LLMGraphBuilderService:
for i, chunk in enumerate(chunks):
if progress_callback:
progress_callback(
f"Extracting from chunk {i+1}/{total} (ok={success_count}, fail={fail_count})...",
f"[Pass 1] Extracting entities from chunk {i+1}/{total} "
f"(ok={success_count}, fail={fail_count})...",
(i + 1) / total
)
try:
logger.info(f"Extracting chunk {i+1}/{total} ({len(chunk)} chars)")
logger.info(f"[Pass 1] Entity extraction chunk {i+1}/{total} ({len(chunk)} chars)")
result = self.llm.chat_json(
messages=[
{
"role": "system",
"content": EXTRACT_SYSTEM_PROMPT_TEMPLATE % ontology_json
"content": ENTITY_EXTRACT_PROMPT % entity_types_json
},
{
"role": "user",
"content": f"Extract entities and relationships from this text:\n\n{chunk}"
"content": f"Extract all entities from this text:\n\n{chunk}"
}
],
temperature=0.1,
max_tokens=4096
)
entities = result.get("entities", [])
rels = result.get("relationships", [])
logger.info(f"Chunk {i+1}: extracted {len(entities)} entities, {len(rels)} relationships")
self._merge_extraction(graph_id, result)
logger.info(f"[Pass 1] Chunk {i+1}: {len(entities)} entities")
self._merge_entities(graph_id, entities)
success_count += 1
except Exception as e:
fail_count += 1
last_error = e
import traceback
logger.error(f"Chunk {i+1} extraction failed: {type(e).__name__}: {e}")
logger.error(f"Chunk {i+1} traceback: {traceback.format_exc()}")
if progress_callback:
progress_callback(f"Chunk {i+1} error: {e}", (i + 1) / total)
logger.error(f"[Pass 1] Chunk {i+1} failed: {type(e).__name__}: {e}")
logger.debug(traceback.format_exc())
logger.info(f"Extraction complete: {success_count}/{total} succeeded, {fail_count} failed")
logger.info(f"[Pass 1] Complete: {success_count}/{total} succeeded, "
f"{len(graph['nodes'])} unique entities found")
if success_count == 0 and total > 0:
raise RuntimeError(f"All {total} chunks failed extraction. Last error: {last_error}")
raise RuntimeError(f"All {total} entity extraction calls failed. Last error: {last_error}")
def _merge_extraction(self, graph_id: str, result: Dict[str, Any]):
"""Merge extracted entities/relationships into the graph, deduplicating by name."""
# ── Pass 2: Relationship Discovery ──
def discover_relationships(
self,
graph_id: str,
full_text: str,
progress_callback: Optional[Callable] = None
):
"""Pass 2: Find relationships between known entities using larger text windows."""
graph = self._graphs[graph_id]
ontology = graph["ontology"]
nodes = graph["nodes"]
if not nodes:
logger.warning("[Pass 2] No entities to find relationships for")
return
# Build edge types schema
edge_types_json = json.dumps(
ontology.get("edge_types", []), indent=2, ensure_ascii=False
)
# Build entity list string
entity_names = sorted(set(n["name"] for n in nodes.values()))
entity_list = "\n".join(f"- {name} ({nodes[name.lower()]['labels'][0]})"
for name in entity_names if name.lower() in nodes)
# Use larger chunks for relationship discovery (5000 chars, 500 overlap)
rel_chunks = TextProcessor.split_text(full_text, chunk_size=5000, overlap=500)
total = len(rel_chunks)
success_count = 0
fail_count = 0
for i, chunk in enumerate(rel_chunks):
if progress_callback:
progress_callback(
f"[Pass 2] Finding relationships in section {i+1}/{total} "
f"(edges so far: {len(graph['edges'])})...",
(i + 1) / total
)
try:
logger.info(f"[Pass 2] Relationship discovery section {i+1}/{total} ({len(chunk)} chars)")
result = self.llm.chat_json(
messages=[
{
"role": "system",
"content": RELATIONSHIP_EXTRACT_PROMPT % (edge_types_json, entity_list)
},
{
"role": "user",
"content": f"Find all relationships between the known entities in this text:\n\n{chunk}"
}
],
temperature=0.1,
max_tokens=4096
)
rels = result.get("relationships", [])
logger.info(f"[Pass 2] Section {i+1}: {len(rels)} relationships")
self._merge_relationships(graph_id, rels)
success_count += 1
except Exception as e:
fail_count += 1
logger.error(f"[Pass 2] Section {i+1} failed: {type(e).__name__}: {e}")
logger.debug(traceback.format_exc())
logger.info(f"[Pass 2] Complete: {success_count}/{total} succeeded, "
f"{len(graph['edges'])} total edges")
# ── Merge Helpers ──
def _merge_entities(self, graph_id: str, entities: List[Dict[str, Any]]):
"""Merge extracted entities into the graph."""
graph = self._graphs[graph_id]
nodes = graph["nodes"]
edges = graph["edges"]
# Valid entity type names from ontology
valid_entity_types = set()
if graph["ontology"]:
for et in graph["ontology"].get("entity_types", []):
valid_entity_types.add(et["name"])
# Valid edge type names
valid_edge_types = set()
if graph["ontology"]:
for et in graph["ontology"].get("edge_types", []):
valid_edge_types.add(et["name"])
# Merge entities
for entity in result.get("entities", []):
for entity in entities:
name = entity.get("name", "").strip()
if not name:
continue
@ -146,13 +229,11 @@ class LLMGraphBuilderService:
key = name.lower()
if key in nodes:
# Update summary if new one is longer
existing = nodes[key]
new_summary = entity.get("summary", "")
if new_summary and len(new_summary) > len(existing.get("summary", "")):
existing["summary"] = new_summary
# Merge attributes
for k, v in entity.get("attributes", {}).items():
for k, v in (entity.get("attributes", {}) or {}).items():
if v and not existing["attributes"].get(k):
existing["attributes"][k] = v
else:
@ -166,12 +247,17 @@ class LLMGraphBuilderService:
"created_at": None,
}
# Merge relationships (deduplicate by source+target+name)
def _merge_relationships(self, graph_id: str, relationships: List[Dict[str, Any]]):
"""Merge extracted relationships into the graph."""
graph = self._graphs[graph_id]
nodes = graph["nodes"]
edges = graph["edges"]
existing_edges = set()
for e in edges:
existing_edges.add((e["source_node_name"].lower(), e["target_node_name"].lower(), e["name"]))
for rel in result.get("relationships", []):
for rel in relationships:
rel_name = rel.get("name", "").strip()
source = rel.get("source", "").strip()
target = rel.get("target", "").strip()
@ -183,13 +269,11 @@ class LLMGraphBuilderService:
continue
existing_edges.add(edge_key)
# Resolve node UUIDs
source_node = nodes.get(source.lower())
target_node = nodes.get(target.lower())
source_uuid = source_node["uuid"] if source_node else str(uuid.uuid4())
target_uuid = target_node["uuid"] if target_node else str(uuid.uuid4())
# Create placeholder nodes if they don't exist
if not source_node:
nodes[source.lower()] = {
"uuid": source_uuid,
@ -226,12 +310,12 @@ class LLMGraphBuilderService:
"episodes": [],
})
# ── Data Access ──
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
"""Return graph data in the same format as the Zep-based builder."""
graph = self._graphs.get(graph_id, {"nodes": {}, "edges": []})
nodes_list = list(graph["nodes"].values())
edges_list = graph["edges"]
return {
"graph_id": graph_id,
"nodes": nodes_list,
@ -241,7 +325,6 @@ class LLMGraphBuilderService:
}
def save_graph_data(self, graph_id: str, project_dir: str) -> str:
"""Persist graph data to a JSON file in the project directory."""
data = self.get_graph_data(graph_id)
path = os.path.join(project_dir, "graph_data.json")
with open(path, "w", encoding="utf-8") as f:
@ -250,7 +333,6 @@ class LLMGraphBuilderService:
@staticmethod
def load_graph_data(project_dir: str) -> Optional[Dict[str, Any]]:
"""Load persisted graph data from disk."""
path = os.path.join(project_dir, "graph_data.json")
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f: