MiroFish/backend/app/services/llm_graph_builder.py
_Yusaki 79519ddd54 Replace Zep with direct LLM calls for graph building
Add LLMGraphBuilderService that extracts entities/relationships
from text chunks using Groq instead of Zep Cloud API. Graph data
is persisted to disk as graph_data.json, with fallback to Zep
for existing graphs.
2026-03-13 19:07:40 +07:00

254 lines
9.4 KiB
Python

"""
LLM-based graph builder service
Replaces Zep with direct LLM calls for entity/relationship extraction
"""
import os
import uuid
import json
from typing import Dict, Any, List, Optional, Callable
from ..utils.llm_client import LLMClient
from ..models.task import TaskManager, TaskStatus
from .text_processor import TextProcessor
EXTRACT_SYSTEM_PROMPT = """You are a knowledge graph extraction engine. Given a text chunk and an ontology schema, extract all entities and relationships.
ONTOLOGY SCHEMA:
{ontology_json}
RULES:
1. Extract entities that match the entity_types defined in the schema. Each entity needs: name, type (matching an entity_type name), summary (1-2 sentences), and any attributes defined for that type.
2. Extract relationships between entities that match the edge_types defined in the schema. Each relationship needs: name (the edge type name), source (entity name), target (entity name), and a fact (short description of the relationship).
3. Only extract entities and relationships that are explicitly mentioned or strongly implied in the text.
4. Use consistent entity names across extractions (e.g., always "Mira" not sometimes "Mira" and sometimes "Mira the Socializer").
5. If no entities or relationships are found, return empty arrays.
Return JSON in this exact format:
{
"entities": [
{
"name": "EntityName",
"type": "EntityTypeName",
"summary": "Brief description",
"attributes": {"attr_name": "attr_value"}
}
],
"relationships": [
{
"name": "EDGE_TYPE_NAME",
"source": "SourceEntityName",
"target": "TargetEntityName",
"fact": "Description of this relationship"
}
]
}"""
class LLMGraphBuilderService:
"""
Graph builder that uses direct LLM calls instead of Zep.
Same interface as GraphBuilderService for drop-in replacement.
"""
def __init__(self, llm_client: Optional[LLMClient] = None):
self.llm = llm_client or LLMClient()
self.task_manager = TaskManager()
# In-memory graph storage (keyed by graph_id)
self._graphs: Dict[str, Dict[str, Any]] = {}
def create_graph(self, name: str) -> str:
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self._graphs[graph_id] = {
"name": name,
"ontology": None,
"nodes": {}, # keyed by normalized name
"edges": [],
}
return graph_id
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
if graph_id in self._graphs:
self._graphs[graph_id]["ontology"] = ontology
def extract_from_chunks(
self,
graph_id: str,
chunks: List[str],
progress_callback: Optional[Callable] = None
):
"""Extract entities and relationships from text chunks using LLM."""
graph = self._graphs[graph_id]
ontology = graph["ontology"]
ontology_json = json.dumps(ontology, indent=2, ensure_ascii=False)
total = len(chunks)
for i, chunk in enumerate(chunks):
if progress_callback:
progress_callback(
f"Extracting from chunk {i+1}/{total}...",
(i + 1) / total
)
try:
result = self.llm.chat_json(
messages=[
{
"role": "system",
"content": EXTRACT_SYSTEM_PROMPT.format(ontology_json=ontology_json)
},
{
"role": "user",
"content": f"Extract entities and relationships from this text:\n\n{chunk}"
}
],
temperature=0.1,
max_tokens=4096
)
self._merge_extraction(graph_id, result)
except Exception as e:
if progress_callback:
progress_callback(f"Chunk {i+1} extraction error: {e}", (i + 1) / total)
def _merge_extraction(self, graph_id: str, result: Dict[str, Any]):
"""Merge extracted entities/relationships into the graph, deduplicating by name."""
graph = self._graphs[graph_id]
nodes = graph["nodes"]
edges = graph["edges"]
# Valid entity type names from ontology
valid_entity_types = set()
if graph["ontology"]:
for et in graph["ontology"].get("entity_types", []):
valid_entity_types.add(et["name"])
# Valid edge type names
valid_edge_types = set()
if graph["ontology"]:
for et in graph["ontology"].get("edge_types", []):
valid_edge_types.add(et["name"])
# Merge entities
for entity in result.get("entities", []):
name = entity.get("name", "").strip()
if not name:
continue
etype = entity.get("type", "Entity")
key = name.lower()
if key in nodes:
# Update summary if new one is longer
existing = nodes[key]
new_summary = entity.get("summary", "")
if new_summary and len(new_summary) > len(existing.get("summary", "")):
existing["summary"] = new_summary
# Merge attributes
for k, v in entity.get("attributes", {}).items():
if v and not existing["attributes"].get(k):
existing["attributes"][k] = v
else:
labels = [etype] if etype in valid_entity_types else ["Entity"]
nodes[key] = {
"uuid": str(uuid.uuid4()),
"name": name,
"labels": labels,
"summary": entity.get("summary", ""),
"attributes": entity.get("attributes", {}) or {},
"created_at": None,
}
# Merge relationships (deduplicate by source+target+name)
existing_edges = set()
for e in edges:
existing_edges.add((e["source_node_name"].lower(), e["target_node_name"].lower(), e["name"]))
for rel in result.get("relationships", []):
rel_name = rel.get("name", "").strip()
source = rel.get("source", "").strip()
target = rel.get("target", "").strip()
if not rel_name or not source or not target:
continue
edge_key = (source.lower(), target.lower(), rel_name)
if edge_key in existing_edges:
continue
existing_edges.add(edge_key)
# Resolve node UUIDs
source_node = nodes.get(source.lower())
target_node = nodes.get(target.lower())
source_uuid = source_node["uuid"] if source_node else str(uuid.uuid4())
target_uuid = target_node["uuid"] if target_node else str(uuid.uuid4())
# Create placeholder nodes if they don't exist
if not source_node:
nodes[source.lower()] = {
"uuid": source_uuid,
"name": source,
"labels": ["Entity"],
"summary": "",
"attributes": {},
"created_at": None,
}
if not target_node:
nodes[target.lower()] = {
"uuid": target_uuid,
"name": target,
"labels": ["Entity"],
"summary": "",
"attributes": {},
"created_at": None,
}
edges.append({
"uuid": str(uuid.uuid4()),
"name": rel_name,
"fact": rel.get("fact", ""),
"fact_type": rel_name,
"source_node_uuid": source_uuid,
"target_node_uuid": target_uuid,
"source_node_name": source,
"target_node_name": target,
"attributes": {},
"created_at": None,
"valid_at": None,
"invalid_at": None,
"expired_at": None,
"episodes": [],
})
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
"""Return graph data in the same format as the Zep-based builder."""
graph = self._graphs.get(graph_id, {"nodes": {}, "edges": []})
nodes_list = list(graph["nodes"].values())
edges_list = graph["edges"]
return {
"graph_id": graph_id,
"nodes": nodes_list,
"edges": edges_list,
"node_count": len(nodes_list),
"edge_count": len(edges_list),
}
def save_graph_data(self, graph_id: str, project_dir: str) -> str:
"""Persist graph data to a JSON file in the project directory."""
data = self.get_graph_data(graph_id)
path = os.path.join(project_dir, "graph_data.json")
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
return path
@staticmethod
def load_graph_data(project_dir: str) -> Optional[Dict[str, Any]]:
"""Load persisted graph data from disk."""
path = os.path.join(project_dir, "graph_data.json")
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
return None
def delete_graph(self, graph_id: str):
self._graphs.pop(graph_id, None)