MiroFish/backend/app/services/zep_graph_memory_updater.py

548 lines
21 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Zep图谱记忆更新服务
将模拟中的Agent活动动态更新到Zep图谱中
"""
import os
import time
import threading
import json
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass
from datetime import datetime
from queue import Queue, Empty
from zep_cloud.client import Zep
from ..config import Config
from ..utils.logger import get_logger
logger = get_logger('mirofish.zep_graph_memory_updater')
@dataclass
class AgentActivity:
"""Agent活动记录"""
platform: str # twitter / reddit
agent_id: int
agent_name: str
action_type: str # CREATE_POST, LIKE_POST, etc.
action_args: Dict[str, Any]
round_num: int
timestamp: str
def to_episode_text(self) -> str:
"""
将活动转换为可以发送给Zep的文本描述
采用自然语言描述格式让Zep能够从中提取实体和关系
不添加模拟相关的前缀,避免误导图谱更新
"""
# 根据不同的动作类型生成不同的描述
action_descriptions = {
"CREATE_POST": self._describe_create_post,
"LIKE_POST": self._describe_like_post,
"DISLIKE_POST": self._describe_dislike_post,
"REPOST": self._describe_repost,
"QUOTE_POST": self._describe_quote_post,
"FOLLOW": self._describe_follow,
"CREATE_COMMENT": self._describe_create_comment,
"LIKE_COMMENT": self._describe_like_comment,
"DISLIKE_COMMENT": self._describe_dislike_comment,
"SEARCH_POSTS": self._describe_search,
"SEARCH_USER": self._describe_search_user,
"MUTE": self._describe_mute,
}
describe_func = action_descriptions.get(self.action_type, self._describe_generic)
description = describe_func()
# 直接返回 "agent名称: 活动描述" 格式,不添加模拟前缀
return f"{self.agent_name}: {description}"
def _describe_create_post(self) -> str:
content = self.action_args.get("content", "")
if content:
return f"published a post: \"{content}\""
return "published a post"
def _describe_like_post(self) -> str:
"""Like post - includes post content and author info"""
post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "")
if post_content and post_author:
return f"liked {post_author}'s post: \"{post_content}\""
elif post_content:
return f"liked a post: \"{post_content}\""
elif post_author:
return f"liked a post by {post_author}"
return "liked a post"
def _describe_dislike_post(self) -> str:
"""Dislike post - includes post content and author info"""
post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "")
if post_content and post_author:
return f"disliked {post_author}'s post: \"{post_content}\""
elif post_content:
return f"disliked a post: \"{post_content}\""
elif post_author:
return f"disliked a post by {post_author}"
return "disliked a post"
def _describe_repost(self) -> str:
"""Repost - includes original post content and author info"""
original_content = self.action_args.get("original_content", "")
original_author = self.action_args.get("original_author_name", "")
if original_content and original_author:
return f"reposted {original_author}'s post: \"{original_content}\""
elif original_content:
return f"reposted a post: \"{original_content}\""
elif original_author:
return f"reposted a post by {original_author}"
return "reposted a post"
def _describe_quote_post(self) -> str:
"""Quote post - includes original content, author info, and quote comment"""
original_content = self.action_args.get("original_content", "")
original_author = self.action_args.get("original_author_name", "")
quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "")
base = ""
if original_content and original_author:
base = f"quoted {original_author}'s post \"{original_content}\""
elif original_content:
base = f"quoted a post \"{original_content}\""
elif original_author:
base = f"quoted a post by {original_author}"
else:
base = "quoted a post"
if quote_content:
base += f", commenting: \"{quote_content}\""
return base
def _describe_follow(self) -> str:
"""Follow user - includes followed user's name"""
target_user_name = self.action_args.get("target_user_name", "")
if target_user_name:
return f"followed user \"{target_user_name}\""
return "followed a user"
def _describe_create_comment(self) -> str:
"""Create comment - includes comment content and post info"""
content = self.action_args.get("content", "")
post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "")
if content:
if post_content and post_author:
return f"commented on {post_author}'s post \"{post_content}\": \"{content}\""
elif post_content:
return f"commented on post \"{post_content}\": \"{content}\""
elif post_author:
return f"commented on {post_author}'s post: \"{content}\""
return f"commented: \"{content}\""
return "posted a comment"
def _describe_like_comment(self) -> str:
"""Like comment - includes comment content and author info"""
comment_content = self.action_args.get("comment_content", "")
comment_author = self.action_args.get("comment_author_name", "")
if comment_content and comment_author:
return f"liked {comment_author}'s comment: \"{comment_content}\""
elif comment_content:
return f"liked a comment: \"{comment_content}\""
elif comment_author:
return f"liked a comment by {comment_author}"
return "liked a comment"
def _describe_dislike_comment(self) -> str:
"""Dislike comment - includes comment content and author info"""
comment_content = self.action_args.get("comment_content", "")
comment_author = self.action_args.get("comment_author_name", "")
if comment_content and comment_author:
return f"disliked {comment_author}'s comment: \"{comment_content}\""
elif comment_content:
return f"disliked a comment: \"{comment_content}\""
elif comment_author:
return f"disliked a comment by {comment_author}"
return "disliked a comment"
def _describe_search(self) -> str:
"""Search posts - includes search keywords"""
query = self.action_args.get("query", "") or self.action_args.get("keyword", "")
return f"searched for \"{query}\"" if query else "performed a search"
def _describe_search_user(self) -> str:
"""Search user - includes search keywords"""
query = self.action_args.get("query", "") or self.action_args.get("username", "")
return f"searched for user \"{query}\"" if query else "searched for a user"
def _describe_mute(self) -> str:
"""Mute user - includes muted user's name"""
target_user_name = self.action_args.get("target_user_name", "")
if target_user_name:
return f"muted user \"{target_user_name}\""
return "muted a user"
def _describe_generic(self) -> str:
# Generic description for unknown action types
return f"performed {self.action_type} action"
class ZepGraphMemoryUpdater:
"""
Zep图谱记忆更新器
监控模拟的actions日志文件将新的agent活动实时更新到Zep图谱中。
按平台分组每累积BATCH_SIZE条活动后批量发送到Zep。
所有有意义的行为都会被更新到Zepaction_args中会包含完整的上下文信息
- 点赞/踩的帖子原文
- 转发/引用的帖子原文
- 关注/屏蔽的用户名
- 点赞/踩的评论原文
"""
# 批量发送大小(每个平台累积多少条后发送)
BATCH_SIZE = 5
# 平台名称映射(用于控制台显示)
PLATFORM_DISPLAY_NAMES = {
'twitter': '世界1',
'reddit': '世界2',
}
# 发送间隔(秒),避免请求过快
SEND_INTERVAL = 0.5
# 重试配置
MAX_RETRIES = 3
RETRY_DELAY = 2 # 秒
def __init__(self, graph_id: str, api_key: Optional[str] = None):
"""
初始化更新器
Args:
graph_id: Zep图谱ID
api_key: Zep API Key可选默认从配置读取
"""
self.graph_id = graph_id
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY is not configured")
self.client = Zep(api_key=self.api_key)
# 活动队列
self._activity_queue: Queue = Queue()
# 按平台分组的活动缓冲区每个平台各自累积到BATCH_SIZE后批量发送
self._platform_buffers: Dict[str, List[AgentActivity]] = {
'twitter': [],
'reddit': [],
}
self._buffer_lock = threading.Lock()
# 控制标志
self._running = False
self._worker_thread: Optional[threading.Thread] = None
# 统计
self._total_activities = 0 # 实际添加到队列的活动数
self._total_sent = 0 # 成功发送到Zep的批次数
self._total_items_sent = 0 # 成功发送到Zep的活动条数
self._failed_count = 0 # 发送失败的批次数
self._skipped_count = 0 # 被过滤跳过的活动数DO_NOTHING
logger.info(f"ZepGraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
def _get_platform_display_name(self, platform: str) -> str:
"""获取平台的显示名称"""
return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform)
def start(self):
"""启动后台工作线程"""
if self._running:
return
self._running = True
self._worker_thread = threading.Thread(
target=self._worker_loop,
daemon=True,
name=f"ZepMemoryUpdater-{self.graph_id[:8]}"
)
self._worker_thread.start()
logger.info(f"ZepGraphMemoryUpdater 已启动: graph_id={self.graph_id}")
def stop(self):
"""停止后台工作线程"""
self._running = False
# 发送剩余的活动
self._flush_remaining()
if self._worker_thread and self._worker_thread.is_alive():
self._worker_thread.join(timeout=10)
logger.info(f"ZepGraphMemoryUpdater 已停止: graph_id={self.graph_id}, "
f"total_activities={self._total_activities}, "
f"batches_sent={self._total_sent}, "
f"items_sent={self._total_items_sent}, "
f"failed={self._failed_count}, "
f"skipped={self._skipped_count}")
def add_activity(self, activity: AgentActivity):
"""
添加一个agent活动到队列
所有有意义的行为都会被添加到队列,包括:
- CREATE_POST发帖
- CREATE_COMMENT评论
- QUOTE_POST引用帖子
- SEARCH_POSTS搜索帖子
- SEARCH_USER搜索用户
- LIKE_POST/DISLIKE_POST点赞/踩帖子)
- REPOST转发
- FOLLOW关注
- MUTE屏蔽
- LIKE_COMMENT/DISLIKE_COMMENT点赞/踩评论)
action_args中会包含完整的上下文信息如帖子原文、用户名等
Args:
activity: Agent活动记录
"""
# 跳过DO_NOTHING类型的活动
if activity.action_type == "DO_NOTHING":
self._skipped_count += 1
return
self._activity_queue.put(activity)
self._total_activities += 1
logger.debug(f"添加活动到Zep队列: {activity.agent_name} - {activity.action_type}")
def add_activity_from_dict(self, data: Dict[str, Any], platform: str):
"""
从字典数据添加活动
Args:
data: 从actions.jsonl解析的字典数据
platform: 平台名称 (twitter/reddit)
"""
# 跳过事件类型的条目
if "event_type" in data:
return
activity = AgentActivity(
platform=platform,
agent_id=data.get("agent_id", 0),
agent_name=data.get("agent_name", ""),
action_type=data.get("action_type", ""),
action_args=data.get("action_args", {}),
round_num=data.get("round", 0),
timestamp=data.get("timestamp", datetime.now().isoformat()),
)
self.add_activity(activity)
def _worker_loop(self):
"""后台工作循环 - 按平台批量发送活动到Zep"""
while self._running or not self._activity_queue.empty():
try:
# 尝试从队列获取活动超时1秒
try:
activity = self._activity_queue.get(timeout=1)
# 将活动添加到对应平台的缓冲区
platform = activity.platform.lower()
with self._buffer_lock:
if platform not in self._platform_buffers:
self._platform_buffers[platform] = []
self._platform_buffers[platform].append(activity)
# 检查该平台是否达到批量大小
if len(self._platform_buffers[platform]) >= self.BATCH_SIZE:
batch = self._platform_buffers[platform][:self.BATCH_SIZE]
self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:]
# 释放锁后再发送
self._send_batch_activities(batch, platform)
# 发送间隔,避免请求过快
time.sleep(self.SEND_INTERVAL)
except Empty:
pass
except Exception as e:
logger.error(f"工作循环异常: {e}")
time.sleep(1)
def _send_batch_activities(self, activities: List[AgentActivity], platform: str):
"""
批量发送活动到Zep图谱合并为一条文本
Args:
activities: Agent活动列表
platform: 平台名称
"""
if not activities:
return
# 将多条活动合并为一条文本,用换行分隔
episode_texts = [activity.to_episode_text() for activity in activities]
combined_text = "\n".join(episode_texts)
# 带重试的发送
for attempt in range(self.MAX_RETRIES):
try:
self.client.graph.add(
graph_id=self.graph_id,
type="text",
data=combined_text
)
self._total_sent += 1
self._total_items_sent += len(activities)
display_name = self._get_platform_display_name(platform)
logger.info(f"成功批量发送 {len(activities)}{display_name}活动到图谱 {self.graph_id}")
logger.debug(f"批量内容预览: {combined_text[:200]}...")
return
except Exception as e:
if attempt < self.MAX_RETRIES - 1:
logger.warning(f"批量发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}")
time.sleep(self.RETRY_DELAY * (attempt + 1))
else:
logger.error(f"批量发送到Zep失败已重试{self.MAX_RETRIES}次: {e}")
self._failed_count += 1
def _flush_remaining(self):
"""发送队列和缓冲区中剩余的活动"""
# 首先处理队列中剩余的活动,添加到缓冲区
while not self._activity_queue.empty():
try:
activity = self._activity_queue.get_nowait()
platform = activity.platform.lower()
with self._buffer_lock:
if platform not in self._platform_buffers:
self._platform_buffers[platform] = []
self._platform_buffers[platform].append(activity)
except Empty:
break
# 然后发送各平台缓冲区中剩余的活动即使不足BATCH_SIZE条
with self._buffer_lock:
for platform, buffer in self._platform_buffers.items():
if buffer:
display_name = self._get_platform_display_name(platform)
logger.info(f"发送{display_name}平台剩余的 {len(buffer)} 条活动")
self._send_batch_activities(buffer, platform)
# 清空所有缓冲区
for platform in self._platform_buffers:
self._platform_buffers[platform] = []
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
with self._buffer_lock:
buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()}
return {
"graph_id": self.graph_id,
"batch_size": self.BATCH_SIZE,
"total_activities": self._total_activities, # 添加到队列的活动总数
"batches_sent": self._total_sent, # 成功发送的批次数
"items_sent": self._total_items_sent, # 成功发送的活动条数
"failed_count": self._failed_count, # 发送失败的批次数
"skipped_count": self._skipped_count, # 被过滤跳过的活动数DO_NOTHING
"queue_size": self._activity_queue.qsize(),
"buffer_sizes": buffer_sizes, # 各平台缓冲区大小
"running": self._running,
}
class ZepGraphMemoryManager:
"""
管理多个模拟的Zep图谱记忆更新器
每个模拟可以有自己的更新器实例
"""
_updaters: Dict[str, ZepGraphMemoryUpdater] = {}
_lock = threading.Lock()
@classmethod
def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater:
"""
为模拟创建图谱记忆更新器
Args:
simulation_id: 模拟ID
graph_id: Zep图谱ID
Returns:
ZepGraphMemoryUpdater实例
"""
with cls._lock:
# 如果已存在,先停止旧的
if simulation_id in cls._updaters:
cls._updaters[simulation_id].stop()
updater = ZepGraphMemoryUpdater(graph_id)
updater.start()
cls._updaters[simulation_id] = updater
logger.info(f"创建图谱记忆更新器: simulation_id={simulation_id}, graph_id={graph_id}")
return updater
@classmethod
def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]:
"""获取模拟的更新器"""
return cls._updaters.get(simulation_id)
@classmethod
def stop_updater(cls, simulation_id: str):
"""停止并移除模拟的更新器"""
with cls._lock:
if simulation_id in cls._updaters:
cls._updaters[simulation_id].stop()
del cls._updaters[simulation_id]
logger.info(f"已停止图谱记忆更新器: simulation_id={simulation_id}")
# 防止 stop_all 重复调用的标志
_stop_all_done = False
@classmethod
def stop_all(cls):
"""停止所有更新器"""
# 防止重复调用
if cls._stop_all_done:
return
cls._stop_all_done = True
with cls._lock:
if cls._updaters:
for simulation_id, updater in list(cls._updaters.items()):
try:
updater.stop()
except Exception as e:
logger.error(f"停止更新器失败: simulation_id={simulation_id}, error={e}")
cls._updaters.clear()
logger.info("已停止所有图谱记忆更新器")
@classmethod
def get_all_stats(cls) -> Dict[str, Dict[str, Any]]:
"""获取所有更新器的统计信息"""
return {
sim_id: updater.get_stats()
for sim_id, updater in cls._updaters.items()
}