MiroFish/backend/app/services/llm_graph_builder.py

343 lines
13 KiB
Python

"""
LLM-based graph builder service
Replaces Zep with direct LLM calls for entity/relationship extraction.
Two-pass approach: (1) extract entities, (2) discover relationships.
"""
import os
import uuid
import json
import logging
import traceback
from typing import Dict, Any, List, Optional, Callable
from ..utils.llm_client import LLMClient
from ..models.task import TaskManager, TaskStatus
from .text_processor import TextProcessor
logger = logging.getLogger('mirofish.llm_graph_builder')
# Default chunk size — larger than Zep's default to capture more context per call
DEFAULT_CHUNK_SIZE = 2500
DEFAULT_CHUNK_OVERLAP = 200
ENTITY_EXTRACT_PROMPT = (
"You are a knowledge graph entity extraction engine. Given a text chunk and an ontology schema, "
"extract all entities mentioned in the text.\n\n"
"ONTOLOGY SCHEMA (entity types only):\n%s\n\n"
"RULES:\n"
"1. Extract entities that match the entity_types defined above.\n"
"2. Each entity needs: name (the canonical name used in the text), type (must match an entity_type name from the schema), "
"summary (1-2 sentences describing the entity based on the text), and attributes (fill in any attributes defined for that type).\n"
"3. Only extract entities explicitly mentioned or strongly implied in the text.\n"
"4. Use the exact name as it appears in the text (e.g. 'Mira' not 'Mira the Socializer').\n"
"5. If no entities are found, return an empty array.\n\n"
'Return JSON: a single key "entities" with an array of objects, each having keys: name, type, summary, attributes.'
)
RELATIONSHIP_EXTRACT_PROMPT = (
"You are a knowledge graph relationship extraction engine. Given a text section, a list of known entities, "
"and an ontology schema, extract all relationships between the entities.\n\n"
"ONTOLOGY SCHEMA (edge types):\n%s\n\n"
"KNOWN ENTITIES:\n%s\n\n"
"RULES:\n"
"1. Find relationships between the known entities that match the edge_types defined above.\n"
"2. Each relationship needs: name (must match an edge_type name), source (entity name), target (entity name), "
"fact (1 sentence describing the specific relationship found in the text).\n"
"3. Both source and target MUST be from the known entities list.\n"
"4. Only extract relationships explicitly stated or strongly implied in the text.\n"
"5. Extract ALL relationships you can find — be thorough.\n"
"6. If no relationships are found, return an empty array.\n\n"
'Return JSON: a single key "relationships" with an array of objects, each having keys: name, source, target, fact.'
)
class LLMGraphBuilderService:
"""
Graph builder using direct LLM calls instead of Zep.
Two-pass extraction: entities first, then relationships.
"""
def __init__(self, llm_client: Optional[LLMClient] = None):
self.llm = llm_client or LLMClient()
self.task_manager = TaskManager()
self._graphs: Dict[str, Dict[str, Any]] = {}
def create_graph(self, name: str) -> str:
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self._graphs[graph_id] = {
"name": name,
"ontology": None,
"nodes": {},
"edges": [],
}
return graph_id
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
if graph_id in self._graphs:
self._graphs[graph_id]["ontology"] = ontology
# ── Pass 1: Entity Extraction ──
def extract_entities(
self,
graph_id: str,
chunks: List[str],
progress_callback: Optional[Callable] = None
):
"""Pass 1: Extract entities from each chunk."""
graph = self._graphs[graph_id]
ontology = graph["ontology"]
# Build entity-types-only schema for the prompt
entity_types_json = json.dumps(
ontology.get("entity_types", []), indent=2, ensure_ascii=False
)
total = len(chunks)
success_count = 0
fail_count = 0
last_error = None
for i, chunk in enumerate(chunks):
if progress_callback:
progress_callback(
f"[Pass 1] Extracting entities from chunk {i+1}/{total} "
f"(ok={success_count}, fail={fail_count})...",
(i + 1) / total
)
try:
logger.info(f"[Pass 1] Entity extraction chunk {i+1}/{total} ({len(chunk)} chars)")
result = self.llm.chat_json(
messages=[
{
"role": "system",
"content": ENTITY_EXTRACT_PROMPT % entity_types_json
},
{
"role": "user",
"content": f"Extract all entities from this text:\n\n{chunk}"
}
],
temperature=0.1,
max_tokens=4096
)
entities = result.get("entities", [])
logger.info(f"[Pass 1] Chunk {i+1}: {len(entities)} entities")
self._merge_entities(graph_id, entities)
success_count += 1
except Exception as e:
fail_count += 1
last_error = e
logger.error(f"[Pass 1] Chunk {i+1} failed: {type(e).__name__}: {e}")
logger.debug(traceback.format_exc())
logger.info(f"[Pass 1] Complete: {success_count}/{total} succeeded, "
f"{len(graph['nodes'])} unique entities found")
if success_count == 0 and total > 0:
raise RuntimeError(f"All {total} entity extraction calls failed. Last error: {last_error}")
# ── Pass 2: Relationship Discovery ──
def discover_relationships(
self,
graph_id: str,
full_text: str,
progress_callback: Optional[Callable] = None
):
"""Pass 2: Find relationships between known entities using larger text windows."""
graph = self._graphs[graph_id]
ontology = graph["ontology"]
nodes = graph["nodes"]
if not nodes:
logger.warning("[Pass 2] No entities to find relationships for")
return
# Build edge types schema
edge_types_json = json.dumps(
ontology.get("edge_types", []), indent=2, ensure_ascii=False
)
# Build entity list string
entity_names = sorted(set(n["name"] for n in nodes.values()))
entity_list = "\n".join(f"- {name} ({nodes[name.lower()]['labels'][0]})"
for name in entity_names if name.lower() in nodes)
# Use larger chunks for relationship discovery (5000 chars, 500 overlap)
rel_chunks = TextProcessor.split_text(full_text, chunk_size=5000, overlap=500)
total = len(rel_chunks)
success_count = 0
fail_count = 0
for i, chunk in enumerate(rel_chunks):
if progress_callback:
progress_callback(
f"[Pass 2] Finding relationships in section {i+1}/{total} "
f"(edges so far: {len(graph['edges'])})...",
(i + 1) / total
)
try:
logger.info(f"[Pass 2] Relationship discovery section {i+1}/{total} ({len(chunk)} chars)")
result = self.llm.chat_json(
messages=[
{
"role": "system",
"content": RELATIONSHIP_EXTRACT_PROMPT % (edge_types_json, entity_list)
},
{
"role": "user",
"content": f"Find all relationships between the known entities in this text:\n\n{chunk}"
}
],
temperature=0.1,
max_tokens=4096
)
rels = result.get("relationships", [])
logger.info(f"[Pass 2] Section {i+1}: {len(rels)} relationships")
self._merge_relationships(graph_id, rels)
success_count += 1
except Exception as e:
fail_count += 1
logger.error(f"[Pass 2] Section {i+1} failed: {type(e).__name__}: {e}")
logger.debug(traceback.format_exc())
logger.info(f"[Pass 2] Complete: {success_count}/{total} succeeded, "
f"{len(graph['edges'])} total edges")
# ── Merge Helpers ──
def _merge_entities(self, graph_id: str, entities: List[Dict[str, Any]]):
"""Merge extracted entities into the graph."""
graph = self._graphs[graph_id]
nodes = graph["nodes"]
valid_entity_types = set()
if graph["ontology"]:
for et in graph["ontology"].get("entity_types", []):
valid_entity_types.add(et["name"])
for entity in entities:
name = entity.get("name", "").strip()
if not name:
continue
etype = entity.get("type", "Entity")
key = name.lower()
if key in nodes:
existing = nodes[key]
new_summary = entity.get("summary", "")
if new_summary and len(new_summary) > len(existing.get("summary", "")):
existing["summary"] = new_summary
for k, v in (entity.get("attributes", {}) or {}).items():
if v and not existing["attributes"].get(k):
existing["attributes"][k] = v
else:
labels = [etype] if etype in valid_entity_types else ["Entity"]
nodes[key] = {
"uuid": str(uuid.uuid4()),
"name": name,
"labels": labels,
"summary": entity.get("summary", ""),
"attributes": entity.get("attributes", {}) or {},
"created_at": None,
}
def _merge_relationships(self, graph_id: str, relationships: List[Dict[str, Any]]):
"""Merge extracted relationships into the graph."""
graph = self._graphs[graph_id]
nodes = graph["nodes"]
edges = graph["edges"]
existing_edges = set()
for e in edges:
existing_edges.add((e["source_node_name"].lower(), e["target_node_name"].lower(), e["name"]))
for rel in relationships:
rel_name = rel.get("name", "").strip()
source = rel.get("source", "").strip()
target = rel.get("target", "").strip()
if not rel_name or not source or not target:
continue
edge_key = (source.lower(), target.lower(), rel_name)
if edge_key in existing_edges:
continue
existing_edges.add(edge_key)
source_node = nodes.get(source.lower())
target_node = nodes.get(target.lower())
source_uuid = source_node["uuid"] if source_node else str(uuid.uuid4())
target_uuid = target_node["uuid"] if target_node else str(uuid.uuid4())
if not source_node:
nodes[source.lower()] = {
"uuid": source_uuid,
"name": source,
"labels": ["Entity"],
"summary": "",
"attributes": {},
"created_at": None,
}
if not target_node:
nodes[target.lower()] = {
"uuid": target_uuid,
"name": target,
"labels": ["Entity"],
"summary": "",
"attributes": {},
"created_at": None,
}
edges.append({
"uuid": str(uuid.uuid4()),
"name": rel_name,
"fact": rel.get("fact", ""),
"fact_type": rel_name,
"source_node_uuid": source_uuid,
"target_node_uuid": target_uuid,
"source_node_name": source,
"target_node_name": target,
"attributes": {},
"created_at": None,
"valid_at": None,
"invalid_at": None,
"expired_at": None,
"episodes": [],
})
# ── Data Access ──
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
graph = self._graphs.get(graph_id, {"nodes": {}, "edges": []})
nodes_list = list(graph["nodes"].values())
edges_list = graph["edges"]
return {
"graph_id": graph_id,
"nodes": nodes_list,
"edges": edges_list,
"node_count": len(nodes_list),
"edge_count": len(edges_list),
}
def save_graph_data(self, graph_id: str, project_dir: str) -> str:
data = self.get_graph_data(graph_id)
path = os.path.join(project_dir, "graph_data.json")
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
return path
@staticmethod
def load_graph_data(project_dir: str) -> Optional[Dict[str, Any]]:
path = os.path.join(project_dir, "graph_data.json")
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
return None
def delete_graph(self, graph_id: str):
self._graphs.pop(graph_id, None)