MiroFish/backend/app/services/llm_graph_builder.py

261 lines
10 KiB
Python

"""
LLM-based graph builder service
Replaces Zep with direct LLM calls for entity/relationship extraction
"""
import os
import uuid
import json
import logging
from typing import Dict, Any, List, Optional, Callable
from ..utils.llm_client import LLMClient
from ..models.task import TaskManager, TaskStatus
from .text_processor import TextProcessor
logger = logging.getLogger('mirofish.llm_graph_builder')
EXTRACT_SYSTEM_PROMPT_TEMPLATE = (
"You are a knowledge graph extraction engine. Given a text chunk and an ontology schema, "
"extract all entities and relationships.\n\n"
"ONTOLOGY SCHEMA:\n%s\n\n"
"RULES:\n"
"1. Extract entities that match the entity_types defined in the schema. Each entity needs: "
"name, type (matching an entity_type name), summary (1-2 sentences), and any attributes defined for that type.\n"
"2. Extract relationships between entities that match the edge_types defined in the schema. "
"Each relationship needs: name (the edge type name), source (entity name), target (entity name), "
"and a fact (short description of the relationship).\n"
"3. Only extract entities and relationships that are explicitly mentioned or strongly implied in the text.\n"
"4. Use consistent entity names across extractions.\n"
"5. If no entities or relationships are found, return empty arrays.\n\n"
'Return JSON with keys "entities" (array of objects with name, type, summary, attributes) '
'and "relationships" (array of objects with name, source, target, fact).'
)
class LLMGraphBuilderService:
"""
Graph builder that uses direct LLM calls instead of Zep.
Same interface as GraphBuilderService for drop-in replacement.
"""
def __init__(self, llm_client: Optional[LLMClient] = None):
self.llm = llm_client or LLMClient()
self.task_manager = TaskManager()
# In-memory graph storage (keyed by graph_id)
self._graphs: Dict[str, Dict[str, Any]] = {}
def create_graph(self, name: str) -> str:
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self._graphs[graph_id] = {
"name": name,
"ontology": None,
"nodes": {}, # keyed by normalized name
"edges": [],
}
return graph_id
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
if graph_id in self._graphs:
self._graphs[graph_id]["ontology"] = ontology
def extract_from_chunks(
self,
graph_id: str,
chunks: List[str],
progress_callback: Optional[Callable] = None
):
"""Extract entities and relationships from text chunks using LLM."""
graph = self._graphs[graph_id]
ontology = graph["ontology"]
ontology_json = json.dumps(ontology, indent=2, ensure_ascii=False)
total = len(chunks)
success_count = 0
fail_count = 0
last_error = None
for i, chunk in enumerate(chunks):
if progress_callback:
progress_callback(
f"Extracting from chunk {i+1}/{total} (ok={success_count}, fail={fail_count})...",
(i + 1) / total
)
try:
logger.info(f"Extracting chunk {i+1}/{total} ({len(chunk)} chars)")
result = self.llm.chat_json(
messages=[
{
"role": "system",
"content": EXTRACT_SYSTEM_PROMPT_TEMPLATE % ontology_json
},
{
"role": "user",
"content": f"Extract entities and relationships from this text:\n\n{chunk}"
}
],
temperature=0.1,
max_tokens=4096
)
entities = result.get("entities", [])
rels = result.get("relationships", [])
logger.info(f"Chunk {i+1}: extracted {len(entities)} entities, {len(rels)} relationships")
self._merge_extraction(graph_id, result)
success_count += 1
except Exception as e:
fail_count += 1
last_error = e
import traceback
logger.error(f"Chunk {i+1} extraction failed: {type(e).__name__}: {e}")
logger.error(f"Chunk {i+1} traceback: {traceback.format_exc()}")
if progress_callback:
progress_callback(f"Chunk {i+1} error: {e}", (i + 1) / total)
logger.info(f"Extraction complete: {success_count}/{total} succeeded, {fail_count} failed")
if success_count == 0 and total > 0:
raise RuntimeError(f"All {total} chunks failed extraction. Last error: {last_error}")
def _merge_extraction(self, graph_id: str, result: Dict[str, Any]):
"""Merge extracted entities/relationships into the graph, deduplicating by name."""
graph = self._graphs[graph_id]
nodes = graph["nodes"]
edges = graph["edges"]
# Valid entity type names from ontology
valid_entity_types = set()
if graph["ontology"]:
for et in graph["ontology"].get("entity_types", []):
valid_entity_types.add(et["name"])
# Valid edge type names
valid_edge_types = set()
if graph["ontology"]:
for et in graph["ontology"].get("edge_types", []):
valid_edge_types.add(et["name"])
# Merge entities
for entity in result.get("entities", []):
name = entity.get("name", "").strip()
if not name:
continue
etype = entity.get("type", "Entity")
key = name.lower()
if key in nodes:
# Update summary if new one is longer
existing = nodes[key]
new_summary = entity.get("summary", "")
if new_summary and len(new_summary) > len(existing.get("summary", "")):
existing["summary"] = new_summary
# Merge attributes
for k, v in entity.get("attributes", {}).items():
if v and not existing["attributes"].get(k):
existing["attributes"][k] = v
else:
labels = [etype] if etype in valid_entity_types else ["Entity"]
nodes[key] = {
"uuid": str(uuid.uuid4()),
"name": name,
"labels": labels,
"summary": entity.get("summary", ""),
"attributes": entity.get("attributes", {}) or {},
"created_at": None,
}
# Merge relationships (deduplicate by source+target+name)
existing_edges = set()
for e in edges:
existing_edges.add((e["source_node_name"].lower(), e["target_node_name"].lower(), e["name"]))
for rel in result.get("relationships", []):
rel_name = rel.get("name", "").strip()
source = rel.get("source", "").strip()
target = rel.get("target", "").strip()
if not rel_name or not source or not target:
continue
edge_key = (source.lower(), target.lower(), rel_name)
if edge_key in existing_edges:
continue
existing_edges.add(edge_key)
# Resolve node UUIDs
source_node = nodes.get(source.lower())
target_node = nodes.get(target.lower())
source_uuid = source_node["uuid"] if source_node else str(uuid.uuid4())
target_uuid = target_node["uuid"] if target_node else str(uuid.uuid4())
# Create placeholder nodes if they don't exist
if not source_node:
nodes[source.lower()] = {
"uuid": source_uuid,
"name": source,
"labels": ["Entity"],
"summary": "",
"attributes": {},
"created_at": None,
}
if not target_node:
nodes[target.lower()] = {
"uuid": target_uuid,
"name": target,
"labels": ["Entity"],
"summary": "",
"attributes": {},
"created_at": None,
}
edges.append({
"uuid": str(uuid.uuid4()),
"name": rel_name,
"fact": rel.get("fact", ""),
"fact_type": rel_name,
"source_node_uuid": source_uuid,
"target_node_uuid": target_uuid,
"source_node_name": source,
"target_node_name": target,
"attributes": {},
"created_at": None,
"valid_at": None,
"invalid_at": None,
"expired_at": None,
"episodes": [],
})
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
"""Return graph data in the same format as the Zep-based builder."""
graph = self._graphs.get(graph_id, {"nodes": {}, "edges": []})
nodes_list = list(graph["nodes"].values())
edges_list = graph["edges"]
return {
"graph_id": graph_id,
"nodes": nodes_list,
"edges": edges_list,
"node_count": len(nodes_list),
"edge_count": len(edges_list),
}
def save_graph_data(self, graph_id: str, project_dir: str) -> str:
"""Persist graph data to a JSON file in the project directory."""
data = self.get_graph_data(graph_id)
path = os.path.join(project_dir, "graph_data.json")
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
return path
@staticmethod
def load_graph_data(project_dir: str) -> Optional[Dict[str, Any]]:
"""Load persisted graph data from disk."""
path = os.path.join(project_dir, "graph_data.json")
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
return None
def delete_graph(self, graph_id: str):
self._graphs.pop(graph_id, None)