Enhance action context enrichment and update activity logging
- Added context enrichment for actions in `fetch_new_actions_from_db`, providing complete information for posts, comments, and user interactions. - Introduced a new `_enrich_action_context` function to supplement action arguments with relevant details such as post content, author names, and comment information. - Updated the `ZepGraphMemoryUpdater` to batch send activities by platform, improving efficiency in processing and logging. - Enhanced logging to include detailed statistics on sent activities and skipped actions, ensuring better traceability and monitoring of the activity flow.
This commit is contained in:
parent
1f191cb21e
commit
3f750ffda2
2 changed files with 443 additions and 54 deletions
|
|
@ -67,59 +67,131 @@ class AgentActivity:
|
|||
return "发布了一条帖子"
|
||||
|
||||
def _describe_like_post(self) -> str:
|
||||
post_id = self.action_args.get("post_id", "")
|
||||
return f"点赞了帖子#{post_id}" if post_id else "点赞了一条帖子"
|
||||
"""点赞帖子 - 包含帖子原文和作者信息"""
|
||||
post_content = self.action_args.get("post_content", "")
|
||||
post_author = self.action_args.get("post_author_name", "")
|
||||
|
||||
if post_content and post_author:
|
||||
return f"点赞了{post_author}的帖子:「{post_content}」"
|
||||
elif post_content:
|
||||
return f"点赞了一条帖子:「{post_content}」"
|
||||
elif post_author:
|
||||
return f"点赞了{post_author}的一条帖子"
|
||||
return "点赞了一条帖子"
|
||||
|
||||
def _describe_dislike_post(self) -> str:
|
||||
post_id = self.action_args.get("post_id", "")
|
||||
return f"踩了帖子#{post_id}" if post_id else "踩了一条帖子"
|
||||
"""踩帖子 - 包含帖子原文和作者信息"""
|
||||
post_content = self.action_args.get("post_content", "")
|
||||
post_author = self.action_args.get("post_author_name", "")
|
||||
|
||||
if post_content and post_author:
|
||||
return f"踩了{post_author}的帖子:「{post_content}」"
|
||||
elif post_content:
|
||||
return f"踩了一条帖子:「{post_content}」"
|
||||
elif post_author:
|
||||
return f"踩了{post_author}的一条帖子"
|
||||
return "踩了一条帖子"
|
||||
|
||||
def _describe_repost(self) -> str:
|
||||
post_id = self.action_args.get("post_id", "")
|
||||
return f"转发了帖子#{post_id}" if post_id else "转发了一条帖子"
|
||||
"""转发帖子 - 包含原帖内容和作者信息"""
|
||||
original_content = self.action_args.get("original_content", "")
|
||||
original_author = self.action_args.get("original_author_name", "")
|
||||
|
||||
if original_content and original_author:
|
||||
return f"转发了{original_author}的帖子:「{original_content}」"
|
||||
elif original_content:
|
||||
return f"转发了一条帖子:「{original_content}」"
|
||||
elif original_author:
|
||||
return f"转发了{original_author}的一条帖子"
|
||||
return "转发了一条帖子"
|
||||
|
||||
def _describe_quote_post(self) -> str:
|
||||
quoted_id = self.action_args.get("quoted_id", "")
|
||||
content = self.action_args.get("content", "")
|
||||
if quoted_id:
|
||||
if content:
|
||||
return f"引用帖子#{quoted_id}并评论:「{content}」"
|
||||
return f"引用了帖子#{quoted_id}"
|
||||
return "引用了一条帖子"
|
||||
"""引用帖子 - 包含原帖内容、作者信息和引用评论"""
|
||||
original_content = self.action_args.get("original_content", "")
|
||||
original_author = self.action_args.get("original_author_name", "")
|
||||
quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "")
|
||||
|
||||
base = ""
|
||||
if original_content and original_author:
|
||||
base = f"引用了{original_author}的帖子「{original_content}」"
|
||||
elif original_content:
|
||||
base = f"引用了一条帖子「{original_content}」"
|
||||
elif original_author:
|
||||
base = f"引用了{original_author}的一条帖子"
|
||||
else:
|
||||
base = "引用了一条帖子"
|
||||
|
||||
if quote_content:
|
||||
base += f",并评论道:「{quote_content}」"
|
||||
return base
|
||||
|
||||
def _describe_follow(self) -> str:
|
||||
target_id = self.action_args.get("user_id", "") or self.action_args.get("target_id", "")
|
||||
return f"关注了用户#{target_id}" if target_id else "关注了一个用户"
|
||||
"""关注用户 - 包含被关注用户的名称"""
|
||||
target_user_name = self.action_args.get("target_user_name", "")
|
||||
|
||||
if target_user_name:
|
||||
return f"关注了用户「{target_user_name}」"
|
||||
return "关注了一个用户"
|
||||
|
||||
def _describe_create_comment(self) -> str:
|
||||
"""发表评论 - 包含评论内容和所评论的帖子信息"""
|
||||
content = self.action_args.get("content", "")
|
||||
post_id = self.action_args.get("post_id", "")
|
||||
post_content = self.action_args.get("post_content", "")
|
||||
post_author = self.action_args.get("post_author_name", "")
|
||||
|
||||
if content:
|
||||
base = f"评论道:「{content}」"
|
||||
if post_id:
|
||||
base = f"在帖子#{post_id}下{base}"
|
||||
return base
|
||||
return f"在帖子#{post_id}下发表了评论" if post_id else "发表了评论"
|
||||
if post_content and post_author:
|
||||
return f"在{post_author}的帖子「{post_content}」下评论道:「{content}」"
|
||||
elif post_content:
|
||||
return f"在帖子「{post_content}」下评论道:「{content}」"
|
||||
elif post_author:
|
||||
return f"在{post_author}的帖子下评论道:「{content}」"
|
||||
return f"评论道:「{content}」"
|
||||
return "发表了评论"
|
||||
|
||||
def _describe_like_comment(self) -> str:
|
||||
comment_id = self.action_args.get("comment_id", "")
|
||||
return f"点赞了评论#{comment_id}" if comment_id else "点赞了一条评论"
|
||||
"""点赞评论 - 包含评论内容和作者信息"""
|
||||
comment_content = self.action_args.get("comment_content", "")
|
||||
comment_author = self.action_args.get("comment_author_name", "")
|
||||
|
||||
if comment_content and comment_author:
|
||||
return f"点赞了{comment_author}的评论:「{comment_content}」"
|
||||
elif comment_content:
|
||||
return f"点赞了一条评论:「{comment_content}」"
|
||||
elif comment_author:
|
||||
return f"点赞了{comment_author}的一条评论"
|
||||
return "点赞了一条评论"
|
||||
|
||||
def _describe_dislike_comment(self) -> str:
|
||||
comment_id = self.action_args.get("comment_id", "")
|
||||
return f"踩了评论#{comment_id}" if comment_id else "踩了一条评论"
|
||||
"""踩评论 - 包含评论内容和作者信息"""
|
||||
comment_content = self.action_args.get("comment_content", "")
|
||||
comment_author = self.action_args.get("comment_author_name", "")
|
||||
|
||||
if comment_content and comment_author:
|
||||
return f"踩了{comment_author}的评论:「{comment_content}」"
|
||||
elif comment_content:
|
||||
return f"踩了一条评论:「{comment_content}」"
|
||||
elif comment_author:
|
||||
return f"踩了{comment_author}的一条评论"
|
||||
return "踩了一条评论"
|
||||
|
||||
def _describe_search(self) -> str:
|
||||
"""搜索帖子 - 包含搜索关键词"""
|
||||
query = self.action_args.get("query", "") or self.action_args.get("keyword", "")
|
||||
return f"搜索了「{query}」" if query else "进行了搜索"
|
||||
|
||||
def _describe_search_user(self) -> str:
|
||||
"""搜索用户 - 包含搜索关键词"""
|
||||
query = self.action_args.get("query", "") or self.action_args.get("username", "")
|
||||
return f"搜索了用户「{query}」" if query else "搜索了用户"
|
||||
|
||||
def _describe_mute(self) -> str:
|
||||
target_id = self.action_args.get("user_id", "") or self.action_args.get("target_id", "")
|
||||
return f"屏蔽了用户#{target_id}" if target_id else "屏蔽了一个用户"
|
||||
"""屏蔽用户 - 包含被屏蔽用户的名称"""
|
||||
target_user_name = self.action_args.get("target_user_name", "")
|
||||
|
||||
if target_user_name:
|
||||
return f"屏蔽了用户「{target_user_name}」"
|
||||
return "屏蔽了一个用户"
|
||||
|
||||
def _describe_generic(self) -> str:
|
||||
# 对于未知的动作类型,生成通用描述
|
||||
|
|
@ -131,9 +203,18 @@ class ZepGraphMemoryUpdater:
|
|||
Zep图谱记忆更新器
|
||||
|
||||
监控模拟的actions日志文件,将新的agent活动实时更新到Zep图谱中。
|
||||
每条活动单独发送到Zep,确保图谱能正确解析实体和关系。
|
||||
按平台分组,每累积BATCH_SIZE条活动后批量发送到Zep。
|
||||
|
||||
所有有意义的行为都会被更新到Zep,action_args中会包含完整的上下文信息:
|
||||
- 点赞/踩的帖子原文
|
||||
- 转发/引用的帖子原文
|
||||
- 关注/屏蔽的用户名
|
||||
- 点赞/踩的评论原文
|
||||
"""
|
||||
|
||||
# 批量发送大小(每个平台累积多少条后发送)
|
||||
BATCH_SIZE = 5
|
||||
|
||||
# 发送间隔(秒),避免请求过快
|
||||
SEND_INTERVAL = 0.5
|
||||
|
||||
|
|
@ -160,16 +241,25 @@ class ZepGraphMemoryUpdater:
|
|||
# 活动队列
|
||||
self._activity_queue: Queue = Queue()
|
||||
|
||||
# 按平台分组的活动缓冲区(每个平台各自累积到BATCH_SIZE后批量发送)
|
||||
self._platform_buffers: Dict[str, List[AgentActivity]] = {
|
||||
'twitter': [],
|
||||
'reddit': [],
|
||||
}
|
||||
self._buffer_lock = threading.Lock()
|
||||
|
||||
# 控制标志
|
||||
self._running = False
|
||||
self._worker_thread: Optional[threading.Thread] = None
|
||||
|
||||
# 统计
|
||||
self._total_activities = 0
|
||||
self._total_sent = 0
|
||||
self._failed_count = 0
|
||||
self._total_activities = 0 # 实际添加到队列的活动数
|
||||
self._total_sent = 0 # 成功发送到Zep的批次数
|
||||
self._total_items_sent = 0 # 成功发送到Zep的活动条数
|
||||
self._failed_count = 0 # 发送失败的批次数
|
||||
self._skipped_count = 0 # 被过滤跳过的活动数(DO_NOTHING)
|
||||
|
||||
logger.info(f"ZepGraphMemoryUpdater 初始化完成: graph_id={graph_id}")
|
||||
logger.info(f"ZepGraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
|
||||
|
||||
def start(self):
|
||||
"""启动后台工作线程"""
|
||||
|
|
@ -197,22 +287,40 @@ class ZepGraphMemoryUpdater:
|
|||
|
||||
logger.info(f"ZepGraphMemoryUpdater 已停止: graph_id={self.graph_id}, "
|
||||
f"total_activities={self._total_activities}, "
|
||||
f"total_sent={self._total_sent}, "
|
||||
f"failed={self._failed_count}")
|
||||
f"batches_sent={self._total_sent}, "
|
||||
f"items_sent={self._total_items_sent}, "
|
||||
f"failed={self._failed_count}, "
|
||||
f"skipped={self._skipped_count}")
|
||||
|
||||
def add_activity(self, activity: AgentActivity):
|
||||
"""
|
||||
添加一个agent活动到队列
|
||||
|
||||
所有有意义的行为都会被添加到队列,包括:
|
||||
- CREATE_POST(发帖)
|
||||
- CREATE_COMMENT(评论)
|
||||
- QUOTE_POST(引用帖子)
|
||||
- SEARCH_POSTS(搜索帖子)
|
||||
- SEARCH_USER(搜索用户)
|
||||
- LIKE_POST/DISLIKE_POST(点赞/踩帖子)
|
||||
- REPOST(转发)
|
||||
- FOLLOW(关注)
|
||||
- MUTE(屏蔽)
|
||||
- LIKE_COMMENT/DISLIKE_COMMENT(点赞/踩评论)
|
||||
|
||||
action_args中会包含完整的上下文信息(如帖子原文、用户名等)。
|
||||
|
||||
Args:
|
||||
activity: Agent活动记录
|
||||
"""
|
||||
# 跳过DO_NOTHING类型的活动
|
||||
if activity.action_type == "DO_NOTHING":
|
||||
self._skipped_count += 1
|
||||
return
|
||||
|
||||
self._activity_queue.put(activity)
|
||||
self._total_activities += 1
|
||||
logger.debug(f"添加活动到Zep队列: {activity.agent_name} - {activity.action_type}")
|
||||
|
||||
def add_activity_from_dict(self, data: Dict[str, Any], platform: str):
|
||||
"""
|
||||
|
|
@ -239,16 +347,29 @@ class ZepGraphMemoryUpdater:
|
|||
self.add_activity(activity)
|
||||
|
||||
def _worker_loop(self):
|
||||
"""后台工作循环 - 逐条发送活动到Zep"""
|
||||
"""后台工作循环 - 按平台批量发送活动到Zep"""
|
||||
while self._running or not self._activity_queue.empty():
|
||||
try:
|
||||
# 尝试从队列获取活动(超时1秒)
|
||||
try:
|
||||
activity = self._activity_queue.get(timeout=1)
|
||||
# 立即发送单条活动
|
||||
self._send_single_activity(activity)
|
||||
# 发送间隔,避免请求过快
|
||||
time.sleep(self.SEND_INTERVAL)
|
||||
|
||||
# 将活动添加到对应平台的缓冲区
|
||||
platform = activity.platform.lower()
|
||||
with self._buffer_lock:
|
||||
if platform not in self._platform_buffers:
|
||||
self._platform_buffers[platform] = []
|
||||
self._platform_buffers[platform].append(activity)
|
||||
|
||||
# 检查该平台是否达到批量大小
|
||||
if len(self._platform_buffers[platform]) >= self.BATCH_SIZE:
|
||||
batch = self._platform_buffers[platform][:self.BATCH_SIZE]
|
||||
self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:]
|
||||
# 释放锁后再发送
|
||||
self._send_batch_activities(batch, platform)
|
||||
# 发送间隔,避免请求过快
|
||||
time.sleep(self.SEND_INTERVAL)
|
||||
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
|
|
@ -256,14 +377,20 @@ class ZepGraphMemoryUpdater:
|
|||
logger.error(f"工作循环异常: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
def _send_single_activity(self, activity: AgentActivity):
|
||||
def _send_batch_activities(self, activities: List[AgentActivity], platform: str):
|
||||
"""
|
||||
发送单条活动到Zep图谱
|
||||
批量发送活动到Zep图谱(合并为一条文本)
|
||||
|
||||
Args:
|
||||
activity: 单条Agent活动
|
||||
activities: Agent活动列表
|
||||
platform: 平台名称
|
||||
"""
|
||||
episode_text = activity.to_episode_text()
|
||||
if not activities:
|
||||
return
|
||||
|
||||
# 将多条活动合并为一条文本,用换行分隔
|
||||
episode_texts = [activity.to_episode_text() for activity in activities]
|
||||
combined_text = "\n".join(episode_texts)
|
||||
|
||||
# 带重试的发送
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
|
|
@ -271,38 +398,62 @@ class ZepGraphMemoryUpdater:
|
|||
self.client.graph.add(
|
||||
graph_id=self.graph_id,
|
||||
type="text",
|
||||
data=episode_text
|
||||
data=combined_text
|
||||
)
|
||||
|
||||
self._total_sent += 1
|
||||
logger.debug(f"成功发送活动到图谱 {self.graph_id}: {episode_text[:50]}...")
|
||||
self._total_items_sent += len(activities)
|
||||
logger.info(f"成功批量发送 {len(activities)} 条{platform}活动到图谱 {self.graph_id}")
|
||||
logger.debug(f"批量内容预览: {combined_text[:200]}...")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
if attempt < self.MAX_RETRIES - 1:
|
||||
logger.warning(f"发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}")
|
||||
logger.warning(f"批量发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}")
|
||||
time.sleep(self.RETRY_DELAY * (attempt + 1))
|
||||
else:
|
||||
logger.error(f"发送到Zep失败,已重试{self.MAX_RETRIES}次: {e}")
|
||||
logger.error(f"批量发送到Zep失败,已重试{self.MAX_RETRIES}次: {e}")
|
||||
self._failed_count += 1
|
||||
|
||||
def _flush_remaining(self):
|
||||
"""发送队列中剩余的活动(逐条发送)"""
|
||||
"""发送队列和缓冲区中剩余的活动"""
|
||||
# 首先处理队列中剩余的活动,添加到缓冲区
|
||||
while not self._activity_queue.empty():
|
||||
try:
|
||||
activity = self._activity_queue.get_nowait()
|
||||
self._send_single_activity(activity)
|
||||
platform = activity.platform.lower()
|
||||
with self._buffer_lock:
|
||||
if platform not in self._platform_buffers:
|
||||
self._platform_buffers[platform] = []
|
||||
self._platform_buffers[platform].append(activity)
|
||||
except Empty:
|
||||
break
|
||||
|
||||
# 然后发送各平台缓冲区中剩余的活动(即使不足BATCH_SIZE条)
|
||||
with self._buffer_lock:
|
||||
for platform, buffer in self._platform_buffers.items():
|
||||
if buffer:
|
||||
logger.info(f"发送{platform}平台剩余的 {len(buffer)} 条活动")
|
||||
self._send_batch_activities(buffer, platform)
|
||||
# 清空所有缓冲区
|
||||
for platform in self._platform_buffers:
|
||||
self._platform_buffers[platform] = []
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
with self._buffer_lock:
|
||||
buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()}
|
||||
|
||||
return {
|
||||
"graph_id": self.graph_id,
|
||||
"total_activities": self._total_activities,
|
||||
"total_sent": self._total_sent,
|
||||
"failed_count": self._failed_count,
|
||||
"batch_size": self.BATCH_SIZE,
|
||||
"total_activities": self._total_activities, # 添加到队列的活动总数
|
||||
"batches_sent": self._total_sent, # 成功发送的批次数
|
||||
"items_sent": self._total_items_sent, # 成功发送的活动条数
|
||||
"failed_count": self._failed_count, # 发送失败的批次数
|
||||
"skipped_count": self._skipped_count, # 被过滤跳过的活动数(DO_NOTHING)
|
||||
"queue_size": self._activity_queue.qsize(),
|
||||
"buffer_sizes": buffer_sizes, # 各平台缓冲区大小
|
||||
"running": self._running,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -615,7 +615,7 @@ def fetch_new_actions_from_db(
|
|||
agent_names: Dict[int, str]
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""
|
||||
从数据库中获取新的动作记录
|
||||
从数据库中获取新的动作记录,并补充完整的上下文信息
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径
|
||||
|
|
@ -624,7 +624,7 @@ def fetch_new_actions_from_db(
|
|||
|
||||
Returns:
|
||||
(actions_list, new_last_rowid)
|
||||
- actions_list: 动作列表,每个元素包含 agent_id, agent_name, action_type, action_args
|
||||
- actions_list: 动作列表,每个元素包含 agent_id, agent_name, action_type, action_args(含上下文信息)
|
||||
- new_last_rowid: 新的最大 rowid 值
|
||||
"""
|
||||
actions = []
|
||||
|
|
@ -684,6 +684,9 @@ def fetch_new_actions_from_db(
|
|||
# 转换动作类型名称
|
||||
action_type = ACTION_TYPE_MAP.get(action, action.upper())
|
||||
|
||||
# 补充上下文信息(帖子内容、用户名等)
|
||||
_enrich_action_context(cursor, action_type, simplified_args, agent_names)
|
||||
|
||||
actions.append({
|
||||
'agent_id': user_id,
|
||||
'agent_name': agent_names.get(user_id, f'Agent_{user_id}'),
|
||||
|
|
@ -698,6 +701,241 @@ def fetch_new_actions_from_db(
|
|||
return actions, new_last_rowid
|
||||
|
||||
|
||||
def _enrich_action_context(
|
||||
cursor,
|
||||
action_type: str,
|
||||
action_args: Dict[str, Any],
|
||||
agent_names: Dict[int, str]
|
||||
) -> None:
|
||||
"""
|
||||
为动作补充上下文信息(帖子内容、用户名等)
|
||||
|
||||
Args:
|
||||
cursor: 数据库游标
|
||||
action_type: 动作类型
|
||||
action_args: 动作参数(会被修改)
|
||||
agent_names: agent_id -> agent_name 映射
|
||||
"""
|
||||
try:
|
||||
# 点赞/踩帖子:补充帖子内容和作者
|
||||
if action_type in ('LIKE_POST', 'DISLIKE_POST'):
|
||||
post_id = action_args.get('post_id')
|
||||
if post_id:
|
||||
post_info = _get_post_info(cursor, post_id, agent_names)
|
||||
if post_info:
|
||||
action_args['post_content'] = post_info.get('content', '')
|
||||
action_args['post_author_name'] = post_info.get('author_name', '')
|
||||
|
||||
# 转发帖子:补充原帖内容和作者
|
||||
elif action_type == 'REPOST':
|
||||
new_post_id = action_args.get('new_post_id')
|
||||
if new_post_id:
|
||||
# 转发帖子的 original_post_id 指向原帖
|
||||
cursor.execute("""
|
||||
SELECT original_post_id FROM post WHERE post_id = ?
|
||||
""", (new_post_id,))
|
||||
row = cursor.fetchone()
|
||||
if row and row[0]:
|
||||
original_post_id = row[0]
|
||||
original_info = _get_post_info(cursor, original_post_id, agent_names)
|
||||
if original_info:
|
||||
action_args['original_content'] = original_info.get('content', '')
|
||||
action_args['original_author_name'] = original_info.get('author_name', '')
|
||||
|
||||
# 引用帖子:补充原帖内容、作者和引用评论
|
||||
elif action_type == 'QUOTE_POST':
|
||||
quoted_id = action_args.get('quoted_id')
|
||||
new_post_id = action_args.get('new_post_id')
|
||||
|
||||
if quoted_id:
|
||||
original_info = _get_post_info(cursor, quoted_id, agent_names)
|
||||
if original_info:
|
||||
action_args['original_content'] = original_info.get('content', '')
|
||||
action_args['original_author_name'] = original_info.get('author_name', '')
|
||||
|
||||
# 获取引用帖子的评论内容(quote_content)
|
||||
if new_post_id:
|
||||
cursor.execute("""
|
||||
SELECT quote_content FROM post WHERE post_id = ?
|
||||
""", (new_post_id,))
|
||||
row = cursor.fetchone()
|
||||
if row and row[0]:
|
||||
action_args['quote_content'] = row[0]
|
||||
|
||||
# 关注用户:补充被关注用户的名称
|
||||
elif action_type == 'FOLLOW':
|
||||
follow_id = action_args.get('follow_id')
|
||||
if follow_id:
|
||||
# 从 follow 表获取 followee_id
|
||||
cursor.execute("""
|
||||
SELECT followee_id FROM follow WHERE follow_id = ?
|
||||
""", (follow_id,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
followee_id = row[0]
|
||||
target_name = _get_user_name(cursor, followee_id, agent_names)
|
||||
if target_name:
|
||||
action_args['target_user_name'] = target_name
|
||||
|
||||
# 屏蔽用户:补充被屏蔽用户的名称
|
||||
elif action_type == 'MUTE':
|
||||
# 从 action_args 中获取 user_id 或 target_id
|
||||
target_id = action_args.get('user_id') or action_args.get('target_id')
|
||||
if target_id:
|
||||
target_name = _get_user_name(cursor, target_id, agent_names)
|
||||
if target_name:
|
||||
action_args['target_user_name'] = target_name
|
||||
|
||||
# 点赞/踩评论:补充评论内容和作者
|
||||
elif action_type in ('LIKE_COMMENT', 'DISLIKE_COMMENT'):
|
||||
comment_id = action_args.get('comment_id')
|
||||
if comment_id:
|
||||
comment_info = _get_comment_info(cursor, comment_id, agent_names)
|
||||
if comment_info:
|
||||
action_args['comment_content'] = comment_info.get('content', '')
|
||||
action_args['comment_author_name'] = comment_info.get('author_name', '')
|
||||
|
||||
# 发表评论:补充所评论的帖子信息
|
||||
elif action_type == 'CREATE_COMMENT':
|
||||
post_id = action_args.get('post_id')
|
||||
if post_id:
|
||||
post_info = _get_post_info(cursor, post_id, agent_names)
|
||||
if post_info:
|
||||
action_args['post_content'] = post_info.get('content', '')
|
||||
action_args['post_author_name'] = post_info.get('author_name', '')
|
||||
|
||||
except Exception as e:
|
||||
# 补充上下文失败不影响主流程
|
||||
print(f"补充动作上下文失败: {e}")
|
||||
|
||||
|
||||
def _get_post_info(
|
||||
cursor,
|
||||
post_id: int,
|
||||
agent_names: Dict[int, str]
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
获取帖子信息
|
||||
|
||||
Args:
|
||||
cursor: 数据库游标
|
||||
post_id: 帖子ID
|
||||
agent_names: agent_id -> agent_name 映射
|
||||
|
||||
Returns:
|
||||
包含 content 和 author_name 的字典,或 None
|
||||
"""
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT p.content, p.user_id, u.agent_id
|
||||
FROM post p
|
||||
LEFT JOIN user u ON p.user_id = u.user_id
|
||||
WHERE p.post_id = ?
|
||||
""", (post_id,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
content = row[0] or ''
|
||||
user_id = row[1]
|
||||
agent_id = row[2]
|
||||
|
||||
# 优先使用 agent_names 中的名称
|
||||
author_name = ''
|
||||
if agent_id is not None and agent_id in agent_names:
|
||||
author_name = agent_names[agent_id]
|
||||
elif user_id:
|
||||
# 从 user 表获取名称
|
||||
cursor.execute("SELECT name, user_name FROM user WHERE user_id = ?", (user_id,))
|
||||
user_row = cursor.fetchone()
|
||||
if user_row:
|
||||
author_name = user_row[0] or user_row[1] or ''
|
||||
|
||||
return {'content': content, 'author_name': author_name}
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _get_user_name(
|
||||
cursor,
|
||||
user_id: int,
|
||||
agent_names: Dict[int, str]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
获取用户名称
|
||||
|
||||
Args:
|
||||
cursor: 数据库游标
|
||||
user_id: 用户ID
|
||||
agent_names: agent_id -> agent_name 映射
|
||||
|
||||
Returns:
|
||||
用户名称,或 None
|
||||
"""
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT agent_id, name, user_name FROM user WHERE user_id = ?
|
||||
""", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
agent_id = row[0]
|
||||
name = row[1]
|
||||
user_name = row[2]
|
||||
|
||||
# 优先使用 agent_names 中的名称
|
||||
if agent_id is not None and agent_id in agent_names:
|
||||
return agent_names[agent_id]
|
||||
return name or user_name or ''
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _get_comment_info(
|
||||
cursor,
|
||||
comment_id: int,
|
||||
agent_names: Dict[int, str]
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
获取评论信息
|
||||
|
||||
Args:
|
||||
cursor: 数据库游标
|
||||
comment_id: 评论ID
|
||||
agent_names: agent_id -> agent_name 映射
|
||||
|
||||
Returns:
|
||||
包含 content 和 author_name 的字典,或 None
|
||||
"""
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT c.content, c.user_id, u.agent_id
|
||||
FROM comment c
|
||||
LEFT JOIN user u ON c.user_id = u.user_id
|
||||
WHERE c.comment_id = ?
|
||||
""", (comment_id,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
content = row[0] or ''
|
||||
user_id = row[1]
|
||||
agent_id = row[2]
|
||||
|
||||
# 优先使用 agent_names 中的名称
|
||||
author_name = ''
|
||||
if agent_id is not None and agent_id in agent_names:
|
||||
author_name = agent_names[agent_id]
|
||||
elif user_id:
|
||||
# 从 user 表获取名称
|
||||
cursor.execute("SELECT name, user_name FROM user WHERE user_id = ?", (user_id,))
|
||||
user_row = cursor.fetchone()
|
||||
if user_row:
|
||||
author_name = user_row[0] or user_row[1] or ''
|
||||
|
||||
return {'content': content, 'author_name': author_name}
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def create_model(config: Dict[str, Any], use_boost: bool = False):
|
||||
"""
|
||||
创建LLM模型
|
||||
|
|
|
|||
Loading…
Reference in a new issue