Skip to content

Black Death Wrapper

HBP1969 edited this page Oct 10, 2024 · 2 revisions

Notes:

The wrapper detects agents that are "dead" at reset by comparing the list of agents it expects (i.e., the list of agents in self.agents) with the agents for which it receives observations from the environment's reset method. If an agent in self.agents does not have an entry in the returned observations, the wrapper considers that agent to be "dead" and assigns it a zero observation.

Here's a more detailed breakdown of the process:

  1. Call the Environment's reset Method: The wrapper calls the reset method of the underlying environment (self.env.reset(seed=seed, options=options)), which returns a tuple of:

    • observations: A dictionary with agent identifiers as keys and their corresponding initial observations as values.
    • infos: A dictionary with agent identifiers as keys and their corresponding info dictionaries as values.
  2. Update Agents List: The wrapper updates its internal list of agents to match the agents currently active in the environment (self.agents = self.env.agents[:]).

  3. Validate Observation Spaces: The wrapper ensures that all agents have observation spaces of type Box by calling _check_valid_for_black_death.

  4. Detect Dead Agents: The wrapper detects dead agents by checking which agents in self.agents are missing from the observations dictionary returned by the environment's reset method. These agents are considered dead because they do not have initial observations.

  5. Create Zero Observations for Dead Agents: For each dead agent (agents not present in the observations dictionary), the wrapper creates a zero-filled observation using np.zeros_like(self.observation_space(agent).low).

  6. Combine Observations: The wrapper combines the actual observations from the environment with the zero observations for dead agents and returns this combined dictionary along with the infos dictionary.

Here is the code for the reset method with detailed comments explaining each step:

import gymnasium
import numpy as np
from pettingzoo.utils.wrappers import BaseParallelWrapper
from supersuit.utils.wrapper_chooser import WrapperChooser

class BlackDeathParallelWrapper(BaseParallelWrapper):
    def __init__(self, env):
        super().__init__(env)

    def _check_valid_for_black_death(self):
        for agent in self.agents:
            space = self.observation_space(agent)
            assert isinstance(
                space, gymnasium.spaces.Box
            ), f"Observation spaces for black death must be Box spaces, but found {space}"

    def reset(self, seed=None, options=None):
        # Call the environment's reset method to get initial observations and infos
        observations, infos = self.env.reset(seed=seed, options=options)

        # Update the internal list of agents to match the environment's agents
        self.agents = self.env.agents[:]
        
        # Ensure that all agents have valid observation spaces
        self._check_valid_for_black_death()
        
        # Create zero observations for agents not present in the initial observations
        black_obs = {
            agent: np.zeros_like(self.observation_space(agent).low)
            for agent in self.agents
            if agent not in observations
        }
        
        # Combine the actual observations with the zero observations for dead agents
        combined_observations = {**observations, **black_obs}
        
        # Return the combined observations and infos
        return combined_observations, infos

    def step(self, actions):
        # Only send actions for agents currently in the environment
        active_actions = {agent: actions[agent] for agent in self.env.agents}
        observations, rewards, terminations, truncations, infos = self.env.step(active_actions)
        
        # Identify dead agents and create zero observations, zero rewards, and empty infos for them
        black_obs = {
            agent: np.zeros_like(self.observation_space(agent).low)
            for agent in self.agents
            if agent not in observations
        }
        black_rewards = {agent: 0.0 for agent in self.agents if agent not in observations}
        black_infos = {agent: {} for agent in self.agents if agent not in observations}

        terminations_array = np.fromiter(terminations.values(), dtype=bool)
        truncations_array = np.fromiter(truncations.values(), dtype=bool)
        env_is_done = (terminations_array & truncations_array).all()
        
        total_obs = {**black_obs, **observations}
        total_rewards = {**black_rewards, **rewards}
        total_infos = {**black_infos, **infos}
        total_dones = {agent: env_is_done for agent in self.agents}

        if env_is_done:
            self.agents.clear()
        
        return total_obs, total_rewards, total_dones, total_dones, total_infos

black_death_v3 = WrapperChooser(parallel_wrapper=BlackDeathParallelWrapper)

Summary:

  • Initial Observations and Infos: Retrieved by calling the environment's reset method.
  • Updating Agents List: Matches the current agents in the environment.
  • Dead Agent Detection: Agents missing from the initial observations are considered "dead."
  • Zero Observations: Zero-filled observations are created for dead agents.
  • Combining Results: The method combines actual observations with zero observations for dead agents and returns them along with infos.
Clone this wiki locally