diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index 83221b3..359b096 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -20,6 +20,81 @@ from ..models.project import ProjectManager logger = get_logger('mirofish.api.simulation') +def _load_disk_graph_data(graph_id: str): + """Load graph data from disk by searching all projects for matching graph_id.""" + all_projects = ProjectManager.list_projects() + for proj in all_projects: + if proj.graph_id == graph_id: + project_dir = ProjectManager._get_project_dir(proj.project_id) + return LLMGraphBuilderService.load_graph_data(project_dir) + return None + + +def _get_filtered_entities(graph_id: str, entity_types=None, enrich=True): + """Get filtered entities, trying disk first then Zep fallback.""" + disk_data = _load_disk_graph_data(graph_id) + if disk_data: + manager = SimulationManager() + return manager._filter_entities_from_data(disk_data, entity_types) + + reader = ZepEntityReader() + return reader.filter_defined_entities( + graph_id=graph_id, + defined_entity_types=entity_types, + enrich_with_edges=enrich + ) + + +def _get_entity_by_uuid(graph_id: str, entity_uuid: str): + """Get a single entity by UUID, trying disk first then Zep fallback.""" + disk_data = _load_disk_graph_data(graph_id) + if disk_data: + from ..services.zep_entity_reader import EntityNode + nodes = disk_data.get("nodes", []) + edges = disk_data.get("edges", []) + node_map = {n["uuid"]: n for n in nodes} + + node = node_map.get(entity_uuid) + if not node: + return None + + related_edges = [] + related_nodes = [] + for edge in edges: + if edge.get("source_node_uuid") == entity_uuid or edge.get("target_node_uuid") == entity_uuid: + related_edges.append({ + "uuid": edge.get("uuid", ""), + "name": edge.get("name", ""), + "fact": edge.get("fact", ""), + "source_node_uuid": edge.get("source_node_uuid", ""), + "target_node_uuid": edge.get("target_node_uuid", ""), + "source_node_name": edge.get("source_node_name", ""), + "target_node_name": edge.get("target_node_name", ""), + }) + other_uuid = (edge.get("target_node_uuid") if edge.get("source_node_uuid") == entity_uuid + else edge.get("source_node_uuid")) + other_node = node_map.get(other_uuid) + if other_node: + related_nodes.append({ + "uuid": other_node.get("uuid", ""), + "name": other_node.get("name", ""), + "labels": other_node.get("labels", []), + }) + + return EntityNode( + uuid=entity_uuid, + name=node.get("name", ""), + labels=node.get("labels", []), + summary=node.get("summary", ""), + attributes=node.get("attributes", {}), + related_edges=related_edges, + related_nodes=related_nodes, + ) + + reader = ZepEntityReader() + return reader.get_entity_with_context(graph_id, entity_uuid) + + # Interview prompt 优化前缀 # 添加此前缀可以避免Agent调用工具,直接用文本回复 INTERVIEW_PROMPT_PREFIX = "Based on your persona, all past memories and actions, respond directly with text without calling any tools: " @@ -57,25 +132,14 @@ def get_graph_entities(graph_id: str): enrich: 是否获取相关边信息(默认true) """ try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": "ZEP_API_KEY is not configured" - }), 500 - entity_types_str = request.args.get('entity_types', '') entity_types = [t.strip() for t in entity_types_str.split(',') if t.strip()] if entity_types_str else None enrich = request.args.get('enrich', 'true').lower() == 'true' - + logger.info(f"获取图谱实体: graph_id={graph_id}, entity_types={entity_types}, enrich={enrich}") - - reader = ZepEntityReader() - result = reader.filter_defined_entities( - graph_id=graph_id, - defined_entity_types=entity_types, - enrich_with_edges=enrich - ) - + + result = _get_filtered_entities(graph_id, entity_types, enrich) + return jsonify({ "success": True, "data": result.to_dict() @@ -94,14 +158,7 @@ def get_graph_entities(graph_id: str): def get_entity_detail(graph_id: str, entity_uuid: str): """获取单个实体的详细信息""" try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": "ZEP_API_KEY is not configured" - }), 500 - - reader = ZepEntityReader() - entity = reader.get_entity_with_context(graph_id, entity_uuid) + entity = _get_entity_by_uuid(graph_id, entity_uuid) if not entity: return jsonify({ @@ -127,20 +184,10 @@ def get_entity_detail(graph_id: str, entity_uuid: str): def get_entities_by_type(graph_id: str, entity_type: str): """获取指定类型的所有实体""" try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": "ZEP_API_KEY is not configured" - }), 500 - enrich = request.args.get('enrich', 'true').lower() == 'true' - - reader = ZepEntityReader() - entities = reader.get_entities_by_type( - graph_id=graph_id, - entity_type=entity_type, - enrich_with_edges=enrich - ) + + result = _get_filtered_entities(graph_id, [entity_type], enrich) + entities = result.entities return jsonify({ "success": True, @@ -473,27 +520,9 @@ def prepare_simulation(): try: logger.info(f"同步获取实体数量: graph_id={state.graph_id}") - # Try disk-stored graph data first (LLM-built graphs) - disk_graph_data = None - all_projects = ProjectManager.list_projects() - for proj in all_projects: - if proj.graph_id == state.graph_id: - project_dir = ProjectManager._get_project_dir(proj.project_id) - disk_graph_data = LLMGraphBuilderService.load_graph_data(project_dir) - break - - if disk_graph_data: - manager = SimulationManager() - filtered_preview = manager._filter_entities_from_data( - disk_graph_data, entity_types_list - ) - else: - reader = ZepEntityReader() - filtered_preview = reader.filter_defined_entities( - graph_id=state.graph_id, - defined_entity_types=entity_types_list, - enrich_with_edges=False - ) + filtered_preview = _get_filtered_entities( + state.graph_id, entity_types_list, enrich=False + ) state.entities_count = filtered_preview.filtered_count state.entity_types = list(filtered_preview.entity_types) @@ -1412,12 +1441,7 @@ def generate_profiles(): use_llm = data.get('use_llm', True) platform = data.get('platform', 'reddit') - reader = ZepEntityReader() - filtered = reader.filter_defined_entities( - graph_id=graph_id, - defined_entity_types=entity_types, - enrich_with_edges=True - ) + filtered = _get_filtered_entities(graph_id, entity_types, enrich=True) if filtered.filtered_count == 0: return jsonify({