Enhance simulation configuration and management features

- Added support for a `max_rounds` parameter in simulation API, allowing users to limit the number of simulation rounds, improving control over simulation duration.
- Updated README.md to reflect the new `max_rounds` parameter and its usage in simulation requests.
- Enhanced error handling for `max_rounds` input validation to ensure it is a positive integer.
- Modified simulation runner and related scripts to incorporate `max_rounds` functionality, ensuring consistent application across Twitter and Reddit simulations.
- Improved logging to indicate when the number of rounds is truncated due to the `max_rounds` setting, enhancing traceability during simulation execution.
This commit is contained in:
666ghj 2025-12-05 15:50:54 +08:00
parent 3c1d554152
commit 5b4f02f421
9 changed files with 243 additions and 53 deletions

View file

@ -554,7 +554,8 @@ backend/
```json ```json
{ {
"simulation_id": "sim_10b494550540", "simulation_id": "sim_10b494550540",
"platform": "parallel" "platform": "parallel",
"max_rounds": 100
} }
``` ```
@ -562,6 +563,7 @@ backend/
|------|------|------|--------|------| |------|------|------|--------|------|
| simulation_id | String | 是 | - | 模拟ID | | simulation_id | String | 是 | - | 模拟ID |
| platform | String | 否 | parallel | 运行平台: twitter/reddit/parallel | | platform | String | 否 | parallel | 运行平台: twitter/reddit/parallel |
| max_rounds | Integer | 否 | - | 最大模拟轮数,用于截断过长的模拟。如果配置中的轮数超过此值,将被截断 |
**返回示例**: **返回示例**:
```json ```json
@ -573,11 +575,15 @@ backend/
"process_pid": 12345, "process_pid": 12345,
"twitter_running": true, "twitter_running": true,
"reddit_running": true, "reddit_running": true,
"started_at": "2025-12-02T11:00:00" "started_at": "2025-12-02T11:00:00",
"total_rounds": 100,
"max_rounds_applied": 100
} }
} }
``` ```
> **说明**: `max_rounds_applied` 字段仅在指定了 `max_rounds` 参数时返回,表示实际应用的最大轮数限制。
--- ---
#### 5. 停止模拟 #### 5. 停止模拟
@ -1502,12 +1508,13 @@ curl -X POST http://localhost:5001/api/simulation/prepare/status \
# 等待status=completed # 等待status=completed
# Step 7: 启动模拟 # Step 7: 启动模拟可选指定max_rounds限制轮数
curl -X POST http://localhost:5001/api/simulation/start \ curl -X POST http://localhost:5001/api/simulation/start \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"simulation_id": "sim_xxx", "simulation_id": "sim_xxx",
"platform": "parallel" "platform": "parallel",
"max_rounds": 50
}' }'
# Step 8: 实时查询运行状态 # Step 8: 实时查询运行状态

View file

@ -15,6 +15,11 @@ def create_app(config_class=Config):
app = Flask(__name__) app = Flask(__name__)
app.config.from_object(config_class) app.config.from_object(config_class)
# 设置JSON编码确保中文直接显示而不是 \uXXXX 格式)
# Flask >= 2.3 使用 app.json.ensure_ascii旧版本使用 JSON_AS_ASCII 配置
if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'):
app.json.ensure_ascii = False
# 设置日志 # 设置日志
logger = setup_logger('mirofish') logger = setup_logger('mirofish')

View file

