From 5f159f6d88e8c3eebfa8417e171dc78060813721 Mon Sep 17 00:00:00 2001 From: 666ghj <670939375@qq.com> Date: Mon, 1 Dec 2025 15:03:44 +0800 Subject: [PATCH] Enhance backend functionality with OASIS simulation features - Updated README.md to include new simulation scripts and configuration details for OASIS, including API retry mechanisms and environment variable settings. - Added simulation management and configuration generation services to streamline the simulation process across Twitter and Reddit platforms. - Introduced new API routes for simulation-related operations, including entity retrieval and simulation status management. - Implemented a robust retry mechanism for external API calls to improve system stability. - Enhanced task management model to include detailed progress tracking. - Added logging capabilities for action tracking during simulations. - Included new scripts for running parallel simulations and testing profile formats. --- backend/README.md | 1442 ++++++++++++++++- backend/app/__init__.py | 3 +- backend/app/api/__init__.py | 2 + backend/app/api/simulation.py | 1330 +++++++++++++++ backend/app/config.py | 14 + backend/app/models/task.py | 10 +- backend/app/services/__init__.py | 43 +- .../app/services/oasis_profile_generator.py | 561 +++++++ .../services/simulation_config_generator.py | 584 +++++++ backend/app/services/simulation_manager.py | 546 +++++++ backend/app/services/simulation_runner.py | 670 ++++++++ backend/app/services/zep_entity_reader.py | 386 +++++ backend/app/utils/retry.py | 238 +++ backend/requirements.txt | 4 + backend/scripts/action_logger.py | 138 ++ backend/scripts/run_parallel_simulation.py | 503 ++++++ backend/scripts/run_reddit_simulation.py | 298 ++++ backend/scripts/run_twitter_simulation.py | 313 ++++ backend/scripts/test_profile_format.py | 166 ++ 19 files changed, 7202 insertions(+), 49 deletions(-) create mode 100644 backend/app/api/simulation.py create mode 100644 backend/app/services/oasis_profile_generator.py create mode 100644 backend/app/services/simulation_config_generator.py create mode 100644 backend/app/services/simulation_manager.py create mode 100644 backend/app/services/simulation_runner.py create mode 100644 backend/app/services/zep_entity_reader.py create mode 100644 backend/app/utils/retry.py create mode 100644 backend/scripts/action_logger.py create mode 100644 backend/scripts/run_parallel_simulation.py create mode 100644 backend/scripts/run_reddit_simulation.py create mode 100644 backend/scripts/run_twitter_simulation.py create mode 100644 backend/scripts/test_profile_format.py diff --git a/backend/README.md b/backend/README.md index ee54adb..2521693 100644 --- a/backend/README.md +++ b/backend/README.md @@ -7,23 +7,37 @@ ``` backend/ ├── app/ -│ ├── __init__.py # Flask应用工厂 -│ ├── config.py # 配置管理 -│ ├── api/ # API路由 -│ │ ├── __init__.py -│ │ └── graph.py # 图谱相关接口 -│ ├── services/ # 业务逻辑层 -│ │ ├── ontology_generator.py # 本体生成服务 -│ │ ├── graph_builder.py # 图谱构建服务 -│ │ └── text_processor.py # 文本处理服务 -│ ├── models/ # 数据模型 -│ │ ├── task.py # 任务状态管理 -│ │ └── project.py # 项目上下文管理 -│ └── utils/ # 工具模块 -│ ├── file_parser.py # 文件解析 -│ └── llm_client.py # LLM客户端 +│ ├── __init__.py # Flask应用工厂 +│ ├── config.py # 配置管理 +│ ├── api/ # API路由 +│ │ ├── __init__.py # Blueprint注册 +│ │ ├── graph.py # Step1: 图谱相关接口 +│ │ └── simulation.py # Step2: 模拟相关接口 +│ ├── services/ # 业务逻辑层 +│ │ ├── __init__.py # 服务模块导出 +│ │ ├── ontology_generator.py # 本体生成服务 +│ │ ├── graph_builder.py # 图谱构建服务 +│ │ ├── text_processor.py # 文本处理服务 +│ │ ├── zep_entity_reader.py # Zep实体读取与过滤 +│ │ ├── oasis_profile_generator.py # Agent Profile生成器 +│ │ ├── simulation_config_generator.py # LLM智能配置生成器(核心) +│ │ └── simulation_manager.py # 模拟管理器 +│ ├── models/ # 数据模型 +│ │ ├── task.py # 任务状态管理 +│ │ └── project.py # 项目上下文管理 +│ └── utils/ # 工具模块 +│ ├── file_parser.py # 文件解析 +│ ├── llm_client.py # LLM客户端 +│ └── logger.py # 日志工具 +├── scripts/ # 预设模拟脚本 +│ ├── run_twitter_simulation.py # Twitter模拟脚本 +│ ├── run_reddit_simulation.py # Reddit模拟脚本 +│ └── run_parallel_simulation.py # 双平台并行脚本 +├── uploads/ # 上传文件存储 +│ ├── projects/ # 项目文件 +│ └── simulations/ # 模拟数据(含配置和脚本副本) ├── requirements.txt -└── run.py # 启动入口 +└── run.py # 启动入口 ``` ## 安装 @@ -39,13 +53,16 @@ pip install -r requirements.txt 在项目根目录 `MiroFish/.env` 中配置: ```bash -# LLM配置 +# LLM配置(统一使用OpenAI格式) LLM_API_KEY=your-llm-api-key LLM_BASE_URL=https://openrouter.ai/api/v1 LLM_MODEL_NAME=gpt-4o-mini # Zep配置 ZEP_API_KEY=your-zep-api-key + +# OASIS模拟配置(可选) +OASIS_DEFAULT_MAX_ROUNDS=10 ``` ## 启动服务 @@ -54,22 +71,63 @@ ZEP_API_KEY=your-zep-api-key python run.py ``` -服务默认运行在 http://localhost:5000 +服务默认运行在 http://localhost:5001 --- -## API接口 +# 系统架构 -### 核心工作流程 +## 完整工作流程 ``` -1. 上传文件 + 生成本体(接口1) +┌─────────────────────────────────────────────────────────────────────────┐ +│ Step 1: 图谱构建 │ +├─────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 上传文档 ──→ 生成本体定义 ──→ 构建Zep图谱 ──→ 图谱数据 │ +│ (PDF/MD/TXT) (LLM分析) (异步任务) (节点/边) │ +│ │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Step 2: 实体读取与模拟准备 │ +├─────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 读取图谱节点 ──→ 过滤符合条件实体 ──→ 生成Agent Profile ──→ 生成脚本 │ +│ (Zep API) (按Labels筛选) (LLM生成人设) (OASIS启动) │ +│ │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Step 3: 双平台并行模拟 │ +├─────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────────┐ ┌─────────────────┐ │ +│ │ Twitter模拟 │ │ Reddit模拟 │ │ +│ │ (短平快交互) │ 并行运行 │ (深度话题讨论) │ │ +│ └─────────────────┘ └─────────────────┘ │ +│ │ │ +│ ▼ │ +│ 同一批智能体,模拟真实社交环境 │ +│ │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +# Step 1: 图谱构建 API + +## 核心工作流程 + +``` +1. 上传文件 + 生成本体 POST /api/graph/ontology/generate - → 自动创建项目,返回 project_id + → 返回 project_id -2. 构建图谱(接口2) +2. 构建图谱 POST /api/graph/build - → 传入 project_id → 返回 task_id 3. 查询任务进度 @@ -91,10 +149,10 @@ python run.py | 字段 | 类型 | 必填 | 说明 | |------|------|------|------| -| `files` | File | ✅ | PDF/MD/TXT文件,可多个 | -| `simulation_requirement` | Text | ✅ | 模拟需求描述 | -| `project_name` | Text | ❌ | 项目名称 | -| `additional_context` | Text | ❌ | 额外说明 | +| `files` | File | 是 | PDF/MD/TXT文件,可多个 | +| `simulation_requirement` | Text | 是 | 模拟需求描述 | +| `project_name` | Text | 否 | 项目名称 | +| `additional_context` | Text | 否 | 额外说明 | **响应示例:** ```json @@ -106,23 +164,26 @@ python run.py "ontology": { "entity_types": [ { - "name": "Person", - "description": "Individuals who can express opinions", - "attributes": [...] + "name": "Student", + "description": "Students enrolled in educational institutions", + "attributes": [ + {"name": "student_id", "type": "text", "description": "Unique identifier"}, + {"name": "major", "type": "text", "description": "Field of study"} + ] } ], "edge_types": [ { "name": "AFFILIATED_WITH", - "description": "Indicates affiliation", - "source_targets": [...] + "description": "Indicates affiliation between entities", + "source_targets": [ + {"source": "Student", "target": "University"} + ] } ] }, "analysis_summary": "分析说明...", - "files": [ - {"filename": "报告.pdf", "size": 123456} - ], + "files": [{"filename": "报告.pdf", "size": 123456}], "total_text_length": 20833 } } @@ -148,10 +209,10 @@ python run.py | 字段 | 类型 | 必填 | 说明 | |------|------|------|------| -| `project_id` | string | ✅ | 来自接口1的返回 | -| `graph_name` | string | ❌ | 图谱名称 | -| `chunk_size` | int | ❌ | 文本块大小,默认500 | -| `chunk_overlap` | int | ❌ | 块重叠字符,默认50 | +| `project_id` | string | 是 | 来自接口1的返回 | +| `graph_name` | string | 否 | 图谱名称 | +| `chunk_size` | int | 否 | 文本块大小,默认500 | +| `chunk_overlap` | int | 否 | 块重叠字符,默认50 | **响应:** ```json @@ -178,7 +239,7 @@ python run.py "task_id": "task_xyz789", "status": "processing", "progress": 45, - "message": "添加文本块 (15/30)...", + "message": "Zep处理中... 15/30 完成", "result": null } } @@ -211,17 +272,1191 @@ python run.py --- -## 实体设计原则 +# Step 2: 实体读取与模拟运行 API + +## 核心设计理念 + +**全程自动化,无需人工设置参数:** +- 脚本是**预设的**,不是动态生成 +- 所有模拟参数由**LLM智能生成** +- LLM读取模拟需求+文档+图谱信息,自动设置最佳参数 +- **通过API接口启动和监控模拟**,前端可实时展示 + +## 核心工作流程 + +``` +1. 创建模拟 + POST /api/simulation/create + → 返回 simulation_id + +2. 准备模拟环境(异步任务) + POST /api/simulation/prepare + Body: { "simulation_id": "sim_xxxx" } + → 返回 task_id(立即响应) + + 查询进度: + POST /api/simulation/prepare/status + Body: { "task_id": "task_xxxx" } + → 返回 status, progress, result + +3. 开始模拟 + POST /api/simulation/start + Body: { "simulation_id": "sim_xxxx", "platform": "parallel" } + → 在后台启动OASIS模拟进程 + → 返回运行状态 + +4. 实时监控(前端轮询) + GET /api/simulation/{simulation_id}/run-status/detail + → 返回当前进度、最近Agent动作 + +5. 停止模拟(可选) + POST /api/simulation/stop + Body: { "simulation_id": "sim_xxxx" } +``` + +--- + +## 实体读取接口 + +### 获取图谱实体(已过滤) + +**GET** `/api/simulation/entities/{graph_id}` + +获取图谱中符合预定义实体类型的节点。 + +**实体过滤逻辑:** +- Zep对符合预定义类型的实体,Labels为 `["Entity", "Student"]` +- 对不符合预定义类型的实体,Labels仅为 `["Entity"]` +- **筛选规则**:只保留Labels中包含除"Entity"和"Node"之外标签的节点 + +**Query参数:** + +| 参数 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `entity_types` | string | 否 | 逗号分隔的实体类型,用于进一步过滤 | +| `enrich` | boolean | 否 | 是否获取相关边信息,默认true | + +**响应示例:** +```json +{ + "success": true, + "data": { + "entities": [ + { + "uuid": "node_uuid_123", + "name": "杨景媛", + "labels": ["Entity", "Student"], + "summary": "武汉大学学生,图书馆事件当事人", + "attributes": { + "student_id": "2021001", + "major": "计算机科学" + }, + "related_edges": [ + { + "direction": "outgoing", + "edge_name": "AFFILIATED_WITH", + "fact": "杨景媛是武汉大学的学生", + "target_node_uuid": "node_uuid_456" + } + ], + "related_nodes": [ + { + "uuid": "node_uuid_456", + "name": "武汉大学", + "labels": ["Entity", "University"], + "summary": "中国著名高等学府" + } + ] + } + ], + "entity_types": ["Student", "University", "PublicFigure"], + "total_count": 100, + "filtered_count": 45 + } +} +``` + +--- + +### 获取单个实体详情 + +**GET** `/api/simulation/entities/{graph_id}/{entity_uuid}` + +获取单个实体的完整信息,包含所有相关边和关联节点。 + +--- + +### 按类型获取实体 + +**GET** `/api/simulation/entities/{graph_id}/by-type/{entity_type}` + +获取指定类型(如Student、PublicFigure)的所有实体。 + +--- + +## 模拟管理接口 + +### 创建模拟 + +**POST** `/api/simulation/create` + +**请求(JSON):** +```json +{ + "project_id": "proj_abc123def456", + "graph_id": "mirofish_xxxx", + "enable_twitter": true, + "enable_reddit": true, + "max_rounds": 10, + "agents_per_round": -1 +} +``` + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `project_id` | string | 是 | 项目ID | +| `graph_id` | string | 否 | 图谱ID,不提供则从project获取 | +| `enable_twitter` | boolean | 否 | 启用Twitter模拟,默认true | +| `enable_reddit` | boolean | 否 | 启用Reddit模拟,默认true | +| `max_rounds` | int | 否 | 最大模拟轮数,默认10 | +| `agents_per_round` | int | 否 | 每轮激活智能体数,-1表示全部 | + +**响应示例:** +```json +{ + "success": true, + "data": { + "simulation_id": "sim_abc123def456", + "config": { + "project_id": "proj_xxxx", + "graph_id": "mirofish_xxxx", + "enable_twitter": true, + "enable_reddit": true, + "max_rounds": 10 + }, + "status": "created", + "created_at": "2025-12-01T10:00:00" + } +} +``` + +--- + +### 准备模拟环境(异步任务) + +**POST** `/api/simulation/prepare` + +**异步接口**:这是一个耗时操作,接口会立即返回`task_id`,通过`/prepare/status`查询进度。 + +执行模拟准备流程(LLM智能生成所有参数,带自动重试机制): +1. 从Zep图谱读取并过滤实体 +2. 为每个实体生成OASIS Agent Profile(带重试) +3. LLM智能生成模拟配置(带重试) +4. 保存配置文件和复制预设脚本 + +**请求(JSON):** +```json +{ + "simulation_id": "sim_xxxx", + "entity_types": ["Student", "PublicFigure"], + "use_llm_for_profiles": true +} +``` + +**响应示例:** +```json +{ + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "task_id": "task_xxxx", + "status": "preparing", + "message": "准备任务已启动,请通过 /api/simulation/prepare/status 查询进度" + } +} +``` + +--- + +### 查询准备进度 + +**POST** `/api/simulation/prepare/status` + +查询准备任务的执行进度。 + +**请求(JSON):** +```json +{ + "task_id": "task_xxxx" +} +``` + +**响应示例:** +```json +{ + "success": true, + "data": { + "task_id": "task_xxxx", + "task_type": "simulation_prepare", + "status": "processing", + "progress": 45, + "message": "[2/4] 生成Agent人设: 35/93 - 生成 教授张三 的人设...", + "progress_detail": { + "current_stage": "generating_profiles", + "current_stage_name": "生成Agent人设", + "stage_index": 2, + "total_stages": 4, + "stage_progress": 38, + "current_item": 35, + "total_items": 93, + "item_description": "生成 教授张三 的人设..." + }, + "result": null, + "error": null, + "metadata": { + "project_id": "proj_xxxx", + "simulation_id": "sim_xxxx" + } + } +} +``` + +**进度详情字段(progress_detail):** + +| 字段 | 类型 | 说明 | +|------|------|------| +| `current_stage` | string | 当前阶段标识 (reading/generating_profiles/generating_config/copying_scripts) | +| `current_stage_name` | string | 当前阶段中文名称 | +| `stage_index` | int | 当前阶段序号 (1-4) | +| `total_stages` | int | 总阶段数 (4) | +| `stage_progress` | int | 当前阶段内进度 (0-100) | +| `current_item` | int | 当前处理的项目序号 | +| `total_items` | int | 当前阶段总项目数 | +| `item_description` | string | 当前项目描述 | + +**阶段说明:** + +| 阶段 | 名称 | 权重 | 说明 | +|------|------|------|------| +| 1 | 读取图谱实体 | 0-20% | 从Zep读取并过滤实体 | +| 2 | 生成Agent人设 | 20-70% | 为每个实体生成OASIS Profile | +| 3 | 生成模拟配置 | 70-90% | LLM智能生成模拟参数 | +| 4 | 准备模拟脚本 | 90-100% | 复制预设脚本到模拟目录 | + +**状态值(status):** +- `pending` - 等待中 +- `processing` - 处理中 +- `completed` - 已完成(此时result包含结果) +- `failed` - 失败(此时error包含错误信息) + +**完成后的响应:** +```json +{ + "success": true, + "data": { + "task_id": "task_xxxx", + "status": "completed", + "progress": 100, + "message": "任务完成", + "result": { + "simulation_id": "sim_xxxx", + "project_id": "proj_xxxx", + "graph_id": "mirofish_xxxx", + "status": "ready", + "entities_count": 93, + "profiles_count": 93, + "entity_types": ["University", "Student", ...], + "config_generated": true, + "error": null + } + } +} +``` + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `entity_types` | array | 否 | 指定实体类型进行过滤 | +| `use_llm_for_profiles` | boolean | 否 | 是否使用LLM生成人设,默认true | + +**注意**:`simulation_requirement`和`document_text`自动从项目中获取 + +**响应示例:** +```json +{ + "success": true, + "data": { + "simulation_id": "sim_abc123def456", + "status": "ready", + "entities_count": 45, + "profiles_count": 45, + "entity_types": ["Student", "PublicFigure", "University"], + "config_generated": true, + "config_reasoning": "根据武汉大学图书馆事件的特点,设置72小时模拟时长...", + "run_instructions": { + "simulation_dir": "/path/to/sim_xxx", + "commands": {...}, + "instructions": "..." + } + } +} +``` + +--- + +### 获取模拟状态 + +**GET** `/api/simulation/{simulation_id}` + +**响应示例:** +```json +{ + "success": true, + "data": { + "simulation_id": "sim_abc123def456", + "status": "ready", + "entities_count": 45, + "profiles_count": 45, + "entity_types": ["Student", "PublicFigure"], + "current_round": 0, + "twitter_status": "not_started", + "reddit_status": "not_started" + } +} +``` + +--- + +### 列出所有模拟 + +**GET** `/api/simulation/list` + +| Query参数 | 类型 | 说明 | +|-----------|------|------| +| `project_id` | string | 按项目ID过滤(可选) | + +--- + +### 获取Agent Profile + +**GET** `/api/simulation/{simulation_id}/profiles` + +| Query参数 | 类型 | 说明 | +|-----------|------|------| +| `platform` | string | 平台类型:reddit 或 twitter | + +**响应示例:** +```json +{ + "success": true, + "data": { + "platform": "reddit", + "count": 45, + "profiles": [ + { + "user_id": 0, + "user_name": "yangjingyuan_123", + "name": "杨景媛", + "bio": "武汉大学学生,关注教育公平与学生权益", + "persona": "杨景媛是一名积极参与社会讨论的大学生,性格内敛但观点鲜明...", + "karma": 1500, + "age": 22, + "gender": "female", + "mbti": "INFJ", + "country": "China", + "profession": "Student", + "interested_topics": ["Education", "Social Issues"] + } + ] + } +} +``` + +--- + +### 获取模拟配置 + +**GET** `/api/simulation/{simulation_id}/config` + +获取LLM智能生成的完整配置,包含: +- `time_config`: 时间配置 +- `agent_configs`: 每个Agent的活动配置 +- `event_config`: 事件配置 +- `generation_reasoning`: LLM的配置推理说明 + +--- + +### 下载文件 + +| 接口 | 说明 | +|------|------| +| GET `/api/simulation/{id}/config/download` | 下载配置文件 | +| GET `/api/simulation/{id}/script/{script_name}/download` | 下载脚本文件 | + +**脚本名称:** +- `run_twitter_simulation.py` +- `run_reddit_simulation.py` +- `run_parallel_simulation.py` + +--- + +### 直接生成Profile + +**POST** `/api/simulation/generate-profiles` + +不创建模拟,直接从图谱生成Agent Profile。 + +```json +{ + "graph_id": "mirofish_xxxx", + "entity_types": ["Student", "PublicFigure"], + "use_llm": true, + "platform": "reddit" +} +``` + +--- + +## 模拟运行控制接口 + +### 开始模拟 + +**POST** `/api/simulation/start` + +启动OASIS模拟,在后台运行。 + +**请求(JSON):** +```json +{ + "simulation_id": "sim_xxxx", // 必填 + "platform": "parallel" // 可选: twitter / reddit / parallel (默认) +} +``` + +**响应示例:** +```json +{ + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "runner_status": "running", + "process_pid": 12345, + "twitter_running": true, + "reddit_running": true, + "total_rounds": 144, + "total_simulation_hours": 72, + "started_at": "2025-12-01T10:00:00" + } +} +``` + +--- + +### 停止模拟 + +**POST** `/api/simulation/stop` + +停止正在运行的模拟。 + +**请求(JSON):** +```json +{ + "simulation_id": "sim_xxxx" // 必填 +} +``` + +**响应示例:** +```json +{ + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "runner_status": "stopped", + "completed_at": "2025-12-01T12:00:00", + "twitter_actions_count": 500, + "reddit_actions_count": 650 + } +} +``` + +--- + +## 实时状态监控接口 + +### 获取运行状态(基础) + +**GET** `/api/simulation/{simulation_id}/run-status` + +获取模拟运行的实时状态,用于前端轮询。 + +**响应示例:** +```json +{ + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "runner_status": "running", + "current_round": 25, + "total_rounds": 144, + "progress_percent": 17.4, + "simulated_hours": 12, + "total_simulation_hours": 72, + "twitter_running": true, + "reddit_running": true, + "twitter_actions_count": 150, + "reddit_actions_count": 200, + "total_actions_count": 350, + "started_at": "2025-12-01T10:00:00", + "updated_at": "2025-12-01T10:30:00" + } +} +``` + +**运行状态值(runner_status):** +- `idle` - 未运行 +- `starting` - 启动中 +- `running` - 运行中 +- `paused` - 已暂停 +- `stopping` - 停止中 +- `stopped` - 已停止 +- `completed` - 已完成 +- `failed` - 失败 + +--- + +### 获取运行状态(详细,含最近动作) + +**GET** `/api/simulation/{simulation_id}/run-status/detail` + +获取详细运行状态,包含最近的Agent动作列表,**用于前端实时展示动态**。 + +**响应示例:** +```json +{ + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "runner_status": "running", + "current_round": 25, + "progress_percent": 17.4, + "recent_actions": [ + { + "round_num": 25, + "timestamp": "2025-12-01T10:30:00", + "platform": "twitter", + "agent_id": 3, + "agent_name": "Entity Name", + "action_type": "CREATE_POST", + "action_args": {"content": "Post content..."}, + "result": null, + "success": true + }, + { + "round_num": 25, + "timestamp": "2025-12-01T10:29:55", + "platform": "reddit", + "agent_id": 7, + "agent_name": "Another Entity", + "action_type": "LIKE_POST", + "action_args": {"post_id": 5}, + "success": true + } + ] + } +} +``` + +--- + +### 获取动作历史 + +**GET** `/api/simulation/{simulation_id}/actions` + +获取完整的Agent动作历史记录。 + +**Query参数:** + +| 参数 | 类型 | 说明 | +|------|------|------| +| `limit` | int | 返回数量(默认100) | +| `offset` | int | 偏移量(默认0) | +| `platform` | string | 过滤平台(twitter/reddit) | +| `agent_id` | int | 过滤Agent ID | +| `round_num` | int | 过滤轮次 | + +--- + +### 获取时间线 + +**GET** `/api/simulation/{simulation_id}/timeline` + +获取按轮次汇总的时间线,用于前端展示进度条和时间线视图。 + +**Query参数:** + +| 参数 | 类型 | 说明 | +|------|------|------| +| `start_round` | int | 起始轮次(默认0) | +| `end_round` | int | 结束轮次(默认全部) | + +**响应示例:** +```json +{ + "success": true, + "data": { + "rounds_count": 25, + "timeline": [ + { + "round_num": 1, + "twitter_actions": 10, + "reddit_actions": 15, + "total_actions": 25, + "active_agents_count": 8, + "active_agents": [0, 1, 3, 5, 7, 10, 12, 15], + "action_types": {"CREATE_POST": 5, "LIKE_POST": 10, "LLM_ACTION": 10}, + "first_action_time": "2025-12-01T10:00:00", + "last_action_time": "2025-12-01T10:05:00" + } + ] + } +} +``` + +--- + +### 获取Agent统计 + +**GET** `/api/simulation/{simulation_id}/agent-stats` + +获取每个Agent的活跃度统计,用于展示排行榜。 + +**响应示例:** +```json +{ + "success": true, + "data": { + "agents_count": 45, + "stats": [ + { + "agent_id": 3, + "agent_name": "Active Agent", + "total_actions": 50, + "twitter_actions": 30, + "reddit_actions": 20, + "action_types": {"CREATE_POST": 10, "LIKE_POST": 25, "REPOST": 15}, + "first_action_time": "2025-12-01T10:00:00", + "last_action_time": "2025-12-01T12:30:00" + } + ] + } +} +``` + +--- + +## 数据库查询接口 + +### 获取帖子 + +**GET** `/api/simulation/{simulation_id}/posts` + +从模拟数据库获取帖子列表。 + +**Query参数:** + +| 参数 | 类型 | 说明 | +|------|------|------| +| `platform` | string | 平台类型(twitter/reddit,默认reddit) | +| `limit` | int | 返回数量(默认50) | +| `offset` | int | 偏移量 | + +--- + +### 获取评论 + +**GET** `/api/simulation/{simulation_id}/comments` + +从Reddit模拟数据库获取评论列表。 + +**Query参数:** + +| 参数 | 类型 | 说明 | +|------|------|------| +| `post_id` | string | 过滤帖子ID(可选) | +| `limit` | int | 返回数量(默认50) | +| `offset` | int | 偏移量 | + +--- + +# 服务层实现细节 + +## 1. ZepEntityReader(Zep实体读取服务) + +**文件:** `app/services/zep_entity_reader.py` + +### 核心功能 + +| 方法 | 说明 | +|------|------| +| `get_all_nodes(graph_id)` | 获取图谱所有节点 | +| `get_all_edges(graph_id)` | 获取图谱所有边 | +| `filter_defined_entities(graph_id, ...)` | 筛选符合条件的实体 | +| `get_entity_with_context(graph_id, uuid)` | 获取实体完整上下文 | +| `get_entities_by_type(graph_id, type)` | 按类型获取实体 | + +### 数据结构 + +```python +@dataclass +class EntityNode: + uuid: str # 节点UUID + name: str # 实体名称 + labels: List[str] # 标签列表 ["Entity", "Student"] + summary: str # 实体摘要 + attributes: Dict[str, Any] # 属性字典 + related_edges: List[Dict] # 相关边信息 + related_nodes: List[Dict] # 关联节点信息 + + def get_entity_type(self) -> Optional[str]: + """获取实体类型(排除默认Entity标签)""" + +@dataclass +class FilteredEntities: + entities: List[EntityNode] # 实体列表 + entity_types: Set[str] # 发现的实体类型 + total_count: int # 总节点数 + filtered_count: int # 过滤后数量 +``` + +### 过滤逻辑示例 + +```python +# Zep返回的节点Labels示例: +# 符合预定义类型: ["Entity", "Student"] +# 不符合预定义类型: ["Entity"] + +for node in all_nodes: + labels = node.get("labels", []) + custom_labels = [l for l in labels if l not in ["Entity", "Node"]] + + if not custom_labels: + # 只有默认标签,跳过 + continue + + # 保留符合条件的实体 + entity_type = custom_labels[0] + filtered_entities.append(node) +``` + +--- + +## 2. OasisProfileGenerator(Agent Profile生成器) + +**文件:** `app/services/oasis_profile_generator.py` + +### 核心功能 + +| 方法 | 说明 | +|------|------| +| `generate_profile_from_entity(entity, user_id)` | 从实体生成单个Profile | +| `generate_profiles_from_entities(entities)` | 批量生成Profile | +| `save_profiles_to_json(profiles, path, platform)` | 保存到JSON文件 | + +### Profile数据结构 + +```python +@dataclass +class OasisAgentProfile: + # 基础字段 + user_id: int # 用户ID + user_name: str # 用户名 + name: str # 显示名称 + bio: str # 简介(max 150字符) + persona: str # 详细人设描述 + + # Reddit字段 + karma: int = 1000 + + # Twitter字段 + friend_count: int = 100 + follower_count: int = 150 + statuses_count: int = 500 + + # 人设详情 + age: Optional[int] = None + gender: Optional[str] = None + mbti: Optional[str] = None # INTJ, ENFP等 + country: Optional[str] = None + profession: Optional[str] = None + interested_topics: List[str] = [] + + # 来源信息 + source_entity_uuid: Optional[str] = None + source_entity_type: Optional[str] = None +``` + +### Profile生成策略 + +**1. LLM生成(默认)** + +使用LLM根据实体信息生成详细人设: + +```python +prompt = f""" +Entity: {entity_name} ({entity_type}) +Summary: {entity_summary} +Context: {related_edges_and_nodes} + +Generate a social media user profile with: +- bio (max 150 chars) +- persona (detailed description) +- age, gender, mbti, country +- profession, interested_topics +""" +``` + +**2. 规则生成(Fallback)** + +根据实体类型使用预定义模板: + +| 实体类型 | 生成策略 | +|----------|----------| +| Student/Alumni | 年龄18-30,学生身份,关注教育话题 | +| PublicFigure/Expert | 年龄35-60,专业人士,政治经济话题 | +| MediaOutlet | 媒体官方账号,新闻时事话题 | +| University/GovernmentAgency | 机构官方账号,政策公告话题 | + +--- + +## 3. SimulationConfigGenerator(模拟配置智能生成器) + +**文件:** `app/services/simulation_config_generator.py` + +### 核心功能 + +使用LLM分析模拟需求、文档内容、图谱实体信息,自动生成最佳的模拟参数配置。 + +| 方法 | 说明 | +|------|------| +| `generate_config(...)` | 智能生成完整模拟配置 | +| `_build_context(...)` | 构建LLM上下文(最大5万字) | +| `_generate_config_with_llm(...)` | 调用LLM生成配置 | +| `_generate_default_config(...)` | 默认配置(LLM失败时) | + +### LLM智能生成的配置内容 + +**1. TimeSimulationConfig(时间配置)** +```python +@dataclass +class TimeSimulationConfig: + total_simulation_hours: int = 72 # 模拟总时长(小时) + minutes_per_round: int = 30 # 每轮代表的时间(分钟) + agents_per_hour_min: int = 5 # 每小时激活Agent数量(最小) + agents_per_hour_max: int = 20 # 每小时激活Agent数量(最大) + peak_hours: List[int] # 高峰时段 [9,10,11,14,15,20,21,22] + off_peak_hours: List[int] # 低谷时段 [0,1,2,3,4,5] + peak_activity_multiplier: float = 1.5 # 高峰活跃度乘数 + off_peak_activity_multiplier: float = 0.3 # 低谷活跃度乘数 +``` + +**2. AgentActivityConfig(每个Agent的活动配置)** +```python +@dataclass +class AgentActivityConfig: + agent_id: int + entity_uuid: str + entity_name: str + entity_type: str + + activity_level: float = 0.5 # 整体活跃度 (0.0-1.0) + posts_per_hour: float = 1.0 # 每小时发帖频率 + comments_per_hour: float = 2.0 # 每小时评论频率 + active_hours: List[int] # 活跃时间段 (0-23) + response_delay_min: int = 5 # 响应延迟最小值(分钟) + response_delay_max: int = 60 # 响应延迟最大值(分钟) + sentiment_bias: float = 0.0 # 情感倾向 (-1到1) + stance: str = "neutral" # 立场 (supportive/opposing/neutral/observer) + influence_weight: float = 1.0 # 影响力权重 +``` + +**3. 不同实体类型的默认参数差异** + +| 实体类型 | 活跃度 | 发帖频率 | 响应延迟 | 影响力 | +|----------|--------|----------|----------|--------| +| University/GovernmentAgency | 0.2 | 0.1/小时 | 60-240分钟 | 3.0 | +| MediaOutlet | 0.6 | 1.0/小时 | 5-30分钟 | 2.5 | +| PublicFigure/Expert | 0.5 | 0.3/小时 | 10-60分钟 | 2.0 | +| Student/Person | 0.7 | 0.5/小时 | 1-20分钟 | 1.0 | + +--- + +## 4. SimulationManager(模拟管理器) + +**文件:** `app/services/simulation_manager.py` + +### 核心功能 + +| 方法 | 说明 | +|------|------| +| `create_simulation(project_id, graph_id, ...)` | 创建模拟 | +| `prepare_simulation(simulation_id, ...)` | 准备模拟环境(调用配置生成器) | +| `get_simulation(simulation_id)` | 获取模拟状态 | +| `get_profiles(simulation_id, platform)` | 获取Profile | +| `get_simulation_config(simulation_id)` | 获取模拟配置 | +| `get_run_instructions(simulation_id)` | 获取运行说明 | + +### 模拟状态流转 + +``` +created → preparing → ready → running → completed + ↓ ↓ + failed paused +``` + +### 生成的文件结构 + +``` +uploads/simulations/sim_xxxx/ +├── state.json # 模拟状态 +├── simulation_config.json # LLM生成的模拟配置(核心文件) +├── reddit_profiles.json # Reddit Agent Profile(JSON格式) +├── twitter_profiles.csv # Twitter Agent Profile(CSV格式) +├── run_reddit_simulation.py # 预设Reddit模拟脚本 +├── run_twitter_simulation.py # 预设Twitter模拟脚本 +├── run_parallel_simulation.py # 预设双平台并行脚本 +├── reddit_simulation.db # Reddit数据库(运行后生成) +└── twitter_simulation.db # Twitter数据库(运行后生成) +``` + +**重要:OASIS平台的Profile格式要求不同:** +- **Twitter**: 使用CSV格式,字段:`user_id,user_name,name,bio,friend_count,follower_count,statuses_count,created_at` +- **Reddit**: 使用JSON格式,支持详细人设字段:`realname,username,bio,persona,age,gender,mbti,country,profession,interested_topics` + +### 配置文件示例 (simulation_config.json) + +```json +{ + "simulation_id": "sim_abc123", + "project_id": "proj_xxx", + "graph_id": "mirofish_xxx", + "simulation_requirement": "分析武汉大学图书馆事件舆论传播", + + "time_config": { + "total_simulation_hours": 72, + "minutes_per_round": 30, + "agents_per_hour_min": 5, + "agents_per_hour_max": 15, + "peak_hours": [9, 10, 11, 14, 15, 20, 21, 22], + "off_peak_hours": [0, 1, 2, 3, 4, 5], + "peak_activity_multiplier": 1.5, + "off_peak_activity_multiplier": 0.3 + }, + + "agent_configs": [ + { + "agent_id": 0, + "entity_name": "武汉大学", + "entity_type": "University", + "activity_level": 0.15, + "posts_per_hour": 0.08, + "comments_per_hour": 0.02, + "active_hours": [9, 10, 11, 14, 15, 16, 17], + "response_delay_min": 120, + "response_delay_max": 360, + "sentiment_bias": 0.1, + "stance": "neutral", + "influence_weight": 4.0 + }, + { + "agent_id": 1, + "entity_name": "杨景媛", + "entity_type": "Student", + "activity_level": 0.8, + "posts_per_hour": 0.5, + "comments_per_hour": 2.0, + "active_hours": [7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "response_delay_min": 1, + "response_delay_max": 15, + "sentiment_bias": -0.3, + "stance": "opposing", + "influence_weight": 1.5 + } + ], + + "event_config": { + "initial_posts": [ + { + "poster_agent_id": 1, + "content": "今天在图书馆发生的事情让我非常失望..." + } + ], + "hot_topics": ["图书馆事件", "学生权益", "校方回应"], + "narrative_direction": "事件发酵后各方反应的模拟" + }, + + "generation_reasoning": "根据武汉大学图书馆事件的特点:1)涉及学生与校方的冲突,设置学生高活跃度、校方低频但高影响力;2)事件性质属于短期热点,设置72小时模拟时长;3)主要当事人杨景媛设置为高活跃度且持opposing立场..." +} + +--- + +## 5. 预设模拟脚本 + +**目录:** `backend/scripts/` + +脚本是**预设的**,不是动态生成。每次准备模拟时,脚本会被复制到模拟目录。 + +### 脚本说明 + +| 脚本 | 说明 | +|------|------| +| `run_twitter_simulation.py` | Twitter单平台模拟 | +| `run_reddit_simulation.py` | Reddit单平台模拟 | +| `run_parallel_simulation.py` | 双平台并行模拟(推荐) | + +### 脚本工作原理 + +```python +# 脚本读取配置文件,自动设置所有参数 +class TwitterSimulationRunner: + def __init__(self, config_path: str): + self.config = self._load_config() # 读取simulation_config.json + + def _get_active_agents_for_round(self, env, current_hour, round_num): + """根据时间和配置决定本轮激活哪些Agent""" + time_config = self.config.get("time_config", {}) + agent_configs = self.config.get("agent_configs", []) + + # 1. 检查是否高峰/低谷时段,调整激活数量 + # 2. 遍历每个Agent配置,检查是否在活跃时间 + # 3. 根据activity_level计算激活概率 + # 4. 返回本轮应激活的Agent列表 + ... + + async def run(self): + # 1. 创建LLM模型 + # 2. 加载Agent图 + # 3. 执行初始事件(从event_config读取) + # 4. 主循环:根据配置激活不同Agent + ... +``` + +### 使用方式 + +```bash +# 进入模拟目录 +cd backend/uploads/simulations/sim_xxxx/ + +# 运行模拟 +python run_parallel_simulation.py --config simulation_config.json + +# 其他选项 +python run_parallel_simulation.py --config simulation_config.json --twitter-only +python run_parallel_simulation.py --config simulation_config.json --reddit-only +``` + +--- + +## 6. Profile文件格式说明 + +**OASIS对两个平台的Profile格式有不同要求:** + +### Twitter Profile (CSV格式) + +```csv +user_id,user_name,name,bio,friend_count,follower_count,statuses_count,created_at +0,user0,User Zero,I am user zero with interests in technology.,100,150,500,2023-01-01 +1,user1,User One,Tech enthusiast and coffee lover.,200,250,1000,2023-01-02 +``` + +| 字段 | 类型 | 说明 | +|------|------|------| +| `user_id` | int | 用户ID | +| `user_name` | string | 用户名 | +| `name` | string | 显示名称 | +| `bio` | string | 简介 | +| `friend_count` | int | 关注数 | +| `follower_count` | int | 粉丝数 | +| `statuses_count` | int | 发帖数 | +| `created_at` | string | 创建日期 | + +### Reddit Profile (JSON详细格式) + +```json +[ + { + "realname": "Test User", + "username": "test_user_123", + "bio": "A test user for validation", + "persona": "Test User is an enthusiastic participant in social discussions.", + "age": 25, + "gender": "male", + "mbti": "INTJ", + "country": "China", + "profession": "Student", + "interested_topics": ["Technology", "Education"] + } +] +``` + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `realname` | string | 是 | 真实姓名 | +| `username` | string | 是 | 用户名 | +| `bio` | string | 是 | 简介(最大150字符) | +| `persona` | string | 是 | 详细人设描述 | +| `age` | int | 否 | 年龄 | +| `gender` | string | 否 | 性别 | +| `mbti` | string | 否 | MBTI人格类型 | +| `country` | string | 否 | 国家 | +| `profession` | string | 否 | 职业 | +| `interested_topics` | array | 否 | 感兴趣话题列表 | + +--- + +## 7. OASIS平台动作类型 + +### Twitter可用动作 + +| 动作 | 说明 | +|------|------| +| `CREATE_POST` | 发布推文 | +| `LIKE_POST` | 点赞推文 | +| `REPOST` | 转发推文 | +| `FOLLOW` | 关注用户 | +| `QUOTE_POST` | 引用转发 | +| `DO_NOTHING` | 不执行动作 | + +### Reddit可用动作 + +| 动作 | 说明 | +|------|------| +| `CREATE_POST` | 发布帖子 | +| `CREATE_COMMENT` | 发表评论 | +| `LIKE_POST` | 点赞帖子 | +| `DISLIKE_POST` | 踩帖子 | +| `LIKE_COMMENT` | 点赞评论 | +| `DISLIKE_COMMENT` | 踩评论 | +| `SEARCH_POSTS` | 搜索帖子 | +| `SEARCH_USER` | 搜索用户 | +| `TREND` | 查看热门 | +| `REFRESH` | 刷新推荐 | +| `FOLLOW` | 关注用户 | +| `MUTE` | 屏蔽用户 | +| `DO_NOTHING` | 不执行动作 | + +--- + +# 实体设计原则 本系统专为社会舆论模拟设计,实体必须是: -**✅ 可以是:** +**可以是:** - 具体的个人(有名有姓) - 注册的公司、组织、机构 - 媒体机构 - 政府部门 +- 高校、NGO等 -**❌ 不可以是:** +**不可以是:** - 抽象概念(如"技术"、"创新") - 情绪、观点、趋势 - 泛指的群体(如"用户"、"消费者") @@ -230,10 +1465,127 @@ python run.py --- -## 项目状态流转 +# 项目状态流转 ``` created → ontology_generated → graph_building → graph_completed ↓ failed ``` + +--- + +# 运行模拟 + +准备完成后,进入模拟数据目录运行预设脚本: + +```bash +# 激活conda环境 +conda activate MiroFish + +# 进入模拟目录 +cd backend/uploads/simulations/sim_xxxx/ + +# 运行单平台模拟 +python run_reddit_simulation.py --config simulation_config.json +# 或 +python run_twitter_simulation.py --config simulation_config.json + +# 运行双平台并行模拟(推荐) +python run_parallel_simulation.py --config simulation_config.json +``` + +### 脚本参数 + +| 参数 | 说明 | +|------|------| +| `--config` | 配置文件路径(必填) | +| `--twitter-only` | 只运行Twitter模拟(仅parallel脚本) | +| `--reddit-only` | 只运行Reddit模拟(仅parallel脚本) | + +### 输出文件 + +模拟运行后会生成: +- `twitter_simulation.db` - Twitter模拟数据库 +- `reddit_simulation.db` - Reddit模拟数据库 + +可使用SQLite工具查看模拟结果(帖子、评论、点赞等) + +--- + +# API调用重试机制 + +**文件:** `app/utils/retry.py` + +为LLM等外部API调用提供自动重试功能,提高系统稳定性。 + +## 重试策略 + +- **最大重试次数**:3次 +- **退避策略**:指数退避(1s → 2s → 4s) +- **最大延迟**:30秒 +- **随机抖动**:避免请求堆积 + +## 使用方式 + +**装饰器方式:** +```python +from app.utils.retry import retry_with_backoff + +@retry_with_backoff(max_retries=3) +def call_llm_api(): + return client.chat.completions.create(...) +``` + +**客户端方式:** +```python +from app.utils.retry import RetryableAPIClient + +retry_client = RetryableAPIClient(max_retries=3) +result = retry_client.call_with_retry(some_function, arg1, arg2) +``` + +**批量处理(单项失败不影响其他):** +```python +results, failures = retry_client.call_batch_with_retry( + items=entities, + process_func=generate_profile, + continue_on_failure=True +) +``` + +## 已应用重试机制的模块 + +| 模块 | 说明 | +|------|------| +| `OasisProfileGenerator` | LLM生成Agent人设 | +| `SimulationConfigGenerator` | LLM生成模拟配置 | + +--- + +# 依赖说明 + +``` +# Flask框架 +flask>=3.0.0 +flask-cors>=4.0.0 + +# Zep Cloud SDK +zep-cloud>=2.0.0 + +# OpenAI SDK(LLM调用) +openai>=1.0.0 + +# PDF处理 +PyMuPDF>=1.24.0 + +# 环境变量 +python-dotenv>=1.0.0 + +# 数据验证 +pydantic>=2.0.0 + +# OASIS社交媒体模拟 +oasis-ai>=0.1.0 +camel-ai>=0.2.0 +``` diff --git a/backend/app/__init__.py b/backend/app/__init__.py index 7ad7c69..b11dc70 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -46,8 +46,9 @@ def create_app(config_class=Config): return response # 注册蓝图 - from .api import graph_bp + from .api import graph_bp, simulation_bp app.register_blueprint(graph_bp, url_prefix='/api/graph') + app.register_blueprint(simulation_bp, url_prefix='/api/simulation') # 健康检查 @app.route('/health') diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py index 9723f92..ad7a722 100644 --- a/backend/app/api/__init__.py +++ b/backend/app/api/__init__.py @@ -5,6 +5,8 @@ API路由模块 from flask import Blueprint graph_bp = Blueprint('graph', __name__) +simulation_bp = Blueprint('simulation', __name__) from . import graph # noqa: E402, F401 +from . import simulation # noqa: E402, F401 diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py new file mode 100644 index 0000000..6d8651b --- /dev/null +++ b/backend/app/api/simulation.py @@ -0,0 +1,1330 @@ +""" +模拟相关API路由 +Step2: Zep实体读取与过滤、OASIS模拟准备与运行(全程自动化) +""" + +import os +import traceback +from flask import request, jsonify, send_file + +from . import simulation_bp +from ..config import Config +from ..services.zep_entity_reader import ZepEntityReader +from ..services.oasis_profile_generator import OasisProfileGenerator +from ..services.simulation_manager import SimulationManager, SimulationStatus +from ..services.simulation_runner import SimulationRunner, RunnerStatus +from ..utils.logger import get_logger +from ..models.project import ProjectManager + +logger = get_logger('mirofish.api.simulation') + + +# ============== 实体读取接口 ============== + +@simulation_bp.route('/entities/', methods=['GET']) +def get_graph_entities(graph_id: str): + """ + 获取图谱中的所有实体(已过滤) + + 只返回符合预定义实体类型的节点(Labels不只是Entity的节点) + + Query参数: + entity_types: 逗号分隔的实体类型列表(可选,用于进一步过滤) + enrich: 是否获取相关边信息(默认true) + """ + try: + if not Config.ZEP_API_KEY: + return jsonify({ + "success": False, + "error": "ZEP_API_KEY未配置" + }), 500 + + entity_types_str = request.args.get('entity_types', '') + entity_types = [t.strip() for t in entity_types_str.split(',') if t.strip()] if entity_types_str else None + enrich = request.args.get('enrich', 'true').lower() == 'true' + + logger.info(f"获取图谱实体: graph_id={graph_id}, entity_types={entity_types}, enrich={enrich}") + + reader = ZepEntityReader() + result = reader.filter_defined_entities( + graph_id=graph_id, + defined_entity_types=entity_types, + enrich_with_edges=enrich + ) + + return jsonify({ + "success": True, + "data": result.to_dict() + }) + + except Exception as e: + logger.error(f"获取图谱实体失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('/entities//', methods=['GET']) +def get_entity_detail(graph_id: str, entity_uuid: str): + """获取单个实体的详细信息""" + try: + if not Config.ZEP_API_KEY: + return jsonify({ + "success": False, + "error": "ZEP_API_KEY未配置" + }), 500 + + reader = ZepEntityReader() + entity = reader.get_entity_with_context(graph_id, entity_uuid) + + if not entity: + return jsonify({ + "success": False, + "error": f"实体不存在: {entity_uuid}" + }), 404 + + return jsonify({ + "success": True, + "data": entity.to_dict() + }) + + except Exception as e: + logger.error(f"获取实体详情失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('/entities//by-type/', methods=['GET']) +def get_entities_by_type(graph_id: str, entity_type: str): + """获取指定类型的所有实体""" + try: + if not Config.ZEP_API_KEY: + return jsonify({ + "success": False, + "error": "ZEP_API_KEY未配置" + }), 500 + + enrich = request.args.get('enrich', 'true').lower() == 'true' + + reader = ZepEntityReader() + entities = reader.get_entities_by_type( + graph_id=graph_id, + entity_type=entity_type, + enrich_with_edges=enrich + ) + + return jsonify({ + "success": True, + "data": { + "entity_type": entity_type, + "count": len(entities), + "entities": [e.to_dict() for e in entities] + } + }) + + except Exception as e: + logger.error(f"获取实体失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +# ============== 模拟管理接口 ============== + +@simulation_bp.route('/create', methods=['POST']) +def create_simulation(): + """ + 创建新的模拟 + + 注意:max_rounds等参数由LLM智能生成,无需手动设置 + + 请求(JSON): + { + "project_id": "proj_xxxx", // 必填 + "graph_id": "mirofish_xxxx", // 可选,如不提供则从project获取 + "enable_twitter": true, // 可选,默认true + "enable_reddit": true // 可选,默认true + } + + 返回: + { + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "project_id": "proj_xxxx", + "graph_id": "mirofish_xxxx", + "status": "created", + "enable_twitter": true, + "enable_reddit": true, + "created_at": "2025-12-01T10:00:00" + } + } + """ + try: + data = request.get_json() or {} + + project_id = data.get('project_id') + if not project_id: + return jsonify({ + "success": False, + "error": "请提供 project_id" + }), 400 + + project = ProjectManager.get_project(project_id) + if not project: + return jsonify({ + "success": False, + "error": f"项目不存在: {project_id}" + }), 404 + + graph_id = data.get('graph_id') or project.graph_id + if not graph_id: + return jsonify({ + "success": False, + "error": "项目尚未构建图谱,请先调用 /api/graph/build" + }), 400 + + manager = SimulationManager() + state = manager.create_simulation( + project_id=project_id, + graph_id=graph_id, + enable_twitter=data.get('enable_twitter', True), + enable_reddit=data.get('enable_reddit', True), + ) + + return jsonify({ + "success": True, + "data": state.to_dict() + }) + + except Exception as e: + logger.error(f"创建模拟失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('/prepare', methods=['POST']) +def prepare_simulation(): + """ + 准备模拟环境(异步任务,LLM智能生成所有参数) + + 这是一个耗时操作,接口会立即返回task_id, + 使用 GET /api/simulation/prepare/status 查询进度 + + 步骤: + 1. 从Zep图谱读取并过滤实体 + 2. 为每个实体生成OASIS Agent Profile(带重试机制) + 3. LLM智能生成模拟配置(带重试机制) + 4. 保存配置文件和预设脚本 + + 请求(JSON): + { + "simulation_id": "sim_xxxx", // 必填,模拟ID + "entity_types": ["Student", "PublicFigure"], // 可选,指定实体类型 + "use_llm_for_profiles": true // 可选,是否用LLM生成人设 + } + + 返回: + { + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "task_id": "task_xxxx", + "status": "preparing", + "message": "准备任务已启动" + } + } + """ + import threading + from ..models.task import TaskManager, TaskStatus + + try: + data = request.get_json() or {} + + simulation_id = data.get('simulation_id') + if not simulation_id: + return jsonify({ + "success": False, + "error": "请提供 simulation_id" + }), 400 + + manager = SimulationManager() + state = manager.get_simulation(simulation_id) + + if not state: + return jsonify({ + "success": False, + "error": f"模拟不存在: {simulation_id}" + }), 404 + + # 从项目获取必要信息 + project = ProjectManager.get_project(state.project_id) + if not project: + return jsonify({ + "success": False, + "error": f"项目不存在: {state.project_id}" + }), 404 + + # 获取模拟需求 + simulation_requirement = project.simulation_requirement or "" + if not simulation_requirement: + return jsonify({ + "success": False, + "error": "项目缺少模拟需求描述 (simulation_requirement)" + }), 400 + + # 获取文档文本 + document_text = ProjectManager.get_extracted_text(state.project_id) or "" + + entity_types_list = data.get('entity_types') + use_llm_for_profiles = data.get('use_llm_for_profiles', True) + + # 创建异步任务 + task_manager = TaskManager() + task_id = task_manager.create_task( + task_type="simulation_prepare", + metadata={ + "simulation_id": simulation_id, + "project_id": state.project_id + } + ) + + # 更新模拟状态 + state.status = SimulationStatus.PREPARING + manager._save_simulation_state(state) + + # 定义后台任务 + def run_prepare(): + try: + task_manager.update_task( + task_id, + status=TaskStatus.PROCESSING, + progress=0, + message="开始准备模拟环境..." + ) + + # 准备模拟(带进度回调) + # 存储阶段进度详情 + stage_details = {} + + def progress_callback(stage, progress, message, **kwargs): + # 计算总进度 + stage_weights = { + "reading": (0, 20), # 0-20% + "generating_profiles": (20, 70), # 20-70% + "generating_config": (70, 90), # 70-90% + "copying_scripts": (90, 100) # 90-100% + } + + start, end = stage_weights.get(stage, (0, 100)) + current_progress = int(start + (end - start) * progress / 100) + + # 构建详细进度信息 + stage_names = { + "reading": "读取图谱实体", + "generating_profiles": "生成Agent人设", + "generating_config": "生成模拟配置", + "copying_scripts": "准备模拟脚本" + } + + stage_index = list(stage_weights.keys()).index(stage) + 1 if stage in stage_weights else 1 + total_stages = len(stage_weights) + + # 更新阶段详情 + stage_details[stage] = { + "stage_name": stage_names.get(stage, stage), + "stage_progress": progress, + "current": kwargs.get("current", 0), + "total": kwargs.get("total", 0), + "item_name": kwargs.get("item_name", "") + } + + # 构建详细进度信息 + detail = stage_details[stage] + progress_detail_data = { + "current_stage": stage, + "current_stage_name": stage_names.get(stage, stage), + "stage_index": stage_index, + "total_stages": total_stages, + "stage_progress": progress, + "current_item": detail["current"], + "total_items": detail["total"], + "item_description": message + } + + # 构建简洁消息 + if detail["total"] > 0: + detailed_message = ( + f"[{stage_index}/{total_stages}] {stage_names.get(stage, stage)}: " + f"{detail['current']}/{detail['total']} - {message}" + ) + else: + detailed_message = f"[{stage_index}/{total_stages}] {stage_names.get(stage, stage)}: {message}" + + task_manager.update_task( + task_id, + progress=current_progress, + message=detailed_message, + progress_detail=progress_detail_data + ) + + result_state = manager.prepare_simulation( + simulation_id=simulation_id, + simulation_requirement=simulation_requirement, + document_text=document_text, + defined_entity_types=entity_types_list, + use_llm_for_profiles=use_llm_for_profiles, + progress_callback=progress_callback + ) + + # 任务完成 + task_manager.complete_task( + task_id, + result=result_state.to_simple_dict() + ) + + except Exception as e: + logger.error(f"准备模拟失败: {str(e)}") + task_manager.fail_task(task_id, str(e)) + + # 更新模拟状态为失败 + state = manager.get_simulation(simulation_id) + if state: + state.status = SimulationStatus.FAILED + state.error = str(e) + manager._save_simulation_state(state) + + # 启动后台线程 + thread = threading.Thread(target=run_prepare, daemon=True) + thread.start() + + return jsonify({ + "success": True, + "data": { + "simulation_id": simulation_id, + "task_id": task_id, + "status": "preparing", + "message": "准备任务已启动,请通过 /api/simulation/prepare/status 查询进度" + } + }) + + except ValueError as e: + return jsonify({ + "success": False, + "error": str(e) + }), 404 + + except Exception as e: + logger.error(f"启动准备任务失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('/prepare/status', methods=['POST']) +def get_prepare_status(): + """ + 查询准备任务进度 + + 请求(JSON): + { + "task_id": "task_xxxx" // 必填,prepare返回的task_id + } + + 返回: + { + "success": true, + "data": { + "task_id": "task_xxxx", + "status": "processing", // pending/processing/completed/failed + "progress": 45, // 0-100 总进度 + "message": "[2/4] 生成Agent人设: 35/93 - 生成 教授张三 的人设...", + "progress_detail": { // 详细进度信息 + "current_stage": "generating_profiles", + "current_stage_name": "生成Agent人设", + "stage_index": 2, // 当前阶段序号 + "total_stages": 4, // 总阶段数 + "stage_progress": 38, // 阶段内进度 0-100 + "current_item": 35, // 当前处理项目序号 + "total_items": 93, // 当前阶段总项目数 + "item_description": "生成 教授张三 的人设..." + }, + "result": null, // 完成后返回结果 + "error": null // 失败时返回错误信息 + } + } + """ + from ..models.task import TaskManager + + try: + data = request.get_json() or {} + + task_id = data.get('task_id') + if not task_id: + return jsonify({ + "success": False, + "error": "请提供 task_id" + }), 400 + + task_manager = TaskManager() + task = task_manager.get_task(task_id) + + if not task: + return jsonify({ + "success": False, + "error": f"任务不存在: {task_id}" + }), 404 + + return jsonify({ + "success": True, + "data": task.to_dict() + }) + + except Exception as e: + logger.error(f"查询任务状态失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e) + }), 500 + + +@simulation_bp.route('/', methods=['GET']) +def get_simulation(simulation_id: str): + """获取模拟状态""" + try: + manager = SimulationManager() + state = manager.get_simulation(simulation_id) + + if not state: + return jsonify({ + "success": False, + "error": f"模拟不存在: {simulation_id}" + }), 404 + + result = state.to_dict() + + # 如果模拟已准备好,附加运行说明 + if state.status == SimulationStatus.READY: + result["run_instructions"] = manager.get_run_instructions(simulation_id) + + return jsonify({ + "success": True, + "data": result + }) + + except Exception as e: + logger.error(f"获取模拟状态失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('/list', methods=['GET']) +def list_simulations(): + """ + 列出所有模拟 + + Query参数: + project_id: 按项目ID过滤(可选) + """ + try: + project_id = request.args.get('project_id') + + manager = SimulationManager() + simulations = manager.list_simulations(project_id=project_id) + + return jsonify({ + "success": True, + "data": [s.to_dict() for s in simulations], + "count": len(simulations) + }) + + except Exception as e: + logger.error(f"列出模拟失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('//profiles', methods=['GET']) +def get_simulation_profiles(simulation_id: str): + """ + 获取模拟的Agent Profile + + Query参数: + platform: 平台类型(reddit/twitter,默认reddit) + """ + try: + platform = request.args.get('platform', 'reddit') + + manager = SimulationManager() + profiles = manager.get_profiles(simulation_id, platform=platform) + + return jsonify({ + "success": True, + "data": { + "platform": platform, + "count": len(profiles), + "profiles": profiles + } + }) + + except ValueError as e: + return jsonify({ + "success": False, + "error": str(e) + }), 404 + + except Exception as e: + logger.error(f"获取Profile失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('//config', methods=['GET']) +def get_simulation_config(simulation_id: str): + """ + 获取模拟配置(LLM智能生成的完整配置) + + 返回包含: + - time_config: 时间配置(模拟时长、轮次、高峰/低谷时段) + - agent_configs: 每个Agent的活动配置(活跃度、发言频率、立场等) + - event_config: 事件配置(初始帖子、热点话题) + - platform_configs: 平台配置 + - generation_reasoning: LLM的配置推理说明 + """ + try: + manager = SimulationManager() + config = manager.get_simulation_config(simulation_id) + + if not config: + return jsonify({ + "success": False, + "error": f"模拟配置不存在,请先调用 /prepare 接口" + }), 404 + + return jsonify({ + "success": True, + "data": config + }) + + except Exception as e: + logger.error(f"获取配置失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('//config/download', methods=['GET']) +def download_simulation_config(simulation_id: str): + """下载模拟配置文件""" + try: + manager = SimulationManager() + sim_dir = manager._get_simulation_dir(simulation_id) + config_path = os.path.join(sim_dir, "simulation_config.json") + + if not os.path.exists(config_path): + return jsonify({ + "success": False, + "error": "配置文件不存在,请先调用 /prepare 接口" + }), 404 + + return send_file( + config_path, + as_attachment=True, + download_name="simulation_config.json" + ) + + except Exception as e: + logger.error(f"下载配置失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('//script//download', methods=['GET']) +def download_simulation_script(simulation_id: str, script_name: str): + """ + 下载模拟脚本文件 + + script_name可选值: + - run_twitter_simulation.py + - run_reddit_simulation.py + - run_parallel_simulation.py + """ + try: + manager = SimulationManager() + sim_dir = manager._get_simulation_dir(simulation_id) + + # 验证脚本名称 + allowed_scripts = [ + "run_twitter_simulation.py", + "run_reddit_simulation.py", + "run_parallel_simulation.py" + ] + + if script_name not in allowed_scripts: + return jsonify({ + "success": False, + "error": f"未知脚本: {script_name},可选: {allowed_scripts}" + }), 400 + + script_path = os.path.join(sim_dir, script_name) + + if not os.path.exists(script_path): + return jsonify({ + "success": False, + "error": "脚本文件不存在,请先调用 /prepare 接口" + }), 404 + + return send_file( + script_path, + as_attachment=True, + download_name=script_name + ) + + except Exception as e: + logger.error(f"下载脚本失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +# ============== Profile生成接口(独立使用) ============== + +@simulation_bp.route('/generate-profiles', methods=['POST']) +def generate_profiles(): + """ + 直接从图谱生成OASIS Agent Profile(不创建模拟) + + 请求(JSON): + { + "graph_id": "mirofish_xxxx", // 必填 + "entity_types": ["Student"], // 可选 + "use_llm": true, // 可选 + "platform": "reddit" // 可选 + } + """ + try: + data = request.get_json() or {} + + graph_id = data.get('graph_id') + if not graph_id: + return jsonify({ + "success": False, + "error": "请提供 graph_id" + }), 400 + + entity_types = data.get('entity_types') + use_llm = data.get('use_llm', True) + platform = data.get('platform', 'reddit') + + reader = ZepEntityReader() + filtered = reader.filter_defined_entities( + graph_id=graph_id, + defined_entity_types=entity_types, + enrich_with_edges=True + ) + + if filtered.filtered_count == 0: + return jsonify({ + "success": False, + "error": "没有找到符合条件的实体" + }), 400 + + generator = OasisProfileGenerator() + profiles = generator.generate_profiles_from_entities( + entities=filtered.entities, + use_llm=use_llm + ) + + if platform == "reddit": + profiles_data = [p.to_reddit_format() for p in profiles] + elif platform == "twitter": + profiles_data = [p.to_twitter_format() for p in profiles] + else: + profiles_data = [p.to_dict() for p in profiles] + + return jsonify({ + "success": True, + "data": { + "platform": platform, + "entity_types": list(filtered.entity_types), + "count": len(profiles_data), + "profiles": profiles_data + } + }) + + except Exception as e: + logger.error(f"生成Profile失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +# ============== 模拟运行控制接口 ============== + +@simulation_bp.route('/start', methods=['POST']) +def start_simulation(): + """ + 开始运行模拟 + + 请求(JSON): + { + "simulation_id": "sim_xxxx", // 必填,模拟ID + "platform": "parallel" // 可选: twitter / reddit / parallel (默认) + } + + 返回: + { + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "runner_status": "running", + "process_pid": 12345, + "twitter_running": true, + "reddit_running": true, + "started_at": "2025-12-01T10:00:00" + } + } + """ + try: + data = request.get_json() or {} + + simulation_id = data.get('simulation_id') + if not simulation_id: + return jsonify({ + "success": False, + "error": "请提供 simulation_id" + }), 400 + + platform = data.get('platform', 'parallel') + + if platform not in ['twitter', 'reddit', 'parallel']: + return jsonify({ + "success": False, + "error": f"无效的平台类型: {platform},可选: twitter/reddit/parallel" + }), 400 + + # 检查模拟是否已准备好 + manager = SimulationManager() + state = manager.get_simulation(simulation_id) + + if not state: + return jsonify({ + "success": False, + "error": f"模拟不存在: {simulation_id}" + }), 404 + + if state.status != SimulationStatus.READY: + return jsonify({ + "success": False, + "error": f"模拟未准备好,当前状态: {state.status.value},请先调用 /prepare 接口" + }), 400 + + # 启动模拟 + run_state = SimulationRunner.start_simulation(simulation_id, platform) + + # 更新模拟状态 + state.status = SimulationStatus.RUNNING + manager._save_simulation_state(state) + + return jsonify({ + "success": True, + "data": run_state.to_dict() + }) + + except ValueError as e: + return jsonify({ + "success": False, + "error": str(e) + }), 400 + + except Exception as e: + logger.error(f"启动模拟失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('/stop', methods=['POST']) +def stop_simulation(): + """ + 停止模拟 + + 请求(JSON): + { + "simulation_id": "sim_xxxx" // 必填,模拟ID + } + + 返回: + { + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "runner_status": "stopped", + "completed_at": "2025-12-01T12:00:00" + } + } + """ + try: + data = request.get_json() or {} + + simulation_id = data.get('simulation_id') + if not simulation_id: + return jsonify({ + "success": False, + "error": "请提供 simulation_id" + }), 400 + + run_state = SimulationRunner.stop_simulation(simulation_id) + + # 更新模拟状态 + manager = SimulationManager() + state = manager.get_simulation(simulation_id) + if state: + state.status = SimulationStatus.PAUSED + manager._save_simulation_state(state) + + return jsonify({ + "success": True, + "data": run_state.to_dict() + }) + + except ValueError as e: + return jsonify({ + "success": False, + "error": str(e) + }), 400 + + except Exception as e: + logger.error(f"停止模拟失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +# ============== 实时状态监控接口 ============== + +@simulation_bp.route('//run-status', methods=['GET']) +def get_run_status(simulation_id: str): + """ + 获取模拟运行实时状态(用于前端轮询) + + 返回: + { + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "runner_status": "running", + "current_round": 5, + "total_rounds": 144, + "progress_percent": 3.5, + "simulated_hours": 2, + "total_simulation_hours": 72, + "twitter_running": true, + "reddit_running": true, + "twitter_actions_count": 150, + "reddit_actions_count": 200, + "total_actions_count": 350, + "started_at": "2025-12-01T10:00:00", + "updated_at": "2025-12-01T10:30:00" + } + } + """ + try: + run_state = SimulationRunner.get_run_state(simulation_id) + + if not run_state: + return jsonify({ + "success": True, + "data": { + "simulation_id": simulation_id, + "runner_status": "idle", + "current_round": 0, + "total_rounds": 0, + "progress_percent": 0, + "twitter_actions_count": 0, + "reddit_actions_count": 0, + "total_actions_count": 0, + } + }) + + return jsonify({ + "success": True, + "data": run_state.to_dict() + }) + + except Exception as e: + logger.error(f"获取运行状态失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('//run-status/detail', methods=['GET']) +def get_run_status_detail(simulation_id: str): + """ + 获取模拟运行详细状态(包含最近动作) + + 用于前端展示实时动态 + + 返回: + { + "success": true, + "data": { + "simulation_id": "sim_xxxx", + "runner_status": "running", + "current_round": 5, + ... + "recent_actions": [ + { + "round_num": 5, + "timestamp": "2025-12-01T10:30:00", + "platform": "twitter", + "agent_id": 3, + "agent_name": "Agent Name", + "action_type": "CREATE_POST", + "action_args": {"content": "..."}, + "result": null, + "success": true + }, + ... + ] + } + } + """ + try: + run_state = SimulationRunner.get_run_state(simulation_id) + + if not run_state: + return jsonify({ + "success": True, + "data": { + "simulation_id": simulation_id, + "runner_status": "idle", + "recent_actions": [] + } + }) + + return jsonify({ + "success": True, + "data": run_state.to_detail_dict() + }) + + except Exception as e: + logger.error(f"获取详细状态失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('//actions', methods=['GET']) +def get_simulation_actions(simulation_id: str): + """ + 获取模拟中的Agent动作历史 + + Query参数: + limit: 返回数量(默认100) + offset: 偏移量(默认0) + platform: 过滤平台(twitter/reddit) + agent_id: 过滤Agent ID + round_num: 过滤轮次 + + 返回: + { + "success": true, + "data": { + "count": 100, + "actions": [...] + } + } + """ + try: + limit = request.args.get('limit', 100, type=int) + offset = request.args.get('offset', 0, type=int) + platform = request.args.get('platform') + agent_id = request.args.get('agent_id', type=int) + round_num = request.args.get('round_num', type=int) + + actions = SimulationRunner.get_actions( + simulation_id=simulation_id, + limit=limit, + offset=offset, + platform=platform, + agent_id=agent_id, + round_num=round_num + ) + + return jsonify({ + "success": True, + "data": { + "count": len(actions), + "actions": [a.to_dict() for a in actions] + } + }) + + except Exception as e: + logger.error(f"获取动作历史失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('//timeline', methods=['GET']) +def get_simulation_timeline(simulation_id: str): + """ + 获取模拟时间线(按轮次汇总) + + 用于前端展示进度条和时间线视图 + + Query参数: + start_round: 起始轮次(默认0) + end_round: 结束轮次(默认全部) + + 返回每轮的汇总信息 + """ + try: + start_round = request.args.get('start_round', 0, type=int) + end_round = request.args.get('end_round', type=int) + + timeline = SimulationRunner.get_timeline( + simulation_id=simulation_id, + start_round=start_round, + end_round=end_round + ) + + return jsonify({ + "success": True, + "data": { + "rounds_count": len(timeline), + "timeline": timeline + } + }) + + except Exception as e: + logger.error(f"获取时间线失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('//agent-stats', methods=['GET']) +def get_agent_stats(simulation_id: str): + """ + 获取每个Agent的统计信息 + + 用于前端展示Agent活跃度排行、动作分布等 + """ + try: + stats = SimulationRunner.get_agent_stats(simulation_id) + + return jsonify({ + "success": True, + "data": { + "agents_count": len(stats), + "stats": stats + } + }) + + except Exception as e: + logger.error(f"获取Agent统计失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +# ============== 数据库查询接口 ============== + +@simulation_bp.route('//posts', methods=['GET']) +def get_simulation_posts(simulation_id: str): + """ + 获取模拟中的帖子 + + Query参数: + platform: 平台类型(twitter/reddit) + limit: 返回数量(默认50) + offset: 偏移量 + + 返回帖子列表(从SQLite数据库读取) + """ + try: + platform = request.args.get('platform', 'reddit') + limit = request.args.get('limit', 50, type=int) + offset = request.args.get('offset', 0, type=int) + + sim_dir = os.path.join( + os.path.dirname(__file__), + f'../../uploads/simulations/{simulation_id}' + ) + + db_file = f"{platform}_simulation.db" + db_path = os.path.join(sim_dir, db_file) + + if not os.path.exists(db_path): + return jsonify({ + "success": True, + "data": { + "platform": platform, + "count": 0, + "posts": [], + "message": "数据库不存在,模拟可能尚未运行" + } + }) + + import sqlite3 + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + + try: + cursor.execute(""" + SELECT * FROM post + ORDER BY created_at DESC + LIMIT ? OFFSET ? + """, (limit, offset)) + + posts = [dict(row) for row in cursor.fetchall()] + + cursor.execute("SELECT COUNT(*) FROM post") + total = cursor.fetchone()[0] + + except sqlite3.OperationalError: + posts = [] + total = 0 + + conn.close() + + return jsonify({ + "success": True, + "data": { + "platform": platform, + "total": total, + "count": len(posts), + "posts": posts + } + }) + + except Exception as e: + logger.error(f"获取帖子失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + +@simulation_bp.route('//comments', methods=['GET']) +def get_simulation_comments(simulation_id: str): + """ + 获取模拟中的评论(仅Reddit) + + Query参数: + post_id: 过滤帖子ID(可选) + limit: 返回数量 + offset: 偏移量 + """ + try: + post_id = request.args.get('post_id') + limit = request.args.get('limit', 50, type=int) + offset = request.args.get('offset', 0, type=int) + + sim_dir = os.path.join( + os.path.dirname(__file__), + f'../../uploads/simulations/{simulation_id}' + ) + + db_path = os.path.join(sim_dir, "reddit_simulation.db") + + if not os.path.exists(db_path): + return jsonify({ + "success": True, + "data": { + "count": 0, + "comments": [] + } + }) + + import sqlite3 + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + + try: + if post_id: + cursor.execute(""" + SELECT * FROM comment + WHERE post_id = ? + ORDER BY created_at DESC + LIMIT ? OFFSET ? + """, (post_id, limit, offset)) + else: + cursor.execute(""" + SELECT * FROM comment + ORDER BY created_at DESC + LIMIT ? OFFSET ? + """, (limit, offset)) + + comments = [dict(row) for row in cursor.fetchall()] + + except sqlite3.OperationalError: + comments = [] + + conn.close() + + return jsonify({ + "success": True, + "data": { + "count": len(comments), + "comments": comments + } + }) + + except Exception as e: + logger.error(f"获取评论失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 diff --git a/backend/app/config.py b/backend/app/config.py index 32ad513..a23556c 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -41,6 +41,20 @@ class Config: DEFAULT_CHUNK_SIZE = 500 # 默认切块大小 DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小 + # OASIS模拟配置 + OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10')) + OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations') + + # OASIS平台可用动作配置 + OASIS_TWITTER_ACTIONS = [ + 'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST' + ] + OASIS_REDDIT_ACTIONS = [ + 'LIKE_POST', 'DISLIKE_POST', 'CREATE_POST', 'CREATE_COMMENT', + 'LIKE_COMMENT', 'DISLIKE_COMMENT', 'SEARCH_POSTS', 'SEARCH_USER', + 'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE' + ] + @classmethod def validate(cls): """验证必要配置""" diff --git a/backend/app/models/task.py b/backend/app/models/task.py index 2741f11..e15f35f 100644 --- a/backend/app/models/task.py +++ b/backend/app/models/task.py @@ -27,11 +27,12 @@ class Task: status: TaskStatus created_at: datetime updated_at: datetime - progress: int = 0 # 进度百分比 0-100 + progress: int = 0 # 总进度百分比 0-100 message: str = "" # 状态消息 result: Optional[Dict] = None # 任务结果 error: Optional[str] = None # 错误信息 metadata: Dict = field(default_factory=dict) # 额外元数据 + progress_detail: Dict = field(default_factory=dict) # 详细进度信息 def to_dict(self) -> Dict[str, Any]: """转换为字典""" @@ -43,6 +44,7 @@ class Task: "updated_at": self.updated_at.isoformat(), "progress": self.progress, "message": self.message, + "progress_detail": self.progress_detail, "result": self.result, "error": self.error, "metadata": self.metadata, @@ -108,7 +110,8 @@ class TaskManager: progress: Optional[int] = None, message: Optional[str] = None, result: Optional[Dict] = None, - error: Optional[str] = None + error: Optional[str] = None, + progress_detail: Optional[Dict] = None ): """ 更新任务状态 @@ -120,6 +123,7 @@ class TaskManager: message: 消息 result: 结果 error: 错误信息 + progress_detail: 详细进度信息 """ with self._task_lock: task = self._tasks.get(task_id) @@ -135,6 +139,8 @@ class TaskManager: task.result = result if error is not None: task.error = error + if progress_detail is not None: + task.progress_detail = progress_detail def complete_task(self, task_id: str, result: Dict): """标记任务完成""" diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py index 42cdc5b..b8f2e67 100644 --- a/backend/app/services/__init__.py +++ b/backend/app/services/__init__.py @@ -5,6 +5,47 @@ from .ontology_generator import OntologyGenerator from .graph_builder import GraphBuilderService from .text_processor import TextProcessor +from .zep_entity_reader import ZepEntityReader, EntityNode, FilteredEntities +from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile +from .simulation_manager import SimulationManager, SimulationState, SimulationStatus +from .simulation_config_generator import ( + SimulationConfigGenerator, + SimulationParameters, + AgentActivityConfig, + TimeSimulationConfig, + EventConfig, + PlatformConfig +) +from .simulation_runner import ( + SimulationRunner, + SimulationRunState, + RunnerStatus, + AgentAction, + RoundSummary +) -__all__ = ['OntologyGenerator', 'GraphBuilderService', 'TextProcessor'] +__all__ = [ + 'OntologyGenerator', + 'GraphBuilderService', + 'TextProcessor', + 'ZepEntityReader', + 'EntityNode', + 'FilteredEntities', + 'OasisProfileGenerator', + 'OasisAgentProfile', + 'SimulationManager', + 'SimulationState', + 'SimulationStatus', + 'SimulationConfigGenerator', + 'SimulationParameters', + 'AgentActivityConfig', + 'TimeSimulationConfig', + 'EventConfig', + 'PlatformConfig', + 'SimulationRunner', + 'SimulationRunState', + 'RunnerStatus', + 'AgentAction', + 'RoundSummary', +] diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py new file mode 100644 index 0000000..fe0144a --- /dev/null +++ b/backend/app/services/oasis_profile_generator.py @@ -0,0 +1,561 @@ +""" +OASIS Agent Profile生成器 +将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式 +""" + +import json +import random +from typing import Dict, Any, List, Optional +from dataclasses import dataclass, field +from datetime import datetime + +from openai import OpenAI + +from ..config import Config +from ..utils.logger import get_logger +from .zep_entity_reader import EntityNode, ZepEntityReader + +logger = get_logger('mirofish.oasis_profile') + + +@dataclass +class OasisAgentProfile: + """OASIS Agent Profile数据结构""" + # 通用字段 + user_id: int + user_name: str + name: str + bio: str + persona: str + + # 可选字段 - Reddit风格 + karma: int = 1000 + + # 可选字段 - Twitter风格 + friend_count: int = 100 + follower_count: int = 150 + statuses_count: int = 500 + + # 额外人设信息 + age: Optional[int] = None + gender: Optional[str] = None + mbti: Optional[str] = None + country: Optional[str] = None + profession: Optional[str] = None + interested_topics: List[str] = field(default_factory=list) + + # 来源实体信息 + source_entity_uuid: Optional[str] = None + source_entity_type: Optional[str] = None + + created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d")) + + def to_reddit_format(self) -> Dict[str, Any]: + """转换为Reddit平台格式""" + profile = { + "user_id": self.user_id, + "user_name": self.user_name, + "name": self.name, + "bio": self.bio, + "persona": self.persona, + "karma": self.karma, + "created_at": self.created_at, + } + + # 添加额外人设信息(如果有) + if self.age: + profile["age"] = self.age + if self.gender: + profile["gender"] = self.gender + if self.mbti: + profile["mbti"] = self.mbti + if self.country: + profile["country"] = self.country + if self.profession: + profile["profession"] = self.profession + if self.interested_topics: + profile["interested_topics"] = self.interested_topics + + return profile + + def to_twitter_format(self) -> Dict[str, Any]: + """转换为Twitter平台格式""" + profile = { + "user_id": self.user_id, + "user_name": self.user_name, + "name": self.name, + "bio": self.bio, + "persona": self.persona, + "friend_count": self.friend_count, + "follower_count": self.follower_count, + "statuses_count": self.statuses_count, + "created_at": self.created_at, + } + + # 添加额外人设信息 + if self.age: + profile["age"] = self.age + if self.gender: + profile["gender"] = self.gender + if self.mbti: + profile["mbti"] = self.mbti + if self.country: + profile["country"] = self.country + if self.profession: + profile["profession"] = self.profession + if self.interested_topics: + profile["interested_topics"] = self.interested_topics + + return profile + + def to_dict(self) -> Dict[str, Any]: + """转换为完整字典格式""" + return { + "user_id": self.user_id, + "user_name": self.user_name, + "name": self.name, + "bio": self.bio, + "persona": self.persona, + "karma": self.karma, + "friend_count": self.friend_count, + "follower_count": self.follower_count, + "statuses_count": self.statuses_count, + "age": self.age, + "gender": self.gender, + "mbti": self.mbti, + "country": self.country, + "profession": self.profession, + "interested_topics": self.interested_topics, + "source_entity_uuid": self.source_entity_uuid, + "source_entity_type": self.source_entity_type, + "created_at": self.created_at, + } + + +class OasisProfileGenerator: + """ + OASIS Profile生成器 + + 将Zep图谱中的实体转换为OASIS模拟所需的Agent Profile + """ + + # MBTI类型列表 + MBTI_TYPES = [ + "INTJ", "INTP", "ENTJ", "ENTP", + "INFJ", "INFP", "ENFJ", "ENFP", + "ISTJ", "ISFJ", "ESTJ", "ESFJ", + "ISTP", "ISFP", "ESTP", "ESFP" + ] + + # 常见国家列表 + COUNTRIES = [ + "China", "US", "UK", "Japan", "Germany", "France", + "Canada", "Australia", "Brazil", "India", "South Korea" + ] + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model_name: Optional[str] = None + ): + self.api_key = api_key or Config.LLM_API_KEY + self.base_url = base_url or Config.LLM_BASE_URL + self.model_name = model_name or Config.LLM_MODEL_NAME + + if not self.api_key: + raise ValueError("LLM_API_KEY 未配置") + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + def generate_profile_from_entity( + self, + entity: EntityNode, + user_id: int, + use_llm: bool = True + ) -> OasisAgentProfile: + """ + 从Zep实体生成OASIS Agent Profile + + Args: + entity: Zep实体节点 + user_id: 用户ID(用于OASIS) + use_llm: 是否使用LLM生成详细人设 + + Returns: + OasisAgentProfile + """ + entity_type = entity.get_entity_type() or "Entity" + + # 基础信息 + name = entity.name + user_name = self._generate_username(name) + + # 构建上下文信息 + context = self._build_entity_context(entity) + + if use_llm: + # 使用LLM生成详细人设 + profile_data = self._generate_profile_with_llm( + entity_name=name, + entity_type=entity_type, + entity_summary=entity.summary, + entity_attributes=entity.attributes, + context=context + ) + else: + # 使用规则生成基础人设 + profile_data = self._generate_profile_rule_based( + entity_name=name, + entity_type=entity_type, + entity_summary=entity.summary, + entity_attributes=entity.attributes + ) + + return OasisAgentProfile( + user_id=user_id, + user_name=user_name, + name=name, + bio=profile_data.get("bio", f"{entity_type}: {name}"), + persona=profile_data.get("persona", entity.summary or f"A {entity_type} named {name}."), + karma=profile_data.get("karma", random.randint(500, 5000)), + friend_count=profile_data.get("friend_count", random.randint(50, 500)), + follower_count=profile_data.get("follower_count", random.randint(100, 1000)), + statuses_count=profile_data.get("statuses_count", random.randint(100, 2000)), + age=profile_data.get("age"), + gender=profile_data.get("gender"), + mbti=profile_data.get("mbti"), + country=profile_data.get("country"), + profession=profile_data.get("profession"), + interested_topics=profile_data.get("interested_topics", []), + source_entity_uuid=entity.uuid, + source_entity_type=entity_type, + ) + + def _generate_username(self, name: str) -> str: + """生成用户名""" + # 移除特殊字符,转换为小写 + username = name.lower().replace(" ", "_") + username = ''.join(c for c in username if c.isalnum() or c == '_') + + # 添加随机后缀避免重复 + suffix = random.randint(100, 999) + return f"{username}_{suffix}" + + def _build_entity_context(self, entity: EntityNode) -> str: + """构建实体的上下文信息""" + context_parts = [] + + # 添加相关边信息 + if entity.related_edges: + relationships = [] + for edge in entity.related_edges[:10]: # 最多取10条 + if edge.get("fact"): + relationships.append(edge["fact"]) + + if relationships: + context_parts.append("Related facts:\n" + "\n".join(f"- {r}" for r in relationships)) + + # 添加关联节点信息 + if entity.related_nodes: + related_names = [n["name"] for n in entity.related_nodes[:5]] + if related_names: + context_parts.append(f"Related to: {', '.join(related_names)}") + + return "\n\n".join(context_parts) + + def _generate_profile_with_llm( + self, + entity_name: str, + entity_type: str, + entity_summary: str, + entity_attributes: Dict[str, Any], + context: str + ) -> Dict[str, Any]: + """使用LLM生成详细人设""" + + prompt = f"""Based on the following entity information, generate a detailed social media user profile for simulation purposes. + +Entity Information: +- Name: {entity_name} +- Type: {entity_type} +- Summary: {entity_summary} +- Attributes: {json.dumps(entity_attributes, ensure_ascii=False)} + +Context: +{context} + +Generate a JSON object with the following fields: +{{ + "bio": "A short bio (max 150 chars) suitable for social media", + "persona": "A detailed persona description (2-3 sentences) describing personality, interests, and behavior patterns", + "age": , + "gender": "", + "mbti": "", + "country": "", + "profession": "", + "interested_topics": ["topic1", "topic2", ...] +}} + +Important: +- The profile should be consistent with the entity type and context +- Make the persona feel realistic and suitable for social media simulation +- If the entity is an organization, institution, or non-person, adapt the profile accordingly (e.g., as an official account) +- Return ONLY the JSON object, no additional text""" + + try: + # 使用重试机制调用LLM API + from ..utils.retry import RetryableAPIClient + + retry_client = RetryableAPIClient(max_retries=3, initial_delay=1.0) + + def call_llm(): + return self.client.chat.completions.create( + model=self.model_name, + messages=[ + {"role": "system", "content": "You are a profile generator for social media simulation. Generate realistic user profiles based on entity information."}, + {"role": "user", "content": prompt} + ], + response_format={"type": "json_object"}, + temperature=0.7 + ) + + response = retry_client.call_with_retry(call_llm) + result = json.loads(response.choices[0].message.content) + return result + + except Exception as e: + logger.warning(f"LLM生成人设失败(已重试): {str(e)}, 使用规则生成") + return self._generate_profile_rule_based( + entity_name, entity_type, entity_summary, entity_attributes + ) + + def _generate_profile_rule_based( + self, + entity_name: str, + entity_type: str, + entity_summary: str, + entity_attributes: Dict[str, Any] + ) -> Dict[str, Any]: + """使用规则生成基础人设""" + + # 根据实体类型生成不同的人设 + entity_type_lower = entity_type.lower() + + if entity_type_lower in ["student", "alumni"]: + return { + "bio": f"{entity_type} with interests in academics and social issues.", + "persona": f"{entity_name} is a {entity_type.lower()} who is actively engaged in academic and social discussions. They enjoy sharing perspectives and connecting with peers.", + "age": random.randint(18, 30), + "gender": random.choice(["male", "female"]), + "mbti": random.choice(self.MBTI_TYPES), + "country": random.choice(self.COUNTRIES), + "profession": "Student", + "interested_topics": ["Education", "Social Issues", "Technology"], + } + + elif entity_type_lower in ["publicfigure", "expert", "faculty"]: + return { + "bio": f"Expert and thought leader in their field.", + "persona": f"{entity_name} is a recognized {entity_type.lower()} who shares insights and opinions on important matters. They are known for their expertise and influence in public discourse.", + "age": random.randint(35, 60), + "gender": random.choice(["male", "female"]), + "mbti": random.choice(["ENTJ", "INTJ", "ENTP", "INTP"]), + "country": random.choice(self.COUNTRIES), + "profession": entity_attributes.get("occupation", "Expert"), + "interested_topics": ["Politics", "Economics", "Culture & Society"], + } + + elif entity_type_lower in ["mediaoutlet", "socialmediaplatform"]: + return { + "bio": f"Official account for {entity_name}. News and updates.", + "persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.", + "profession": "Media", + "interested_topics": ["General News", "Current Events", "Public Affairs"], + } + + elif entity_type_lower in ["university", "governmentagency", "ngo", "organization"]: + return { + "bio": f"Official account of {entity_name}.", + "persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.", + "profession": entity_type, + "interested_topics": ["Public Policy", "Community", "Official Announcements"], + } + + else: + # 默认人设 + return { + "bio": entity_summary[:150] if entity_summary else f"{entity_type}: {entity_name}", + "persona": entity_summary or f"{entity_name} is a {entity_type.lower()} participating in social discussions.", + "age": random.randint(25, 50), + "gender": random.choice(["male", "female"]), + "mbti": random.choice(self.MBTI_TYPES), + "country": random.choice(self.COUNTRIES), + "profession": entity_type, + "interested_topics": ["General", "Social Issues"], + } + + def generate_profiles_from_entities( + self, + entities: List[EntityNode], + use_llm: bool = True, + progress_callback: Optional[callable] = None + ) -> List[OasisAgentProfile]: + """ + 批量从实体生成Agent Profile + + Args: + entities: 实体列表 + use_llm: 是否使用LLM生成详细人设 + progress_callback: 进度回调函数 (current, total, message) + + Returns: + Agent Profile列表 + """ + profiles = [] + total = len(entities) + + for idx, entity in enumerate(entities): + if progress_callback: + progress_callback(idx + 1, total, f"生成 {entity.name} 的人设...") + + try: + profile = self.generate_profile_from_entity( + entity=entity, + user_id=idx, + use_llm=use_llm + ) + profiles.append(profile) + + except Exception as e: + logger.error(f"生成实体 {entity.name} 的人设失败: {str(e)}") + # 创建一个基础profile + profiles.append(OasisAgentProfile( + user_id=idx, + user_name=self._generate_username(entity.name), + name=entity.name, + bio=f"{entity.get_entity_type() or 'Entity'}: {entity.name}", + persona=entity.summary or f"A participant in social discussions.", + source_entity_uuid=entity.uuid, + source_entity_type=entity.get_entity_type(), + )) + + return profiles + + def save_profiles( + self, + profiles: List[OasisAgentProfile], + file_path: str, + platform: str = "reddit" + ): + """ + 保存Profile到文件(根据平台选择正确格式) + + OASIS平台格式要求: + - Twitter: CSV格式 + - Reddit: JSON格式 + + Args: + profiles: Profile列表 + file_path: 文件路径 + platform: 平台类型 ("reddit" 或 "twitter") + """ + if platform == "twitter": + self._save_twitter_csv(profiles, file_path) + else: + self._save_reddit_json(profiles, file_path) + + def _save_twitter_csv(self, profiles: List[OasisAgentProfile], file_path: str): + """ + 保存Twitter Profile为CSV格式 + + OASIS Twitter要求的CSV字段: + user_id, user_name, name, bio, friend_count, follower_count, statuses_count, created_at + """ + import csv + + # 确保文件扩展名是.csv + if not file_path.endswith('.csv'): + file_path = file_path.replace('.json', '.csv') + + with open(file_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + + # 写入表头 + headers = ['user_id', 'user_name', 'name', 'bio', 'friend_count', + 'follower_count', 'statuses_count', 'created_at'] + writer.writerow(headers) + + # 写入数据行 + for profile in profiles: + # bio需要处理换行符和逗号 + bio = profile.bio.replace('\n', ' ').replace('\r', ' ') + row = [ + profile.user_id, + profile.user_name, + profile.name, + bio, + profile.friend_count, + profile.follower_count, + profile.statuses_count, + profile.created_at + ] + writer.writerow(row) + + logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (CSV格式)") + + def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str): + """ + 保存Reddit Profile为JSON格式 + + OASIS Reddit支持两种JSON格式: + 1. 基础格式: user_id, user_name, name, bio, karma, created_at + 2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics + + 我们使用详细格式,与用户示例数据(36个简单人设.json)保持一致 + """ + data = [] + for profile in profiles: + # 使用详细格式(与用户示例兼容) + item = { + "realname": profile.name, + "username": profile.user_name, + "bio": profile.bio[:150] if profile.bio else "", # OASIS bio限制150字符 + "persona": profile.persona or f"{profile.name} is a participant in social discussions.", + } + + # 添加人设详情字段 + if profile.age: + item["age"] = profile.age + if profile.gender: + item["gender"] = profile.gender + if profile.mbti: + item["mbti"] = profile.mbti + if profile.country: + item["country"] = profile.country + if profile.profession: + item["profession"] = profile.profession + if profile.interested_topics: + item["interested_topics"] = profile.interested_topics + + data.append(item) + + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式)") + + # 保留旧方法名作为别名,保持向后兼容 + def save_profiles_to_json( + self, + profiles: List[OasisAgentProfile], + file_path: str, + platform: str = "reddit" + ): + """[已废弃] 请使用 save_profiles() 方法""" + logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法") + self.save_profiles(profiles, file_path, platform) + diff --git a/backend/app/services/simulation_config_generator.py b/backend/app/services/simulation_config_generator.py new file mode 100644 index 0000000..b75ee24 --- /dev/null +++ b/backend/app/services/simulation_config_generator.py @@ -0,0 +1,584 @@ +""" +模拟配置智能生成器 +使用LLM根据模拟需求、文档内容、图谱信息自动生成细致的模拟参数 +实现全程自动化,无需人工设置参数 +""" + +import json +from typing import Dict, Any, List, Optional +from dataclasses import dataclass, field, asdict +from datetime import datetime + +from openai import OpenAI + +from ..config import Config +from ..utils.logger import get_logger +from .zep_entity_reader import EntityNode, ZepEntityReader + +logger = get_logger('mirofish.simulation_config') + + +@dataclass +class AgentActivityConfig: + """单个Agent的活动配置""" + agent_id: int + entity_uuid: str + entity_name: str + entity_type: str + + # 活跃度配置 (0.0-1.0) + activity_level: float = 0.5 # 整体活跃度 + + # 发言频率(每小时预期发言次数) + posts_per_hour: float = 1.0 + comments_per_hour: float = 2.0 + + # 活跃时间段(24小时制,0-23) + active_hours: List[int] = field(default_factory=lambda: list(range(8, 23))) + + # 响应速度(对热点事件的反应延迟,单位:模拟分钟) + response_delay_min: int = 5 + response_delay_max: int = 60 + + # 情感倾向 (-1.0到1.0,负面到正面) + sentiment_bias: float = 0.0 + + # 立场(对特定话题的态度) + stance: str = "neutral" # supportive, opposing, neutral, observer + + # 影响力权重(决定其发言被其他Agent看到的概率) + influence_weight: float = 1.0 + + +@dataclass +class TimeSimulationConfig: + """时间模拟配置""" + # 模拟总时长(模拟小时数) + total_simulation_hours: int = 72 # 默认模拟72小时(3天) + + # 每轮代表的时间(模拟分钟) + minutes_per_round: int = 30 + + # 每小时激活的Agent数量范围 + agents_per_hour_min: int = 5 + agents_per_hour_max: int = 20 + + # 高峰时段(活跃度提升) + peak_hours: List[int] = field(default_factory=lambda: [9, 10, 11, 14, 15, 20, 21, 22]) + peak_activity_multiplier: float = 1.5 + + # 低谷时段(活跃度降低) + off_peak_hours: List[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5, 6]) + off_peak_activity_multiplier: float = 0.3 + + +@dataclass +class EventConfig: + """事件配置""" + # 初始事件(模拟开始时的触发事件) + initial_posts: List[Dict[str, Any]] = field(default_factory=list) + + # 定时事件(在特定时间触发的事件) + scheduled_events: List[Dict[str, Any]] = field(default_factory=list) + + # 热点话题关键词 + hot_topics: List[str] = field(default_factory=list) + + # 舆论引导方向 + narrative_direction: str = "" + + +@dataclass +class PlatformConfig: + """平台特定配置""" + platform: str # twitter or reddit + + # 推荐算法权重 + recency_weight: float = 0.4 # 时间新鲜度 + popularity_weight: float = 0.3 # 热度 + relevance_weight: float = 0.3 # 相关性 + + # 病毒传播阈值(达到多少互动后触发扩散) + viral_threshold: int = 10 + + # 回声室效应强度(相似观点聚集程度) + echo_chamber_strength: float = 0.5 + + +@dataclass +class SimulationParameters: + """完整的模拟参数配置""" + # 基础信息 + simulation_id: str + project_id: str + graph_id: str + simulation_requirement: str + + # 时间配置 + time_config: TimeSimulationConfig = field(default_factory=TimeSimulationConfig) + + # Agent配置列表 + agent_configs: List[AgentActivityConfig] = field(default_factory=list) + + # 事件配置 + event_config: EventConfig = field(default_factory=EventConfig) + + # 平台配置 + twitter_config: Optional[PlatformConfig] = None + reddit_config: Optional[PlatformConfig] = None + + # LLM配置 + llm_model: str = "" + llm_base_url: str = "" + + # 生成元数据 + generated_at: str = field(default_factory=lambda: datetime.now().isoformat()) + generation_reasoning: str = "" # LLM的推理说明 + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "simulation_id": self.simulation_id, + "project_id": self.project_id, + "graph_id": self.graph_id, + "simulation_requirement": self.simulation_requirement, + "time_config": asdict(self.time_config), + "agent_configs": [asdict(a) for a in self.agent_configs], + "event_config": asdict(self.event_config), + "twitter_config": asdict(self.twitter_config) if self.twitter_config else None, + "reddit_config": asdict(self.reddit_config) if self.reddit_config else None, + "llm_model": self.llm_model, + "llm_base_url": self.llm_base_url, + "generated_at": self.generated_at, + "generation_reasoning": self.generation_reasoning, + } + + def to_json(self, indent: int = 2) -> str: + """转换为JSON字符串""" + return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent) + + +class SimulationConfigGenerator: + """ + 模拟配置智能生成器 + + 使用LLM分析模拟需求、文档内容、图谱实体信息, + 自动生成最佳的模拟参数配置 + """ + + # 上下文最大字符数 + MAX_CONTEXT_LENGTH = 50000 + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model_name: Optional[str] = None + ): + self.api_key = api_key or Config.LLM_API_KEY + self.base_url = base_url or Config.LLM_BASE_URL + self.model_name = model_name or Config.LLM_MODEL_NAME + + if not self.api_key: + raise ValueError("LLM_API_KEY 未配置") + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + def generate_config( + self, + simulation_id: str, + project_id: str, + graph_id: str, + simulation_requirement: str, + document_text: str, + entities: List[EntityNode], + enable_twitter: bool = True, + enable_reddit: bool = True, + ) -> SimulationParameters: + """ + 智能生成完整的模拟配置 + + Args: + simulation_id: 模拟ID + project_id: 项目ID + graph_id: 图谱ID + simulation_requirement: 模拟需求描述 + document_text: 原始文档内容 + entities: 过滤后的实体列表 + enable_twitter: 是否启用Twitter + enable_reddit: 是否启用Reddit + + Returns: + SimulationParameters: 完整的模拟参数 + """ + logger.info(f"开始智能生成模拟配置: simulation_id={simulation_id}") + + # 1. 构建上下文信息(截断到50000字符) + context = self._build_context( + simulation_requirement=simulation_requirement, + document_text=document_text, + entities=entities + ) + + # 2. 调用LLM生成配置 + llm_result = self._generate_config_with_llm( + context=context, + entities=entities, + enable_twitter=enable_twitter, + enable_reddit=enable_reddit + ) + + # 3. 构建SimulationParameters对象 + params = self._build_parameters( + simulation_id=simulation_id, + project_id=project_id, + graph_id=graph_id, + simulation_requirement=simulation_requirement, + entities=entities, + llm_result=llm_result, + enable_twitter=enable_twitter, + enable_reddit=enable_reddit + ) + + logger.info(f"模拟配置生成完成: {len(params.agent_configs)} 个Agent配置") + + return params + + def _build_context( + self, + simulation_requirement: str, + document_text: str, + entities: List[EntityNode] + ) -> str: + """构建LLM上下文,截断到最大长度""" + + # 实体摘要 + entity_summary = self._summarize_entities(entities) + + # 构建上下文 + context_parts = [ + f"## 模拟需求\n{simulation_requirement}", + f"\n## 实体信息 ({len(entities)}个)\n{entity_summary}", + ] + + current_length = sum(len(p) for p in context_parts) + remaining_length = self.MAX_CONTEXT_LENGTH - current_length - 500 # 留500字符余量 + + if remaining_length > 0 and document_text: + doc_text = document_text[:remaining_length] + if len(document_text) > remaining_length: + doc_text += "\n...(文档已截断)" + context_parts.append(f"\n## 原始文档内容\n{doc_text}") + + return "\n".join(context_parts) + + def _summarize_entities(self, entities: List[EntityNode]) -> str: + """生成实体摘要""" + lines = [] + + # 按类型分组 + by_type: Dict[str, List[EntityNode]] = {} + for e in entities: + t = e.get_entity_type() or "Unknown" + if t not in by_type: + by_type[t] = [] + by_type[t].append(e) + + for entity_type, type_entities in by_type.items(): + lines.append(f"\n### {entity_type} ({len(type_entities)}个)") + for e in type_entities[:10]: # 每类最多显示10个 + summary_preview = (e.summary[:100] + "...") if len(e.summary) > 100 else e.summary + lines.append(f"- {e.name}: {summary_preview}") + if len(type_entities) > 10: + lines.append(f" ... 还有 {len(type_entities) - 10} 个") + + return "\n".join(lines) + + def _generate_config_with_llm( + self, + context: str, + entities: List[EntityNode], + enable_twitter: bool, + enable_reddit: bool + ) -> Dict[str, Any]: + """调用LLM生成配置""" + + # 构建实体列表用于Agent配置 + entity_list = [] + for i, e in enumerate(entities): + entity_list.append({ + "agent_id": i, + "entity_uuid": e.uuid, + "entity_name": e.name, + "entity_type": e.get_entity_type() or "Unknown", + "summary": e.summary[:200] if e.summary else "" + }) + + prompt = f"""你是一个社交媒体舆论模拟专家。请根据以下信息,生成详细的模拟参数配置。 + +{context} + +## 实体列表(需要为每个实体生成活动配置) +```json +{json.dumps(entity_list, ensure_ascii=False, indent=2)} +``` + +## 任务 +请生成一个JSON配置,包含以下部分: + +1. **time_config** - 时间模拟配置 + - total_simulation_hours: 模拟总时长(小时),根据事件性质决定(短期热点24-72小时,长期舆论168-336小时) + - minutes_per_round: 每轮代表的时间(分钟),建议15-60 + - agents_per_hour_min/max: 每小时激活的Agent数量范围 + - peak_hours: 高峰时段列表(0-23) + - off_peak_hours: 低谷时段列表 + +2. **agent_configs** - 每个Agent的活动配置(必须为每个实体生成) + 对于每个agent_id,设置: + - activity_level: 活跃度(0.0-1.0),官方机构通常0.1-0.3,媒体0.3-0.5,个人0.5-0.9 + - posts_per_hour: 每小时发帖频率,官方机构0.05-0.2,媒体0.5-2,个人0.1-1 + - comments_per_hour: 每小时评论频率 + - active_hours: 活跃时间段列表,官方通常工作时间,个人更分散 + - response_delay_min/max: 响应延迟(模拟分钟),官方较慢(30-180),个人较快(1-30) + - sentiment_bias: 情感倾向(-1到1),根据实体立场设置 + - stance: 立场(supportive/opposing/neutral/observer) + - influence_weight: 影响力权重,知名人物和媒体较高 + +3. **event_config** - 事件配置 + - initial_posts: 初始帖子列表,包含content和poster_agent_id + - hot_topics: 热点话题关键词列表 + - narrative_direction: 舆论发展方向描述 + +4. **platform_configs** - 平台配置(如果启用) + - viral_threshold: 病毒传播阈值 + - echo_chamber_strength: 回声室效应强度(0-1) + +5. **reasoning** - 你的推理说明,解释为什么这样设置参数 + +## 重要原则 +- 官方机构(University、GovernmentAgency)发言频率低但影响力大 +- 媒体(MediaOutlet)发言频率中等,传播速度快 +- 个人(Student、PublicFigure)发言频率高但影响力分散 +- 根据模拟需求判断各实体的立场和情感倾向 +- 时间配置要符合真实社交媒体的使用规律 + +请返回JSON格式,不要包含markdown代码块标记。""" + + try: + # 使用重试机制调用LLM API + from ..utils.retry import RetryableAPIClient + + retry_client = RetryableAPIClient(max_retries=3, initial_delay=2.0, max_delay=60.0) + + def call_llm(): + return self.client.chat.completions.create( + model=self.model_name, + messages=[ + { + "role": "system", + "content": "你是社交媒体舆论模拟专家,擅长设计真实的模拟参数。返回纯JSON格式,不要markdown。" + }, + {"role": "user", "content": prompt} + ], + response_format={"type": "json_object"}, + temperature=0.7, + max_tokens=8000 + ) + + response = retry_client.call_with_retry(call_llm) + result = json.loads(response.choices[0].message.content) + logger.info(f"LLM配置生成成功") + return result + + except Exception as e: + logger.error(f"LLM配置生成失败(已重试): {str(e)}") + # 返回默认配置 + return self._generate_default_config(entities) + + def _generate_default_config(self, entities: List[EntityNode]) -> Dict[str, Any]: + """生成默认配置(LLM失败时的fallback)""" + agent_configs = [] + + for i, e in enumerate(entities): + entity_type = (e.get_entity_type() or "Unknown").lower() + + # 根据实体类型设置默认参数 + if entity_type in ["university", "governmentagency", "ngo"]: + config = { + "agent_id": i, + "activity_level": 0.2, + "posts_per_hour": 0.1, + "comments_per_hour": 0.05, + "active_hours": list(range(9, 18)), + "response_delay_min": 60, + "response_delay_max": 240, + "sentiment_bias": 0.0, + "stance": "neutral", + "influence_weight": 3.0 + } + elif entity_type in ["mediaoutlet"]: + config = { + "agent_id": i, + "activity_level": 0.6, + "posts_per_hour": 1.0, + "comments_per_hour": 0.5, + "active_hours": list(range(6, 24)), + "response_delay_min": 5, + "response_delay_max": 30, + "sentiment_bias": 0.0, + "stance": "observer", + "influence_weight": 2.5 + } + elif entity_type in ["publicfigure", "expert"]: + config = { + "agent_id": i, + "activity_level": 0.5, + "posts_per_hour": 0.3, + "comments_per_hour": 0.5, + "active_hours": list(range(8, 23)), + "response_delay_min": 10, + "response_delay_max": 60, + "sentiment_bias": 0.0, + "stance": "neutral", + "influence_weight": 2.0 + } + else: # Student, Person, etc. + config = { + "agent_id": i, + "activity_level": 0.7, + "posts_per_hour": 0.5, + "comments_per_hour": 1.0, + "active_hours": list(range(7, 24)), + "response_delay_min": 1, + "response_delay_max": 20, + "sentiment_bias": 0.0, + "stance": "neutral", + "influence_weight": 1.0 + } + + agent_configs.append(config) + + return { + "time_config": { + "total_simulation_hours": 72, + "minutes_per_round": 30, + "agents_per_hour_min": max(1, len(entities) // 10), + "agents_per_hour_max": max(5, len(entities) // 3), + "peak_hours": [9, 10, 11, 14, 15, 20, 21, 22], + "off_peak_hours": [0, 1, 2, 3, 4, 5] + }, + "agent_configs": agent_configs, + "event_config": { + "initial_posts": [], + "hot_topics": [], + "narrative_direction": "" + }, + "reasoning": "使用默认配置(LLM生成失败)" + } + + def _build_parameters( + self, + simulation_id: str, + project_id: str, + graph_id: str, + simulation_requirement: str, + entities: List[EntityNode], + llm_result: Dict[str, Any], + enable_twitter: bool, + enable_reddit: bool + ) -> SimulationParameters: + """根据LLM结果构建SimulationParameters对象""" + + # 时间配置 + time_cfg = llm_result.get("time_config", {}) + time_config = TimeSimulationConfig( + total_simulation_hours=time_cfg.get("total_simulation_hours", 72), + minutes_per_round=time_cfg.get("minutes_per_round", 30), + agents_per_hour_min=time_cfg.get("agents_per_hour_min", 5), + agents_per_hour_max=time_cfg.get("agents_per_hour_max", 20), + peak_hours=time_cfg.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]), + off_peak_hours=time_cfg.get("off_peak_hours", [0, 1, 2, 3, 4, 5]), + peak_activity_multiplier=time_cfg.get("peak_activity_multiplier", 1.5), + off_peak_activity_multiplier=time_cfg.get("off_peak_activity_multiplier", 0.3) + ) + + # Agent配置 + agent_configs = [] + llm_agent_configs = {cfg["agent_id"]: cfg for cfg in llm_result.get("agent_configs", [])} + + for i, entity in enumerate(entities): + cfg = llm_agent_configs.get(i, {}) + + agent_config = AgentActivityConfig( + agent_id=i, + entity_uuid=entity.uuid, + entity_name=entity.name, + entity_type=entity.get_entity_type() or "Unknown", + activity_level=cfg.get("activity_level", 0.5), + posts_per_hour=cfg.get("posts_per_hour", 0.5), + comments_per_hour=cfg.get("comments_per_hour", 1.0), + active_hours=cfg.get("active_hours", list(range(8, 23))), + response_delay_min=cfg.get("response_delay_min", 5), + response_delay_max=cfg.get("response_delay_max", 60), + sentiment_bias=cfg.get("sentiment_bias", 0.0), + stance=cfg.get("stance", "neutral"), + influence_weight=cfg.get("influence_weight", 1.0) + ) + agent_configs.append(agent_config) + + # 事件配置 + event_cfg = llm_result.get("event_config", {}) + event_config = EventConfig( + initial_posts=event_cfg.get("initial_posts", []), + scheduled_events=event_cfg.get("scheduled_events", []), + hot_topics=event_cfg.get("hot_topics", []), + narrative_direction=event_cfg.get("narrative_direction", "") + ) + + # 平台配置 + twitter_config = None + reddit_config = None + + platform_cfgs = llm_result.get("platform_configs", {}) + + if enable_twitter: + tw_cfg = platform_cfgs.get("twitter", {}) + twitter_config = PlatformConfig( + platform="twitter", + recency_weight=tw_cfg.get("recency_weight", 0.4), + popularity_weight=tw_cfg.get("popularity_weight", 0.3), + relevance_weight=tw_cfg.get("relevance_weight", 0.3), + viral_threshold=tw_cfg.get("viral_threshold", 10), + echo_chamber_strength=tw_cfg.get("echo_chamber_strength", 0.5) + ) + + if enable_reddit: + rd_cfg = platform_cfgs.get("reddit", {}) + reddit_config = PlatformConfig( + platform="reddit", + recency_weight=rd_cfg.get("recency_weight", 0.3), + popularity_weight=rd_cfg.get("popularity_weight", 0.4), + relevance_weight=rd_cfg.get("relevance_weight", 0.3), + viral_threshold=rd_cfg.get("viral_threshold", 15), + echo_chamber_strength=rd_cfg.get("echo_chamber_strength", 0.6) + ) + + return SimulationParameters( + simulation_id=simulation_id, + project_id=project_id, + graph_id=graph_id, + simulation_requirement=simulation_requirement, + time_config=time_config, + agent_configs=agent_configs, + event_config=event_config, + twitter_config=twitter_config, + reddit_config=reddit_config, + llm_model=self.model_name, + llm_base_url=self.base_url, + generation_reasoning=llm_result.get("reasoning", "") + ) + + diff --git a/backend/app/services/simulation_manager.py b/backend/app/services/simulation_manager.py new file mode 100644 index 0000000..fbb42a6 --- /dev/null +++ b/backend/app/services/simulation_manager.py @@ -0,0 +1,546 @@ +""" +OASIS模拟管理器 +管理Twitter和Reddit双平台并行模拟 +使用预设脚本 + LLM智能生成配置参数 +""" + +import os +import json +import shutil +from typing import Dict, Any, List, Optional +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum + +from ..config import Config +from ..utils.logger import get_logger +from .zep_entity_reader import ZepEntityReader, FilteredEntities +from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile +from .simulation_config_generator import SimulationConfigGenerator, SimulationParameters + +logger = get_logger('mirofish.simulation') + + +class SimulationStatus(str, Enum): + """模拟状态""" + CREATED = "created" + PREPARING = "preparing" + READY = "ready" + RUNNING = "running" + PAUSED = "paused" + COMPLETED = "completed" + FAILED = "failed" + + +class PlatformType(str, Enum): + """平台类型""" + TWITTER = "twitter" + REDDIT = "reddit" + + +@dataclass +class SimulationState: + """模拟状态""" + simulation_id: str + project_id: str + graph_id: str + + # 平台启用状态 + enable_twitter: bool = True + enable_reddit: bool = True + + # 状态 + status: SimulationStatus = SimulationStatus.CREATED + + # 准备阶段数据 + entities_count: int = 0 + profiles_count: int = 0 + entity_types: List[str] = field(default_factory=list) + + # 配置生成信息 + config_generated: bool = False + config_reasoning: str = "" + + # 运行时数据 + current_round: int = 0 + twitter_status: str = "not_started" + reddit_status: str = "not_started" + + # 时间戳 + created_at: str = field(default_factory=lambda: datetime.now().isoformat()) + updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) + + # 错误信息 + error: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """完整状态字典(内部使用)""" + return { + "simulation_id": self.simulation_id, + "project_id": self.project_id, + "graph_id": self.graph_id, + "enable_twitter": self.enable_twitter, + "enable_reddit": self.enable_reddit, + "status": self.status.value, + "entities_count": self.entities_count, + "profiles_count": self.profiles_count, + "entity_types": self.entity_types, + "config_generated": self.config_generated, + "config_reasoning": self.config_reasoning, + "current_round": self.current_round, + "twitter_status": self.twitter_status, + "reddit_status": self.reddit_status, + "created_at": self.created_at, + "updated_at": self.updated_at, + "error": self.error, + } + + def to_simple_dict(self) -> Dict[str, Any]: + """简化状态字典(API返回使用)""" + return { + "simulation_id": self.simulation_id, + "project_id": self.project_id, + "graph_id": self.graph_id, + "status": self.status.value, + "entities_count": self.entities_count, + "profiles_count": self.profiles_count, + "entity_types": self.entity_types, + "config_generated": self.config_generated, + "error": self.error, + } + + +class SimulationManager: + """ + 模拟管理器 + + 核心功能: + 1. 从Zep图谱读取实体并过滤 + 2. 生成OASIS Agent Profile + 3. 使用LLM智能生成模拟配置参数 + 4. 准备预设脚本所需的所有文件 + """ + + # 模拟数据存储目录 + SIMULATION_DATA_DIR = os.path.join( + os.path.dirname(__file__), + '../../uploads/simulations' + ) + + # 预设脚本目录 + SCRIPTS_DIR = os.path.join( + os.path.dirname(__file__), + '../../scripts' + ) + + def __init__(self): + # 确保目录存在 + os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True) + + # 内存中的模拟状态缓存 + self._simulations: Dict[str, SimulationState] = {} + + def _get_simulation_dir(self, simulation_id: str) -> str: + """获取模拟数据目录""" + sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id) + os.makedirs(sim_dir, exist_ok=True) + return sim_dir + + def _save_simulation_state(self, state: SimulationState): + """保存模拟状态到文件""" + sim_dir = self._get_simulation_dir(state.simulation_id) + state_file = os.path.join(sim_dir, "state.json") + + state.updated_at = datetime.now().isoformat() + + with open(state_file, 'w', encoding='utf-8') as f: + json.dump(state.to_dict(), f, ensure_ascii=False, indent=2) + + self._simulations[state.simulation_id] = state + + def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]: + """从文件加载模拟状态""" + if simulation_id in self._simulations: + return self._simulations[simulation_id] + + sim_dir = self._get_simulation_dir(simulation_id) + state_file = os.path.join(sim_dir, "state.json") + + if not os.path.exists(state_file): + return None + + with open(state_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + state = SimulationState( + simulation_id=simulation_id, + project_id=data.get("project_id", ""), + graph_id=data.get("graph_id", ""), + enable_twitter=data.get("enable_twitter", True), + enable_reddit=data.get("enable_reddit", True), + status=SimulationStatus(data.get("status", "created")), + entities_count=data.get("entities_count", 0), + profiles_count=data.get("profiles_count", 0), + entity_types=data.get("entity_types", []), + config_generated=data.get("config_generated", False), + config_reasoning=data.get("config_reasoning", ""), + current_round=data.get("current_round", 0), + twitter_status=data.get("twitter_status", "not_started"), + reddit_status=data.get("reddit_status", "not_started"), + created_at=data.get("created_at", datetime.now().isoformat()), + updated_at=data.get("updated_at", datetime.now().isoformat()), + error=data.get("error"), + ) + + self._simulations[simulation_id] = state + return state + + def create_simulation( + self, + project_id: str, + graph_id: str, + enable_twitter: bool = True, + enable_reddit: bool = True, + ) -> SimulationState: + """ + 创建新的模拟 + + Args: + project_id: 项目ID + graph_id: Zep图谱ID + enable_twitter: 是否启用Twitter模拟 + enable_reddit: 是否启用Reddit模拟 + + Returns: + SimulationState + """ + import uuid + simulation_id = f"sim_{uuid.uuid4().hex[:12]}" + + state = SimulationState( + simulation_id=simulation_id, + project_id=project_id, + graph_id=graph_id, + enable_twitter=enable_twitter, + enable_reddit=enable_reddit, + status=SimulationStatus.CREATED, + ) + + self._save_simulation_state(state) + logger.info(f"创建模拟: {simulation_id}, project={project_id}, graph={graph_id}") + + return state + + def prepare_simulation( + self, + simulation_id: str, + simulation_requirement: str, + document_text: str, + defined_entity_types: Optional[List[str]] = None, + use_llm_for_profiles: bool = True, + progress_callback: Optional[callable] = None + ) -> SimulationState: + """ + 准备模拟环境(全程自动化) + + 步骤: + 1. 从Zep图谱读取并过滤实体 + 2. 为每个实体生成OASIS Agent Profile(可选LLM增强) + 3. 使用LLM智能生成模拟配置参数(时间、活跃度、发言频率等) + 4. 保存配置文件和Profile文件 + 5. 复制预设脚本到模拟目录 + + Args: + simulation_id: 模拟ID + simulation_requirement: 模拟需求描述(用于LLM生成配置) + document_text: 原始文档内容(用于LLM理解背景) + defined_entity_types: 预定义的实体类型(可选) + use_llm_for_profiles: 是否使用LLM生成详细人设 + progress_callback: 进度回调函数 (stage, progress, message) + + Returns: + SimulationState + """ + state = self._load_simulation_state(simulation_id) + if not state: + raise ValueError(f"模拟不存在: {simulation_id}") + + try: + state.status = SimulationStatus.PREPARING + self._save_simulation_state(state) + + sim_dir = self._get_simulation_dir(simulation_id) + + # ========== 阶段1: 读取并过滤实体 ========== + if progress_callback: + progress_callback("reading", 0, "正在连接Zep图谱...") + + reader = ZepEntityReader() + + if progress_callback: + progress_callback("reading", 30, "正在读取节点数据...") + + filtered = reader.filter_defined_entities( + graph_id=state.graph_id, + defined_entity_types=defined_entity_types, + enrich_with_edges=True + ) + + state.entities_count = filtered.filtered_count + state.entity_types = list(filtered.entity_types) + + if progress_callback: + progress_callback( + "reading", 100, + f"完成,共 {filtered.filtered_count} 个实体", + current=filtered.filtered_count, + total=filtered.filtered_count + ) + + if filtered.filtered_count == 0: + state.status = SimulationStatus.FAILED + state.error = "没有找到符合条件的实体,请检查图谱是否正确构建" + self._save_simulation_state(state) + return state + + # ========== 阶段2: 生成Agent Profile ========== + total_entities = len(filtered.entities) + + if progress_callback: + progress_callback( + "generating_profiles", 0, + "开始生成...", + current=0, + total=total_entities + ) + + generator = OasisProfileGenerator() + + def profile_progress(current, total, msg): + if progress_callback: + progress_callback( + "generating_profiles", + int(current / total * 100), + msg, + current=current, + total=total, + item_name=msg + ) + + profiles = generator.generate_profiles_from_entities( + entities=filtered.entities, + use_llm=use_llm_for_profiles, + progress_callback=profile_progress + ) + + state.profiles_count = len(profiles) + + # 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式) + if progress_callback: + progress_callback( + "generating_profiles", 95, + "保存Profile文件...", + current=total_entities, + total=total_entities + ) + + if state.enable_reddit: + generator.save_profiles( + profiles=profiles, + file_path=os.path.join(sim_dir, "reddit_profiles.json"), + platform="reddit" + ) + + if state.enable_twitter: + # Twitter使用CSV格式!这是OASIS的要求 + generator.save_profiles( + profiles=profiles, + file_path=os.path.join(sim_dir, "twitter_profiles.csv"), + platform="twitter" + ) + + if progress_callback: + progress_callback( + "generating_profiles", 100, + f"完成,共 {len(profiles)} 个Profile", + current=len(profiles), + total=len(profiles) + ) + + # ========== 阶段3: LLM智能生成模拟配置 ========== + if progress_callback: + progress_callback( + "generating_config", 0, + "正在分析模拟需求...", + current=0, + total=3 + ) + + config_generator = SimulationConfigGenerator() + + if progress_callback: + progress_callback( + "generating_config", 30, + "正在调用LLM生成配置...", + current=1, + total=3 + ) + + sim_params = config_generator.generate_config( + simulation_id=simulation_id, + project_id=state.project_id, + graph_id=state.graph_id, + simulation_requirement=simulation_requirement, + document_text=document_text, + entities=filtered.entities, + enable_twitter=state.enable_twitter, + enable_reddit=state.enable_reddit + ) + + if progress_callback: + progress_callback( + "generating_config", 70, + "正在保存配置文件...", + current=2, + total=3 + ) + + # 保存配置文件 + config_path = os.path.join(sim_dir, "simulation_config.json") + with open(config_path, 'w', encoding='utf-8') as f: + f.write(sim_params.to_json()) + + state.config_generated = True + state.config_reasoning = sim_params.generation_reasoning + + if progress_callback: + progress_callback( + "generating_config", 100, + "配置生成完成", + current=3, + total=3 + ) + + # ========== 阶段4: 复制预设脚本 ========== + script_files = ["run_twitter_simulation.py", "run_reddit_simulation.py", + "run_parallel_simulation.py", "action_logger.py"] + + if progress_callback: + progress_callback( + "copying_scripts", 0, + "开始准备脚本...", + current=0, + total=len(script_files) + ) + + self._copy_preset_scripts(sim_dir) + + if progress_callback: + progress_callback( + "copying_scripts", 100, + f"完成,共 {len(script_files)} 个脚本", + current=len(script_files), + total=len(script_files) + ) + + # 更新状态 + state.status = SimulationStatus.READY + self._save_simulation_state(state) + + logger.info(f"模拟准备完成: {simulation_id}, " + f"entities={state.entities_count}, profiles={state.profiles_count}") + + return state + + except Exception as e: + logger.error(f"模拟准备失败: {simulation_id}, error={str(e)}") + import traceback + logger.error(traceback.format_exc()) + state.status = SimulationStatus.FAILED + state.error = str(e) + self._save_simulation_state(state) + raise + + def _copy_preset_scripts(self, sim_dir: str): + """复制预设脚本到模拟目录""" + scripts = [ + "run_twitter_simulation.py", + "run_reddit_simulation.py", + "run_parallel_simulation.py" + ] + + for script in scripts: + src = os.path.join(self.SCRIPTS_DIR, script) + dst = os.path.join(sim_dir, script) + + if os.path.exists(src): + shutil.copy2(src, dst) + logger.debug(f"复制脚本: {script}") + else: + logger.warning(f"预设脚本不存在: {src}") + + def get_simulation(self, simulation_id: str) -> Optional[SimulationState]: + """获取模拟状态""" + return self._load_simulation_state(simulation_id) + + def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]: + """列出所有模拟""" + simulations = [] + + if os.path.exists(self.SIMULATION_DATA_DIR): + for sim_id in os.listdir(self.SIMULATION_DATA_DIR): + state = self._load_simulation_state(sim_id) + if state: + if project_id is None or state.project_id == project_id: + simulations.append(state) + + return simulations + + def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]: + """获取模拟的Agent Profile""" + state = self._load_simulation_state(simulation_id) + if not state: + raise ValueError(f"模拟不存在: {simulation_id}") + + sim_dir = self._get_simulation_dir(simulation_id) + profile_path = os.path.join(sim_dir, f"{platform}_profiles.json") + + if not os.path.exists(profile_path): + return [] + + with open(profile_path, 'r', encoding='utf-8') as f: + return json.load(f) + + def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]: + """获取模拟配置""" + sim_dir = self._get_simulation_dir(simulation_id) + config_path = os.path.join(sim_dir, "simulation_config.json") + + if not os.path.exists(config_path): + return None + + with open(config_path, 'r', encoding='utf-8') as f: + return json.load(f) + + def get_run_instructions(self, simulation_id: str) -> Dict[str, str]: + """获取运行说明""" + sim_dir = self._get_simulation_dir(simulation_id) + config_path = os.path.join(sim_dir, "simulation_config.json") + + return { + "simulation_dir": sim_dir, + "config_file": config_path, + "commands": { + "twitter": f"python run_twitter_simulation.py --config simulation_config.json", + "reddit": f"python run_reddit_simulation.py --config simulation_config.json", + "parallel": f"python run_parallel_simulation.py --config simulation_config.json", + }, + "instructions": ( + f"1. 进入模拟目录: cd {sim_dir}\n" + f"2. 激活conda环境: conda activate MiroFish\n" + f"3. 运行模拟:\n" + f" - 单独运行Twitter: python run_twitter_simulation.py --config simulation_config.json\n" + f" - 单独运行Reddit: python run_reddit_simulation.py --config simulation_config.json\n" + f" - 并行运行双平台: python run_parallel_simulation.py --config simulation_config.json" + ) + } diff --git a/backend/app/services/simulation_runner.py b/backend/app/services/simulation_runner.py new file mode 100644 index 0000000..68d4b41 --- /dev/null +++ b/backend/app/services/simulation_runner.py @@ -0,0 +1,670 @@ +""" +OASIS模拟运行器 +在后台运行模拟并记录每个Agent的动作,支持实时状态监控 +""" + +import os +import sys +import json +import time +import asyncio +import threading +import subprocess +from typing import Dict, Any, List, Optional +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from queue import Queue + +from ..config import Config +from ..utils.logger import get_logger + +logger = get_logger('mirofish.simulation_runner') + + +class RunnerStatus(str, Enum): + """运行器状态""" + IDLE = "idle" + STARTING = "starting" + RUNNING = "running" + PAUSED = "paused" + STOPPING = "stopping" + STOPPED = "stopped" + COMPLETED = "completed" + FAILED = "failed" + + +@dataclass +class AgentAction: + """Agent动作记录""" + round_num: int + timestamp: str + platform: str # twitter / reddit + agent_id: int + agent_name: str + action_type: str # CREATE_POST, LIKE_POST, etc. + action_args: Dict[str, Any] = field(default_factory=dict) + result: Optional[str] = None + success: bool = True + + def to_dict(self) -> Dict[str, Any]: + return { + "round_num": self.round_num, + "timestamp": self.timestamp, + "platform": self.platform, + "agent_id": self.agent_id, + "agent_name": self.agent_name, + "action_type": self.action_type, + "action_args": self.action_args, + "result": self.result, + "success": self.success, + } + + +@dataclass +class RoundSummary: + """每轮摘要""" + round_num: int + start_time: str + end_time: Optional[str] = None + simulated_hour: int = 0 + twitter_actions: int = 0 + reddit_actions: int = 0 + active_agents: List[int] = field(default_factory=list) + actions: List[AgentAction] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "round_num": self.round_num, + "start_time": self.start_time, + "end_time": self.end_time, + "simulated_hour": self.simulated_hour, + "twitter_actions": self.twitter_actions, + "reddit_actions": self.reddit_actions, + "active_agents": self.active_agents, + "actions_count": len(self.actions), + "actions": [a.to_dict() for a in self.actions], + } + + +@dataclass +class SimulationRunState: + """模拟运行状态(实时)""" + simulation_id: str + runner_status: RunnerStatus = RunnerStatus.IDLE + + # 进度信息 + current_round: int = 0 + total_rounds: int = 0 + simulated_hours: int = 0 + total_simulation_hours: int = 0 + + # 平台状态 + twitter_running: bool = False + reddit_running: bool = False + twitter_actions_count: int = 0 + reddit_actions_count: int = 0 + + # 每轮摘要 + rounds: List[RoundSummary] = field(default_factory=list) + + # 最近动作(用于前端实时展示) + recent_actions: List[AgentAction] = field(default_factory=list) + max_recent_actions: int = 50 + + # 时间戳 + started_at: Optional[str] = None + updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) + completed_at: Optional[str] = None + + # 错误信息 + error: Optional[str] = None + + # 进程ID(用于停止) + process_pid: Optional[int] = None + + def add_action(self, action: AgentAction): + """添加动作到最近动作列表""" + self.recent_actions.insert(0, action) + if len(self.recent_actions) > self.max_recent_actions: + self.recent_actions = self.recent_actions[:self.max_recent_actions] + + if action.platform == "twitter": + self.twitter_actions_count += 1 + else: + self.reddit_actions_count += 1 + + self.updated_at = datetime.now().isoformat() + + def to_dict(self) -> Dict[str, Any]: + return { + "simulation_id": self.simulation_id, + "runner_status": self.runner_status.value, + "current_round": self.current_round, + "total_rounds": self.total_rounds, + "simulated_hours": self.simulated_hours, + "total_simulation_hours": self.total_simulation_hours, + "progress_percent": round(self.current_round / max(self.total_rounds, 1) * 100, 1), + "twitter_running": self.twitter_running, + "reddit_running": self.reddit_running, + "twitter_actions_count": self.twitter_actions_count, + "reddit_actions_count": self.reddit_actions_count, + "total_actions_count": self.twitter_actions_count + self.reddit_actions_count, + "started_at": self.started_at, + "updated_at": self.updated_at, + "completed_at": self.completed_at, + "error": self.error, + "process_pid": self.process_pid, + } + + def to_detail_dict(self) -> Dict[str, Any]: + """包含最近动作的详细信息""" + result = self.to_dict() + result["recent_actions"] = [a.to_dict() for a in self.recent_actions] + result["rounds_count"] = len(self.rounds) + return result + + +class SimulationRunner: + """ + 模拟运行器 + + 负责: + 1. 在后台进程中运行OASIS模拟 + 2. 解析运行日志,记录每个Agent的动作 + 3. 提供实时状态查询接口 + 4. 支持暂停/停止/恢复操作 + """ + + # 运行状态存储目录 + RUN_STATE_DIR = os.path.join( + os.path.dirname(__file__), + '../../uploads/simulations' + ) + + # 内存中的运行状态 + _run_states: Dict[str, SimulationRunState] = {} + _processes: Dict[str, subprocess.Popen] = {} + _action_queues: Dict[str, Queue] = {} + _monitor_threads: Dict[str, threading.Thread] = {} + + @classmethod + def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: + """获取运行状态""" + if simulation_id in cls._run_states: + return cls._run_states[simulation_id] + + # 尝试从文件加载 + state = cls._load_run_state(simulation_id) + if state: + cls._run_states[simulation_id] = state + return state + + @classmethod + def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: + """从文件加载运行状态""" + state_file = os.path.join(cls.RUN_STATE_DIR, simulation_id, "run_state.json") + if not os.path.exists(state_file): + return None + + try: + with open(state_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + state = SimulationRunState( + simulation_id=simulation_id, + runner_status=RunnerStatus(data.get("runner_status", "idle")), + current_round=data.get("current_round", 0), + total_rounds=data.get("total_rounds", 0), + simulated_hours=data.get("simulated_hours", 0), + total_simulation_hours=data.get("total_simulation_hours", 0), + twitter_running=data.get("twitter_running", False), + reddit_running=data.get("reddit_running", False), + twitter_actions_count=data.get("twitter_actions_count", 0), + reddit_actions_count=data.get("reddit_actions_count", 0), + started_at=data.get("started_at"), + updated_at=data.get("updated_at", datetime.now().isoformat()), + completed_at=data.get("completed_at"), + error=data.get("error"), + process_pid=data.get("process_pid"), + ) + + # 加载最近动作 + actions_data = data.get("recent_actions", []) + for a in actions_data: + state.recent_actions.append(AgentAction( + round_num=a.get("round_num", 0), + timestamp=a.get("timestamp", ""), + platform=a.get("platform", ""), + agent_id=a.get("agent_id", 0), + agent_name=a.get("agent_name", ""), + action_type=a.get("action_type", ""), + action_args=a.get("action_args", {}), + result=a.get("result"), + success=a.get("success", True), + )) + + return state + except Exception as e: + logger.error(f"加载运行状态失败: {str(e)}") + return None + + @classmethod + def _save_run_state(cls, state: SimulationRunState): + """保存运行状态到文件""" + sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id) + os.makedirs(sim_dir, exist_ok=True) + state_file = os.path.join(sim_dir, "run_state.json") + + data = state.to_detail_dict() + + with open(state_file, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + cls._run_states[state.simulation_id] = state + + @classmethod + def start_simulation( + cls, + simulation_id: str, + platform: str = "parallel" # twitter / reddit / parallel + ) -> SimulationRunState: + """ + 启动模拟 + + Args: + simulation_id: 模拟ID + platform: 运行平台 (twitter/reddit/parallel) + + Returns: + SimulationRunState + """ + # 检查是否已在运行 + existing = cls.get_run_state(simulation_id) + if existing and existing.runner_status in [RunnerStatus.RUNNING, RunnerStatus.STARTING]: + raise ValueError(f"模拟已在运行中: {simulation_id}") + + # 加载模拟配置 + sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) + config_path = os.path.join(sim_dir, "simulation_config.json") + + if not os.path.exists(config_path): + raise ValueError(f"模拟配置不存在,请先调用 /prepare 接口") + + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + + # 初始化运行状态 + time_config = config.get("time_config", {}) + total_hours = time_config.get("total_simulation_hours", 72) + minutes_per_round = time_config.get("minutes_per_round", 30) + total_rounds = int(total_hours * 60 / minutes_per_round) + + state = SimulationRunState( + simulation_id=simulation_id, + runner_status=RunnerStatus.STARTING, + total_rounds=total_rounds, + total_simulation_hours=total_hours, + started_at=datetime.now().isoformat(), + ) + + cls._save_run_state(state) + + # 确定运行哪个脚本 + if platform == "twitter": + script_name = "run_twitter_simulation.py" + state.twitter_running = True + elif platform == "reddit": + script_name = "run_reddit_simulation.py" + state.reddit_running = True + else: + script_name = "run_parallel_simulation.py" + state.twitter_running = True + state.reddit_running = True + + script_path = os.path.join(sim_dir, script_name) + + if not os.path.exists(script_path): + raise ValueError(f"脚本不存在: {script_path}") + + # 创建动作队列 + action_queue = Queue() + cls._action_queues[simulation_id] = action_queue + + # 启动模拟进程 + try: + # 构建运行命令 + cmd = [ + sys.executable, # Python解释器 + script_path, + "--config", "simulation_config.json", + "--action-log", "actions.jsonl", # 动作日志文件 + ] + + # 设置工作目录为模拟目录 + process = subprocess.Popen( + cmd, + cwd=sim_dir, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + + state.process_pid = process.pid + state.runner_status = RunnerStatus.RUNNING + cls._processes[simulation_id] = process + cls._save_run_state(state) + + # 启动监控线程 + monitor_thread = threading.Thread( + target=cls._monitor_simulation, + args=(simulation_id,), + daemon=True + ) + monitor_thread.start() + cls._monitor_threads[simulation_id] = monitor_thread + + logger.info(f"模拟启动成功: {simulation_id}, pid={process.pid}, platform={platform}") + + except Exception as e: + state.runner_status = RunnerStatus.FAILED + state.error = str(e) + cls._save_run_state(state) + raise + + return state + + @classmethod + def _monitor_simulation(cls, simulation_id: str): + """监控模拟进程,解析动作日志""" + sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) + actions_log = os.path.join(sim_dir, "actions.jsonl") + + process = cls._processes.get(simulation_id) + state = cls.get_run_state(simulation_id) + + if not process or not state: + return + + last_position = 0 + + try: + while process.poll() is None: # 进程仍在运行 + # 读取动作日志 + if os.path.exists(actions_log): + with open(actions_log, 'r', encoding='utf-8') as f: + f.seek(last_position) + for line in f: + line = line.strip() + if line: + try: + action_data = json.loads(line) + action = AgentAction( + round_num=action_data.get("round", 0), + timestamp=action_data.get("timestamp", datetime.now().isoformat()), + platform=action_data.get("platform", "unknown"), + agent_id=action_data.get("agent_id", 0), + agent_name=action_data.get("agent_name", ""), + action_type=action_data.get("action_type", ""), + action_args=action_data.get("action_args", {}), + result=action_data.get("result"), + success=action_data.get("success", True), + ) + state.add_action(action) + + # 更新轮次 + if action.round_num > state.current_round: + state.current_round = action.round_num + + except json.JSONDecodeError: + pass + last_position = f.tell() + + # 定期保存状态 + cls._save_run_state(state) + time.sleep(1) # 每秒检查一次 + + # 进程结束 + exit_code = process.returncode + + if exit_code == 0: + state.runner_status = RunnerStatus.COMPLETED + state.completed_at = datetime.now().isoformat() + logger.info(f"模拟完成: {simulation_id}") + else: + state.runner_status = RunnerStatus.FAILED + stderr = process.stderr.read() if process.stderr else "" + state.error = f"进程退出码: {exit_code}, 错误: {stderr[:500]}" + logger.error(f"模拟失败: {simulation_id}, error={state.error}") + + state.twitter_running = False + state.reddit_running = False + cls._save_run_state(state) + + except Exception as e: + logger.error(f"监控线程异常: {simulation_id}, error={str(e)}") + state.runner_status = RunnerStatus.FAILED + state.error = str(e) + cls._save_run_state(state) + + finally: + # 清理 + cls._processes.pop(simulation_id, None) + cls._action_queues.pop(simulation_id, None) + + @classmethod + def stop_simulation(cls, simulation_id: str) -> SimulationRunState: + """停止模拟""" + state = cls.get_run_state(simulation_id) + if not state: + raise ValueError(f"模拟不存在: {simulation_id}") + + if state.runner_status not in [RunnerStatus.RUNNING, RunnerStatus.PAUSED]: + raise ValueError(f"模拟未在运行: {simulation_id}, status={state.runner_status}") + + state.runner_status = RunnerStatus.STOPPING + cls._save_run_state(state) + + # 终止进程 + process = cls._processes.get(simulation_id) + if process: + process.terminate() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + process.kill() + + state.runner_status = RunnerStatus.STOPPED + state.twitter_running = False + state.reddit_running = False + state.completed_at = datetime.now().isoformat() + cls._save_run_state(state) + + logger.info(f"模拟已停止: {simulation_id}") + return state + + @classmethod + def get_actions( + cls, + simulation_id: str, + limit: int = 100, + offset: int = 0, + platform: Optional[str] = None, + agent_id: Optional[int] = None, + round_num: Optional[int] = None + ) -> List[AgentAction]: + """ + 获取动作历史 + + Args: + simulation_id: 模拟ID + limit: 返回数量限制 + offset: 偏移量 + platform: 过滤平台 + agent_id: 过滤Agent + round_num: 过滤轮次 + + Returns: + 动作列表 + """ + sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) + actions_log = os.path.join(sim_dir, "actions.jsonl") + + if not os.path.exists(actions_log): + return [] + + actions = [] + + with open(actions_log, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + continue + + try: + data = json.loads(line) + + # 过滤 + if platform and data.get("platform") != platform: + continue + if agent_id is not None and data.get("agent_id") != agent_id: + continue + if round_num is not None and data.get("round") != round_num: + continue + + actions.append(AgentAction( + round_num=data.get("round", 0), + timestamp=data.get("timestamp", ""), + platform=data.get("platform", ""), + agent_id=data.get("agent_id", 0), + agent_name=data.get("agent_name", ""), + action_type=data.get("action_type", ""), + action_args=data.get("action_args", {}), + result=data.get("result"), + success=data.get("success", True), + )) + + except json.JSONDecodeError: + continue + + # 按时间倒序排列 + actions.reverse() + + # 分页 + return actions[offset:offset + limit] + + @classmethod + def get_timeline( + cls, + simulation_id: str, + start_round: int = 0, + end_round: Optional[int] = None + ) -> List[Dict[str, Any]]: + """ + 获取模拟时间线(按轮次汇总) + + Args: + simulation_id: 模拟ID + start_round: 起始轮次 + end_round: 结束轮次 + + Returns: + 每轮的汇总信息 + """ + actions = cls.get_actions(simulation_id, limit=10000) + + # 按轮次分组 + rounds: Dict[int, Dict[str, Any]] = {} + + for action in actions: + round_num = action.round_num + + if round_num < start_round: + continue + if end_round is not None and round_num > end_round: + continue + + if round_num not in rounds: + rounds[round_num] = { + "round_num": round_num, + "twitter_actions": 0, + "reddit_actions": 0, + "active_agents": set(), + "action_types": {}, + "first_action_time": action.timestamp, + "last_action_time": action.timestamp, + } + + r = rounds[round_num] + + if action.platform == "twitter": + r["twitter_actions"] += 1 + else: + r["reddit_actions"] += 1 + + r["active_agents"].add(action.agent_id) + r["action_types"][action.action_type] = r["action_types"].get(action.action_type, 0) + 1 + r["last_action_time"] = action.timestamp + + # 转换为列表 + result = [] + for round_num in sorted(rounds.keys()): + r = rounds[round_num] + result.append({ + "round_num": round_num, + "twitter_actions": r["twitter_actions"], + "reddit_actions": r["reddit_actions"], + "total_actions": r["twitter_actions"] + r["reddit_actions"], + "active_agents_count": len(r["active_agents"]), + "active_agents": list(r["active_agents"]), + "action_types": r["action_types"], + "first_action_time": r["first_action_time"], + "last_action_time": r["last_action_time"], + }) + + return result + + @classmethod + def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]: + """ + 获取每个Agent的统计信息 + + Returns: + Agent统计列表 + """ + actions = cls.get_actions(simulation_id, limit=10000) + + agent_stats: Dict[int, Dict[str, Any]] = {} + + for action in actions: + agent_id = action.agent_id + + if agent_id not in agent_stats: + agent_stats[agent_id] = { + "agent_id": agent_id, + "agent_name": action.agent_name, + "total_actions": 0, + "twitter_actions": 0, + "reddit_actions": 0, + "action_types": {}, + "first_action_time": action.timestamp, + "last_action_time": action.timestamp, + } + + stats = agent_stats[agent_id] + stats["total_actions"] += 1 + + if action.platform == "twitter": + stats["twitter_actions"] += 1 + else: + stats["reddit_actions"] += 1 + + stats["action_types"][action.action_type] = stats["action_types"].get(action.action_type, 0) + 1 + stats["last_action_time"] = action.timestamp + + # 按总动作数排序 + result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True) + + return result + diff --git a/backend/app/services/zep_entity_reader.py b/backend/app/services/zep_entity_reader.py new file mode 100644 index 0000000..b2a3a3b --- /dev/null +++ b/backend/app/services/zep_entity_reader.py @@ -0,0 +1,386 @@ +""" +Zep实体读取与过滤服务 +从Zep图谱中读取节点,筛选出符合预定义实体类型的节点 +""" + +from typing import Dict, Any, List, Optional, Set +from dataclasses import dataclass, field + +from zep_cloud.client import Zep + +from ..config import Config +from ..utils.logger import get_logger + +logger = get_logger('mirofish.zep_entity_reader') + + +@dataclass +class EntityNode: + """实体节点数据结构""" + uuid: str + name: str + labels: List[str] + summary: str + attributes: Dict[str, Any] + # 相关的边信息 + related_edges: List[Dict[str, Any]] = field(default_factory=list) + # 相关的其他节点信息 + related_nodes: List[Dict[str, Any]] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "uuid": self.uuid, + "name": self.name, + "labels": self.labels, + "summary": self.summary, + "attributes": self.attributes, + "related_edges": self.related_edges, + "related_nodes": self.related_nodes, + } + + def get_entity_type(self) -> Optional[str]: + """获取实体类型(排除默认的Entity标签)""" + for label in self.labels: + if label not in ["Entity", "Node"]: + return label + return None + + +@dataclass +class FilteredEntities: + """过滤后的实体集合""" + entities: List[EntityNode] + entity_types: Set[str] + total_count: int + filtered_count: int + + def to_dict(self) -> Dict[str, Any]: + return { + "entities": [e.to_dict() for e in self.entities], + "entity_types": list(self.entity_types), + "total_count": self.total_count, + "filtered_count": self.filtered_count, + } + + +class ZepEntityReader: + """ + Zep实体读取与过滤服务 + + 主要功能: + 1. 从Zep图谱读取所有节点 + 2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点) + 3. 获取每个实体的相关边和关联节点信息 + """ + + def __init__(self, api_key: Optional[str] = None): + self.api_key = api_key or Config.ZEP_API_KEY + if not self.api_key: + raise ValueError("ZEP_API_KEY 未配置") + + self.client = Zep(api_key=self.api_key) + + def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: + """ + 获取图谱的所有节点 + + Args: + graph_id: 图谱ID + + Returns: + 节点列表 + """ + logger.info(f"获取图谱 {graph_id} 的所有节点...") + + nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id) + + nodes_data = [] + for node in nodes: + nodes_data.append({ + "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), + "name": node.name or "", + "labels": node.labels or [], + "summary": node.summary or "", + "attributes": node.attributes or {}, + }) + + logger.info(f"共获取 {len(nodes_data)} 个节点") + return nodes_data + + def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: + """ + 获取图谱的所有边 + + Args: + graph_id: 图谱ID + + Returns: + 边列表 + """ + logger.info(f"获取图谱 {graph_id} 的所有边...") + + edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id) + + edges_data = [] + for edge in edges: + edges_data.append({ + "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), + "name": edge.name or "", + "fact": edge.fact or "", + "source_node_uuid": edge.source_node_uuid, + "target_node_uuid": edge.target_node_uuid, + "attributes": edge.attributes or {}, + }) + + logger.info(f"共获取 {len(edges_data)} 条边") + return edges_data + + def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: + """ + 获取指定节点的所有相关边 + + Args: + node_uuid: 节点UUID + + Returns: + 边列表 + """ + try: + edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid) + + edges_data = [] + for edge in edges: + edges_data.append({ + "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), + "name": edge.name or "", + "fact": edge.fact or "", + "source_node_uuid": edge.source_node_uuid, + "target_node_uuid": edge.target_node_uuid, + "attributes": edge.attributes or {}, + }) + + return edges_data + except Exception as e: + logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}") + return [] + + def filter_defined_entities( + self, + graph_id: str, + defined_entity_types: Optional[List[str]] = None, + enrich_with_edges: bool = True + ) -> FilteredEntities: + """ + 筛选出符合预定义实体类型的节点 + + 筛选逻辑: + - 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过 + - 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留 + + Args: + graph_id: 图谱ID + defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型) + enrich_with_edges: 是否获取每个实体的相关边信息 + + Returns: + FilteredEntities: 过滤后的实体集合 + """ + logger.info(f"开始筛选图谱 {graph_id} 的实体...") + + # 获取所有节点 + all_nodes = self.get_all_nodes(graph_id) + total_count = len(all_nodes) + + # 获取所有边(用于后续关联查找) + all_edges = self.get_all_edges(graph_id) if enrich_with_edges else [] + + # 构建节点UUID到节点数据的映射 + node_map = {n["uuid"]: n for n in all_nodes} + + # 筛选符合条件的实体 + filtered_entities = [] + entity_types_found = set() + + for node in all_nodes: + labels = node.get("labels", []) + + # 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签 + custom_labels = [l for l in labels if l not in ["Entity", "Node"]] + + if not custom_labels: + # 只有默认标签,跳过 + continue + + # 如果指定了预定义类型,检查是否匹配 + if defined_entity_types: + matching_labels = [l for l in custom_labels if l in defined_entity_types] + if not matching_labels: + continue + entity_type = matching_labels[0] + else: + entity_type = custom_labels[0] + + entity_types_found.add(entity_type) + + # 创建实体节点对象 + entity = EntityNode( + uuid=node["uuid"], + name=node["name"], + labels=labels, + summary=node["summary"], + attributes=node["attributes"], + ) + + # 获取相关边和节点 + if enrich_with_edges: + related_edges = [] + related_node_uuids = set() + + for edge in all_edges: + if edge["source_node_uuid"] == node["uuid"]: + related_edges.append({ + "direction": "outgoing", + "edge_name": edge["name"], + "fact": edge["fact"], + "target_node_uuid": edge["target_node_uuid"], + }) + related_node_uuids.add(edge["target_node_uuid"]) + elif edge["target_node_uuid"] == node["uuid"]: + related_edges.append({ + "direction": "incoming", + "edge_name": edge["name"], + "fact": edge["fact"], + "source_node_uuid": edge["source_node_uuid"], + }) + related_node_uuids.add(edge["source_node_uuid"]) + + entity.related_edges = related_edges + + # 获取关联节点的基本信息 + related_nodes = [] + for related_uuid in related_node_uuids: + if related_uuid in node_map: + related_node = node_map[related_uuid] + related_nodes.append({ + "uuid": related_node["uuid"], + "name": related_node["name"], + "labels": related_node["labels"], + "summary": related_node.get("summary", ""), + }) + + entity.related_nodes = related_nodes + + filtered_entities.append(entity) + + logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, " + f"实体类型: {entity_types_found}") + + return FilteredEntities( + entities=filtered_entities, + entity_types=entity_types_found, + total_count=total_count, + filtered_count=len(filtered_entities), + ) + + def get_entity_with_context( + self, + graph_id: str, + entity_uuid: str + ) -> Optional[EntityNode]: + """ + 获取单个实体及其完整上下文(边和关联节点) + + Args: + graph_id: 图谱ID + entity_uuid: 实体UUID + + Returns: + EntityNode或None + """ + try: + # 获取节点 + node = self.client.graph.node.get(uuid_=entity_uuid) + + if not node: + return None + + # 获取节点的边 + edges = self.get_node_edges(entity_uuid) + + # 获取所有节点用于关联查找 + all_nodes = self.get_all_nodes(graph_id) + node_map = {n["uuid"]: n for n in all_nodes} + + # 处理相关边和节点 + related_edges = [] + related_node_uuids = set() + + for edge in edges: + if edge["source_node_uuid"] == entity_uuid: + related_edges.append({ + "direction": "outgoing", + "edge_name": edge["name"], + "fact": edge["fact"], + "target_node_uuid": edge["target_node_uuid"], + }) + related_node_uuids.add(edge["target_node_uuid"]) + else: + related_edges.append({ + "direction": "incoming", + "edge_name": edge["name"], + "fact": edge["fact"], + "source_node_uuid": edge["source_node_uuid"], + }) + related_node_uuids.add(edge["source_node_uuid"]) + + # 获取关联节点信息 + related_nodes = [] + for related_uuid in related_node_uuids: + if related_uuid in node_map: + related_node = node_map[related_uuid] + related_nodes.append({ + "uuid": related_node["uuid"], + "name": related_node["name"], + "labels": related_node["labels"], + "summary": related_node.get("summary", ""), + }) + + return EntityNode( + uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), + name=node.name or "", + labels=node.labels or [], + summary=node.summary or "", + attributes=node.attributes or {}, + related_edges=related_edges, + related_nodes=related_nodes, + ) + + except Exception as e: + logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}") + return None + + def get_entities_by_type( + self, + graph_id: str, + entity_type: str, + enrich_with_edges: bool = True + ) -> List[EntityNode]: + """ + 获取指定类型的所有实体 + + Args: + graph_id: 图谱ID + entity_type: 实体类型(如 "Student", "PublicFigure" 等) + enrich_with_edges: 是否获取相关边信息 + + Returns: + 实体列表 + """ + result = self.filter_defined_entities( + graph_id=graph_id, + defined_entity_types=[entity_type], + enrich_with_edges=enrich_with_edges + ) + return result.entities + + diff --git a/backend/app/utils/retry.py b/backend/app/utils/retry.py new file mode 100644 index 0000000..819b1cf --- /dev/null +++ b/backend/app/utils/retry.py @@ -0,0 +1,238 @@ +""" +API调用重试机制 +用于处理LLM等外部API调用的重试逻辑 +""" + +import time +import random +import functools +from typing import Callable, Any, Optional, Type, Tuple +from ..utils.logger import get_logger + +logger = get_logger('mirofish.retry') + + +def retry_with_backoff( + max_retries: int = 3, + initial_delay: float = 1.0, + max_delay: float = 30.0, + backoff_factor: float = 2.0, + jitter: bool = True, + exceptions: Tuple[Type[Exception], ...] = (Exception,), + on_retry: Optional[Callable[[Exception, int], None]] = None +): + """ + 带指数退避的重试装饰器 + + Args: + max_retries: 最大重试次数 + initial_delay: 初始延迟(秒) + max_delay: 最大延迟(秒) + backoff_factor: 退避因子 + jitter: 是否添加随机抖动 + exceptions: 需要重试的异常类型 + on_retry: 重试时的回调函数 (exception, retry_count) + + Usage: + @retry_with_backoff(max_retries=3) + def call_llm_api(): + ... + """ + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs) -> Any: + last_exception = None + delay = initial_delay + + for attempt in range(max_retries + 1): + try: + return func(*args, **kwargs) + + except exceptions as e: + last_exception = e + + if attempt == max_retries: + logger.error(f"函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}") + raise + + # 计算延迟 + current_delay = min(delay, max_delay) + if jitter: + current_delay = current_delay * (0.5 + random.random()) + + logger.warning( + f"函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, " + f"{current_delay:.1f}秒后重试..." + ) + + if on_retry: + on_retry(e, attempt + 1) + + time.sleep(current_delay) + delay *= backoff_factor + + raise last_exception + + return wrapper + return decorator + + +def retry_with_backoff_async( + max_retries: int = 3, + initial_delay: float = 1.0, + max_delay: float = 30.0, + backoff_factor: float = 2.0, + jitter: bool = True, + exceptions: Tuple[Type[Exception], ...] = (Exception,), + on_retry: Optional[Callable[[Exception, int], None]] = None +): + """ + 异步版本的重试装饰器 + """ + import asyncio + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + async def wrapper(*args, **kwargs) -> Any: + last_exception = None + delay = initial_delay + + for attempt in range(max_retries + 1): + try: + return await func(*args, **kwargs) + + except exceptions as e: + last_exception = e + + if attempt == max_retries: + logger.error(f"异步函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}") + raise + + current_delay = min(delay, max_delay) + if jitter: + current_delay = current_delay * (0.5 + random.random()) + + logger.warning( + f"异步函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, " + f"{current_delay:.1f}秒后重试..." + ) + + if on_retry: + on_retry(e, attempt + 1) + + await asyncio.sleep(current_delay) + delay *= backoff_factor + + raise last_exception + + return wrapper + return decorator + + +class RetryableAPIClient: + """ + 可重试的API客户端封装 + """ + + def __init__( + self, + max_retries: int = 3, + initial_delay: float = 1.0, + max_delay: float = 30.0, + backoff_factor: float = 2.0 + ): + self.max_retries = max_retries + self.initial_delay = initial_delay + self.max_delay = max_delay + self.backoff_factor = backoff_factor + + def call_with_retry( + self, + func: Callable, + *args, + exceptions: Tuple[Type[Exception], ...] = (Exception,), + **kwargs + ) -> Any: + """ + 执行函数调用并在失败时重试 + + Args: + func: 要调用的函数 + *args: 函数参数 + exceptions: 需要重试的异常类型 + **kwargs: 函数关键字参数 + + Returns: + 函数返回值 + """ + last_exception = None + delay = self.initial_delay + + for attempt in range(self.max_retries + 1): + try: + return func(*args, **kwargs) + + except exceptions as e: + last_exception = e + + if attempt == self.max_retries: + logger.error(f"API调用在 {self.max_retries} 次重试后仍失败: {str(e)}") + raise + + current_delay = min(delay, self.max_delay) + current_delay = current_delay * (0.5 + random.random()) + + logger.warning( + f"API调用第 {attempt + 1} 次尝试失败: {str(e)}, " + f"{current_delay:.1f}秒后重试..." + ) + + time.sleep(current_delay) + delay *= self.backoff_factor + + raise last_exception + + def call_batch_with_retry( + self, + items: list, + process_func: Callable, + exceptions: Tuple[Type[Exception], ...] = (Exception,), + continue_on_failure: bool = True + ) -> Tuple[list, list]: + """ + 批量调用并对每个失败项单独重试 + + Args: + items: 要处理的项目列表 + process_func: 处理函数,接收单个item作为参数 + exceptions: 需要重试的异常类型 + continue_on_failure: 单项失败后是否继续处理其他项 + + Returns: + (成功结果列表, 失败项列表) + """ + results = [] + failures = [] + + for idx, item in enumerate(items): + try: + result = self.call_with_retry( + process_func, + item, + exceptions=exceptions + ) + results.append(result) + + except Exception as e: + logger.error(f"处理第 {idx + 1} 项失败: {str(e)}") + failures.append({ + "index": idx, + "item": item, + "error": str(e) + }) + + if not continue_on_failure: + raise + + return results, failures + diff --git a/backend/requirements.txt b/backend/requirements.txt index 0824a74..1a0d82d 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -20,3 +20,7 @@ pydantic>=2.0.0 # 文件处理 werkzeug>=3.0.0 +# OASIS社交媒体模拟框架 +oasis-ai>=0.1.0 +camel-ai>=0.2.0 + diff --git a/backend/scripts/action_logger.py b/backend/scripts/action_logger.py new file mode 100644 index 0000000..9ecfd1c --- /dev/null +++ b/backend/scripts/action_logger.py @@ -0,0 +1,138 @@ +""" +动作日志记录器 +用于记录OASIS模拟中每个Agent的动作,供后端监控使用 +""" + +import json +import os +from datetime import datetime +from typing import Dict, Any, Optional + + +class ActionLogger: + """动作日志记录器""" + + def __init__(self, log_path: str): + """ + 初始化日志记录器 + + Args: + log_path: 日志文件路径(.jsonl格式) + """ + self.log_path = log_path + self._ensure_dir() + + def _ensure_dir(self): + """确保目录存在""" + log_dir = os.path.dirname(self.log_path) + if log_dir: + os.makedirs(log_dir, exist_ok=True) + + def log_action( + self, + round_num: int, + platform: str, + agent_id: int, + agent_name: str, + action_type: str, + action_args: Optional[Dict[str, Any]] = None, + result: Optional[str] = None, + success: bool = True + ): + """ + 记录一个动作 + + Args: + round_num: 轮次 + platform: 平台 (twitter/reddit) + agent_id: Agent ID + agent_name: Agent名称 + action_type: 动作类型 + action_args: 动作参数 + result: 执行结果 + success: 是否成功 + """ + entry = { + "round": round_num, + "timestamp": datetime.now().isoformat(), + "platform": platform, + "agent_id": agent_id, + "agent_name": agent_name, + "action_type": action_type, + "action_args": action_args or {}, + "result": result, + "success": success, + } + + with open(self.log_path, 'a', encoding='utf-8') as f: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + + def log_round_start(self, round_num: int, simulated_hour: int, platform: str): + """记录轮次开始""" + entry = { + "round": round_num, + "timestamp": datetime.now().isoformat(), + "platform": platform, + "event_type": "round_start", + "simulated_hour": simulated_hour, + } + + with open(self.log_path, 'a', encoding='utf-8') as f: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + + def log_round_end(self, round_num: int, actions_count: int, platform: str): + """记录轮次结束""" + entry = { + "round": round_num, + "timestamp": datetime.now().isoformat(), + "platform": platform, + "event_type": "round_end", + "actions_count": actions_count, + } + + with open(self.log_path, 'a', encoding='utf-8') as f: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + + def log_simulation_start(self, platform: str, config: Dict[str, Any]): + """记录模拟开始""" + entry = { + "timestamp": datetime.now().isoformat(), + "platform": platform, + "event_type": "simulation_start", + "total_rounds": config.get("time_config", {}).get("total_simulation_hours", 72) * 2, + "agents_count": len(config.get("agent_configs", [])), + } + + with open(self.log_path, 'a', encoding='utf-8') as f: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + + def log_simulation_end(self, platform: str, total_rounds: int, total_actions: int): + """记录模拟结束""" + entry = { + "timestamp": datetime.now().isoformat(), + "platform": platform, + "event_type": "simulation_end", + "total_rounds": total_rounds, + "total_actions": total_actions, + } + + with open(self.log_path, 'a', encoding='utf-8') as f: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + + +# 全局日志实例(可选) +_global_logger: Optional[ActionLogger] = None + + +def get_logger(log_path: Optional[str] = None) -> ActionLogger: + """获取全局日志实例""" + global _global_logger + + if log_path: + _global_logger = ActionLogger(log_path) + + if _global_logger is None: + _global_logger = ActionLogger("actions.jsonl") + + return _global_logger + diff --git a/backend/scripts/run_parallel_simulation.py b/backend/scripts/run_parallel_simulation.py new file mode 100644 index 0000000..b3f6c07 --- /dev/null +++ b/backend/scripts/run_parallel_simulation.py @@ -0,0 +1,503 @@ +""" +OASIS 双平台并行模拟预设脚本 +同时运行Twitter和Reddit模拟,读取相同的配置文件 + +使用方式: + python run_parallel_simulation.py --config simulation_config.json [--action-log actions.jsonl] +""" + +import argparse +import asyncio +import json +import os +import random +import sys +from datetime import datetime +from typing import Dict, Any, List, Optional + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from action_logger import ActionLogger + +try: + from camel.models import ModelFactory + from camel.types import ModelPlatformType + import oasis + from oasis import ( + ActionType, + LLMAction, + ManualAction, + generate_twitter_agent_graph, + generate_reddit_agent_graph + ) +except ImportError as e: + print(f"错误: 缺少依赖 {e}") + print("请先安装: pip install oasis-ai camel-ai") + sys.exit(1) + + +# Twitter可用动作 +TWITTER_ACTIONS = [ + ActionType.CREATE_POST, + ActionType.LIKE_POST, + ActionType.REPOST, + ActionType.FOLLOW, + ActionType.DO_NOTHING, + ActionType.QUOTE_POST, +] + +# Reddit可用动作 +REDDIT_ACTIONS = [ + ActionType.LIKE_POST, + ActionType.DISLIKE_POST, + ActionType.CREATE_POST, + ActionType.CREATE_COMMENT, + ActionType.LIKE_COMMENT, + ActionType.DISLIKE_COMMENT, + ActionType.SEARCH_POSTS, + ActionType.SEARCH_USER, + ActionType.TREND, + ActionType.REFRESH, + ActionType.DO_NOTHING, + ActionType.FOLLOW, + ActionType.MUTE, +] + + +def load_config(config_path: str) -> Dict[str, Any]: + """加载配置文件""" + with open(config_path, 'r', encoding='utf-8') as f: + return json.load(f) + + +def create_model(config: Dict[str, Any]): + """ + 创建LLM模型 + + OASIS使用camel-ai的ModelFactory,配置方式: + - 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量 + - 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量 + """ + llm_model = config.get("llm_model", "gpt-4o-mini") + llm_base_url = config.get("llm_base_url", "") + + # 如果配置了base_url,设置环境变量 + if llm_base_url: + os.environ["OPENAI_API_BASE_URL"] = llm_base_url + + return ModelFactory.create( + model_platform=ModelPlatformType.OPENAI, + model_type=llm_model, + ) + + +def get_active_agents_for_round( + env, + config: Dict[str, Any], + current_hour: int, + round_num: int +) -> List: + """根据时间和配置决定本轮激活哪些Agent""" + time_config = config.get("time_config", {}) + agent_configs = config.get("agent_configs", []) + + base_min = time_config.get("agents_per_hour_min", 5) + base_max = time_config.get("agents_per_hour_max", 20) + + peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]) + off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5]) + + if current_hour in peak_hours: + multiplier = time_config.get("peak_activity_multiplier", 1.5) + elif current_hour in off_peak_hours: + multiplier = time_config.get("off_peak_activity_multiplier", 0.3) + else: + multiplier = 1.0 + + target_count = int(random.uniform(base_min, base_max) * multiplier) + + candidates = [] + for cfg in agent_configs: + agent_id = cfg.get("agent_id", 0) + active_hours = cfg.get("active_hours", list(range(8, 23))) + activity_level = cfg.get("activity_level", 0.5) + + if current_hour not in active_hours: + continue + + if random.random() < activity_level: + candidates.append(agent_id) + + selected_ids = random.sample( + candidates, + min(target_count, len(candidates)) + ) if candidates else [] + + active_agents = [] + for agent_id in selected_ids: + try: + agent = env.agent_graph.get_agent(agent_id) + active_agents.append((agent_id, agent)) + except Exception: + pass + + return active_agents + + +async def run_twitter_simulation( + config: Dict[str, Any], + simulation_dir: str, + action_logger: Optional[ActionLogger] = None +): + """运行Twitter模拟""" + print("[Twitter] 初始化...") + + model = create_model(config) + + # OASIS Twitter使用CSV格式 + profile_path = os.path.join(simulation_dir, "twitter_profiles.csv") + if not os.path.exists(profile_path): + print(f"[Twitter] 错误: Profile文件不存在: {profile_path}") + return + + agent_graph = await generate_twitter_agent_graph( + profile_path=profile_path, + model=model, + available_actions=TWITTER_ACTIONS, + ) + + # 获取Agent名称映射 + agent_names = {} + for agent_id, agent in agent_graph.get_agents(): + agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}') + + db_path = os.path.join(simulation_dir, "twitter_simulation.db") + if os.path.exists(db_path): + os.remove(db_path) + + env = oasis.make( + agent_graph=agent_graph, + platform=oasis.DefaultPlatformType.TWITTER, + database_path=db_path, + ) + + await env.reset() + print("[Twitter] 环境已启动") + + if action_logger: + action_logger.log_simulation_start("twitter", config) + + total_actions = 0 + + # 执行初始事件 + event_config = config.get("event_config", {}) + initial_posts = event_config.get("initial_posts", []) + + if initial_posts: + initial_actions = {} + for post in initial_posts: + agent_id = post.get("poster_agent_id", 0) + content = post.get("content", "") + try: + agent = env.agent_graph.get_agent(agent_id) + initial_actions[agent] = ManualAction( + action_type=ActionType.CREATE_POST, + action_args={"content": content} + ) + + if action_logger: + action_logger.log_action( + round_num=0, + platform="twitter", + agent_id=agent_id, + agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"), + action_type="CREATE_POST", + action_args={"content": content[:100] + "..." if len(content) > 100 else content} + ) + total_actions += 1 + except Exception: + pass + + if initial_actions: + await env.step(initial_actions) + print(f"[Twitter] 已发布 {len(initial_actions)} 条初始帖子") + + # 主模拟循环 + time_config = config.get("time_config", {}) + total_hours = time_config.get("total_simulation_hours", 72) + minutes_per_round = time_config.get("minutes_per_round", 30) + total_rounds = (total_hours * 60) // minutes_per_round + + start_time = datetime.now() + + for round_num in range(total_rounds): + simulated_minutes = round_num * minutes_per_round + simulated_hour = (simulated_minutes // 60) % 24 + simulated_day = simulated_minutes // (60 * 24) + 1 + + active_agents = get_active_agents_for_round( + env, config, simulated_hour, round_num + ) + + if not active_agents: + continue + + if action_logger: + action_logger.log_round_start(round_num + 1, simulated_hour, "twitter") + + actions = {agent: LLMAction() for _, agent in active_agents} + await env.step(actions) + + # 记录动作 + for agent_id, agent in active_agents: + if action_logger: + action_logger.log_action( + round_num=round_num + 1, + platform="twitter", + agent_id=agent_id, + agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"), + action_type="LLM_ACTION", + action_args={} + ) + total_actions += 1 + + if action_logger: + action_logger.log_round_end(round_num + 1, len(active_agents), "twitter") + + if (round_num + 1) % 20 == 0: + progress = (round_num + 1) / total_rounds * 100 + print(f"[Twitter] Day {simulated_day}, {simulated_hour:02d}:00 " + f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)") + + await env.close() + + if action_logger: + action_logger.log_simulation_end("twitter", total_rounds, total_actions) + + elapsed = (datetime.now() - start_time).total_seconds() + print(f"[Twitter] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}") + + +async def run_reddit_simulation( + config: Dict[str, Any], + simulation_dir: str, + action_logger: Optional[ActionLogger] = None +): + """运行Reddit模拟""" + print("[Reddit] 初始化...") + + model = create_model(config) + + profile_path = os.path.join(simulation_dir, "reddit_profiles.json") + if not os.path.exists(profile_path): + print(f"[Reddit] 错误: Profile文件不存在: {profile_path}") + return + + agent_graph = await generate_reddit_agent_graph( + profile_path=profile_path, + model=model, + available_actions=REDDIT_ACTIONS, + ) + + # 获取Agent名称映射 + agent_names = {} + for agent_id, agent in agent_graph.get_agents(): + agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}') + + db_path = os.path.join(simulation_dir, "reddit_simulation.db") + if os.path.exists(db_path): + os.remove(db_path) + + env = oasis.make( + agent_graph=agent_graph, + platform=oasis.DefaultPlatformType.REDDIT, + database_path=db_path, + ) + + await env.reset() + print("[Reddit] 环境已启动") + + if action_logger: + action_logger.log_simulation_start("reddit", config) + + total_actions = 0 + + # 执行初始事件 + event_config = config.get("event_config", {}) + initial_posts = event_config.get("initial_posts", []) + + if initial_posts: + initial_actions = {} + for post in initial_posts: + agent_id = post.get("poster_agent_id", 0) + content = post.get("content", "") + try: + agent = env.agent_graph.get_agent(agent_id) + if agent in initial_actions: + if not isinstance(initial_actions[agent], list): + initial_actions[agent] = [initial_actions[agent]] + initial_actions[agent].append(ManualAction( + action_type=ActionType.CREATE_POST, + action_args={"content": content} + )) + else: + initial_actions[agent] = ManualAction( + action_type=ActionType.CREATE_POST, + action_args={"content": content} + ) + + if action_logger: + action_logger.log_action( + round_num=0, + platform="reddit", + agent_id=agent_id, + agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"), + action_type="CREATE_POST", + action_args={"content": content[:100] + "..." if len(content) > 100 else content} + ) + total_actions += 1 + except Exception: + pass + + if initial_actions: + await env.step(initial_actions) + print(f"[Reddit] 已发布 {len(initial_actions)} 条初始帖子") + + # 主模拟循环 + time_config = config.get("time_config", {}) + total_hours = time_config.get("total_simulation_hours", 72) + minutes_per_round = time_config.get("minutes_per_round", 30) + total_rounds = (total_hours * 60) // minutes_per_round + + start_time = datetime.now() + + for round_num in range(total_rounds): + simulated_minutes = round_num * minutes_per_round + simulated_hour = (simulated_minutes // 60) % 24 + simulated_day = simulated_minutes // (60 * 24) + 1 + + active_agents = get_active_agents_for_round( + env, config, simulated_hour, round_num + ) + + if not active_agents: + continue + + if action_logger: + action_logger.log_round_start(round_num + 1, simulated_hour, "reddit") + + actions = {agent: LLMAction() for _, agent in active_agents} + await env.step(actions) + + # 记录动作 + for agent_id, agent in active_agents: + if action_logger: + action_logger.log_action( + round_num=round_num + 1, + platform="reddit", + agent_id=agent_id, + agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"), + action_type="LLM_ACTION", + action_args={} + ) + total_actions += 1 + + if action_logger: + action_logger.log_round_end(round_num + 1, len(active_agents), "reddit") + + if (round_num + 1) % 20 == 0: + progress = (round_num + 1) / total_rounds * 100 + print(f"[Reddit] Day {simulated_day}, {simulated_hour:02d}:00 " + f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)") + + await env.close() + + if action_logger: + action_logger.log_simulation_end("reddit", total_rounds, total_actions) + + elapsed = (datetime.now() - start_time).total_seconds() + print(f"[Reddit] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}") + + +async def main(): + parser = argparse.ArgumentParser(description='OASIS双平台并行模拟') + parser.add_argument( + '--config', + type=str, + required=True, + help='配置文件路径 (simulation_config.json)' + ) + parser.add_argument( + '--twitter-only', + action='store_true', + help='只运行Twitter模拟' + ) + parser.add_argument( + '--reddit-only', + action='store_true', + help='只运行Reddit模拟' + ) + parser.add_argument( + '--action-log', + type=str, + default='actions.jsonl', + help='动作日志文件路径 (默认: actions.jsonl)' + ) + + args = parser.parse_args() + + if not os.path.exists(args.config): + print(f"错误: 配置文件不存在: {args.config}") + sys.exit(1) + + config = load_config(args.config) + simulation_dir = os.path.dirname(args.config) or "." + + # 创建动作日志记录器 + action_log_path = os.path.join(simulation_dir, args.action_log) + action_logger = ActionLogger(action_log_path) + + print("=" * 60) + print("OASIS 双平台并行模拟") + print(f"配置文件: {args.config}") + print(f"模拟ID: {config.get('simulation_id', 'unknown')}") + print(f"动作日志: {action_log_path}") + print("=" * 60) + + time_config = config.get("time_config", {}) + print(f"\n模拟参数:") + print(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时") + print(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟") + print(f" - Agent数量: {len(config.get('agent_configs', []))}") + + # LLM推理说明 + reasoning = config.get("generation_reasoning", "") + if reasoning: + print(f"\nLLM配置推理:") + print(f" {reasoning[:500]}..." if len(reasoning) > 500 else f" {reasoning}") + + print("\n" + "=" * 60) + + start_time = datetime.now() + + if args.twitter_only: + await run_twitter_simulation(config, simulation_dir, action_logger) + elif args.reddit_only: + await run_reddit_simulation(config, simulation_dir, action_logger) + else: + # 并行运行(共享同一个action_logger) + await asyncio.gather( + run_twitter_simulation(config, simulation_dir, action_logger), + run_reddit_simulation(config, simulation_dir, action_logger), + ) + + total_elapsed = (datetime.now() - start_time).total_seconds() + print("\n" + "=" * 60) + print(f"全部模拟完成! 总耗时: {total_elapsed:.1f}秒") + print(f"动作日志已保存到: {action_log_path}") + print("=" * 60) + + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/backend/scripts/run_reddit_simulation.py b/backend/scripts/run_reddit_simulation.py new file mode 100644 index 0000000..59f7748 --- /dev/null +++ b/backend/scripts/run_reddit_simulation.py @@ -0,0 +1,298 @@ +""" +OASIS Reddit模拟预设脚本 +此脚本读取配置文件中的参数来执行模拟,实现全程自动化 + +使用方式: + python run_reddit_simulation.py --config /path/to/simulation_config.json +""" + +import argparse +import asyncio +import json +import os +import random +import sys +from datetime import datetime +from typing import Dict, Any, List + +# 添加项目路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +try: + from camel.models import ModelFactory + from camel.types import ModelPlatformType + import oasis + from oasis import ( + ActionType, + LLMAction, + ManualAction, + generate_reddit_agent_graph + ) +except ImportError as e: + print(f"错误: 缺少依赖 {e}") + print("请先安装: pip install oasis-ai camel-ai") + sys.exit(1) + + +class RedditSimulationRunner: + """Reddit模拟运行器""" + + # Reddit可用动作 + AVAILABLE_ACTIONS = [ + ActionType.LIKE_POST, + ActionType.DISLIKE_POST, + ActionType.CREATE_POST, + ActionType.CREATE_COMMENT, + ActionType.LIKE_COMMENT, + ActionType.DISLIKE_COMMENT, + ActionType.SEARCH_POSTS, + ActionType.SEARCH_USER, + ActionType.TREND, + ActionType.REFRESH, + ActionType.DO_NOTHING, + ActionType.FOLLOW, + ActionType.MUTE, + ] + + def __init__(self, config_path: str): + """ + 初始化模拟运行器 + + Args: + config_path: 配置文件路径 (simulation_config.json) + """ + self.config_path = config_path + self.config = self._load_config() + self.simulation_dir = os.path.dirname(config_path) + + def _load_config(self) -> Dict[str, Any]: + """加载配置文件""" + with open(self.config_path, 'r', encoding='utf-8') as f: + return json.load(f) + + def _get_profile_path(self) -> str: + """获取Profile文件路径""" + return os.path.join(self.simulation_dir, "reddit_profiles.json") + + def _get_db_path(self) -> str: + """获取数据库路径""" + return os.path.join(self.simulation_dir, "reddit_simulation.db") + + def _create_model(self): + """ + 创建LLM模型 + + OASIS使用camel-ai的ModelFactory,配置方式: + - 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量 + - 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量 + """ + import os + + llm_model = self.config.get("llm_model", "gpt-4o-mini") + llm_base_url = self.config.get("llm_base_url", "") + + # 如果配置了base_url,设置环境变量 + if llm_base_url: + os.environ["OPENAI_API_BASE_URL"] = llm_base_url + + return ModelFactory.create( + model_platform=ModelPlatformType.OPENAI, + model_type=llm_model, + ) + + def _get_active_agents_for_round( + self, + env, + current_hour: int, + round_num: int + ) -> List: + """ + 根据时间和配置决定本轮激活哪些Agent + """ + time_config = self.config.get("time_config", {}) + agent_configs = self.config.get("agent_configs", []) + + base_min = time_config.get("agents_per_hour_min", 5) + base_max = time_config.get("agents_per_hour_max", 20) + + peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]) + off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5]) + + if current_hour in peak_hours: + multiplier = time_config.get("peak_activity_multiplier", 1.5) + elif current_hour in off_peak_hours: + multiplier = time_config.get("off_peak_activity_multiplier", 0.3) + else: + multiplier = 1.0 + + target_count = int(random.uniform(base_min, base_max) * multiplier) + + candidates = [] + for cfg in agent_configs: + agent_id = cfg.get("agent_id", 0) + active_hours = cfg.get("active_hours", list(range(8, 23))) + activity_level = cfg.get("activity_level", 0.5) + + if current_hour not in active_hours: + continue + + if random.random() < activity_level: + candidates.append(agent_id) + + selected_ids = random.sample( + candidates, + min(target_count, len(candidates)) + ) if candidates else [] + + active_agents = [] + for agent_id in selected_ids: + try: + agent = env.agent_graph.get_agent(agent_id) + active_agents.append((agent_id, agent)) + except Exception: + pass + + return active_agents + + async def run(self): + """运行Reddit模拟""" + print("=" * 60) + print("OASIS Reddit模拟") + print(f"配置文件: {self.config_path}") + print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}") + print("=" * 60) + + time_config = self.config.get("time_config", {}) + total_hours = time_config.get("total_simulation_hours", 72) + minutes_per_round = time_config.get("minutes_per_round", 30) + total_rounds = (total_hours * 60) // minutes_per_round + + print(f"\n模拟参数:") + print(f" - 总模拟时长: {total_hours}小时") + print(f" - 每轮时间: {minutes_per_round}分钟") + print(f" - 总轮数: {total_rounds}") + print(f" - Agent数量: {len(self.config.get('agent_configs', []))}") + + print("\n初始化LLM模型...") + model = self._create_model() + + print("加载Agent Profile...") + profile_path = self._get_profile_path() + if not os.path.exists(profile_path): + print(f"错误: Profile文件不存在: {profile_path}") + return + + agent_graph = await generate_reddit_agent_graph( + profile_path=profile_path, + model=model, + available_actions=self.AVAILABLE_ACTIONS, + ) + + db_path = self._get_db_path() + if os.path.exists(db_path): + os.remove(db_path) + print(f"已删除旧数据库: {db_path}") + + print("创建OASIS环境...") + env = oasis.make( + agent_graph=agent_graph, + platform=oasis.DefaultPlatformType.REDDIT, + database_path=db_path, + ) + + await env.reset() + print("环境初始化完成\n") + + # 执行初始事件 + event_config = self.config.get("event_config", {}) + initial_posts = event_config.get("initial_posts", []) + + if initial_posts: + print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...") + initial_actions = {} + for post in initial_posts: + agent_id = post.get("poster_agent_id", 0) + content = post.get("content", "") + try: + agent = env.agent_graph.get_agent(agent_id) + if agent in initial_actions: + if not isinstance(initial_actions[agent], list): + initial_actions[agent] = [initial_actions[agent]] + initial_actions[agent].append(ManualAction( + action_type=ActionType.CREATE_POST, + action_args={"content": content} + )) + else: + initial_actions[agent] = ManualAction( + action_type=ActionType.CREATE_POST, + action_args={"content": content} + ) + except Exception as e: + print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}") + + if initial_actions: + await env.step(initial_actions) + print(f" 已发布 {len(initial_actions)} 条初始帖子") + + # 主模拟循环 + print("\n开始模拟循环...") + start_time = datetime.now() + + for round_num in range(total_rounds): + simulated_minutes = round_num * minutes_per_round + simulated_hour = (simulated_minutes // 60) % 24 + simulated_day = simulated_minutes // (60 * 24) + 1 + + active_agents = self._get_active_agents_for_round( + env, simulated_hour, round_num + ) + + if not active_agents: + continue + + actions = { + agent: LLMAction() + for _, agent in active_agents + } + + await env.step(actions) + + if (round_num + 1) % 10 == 0 or round_num == 0: + elapsed = (datetime.now() - start_time).total_seconds() + progress = (round_num + 1) / total_rounds * 100 + print(f" [Day {simulated_day}, {simulated_hour:02d}:00] " + f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) " + f"- {len(active_agents)} agents active " + f"- elapsed: {elapsed:.1f}s") + + await env.close() + + total_elapsed = (datetime.now() - start_time).total_seconds() + print(f"\n模拟完成!") + print(f" - 总耗时: {total_elapsed:.1f}秒") + print(f" - 数据库: {db_path}") + print("=" * 60) + + +async def main(): + parser = argparse.ArgumentParser(description='OASIS Reddit模拟') + parser.add_argument( + '--config', + type=str, + required=True, + help='配置文件路径 (simulation_config.json)' + ) + + args = parser.parse_args() + + if not os.path.exists(args.config): + print(f"错误: 配置文件不存在: {args.config}") + sys.exit(1) + + runner = RedditSimulationRunner(args.config) + await runner.run() + + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/backend/scripts/run_twitter_simulation.py b/backend/scripts/run_twitter_simulation.py new file mode 100644 index 0000000..c5f8966 --- /dev/null +++ b/backend/scripts/run_twitter_simulation.py @@ -0,0 +1,313 @@ +""" +OASIS Twitter模拟预设脚本 +此脚本读取配置文件中的参数来执行模拟,实现全程自动化 + +使用方式: + python run_twitter_simulation.py --config /path/to/simulation_config.json +""" + +import argparse +import asyncio +import json +import os +import random +import sys +from datetime import datetime +from typing import Dict, Any, List + +# 添加项目路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +try: + from camel.models import ModelFactory + from camel.types import ModelPlatformType + import oasis + from oasis import ( + ActionType, + LLMAction, + ManualAction, + generate_twitter_agent_graph + ) +except ImportError as e: + print(f"错误: 缺少依赖 {e}") + print("请先安装: pip install oasis-ai camel-ai") + sys.exit(1) + + +class TwitterSimulationRunner: + """Twitter模拟运行器""" + + # Twitter可用动作 + AVAILABLE_ACTIONS = [ + ActionType.CREATE_POST, + ActionType.LIKE_POST, + ActionType.REPOST, + ActionType.FOLLOW, + ActionType.DO_NOTHING, + ActionType.QUOTE_POST, + ] + + def __init__(self, config_path: str): + """ + 初始化模拟运行器 + + Args: + config_path: 配置文件路径 (simulation_config.json) + """ + self.config_path = config_path + self.config = self._load_config() + self.simulation_dir = os.path.dirname(config_path) + + def _load_config(self) -> Dict[str, Any]: + """加载配置文件""" + with open(self.config_path, 'r', encoding='utf-8') as f: + return json.load(f) + + def _get_profile_path(self) -> str: + """获取Profile文件路径(OASIS Twitter使用CSV格式)""" + return os.path.join(self.simulation_dir, "twitter_profiles.csv") + + def _get_db_path(self) -> str: + """获取数据库路径""" + return os.path.join(self.simulation_dir, "twitter_simulation.db") + + def _create_model(self): + """ + 创建LLM模型 + + OASIS使用camel-ai的ModelFactory,配置方式: + - 标准OpenAI: 只需设置 OPENAI_API_KEY 环境变量 + - 自定义API: 设置 OPENAI_API_KEY 和 OPENAI_API_BASE_URL 环境变量 + + 配置文件中的 llm_model 对应 model_type + """ + import os + + llm_model = self.config.get("llm_model", "gpt-4o-mini") + llm_base_url = self.config.get("llm_base_url", "") + + # 如果配置了base_url,设置环境变量(OASIS通过环境变量读取) + if llm_base_url: + os.environ["OPENAI_API_BASE_URL"] = llm_base_url + + return ModelFactory.create( + model_platform=ModelPlatformType.OPENAI, + model_type=llm_model, + ) + + def _get_active_agents_for_round( + self, + env, + current_hour: int, + round_num: int + ) -> List: + """ + 根据时间和配置决定本轮激活哪些Agent + + Args: + env: OASIS环境 + current_hour: 当前模拟小时(0-23) + round_num: 当前轮数 + + Returns: + 激活的Agent列表 + """ + time_config = self.config.get("time_config", {}) + agent_configs = self.config.get("agent_configs", []) + + # 基础激活数量 + base_min = time_config.get("agents_per_hour_min", 5) + base_max = time_config.get("agents_per_hour_max", 20) + + # 根据时段调整 + peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]) + off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5]) + + if current_hour in peak_hours: + multiplier = time_config.get("peak_activity_multiplier", 1.5) + elif current_hour in off_peak_hours: + multiplier = time_config.get("off_peak_activity_multiplier", 0.3) + else: + multiplier = 1.0 + + target_count = int(random.uniform(base_min, base_max) * multiplier) + + # 根据每个Agent的配置计算激活概率 + candidates = [] + for cfg in agent_configs: + agent_id = cfg.get("agent_id", 0) + active_hours = cfg.get("active_hours", list(range(8, 23))) + activity_level = cfg.get("activity_level", 0.5) + + # 检查是否在活跃时间 + if current_hour not in active_hours: + continue + + # 根据活跃度计算概率 + if random.random() < activity_level: + candidates.append(agent_id) + + # 随机选择 + selected_ids = random.sample( + candidates, + min(target_count, len(candidates)) + ) if candidates else [] + + # 转换为Agent对象 + active_agents = [] + for agent_id in selected_ids: + try: + agent = env.agent_graph.get_agent(agent_id) + active_agents.append((agent_id, agent)) + except Exception: + pass + + return active_agents + + async def run(self): + """运行Twitter模拟""" + print("=" * 60) + print("OASIS Twitter模拟") + print(f"配置文件: {self.config_path}") + print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}") + print("=" * 60) + + # 加载时间配置 + time_config = self.config.get("time_config", {}) + total_hours = time_config.get("total_simulation_hours", 72) + minutes_per_round = time_config.get("minutes_per_round", 30) + + # 计算总轮数 + total_rounds = (total_hours * 60) // minutes_per_round + + print(f"\n模拟参数:") + print(f" - 总模拟时长: {total_hours}小时") + print(f" - 每轮时间: {minutes_per_round}分钟") + print(f" - 总轮数: {total_rounds}") + print(f" - Agent数量: {len(self.config.get('agent_configs', []))}") + + # 创建模型 + print("\n初始化LLM模型...") + model = self._create_model() + + # 加载Agent图 + print("加载Agent Profile...") + profile_path = self._get_profile_path() + if not os.path.exists(profile_path): + print(f"错误: Profile文件不存在: {profile_path}") + return + + agent_graph = await generate_twitter_agent_graph( + profile_path=profile_path, + model=model, + available_actions=self.AVAILABLE_ACTIONS, + ) + + # 数据库路径 + db_path = self._get_db_path() + if os.path.exists(db_path): + os.remove(db_path) + print(f"已删除旧数据库: {db_path}") + + # 创建环境 + print("创建OASIS环境...") + env = oasis.make( + agent_graph=agent_graph, + platform=oasis.DefaultPlatformType.TWITTER, + database_path=db_path, + ) + + await env.reset() + print("环境初始化完成\n") + + # 执行初始事件 + event_config = self.config.get("event_config", {}) + initial_posts = event_config.get("initial_posts", []) + + if initial_posts: + print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...") + initial_actions = {} + for post in initial_posts: + agent_id = post.get("poster_agent_id", 0) + content = post.get("content", "") + try: + agent = env.agent_graph.get_agent(agent_id) + initial_actions[agent] = ManualAction( + action_type=ActionType.CREATE_POST, + action_args={"content": content} + ) + except Exception as e: + print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}") + + if initial_actions: + await env.step(initial_actions) + print(f" 已发布 {len(initial_actions)} 条初始帖子") + + # 主模拟循环 + print("\n开始模拟循环...") + start_time = datetime.now() + + for round_num in range(total_rounds): + # 计算当前模拟时间 + simulated_minutes = round_num * minutes_per_round + simulated_hour = (simulated_minutes // 60) % 24 + simulated_day = simulated_minutes // (60 * 24) + 1 + + # 获取本轮激活的Agent + active_agents = self._get_active_agents_for_round( + env, simulated_hour, round_num + ) + + if not active_agents: + continue + + # 构建动作 + actions = { + agent: LLMAction() + for _, agent in active_agents + } + + # 执行动作 + await env.step(actions) + + # 打印进度 + if (round_num + 1) % 10 == 0 or round_num == 0: + elapsed = (datetime.now() - start_time).total_seconds() + progress = (round_num + 1) / total_rounds * 100 + print(f" [Day {simulated_day}, {simulated_hour:02d}:00] " + f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) " + f"- {len(active_agents)} agents active " + f"- elapsed: {elapsed:.1f}s") + + # 关闭环境 + await env.close() + + total_elapsed = (datetime.now() - start_time).total_seconds() + print(f"\n模拟完成!") + print(f" - 总耗时: {total_elapsed:.1f}秒") + print(f" - 数据库: {db_path}") + print("=" * 60) + + +async def main(): + parser = argparse.ArgumentParser(description='OASIS Twitter模拟') + parser.add_argument( + '--config', + type=str, + required=True, + help='配置文件路径 (simulation_config.json)' + ) + + args = parser.parse_args() + + if not os.path.exists(args.config): + print(f"错误: 配置文件不存在: {args.config}") + sys.exit(1) + + runner = TwitterSimulationRunner(args.config) + await runner.run() + + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/backend/scripts/test_profile_format.py b/backend/scripts/test_profile_format.py new file mode 100644 index 0000000..354e8b5 --- /dev/null +++ b/backend/scripts/test_profile_format.py @@ -0,0 +1,166 @@ +""" +测试Profile格式生成是否符合OASIS要求 +验证: +1. Twitter Profile生成CSV格式 +2. Reddit Profile生成JSON详细格式 +""" + +import os +import sys +import json +import csv +import tempfile + +# 添加项目路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from app.services.oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile + + +def test_profile_formats(): + """测试Profile格式""" + print("=" * 60) + print("OASIS Profile格式测试") + print("=" * 60) + + # 创建测试Profile数据 + test_profiles = [ + OasisAgentProfile( + user_id=0, + user_name="test_user_123", + name="Test User", + bio="A test user for validation", + persona="Test User is an enthusiastic participant in social discussions.", + karma=1500, + friend_count=100, + follower_count=200, + statuses_count=500, + age=25, + gender="male", + mbti="INTJ", + country="China", + profession="Student", + interested_topics=["Technology", "Education"], + source_entity_uuid="test-uuid-123", + source_entity_type="Student", + ), + OasisAgentProfile( + user_id=1, + user_name="org_official_456", + name="Official Organization", + bio="Official account for Organization", + persona="This is an official institutional account that communicates official positions.", + karma=5000, + friend_count=50, + follower_count=10000, + statuses_count=200, + profession="Organization", + interested_topics=["Public Policy", "Announcements"], + source_entity_uuid="test-uuid-456", + source_entity_type="University", + ), + ] + + generator = OasisProfileGenerator.__new__(OasisProfileGenerator) + + # 使用临时目录 + with tempfile.TemporaryDirectory() as temp_dir: + twitter_path = os.path.join(temp_dir, "twitter_profiles.csv") + reddit_path = os.path.join(temp_dir, "reddit_profiles.json") + + # 测试Twitter CSV格式 + print("\n1. 测试Twitter Profile (CSV格式)") + print("-" * 40) + generator._save_twitter_csv(test_profiles, twitter_path) + + # 读取并验证CSV + with open(twitter_path, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + rows = list(reader) + + print(f" 文件: {twitter_path}") + print(f" 行数: {len(rows)}") + print(f" 表头: {list(rows[0].keys())}") + print(f"\n 示例数据 (第1行):") + for key, value in rows[0].items(): + print(f" {key}: {value}") + + # 验证必需字段 + required_twitter_fields = ['user_id', 'user_name', 'name', 'bio', + 'friend_count', 'follower_count', 'statuses_count', 'created_at'] + missing = set(required_twitter_fields) - set(rows[0].keys()) + if missing: + print(f"\n [错误] 缺少字段: {missing}") + else: + print(f"\n [通过] 所有必需字段都存在") + + # 测试Reddit JSON格式 + print("\n2. 测试Reddit Profile (JSON详细格式)") + print("-" * 40) + generator._save_reddit_json(test_profiles, reddit_path) + + # 读取并验证JSON + with open(reddit_path, 'r', encoding='utf-8') as f: + reddit_data = json.load(f) + + print(f" 文件: {reddit_path}") + print(f" 条目数: {len(reddit_data)}") + print(f" 字段: {list(reddit_data[0].keys())}") + print(f"\n 示例数据 (第1条):") + print(json.dumps(reddit_data[0], ensure_ascii=False, indent=4)) + + # 验证详细格式字段 + required_reddit_fields = ['realname', 'username', 'bio', 'persona'] + optional_reddit_fields = ['age', 'gender', 'mbti', 'country', 'profession', 'interested_topics'] + + missing = set(required_reddit_fields) - set(reddit_data[0].keys()) + if missing: + print(f"\n [错误] 缺少必需字段: {missing}") + else: + print(f"\n [通过] 所有必需字段都存在") + + present_optional = set(optional_reddit_fields) & set(reddit_data[0].keys()) + print(f" [信息] 可选字段: {present_optional}") + + print("\n" + "=" * 60) + print("测试完成!") + print("=" * 60) + + +def show_expected_formats(): + """显示OASIS期望的格式""" + print("\n" + "=" * 60) + print("OASIS 期望的Profile格式参考") + print("=" * 60) + + print("\n1. Twitter Profile (CSV格式)") + print("-" * 40) + twitter_example = """user_id,user_name,name,bio,friend_count,follower_count,statuses_count,created_at +0,user0,User Zero,I am user zero with interests in technology.,100,150,500,2023-01-01 +1,user1,User One,Tech enthusiast and coffee lover.,200,250,1000,2023-01-02""" + print(twitter_example) + + print("\n2. Reddit Profile (JSON详细格式)") + print("-" * 40) + reddit_example = [ + { + "realname": "James Miller", + "username": "millerhospitality", + "bio": "Passionate about hospitality & tourism.", + "persona": "James is a seasoned professional in the Hospitality & Tourism industry...", + "age": 40, + "gender": "male", + "mbti": "ESTJ", + "country": "UK", + "profession": "Hospitality & Tourism", + "interested_topics": ["Economics", "Business"] + } + ] + print(json.dumps(reddit_example, ensure_ascii=False, indent=2)) + + +if __name__ == "__main__": + test_profile_formats() + show_expected_formats() + +