Enhance interview prompt handling and update README.md
- Introduced a prefix to optimize interview prompts, ensuring agents respond directly with text without invoking tools. - Updated the simulation API to utilize the optimized prompts for individual and batch interviews. - Modified the `get_interview_history` function to allow for flexible platform querying, returning results from both Reddit and Twitter when no platform is specified. - Enhanced README.md to include new prompt optimization details and updated API usage examples for clarity.
This commit is contained in:
parent
1042d50306
commit
1f191cb21e
3 changed files with 120 additions and 33 deletions
|
|
@ -672,6 +672,12 @@ Flask后端 模拟脚本
|
||||||
> **注意**: 所有Interview接口的参数都通过请求体(JSON)传递,包括simulation_id。
|
> **注意**: 所有Interview接口的参数都通过请求体(JSON)传递,包括simulation_id。
|
||||||
>
|
>
|
||||||
> **双平台模式说明**: 当不指定`platform`参数时,双平台模拟会同时采访两个平台并返回整合结果。
|
> **双平台模式说明**: 当不指定`platform`参数时,双平台模拟会同时采访两个平台并返回整合结果。
|
||||||
|
>
|
||||||
|
> **Prompt自动优化**: 系统会自动在用户提供的prompt前添加说明前缀,避免Agent调用工具:
|
||||||
|
> ```
|
||||||
|
> 原始prompt: "武汉大学发布撤销处分通告后你有什么看法"
|
||||||
|
> 优化后: "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:武汉大学发布撤销处分通告后你有什么看法"
|
||||||
|
> ```
|
||||||
|
|
||||||
#### 1. 采访单个Agent
|
#### 1. 采访单个Agent
|
||||||
|
|
||||||
|
|
@ -860,11 +866,11 @@ Flask后端 模拟脚本
|
||||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||||
|------|------|------|--------|------|
|
|------|------|------|--------|------|
|
||||||
| simulation_id | String | 是 | - | 模拟ID |
|
| simulation_id | String | 是 | - | 模拟ID |
|
||||||
| platform | String | 否 | reddit | 平台类型(reddit/twitter) |
|
| platform | String | 否 | null | 平台类型(reddit/twitter),不指定则返回两个平台的所有历史 |
|
||||||
| agent_id | Integer | 否 | - | 过滤Agent ID |
|
| agent_id | Integer | 否 | - | 只获取该Agent的采访历史 |
|
||||||
| limit | Integer | 否 | 100 | 返回数量限制 |
|
| limit | Integer | 否 | 100 | 返回数量限制 |
|
||||||
|
|
||||||
**返回示例**:
|
**返回示例(不指定platform,返回双平台历史)**:
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
|
|
@ -875,7 +881,14 @@ Flask后端 模拟脚本
|
||||||
"agent_id": 0,
|
"agent_id": 0,
|
||||||
"response": "我认为...",
|
"response": "我认为...",
|
||||||
"prompt": "你对这件事有什么看法?",
|
"prompt": "你对这件事有什么看法?",
|
||||||
"timestamp": "2025-12-08T10:00:00",
|
"timestamp": "2025-12-08T10:00:02",
|
||||||
|
"platform": "twitter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"agent_id": 0,
|
||||||
|
"response": "从Reddit角度来看...",
|
||||||
|
"prompt": "你对这件事有什么看法?",
|
||||||
|
"timestamp": "2025-12-08T10:00:01",
|
||||||
"platform": "reddit"
|
"platform": "reddit"
|
||||||
},
|
},
|
||||||
...
|
...
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,29 @@ from ..models.project import ProjectManager
|
||||||
logger = get_logger('mirofish.api.simulation')
|
logger = get_logger('mirofish.api.simulation')
|
||||||
|
|
||||||
|
|
||||||
|
# Interview prompt 优化前缀
|
||||||
|
# 添加此前缀可以避免Agent调用工具,直接用文本回复
|
||||||
|
INTERVIEW_PROMPT_PREFIX = "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:"
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_interview_prompt(prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
优化Interview提问,添加前缀避免Agent调用工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 原始提问
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
优化后的提问
|
||||||
|
"""
|
||||||
|
if not prompt:
|
||||||
|
return prompt
|
||||||
|
# 避免重复添加前缀
|
||||||
|
if prompt.startswith(INTERVIEW_PROMPT_PREFIX):
|
||||||
|
return prompt
|
||||||
|
return f"{INTERVIEW_PROMPT_PREFIX}{prompt}"
|
||||||
|
|
||||||
|
|
||||||
# ============== 实体读取接口 ==============
|
# ============== 实体读取接口 ==============
|
||||||
|
|
||||||
@simulation_bp.route('/entities/<graph_id>', methods=['GET'])
|
@simulation_bp.route('/entities/<graph_id>', methods=['GET'])
|
||||||
|
|
@ -1819,14 +1842,17 @@ def interview_agent():
|
||||||
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
|
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
|
# 优化prompt,添加前缀避免Agent调用工具
|
||||||
|
optimized_prompt = optimize_interview_prompt(prompt)
|
||||||
|
|
||||||
result = SimulationRunner.interview_agent(
|
result = SimulationRunner.interview_agent(
|
||||||
simulation_id=simulation_id,
|
simulation_id=simulation_id,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
prompt=prompt,
|
prompt=optimized_prompt,
|
||||||
platform=platform,
|
platform=platform,
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": result.get("success", False),
|
"success": result.get("success", False),
|
||||||
"data": result
|
"data": result
|
||||||
|
|
@ -1951,9 +1977,16 @@ def interview_agents_batch():
|
||||||
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
|
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
|
# 优化每个采访项的prompt,添加前缀避免Agent调用工具
|
||||||
|
optimized_interviews = []
|
||||||
|
for interview in interviews:
|
||||||
|
optimized_interview = interview.copy()
|
||||||
|
optimized_interview['prompt'] = optimize_interview_prompt(interview.get('prompt', ''))
|
||||||
|
optimized_interviews.append(optimized_interview)
|
||||||
|
|
||||||
result = SimulationRunner.interview_agents_batch(
|
result = SimulationRunner.interview_agents_batch(
|
||||||
simulation_id=simulation_id,
|
simulation_id=simulation_id,
|
||||||
interviews=interviews,
|
interviews=optimized_interviews,
|
||||||
platform=platform,
|
platform=platform,
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
|
|
@ -2051,9 +2084,12 @@ def interview_all_agents():
|
||||||
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
|
"error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。"
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
|
# 优化prompt,添加前缀避免Agent调用工具
|
||||||
|
optimized_prompt = optimize_interview_prompt(prompt)
|
||||||
|
|
||||||
result = SimulationRunner.interview_all_agents(
|
result = SimulationRunner.interview_all_agents(
|
||||||
simulation_id=simulation_id,
|
simulation_id=simulation_id,
|
||||||
prompt=prompt,
|
prompt=optimized_prompt,
|
||||||
platform=platform,
|
platform=platform,
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
|
|
@ -2094,8 +2130,9 @@ def get_interview_history():
|
||||||
请求(JSON):
|
请求(JSON):
|
||||||
{
|
{
|
||||||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||||||
"platform": "reddit", // 可选,平台类型(reddit/twitter),默认reddit
|
"platform": "reddit", // 可选,平台类型(reddit/twitter)
|
||||||
"agent_id": 0, // 可选,过滤Agent ID
|
// 不指定则返回两个平台的所有历史
|
||||||
|
"agent_id": 0, // 可选,只获取该Agent的采访历史
|
||||||
"limit": 100 // 可选,返回数量,默认100
|
"limit": 100 // 可选,返回数量,默认100
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2121,7 +2158,7 @@ def get_interview_history():
|
||||||
data = request.get_json() or {}
|
data = request.get_json() or {}
|
||||||
|
|
||||||
simulation_id = data.get('simulation_id')
|
simulation_id = data.get('simulation_id')
|
||||||
platform = data.get('platform', 'reddit')
|
platform = data.get('platform') # 不指定则返回两个平台的历史
|
||||||
agent_id = data.get('agent_id')
|
agent_id = data.get('agent_id')
|
||||||
limit = data.get('limit', 100)
|
limit = data.get('limit', 100)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1279,30 +1279,16 @@ class SimulationRunner:
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_interview_history(
|
def _get_interview_history_from_db(
|
||||||
cls,
|
cls,
|
||||||
simulation_id: str,
|
db_path: str,
|
||||||
platform: str = "reddit",
|
platform_name: str,
|
||||||
agent_id: Optional[int] = None,
|
agent_id: Optional[int] = None,
|
||||||
limit: int = 100
|
limit: int = 100
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""从单个数据库获取Interview历史"""
|
||||||
获取Interview历史记录(从数据库读取)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
simulation_id: 模拟ID
|
|
||||||
platform: 平台类型(reddit/twitter)
|
|
||||||
agent_id: 过滤Agent ID(可选)
|
|
||||||
limit: 返回数量限制
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Interview历史记录列表
|
|
||||||
"""
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
|
||||||
db_path = os.path.join(sim_dir, f"{platform}_simulation.db")
|
|
||||||
|
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -1312,8 +1298,6 @@ class SimulationRunner:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 构建查询
|
|
||||||
# 注意:ActionType.INTERVIEW.value 应该是字符串形式
|
|
||||||
if agent_id is not None:
|
if agent_id is not None:
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT user_id, info, created_at
|
SELECT user_id, info, created_at
|
||||||
|
|
@ -1342,13 +1326,66 @@ class SimulationRunner:
|
||||||
"response": info.get("response", info),
|
"response": info.get("response", info),
|
||||||
"prompt": info.get("prompt", ""),
|
"prompt": info.get("prompt", ""),
|
||||||
"timestamp": created_at,
|
"timestamp": created_at,
|
||||||
"platform": platform
|
"platform": platform_name
|
||||||
})
|
})
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取Interview历史失败: {e}")
|
logger.error(f"读取Interview历史失败 ({platform_name}): {e}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_interview_history(
|
||||||
|
cls,
|
||||||
|
simulation_id: str,
|
||||||
|
platform: str = None,
|
||||||
|
agent_id: Optional[int] = None,
|
||||||
|
limit: int = 100
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取Interview历史记录(从数据库读取)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
simulation_id: 模拟ID
|
||||||
|
platform: 平台类型(reddit/twitter/None)
|
||||||
|
- "reddit": 只获取Reddit平台的历史
|
||||||
|
- "twitter": 只获取Twitter平台的历史
|
||||||
|
- None: 获取两个平台的所有历史
|
||||||
|
agent_id: 指定Agent ID(可选,只获取该Agent的历史)
|
||||||
|
limit: 每个平台返回数量限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Interview历史记录列表
|
||||||
|
"""
|
||||||
|
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# 确定要查询的平台
|
||||||
|
if platform in ("reddit", "twitter"):
|
||||||
|
platforms = [platform]
|
||||||
|
else:
|
||||||
|
# 不指定platform时,查询两个平台
|
||||||
|
platforms = ["twitter", "reddit"]
|
||||||
|
|
||||||
|
for p in platforms:
|
||||||
|
db_path = os.path.join(sim_dir, f"{p}_simulation.db")
|
||||||
|
platform_results = cls._get_interview_history_from_db(
|
||||||
|
db_path=db_path,
|
||||||
|
platform_name=p,
|
||||||
|
agent_id=agent_id,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
results.extend(platform_results)
|
||||||
|
|
||||||
|
# 按时间降序排序
|
||||||
|
results.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
|
||||||
|
|
||||||
|
# 如果查询了多个平台,限制总数
|
||||||
|
if len(platforms) > 1 and len(results) > limit:
|
||||||
|
results = results[:limit]
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue