Replace Zep with direct LLM calls for graph building

Add LLMGraphBuilderService that extracts entities/relationships
from text chunks using Groq instead of Zep Cloud API. Graph data
is persisted to disk as graph_data.json, with fallback to Zep
for existing graphs.
This commit is contained in:
_Yusaki 2026-03-13 19:07:40 +07:00
parent 034504c92a
commit 79519ddd54
2 changed files with 314 additions and 74 deletions

View file

@ -12,6 +12,7 @@ from . import graph_bp
from ..config import Config
from ..services.ontology_generator import OntologyGenerator
from ..services.graph_builder import GraphBuilderService
from ..services.llm_graph_builder import LLMGraphBuilderService
from ..services.text_processor import TextProcessor
from ..utils.file_parser import FileParser
from ..utils.logger import get_logger
@ -282,17 +283,6 @@ def build_graph():
try:
logger.info("=== 开始构建图谱 ===")
# 检查配置
errors = []
if not Config.ZEP_API_KEY:
errors.append("ZEP_API_KEY is not configured")
if errors:
logger.error(f"配置错误: {errors}")
return jsonify({
"success": False,
"error": "Configuration error: " + "; ".join(errors)
}), 500
# 解析请求
data = request.get_json() or {}
project_id = data.get('project_id')
@ -374,16 +364,16 @@ def build_graph():
def build_task():
build_logger = get_logger('mirofish.build')
try:
build_logger.info(f"[{task_id}] 开始构建图谱...")
build_logger.info(f"[{task_id}] 开始构建图谱 (LLM mode)...")
task_manager.update_task(
task_id,
task_id,
status=TaskStatus.PROCESSING,
message="Initializing graph build service..."
message="Initializing LLM graph build service..."
)
# 创建图谱构建服务
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
# 创建 LLM 图谱构建服务(不需要 Zep
builder = LLMGraphBuilderService()
# 分块
task_manager.update_task(
task_id,
@ -391,71 +381,48 @@ def build_graph():
progress=5
)
chunks = TextProcessor.split_text(
text,
chunk_size=chunk_size,
text,
chunk_size=chunk_size,
overlap=chunk_overlap
)
total_chunks = len(chunks)
# 创建图谱
task_manager.update_task(
task_id,
message="Creating Zep graph...",
message="Creating graph...",
progress=10
)
graph_id = builder.create_graph(name=graph_name)
# 更新项目的graph_id
project.graph_id = graph_id
ProjectManager.save_project(project)
# 设置本体
task_manager.update_task(
task_id,
message="Setting ontology definition...",
progress=15
)
builder.set_ontology(graph_id, ontology)
# 添加文本progress_callback 签名是 (msg, progress_ratio)
def add_progress_callback(msg, progress_ratio):
progress = 15 + int(progress_ratio * 40) # 15% - 55%
# LLM extraction from chunks
def extract_progress_callback(msg, progress_ratio):
progress = 15 + int(progress_ratio * 75) # 15% - 90%
task_manager.update_task(
task_id,
message=msg,
progress=progress
)
task_manager.update_task(
task_id,
message=f"Adding {total_chunks} text chunks...",
message=f"Extracting entities from {total_chunks} chunks via LLM...",
progress=15
)
episode_uuids = builder.add_text_batches(
graph_id,
builder.extract_from_chunks(
graph_id,
chunks,
batch_size=3,
progress_callback=add_progress_callback
progress_callback=extract_progress_callback
)
# 等待Zep处理完成查询每个episode的processed状态
task_manager.update_task(
task_id,
message="Waiting for Zep to process data...",
progress=55
)
def wait_progress_callback(msg, progress_ratio):
progress = 55 + int(progress_ratio * 35) # 55% - 90%
task_manager.update_task(
task_id,
message=msg,
progress=progress
)
builder._wait_for_episodes(episode_uuids, wait_progress_callback)
# 获取图谱数据
task_manager.update_task(
task_id,
@ -463,15 +430,19 @@ def build_graph():
progress=95
)
graph_data = builder.get_graph_data(graph_id)
# Persist graph data to disk
project_dir = ProjectManager._get_project_dir(project_id)
builder.save_graph_data(graph_id, project_dir)
# 更新项目状态
project.status = ProjectStatus.GRAPH_COMPLETED
ProjectManager.save_project(project)
node_count = graph_data.get("node_count", 0)
edge_count = graph_data.get("edge_count", 0)
build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}")
# 完成
task_manager.update_task(
task_id,
@ -486,16 +457,16 @@ def build_graph():
"chunk_count": total_chunks
}
)
except Exception as e:
# 更新项目状态为失败
build_logger.error(f"[{task_id}] 图谱构建失败: {str(e)}")
build_logger.debug(traceback.format_exc())
project.status = ProjectStatus.FAILED
project.error = str(e)
ProjectManager.save_project(project)
task_manager.update_task(
task_id,
status=TaskStatus.FAILED,
@ -565,21 +536,36 @@ def list_tasks():
def get_graph_data(graph_id: str):
"""
获取图谱数据节点和边
First tries disk (LLM builder), falls back to Zep if available.
"""
try:
if not Config.ZEP_API_KEY:
return jsonify({
"success": False,
"error": "ZEP_API_KEY is not configured"
}), 500
# Find which project owns this graph_id
all_projects = ProjectManager.list_projects()
for proj_summary in all_projects:
proj = ProjectManager.get_project(proj_summary["project_id"])
if proj and proj.graph_id == graph_id:
project_dir = ProjectManager._get_project_dir(proj.project_id)
graph_data = LLMGraphBuilderService.load_graph_data(project_dir)
if graph_data:
return jsonify({
"success": True,
"data": graph_data
})
break
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
graph_data = builder.get_graph_data(graph_id)
# Fallback to Zep if graph data not on disk
if Config.ZEP_API_KEY:
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
graph_data = builder.get_graph_data(graph_id)
return jsonify({
"success": True,
"data": graph_data
})
return jsonify({
"success": True,
"data": graph_data
})
"success": False,
"error": f"Graph data not found for {graph_id}"
}), 404
except Exception as e:
return jsonify({

View file

@ -0,0 +1,254 @@
"""
LLM-based graph builder service
Replaces Zep with direct LLM calls for entity/relationship extraction
"""
import os
import uuid
import json
from typing import Dict, Any, List, Optional, Callable
from ..utils.llm_client import LLMClient
from ..models.task import TaskManager, TaskStatus
from .text_processor import TextProcessor
EXTRACT_SYSTEM_PROMPT = """You are a knowledge graph extraction engine. Given a text chunk and an ontology schema, extract all entities and relationships.
ONTOLOGY SCHEMA:
{ontology_json}
RULES:
1. Extract entities that match the entity_types defined in the schema. Each entity needs: name, type (matching an entity_type name), summary (1-2 sentences), and any attributes defined for that type.
2. Extract relationships between entities that match the edge_types defined in the schema. Each relationship needs: name (the edge type name), source (entity name), target (entity name), and a fact (short description of the relationship).
3. Only extract entities and relationships that are explicitly mentioned or strongly implied in the text.
4. Use consistent entity names across extractions (e.g., always "Mira" not sometimes "Mira" and sometimes "Mira the Socializer").
5. If no entities or relationships are found, return empty arrays.
Return JSON in this exact format:
{
"entities": [
{
"name": "EntityName",
"type": "EntityTypeName",
"summary": "Brief description",
"attributes": {"attr_name": "attr_value"}
}
],
"relationships": [
{
"name": "EDGE_TYPE_NAME",
"source": "SourceEntityName",
"target": "TargetEntityName",
"fact": "Description of this relationship"
}
]
}"""
class LLMGraphBuilderService:
"""
Graph builder that uses direct LLM calls instead of Zep.
Same interface as GraphBuilderService for drop-in replacement.
"""
def __init__(self, llm_client: Optional[LLMClient] = None):
self.llm = llm_client or LLMClient()
self.task_manager = TaskManager()
# In-memory graph storage (keyed by graph_id)
self._graphs: Dict[str, Dict[str, Any]] = {}
def create_graph(self, name: str) -> str:
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self._graphs[graph_id] = {
"name": name,
"ontology": None,
"nodes": {}, # keyed by normalized name
"edges": [],
}
return graph_id
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
if graph_id in self._graphs:
self._graphs[graph_id]["ontology"] = ontology
def extract_from_chunks(
self,
graph_id: str,
chunks: List[str],
progress_callback: Optional[Callable] = None
):
"""Extract entities and relationships from text chunks using LLM."""
graph = self._graphs[graph_id]
ontology = graph["ontology"]
ontology_json = json.dumps(ontology, indent=2, ensure_ascii=False)
total = len(chunks)
for i, chunk in enumerate(chunks):
if progress_callback:
progress_callback(
f"Extracting from chunk {i+1}/{total}...",
(i + 1) / total
)
try:
result = self.llm.chat_json(
messages=[
{
"role": "system",
"content": EXTRACT_SYSTEM_PROMPT.format(ontology_json=ontology_json)
},
{
"role": "user",
"content": f"Extract entities and relationships from this text:\n\n{chunk}"
}
],
temperature=0.1,
max_tokens=4096
)
self._merge_extraction(graph_id, result)
except Exception as e:
if progress_callback:
progress_callback(f"Chunk {i+1} extraction error: {e}", (i + 1) / total)
def _merge_extraction(self, graph_id: str, result: Dict[str, Any]):
"""Merge extracted entities/relationships into the graph, deduplicating by name."""
graph = self._graphs[graph_id]
nodes = graph["nodes"]
edges = graph["edges"]
# Valid entity type names from ontology
valid_entity_types = set()
if graph["ontology"]:
for et in graph["ontology"].get("entity_types", []):
valid_entity_types.add(et["name"])
# Valid edge type names
valid_edge_types = set()
if graph["ontology"]:
for et in graph["ontology"].get("edge_types", []):
valid_edge_types.add(et["name"])
# Merge entities
for entity in result.get("entities", []):
name = entity.get("name", "").strip()
if not name:
continue
etype = entity.get("type", "Entity")
key = name.lower()
if key in nodes:
# Update summary if new one is longer
existing = nodes[key]
new_summary = entity.get("summary", "")
if new_summary and len(new_summary) > len(existing.get("summary", "")):
existing["summary"] = new_summary
# Merge attributes
for k, v in entity.get("attributes", {}).items():
if v and not existing["attributes"].get(k):
existing["attributes"][k] = v
else:
labels = [etype] if etype in valid_entity_types else ["Entity"]
nodes[key] = {
"uuid": str(uuid.uuid4()),
"name": name,
"labels": labels,
"summary": entity.get("summary", ""),
"attributes": entity.get("attributes", {}) or {},
"created_at": None,
}
# Merge relationships (deduplicate by source+target+name)
existing_edges = set()
for e in edges:
existing_edges.add((e["source_node_name"].lower(), e["target_node_name"].lower(), e["name"]))
for rel in result.get("relationships", []):
rel_name = rel.get("name", "").strip()
source = rel.get("source", "").strip()
target = rel.get("target", "").strip()
if not rel_name or not source or not target:
continue
edge_key = (source.lower(), target.lower(), rel_name)
if edge_key in existing_edges:
continue
existing_edges.add(edge_key)
# Resolve node UUIDs
source_node = nodes.get(source.lower())
target_node = nodes.get(target.lower())
source_uuid = source_node["uuid"] if source_node else str(uuid.uuid4())
target_uuid = target_node["uuid"] if target_node else str(uuid.uuid4())
# Create placeholder nodes if they don't exist
if not source_node:
nodes[source.lower()] = {
"uuid": source_uuid,
"name": source,
"labels": ["Entity"],
"summary": "",
"attributes": {},
"created_at": None,
}
if not target_node:
nodes[target.lower()] = {
"uuid": target_uuid,
"name": target,
"labels": ["Entity"],
"summary": "",
"attributes": {},
"created_at": None,
}
edges.append({
"uuid": str(uuid.uuid4()),
"name": rel_name,
"fact": rel.get("fact", ""),
"fact_type": rel_name,
"source_node_uuid": source_uuid,
"target_node_uuid": target_uuid,
"source_node_name": source,
"target_node_name": target,
"attributes": {},
"created_at": None,
"valid_at": None,
"invalid_at": None,
"expired_at": None,
"episodes": [],
})
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
"""Return graph data in the same format as the Zep-based builder."""
graph = self._graphs.get(graph_id, {"nodes": {}, "edges": []})
nodes_list = list(graph["nodes"].values())
edges_list = graph["edges"]
return {
"graph_id": graph_id,
"nodes": nodes_list,
"edges": edges_list,
"node_count": len(nodes_list),
"edge_count": len(edges_list),
}
def save_graph_data(self, graph_id: str, project_dir: str) -> str:
"""Persist graph data to a JSON file in the project directory."""
data = self.get_graph_data(graph_id)
path = os.path.join(project_dir, "graph_data.json")
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
return path
@staticmethod
def load_graph_data(project_dir: str) -> Optional[Dict[str, Any]]:
"""Load persisted graph data from disk."""
path = os.path.join(project_dir, "graph_data.json")
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
return None
def delete_graph(self, graph_id: str):
self._graphs.pop(graph_id, None)