Enhance interview prompt handling and update README.md

- Introduced a prefix to optimize interview prompts, ensuring agents respond directly with text without invoking tools.
- Updated the simulation API to utilize the optimized prompts for individual and batch interviews.
- Modified the `get_interview_history` function to allow for flexible platform querying, returning results from both Reddit and Twitter when no platform is specified.
- Enhanced README.md to include new prompt optimization details and updated API usage examples for clarity.
This commit is contained in:
666ghj 2025-12-08 16:08:33 +08:00
parent 1042d50306
commit 1f191cb21e
3 changed files with 120 additions and 33 deletions

View file

@ -672,6 +672,12 @@ Flask后端 模拟脚本
> **注意**: 所有Interview接口的参数都通过请求体(JSON)传递包括simulation_id。
>
> **双平台模式说明**: 当不指定`platform`参数时,双平台模拟会同时采访两个平台并返回整合结果。
>
> **Prompt自动优化**: 系统会自动在用户提供的prompt前添加说明前缀避免Agent调用工具
> ```
> 原始prompt: "武汉大学发布撤销处分通告后你有什么看法"
> 优化后: "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:武汉大学发布撤销处分通告后你有什么看法"
> ```
#### 1. 采访单个Agent
@ -860,11 +866,11 @@ Flask后端 模拟脚本
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|------|------|------|--------|------|
| simulation_id | String | 是 | - | 模拟ID |
| platform | String | 否 | reddit | 平台类型reddit/twitter |
| agent_id | Integer | 否 | - | 过滤Agent ID |
| platform | String | 否 | null | 平台类型reddit/twitter不指定则返回两个平台的所有历史 |
| agent_id | Integer | 否 | - | 只获取该Agent的采访历史 |
| limit | Integer | 否 | 100 | 返回数量限制 |
**返回示例**:
**返回示例不指定platform返回双平台历史**:
```json
{
"success": true,
@ -875,7 +881,14 @@ Flask后端 模拟脚本
"agent_id": 0,
"response": "我认为...",
"prompt": "你对这件事有什么看法?",
"timestamp": "2025-12-08T10:00:00",
"timestamp": "2025-12-08T10:00:02",
"platform": "twitter"
},
{
"agent_id": 0,
"response": "从Reddit角度来看...",
"prompt": "你对这件事有什么看法?",
"timestamp": "2025-12-08T10:00:01",
"platform": "reddit"
},
...

View file

@ -19,6 +19,29 @@ from ..models.project import ProjectManager
logger = get_logger('mirofish.api.simulation')
# Interview prompt 优化前缀
# 添加此前缀可以避免Agent调用工具直接用文本回复
INTERVIEW_PROMPT_PREFIX = "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:"
def optimize_interview_prompt(prompt: str) -> str:
"""
优化Interview提问添加前缀避免Agent调用工具
Args:
prompt: 原始提问
Returns:
优化后的提问
"""
if not prompt:
return prompt
# 避免重复添加前缀
if prompt.startswith(INTERVIEW_PROMPT_PREFIX):
return prompt
return f"{INTERVIEW_PROMPT_PREFIX}{prompt}"
# ============== 实体读取接口 ==============
@simulation_bp.route('/entities/<graph_id>', methods=['GET'])
@ -1819,14 +1842,17 @@ def interview_agent():
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
}), 400
# 优化prompt添加前缀避免Agent调用工具
optimized_prompt = optimize_interview_prompt(prompt)
result = SimulationRunner.interview_agent(
simulation_id=simulation_id,
agent_id=agent_id,
prompt=prompt,
prompt=optimized_prompt,
platform=platform,
timeout=timeout
)
return jsonify({
"success": result.get("success", False),
"data": result
@ -1951,9 +1977,16 @@ def interview_agents_batch():
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
}), 400
# 优化每个采访项的prompt添加前缀避免Agent调用工具
optimized_interviews = []
for interview in interviews:
optimized_interview = interview.copy()
optimized_interview['prompt'] = optimize_interview_prompt(interview.get('prompt', ''))
optimized_interviews.append(optimized_interview)
result = SimulationRunner.interview_agents_batch(
simulation_id=simulation_id,
interviews=interviews,
interviews=optimized_interviews,
platform=platform,
timeout=timeout
)
@ -2051,9 +2084,12 @@ def interview_all_agents():
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
}), 400
# 优化prompt添加前缀避免Agent调用工具
optimized_prompt = optimize_interview_prompt(prompt)
result = SimulationRunner.interview_all_agents(
simulation_id=simulation_id,
prompt=prompt,
prompt=optimized_prompt,
platform=platform,
timeout=timeout
)
@ -2094,8 +2130,9 @@ def get_interview_history():
请求JSON
{
"simulation_id": "sim_xxxx", // 必填模拟ID
"platform": "reddit", // 可选平台类型reddit/twitter默认reddit
"agent_id": 0, // 可选过滤Agent ID
"platform": "reddit", // 可选平台类型reddit/twitter
// 不指定则返回两个平台的所有历史
"agent_id": 0, // 可选只获取该Agent的采访历史
"limit": 100 // 可选返回数量默认100
}
@ -2121,7 +2158,7 @@ def get_interview_history():
data = request.get_json() or {}
simulation_id = data.get('simulation_id')
platform = data.get('platform', 'reddit')
platform = data.get('platform') # 不指定则返回两个平台的历史
agent_id = data.get('agent_id')
limit = data.get('limit', 100)

View file

@ -1279,30 +1279,16 @@ class SimulationRunner:
}
@classmethod
def get_interview_history(
def _get_interview_history_from_db(
cls,
simulation_id: str,
platform: str = "reddit",
db_path: str,
platform_name: str,
agent_id: Optional[int] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""
获取Interview历史记录从数据库读取
Args:
simulation_id: 模拟ID
platform: 平台类型reddit/twitter
agent_id: 过滤Agent ID可选
limit: 返回数量限制
Returns:
Interview历史记录列表
"""
"""从单个数据库获取Interview历史"""
import sqlite3
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
db_path = os.path.join(sim_dir, f"{platform}_simulation.db")
if not os.path.exists(db_path):
return []
@ -1312,8 +1298,6 @@ class SimulationRunner:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# 构建查询
# 注意ActionType.INTERVIEW.value 应该是字符串形式
if agent_id is not None:
cursor.execute("""
SELECT user_id, info, created_at
@ -1342,13 +1326,66 @@ class SimulationRunner:
"response": info.get("response", info),
"prompt": info.get("prompt", ""),
"timestamp": created_at,
"platform": platform
"platform": platform_name
})
conn.close()
except Exception as e:
logger.error(f"读取Interview历史失败: {e}")
logger.error(f"读取Interview历史失败 ({platform_name}): {e}")
return results
@classmethod
def get_interview_history(
cls,
simulation_id: str,
platform: str = None,
agent_id: Optional[int] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""
获取Interview历史记录从数据库读取
Args:
simulation_id: 模拟ID
platform: 平台类型reddit/twitter/None
- "reddit": 只获取Reddit平台的历史
- "twitter": 只获取Twitter平台的历史
- None: 获取两个平台的所有历史
agent_id: 指定Agent ID可选只获取该Agent的历史
limit: 每个平台返回数量限制
Returns:
Interview历史记录列表
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
results = []
# 确定要查询的平台
if platform in ("reddit", "twitter"):
platforms = [platform]
else:
# 不指定platform时查询两个平台
platforms = ["twitter", "reddit"]
for p in platforms:
db_path = os.path.join(sim_dir, f"{p}_simulation.db")
platform_results = cls._get_interview_history_from_db(
db_path=db_path,
platform_name=p,
agent_id=agent_id,
limit=limit
)
results.extend(platform_results)
# 按时间降序排序
results.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
# 如果查询了多个平台,限制总数
if len(platforms) > 1 and len(results) > limit:
results = results[:limit]
return results