@ -1114,7 +1114,8 @@ def start_simulation():
请求JSON 请求JSON
{ {
"simulation_id": "sim_xxxx", // 必填模拟ID "simulation_id": "sim_xxxx", // 必填模拟ID
"platform": "parallel" // 可选: twitter / reddit / parallel (默认) "platform": "parallel", // 可选: twitter / reddit / parallel (默认)
"max_rounds": 100 // 可选: 最大模拟轮数用于截断过长的模拟
} }
返回 返回
@ -1141,6 +1142,22 @@ def start_simulation():
}), 400 }), 400
platform = data.get('platform', 'parallel') platform = data.get('platform', 'parallel')
max_rounds = data.get('max_rounds') # 可选:最大模拟轮数
# 验证 max_rounds 参数
if max_rounds is not None:
try:
max_rounds = int(max_rounds)
if max_rounds <= 0:
return jsonify({
"success": False,
"error": "max_rounds 必须是正整数"
}), 400
except (ValueError, TypeError):
return jsonify({
"success": False,
"error": "max_rounds 必须是有效的整数"
}), 400
if platform not in ['twitter', 'reddit', 'parallel']: if platform not in ['twitter', 'reddit', 'parallel']:
return jsonify({ return jsonify({
@ -1187,15 +1204,19 @@ def start_simulation():
}), 400 }), 400
# 启动模拟 # 启动模拟
run_state = SimulationRunner.start_simulation(simulation_id, platform) run_state = SimulationRunner.start_simulation(simulation_id, platform, max_rounds)
# 更新模拟状态 # 更新模拟状态
state.status = SimulationStatus.RUNNING state.status = SimulationStatus.RUNNING
manager._save_simulation_state(state) manager._save_simulation_state(state)
response_data = run_state.to_dict()
if max_rounds:
response_data['max_rounds_applied'] = max_rounds
return jsonify({ return jsonify({
"success": True, "success": True,
"data": run_state.to_dict() "data": response_data
}) })
except ValueError as e: except ValueError as e:

View file

@ -24,6 +24,9 @@ class Config:
SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key') SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key')
DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true' DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true'
# JSON配置 - 禁用ASCII转义让中文直接显示而不是 \uXXXX 格式)
JSON_AS_ASCII = False
# LLM配置统一使用OpenAI格式 # LLM配置统一使用OpenAI格式
LLM_API_KEY = os.environ.get('LLM_API_KEY') LLM_API_KEY = os.environ.get('LLM_API_KEY')
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')

View file

@ -85,8 +85,8 @@ class TimeSimulationConfig:
# 模拟总时长(模拟小时数) # 模拟总时长(模拟小时数)
total_simulation_hours: int = 72 # 默认模拟72小时3天 total_simulation_hours: int = 72 # 默认模拟72小时3天
# 每轮代表的时间(模拟分钟) # 每轮代表的时间(模拟分钟)- 默认60分钟1小时加快时间流速
minutes_per_round: int = 30 minutes_per_round: int = 60
# 每小时激活的Agent数量范围 # 每小时激活的Agent数量范围
agents_per_hour_min: int = 5 agents_per_hour_min: int = 5
@ -205,7 +205,7 @@ class SimulationConfigGenerator:
采用分步生成策略 采用分步生成策略
1. 生成时间配置和事件配置轻量级 1. 生成时间配置和事件配置轻量级
2. 分批生成Agent配置每批10-15 2. 分批生成Agent配置每批10-20
3. 生成平台配置 3. 生成平台配置
""" """
@ -214,6 +214,13 @@ class SimulationConfigGenerator:
# 每批生成的Agent数量 # 每批生成的Agent数量
AGENTS_PER_BATCH = 15 AGENTS_PER_BATCH = 15
# 各步骤的上下文截断长度(字符数)
TIME_CONFIG_CONTEXT_LENGTH = 10000 # 时间配置
EVENT_CONFIG_CONTEXT_LENGTH = 8000 # 事件配置
ENTITY_SUMMARY_LENGTH = 300 # 实体摘要
AGENT_SUMMARY_LENGTH = 300 # Agent配置中的实体摘要
ENTITIES_PER_TYPE_DISPLAY = 20 # 每类实体显示数量
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@ -286,8 +293,9 @@ class SimulationConfigGenerator:
# ========== 步骤1: 生成时间配置 ========== # ========== 步骤1: 生成时间配置 ==========
report_progress(1, "生成时间配置...") report_progress(1, "生成时间配置...")
time_config_result = self._generate_time_config(context, len(entities)) num_entities = len(entities)
time_config = self._parse_time_config(time_config_result) time_config_result = self._generate_time_config(context, num_entities)
time_config = self._parse_time_config(time_config_result, num_entities)
reasoning_parts.append(f"时间配置: {time_config_result.get('reasoning', '成功')}") reasoning_parts.append(f"时间配置: {time_config_result.get('reasoning', '成功')}")
# ========== 步骤2: 生成事件配置 ========== # ========== 步骤2: 生成事件配置 ==========
@ -411,11 +419,14 @@ class SimulationConfigGenerator:
for entity_type, type_entities in by_type.items(): for entity_type, type_entities in by_type.items():
lines.append(f"\n### {entity_type} ({len(type_entities)}个)") lines.append(f"\n### {entity_type} ({len(type_entities)}个)")
for e in type_entities[:10]: # 每类最多显示10个 # 使用配置的显示数量和摘要长度
summary_preview = (e.summary[:100] + "...") if len(e.summary) > 100 else e.summary display_count = self.ENTITIES_PER_TYPE_DISPLAY
summary_len = self.ENTITY_SUMMARY_LENGTH
for e in type_entities[:display_count]:
summary_preview = (e.summary[:summary_len] + "...") if len(e.summary) > summary_len else e.summary
lines.append(f"- {e.name}: {summary_preview}") lines.append(f"- {e.name}: {summary_preview}")
if len(type_entities) > 10: if len(type_entities) > display_count:
lines.append(f" ... 还有 {len(type_entities) - 10}") lines.append(f" ... 还有 {len(type_entities) - display_count}")
return "\n".join(lines) return "\n".join(lines)
@ -522,33 +533,56 @@ class SimulationConfigGenerator:
def _generate_time_config(self, context: str, num_entities: int) -> Dict[str, Any]: def _generate_time_config(self, context: str, num_entities: int) -> Dict[str, Any]:
"""生成时间配置""" """生成时间配置"""
# 使用配置的上下文截断长度
context_truncated = context[:self.TIME_CONFIG_CONTEXT_LENGTH]
# 计算最大允许值80%的agent数
max_agents_allowed = max(1, int(num_entities * 0.9))
prompt = f"""基于以下模拟需求,生成时间模拟配置。 prompt = f"""基于以下模拟需求,生成时间模拟配置。
{context[:5000]} {context_truncated}
## 任务 ## 任务
请生成时间配置JSON注意 请生成时间配置JSON
### 基本原则(仅供参考,需根据具体事件和参与群体灵活调整):
- 用户群体为中国人需符合北京时间作息习惯 - 用户群体为中国人需符合北京时间作息习惯
- 凌晨0-5点几乎无人活动活跃度系数0.05 - 凌晨0-5点几乎无人活动活跃度系数0.05
- 早上6-8点逐渐活跃活跃度系数0.4 - 早上6-8点逐渐活跃活跃度系数0.4
- 工作时间9-18点中等活跃活跃度系数0.7 - 工作时间9-18点中等活跃活跃度系数0.7
- 晚间19-22点是高峰期活跃度系数1.5 - 晚间19-22点是高峰期活跃度系数1.5
- 23点后活跃度下降活跃度系数0.5 - 23点后活跃度下降活跃度系数0.5
- 一般规律凌晨低活跃早间渐增工作时段中等晚间高峰
- **重要**以下示例值仅供参考你需要根据事件性质参与群体特点来调整具体时段
- 例如学生群体高峰可能是21-23媒体全天活跃官方机构只在工作时间
- 例如突发热点可能导致深夜也有讨论off_peak_hours 可适当缩短
当前实体数量: {num_entities} ### 返回JSON格式不要markdown
返回JSON格式不要markdown 示例
{{ {{
"total_simulation_hours": <72-168根据事件性质决定>, "total_simulation_hours": 72,
"minutes_per_round": <15-60>, "minutes_per_round": 60,
"agents_per_hour_min": <每小时最少激活Agent数>, "agents_per_hour_min": 5,
"agents_per_hour_max": <每小时最多激活Agent数>, "agents_per_hour_max": 50,
"peak_hours": [19, 20, 21, 22], "peak_hours": [19, 20, 21, 22],
"off_peak_hours": [0, 1, 2, 3, 4, 5], "off_peak_hours": [0, 1, 2, 3, 4, 5],
"morning_hours": [6, 7, 8], "morning_hours": [6, 7, 8],
"work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18], "work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
"reasoning": "<简要说明>" "reasoning": "针对该事件的时间配置说明"
}}""" }}
字段说明
- total_simulation_hours (int): 模拟总时长24-168小时突发事件短持续话题长
- minutes_per_round (int): 每轮时长30-120分钟建议60分钟
- agents_per_hour_min (int): 每小时最少激活Agent数取值范围: 1-{max_agents_allowed}
- agents_per_hour_max (int): 每小时最多激活Agent数取值范围: 1-{max_agents_allowed}
- peak_hours (int数组): 高峰时段根据事件参与群体调整
- off_peak_hours (int数组): 低谷时段通常深夜凌晨
- morning_hours (int数组): 早间时段
- work_hours (int数组): 工作时段
- reasoning (string): 简要说明为什么这样配置"""
system_prompt = "你是社交媒体模拟专家。返回纯JSON格式时间配置需符合中国人作息习惯。" system_prompt = "你是社交媒体模拟专家。返回纯JSON格式时间配置需符合中国人作息习惯。"
@ -562,23 +596,41 @@ class SimulationConfigGenerator:
"""获取默认时间配置(中国人作息)""" """获取默认时间配置(中国人作息)"""
return { return {
"total_simulation_hours": 72, "total_simulation_hours": 72,
"minutes_per_round": 30, "minutes_per_round": 60, # 每轮1小时加快时间流速
"agents_per_hour_min": max(1, num_entities // 15), "agents_per_hour_min": max(1, num_entities // 15),
"agents_per_hour_max": max(5, num_entities // 5), "agents_per_hour_max": max(5, num_entities // 5),
"peak_hours": [19, 20, 21, 22], "peak_hours": [19, 20, 21, 22],
"off_peak_hours": [0, 1, 2, 3, 4, 5], "off_peak_hours": [0, 1, 2, 3, 4, 5],
"morning_hours": [6, 7, 8], "morning_hours": [6, 7, 8],
"work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18], "work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
"reasoning": "使用默认中国人作息配置" "reasoning": "使用默认中国人作息配置每轮1小时"
} }
def _parse_time_config(self, result: Dict[str, Any]) -> TimeSimulationConfig: def _parse_time_config(self, result: Dict[str, Any], num_entities: int) -> TimeSimulationConfig:
"""解析时间配置结果""" """解析时间配置结果并验证agents_per_hour值不超过总agent数"""
# 获取原始值
agents_per_hour_min = result.get("agents_per_hour_min", max(1, num_entities // 15))
agents_per_hour_max = result.get("agents_per_hour_max", max(5, num_entities // 5))
# 验证并修正确保不超过总agent数
if agents_per_hour_min > num_entities:
logger.warning(f"agents_per_hour_min ({agents_per_hour_min}) 超过总Agent数 ({num_entities}),已修正")
agents_per_hour_min = max(1, num_entities // 10)
if agents_per_hour_max > num_entities:
logger.warning(f"agents_per_hour_max ({agents_per_hour_max}) 超过总Agent数 ({num_entities}),已修正")
agents_per_hour_max = max(agents_per_hour_min + 1, num_entities // 2)
# 确保 min < max
if agents_per_hour_min >= agents_per_hour_max:
agents_per_hour_min = max(1, agents_per_hour_max // 2)
logger.warning(f"agents_per_hour_min >= max已修正为 {agents_per_hour_min}")
return TimeSimulationConfig( return TimeSimulationConfig(
total_simulation_hours=result.get("total_simulation_hours", 72), total_simulation_hours=result.get("total_simulation_hours", 72),
minutes_per_round=result.get("minutes_per_round", 30), minutes_per_round=result.get("minutes_per_round", 60), # 默认每轮1小时
agents_per_hour_min=result.get("agents_per_hour_min", 5), agents_per_hour_min=agents_per_hour_min,
agents_per_hour_max=result.get("agents_per_hour_max", 20), agents_per_hour_max=agents_per_hour_max,
peak_hours=result.get("peak_hours", [19, 20, 21, 22]), peak_hours=result.get("peak_hours", [19, 20, 21, 22]),
off_peak_hours=result.get("off_peak_hours", [0, 1, 2, 3, 4, 5]), off_peak_hours=result.get("off_peak_hours", [0, 1, 2, 3, 4, 5]),
off_peak_activity_multiplier=0.05, # 凌晨几乎无人 off_peak_activity_multiplier=0.05, # 凌晨几乎无人
@ -616,11 +668,14 @@ class SimulationConfigGenerator:
for t, examples in type_examples.items() for t, examples in type_examples.items()
]) ])
# 使用配置的上下文截断长度
context_truncated = context[:self.EVENT_CONFIG_CONTEXT_LENGTH]
prompt = f"""基于以下模拟需求,生成事件配置。 prompt = f"""基于以下模拟需求,生成事件配置。
模拟需求: {simulation_requirement} 模拟需求: {simulation_requirement}
{context[:3000]} {context_truncated}
## 可用实体类型及示例 ## 可用实体类型及示例
{type_info} {type_info}
@ -761,14 +816,15 @@ class SimulationConfigGenerator:
) -> List[AgentActivityConfig]: ) -> List[AgentActivityConfig]:
"""分批生成Agent配置""" """分批生成Agent配置"""
# 构建实体信息 # 构建实体信息(使用配置的摘要长度)
entity_list = [] entity_list = []
summary_len = self.AGENT_SUMMARY_LENGTH
for i, e in enumerate(entities): for i, e in enumerate(entities):
entity_list.append({ entity_list.append({
"agent_id": start_idx + i, "agent_id": start_idx + i,
"entity_name": e.name, "entity_name": e.name,
"entity_type": e.get_entity_type() or "Unknown", "entity_type": e.get_entity_type() or "Unknown",
"summary": e.summary[:150] if e.summary else "" "summary": e.summary[:summary_len] if e.summary else ""
}) })
prompt = f"""基于以下信息,为每个实体生成社交媒体活动配置。 prompt = f"""基于以下信息,为每个实体生成社交媒体活动配置。

View file

@ -280,7 +280,8 @@ class SimulationRunner:
def start_simulation( def start_simulation(
cls, cls,
simulation_id: str, simulation_id: str,
platform: str = "parallel" # twitter / reddit / parallel platform: str = "parallel", # twitter / reddit / parallel
max_rounds: int = None # 最大模拟轮数(可选,用于截断过长的模拟)
) -> SimulationRunState: ) -> SimulationRunState:
""" """
启动模拟 启动模拟
@ -288,6 +289,7 @@ class SimulationRunner:
Args: Args:
simulation_id: 模拟ID simulation_id: 模拟ID
platform: 运行平台 (twitter/reddit/parallel) platform: 运行平台 (twitter/reddit/parallel)
max_rounds: 最大模拟轮数可选用于截断过长的模拟
Returns: Returns:
SimulationRunState SimulationRunState
@ -313,6 +315,13 @@ class SimulationRunner:
minutes_per_round = time_config.get("minutes_per_round", 30) minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = int(total_hours * 60 / minutes_per_round) total_rounds = int(total_hours * 60 / minutes_per_round)
# 如果指定了最大轮数,则截断
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
logger.info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
state = SimulationRunState( state = SimulationRunState(
simulation_id=simulation_id, simulation_id=simulation_id,
runner_status=RunnerStatus.STARTING, runner_status=RunnerStatus.STARTING,
@ -358,6 +367,10 @@ class SimulationRunner:
"--config", config_path, # 使用完整配置文件路径 "--config", config_path, # 使用完整配置文件路径
] ]
# 如果指定了最大轮数,添加到命令行参数
if max_rounds is not None and max_rounds > 0:
cmd.extend(["--max-rounds", str(max_rounds)])
# 创建主日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞 # 创建主日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
main_log_path = os.path.join(sim_dir, "simulation.log") main_log_path = os.path.join(sim_dir, "simulation.log")
main_log_file = open(main_log_path, 'w', encoding='utf-8') main_log_file = open(main_log_path, 'w', encoding='utf-8')

View file

@ -404,9 +404,18 @@ async def run_twitter_simulation(
config: Dict[str, Any], config: Dict[str, Any],
simulation_dir: str, simulation_dir: str,
action_logger: Optional[PlatformActionLogger] = None, action_logger: Optional[PlatformActionLogger] = None,
main_logger: Optional[SimulationLogManager] = None main_logger: Optional[SimulationLogManager] = None,
max_rounds: Optional[int] = None
): ):
"""运行Twitter模拟""" """运行Twitter模拟
Args:
config: 模拟配置
simulation_dir: 模拟目录
action_logger: 动作日志记录器
main_logger: 主日志管理器
max_rounds: 最大模拟轮数可选用于截断过长的模拟
"""
def log_info(msg): def log_info(msg):
if main_logger: if main_logger:
main_logger.info(f"[Twitter] {msg}") main_logger.info(f"[Twitter] {msg}")
@ -494,6 +503,13 @@ async def run_twitter_simulation(
minutes_per_round = time_config.get("minutes_per_round", 30) minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round total_rounds = (total_hours * 60) // minutes_per_round
# 如果指定了最大轮数,则截断
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
start_time = datetime.now() start_time = datetime.now()
for round_num in range(total_rounds): for round_num in range(total_rounds):
@ -552,9 +568,18 @@ async def run_reddit_simulation(
config: Dict[str, Any], config: Dict[str, Any],
simulation_dir: str, simulation_dir: str,
action_logger: Optional[PlatformActionLogger] = None, action_logger: Optional[PlatformActionLogger] = None,
main_logger: Optional[SimulationLogManager] = None main_logger: Optional[SimulationLogManager] = None,
max_rounds: Optional[int] = None
): ):
"""运行Reddit模拟""" """运行Reddit模拟
Args:
config: 模拟配置
simulation_dir: 模拟目录
action_logger: 动作日志记录器
main_logger: 主日志管理器
max_rounds: 最大模拟轮数可选用于截断过长的模拟
"""
def log_info(msg): def log_info(msg):
if main_logger: if main_logger:
main_logger.info(f"[Reddit] {msg}") main_logger.info(f"[Reddit] {msg}")
@ -649,6 +674,13 @@ async def run_reddit_simulation(
minutes_per_round = time_config.get("minutes_per_round", 30) minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round total_rounds = (total_hours * 60) // minutes_per_round
# 如果指定了最大轮数,则截断
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
start_time = datetime.now() start_time = datetime.now()
for round_num in range(total_rounds): for round_num in range(total_rounds):
@ -721,6 +753,12 @@ async def main():
action='store_true', action='store_true',
help='只运行Reddit模拟' help='只运行Reddit模拟'
) )
parser.add_argument(
'--max-rounds',
type=int,
default=None,
help='最大模拟轮数(可选,用于截断过长的模拟)'
)
args = parser.parse_args() args = parser.parse_args()
@ -746,9 +784,18 @@ async def main():
log_manager.info("=" * 60) log_manager.info("=" * 60)
time_config = config.get("time_config", {}) time_config = config.get("time_config", {})
total_hours = time_config.get('total_simulation_hours', 72)
minutes_per_round = time_config.get('minutes_per_round', 30)
config_total_rounds = (total_hours * 60) // minutes_per_round
log_manager.info(f"模拟参数:") log_manager.info(f"模拟参数:")
log_manager.info(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时") log_manager.info(f" - 总模拟时长: {total_hours}小时")
log_manager.info(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟") log_manager.info(f" - 每轮时间: {minutes_per_round}分钟")
log_manager.info(f" - 配置总轮数: {config_total_rounds}")
if args.max_rounds:
log_manager.info(f" - 最大轮数限制: {args.max_rounds}")
if args.max_rounds < config_total_rounds:
log_manager.info(f" - 实际执行轮数: {args.max_rounds} (已截断)")
log_manager.info(f" - Agent数量: {len(config.get('agent_configs', []))}") log_manager.info(f" - Agent数量: {len(config.get('agent_configs', []))}")
log_manager.info("日志结构:") log_manager.info("日志结构:")
@ -760,14 +807,14 @@ async def main():
start_time = datetime.now() start_time = datetime.now()
if args.twitter_only: if args.twitter_only:
await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager) await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds)
elif args.reddit_only: elif args.reddit_only:
await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager) await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds)
else: else:
# 并行运行(每个平台使用独立的日志记录器) # 并行运行(每个平台使用独立的日志记录器)
await asyncio.gather( await asyncio.gather(
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager), run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds),
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager), run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds),
) )
total_elapsed = (datetime.now() - start_time).total_seconds() total_elapsed = (datetime.now() - start_time).total_seconds()

