Enhance simulation configuration and management features
- Added support for a `max_rounds` parameter in simulation API, allowing users to limit the number of simulation rounds, improving control over simulation duration. - Updated README.md to reflect the new `max_rounds` parameter and its usage in simulation requests. - Enhanced error handling for `max_rounds` input validation to ensure it is a positive integer. - Modified simulation runner and related scripts to incorporate `max_rounds` functionality, ensuring consistent application across Twitter and Reddit simulations. - Improved logging to indicate when the number of rounds is truncated due to the `max_rounds` setting, enhancing traceability during simulation execution.
This commit is contained in:
parent
3c1d554152
commit
5b4f02f421
9 changed files with 243 additions and 53 deletions
|
|
@ -554,7 +554,8 @@ backend/
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"simulation_id": "sim_10b494550540",
|
"simulation_id": "sim_10b494550540",
|
||||||
"platform": "parallel"
|
"platform": "parallel",
|
||||||
|
"max_rounds": 100
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -562,6 +563,7 @@ backend/
|
||||||
|------|------|------|--------|------|
|
|------|------|------|--------|------|
|
||||||
| simulation_id | String | 是 | - | 模拟ID |
|
| simulation_id | String | 是 | - | 模拟ID |
|
||||||
| platform | String | 否 | parallel | 运行平台: twitter/reddit/parallel |
|
| platform | String | 否 | parallel | 运行平台: twitter/reddit/parallel |
|
||||||
|
| max_rounds | Integer | 否 | - | 最大模拟轮数,用于截断过长的模拟。如果配置中的轮数超过此值,将被截断 |
|
||||||
|
|
||||||
**返回示例**:
|
**返回示例**:
|
||||||
```json
|
```json
|
||||||
|
|
@ -573,11 +575,15 @@ backend/
|
||||||
"process_pid": 12345,
|
"process_pid": 12345,
|
||||||
"twitter_running": true,
|
"twitter_running": true,
|
||||||
"reddit_running": true,
|
"reddit_running": true,
|
||||||
"started_at": "2025-12-02T11:00:00"
|
"started_at": "2025-12-02T11:00:00",
|
||||||
|
"total_rounds": 100,
|
||||||
|
"max_rounds_applied": 100
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **说明**: `max_rounds_applied` 字段仅在指定了 `max_rounds` 参数时返回,表示实际应用的最大轮数限制。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
#### 5. 停止模拟
|
#### 5. 停止模拟
|
||||||
|
|
@ -1502,12 +1508,13 @@ curl -X POST http://localhost:5001/api/simulation/prepare/status \
|
||||||
|
|
||||||
# 等待status=completed
|
# 等待status=completed
|
||||||
|
|
||||||
# Step 7: 启动模拟
|
# Step 7: 启动模拟(可选指定max_rounds限制轮数)
|
||||||
curl -X POST http://localhost:5001/api/simulation/start \
|
curl -X POST http://localhost:5001/api/simulation/start \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"simulation_id": "sim_xxx",
|
"simulation_id": "sim_xxx",
|
||||||
"platform": "parallel"
|
"platform": "parallel",
|
||||||
|
"max_rounds": 50
|
||||||
}'
|
}'
|
||||||
|
|
||||||
# Step 8: 实时查询运行状态
|
# Step 8: 实时查询运行状态
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,11 @@ def create_app(config_class=Config):
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.config.from_object(config_class)
|
app.config.from_object(config_class)
|
||||||
|
|
||||||
|
# 设置JSON编码:确保中文直接显示(而不是 \uXXXX 格式)
|
||||||
|
# Flask >= 2.3 使用 app.json.ensure_ascii,旧版本使用 JSON_AS_ASCII 配置
|
||||||
|
if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'):
|
||||||
|
app.json.ensure_ascii = False
|
||||||
|
|
||||||
# 设置日志
|
# 设置日志
|
||||||
logger = setup_logger('mirofish')
|
logger = setup_logger('mirofish')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1114,7 +1114,8 @@ def start_simulation():
|
||||||
请求(JSON):
|
请求(JSON):
|
||||||
{
|
{
|
||||||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||||||
"platform": "parallel" // 可选: twitter / reddit / parallel (默认)
|
"platform": "parallel", // 可选: twitter / reddit / parallel (默认)
|
||||||
|
"max_rounds": 100 // 可选: 最大模拟轮数,用于截断过长的模拟
|
||||||
}
|
}
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
|
|
@ -1141,6 +1142,22 @@ def start_simulation():
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
platform = data.get('platform', 'parallel')
|
platform = data.get('platform', 'parallel')
|
||||||
|
max_rounds = data.get('max_rounds') # 可选:最大模拟轮数
|
||||||
|
|
||||||
|
# 验证 max_rounds 参数
|
||||||
|
if max_rounds is not None:
|
||||||
|
try:
|
||||||
|
max_rounds = int(max_rounds)
|
||||||
|
if max_rounds <= 0:
|
||||||
|
return jsonify({
|
||||||
|
"success": False,
|
||||||
|
"error": "max_rounds 必须是正整数"
|
||||||
|
}), 400
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return jsonify({
|
||||||
|
"success": False,
|
||||||
|
"error": "max_rounds 必须是有效的整数"
|
||||||
|
}), 400
|
||||||
|
|
||||||
if platform not in ['twitter', 'reddit', 'parallel']:
|
if platform not in ['twitter', 'reddit', 'parallel']:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
@ -1187,15 +1204,19 @@ def start_simulation():
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
# 启动模拟
|
# 启动模拟
|
||||||
run_state = SimulationRunner.start_simulation(simulation_id, platform)
|
run_state = SimulationRunner.start_simulation(simulation_id, platform, max_rounds)
|
||||||
|
|
||||||
# 更新模拟状态
|
# 更新模拟状态
|
||||||
state.status = SimulationStatus.RUNNING
|
state.status = SimulationStatus.RUNNING
|
||||||
manager._save_simulation_state(state)
|
manager._save_simulation_state(state)
|
||||||
|
|
||||||
|
response_data = run_state.to_dict()
|
||||||
|
if max_rounds:
|
||||||
|
response_data['max_rounds_applied'] = max_rounds
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": True,
|
"success": True,
|
||||||
"data": run_state.to_dict()
|
"data": response_data
|
||||||
})
|
})
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,9 @@ class Config:
|
||||||
SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key')
|
SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key')
|
||||||
DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true'
|
DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true'
|
||||||
|
|
||||||
|
# JSON配置 - 禁用ASCII转义,让中文直接显示(而不是 \uXXXX 格式)
|
||||||
|
JSON_AS_ASCII = False
|
||||||
|
|
||||||
# LLM配置(统一使用OpenAI格式)
|
# LLM配置(统一使用OpenAI格式)
|
||||||
LLM_API_KEY = os.environ.get('LLM_API_KEY')
|
LLM_API_KEY = os.environ.get('LLM_API_KEY')
|
||||||
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
|
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
|
||||||
|
|
|
||||||
|
|
@ -85,8 +85,8 @@ class TimeSimulationConfig:
|
||||||
# 模拟总时长(模拟小时数)
|
# 模拟总时长(模拟小时数)
|
||||||
total_simulation_hours: int = 72 # 默认模拟72小时(3天)
|
total_simulation_hours: int = 72 # 默认模拟72小时(3天)
|
||||||
|
|
||||||
# 每轮代表的时间(模拟分钟)
|
# 每轮代表的时间(模拟分钟)- 默认60分钟(1小时),加快时间流速
|
||||||
minutes_per_round: int = 30
|
minutes_per_round: int = 60
|
||||||
|
|
||||||
# 每小时激活的Agent数量范围
|
# 每小时激活的Agent数量范围
|
||||||
agents_per_hour_min: int = 5
|
agents_per_hour_min: int = 5
|
||||||
|
|
@ -205,7 +205,7 @@ class SimulationConfigGenerator:
|
||||||
|
|
||||||
采用分步生成策略:
|
采用分步生成策略:
|
||||||
1. 生成时间配置和事件配置(轻量级)
|
1. 生成时间配置和事件配置(轻量级)
|
||||||
2. 分批生成Agent配置(每批10-15个)
|
2. 分批生成Agent配置(每批10-20个)
|
||||||
3. 生成平台配置
|
3. 生成平台配置
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -214,6 +214,13 @@ class SimulationConfigGenerator:
|
||||||
# 每批生成的Agent数量
|
# 每批生成的Agent数量
|
||||||
AGENTS_PER_BATCH = 15
|
AGENTS_PER_BATCH = 15
|
||||||
|
|
||||||
|
# 各步骤的上下文截断长度(字符数)
|
||||||
|
TIME_CONFIG_CONTEXT_LENGTH = 10000 # 时间配置
|
||||||
|
EVENT_CONFIG_CONTEXT_LENGTH = 8000 # 事件配置
|
||||||
|
ENTITY_SUMMARY_LENGTH = 300 # 实体摘要
|
||||||
|
AGENT_SUMMARY_LENGTH = 300 # Agent配置中的实体摘要
|
||||||
|
ENTITIES_PER_TYPE_DISPLAY = 20 # 每类实体显示数量
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
|
@ -286,8 +293,9 @@ class SimulationConfigGenerator:
|
||||||
|
|
||||||
# ========== 步骤1: 生成时间配置 ==========
|
# ========== 步骤1: 生成时间配置 ==========
|
||||||
report_progress(1, "生成时间配置...")
|
report_progress(1, "生成时间配置...")
|
||||||
time_config_result = self._generate_time_config(context, len(entities))
|
num_entities = len(entities)
|
||||||
time_config = self._parse_time_config(time_config_result)
|
time_config_result = self._generate_time_config(context, num_entities)
|
||||||
|
time_config = self._parse_time_config(time_config_result, num_entities)
|
||||||
reasoning_parts.append(f"时间配置: {time_config_result.get('reasoning', '成功')}")
|
reasoning_parts.append(f"时间配置: {time_config_result.get('reasoning', '成功')}")
|
||||||
|
|
||||||
# ========== 步骤2: 生成事件配置 ==========
|
# ========== 步骤2: 生成事件配置 ==========
|
||||||
|
|
@ -411,11 +419,14 @@ class SimulationConfigGenerator:
|
||||||
|
|
||||||
for entity_type, type_entities in by_type.items():
|
for entity_type, type_entities in by_type.items():
|
||||||
lines.append(f"\n### {entity_type} ({len(type_entities)}个)")
|
lines.append(f"\n### {entity_type} ({len(type_entities)}个)")
|
||||||
for e in type_entities[:10]: # 每类最多显示10个
|
# 使用配置的显示数量和摘要长度
|
||||||
summary_preview = (e.summary[:100] + "...") if len(e.summary) > 100 else e.summary
|
display_count = self.ENTITIES_PER_TYPE_DISPLAY
|
||||||
|
summary_len = self.ENTITY_SUMMARY_LENGTH
|
||||||
|
for e in type_entities[:display_count]:
|
||||||
|
summary_preview = (e.summary[:summary_len] + "...") if len(e.summary) > summary_len else e.summary
|
||||||
lines.append(f"- {e.name}: {summary_preview}")
|
lines.append(f"- {e.name}: {summary_preview}")
|
||||||
if len(type_entities) > 10:
|
if len(type_entities) > display_count:
|
||||||
lines.append(f" ... 还有 {len(type_entities) - 10} 个")
|
lines.append(f" ... 还有 {len(type_entities) - display_count} 个")
|
||||||
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
@ -522,33 +533,56 @@ class SimulationConfigGenerator:
|
||||||
|
|
||||||
def _generate_time_config(self, context: str, num_entities: int) -> Dict[str, Any]:
|
def _generate_time_config(self, context: str, num_entities: int) -> Dict[str, Any]:
|
||||||
"""生成时间配置"""
|
"""生成时间配置"""
|
||||||
|
# 使用配置的上下文截断长度
|
||||||
|
context_truncated = context[:self.TIME_CONFIG_CONTEXT_LENGTH]
|
||||||
|
|
||||||
|
# 计算最大允许值(80%的agent数)
|
||||||
|
max_agents_allowed = max(1, int(num_entities * 0.9))
|
||||||
|
|
||||||
prompt = f"""基于以下模拟需求,生成时间模拟配置。
|
prompt = f"""基于以下模拟需求,生成时间模拟配置。
|
||||||
|
|
||||||
{context[:5000]}
|
{context_truncated}
|
||||||
|
|
||||||
## 任务
|
## 任务
|
||||||
请生成时间配置JSON,注意:
|
请生成时间配置JSON。
|
||||||
|
|
||||||
|
### 基本原则(仅供参考,需根据具体事件和参与群体灵活调整):
|
||||||
- 用户群体为中国人,需符合北京时间作息习惯
|
- 用户群体为中国人,需符合北京时间作息习惯
|
||||||
- 凌晨0-5点几乎无人活动(活跃度系数0.05)
|
- 凌晨0-5点几乎无人活动(活跃度系数0.05)
|
||||||
- 早上6-8点逐渐活跃(活跃度系数0.4)
|
- 早上6-8点逐渐活跃(活跃度系数0.4)
|
||||||
- 工作时间9-18点中等活跃(活跃度系数0.7)
|
- 工作时间9-18点中等活跃(活跃度系数0.7)
|
||||||
- 晚间19-22点是高峰期(活跃度系数1.5)
|
- 晚间19-22点是高峰期(活跃度系数1.5)
|
||||||
- 23点后活跃度下降(活跃度系数0.5)
|
- 23点后活跃度下降(活跃度系数0.5)
|
||||||
|
- 一般规律:凌晨低活跃、早间渐增、工作时段中等、晚间高峰
|
||||||
|
- **重要**:以下示例值仅供参考,你需要根据事件性质、参与群体特点来调整具体时段
|
||||||
|
- 例如:学生群体高峰可能是21-23点;媒体全天活跃;官方机构只在工作时间
|
||||||
|
- 例如:突发热点可能导致深夜也有讨论,off_peak_hours 可适当缩短
|
||||||
|
|
||||||
当前实体数量: {num_entities}
|
### 返回JSON格式(不要markdown)
|
||||||
|
|
||||||
返回JSON格式(不要markdown):
|
示例:
|
||||||
{{
|
{{
|
||||||
"total_simulation_hours": <72-168,根据事件性质决定>,
|
"total_simulation_hours": 72,
|
||||||
"minutes_per_round": <15-60>,
|
"minutes_per_round": 60,
|
||||||
"agents_per_hour_min": <每小时最少激活Agent数>,
|
"agents_per_hour_min": 5,
|
||||||
"agents_per_hour_max": <每小时最多激活Agent数>,
|
"agents_per_hour_max": 50,
|
||||||
"peak_hours": [19, 20, 21, 22],
|
"peak_hours": [19, 20, 21, 22],
|
||||||
"off_peak_hours": [0, 1, 2, 3, 4, 5],
|
"off_peak_hours": [0, 1, 2, 3, 4, 5],
|
||||||
"morning_hours": [6, 7, 8],
|
"morning_hours": [6, 7, 8],
|
||||||
"work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
|
"work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
|
||||||
"reasoning": "<简要说明>"
|
"reasoning": "针对该事件的时间配置说明"
|
||||||
}}"""
|
}}
|
||||||
|
|
||||||
|
字段说明:
|
||||||
|
- total_simulation_hours (int): 模拟总时长,24-168小时,突发事件短、持续话题长
|
||||||
|
- minutes_per_round (int): 每轮时长,30-120分钟,建议60分钟
|
||||||
|
- agents_per_hour_min (int): 每小时最少激活Agent数(取值范围: 1-{max_agents_allowed})
|
||||||
|
- agents_per_hour_max (int): 每小时最多激活Agent数(取值范围: 1-{max_agents_allowed})
|
||||||
|
- peak_hours (int数组): 高峰时段,根据事件参与群体调整
|
||||||
|
- off_peak_hours (int数组): 低谷时段,通常深夜凌晨
|
||||||
|
- morning_hours (int数组): 早间时段
|
||||||
|
- work_hours (int数组): 工作时段
|
||||||
|
- reasoning (string): 简要说明为什么这样配置"""
|
||||||
|
|
||||||
system_prompt = "你是社交媒体模拟专家。返回纯JSON格式,时间配置需符合中国人作息习惯。"
|
system_prompt = "你是社交媒体模拟专家。返回纯JSON格式,时间配置需符合中国人作息习惯。"
|
||||||
|
|
||||||
|
|
@ -562,23 +596,41 @@ class SimulationConfigGenerator:
|
||||||
"""获取默认时间配置(中国人作息)"""
|
"""获取默认时间配置(中国人作息)"""
|
||||||
return {
|
return {
|
||||||
"total_simulation_hours": 72,
|
"total_simulation_hours": 72,
|
||||||
"minutes_per_round": 30,
|
"minutes_per_round": 60, # 每轮1小时,加快时间流速
|
||||||
"agents_per_hour_min": max(1, num_entities // 15),
|
"agents_per_hour_min": max(1, num_entities // 15),
|
||||||
"agents_per_hour_max": max(5, num_entities // 5),
|
"agents_per_hour_max": max(5, num_entities // 5),
|
||||||
"peak_hours": [19, 20, 21, 22],
|
"peak_hours": [19, 20, 21, 22],
|
||||||
"off_peak_hours": [0, 1, 2, 3, 4, 5],
|
"off_peak_hours": [0, 1, 2, 3, 4, 5],
|
||||||
"morning_hours": [6, 7, 8],
|
"morning_hours": [6, 7, 8],
|
||||||
"work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
|
"work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
|
||||||
"reasoning": "使用默认中国人作息配置"
|
"reasoning": "使用默认中国人作息配置(每轮1小时)"
|
||||||
}
|
}
|
||||||
|
|
||||||
def _parse_time_config(self, result: Dict[str, Any]) -> TimeSimulationConfig:
|
def _parse_time_config(self, result: Dict[str, Any], num_entities: int) -> TimeSimulationConfig:
|
||||||
"""解析时间配置结果"""
|
"""解析时间配置结果,并验证agents_per_hour值不超过总agent数"""
|
||||||
|
# 获取原始值
|
||||||
|
agents_per_hour_min = result.get("agents_per_hour_min", max(1, num_entities // 15))
|
||||||
|
agents_per_hour_max = result.get("agents_per_hour_max", max(5, num_entities // 5))
|
||||||
|
|
||||||
|
# 验证并修正:确保不超过总agent数
|
||||||
|
if agents_per_hour_min > num_entities:
|
||||||
|
logger.warning(f"agents_per_hour_min ({agents_per_hour_min}) 超过总Agent数 ({num_entities}),已修正")
|
||||||
|
agents_per_hour_min = max(1, num_entities // 10)
|
||||||
|
|
||||||
|
if agents_per_hour_max > num_entities:
|
||||||
|
logger.warning(f"agents_per_hour_max ({agents_per_hour_max}) 超过总Agent数 ({num_entities}),已修正")
|
||||||
|
agents_per_hour_max = max(agents_per_hour_min + 1, num_entities // 2)
|
||||||
|
|
||||||
|
# 确保 min < max
|
||||||
|
if agents_per_hour_min >= agents_per_hour_max:
|
||||||
|
agents_per_hour_min = max(1, agents_per_hour_max // 2)
|
||||||
|
logger.warning(f"agents_per_hour_min >= max,已修正为 {agents_per_hour_min}")
|
||||||
|
|
||||||
return TimeSimulationConfig(
|
return TimeSimulationConfig(
|
||||||
total_simulation_hours=result.get("total_simulation_hours", 72),
|
total_simulation_hours=result.get("total_simulation_hours", 72),
|
||||||
minutes_per_round=result.get("minutes_per_round", 30),
|
minutes_per_round=result.get("minutes_per_round", 60), # 默认每轮1小时
|
||||||
agents_per_hour_min=result.get("agents_per_hour_min", 5),
|
agents_per_hour_min=agents_per_hour_min,
|
||||||
agents_per_hour_max=result.get("agents_per_hour_max", 20),
|
agents_per_hour_max=agents_per_hour_max,
|
||||||
peak_hours=result.get("peak_hours", [19, 20, 21, 22]),
|
peak_hours=result.get("peak_hours", [19, 20, 21, 22]),
|
||||||
off_peak_hours=result.get("off_peak_hours", [0, 1, 2, 3, 4, 5]),
|
off_peak_hours=result.get("off_peak_hours", [0, 1, 2, 3, 4, 5]),
|
||||||
off_peak_activity_multiplier=0.05, # 凌晨几乎无人
|
off_peak_activity_multiplier=0.05, # 凌晨几乎无人
|
||||||
|
|
@ -616,11 +668,14 @@ class SimulationConfigGenerator:
|
||||||
for t, examples in type_examples.items()
|
for t, examples in type_examples.items()
|
||||||
])
|
])
|
||||||
|
|
||||||
|
# 使用配置的上下文截断长度
|
||||||
|
context_truncated = context[:self.EVENT_CONFIG_CONTEXT_LENGTH]
|
||||||
|
|
||||||
prompt = f"""基于以下模拟需求,生成事件配置。
|
prompt = f"""基于以下模拟需求,生成事件配置。
|
||||||
|
|
||||||
模拟需求: {simulation_requirement}
|
模拟需求: {simulation_requirement}
|
||||||
|
|
||||||
{context[:3000]}
|
{context_truncated}
|
||||||
|
|
||||||
## 可用实体类型及示例
|
## 可用实体类型及示例
|
||||||
{type_info}
|
{type_info}
|
||||||
|
|
@ -761,14 +816,15 @@ class SimulationConfigGenerator:
|
||||||
) -> List[AgentActivityConfig]:
|
) -> List[AgentActivityConfig]:
|
||||||
"""分批生成Agent配置"""
|
"""分批生成Agent配置"""
|
||||||
|
|
||||||
# 构建实体信息
|
# 构建实体信息(使用配置的摘要长度)
|
||||||
entity_list = []
|
entity_list = []
|
||||||
|
summary_len = self.AGENT_SUMMARY_LENGTH
|
||||||
for i, e in enumerate(entities):
|
for i, e in enumerate(entities):
|
||||||
entity_list.append({
|
entity_list.append({
|
||||||
"agent_id": start_idx + i,
|
"agent_id": start_idx + i,
|
||||||
"entity_name": e.name,
|
"entity_name": e.name,
|
||||||
"entity_type": e.get_entity_type() or "Unknown",
|
"entity_type": e.get_entity_type() or "Unknown",
|
||||||
"summary": e.summary[:150] if e.summary else ""
|
"summary": e.summary[:summary_len] if e.summary else ""
|
||||||
})
|
})
|
||||||
|
|
||||||
prompt = f"""基于以下信息,为每个实体生成社交媒体活动配置。
|
prompt = f"""基于以下信息,为每个实体生成社交媒体活动配置。
|
||||||
|
|
|
||||||
|
|
@ -280,7 +280,8 @@ class SimulationRunner:
|
||||||
def start_simulation(
|
def start_simulation(
|
||||||
cls,
|
cls,
|
||||||
simulation_id: str,
|
simulation_id: str,
|
||||||
platform: str = "parallel" # twitter / reddit / parallel
|
platform: str = "parallel", # twitter / reddit / parallel
|
||||||
|
max_rounds: int = None # 最大模拟轮数(可选,用于截断过长的模拟)
|
||||||
) -> SimulationRunState:
|
) -> SimulationRunState:
|
||||||
"""
|
"""
|
||||||
启动模拟
|
启动模拟
|
||||||
|
|
@ -288,6 +289,7 @@ class SimulationRunner:
|
||||||
Args:
|
Args:
|
||||||
simulation_id: 模拟ID
|
simulation_id: 模拟ID
|
||||||
platform: 运行平台 (twitter/reddit/parallel)
|
platform: 运行平台 (twitter/reddit/parallel)
|
||||||
|
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SimulationRunState
|
SimulationRunState
|
||||||
|
|
@ -313,6 +315,13 @@ class SimulationRunner:
|
||||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||||
total_rounds = int(total_hours * 60 / minutes_per_round)
|
total_rounds = int(total_hours * 60 / minutes_per_round)
|
||||||
|
|
||||||
|
# 如果指定了最大轮数,则截断
|
||||||
|
if max_rounds is not None and max_rounds > 0:
|
||||||
|
original_rounds = total_rounds
|
||||||
|
total_rounds = min(total_rounds, max_rounds)
|
||||||
|
if total_rounds < original_rounds:
|
||||||
|
logger.info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
|
||||||
|
|
||||||
state = SimulationRunState(
|
state = SimulationRunState(
|
||||||
simulation_id=simulation_id,
|
simulation_id=simulation_id,
|
||||||
runner_status=RunnerStatus.STARTING,
|
runner_status=RunnerStatus.STARTING,
|
||||||
|
|
@ -358,6 +367,10 @@ class SimulationRunner:
|
||||||
"--config", config_path, # 使用完整配置文件路径
|
"--config", config_path, # 使用完整配置文件路径
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# 如果指定了最大轮数,添加到命令行参数
|
||||||
|
if max_rounds is not None and max_rounds > 0:
|
||||||
|
cmd.extend(["--max-rounds", str(max_rounds)])
|
||||||
|
|
||||||
# 创建主日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
|
# 创建主日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
|
||||||
main_log_path = os.path.join(sim_dir, "simulation.log")
|
main_log_path = os.path.join(sim_dir, "simulation.log")
|
||||||
main_log_file = open(main_log_path, 'w', encoding='utf-8')
|
main_log_file = open(main_log_path, 'w', encoding='utf-8')
|
||||||
|
|
|
||||||
|
|
@ -404,9 +404,18 @@ async def run_twitter_simulation(
|
||||||
config: Dict[str, Any],
|
config: Dict[str, Any],
|
||||||
simulation_dir: str,
|
simulation_dir: str,
|
||||||
action_logger: Optional[PlatformActionLogger] = None,
|
action_logger: Optional[PlatformActionLogger] = None,
|
||||||
main_logger: Optional[SimulationLogManager] = None
|
main_logger: Optional[SimulationLogManager] = None,
|
||||||
|
max_rounds: Optional[int] = None
|
||||||
):
|
):
|
||||||
"""运行Twitter模拟"""
|
"""运行Twitter模拟
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: 模拟配置
|
||||||
|
simulation_dir: 模拟目录
|
||||||
|
action_logger: 动作日志记录器
|
||||||
|
main_logger: 主日志管理器
|
||||||
|
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||||
|
"""
|
||||||
def log_info(msg):
|
def log_info(msg):
|
||||||
if main_logger:
|
if main_logger:
|
||||||
main_logger.info(f"[Twitter] {msg}")
|
main_logger.info(f"[Twitter] {msg}")
|
||||||
|
|
@ -494,6 +503,13 @@ async def run_twitter_simulation(
|
||||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||||
total_rounds = (total_hours * 60) // minutes_per_round
|
total_rounds = (total_hours * 60) // minutes_per_round
|
||||||
|
|
||||||
|
# 如果指定了最大轮数,则截断
|
||||||
|
if max_rounds is not None and max_rounds > 0:
|
||||||
|
original_rounds = total_rounds
|
||||||
|
total_rounds = min(total_rounds, max_rounds)
|
||||||
|
if total_rounds < original_rounds:
|
||||||
|
log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
|
||||||
|
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
|
|
||||||
for round_num in range(total_rounds):
|
for round_num in range(total_rounds):
|
||||||
|
|
@ -552,9 +568,18 @@ async def run_reddit_simulation(
|
||||||
config: Dict[str, Any],
|
config: Dict[str, Any],
|
||||||
simulation_dir: str,
|
simulation_dir: str,
|
||||||
action_logger: Optional[PlatformActionLogger] = None,
|
action_logger: Optional[PlatformActionLogger] = None,
|
||||||
main_logger: Optional[SimulationLogManager] = None
|
main_logger: Optional[SimulationLogManager] = None,
|
||||||
|
max_rounds: Optional[int] = None
|
||||||
):
|
):
|
||||||
"""运行Reddit模拟"""
|
"""运行Reddit模拟
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: 模拟配置
|
||||||
|
simulation_dir: 模拟目录
|
||||||
|
action_logger: 动作日志记录器
|
||||||
|
main_logger: 主日志管理器
|
||||||
|
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||||
|
"""
|
||||||
def log_info(msg):
|
def log_info(msg):
|
||||||
if main_logger:
|
if main_logger:
|
||||||
main_logger.info(f"[Reddit] {msg}")
|
main_logger.info(f"[Reddit] {msg}")
|
||||||
|
|
@ -649,6 +674,13 @@ async def run_reddit_simulation(
|
||||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||||
total_rounds = (total_hours * 60) // minutes_per_round
|
total_rounds = (total_hours * 60) // minutes_per_round
|
||||||
|
|
||||||
|
# 如果指定了最大轮数,则截断
|
||||||
|
if max_rounds is not None and max_rounds > 0:
|
||||||
|
original_rounds = total_rounds
|
||||||
|
total_rounds = min(total_rounds, max_rounds)
|
||||||
|
if total_rounds < original_rounds:
|
||||||
|
log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
|
||||||
|
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
|
|
||||||
for round_num in range(total_rounds):
|
for round_num in range(total_rounds):
|
||||||
|
|
@ -721,6 +753,12 @@ async def main():
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='只运行Reddit模拟'
|
help='只运行Reddit模拟'
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--max-rounds',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -746,9 +784,18 @@ async def main():
|
||||||
log_manager.info("=" * 60)
|
log_manager.info("=" * 60)
|
||||||
|
|
||||||
time_config = config.get("time_config", {})
|
time_config = config.get("time_config", {})
|
||||||
|
total_hours = time_config.get('total_simulation_hours', 72)
|
||||||
|
minutes_per_round = time_config.get('minutes_per_round', 30)
|
||||||
|
config_total_rounds = (total_hours * 60) // minutes_per_round
|
||||||
|
|
||||||
log_manager.info(f"模拟参数:")
|
log_manager.info(f"模拟参数:")
|
||||||
log_manager.info(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时")
|
log_manager.info(f" - 总模拟时长: {total_hours}小时")
|
||||||
log_manager.info(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟")
|
log_manager.info(f" - 每轮时间: {minutes_per_round}分钟")
|
||||||
|
log_manager.info(f" - 配置总轮数: {config_total_rounds}")
|
||||||
|
if args.max_rounds:
|
||||||
|
log_manager.info(f" - 最大轮数限制: {args.max_rounds}")
|
||||||
|
if args.max_rounds < config_total_rounds:
|
||||||
|
log_manager.info(f" - 实际执行轮数: {args.max_rounds} (已截断)")
|
||||||
log_manager.info(f" - Agent数量: {len(config.get('agent_configs', []))}")
|
log_manager.info(f" - Agent数量: {len(config.get('agent_configs', []))}")
|
||||||
|
|
||||||
log_manager.info("日志结构:")
|
log_manager.info("日志结构:")
|
||||||
|
|
@ -760,14 +807,14 @@ async def main():
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
|
|
||||||
if args.twitter_only:
|
if args.twitter_only:
|
||||||
await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager)
|
await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds)
|
||||||
elif args.reddit_only:
|
elif args.reddit_only:
|
||||||
await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager)
|
await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds)
|
||||||
else:
|
else:
|
||||||
# 并行运行(每个平台使用独立的日志记录器)
|
# 并行运行(每个平台使用独立的日志记录器)
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager),
|
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds),
|
||||||
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager),
|
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds),
|
||||||
)
|
)
|
||||||
|
|
||||||
total_elapsed = (datetime.now() - start_time).total_seconds()
|
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||||
|
|
|
||||||
|
|
@ -251,8 +251,12 @@ class RedditSimulationRunner:
|
||||||
|
|
||||||
return active_agents
|
return active_agents
|
||||||
|
|
||||||
async def run(self):
|
async def run(self, max_rounds: int = None):
|
||||||
"""运行Reddit模拟"""
|
"""运行Reddit模拟
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||||
|
"""
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("OASIS Reddit模拟")
|
print("OASIS Reddit模拟")
|
||||||
print(f"配置文件: {self.config_path}")
|
print(f"配置文件: {self.config_path}")
|
||||||
|
|
@ -264,10 +268,19 @@ class RedditSimulationRunner:
|
||||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||||
total_rounds = (total_hours * 60) // minutes_per_round
|
total_rounds = (total_hours * 60) // minutes_per_round
|
||||||
|
|
||||||
|
# 如果指定了最大轮数,则截断
|
||||||
|
if max_rounds is not None and max_rounds > 0:
|
||||||
|
original_rounds = total_rounds
|
||||||
|
total_rounds = min(total_rounds, max_rounds)
|
||||||
|
if total_rounds < original_rounds:
|
||||||
|
print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
|
||||||
|
|
||||||
print(f"\n模拟参数:")
|
print(f"\n模拟参数:")
|
||||||
print(f" - 总模拟时长: {total_hours}小时")
|
print(f" - 总模拟时长: {total_hours}小时")
|
||||||
print(f" - 每轮时间: {minutes_per_round}分钟")
|
print(f" - 每轮时间: {minutes_per_round}分钟")
|
||||||
print(f" - 总轮数: {total_rounds}")
|
print(f" - 总轮数: {total_rounds}")
|
||||||
|
if max_rounds:
|
||||||
|
print(f" - 最大轮数限制: {max_rounds}")
|
||||||
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
|
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
|
||||||
|
|
||||||
print("\n初始化LLM模型...")
|
print("\n初始化LLM模型...")
|
||||||
|
|
@ -380,6 +393,12 @@ async def main():
|
||||||
required=True,
|
required=True,
|
||||||
help='配置文件路径 (simulation_config.json)'
|
help='配置文件路径 (simulation_config.json)'
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--max-rounds',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -392,7 +411,7 @@ async def main():
|
||||||
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
||||||
|
|
||||||
runner = RedditSimulationRunner(args.config)
|
runner = RedditSimulationRunner(args.config)
|
||||||
await runner.run()
|
await runner.run(max_rounds=args.max_rounds)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -259,8 +259,12 @@ class TwitterSimulationRunner:
|
||||||
|
|
||||||
return active_agents
|
return active_agents
|
||||||
|
|
||||||
async def run(self):
|
async def run(self, max_rounds: int = None):
|
||||||
"""运行Twitter模拟"""
|
"""运行Twitter模拟
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||||
|
"""
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("OASIS Twitter模拟")
|
print("OASIS Twitter模拟")
|
||||||
print(f"配置文件: {self.config_path}")
|
print(f"配置文件: {self.config_path}")
|
||||||
|
|
@ -275,10 +279,19 @@ class TwitterSimulationRunner:
|
||||||
# 计算总轮数
|
# 计算总轮数
|
||||||
total_rounds = (total_hours * 60) // minutes_per_round
|
total_rounds = (total_hours * 60) // minutes_per_round
|
||||||
|
|
||||||
|
# 如果指定了最大轮数,则截断
|
||||||
|
if max_rounds is not None and max_rounds > 0:
|
||||||
|
original_rounds = total_rounds
|
||||||
|
total_rounds = min(total_rounds, max_rounds)
|
||||||
|
if total_rounds < original_rounds:
|
||||||
|
print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
|
||||||
|
|
||||||
print(f"\n模拟参数:")
|
print(f"\n模拟参数:")
|
||||||
print(f" - 总模拟时长: {total_hours}小时")
|
print(f" - 总模拟时长: {total_hours}小时")
|
||||||
print(f" - 每轮时间: {minutes_per_round}分钟")
|
print(f" - 每轮时间: {minutes_per_round}分钟")
|
||||||
print(f" - 总轮数: {total_rounds}")
|
print(f" - 总轮数: {total_rounds}")
|
||||||
|
if max_rounds:
|
||||||
|
print(f" - 最大轮数限制: {max_rounds}")
|
||||||
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
|
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
|
||||||
|
|
||||||
# 创建模型
|
# 创建模型
|
||||||
|
|
@ -393,6 +406,12 @@ async def main():
|
||||||
required=True,
|
required=True,
|
||||||
help='配置文件路径 (simulation_config.json)'
|
help='配置文件路径 (simulation_config.json)'
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--max-rounds',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -405,7 +424,7 @@ async def main():
|
||||||
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
||||||
|
|
||||||
runner = TwitterSimulationRunner(args.config)
|
runner = TwitterSimulationRunner(args.config)
|
||||||
await runner.run()
|
await runner.run(max_rounds=args.max_rounds)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue