fix(report_agent): refine tool call handling and response validation; enforce strict separation between tool calls and final answers
This commit is contained in:
parent
a795405428
commit
25aa4f75d2
2 changed files with 103 additions and 39 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -43,6 +43,7 @@ htmlcov/
|
|||
|
||||
# Cursor
|
||||
.cursor/
|
||||
.claude/
|
||||
|
||||
# 文档与测试程序
|
||||
mydoc/
|
||||
|
|
|
|||
|
|
@ -714,21 +714,25 @@ SECTION_SYSTEM_PROMPT_TEMPLATE = """\
|
|||
- interview_agents: 采访模拟Agent,获取不同角色的第一人称观点和真实反应
|
||||
|
||||
═══════════════════════════════════════════════════════════════
|
||||
【ReACT工作流程】
|
||||
【工作流程】
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
||||
1. Thought: [分析需要什么信息,规划检索策略]
|
||||
2. Action: [调用一个工具获取信息](每轮只能调用一个工具!)
|
||||
<tool_call>
|
||||
{{"name": "工具名称", "parameters": {{"参数名": "参数值"}}}}
|
||||
</tool_call>
|
||||
3. Observation: [系统返回工具结果]
|
||||
4. 重复步骤1-3,直到收集到足够信息
|
||||
5. Final Answer: [基于检索结果撰写章节内容]
|
||||
每次回复你只能做以下两件事之一(不可同时做):
|
||||
|
||||
⚠️ 重要规则:
|
||||
- 每轮只能调用一个工具,不要在一次回复中放多个 <tool_call>
|
||||
- 当你认为信息足够时,必须以 "Final Answer:" 开头输出最终内容
|
||||
选项A - 调用工具:
|
||||
输出你的思考,然后用以下格式调用一个工具:
|
||||
<tool_call>
|
||||
{{"name": "工具名称", "parameters": {{"参数名": "参数值"}}}}
|
||||
</tool_call>
|
||||
系统会执行工具并把结果返回给你。你不需要也不能自己编写工具返回结果。
|
||||
|
||||
选项B - 输出最终内容:
|
||||
当你已通过工具获取了足够信息,以 "Final Answer:" 开头输出章节内容。
|
||||
|
||||
⚠️ 严格禁止:
|
||||
- 禁止在一次回复中同时包含工具调用和 Final Answer
|
||||
- 禁止自己编造工具返回结果(Observation),所有工具结果由系统注入
|
||||
- 每次回复最多调用一个工具
|
||||
|
||||
═══════════════════════════════════════════════════════════════
|
||||
【章节内容要求】
|
||||
|
|
@ -1056,21 +1060,20 @@ class ReportAgent:
|
|||
logger.error(f"工具执行失败: {tool_name}, 错误: {str(e)}")
|
||||
return f"工具执行失败: {str(e)}"
|
||||
|
||||
# 合法的工具名称集合,用于裸 JSON 兜底解析时校验
|
||||
VALID_TOOL_NAMES = {"insight_forge", "panorama_search", "quick_search", "interview_agents"}
|
||||
|
||||
def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从LLM响应中解析工具调用
|
||||
|
||||
支持的格式:
|
||||
<tool_call>
|
||||
{"name": "tool_name", "parameters": {"param1": "value1"}}
|
||||
</tool_call>
|
||||
|
||||
或者:
|
||||
[TOOL_CALL] tool_name(param1="value1", param2="value2")
|
||||
支持的格式(按优先级):
|
||||
1. <tool_call>{"name": "tool_name", "parameters": {...}}</tool_call>
|
||||
2. 裸 JSON(响应整体或单行就是一个工具调用 JSON)
|
||||
"""
|
||||
tool_calls = []
|
||||
|
||||
# 格式1: XML风格
|
||||
# 格式1: XML风格(标准格式)
|
||||
xml_pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
|
||||
for match in re.finditer(xml_pattern, response, re.DOTALL):
|
||||
try:
|
||||
|
|
@ -1079,24 +1082,47 @@ class ReportAgent:
|
|||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 格式2: 函数调用风格
|
||||
func_pattern = r'\[TOOL_CALL\]\s*(\w+)\s*\((.*?)\)'
|
||||
for match in re.finditer(func_pattern, response, re.DOTALL):
|
||||
tool_name = match.group(1)
|
||||
params_str = match.group(2)
|
||||
if tool_calls:
|
||||
return tool_calls
|
||||
|
||||
# 解析参数
|
||||
params = {}
|
||||
for param_match in re.finditer(r'(\w+)\s*=\s*["\']([^"\']*)["\']', params_str):
|
||||
params[param_match.group(1)] = param_match.group(2)
|
||||
# 格式2: 兜底 - LLM 直接输出裸 JSON(没包 <tool_call> 标签)
|
||||
# 只在格式1未匹配时尝试,避免误匹配正文中的 JSON
|
||||
stripped = response.strip()
|
||||
if stripped.startswith('{') and stripped.endswith('}'):
|
||||
try:
|
||||
call_data = json.loads(stripped)
|
||||
if self._is_valid_tool_call(call_data):
|
||||
tool_calls.append(call_data)
|
||||
return tool_calls
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
tool_calls.append({
|
||||
"name": tool_name,
|
||||
"parameters": params
|
||||
})
|
||||
# 响应可能包含思考文字 + 裸 JSON,尝试提取最后一个 JSON 对象
|
||||
json_pattern = r'(\{"(?:name|tool)"\s*:.*?\})\s*$'
|
||||
match = re.search(json_pattern, stripped, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
call_data = json.loads(match.group(1))
|
||||
if self._is_valid_tool_call(call_data):
|
||||
tool_calls.append(call_data)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _is_valid_tool_call(self, data: dict) -> bool:
|
||||
"""校验解析出的 JSON 是否是合法的工具调用"""
|
||||
# 支持 {"name": ..., "parameters": ...} 和 {"tool": ..., "params": ...} 两种键名
|
||||
tool_name = data.get("name") or data.get("tool")
|
||||
if tool_name and tool_name in self.VALID_TOOL_NAMES:
|
||||
# 统一键名为 name / parameters
|
||||
if "tool" in data:
|
||||
data["name"] = data.pop("tool")
|
||||
if "params" in data and "parameters" not in data:
|
||||
data["parameters"] = data.pop("params")
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_tools_description(self) -> str:
|
||||
"""生成工具描述文本"""
|
||||
desc_parts = ["可用工具:"]
|
||||
|
|
@ -1258,6 +1284,7 @@ class ReportAgent:
|
|||
tool_calls_count = 0
|
||||
max_iterations = 5 # 最大迭代轮数
|
||||
min_tool_calls = 3 # 最少工具调用次数
|
||||
conflict_retries = 0 # 工具调用与Final Answer同时出现的连续冲突次数
|
||||
used_tools = set() # 记录已调用过的工具名
|
||||
all_tools = {"insight_forge", "panorama_search", "quick_search", "interview_agents"}
|
||||
|
||||
|
|
@ -1297,6 +1324,42 @@ class ReportAgent:
|
|||
has_tool_calls = bool(tool_calls)
|
||||
has_final_answer = "Final Answer:" in response
|
||||
|
||||
# ── 冲突处理:LLM 同时输出了工具调用和 Final Answer ──
|
||||
if has_tool_calls and has_final_answer:
|
||||
conflict_retries += 1
|
||||
logger.warning(
|
||||
f"章节 {section.title} 第 {iteration+1} 轮: "
|
||||
f"LLM 同时输出工具调用和 Final Answer(第 {conflict_retries} 次冲突)"
|
||||
)
|
||||
|
||||
if conflict_retries <= 2:
|
||||
# 前两次:丢弃本次响应,要求 LLM 重新回复
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": (
|
||||
"【格式错误】你在一次回复中同时包含了工具调用和 Final Answer,这是不允许的。\n"
|
||||
"每次回复只能做以下两件事之一:\n"
|
||||
"- 调用一个工具(输出一个 <tool_call> 块,不要写 Final Answer)\n"
|
||||
"- 输出最终内容(以 'Final Answer:' 开头,不要包含 <tool_call>)\n"
|
||||
"请重新回复,只做其中一件事。"
|
||||
),
|
||||
})
|
||||
continue
|
||||
else:
|
||||
# 第三次:降级处理,截断到第一个工具调用,强制执行
|
||||
logger.warning(
|
||||
f"章节 {section.title}: 连续 {conflict_retries} 次冲突,"
|
||||
"降级为截断执行第一个工具调用"
|
||||
)
|
||||
first_tool_end = response.find('</tool_call>')
|
||||
if first_tool_end != -1:
|
||||
response = response[:first_tool_end + len('</tool_call>')]
|
||||
tool_calls = self._parse_tool_calls(response)
|
||||
has_tool_calls = bool(tool_calls)
|
||||
has_final_answer = False
|
||||
conflict_retries = 0
|
||||
|
||||
# 记录 LLM 响应日志
|
||||
if self.report_logger:
|
||||
self.report_logger.log_llm_response(
|
||||
|
|
|
|||
Loading…
Reference in a new issue