Enhance interview prompt handling and update README.md
- Introduced a prefix to optimize interview prompts, ensuring agents respond directly with text without invoking tools. - Updated the simulation API to utilize the optimized prompts for individual and batch interviews. - Modified the `get_interview_history` function to allow for flexible platform querying, returning results from both Reddit and Twitter when no platform is specified. - Enhanced README.md to include new prompt optimization details and updated API usage examples for clarity.
This commit is contained in:
parent
1042d50306
commit
1f191cb21e
3 changed files with 120 additions and 33 deletions
|
|
@ -672,6 +672,12 @@ Flask后端 模拟脚本
|
|||
> **注意**: 所有Interview接口的参数都通过请求体(JSON)传递,包括simulation_id。
|
||||
>
|
||||
> **双平台模式说明**: 当不指定`platform`参数时,双平台模拟会同时采访两个平台并返回整合结果。
|
||||
>
|
||||
> **Prompt自动优化**: 系统会自动在用户提供的prompt前添加说明前缀,避免Agent调用工具:
|
||||
> ```
|
||||
> 原始prompt: "武汉大学发布撤销处分通告后你有什么看法"
|
||||
> 优化后: "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:武汉大学发布撤销处分通告后你有什么看法"
|
||||
> ```
|
||||
|
||||
#### 1. 采访单个Agent
|
||||
|
||||
|
|
@ -860,11 +866,11 @@ Flask后端 模拟脚本
|
|||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| simulation_id | String | 是 | - | 模拟ID |
|
||||
| platform | String | 否 | reddit | 平台类型(reddit/twitter) |
|
||||
| agent_id | Integer | 否 | - | 过滤Agent ID |
|
||||
| platform | String | 否 | null | 平台类型(reddit/twitter),不指定则返回两个平台的所有历史 |
|
||||
| agent_id | Integer | 否 | - | 只获取该Agent的采访历史 |
|
||||
| limit | Integer | 否 | 100 | 返回数量限制 |
|
||||
|
||||
**返回示例**:
|
||||
**返回示例(不指定platform,返回双平台历史)**:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
|
|
@ -875,7 +881,14 @@ Flask后端 模拟脚本
|
|||
"agent_id": 0,
|
||||
"response": "我认为...",
|
||||
"prompt": "你对这件事有什么看法?",
|
||||
"timestamp": "2025-12-08T10:00:00",
|
||||
"timestamp": "2025-12-08T10:00:02",
|
||||
"platform": "twitter"
|
||||
},
|
||||
{
|
||||
"agent_id": 0,
|
||||
"response": "从Reddit角度来看...",
|
||||
"prompt": "你对这件事有什么看法?",
|
||||
"timestamp": "2025-12-08T10:00:01",
|
||||
"platform": "reddit"
|
||||
},
|
||||
...
|
||||
|
|
|
|||
|
|
@ -19,6 +19,29 @@ from ..models.project import ProjectManager
|
|||
logger = get_logger('mirofish.api.simulation')
|
||||
|
||||
|
||||
# Interview prompt 优化前缀
|
||||
# 添加此前缀可以避免Agent调用工具,直接用文本回复
|
||||
INTERVIEW_PROMPT_PREFIX = "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:"
|
||||
|
||||
|
||||
def optimize_interview_prompt(prompt: str) -> str:
|
||||
"""
|
||||
优化Interview提问,添加前缀避免Agent调用工具
|
||||
|
||||
Args:
|
||||
prompt: 原始提问
|
||||
|
||||
Returns:
|
||||
优化后的提问
|
||||
"""
|
||||
if not prompt:
|
||||
return prompt
|
||||
# 避免重复添加前缀
|
||||
if prompt.startswith(INTERVIEW_PROMPT_PREFIX):
|
||||
return prompt
|
||||
return f"{INTERVIEW_PROMPT_PREFIX}{prompt}"
|
||||
|
||||
|
||||
# ============== 实体读取接口 ==============
|
||||
|
||||
@simulation_bp.route('/entities/<graph_id>', methods=['GET'])
|
||||
|
|
@ -1819,14 +1842,17 @@ def interview_agent():
|
|||
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
|
||||
}), 400
|
||||
|
||||
# 优化prompt,添加前缀避免Agent调用工具
|
||||
optimized_prompt = optimize_interview_prompt(prompt)
|
||||
|
||||
result = SimulationRunner.interview_agent(
|
||||
simulation_id=simulation_id,
|
||||
agent_id=agent_id,
|
||||
prompt=prompt,
|
||||
prompt=optimized_prompt,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
|
||||
return jsonify({
|
||||
"success": result.get("success", False),
|
||||
"data": result
|
||||
|
|
@ -1951,9 +1977,16 @@ def interview_agents_batch():
|
|||
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
|
||||
}), 400
|
||||
|
||||
# 优化每个采访项的prompt,添加前缀避免Agent调用工具
|
||||
optimized_interviews = []
|
||||
for interview in interviews:
|
||||
optimized_interview = interview.copy()
|
||||
optimized_interview['prompt'] = optimize_interview_prompt(interview.get('prompt', ''))
|
||||
optimized_interviews.append(optimized_interview)
|
||||
|
||||
result = SimulationRunner.interview_agents_batch(
|
||||
simulation_id=simulation_id,
|
||||
interviews=interviews,
|
||||
interviews=optimized_interviews,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
|
@ -2051,9 +2084,12 @@ def interview_all_agents():
|
|||
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
|
||||
}), 400
|
||||
|
||||
# 优化prompt,添加前缀避免Agent调用工具
|
||||
optimized_prompt = optimize_interview_prompt(prompt)
|
||||
|
||||
result = SimulationRunner.interview_all_agents(
|
||||
simulation_id=simulation_id,
|
||||
prompt=prompt,
|
||||
prompt=optimized_prompt,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
|
@ -2094,8 +2130,9 @@ def get_interview_history():
|
|||
请求(JSON):
|
||||
{
|
||||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||||
"platform": "reddit", // 可选,平台类型(reddit/twitter),默认reddit
|
||||
"agent_id": 0, // 可选,过滤Agent ID
|
||||
"platform": "reddit", // 可选,平台类型(reddit/twitter)
|
||||
// 不指定则返回两个平台的所有历史
|
||||
"agent_id": 0, // 可选,只获取该Agent的采访历史
|
||||
"limit": 100 // 可选,返回数量,默认100
|
||||
}
|
||||
|
||||
|
|
@ -2121,7 +2158,7 @@ def get_interview_history():
|
|||
data = request.get_json() or {}
|
||||
|
||||
simulation_id = data.get('simulation_id')
|
||||
platform = data.get('platform', 'reddit')
|
||||
platform = data.get('platform') # 不指定则返回两个平台的历史
|
||||
agent_id = data.get('agent_id')
|
||||
limit = data.get('limit', 100)
|
||||
|
||||
|
|
|
|||
|
|
@ -1279,30 +1279,16 @@ class SimulationRunner:
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def get_interview_history(
|
||||
def _get_interview_history_from_db(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
platform: str = "reddit",
|
||||
db_path: str,
|
||||
platform_name: str,
|
||||
agent_id: Optional[int] = None,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取Interview历史记录(从数据库读取)
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
platform: 平台类型(reddit/twitter)
|
||||
agent_id: 过滤Agent ID(可选)
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
Interview历史记录列表
|
||||
"""
|
||||
"""从单个数据库获取Interview历史"""
|
||||
import sqlite3
|
||||
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
db_path = os.path.join(sim_dir, f"{platform}_simulation.db")
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return []
|
||||
|
||||
|
|
@ -1312,8 +1298,6 @@ class SimulationRunner:
|
|||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 构建查询
|
||||
# 注意:ActionType.INTERVIEW.value 应该是字符串形式
|
||||
if agent_id is not None:
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
|
|
@ -1342,13 +1326,66 @@ class SimulationRunner:
|
|||
"response": info.get("response", info),
|
||||
"prompt": info.get("prompt", ""),
|
||||
"timestamp": created_at,
|
||||
"platform": platform
|
||||
"platform": platform_name
|
||||
})
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"读取Interview历史失败: {e}")
|
||||
logger.error(f"读取Interview历史失败 ({platform_name}): {e}")
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def get_interview_history(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
platform: str = None,
|
||||
agent_id: Optional[int] = None,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取Interview历史记录(从数据库读取)
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
platform: 平台类型(reddit/twitter/None)
|
||||
- "reddit": 只获取Reddit平台的历史
|
||||
- "twitter": 只获取Twitter平台的历史
|
||||
- None: 获取两个平台的所有历史
|
||||
agent_id: 指定Agent ID(可选,只获取该Agent的历史)
|
||||
limit: 每个平台返回数量限制
|
||||
|
||||
Returns:
|
||||
Interview历史记录列表
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
|
||||
results = []
|
||||
|
||||
# 确定要查询的平台
|
||||
if platform in ("reddit", "twitter"):
|
||||
platforms = [platform]
|
||||
else:
|
||||
# 不指定platform时,查询两个平台
|
||||
platforms = ["twitter", "reddit"]
|
||||
|
||||
for p in platforms:
|
||||
db_path = os.path.join(sim_dir, f"{p}_simulation.db")
|
||||
platform_results = cls._get_interview_history_from_db(
|
||||
db_path=db_path,
|
||||
platform_name=p,
|
||||
agent_id=agent_id,
|
||||
limit=limit
|
||||
)
|
||||
results.extend(platform_results)
|
||||
|
||||
# 按时间降序排序
|
||||
results.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
|
||||
|
||||
# 如果查询了多个平台,限制总数
|
||||
if len(platforms) > 1 and len(results) > limit:
|
||||
results = results[:limit]
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue