346 lines
14 KiB
Python
346 lines
14 KiB
Python
"""
|
|
LLM-based graph builder service
|
|
Replaces Zep with direct LLM calls for entity/relationship extraction.
|
|
Two-pass approach: (1) extract entities, (2) discover relationships.
|
|
"""
|
|
|
|
import os
|
|
import uuid
|
|
import json
|
|
import logging
|
|
import traceback
|
|
from typing import Dict, Any, List, Optional, Callable
|
|
|
|
from ..utils.llm_client import LLMClient
|
|
from ..models.task import TaskManager, TaskStatus
|
|
from .text_processor import TextProcessor
|
|
|
|
logger = logging.getLogger('mirofish.llm_graph_builder')
|
|
|
|
# Default chunk size — larger than Zep's default to capture more context per call
|
|
DEFAULT_CHUNK_SIZE = 2500
|
|
DEFAULT_CHUNK_OVERLAP = 200
|
|
|
|
ENTITY_EXTRACT_PROMPT = (
|
|
"You are a knowledge graph entity extraction engine. Given a text chunk and an ontology schema, "
|
|
"extract all entities mentioned in the text.\n\n"
|
|
"ONTOLOGY SCHEMA (entity types only):\n%s\n\n"
|
|
"RULES:\n"
|
|
"1. Extract entities that match the entity_types defined above.\n"
|
|
"2. Each entity needs: name (the canonical name used in the text), type (must match an entity_type name from the schema), "
|
|
"summary (1-2 sentences describing the entity based on the text), and attributes (fill in any attributes defined for that type).\n"
|
|
"3. Only extract entities explicitly mentioned or strongly implied in the text.\n"
|
|
"4. Use the exact name as it appears in the text (e.g. 'Mira' not 'Mira the Socializer').\n"
|
|
"5. If no entities are found, return an empty array.\n\n"
|
|
'Return JSON: a single key "entities" with an array of objects, each having keys: name, type, summary, attributes.'
|
|
)
|
|
|
|
RELATIONSHIP_EXTRACT_PROMPT = (
|
|
"You are a knowledge graph relationship inference engine. Given a text section describing characters/entities "
|
|
"in a social simulation, a list of known entities, and relationship types, infer all plausible relationships.\n\n"
|
|
"RELATIONSHIP TYPES:\n%s\n\n"
|
|
"KNOWN ENTITIES:\n%s\n\n"
|
|
"RULES:\n"
|
|
"1. Infer relationships between the known entities based on their described traits, roles, goals, and behaviors.\n"
|
|
"2. Each relationship needs: name (must match a relationship type name above), source (entity name), target (entity name), "
|
|
"fact (1 sentence explaining WHY this relationship is likely based on the text).\n"
|
|
"3. Both source and target MUST be from the known entities list.\n"
|
|
"4. Include both explicit relationships AND strongly implied ones based on complementary/conflicting traits.\n"
|
|
" Examples of inference: if player A 'trades supplies' and player B 'buys tools', infer TRADE_WITH.\n"
|
|
" If player A 'mediates conflicts' and player B 'causes chaos', infer OPPOSES.\n"
|
|
" If two players share goals or would naturally work together, infer COLLABORATES_WITH.\n"
|
|
"5. Be thorough — extract as many plausible relationships as the text supports.\n"
|
|
"6. Do NOT invent relationships with no textual basis. Each fact must reference something from the text.\n\n"
|
|
'Return JSON: a single key "relationships" with an array of objects, each having keys: name, source, target, fact.'
|
|
)
|
|
|
|
|
|
class LLMGraphBuilderService:
|
|
"""
|
|
Graph builder using direct LLM calls instead of Zep.
|
|
Two-pass extraction: entities first, then relationships.
|
|
"""
|
|
|
|
def __init__(self, llm_client: Optional[LLMClient] = None):
|
|
self.llm = llm_client or LLMClient()
|
|
self.task_manager = TaskManager()
|
|
self._graphs: Dict[str, Dict[str, Any]] = {}
|
|
|
|
def create_graph(self, name: str) -> str:
|
|
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
|
|
self._graphs[graph_id] = {
|
|
"name": name,
|
|
"ontology": None,
|
|
"nodes": {},
|
|
"edges": [],
|
|
}
|
|
return graph_id
|
|
|
|
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
|
|
if graph_id in self._graphs:
|
|
self._graphs[graph_id]["ontology"] = ontology
|
|
|
|
# ── Pass 1: Entity Extraction ──
|
|
|
|
def extract_entities(
|
|
self,
|
|
graph_id: str,
|
|
chunks: List[str],
|
|
progress_callback: Optional[Callable] = None
|
|
):
|
|
"""Pass 1: Extract entities from each chunk."""
|
|
graph = self._graphs[graph_id]
|
|
ontology = graph["ontology"]
|
|
|
|
# Build entity-types-only schema for the prompt
|
|
entity_types_json = json.dumps(
|
|
ontology.get("entity_types", []), indent=2, ensure_ascii=False
|
|
)
|
|
|
|
total = len(chunks)
|
|
success_count = 0
|
|
fail_count = 0
|
|
last_error = None
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
if progress_callback:
|
|
progress_callback(
|
|
f"[Pass 1] Extracting entities from chunk {i+1}/{total} "
|
|
f"(ok={success_count}, fail={fail_count})...",
|
|
(i + 1) / total
|
|
)
|
|
|
|
try:
|
|
logger.info(f"[Pass 1] Entity extraction chunk {i+1}/{total} ({len(chunk)} chars)")
|
|
result = self.llm.chat_json(
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": ENTITY_EXTRACT_PROMPT % entity_types_json
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": f"Extract all entities from this text:\n\n{chunk}"
|
|
}
|
|
],
|
|
temperature=0.1,
|
|
max_tokens=4096
|
|
)
|
|
entities = result.get("entities", [])
|
|
logger.info(f"[Pass 1] Chunk {i+1}: {len(entities)} entities")
|
|
self._merge_entities(graph_id, entities)
|
|
success_count += 1
|
|
except Exception as e:
|
|
fail_count += 1
|
|
last_error = e
|
|
logger.error(f"[Pass 1] Chunk {i+1} failed: {type(e).__name__}: {e}")
|
|
logger.debug(traceback.format_exc())
|
|
|
|
logger.info(f"[Pass 1] Complete: {success_count}/{total} succeeded, "
|
|
f"{len(graph['nodes'])} unique entities found")
|
|
|
|
if success_count == 0 and total > 0:
|
|
raise RuntimeError(f"All {total} entity extraction calls failed. Last error: {last_error}")
|
|
|
|
# ── Pass 2: Relationship Discovery ──
|
|
|
|
def discover_relationships(
|
|
self,
|
|
graph_id: str,
|
|
full_text: str,
|
|
progress_callback: Optional[Callable] = None
|
|
):
|
|
"""Pass 2: Find relationships between known entities using larger text windows."""
|
|
graph = self._graphs[graph_id]
|
|
ontology = graph["ontology"]
|
|
nodes = graph["nodes"]
|
|
|
|
if not nodes:
|
|
logger.warning("[Pass 2] No entities to find relationships for")
|
|
return
|
|
|
|
# Build edge types schema
|
|
edge_types_json = json.dumps(
|
|
ontology.get("edge_types", []), indent=2, ensure_ascii=False
|
|
)
|
|
|
|
# Build entity list string
|
|
entity_names = sorted(set(n["name"] for n in nodes.values()))
|
|
entity_list = "\n".join(f"- {name} ({nodes[name.lower()]['labels'][0]})"
|
|
for name in entity_names if name.lower() in nodes)
|
|
|
|
# Use larger chunks for relationship discovery (5000 chars, 500 overlap)
|
|
rel_chunks = TextProcessor.split_text(full_text, chunk_size=5000, overlap=500)
|
|
total = len(rel_chunks)
|
|
success_count = 0
|
|
fail_count = 0
|
|
|
|
for i, chunk in enumerate(rel_chunks):
|
|
if progress_callback:
|
|
progress_callback(
|
|
f"[Pass 2] Finding relationships in section {i+1}/{total} "
|
|
f"(edges so far: {len(graph['edges'])})...",
|
|
(i + 1) / total
|
|
)
|
|
|
|
try:
|
|
logger.info(f"[Pass 2] Relationship discovery section {i+1}/{total} ({len(chunk)} chars)")
|
|
result = self.llm.chat_json(
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": RELATIONSHIP_EXTRACT_PROMPT % (edge_types_json, entity_list)
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": f"Find all relationships between the known entities in this text:\n\n{chunk}"
|
|
}
|
|
],
|
|
temperature=0.1,
|
|
max_tokens=4096
|
|
)
|
|
rels = result.get("relationships", [])
|
|
logger.info(f"[Pass 2] Section {i+1}: {len(rels)} relationships")
|
|
self._merge_relationships(graph_id, rels)
|
|
success_count += 1
|
|
except Exception as e:
|
|
fail_count += 1
|
|
logger.error(f"[Pass 2] Section {i+1} failed: {type(e).__name__}: {e}")
|
|
logger.debug(traceback.format_exc())
|
|
|
|
logger.info(f"[Pass 2] Complete: {success_count}/{total} succeeded, "
|
|
f"{len(graph['edges'])} total edges")
|
|
|
|
# ── Merge Helpers ──
|
|
|
|
def _merge_entities(self, graph_id: str, entities: List[Dict[str, Any]]):
|
|
"""Merge extracted entities into the graph."""
|
|
graph = self._graphs[graph_id]
|
|
nodes = graph["nodes"]
|
|
|
|
valid_entity_types = set()
|
|
if graph["ontology"]:
|
|
for et in graph["ontology"].get("entity_types", []):
|
|
valid_entity_types.add(et["name"])
|
|
|
|
for entity in entities:
|
|
name = entity.get("name", "").strip()
|
|
if not name:
|
|
continue
|
|
|
|
etype = entity.get("type", "Entity")
|
|
key = name.lower()
|
|
|
|
if key in nodes:
|
|
existing = nodes[key]
|
|
new_summary = entity.get("summary", "")
|
|
if new_summary and len(new_summary) > len(existing.get("summary", "")):
|
|
existing["summary"] = new_summary
|
|
for k, v in (entity.get("attributes", {}) or {}).items():
|
|
if v and not existing["attributes"].get(k):
|
|
existing["attributes"][k] = v
|
|
else:
|
|
labels = [etype] if etype in valid_entity_types else ["Entity"]
|
|
nodes[key] = {
|
|
"uuid": str(uuid.uuid4()),
|
|
"name": name,
|
|
"labels": labels,
|
|
"summary": entity.get("summary", ""),
|
|
"attributes": entity.get("attributes", {}) or {},
|
|
"created_at": None,
|
|
}
|
|
|
|
def _merge_relationships(self, graph_id: str, relationships: List[Dict[str, Any]]):
|
|
"""Merge extracted relationships into the graph."""
|
|
graph = self._graphs[graph_id]
|
|
nodes = graph["nodes"]
|
|
edges = graph["edges"]
|
|
|
|
existing_edges = set()
|
|
for e in edges:
|
|
existing_edges.add((e["source_node_name"].lower(), e["target_node_name"].lower(), e["name"]))
|
|
|
|
for rel in relationships:
|
|
rel_name = rel.get("name", "").strip()
|
|
source = rel.get("source", "").strip()
|
|
target = rel.get("target", "").strip()
|
|
if not rel_name or not source or not target:
|
|
continue
|
|
|
|
edge_key = (source.lower(), target.lower(), rel_name)
|
|
if edge_key in existing_edges:
|
|
continue
|
|
existing_edges.add(edge_key)
|
|
|
|
source_node = nodes.get(source.lower())
|
|
target_node = nodes.get(target.lower())
|
|
source_uuid = source_node["uuid"] if source_node else str(uuid.uuid4())
|
|
target_uuid = target_node["uuid"] if target_node else str(uuid.uuid4())
|
|
|
|
if not source_node:
|
|
nodes[source.lower()] = {
|
|
"uuid": source_uuid,
|
|
"name": source,
|
|
"labels": ["Entity"],
|
|
"summary": "",
|
|
"attributes": {},
|
|
"created_at": None,
|
|
}
|
|
if not target_node:
|
|
nodes[target.lower()] = {
|
|
"uuid": target_uuid,
|
|
"name": target,
|
|
"labels": ["Entity"],
|
|
"summary": "",
|
|
"attributes": {},
|
|
"created_at": None,
|
|
}
|
|
|
|
edges.append({
|
|
"uuid": str(uuid.uuid4()),
|
|
"name": rel_name,
|
|
"fact": rel.get("fact", ""),
|
|
"fact_type": rel_name,
|
|
"source_node_uuid": source_uuid,
|
|
"target_node_uuid": target_uuid,
|
|
"source_node_name": source,
|
|
"target_node_name": target,
|
|
"attributes": {},
|
|
"created_at": None,
|
|
"valid_at": None,
|
|
"invalid_at": None,
|
|
"expired_at": None,
|
|
"episodes": [],
|
|
})
|
|
|
|
# ── Data Access ──
|
|
|
|
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
|
|
graph = self._graphs.get(graph_id, {"nodes": {}, "edges": []})
|
|
nodes_list = list(graph["nodes"].values())
|
|
edges_list = graph["edges"]
|
|
return {
|
|
"graph_id": graph_id,
|
|
"nodes": nodes_list,
|
|
"edges": edges_list,
|
|
"node_count": len(nodes_list),
|
|
"edge_count": len(edges_list),
|
|
}
|
|
|
|
def save_graph_data(self, graph_id: str, project_dir: str) -> str:
|
|
data = self.get_graph_data(graph_id)
|
|
path = os.path.join(project_dir, "graph_data.json")
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
return path
|
|
|
|
@staticmethod
|
|
def load_graph_data(project_dir: str) -> Optional[Dict[str, Any]]:
|
|
path = os.path.join(project_dir, "graph_data.json")
|
|
if os.path.exists(path):
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
return None
|
|
|
|
def delete_graph(self, graph_id: str):
|
|
self._graphs.pop(graph_id, None)
|