Add function to retrieve agent names from configuration
- Introduced `get_agent_names_from_config` function to map agent IDs to their entity names from the simulation configuration, enhancing clarity in action representation. - Updated simulation scripts to utilize this new function for fetching agent names, ensuring that real entity names are displayed instead of default identifiers. - Improved handling of agent names by falling back to default names only if not specified in the configuration, maintaining consistency across simulations.
This commit is contained in:
parent
88676e8207
commit
3c1d554152
1 changed files with 34 additions and 6 deletions
|
|
@ -176,6 +176,30 @@ ACTION_TYPE_MAP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_names_from_config(config: Dict[str, Any]) -> Dict[int, str]:
|
||||||
|
"""
|
||||||
|
从 simulation_config 中获取 agent_id -> entity_name 的映射
|
||||||
|
|
||||||
|
这样可以在 actions.jsonl 中显示真实的实体名称,而不是 "Agent_0" 这样的代号
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: simulation_config.json 的内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
agent_id -> entity_name 的映射字典
|
||||||
|
"""
|
||||||
|
agent_names = {}
|
||||||
|
agent_configs = config.get("agent_configs", [])
|
||||||
|
|
||||||
|
for agent_config in agent_configs:
|
||||||
|
agent_id = agent_config.get("agent_id")
|
||||||
|
entity_name = agent_config.get("entity_name", f"Agent_{agent_id}")
|
||||||
|
if agent_id is not None:
|
||||||
|
agent_names[agent_id] = entity_name
|
||||||
|
|
||||||
|
return agent_names
|
||||||
|
|
||||||
|
|
||||||
def fetch_new_actions_from_db(
|
def fetch_new_actions_from_db(
|
||||||
db_path: str,
|
db_path: str,
|
||||||
last_rowid: int,
|
last_rowid: int,
|
||||||
|
|
@ -405,10 +429,12 @@ async def run_twitter_simulation(
|
||||||
available_actions=TWITTER_ACTIONS,
|
available_actions=TWITTER_ACTIONS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取Agent名称映射
|
# 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X)
|
||||||
agent_names = {}
|
agent_names = get_agent_names_from_config(config)
|
||||||
|
# 如果配置中没有某个 agent,则使用 OASIS 的默认名称
|
||||||
for agent_id, agent in agent_graph.get_agents():
|
for agent_id, agent in agent_graph.get_agents():
|
||||||
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
if agent_id not in agent_names:
|
||||||
|
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
||||||
|
|
||||||
db_path = os.path.join(simulation_dir, "twitter_simulation.db")
|
db_path = os.path.join(simulation_dir, "twitter_simulation.db")
|
||||||
if os.path.exists(db_path):
|
if os.path.exists(db_path):
|
||||||
|
|
@ -550,10 +576,12 @@ async def run_reddit_simulation(
|
||||||
available_actions=REDDIT_ACTIONS,
|
available_actions=REDDIT_ACTIONS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取Agent名称映射
|
# 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X)
|
||||||
agent_names = {}
|
agent_names = get_agent_names_from_config(config)
|
||||||
|
# 如果配置中没有某个 agent,则使用 OASIS 的默认名称
|
||||||
for agent_id, agent in agent_graph.get_agents():
|
for agent_id, agent in agent_graph.get_agents():
|
||||||
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
if agent_id not in agent_names:
|
||||||
|
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
||||||
|
|
||||||
db_path = os.path.join(simulation_dir, "reddit_simulation.db")
|
db_path = os.path.join(simulation_dir, "reddit_simulation.db")
|
||||||
if os.path.exists(db_path):
|
if os.path.exists(db_path):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue