Add function to retrieve agent names from configuration

- Introduced `get_agent_names_from_config` function to map agent IDs to their entity names from the simulation configuration, enhancing clarity in action representation.
- Updated simulation scripts to utilize this new function for fetching agent names, ensuring that real entity names are displayed instead of default identifiers.
- Improved handling of agent names by falling back to default names only if not specified in the configuration, maintaining consistency across simulations.
This commit is contained in:
666ghj 2025-12-04 19:19:16 +08:00
parent 88676e8207
commit 3c1d554152

View file

@ -176,6 +176,30 @@ ACTION_TYPE_MAP = {
} }
def get_agent_names_from_config(config: Dict[str, Any]) -> Dict[int, str]:
"""
simulation_config 中获取 agent_id -> entity_name 的映射
这样可以在 actions.jsonl 中显示真实的实体名称而不是 "Agent_0" 这样的代号
Args:
config: simulation_config.json 的内容
Returns:
agent_id -> entity_name 的映射字典
"""
agent_names = {}
agent_configs = config.get("agent_configs", [])
for agent_config in agent_configs:
agent_id = agent_config.get("agent_id")
entity_name = agent_config.get("entity_name", f"Agent_{agent_id}")
if agent_id is not None:
agent_names[agent_id] = entity_name
return agent_names
def fetch_new_actions_from_db( def fetch_new_actions_from_db(
db_path: str, db_path: str,
last_rowid: int, last_rowid: int,
@ -405,9 +429,11 @@ async def run_twitter_simulation(
available_actions=TWITTER_ACTIONS, available_actions=TWITTER_ACTIONS,
) )
# 获取Agent名称映射 # 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X
agent_names = {} agent_names = get_agent_names_from_config(config)
# 如果配置中没有某个 agent则使用 OASIS 的默认名称
for agent_id, agent in agent_graph.get_agents(): for agent_id, agent in agent_graph.get_agents():
if agent_id not in agent_names:
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}') agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
db_path = os.path.join(simulation_dir, "twitter_simulation.db") db_path = os.path.join(simulation_dir, "twitter_simulation.db")
@ -550,9 +576,11 @@ async def run_reddit_simulation(
available_actions=REDDIT_ACTIONS, available_actions=REDDIT_ACTIONS,
) )
# 获取Agent名称映射 # 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X
agent_names = {} agent_names = get_agent_names_from_config(config)
# 如果配置中没有某个 agent则使用 OASIS 的默认名称
for agent_id, agent in agent_graph.get_agents(): for agent_id, agent in agent_graph.get_agents():
if agent_id not in agent_names:
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}') agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
db_path = os.path.join(simulation_dir, "reddit_simulation.db") db_path = os.path.join(simulation_dir, "reddit_simulation.db")