-
Notifications
You must be signed in to change notification settings - Fork 4
Black Death Wrapper
The wrapper detects agents that are "dead" at reset by comparing the list of agents it expects (i.e., the list of agents in self.agents
) with the agents for which it receives observations from the environment's reset
method. If an agent in self.agents
does not have an entry in the returned observations, the wrapper considers that agent to be "dead" and assigns it a zero observation.
Here's a more detailed breakdown of the process:
-
Call the Environment's
reset
Method: The wrapper calls thereset
method of the underlying environment (self.env.reset(seed=seed, options=options)
), which returns a tuple of:-
observations
: A dictionary with agent identifiers as keys and their corresponding initial observations as values. -
infos
: A dictionary with agent identifiers as keys and their corresponding info dictionaries as values.
-
-
Update Agents List: The wrapper updates its internal list of agents to match the agents currently active in the environment (
self.agents = self.env.agents[:]
). -
Validate Observation Spaces: The wrapper ensures that all agents have observation spaces of type
Box
by calling_check_valid_for_black_death
. -
Detect Dead Agents: The wrapper detects dead agents by checking which agents in
self.agents
are missing from theobservations
dictionary returned by the environment'sreset
method. These agents are considered dead because they do not have initial observations. -
Create Zero Observations for Dead Agents: For each dead agent (agents not present in the
observations
dictionary), the wrapper creates a zero-filled observation usingnp.zeros_like(self.observation_space(agent).low)
. -
Combine Observations: The wrapper combines the actual observations from the environment with the zero observations for dead agents and returns this combined dictionary along with the
infos
dictionary.
Here is the code for the reset
method with detailed comments explaining each step:
import gymnasium
import numpy as np
from pettingzoo.utils.wrappers import BaseParallelWrapper
from supersuit.utils.wrapper_chooser import WrapperChooser
class BlackDeathParallelWrapper(BaseParallelWrapper):
def __init__(self, env):
super().__init__(env)
def _check_valid_for_black_death(self):
for agent in self.agents:
space = self.observation_space(agent)
assert isinstance(
space, gymnasium.spaces.Box
), f"Observation spaces for black death must be Box spaces, but found {space}"
def reset(self, seed=None, options=None):
# Call the environment's reset method to get initial observations and infos
observations, infos = self.env.reset(seed=seed, options=options)
# Update the internal list of agents to match the environment's agents
self.agents = self.env.agents[:]
# Ensure that all agents have valid observation spaces
self._check_valid_for_black_death()
# Create zero observations for agents not present in the initial observations
black_obs = {
agent: np.zeros_like(self.observation_space(agent).low)
for agent in self.agents
if agent not in observations
}
# Combine the actual observations with the zero observations for dead agents
combined_observations = {**observations, **black_obs}
# Return the combined observations and infos
return combined_observations, infos
def step(self, actions):
# Only send actions for agents currently in the environment
active_actions = {agent: actions[agent] for agent in self.env.agents}
observations, rewards, terminations, truncations, infos = self.env.step(active_actions)
# Identify dead agents and create zero observations, zero rewards, and empty infos for them
black_obs = {
agent: np.zeros_like(self.observation_space(agent).low)
for agent in self.agents
if agent not in observations
}
black_rewards = {agent: 0.0 for agent in self.agents if agent not in observations}
black_infos = {agent: {} for agent in self.agents if agent not in observations}
terminations_array = np.fromiter(terminations.values(), dtype=bool)
truncations_array = np.fromiter(truncations.values(), dtype=bool)
env_is_done = (terminations_array & truncations_array).all()
total_obs = {**black_obs, **observations}
total_rewards = {**black_rewards, **rewards}
total_infos = {**black_infos, **infos}
total_dones = {agent: env_is_done for agent in self.agents}
if env_is_done:
self.agents.clear()
return total_obs, total_rewards, total_dones, total_dones, total_infos
black_death_v3 = WrapperChooser(parallel_wrapper=BlackDeathParallelWrapper)
-
Initial Observations and Infos: Retrieved by calling the environment's
reset
method. - Updating Agents List: Matches the current agents in the environment.
- Dead Agent Detection: Agents missing from the initial observations are considered "dead."
- Zero Observations: Zero-filled observations are created for dead agents.
- Combining Results: The method combines actual observations with zero observations for dead agents and returns them along with infos.