Replace Zep with direct LLM calls for graph building
Add LLMGraphBuilderService that extracts entities/relationships from text chunks using Groq instead of Zep Cloud API. Graph data is persisted to disk as graph_data.json, with fallback to Zep for existing graphs.
This commit is contained in:
parent
034504c92a
commit
79519ddd54
2 changed files with 314 additions and 74 deletions
|
|
@ -12,6 +12,7 @@ from . import graph_bp
|
|||
from ..config import Config
|
||||
from ..services.ontology_generator import OntologyGenerator
|
||||
from ..services.graph_builder import GraphBuilderService
|
||||
from ..services.llm_graph_builder import LLMGraphBuilderService
|
||||
from ..services.text_processor import TextProcessor
|
||||
from ..utils.file_parser import FileParser
|
||||
from ..utils.logger import get_logger
|
||||
|
|
@ -282,17 +283,6 @@ def build_graph():
|
|||
try:
|
||||
logger.info("=== 开始构建图谱 ===")
|
||||
|
||||
# 检查配置
|
||||
errors = []
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors.append("ZEP_API_KEY is not configured")
|
||||
if errors:
|
||||
logger.error(f"配置错误: {errors}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "Configuration error: " + "; ".join(errors)
|
||||
}), 500
|
||||
|
||||
# 解析请求
|
||||
data = request.get_json() or {}
|
||||
project_id = data.get('project_id')
|
||||
|
|
@ -374,16 +364,16 @@ def build_graph():
|
|||
def build_task():
|
||||
build_logger = get_logger('mirofish.build')
|
||||
try:
|
||||
build_logger.info(f"[{task_id}] 开始构建图谱...")
|
||||
build_logger.info(f"[{task_id}] 开始构建图谱 (LLM mode)...")
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
task_id,
|
||||
status=TaskStatus.PROCESSING,
|
||||
message="Initializing graph build service..."
|
||||
message="Initializing LLM graph build service..."
|
||||
)
|
||||
|
||||
# 创建图谱构建服务
|
||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||||
|
||||
|
||||
# 创建 LLM 图谱构建服务(不需要 Zep)
|
||||
builder = LLMGraphBuilderService()
|
||||
|
||||
# 分块
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
|
|
@ -391,71 +381,48 @@ def build_graph():
|
|||
progress=5
|
||||
)
|
||||
chunks = TextProcessor.split_text(
|
||||
text,
|
||||
chunk_size=chunk_size,
|
||||
text,
|
||||
chunk_size=chunk_size,
|
||||
overlap=chunk_overlap
|
||||
)
|
||||
total_chunks = len(chunks)
|
||||
|
||||
|
||||
# 创建图谱
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message="Creating Zep graph...",
|
||||
message="Creating graph...",
|
||||
progress=10
|
||||
)
|
||||
graph_id = builder.create_graph(name=graph_name)
|
||||
|
||||
|
||||
# 更新项目的graph_id
|
||||
project.graph_id = graph_id
|
||||
ProjectManager.save_project(project)
|
||||
|
||||
|
||||
# 设置本体
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message="Setting ontology definition...",
|
||||
progress=15
|
||||
)
|
||||
builder.set_ontology(graph_id, ontology)
|
||||
|
||||
# 添加文本(progress_callback 签名是 (msg, progress_ratio))
|
||||
def add_progress_callback(msg, progress_ratio):
|
||||
progress = 15 + int(progress_ratio * 40) # 15% - 55%
|
||||
|
||||
# LLM extraction from chunks
|
||||
def extract_progress_callback(msg, progress_ratio):
|
||||
progress = 15 + int(progress_ratio * 75) # 15% - 90%
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message=msg,
|
||||
progress=progress
|
||||
)
|
||||
|
||||
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message=f"Adding {total_chunks} text chunks...",
|
||||
message=f"Extracting entities from {total_chunks} chunks via LLM...",
|
||||
progress=15
|
||||
)
|
||||
|
||||
episode_uuids = builder.add_text_batches(
|
||||
graph_id,
|
||||
|
||||
builder.extract_from_chunks(
|
||||
graph_id,
|
||||
chunks,
|
||||
batch_size=3,
|
||||
progress_callback=add_progress_callback
|
||||
progress_callback=extract_progress_callback
|
||||
)
|
||||
|
||||
# 等待Zep处理完成(查询每个episode的processed状态)
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message="Waiting for Zep to process data...",
|
||||
progress=55
|
||||
)
|
||||
|
||||
def wait_progress_callback(msg, progress_ratio):
|
||||
progress = 55 + int(progress_ratio * 35) # 55% - 90%
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message=msg,
|
||||
progress=progress
|
||||
)
|
||||
|
||||
builder._wait_for_episodes(episode_uuids, wait_progress_callback)
|
||||
|
||||
|
||||
# 获取图谱数据
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
|
|
@ -463,15 +430,19 @@ def build_graph():
|
|||
progress=95
|
||||
)
|
||||
graph_data = builder.get_graph_data(graph_id)
|
||||
|
||||
|
||||
# Persist graph data to disk
|
||||
project_dir = ProjectManager._get_project_dir(project_id)
|
||||
builder.save_graph_data(graph_id, project_dir)
|
||||
|
||||
# 更新项目状态
|
||||
project.status = ProjectStatus.GRAPH_COMPLETED
|
||||
ProjectManager.save_project(project)
|
||||
|
||||
|
||||
node_count = graph_data.get("node_count", 0)
|
||||
edge_count = graph_data.get("edge_count", 0)
|
||||
build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}")
|
||||
|
||||
|
||||
# 完成
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
|
|
@ -486,16 +457,16 @@ def build_graph():
|
|||
"chunk_count": total_chunks
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# 更新项目状态为失败
|
||||
build_logger.error(f"[{task_id}] 图谱构建失败: {str(e)}")
|
||||
build_logger.debug(traceback.format_exc())
|
||||
|
||||
|
||||
project.status = ProjectStatus.FAILED
|
||||
project.error = str(e)
|
||||
ProjectManager.save_project(project)
|
||||
|
||||
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
status=TaskStatus.FAILED,
|
||||
|
|
@ -565,21 +536,36 @@ def list_tasks():
|
|||
def get_graph_data(graph_id: str):
|
||||
"""
|
||||
获取图谱数据(节点和边)
|
||||
First tries disk (LLM builder), falls back to Zep if available.
|
||||
"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "ZEP_API_KEY is not configured"
|
||||
}), 500
|
||||
# Find which project owns this graph_id
|
||||
all_projects = ProjectManager.list_projects()
|
||||
for proj_summary in all_projects:
|
||||
proj = ProjectManager.get_project(proj_summary["project_id"])
|
||||
if proj and proj.graph_id == graph_id:
|
||||
project_dir = ProjectManager._get_project_dir(proj.project_id)
|
||||
graph_data = LLMGraphBuilderService.load_graph_data(project_dir)
|
||||
if graph_data:
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"data": graph_data
|
||||
})
|
||||
break
|
||||
|
||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||||
graph_data = builder.get_graph_data(graph_id)
|
||||
# Fallback to Zep if graph data not on disk
|
||||
if Config.ZEP_API_KEY:
|
||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||||
graph_data = builder.get_graph_data(graph_id)
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"data": graph_data
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"data": graph_data
|
||||
})
|
||||
"success": False,
|
||||
"error": f"Graph data not found for {graph_id}"
|
||||
}), 404
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
|
|
|
|||
254
backend/app/services/llm_graph_builder.py
Normal file
254
backend/app/services/llm_graph_builder.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
"""
|
||||
LLM-based graph builder service
|
||||
Replaces Zep with direct LLM calls for entity/relationship extraction
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
|
||||
from ..utils.llm_client import LLMClient
|
||||
from ..models.task import TaskManager, TaskStatus
|
||||
from .text_processor import TextProcessor
|
||||
|
||||
|
||||
EXTRACT_SYSTEM_PROMPT = """You are a knowledge graph extraction engine. Given a text chunk and an ontology schema, extract all entities and relationships.
|
||||
|
||||
ONTOLOGY SCHEMA:
|
||||
{ontology_json}
|
||||
|
||||
RULES:
|
||||
1. Extract entities that match the entity_types defined in the schema. Each entity needs: name, type (matching an entity_type name), summary (1-2 sentences), and any attributes defined for that type.
|
||||
2. Extract relationships between entities that match the edge_types defined in the schema. Each relationship needs: name (the edge type name), source (entity name), target (entity name), and a fact (short description of the relationship).
|
||||
3. Only extract entities and relationships that are explicitly mentioned or strongly implied in the text.
|
||||
4. Use consistent entity names across extractions (e.g., always "Mira" not sometimes "Mira" and sometimes "Mira the Socializer").
|
||||
5. If no entities or relationships are found, return empty arrays.
|
||||
|
||||
Return JSON in this exact format:
|
||||
{
|
||||
"entities": [
|
||||
{
|
||||
"name": "EntityName",
|
||||
"type": "EntityTypeName",
|
||||
"summary": "Brief description",
|
||||
"attributes": {"attr_name": "attr_value"}
|
||||
}
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"name": "EDGE_TYPE_NAME",
|
||||
"source": "SourceEntityName",
|
||||
"target": "TargetEntityName",
|
||||
"fact": "Description of this relationship"
|
||||
}
|
||||
]
|
||||
}"""
|
||||
|
||||
|
||||
class LLMGraphBuilderService:
|
||||
"""
|
||||
Graph builder that uses direct LLM calls instead of Zep.
|
||||
Same interface as GraphBuilderService for drop-in replacement.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
self.llm = llm_client or LLMClient()
|
||||
self.task_manager = TaskManager()
|
||||
# In-memory graph storage (keyed by graph_id)
|
||||
self._graphs: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def create_graph(self, name: str) -> str:
|
||||
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
|
||||
self._graphs[graph_id] = {
|
||||
"name": name,
|
||||
"ontology": None,
|
||||
"nodes": {}, # keyed by normalized name
|
||||
"edges": [],
|
||||
}
|
||||
return graph_id
|
||||
|
||||
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
|
||||
if graph_id in self._graphs:
|
||||
self._graphs[graph_id]["ontology"] = ontology
|
||||
|
||||
def extract_from_chunks(
|
||||
self,
|
||||
graph_id: str,
|
||||
chunks: List[str],
|
||||
progress_callback: Optional[Callable] = None
|
||||
):
|
||||
"""Extract entities and relationships from text chunks using LLM."""
|
||||
graph = self._graphs[graph_id]
|
||||
ontology = graph["ontology"]
|
||||
ontology_json = json.dumps(ontology, indent=2, ensure_ascii=False)
|
||||
|
||||
total = len(chunks)
|
||||
for i, chunk in enumerate(chunks):
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
f"Extracting from chunk {i+1}/{total}...",
|
||||
(i + 1) / total
|
||||
)
|
||||
|
||||
try:
|
||||
result = self.llm.chat_json(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_SYSTEM_PROMPT.format(ontology_json=ontology_json)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Extract entities and relationships from this text:\n\n{chunk}"
|
||||
}
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=4096
|
||||
)
|
||||
self._merge_extraction(graph_id, result)
|
||||
except Exception as e:
|
||||
if progress_callback:
|
||||
progress_callback(f"Chunk {i+1} extraction error: {e}", (i + 1) / total)
|
||||
|
||||
def _merge_extraction(self, graph_id: str, result: Dict[str, Any]):
|
||||
"""Merge extracted entities/relationships into the graph, deduplicating by name."""
|
||||
graph = self._graphs[graph_id]
|
||||
nodes = graph["nodes"]
|
||||
edges = graph["edges"]
|
||||
|
||||
# Valid entity type names from ontology
|
||||
valid_entity_types = set()
|
||||
if graph["ontology"]:
|
||||
for et in graph["ontology"].get("entity_types", []):
|
||||
valid_entity_types.add(et["name"])
|
||||
|
||||
# Valid edge type names
|
||||
valid_edge_types = set()
|
||||
if graph["ontology"]:
|
||||
for et in graph["ontology"].get("edge_types", []):
|
||||
valid_edge_types.add(et["name"])
|
||||
|
||||
# Merge entities
|
||||
for entity in result.get("entities", []):
|
||||
name = entity.get("name", "").strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
etype = entity.get("type", "Entity")
|
||||
key = name.lower()
|
||||
|
||||
if key in nodes:
|
||||
# Update summary if new one is longer
|
||||
existing = nodes[key]
|
||||
new_summary = entity.get("summary", "")
|
||||
if new_summary and len(new_summary) > len(existing.get("summary", "")):
|
||||
existing["summary"] = new_summary
|
||||
# Merge attributes
|
||||
for k, v in entity.get("attributes", {}).items():
|
||||
if v and not existing["attributes"].get(k):
|
||||
existing["attributes"][k] = v
|
||||
else:
|
||||
labels = [etype] if etype in valid_entity_types else ["Entity"]
|
||||
nodes[key] = {
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"labels": labels,
|
||||
"summary": entity.get("summary", ""),
|
||||
"attributes": entity.get("attributes", {}) or {},
|
||||
"created_at": None,
|
||||
}
|
||||
|
||||
# Merge relationships (deduplicate by source+target+name)
|
||||
existing_edges = set()
|
||||
for e in edges:
|
||||
existing_edges.add((e["source_node_name"].lower(), e["target_node_name"].lower(), e["name"]))
|
||||
|
||||
for rel in result.get("relationships", []):
|
||||
rel_name = rel.get("name", "").strip()
|
||||
source = rel.get("source", "").strip()
|
||||
target = rel.get("target", "").strip()
|
||||
if not rel_name or not source or not target:
|
||||
continue
|
||||
|
||||
edge_key = (source.lower(), target.lower(), rel_name)
|
||||
if edge_key in existing_edges:
|
||||
continue
|
||||
existing_edges.add(edge_key)
|
||||
|
||||
# Resolve node UUIDs
|
||||
source_node = nodes.get(source.lower())
|
||||
target_node = nodes.get(target.lower())
|
||||
source_uuid = source_node["uuid"] if source_node else str(uuid.uuid4())
|
||||
target_uuid = target_node["uuid"] if target_node else str(uuid.uuid4())
|
||||
|
||||
# Create placeholder nodes if they don't exist
|
||||
if not source_node:
|
||||
nodes[source.lower()] = {
|
||||
"uuid": source_uuid,
|
||||
"name": source,
|
||||
"labels": ["Entity"],
|
||||
"summary": "",
|
||||
"attributes": {},
|
||||
"created_at": None,
|
||||
}
|
||||
if not target_node:
|
||||
nodes[target.lower()] = {
|
||||
"uuid": target_uuid,
|
||||
"name": target,
|
||||
"labels": ["Entity"],
|
||||
"summary": "",
|
||||
"attributes": {},
|
||||
"created_at": None,
|
||||
}
|
||||
|
||||
edges.append({
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"name": rel_name,
|
||||
"fact": rel.get("fact", ""),
|
||||
"fact_type": rel_name,
|
||||
"source_node_uuid": source_uuid,
|
||||
"target_node_uuid": target_uuid,
|
||||
"source_node_name": source,
|
||||
"target_node_name": target,
|
||||
"attributes": {},
|
||||
"created_at": None,
|
||||
"valid_at": None,
|
||||
"invalid_at": None,
|
||||
"expired_at": None,
|
||||
"episodes": [],
|
||||
})
|
||||
|
||||
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
|
||||
"""Return graph data in the same format as the Zep-based builder."""
|
||||
graph = self._graphs.get(graph_id, {"nodes": {}, "edges": []})
|
||||
nodes_list = list(graph["nodes"].values())
|
||||
edges_list = graph["edges"]
|
||||
|
||||
return {
|
||||
"graph_id": graph_id,
|
||||
"nodes": nodes_list,
|
||||
"edges": edges_list,
|
||||
"node_count": len(nodes_list),
|
||||
"edge_count": len(edges_list),
|
||||
}
|
||||
|
||||
def save_graph_data(self, graph_id: str, project_dir: str) -> str:
|
||||
"""Persist graph data to a JSON file in the project directory."""
|
||||
data = self.get_graph_data(graph_id)
|
||||
path = os.path.join(project_dir, "graph_data.json")
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
return path
|
||||
|
||||
@staticmethod
|
||||
def load_graph_data(project_dir: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load persisted graph data from disk."""
|
||||
path = os.path.join(project_dir, "graph_data.json")
|
||||
if os.path.exists(path):
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
return None
|
||||
|
||||
def delete_graph(self, graph_id: str):
|
||||
self._graphs.pop(graph_id, None)
|
||||
Loading…
Reference in a new issue