View file

@ -251,8 +251,12 @@ class RedditSimulationRunner:
return active_agents return active_agents
async def run(self): async def run(self, max_rounds: int = None):
"""运行Reddit模拟""" """运行Reddit模拟
Args:
max_rounds: 最大模拟轮数可选用于截断过长的模拟
"""
print("=" * 60) print("=" * 60)
print("OASIS Reddit模拟") print("OASIS Reddit模拟")
print(f"配置文件: {self.config_path}") print(f"配置文件: {self.config_path}")
@ -264,10 +268,19 @@ class RedditSimulationRunner:
minutes_per_round = time_config.get("minutes_per_round", 30) minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round total_rounds = (total_hours * 60) // minutes_per_round
# 如果指定了最大轮数,则截断
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
print(f"\n模拟参数:") print(f"\n模拟参数:")
print(f" - 总模拟时长: {total_hours}小时") print(f" - 总模拟时长: {total_hours}小时")
print(f" - 每轮时间: {minutes_per_round}分钟") print(f" - 每轮时间: {minutes_per_round}分钟")
print(f" - 总轮数: {total_rounds}") print(f" - 总轮数: {total_rounds}")
if max_rounds:
print(f" - 最大轮数限制: {max_rounds}")
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}") print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
print("\n初始化LLM模型...") print("\n初始化LLM模型...")
@ -380,6 +393,12 @@ async def main():
required=True, required=True,
help='配置文件路径 (simulation_config.json)' help='配置文件路径 (simulation_config.json)'
) )
parser.add_argument(
'--max-rounds',
type=int,
default=None,
help='最大模拟轮数(可选,用于截断过长的模拟)'
)
args = parser.parse_args() args = parser.parse_args()
@ -392,7 +411,7 @@ async def main():
setup_oasis_logging(os.path.join(simulation_dir, "log")) setup_oasis_logging(os.path.join(simulation_dir, "log"))
runner = RedditSimulationRunner(args.config) runner = RedditSimulationRunner(args.config)
await runner.run() await runner.run(max_rounds=args.max_rounds)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -259,8 +259,12 @@ class TwitterSimulationRunner:
return active_agents return active_agents
async def run(self): async def run(self, max_rounds: int = None):
"""运行Twitter模拟""" """运行Twitter模拟
Args:
max_rounds: 最大模拟轮数可选用于截断过长的模拟
"""
print("=" * 60) print("=" * 60)
print("OASIS Twitter模拟") print("OASIS Twitter模拟")
print(f"配置文件: {self.config_path}") print(f"配置文件: {self.config_path}")
@ -275,10 +279,19 @@ class TwitterSimulationRunner:
# 计算总轮数 # 计算总轮数
total_rounds = (total_hours * 60) // minutes_per_round total_rounds = (total_hours * 60) // minutes_per_round
# 如果指定了最大轮数,则截断
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
print(f"\n模拟参数:") print(f"\n模拟参数:")
print(f" - 总模拟时长: {total_hours}小时") print(f" - 总模拟时长: {total_hours}小时")
print(f" - 每轮时间: {minutes_per_round}分钟") print(f" - 每轮时间: {minutes_per_round}分钟")
print(f" - 总轮数: {total_rounds}") print(f" - 总轮数: {total_rounds}")
if max_rounds:
print(f" - 最大轮数限制: {max_rounds}")
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}") print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
# 创建模型 # 创建模型
@ -393,6 +406,12 @@ async def main():
required=True, required=True,
help='配置文件路径 (simulation_config.json)' help='配置文件路径 (simulation_config.json)'
) )
parser.add_argument(
'--max-rounds',
type=int,
default=None,
help='最大模拟轮数(可选,用于截断过长的模拟)'
)
args = parser.parse_args() args = parser.parse_args()
@ -405,7 +424,7 @@ async def main():
setup_oasis_logging(os.path.join(simulation_dir, "log")) setup_oasis_logging(os.path.join(simulation_dir, "log"))
runner = TwitterSimulationRunner(args.config) runner = TwitterSimulationRunner(args.config)
await runner.run() await runner.run(max_rounds=args.max_rounds)
if __name__ == "__main__": if __name__ == "__main__":