Two-pass graph extraction: entities then relationships with larger chunks
This commit is contained in:
parent
e806898018
commit
10a85e76d6
2 changed files with 175 additions and 67 deletions
|
|
@ -372,8 +372,13 @@ def build_graph():
|
|||
)
|
||||
|
||||
# 创建 LLM 图谱构建服务(不需要 Zep)
|
||||
from ..services.llm_graph_builder import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP
|
||||
builder = LLMGraphBuilderService()
|
||||
|
||||
# Use larger chunks for better context
|
||||
entity_chunk_size = max(chunk_size, DEFAULT_CHUNK_SIZE)
|
||||
entity_chunk_overlap = max(chunk_overlap, DEFAULT_CHUNK_OVERLAP)
|
||||
|
||||
# 分块
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
|
|
@ -382,8 +387,8 @@ def build_graph():
|
|||
)
|
||||
chunks = TextProcessor.split_text(
|
||||
text,
|
||||
chunk_size=chunk_size,
|
||||
overlap=chunk_overlap
|
||||
chunk_size=entity_chunk_size,
|
||||
overlap=entity_chunk_overlap
|
||||
)
|
||||
total_chunks = len(chunks)
|
||||
|
||||
|
|
@ -402,9 +407,9 @@ def build_graph():
|
|||
# 设置本体
|
||||
builder.set_ontology(graph_id, ontology)
|
||||
|
||||
# LLM extraction from chunks
|
||||
def extract_progress_callback(msg, progress_ratio):
|
||||
progress = 15 + int(progress_ratio * 75) # 15% - 90%
|
||||
# Pass 1: Entity extraction
|
||||
def entity_progress_callback(msg, progress_ratio):
|
||||
progress = 15 + int(progress_ratio * 40) # 15% - 55%
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message=msg,
|
||||
|
|
@ -413,14 +418,35 @@ def build_graph():
|
|||
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message=f"Extracting entities from {total_chunks} chunks via LLM...",
|
||||
message=f"[Pass 1] Extracting entities from {total_chunks} chunks...",
|
||||
progress=15
|
||||
)
|
||||
|
||||
builder.extract_from_chunks(
|
||||
builder.extract_entities(
|
||||
graph_id,
|
||||
chunks,
|
||||
progress_callback=extract_progress_callback
|
||||
progress_callback=entity_progress_callback
|
||||
)
|
||||
|
||||
# Pass 2: Relationship discovery
|
||||
def rel_progress_callback(msg, progress_ratio):
|
||||
progress = 55 + int(progress_ratio * 35) # 55% - 90%
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message=msg,
|
||||
progress=progress
|
||||
)
|
||||
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message="[Pass 2] Discovering relationships between entities...",
|
||||
progress=55
|
||||
)
|
||||
|
||||
builder.discover_relationships(
|
||||
graph_id,
|
||||
text,
|
||||
progress_callback=rel_progress_callback
|
||||
)
|
||||
|
||||
# 获取图谱数据
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
"""
|
||||
LLM-based graph builder service
|
||||
Replaces Zep with direct LLM calls for entity/relationship extraction
|
||||
Replaces Zep with direct LLM calls for entity/relationship extraction.
|
||||
Two-pass approach: (1) extract entities, (2) discover relationships.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
|
||||
from ..utils.llm_client import LLMClient
|
||||
|
|
@ -15,35 +17,50 @@ from .text_processor import TextProcessor
|
|||
|
||||
logger = logging.getLogger('mirofish.llm_graph_builder')
|
||||
|
||||
# Default chunk size — larger than Zep's default to capture more context per call
|
||||
DEFAULT_CHUNK_SIZE = 2500
|
||||
DEFAULT_CHUNK_OVERLAP = 200
|
||||
|
||||
EXTRACT_SYSTEM_PROMPT_TEMPLATE = (
|
||||
"You are a knowledge graph extraction engine. Given a text chunk and an ontology schema, "
|
||||
"extract all entities and relationships.\n\n"
|
||||
"ONTOLOGY SCHEMA:\n%s\n\n"
|
||||
ENTITY_EXTRACT_PROMPT = (
|
||||
"You are a knowledge graph entity extraction engine. Given a text chunk and an ontology schema, "
|
||||
"extract all entities mentioned in the text.\n\n"
|
||||
"ONTOLOGY SCHEMA (entity types only):\n%s\n\n"
|
||||
"RULES:\n"
|
||||
"1. Extract entities that match the entity_types defined in the schema. Each entity needs: "
|
||||
"name, type (matching an entity_type name), summary (1-2 sentences), and any attributes defined for that type.\n"
|
||||
"2. Extract relationships between entities that match the edge_types defined in the schema. "
|
||||
"Each relationship needs: name (the edge type name), source (entity name), target (entity name), "
|
||||
"and a fact (short description of the relationship).\n"
|
||||
"3. Only extract entities and relationships that are explicitly mentioned or strongly implied in the text.\n"
|
||||
"4. Use consistent entity names across extractions.\n"
|
||||
"5. If no entities or relationships are found, return empty arrays.\n\n"
|
||||
'Return JSON with keys "entities" (array of objects with name, type, summary, attributes) '
|
||||
'and "relationships" (array of objects with name, source, target, fact).'
|
||||
"1. Extract entities that match the entity_types defined above.\n"
|
||||
"2. Each entity needs: name (the canonical name used in the text), type (must match an entity_type name from the schema), "
|
||||
"summary (1-2 sentences describing the entity based on the text), and attributes (fill in any attributes defined for that type).\n"
|
||||
"3. Only extract entities explicitly mentioned or strongly implied in the text.\n"
|
||||
"4. Use the exact name as it appears in the text (e.g. 'Mira' not 'Mira the Socializer').\n"
|
||||
"5. If no entities are found, return an empty array.\n\n"
|
||||
'Return JSON: a single key "entities" with an array of objects, each having keys: name, type, summary, attributes.'
|
||||
)
|
||||
|
||||
RELATIONSHIP_EXTRACT_PROMPT = (
|
||||
"You are a knowledge graph relationship extraction engine. Given a text section, a list of known entities, "
|
||||
"and an ontology schema, extract all relationships between the entities.\n\n"
|
||||
"ONTOLOGY SCHEMA (edge types):\n%s\n\n"
|
||||
"KNOWN ENTITIES:\n%s\n\n"
|
||||
"RULES:\n"
|
||||
"1. Find relationships between the known entities that match the edge_types defined above.\n"
|
||||
"2. Each relationship needs: name (must match an edge_type name), source (entity name), target (entity name), "
|
||||
"fact (1 sentence describing the specific relationship found in the text).\n"
|
||||
"3. Both source and target MUST be from the known entities list.\n"
|
||||
"4. Only extract relationships explicitly stated or strongly implied in the text.\n"
|
||||
"5. Extract ALL relationships you can find — be thorough.\n"
|
||||
"6. If no relationships are found, return an empty array.\n\n"
|
||||
'Return JSON: a single key "relationships" with an array of objects, each having keys: name, source, target, fact.'
|
||||
)
|
||||
|
||||
|
||||
class LLMGraphBuilderService:
|
||||
"""
|
||||
Graph builder that uses direct LLM calls instead of Zep.
|
||||
Same interface as GraphBuilderService for drop-in replacement.
|
||||
Graph builder using direct LLM calls instead of Zep.
|
||||
Two-pass extraction: entities first, then relationships.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
self.llm = llm_client or LLMClient()
|
||||
self.task_manager = TaskManager()
|
||||
# In-memory graph storage (keyed by graph_id)
|
||||
self._graphs: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def create_graph(self, name: str) -> str:
|
||||
|
|
@ -51,7 +68,7 @@ class LLMGraphBuilderService:
|
|||
self._graphs[graph_id] = {
|
||||
"name": name,
|
||||
"ontology": None,
|
||||
"nodes": {}, # keyed by normalized name
|
||||
"nodes": {},
|
||||
"edges": [],
|
||||
}
|
||||
return graph_id
|
||||
|
|
@ -60,16 +77,22 @@ class LLMGraphBuilderService:
|
|||
if graph_id in self._graphs:
|
||||
self._graphs[graph_id]["ontology"] = ontology
|
||||
|
||||
def extract_from_chunks(
|
||||
# ── Pass 1: Entity Extraction ──
|
||||
|
||||
def extract_entities(
|
||||
self,
|
||||
graph_id: str,
|
||||
chunks: List[str],
|
||||
progress_callback: Optional[Callable] = None
|
||||
):
|
||||
"""Extract entities and relationships from text chunks using LLM."""
|
||||
"""Pass 1: Extract entities from each chunk."""
|
||||
graph = self._graphs[graph_id]
|
||||
ontology = graph["ontology"]
|
||||
ontology_json = json.dumps(ontology, indent=2, ensure_ascii=False)
|
||||
|
||||
# Build entity-types-only schema for the prompt
|
||||
entity_types_json = json.dumps(
|
||||
ontology.get("entity_types", []), indent=2, ensure_ascii=False
|
||||
)
|
||||
|
||||
total = len(chunks)
|
||||
success_count = 0
|
||||
|
|
@ -79,65 +102,125 @@ class LLMGraphBuilderService:
|
|||
for i, chunk in enumerate(chunks):
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
f"Extracting from chunk {i+1}/{total} (ok={success_count}, fail={fail_count})...",
|
||||
f"[Pass 1] Extracting entities from chunk {i+1}/{total} "
|
||||
f"(ok={success_count}, fail={fail_count})...",
|
||||
(i + 1) / total
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info(f"Extracting chunk {i+1}/{total} ({len(chunk)} chars)")
|
||||
logger.info(f"[Pass 1] Entity extraction chunk {i+1}/{total} ({len(chunk)} chars)")
|
||||
result = self.llm.chat_json(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_SYSTEM_PROMPT_TEMPLATE % ontology_json
|
||||
"content": ENTITY_EXTRACT_PROMPT % entity_types_json
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Extract entities and relationships from this text:\n\n{chunk}"
|
||||
"content": f"Extract all entities from this text:\n\n{chunk}"
|
||||
}
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=4096
|
||||
)
|
||||
entities = result.get("entities", [])
|
||||
rels = result.get("relationships", [])
|
||||
logger.info(f"Chunk {i+1}: extracted {len(entities)} entities, {len(rels)} relationships")
|
||||
self._merge_extraction(graph_id, result)
|
||||
logger.info(f"[Pass 1] Chunk {i+1}: {len(entities)} entities")
|
||||
self._merge_entities(graph_id, entities)
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
fail_count += 1
|
||||
last_error = e
|
||||
import traceback
|
||||
logger.error(f"Chunk {i+1} extraction failed: {type(e).__name__}: {e}")
|
||||
logger.error(f"Chunk {i+1} traceback: {traceback.format_exc()}")
|
||||
if progress_callback:
|
||||
progress_callback(f"Chunk {i+1} error: {e}", (i + 1) / total)
|
||||
logger.error(f"[Pass 1] Chunk {i+1} failed: {type(e).__name__}: {e}")
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
logger.info(f"Extraction complete: {success_count}/{total} succeeded, {fail_count} failed")
|
||||
logger.info(f"[Pass 1] Complete: {success_count}/{total} succeeded, "
|
||||
f"{len(graph['nodes'])} unique entities found")
|
||||
|
||||
if success_count == 0 and total > 0:
|
||||
raise RuntimeError(f"All {total} chunks failed extraction. Last error: {last_error}")
|
||||
raise RuntimeError(f"All {total} entity extraction calls failed. Last error: {last_error}")
|
||||
|
||||
def _merge_extraction(self, graph_id: str, result: Dict[str, Any]):
|
||||
"""Merge extracted entities/relationships into the graph, deduplicating by name."""
|
||||
# ── Pass 2: Relationship Discovery ──
|
||||
|
||||
def discover_relationships(
|
||||
self,
|
||||
graph_id: str,
|
||||
full_text: str,
|
||||
progress_callback: Optional[Callable] = None
|
||||
):
|
||||
"""Pass 2: Find relationships between known entities using larger text windows."""
|
||||
graph = self._graphs[graph_id]
|
||||
ontology = graph["ontology"]
|
||||
nodes = graph["nodes"]
|
||||
|
||||
if not nodes:
|
||||
logger.warning("[Pass 2] No entities to find relationships for")
|
||||
return
|
||||
|
||||
# Build edge types schema
|
||||
edge_types_json = json.dumps(
|
||||
ontology.get("edge_types", []), indent=2, ensure_ascii=False
|
||||
)
|
||||
|
||||
# Build entity list string
|
||||
entity_names = sorted(set(n["name"] for n in nodes.values()))
|
||||
entity_list = "\n".join(f"- {name} ({nodes[name.lower()]['labels'][0]})"
|
||||
for name in entity_names if name.lower() in nodes)
|
||||
|
||||
# Use larger chunks for relationship discovery (5000 chars, 500 overlap)
|
||||
rel_chunks = TextProcessor.split_text(full_text, chunk_size=5000, overlap=500)
|
||||
total = len(rel_chunks)
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for i, chunk in enumerate(rel_chunks):
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
f"[Pass 2] Finding relationships in section {i+1}/{total} "
|
||||
f"(edges so far: {len(graph['edges'])})...",
|
||||
(i + 1) / total
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info(f"[Pass 2] Relationship discovery section {i+1}/{total} ({len(chunk)} chars)")
|
||||
result = self.llm.chat_json(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": RELATIONSHIP_EXTRACT_PROMPT % (edge_types_json, entity_list)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Find all relationships between the known entities in this text:\n\n{chunk}"
|
||||
}
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=4096
|
||||
)
|
||||
rels = result.get("relationships", [])
|
||||
logger.info(f"[Pass 2] Section {i+1}: {len(rels)} relationships")
|
||||
self._merge_relationships(graph_id, rels)
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
fail_count += 1
|
||||
logger.error(f"[Pass 2] Section {i+1} failed: {type(e).__name__}: {e}")
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
logger.info(f"[Pass 2] Complete: {success_count}/{total} succeeded, "
|
||||
f"{len(graph['edges'])} total edges")
|
||||
|
||||
# ── Merge Helpers ──
|
||||
|
||||
def _merge_entities(self, graph_id: str, entities: List[Dict[str, Any]]):
|
||||
"""Merge extracted entities into the graph."""
|
||||
graph = self._graphs[graph_id]
|
||||
nodes = graph["nodes"]
|
||||
edges = graph["edges"]
|
||||
|
||||
# Valid entity type names from ontology
|
||||
valid_entity_types = set()
|
||||
if graph["ontology"]:
|
||||
for et in graph["ontology"].get("entity_types", []):
|
||||
valid_entity_types.add(et["name"])
|
||||
|
||||
# Valid edge type names
|
||||
valid_edge_types = set()
|
||||
if graph["ontology"]:
|
||||
for et in graph["ontology"].get("edge_types", []):
|
||||
valid_edge_types.add(et["name"])
|
||||
|
||||
# Merge entities
|
||||
for entity in result.get("entities", []):
|
||||
for entity in entities:
|
||||
name = entity.get("name", "").strip()
|
||||
if not name:
|
||||
continue
|
||||
|
|
@ -146,13 +229,11 @@ class LLMGraphBuilderService:
|
|||
key = name.lower()
|
||||
|
||||
if key in nodes:
|
||||
# Update summary if new one is longer
|
||||
existing = nodes[key]
|
||||
new_summary = entity.get("summary", "")
|
||||
if new_summary and len(new_summary) > len(existing.get("summary", "")):
|
||||
existing["summary"] = new_summary
|
||||
# Merge attributes
|
||||
for k, v in entity.get("attributes", {}).items():
|
||||
for k, v in (entity.get("attributes", {}) or {}).items():
|
||||
if v and not existing["attributes"].get(k):
|
||||
existing["attributes"][k] = v
|
||||
else:
|
||||
|
|
@ -166,12 +247,17 @@ class LLMGraphBuilderService:
|
|||
"created_at": None,
|
||||
}
|
||||
|
||||
# Merge relationships (deduplicate by source+target+name)
|
||||
def _merge_relationships(self, graph_id: str, relationships: List[Dict[str, Any]]):
|
||||
"""Merge extracted relationships into the graph."""
|
||||
graph = self._graphs[graph_id]
|
||||
nodes = graph["nodes"]
|
||||
edges = graph["edges"]
|
||||
|
||||
existing_edges = set()
|
||||
for e in edges:
|
||||
existing_edges.add((e["source_node_name"].lower(), e["target_node_name"].lower(), e["name"]))
|
||||
|
||||
for rel in result.get("relationships", []):
|
||||
for rel in relationships:
|
||||
rel_name = rel.get("name", "").strip()
|
||||
source = rel.get("source", "").strip()
|
||||
target = rel.get("target", "").strip()
|
||||
|
|
@ -183,13 +269,11 @@ class LLMGraphBuilderService:
|
|||
continue
|
||||
existing_edges.add(edge_key)
|
||||
|
||||
# Resolve node UUIDs
|
||||
source_node = nodes.get(source.lower())
|
||||
target_node = nodes.get(target.lower())
|
||||
source_uuid = source_node["uuid"] if source_node else str(uuid.uuid4())
|
||||
target_uuid = target_node["uuid"] if target_node else str(uuid.uuid4())
|
||||
|
||||
# Create placeholder nodes if they don't exist
|
||||
if not source_node:
|
||||
nodes[source.lower()] = {
|
||||
"uuid": source_uuid,
|
||||
|
|
@ -226,12 +310,12 @@ class LLMGraphBuilderService:
|
|||
"episodes": [],
|
||||
})
|
||||
|
||||
# ── Data Access ──
|
||||
|
||||
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
|
||||
"""Return graph data in the same format as the Zep-based builder."""
|
||||
graph = self._graphs.get(graph_id, {"nodes": {}, "edges": []})
|
||||
nodes_list = list(graph["nodes"].values())
|
||||
edges_list = graph["edges"]
|
||||
|
||||
return {
|
||||
"graph_id": graph_id,
|
||||
"nodes": nodes_list,
|
||||
|
|
@ -241,7 +325,6 @@ class LLMGraphBuilderService:
|
|||
}
|
||||
|
||||
def save_graph_data(self, graph_id: str, project_dir: str) -> str:
|
||||
"""Persist graph data to a JSON file in the project directory."""
|
||||
data = self.get_graph_data(graph_id)
|
||||
path = os.path.join(project_dir, "graph_data.json")
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
|
|
@ -250,7 +333,6 @@ class LLMGraphBuilderService:
|
|||
|
||||
@staticmethod
|
||||
def load_graph_data(project_dir: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load persisted graph data from disk."""
|
||||
path = os.path.join(project_dir, "graph_data.json")
|
||||
if os.path.exists(path):
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
|
|
|
|||
Loading…
Reference in a new issue