Skip to content

Commit

Permalink
Use PB type for actions space and for ram states type, acquire new ca…
Browse files Browse the repository at this point in the history
…se keys for ram states, rename wrappers list key
  • Loading branch information
alexpalms committed Sep 14, 2023
1 parent fba8bed commit 0770bff
Show file tree
Hide file tree
Showing 18 changed files with 245 additions and 239 deletions.
4 changes: 3 additions & 1 deletion diambra/arena/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from diambra.engine import SpaceType
from diambra.engine import model
from .make_env import make
from .utils.gym_utils import available_games, game_sha_256, check_game_sha_256, get_num_envs
from .utils.gym_utils import available_games, game_sha_256, check_game_sha_256, get_num_envs
43 changes: 21 additions & 22 deletions diambra/arena/arena_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from diambra.arena.engine.interface import DiambraEngine
from diambra.arena.env_settings import EnvironmentSettings1P, EnvironmentSettings2P
from typing import Union, Any, Dict, List
from diambra.engine import model, SpaceType

class DiambraGymBase(gym.Env):
"""Diambra Environment gymnasium base interface"""
Expand Down Expand Up @@ -84,8 +85,8 @@ def reset(self, seed: int = None, options: Dict[str, Any] = None):
if options is None:
options = {}
options["seed"] = seed
request = self.env_settings.update_variable_env_settings(options)
response = self.arena_engine.reset(request.variable_env_settings)
request = self.env_settings.update_episode_settings(options)
response = self.arena_engine.reset(request.episode_settings)
return self._get_obs(response), self._get_info(response)

# Rendering the environment
Expand Down Expand Up @@ -153,18 +154,17 @@ def _get_ram_states_obs_dict(self):
for k, v in self.env_info.ram_states.items():
if k.endswith("P1"):
target_dict = player_spec_dict
knew = "own_" + k[:-2]
knew = "own_" + k[:-3]
elif k.endswith("P2"):
target_dict = player_spec_dict
knew = "opp_" + k[:-2]
knew = "opp_" + k[:-3]
else:
target_dict = generic_dict
knew = k

# Discrete spaces (binary / categorical)
if v.type == 0 or v.type == 2:
if v.type == SpaceType.BINARY or v.type == SpaceType.DISCRETE:
target_dict[knew] = gym.spaces.Discrete(v.max + 1)
elif v.type == 1: # Box spaces
elif v.type == SpaceType.BOX:
target_dict[knew] = gym.spaces.Box(low=v.min, high=v.max, shape=(1,), dtype=np.int32)
else:
raise RuntimeError("Only Discrete (Binary/Categorical) | Box Spaces allowed")
Expand Down Expand Up @@ -193,20 +193,20 @@ def _player_specific_ram_states_integration(self, response, idx):
generic_dict = {}

# Adding env additional observations (side-specific)
player_role = self.env_settings.pb_model.variable_env_settings.player_env_settings[idx].role
player_role = self.env_settings.pb_model.episode_settings.player_settings[idx].role
for k, v in self.env_info.ram_states.items():
if (k.endswith("P1") or k.endswith("P2")):
target_dict = player_spec_dict
if k[-2:] == player_role:
knew = "own_" + k[:-2]
knew = "own_" + k[:-3]
else:
knew = "opp_" + k[:-2]
knew = "opp_" + k[:-3]
else:
target_dict = generic_dict
knew = k

# Box spaces
if v.type == 1:
if v.type == SpaceType.BOX:
target_dict[knew] = np.array([response.observation.ram_states[k]], dtype=np.int32)
else: # Discrete spaces (binary / categorical)
target_dict[knew] = response.observation.ram_states[k]
Expand Down Expand Up @@ -240,12 +240,11 @@ def __init__(self, env_settings):
# Discrete actions:
# - Arrows U Buttons -> One discrete set
# NB: use the convention NOOP = 0
if env_settings.action_space == "multi_discrete":
if env_settings.action_space == SpaceType.MULTI_DISCRETE:
self.action_space = gym.spaces.MultiDiscrete(self.n_actions)
self.logger.debug("Using MultiDiscrete action space")
elif env_settings.action_space == "discrete":
elif env_settings.action_space == SpaceType.DISCRETE:
self.action_space = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)
self.logger.debug("Using Discrete action space")
self.logger.debug("Using {} action space".format(SpaceType.Name(env_settings.action_space)))

# Return the no-op action
def get_no_op_action(self):
Expand Down Expand Up @@ -297,17 +296,17 @@ def __init__(self, env_settings):

# Action space
# Dictionary
action_spaces_values = {"multi_discrete": gym.spaces.MultiDiscrete(self.n_actions),
"discrete": gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)}
action_space_dict = self._update_dict(action_spaces_values)
action_spaces_values = {SpaceType.MULTI_DISCRETE: gym.spaces.MultiDiscrete(self.n_actions),
SpaceType.DISCRETE: gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)}
action_space_dict = self._map_action_spaces_to_agents(action_spaces_values)
self.logger.debug("Using the following action spaces: {}".format(action_space_dict))
self.action_space = gym.spaces.Dict(action_space_dict)

# Return the no-op action
def get_no_op_action(self):
no_op_values = {"multi_discrete": [0, 0],
"discrete": 0}
return self._update_dict(no_op_values)
no_op_values = {SpaceType.MULTI_DISCRETE: [0, 0],
SpaceType.DISCRETE: 0}
return self._map_action_spaces_to_agents(no_op_values)

# Step the environment
def step(self, actions: Dict[str, Union[int, List[int]]]):
Expand All @@ -324,7 +323,7 @@ def step(self, actions: Dict[str, Union[int, List[int]]]):

return observation, response.reward, response.info.game_states["game_done"], False, self._get_info(response)

def _update_dict(self, values_dict):
def _map_action_spaces_to_agents(self, values_dict):
out_dict = {}
for idx, action_space in enumerate(self.env_settings.action_space):
out_dict["agent_{}".format(idx)] = values_dict[action_space]
Expand Down
4 changes: 2 additions & 2 deletions diambra/arena/engine/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def env_init(self, env_settings_pb):
return response

# Reset the environment [pb low level]
def reset(self, variable_env_settings):
return self.client.Reset(variable_env_settings)
def reset(self, episode_settings):
return self.client.Reset(episode_settings)

# Step the environment [pb low level]
def step(self, action_list):
Expand Down
51 changes: 25 additions & 26 deletions diambra/arena/env_settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Union, List, Tuple, Any, Dict
from diambra.arena.utils.gym_utils import available_games
from diambra.arena import SpaceType
import numpy as np
import random
from diambra.engine import model
Expand All @@ -21,6 +22,10 @@ def check_val_in_list(key, value, valid_list):
assert (value in valid_list), error_message
assert (type(value)==type(valid_list[valid_list.index(value)])), error_message

def check_space_type(key, value, valid_list):
error_message = "ERROR: \"{}\" ({}) admissible values are {}".format(key, SpaceType.Name(value), [SpaceType.Name(elem) for elem in valid_list])
assert (value in valid_list), error_message

@dataclass
class EnvironmentSettings:
"""Generic Environment Settings Class"""
Expand Down Expand Up @@ -50,7 +55,7 @@ class EnvironmentSettings:
_last_seed: int = None
pb_model: model = None

variable_env_settings = ["seed", "difficulty", "continue_game", "show_final", "tower", "role",
episode_settings = ["seed", "difficulty", "continue_game", "show_final", "tower", "role",
"characters", "outfits", "super_art", "fighting_style", "ultimate_style"]

# Transforming env settings dict to pb request
Expand All @@ -74,18 +79,18 @@ def get_pb_request(self, init=False):
if init is False:
self._process_random_values()

player_env_settings = self._get_player_specific_values()
player_settings = self._get_player_specific_values()

variable_env_settings = model.EnvSettings.VariableEnvSettings(
episode_settings = model.EnvSettings.EpisodeSettings(
random_seed=self.seed,
difficulty=self.difficulty,
continue_game=self.continue_game,
show_final=self.show_final,
tower=self.tower,
player_env_settings=player_env_settings,
player_settings=player_settings,
)
else:
variable_env_settings = model.EnvSettings.VariableEnvSettings()
episode_settings = model.EnvSettings.EpisodeSettings()

request = model.EnvSettings(
game_id=self.game_id,
Expand All @@ -96,7 +101,7 @@ def get_pb_request(self, init=False):
disable_joystick=self.disable_joystick,
rank=self.rank,
action_spaces=action_spaces,
variable_env_settings=variable_env_settings,
episode_settings=episode_settings,
)

self.pb_model = request
Expand All @@ -111,16 +116,16 @@ def finalize_init(self, env_info):
self.valid_characters = [character for character in self.env_info.characters_info.char_list \
if character not in self.env_info.characters_info.char_forbidden_list]

def update_variable_env_settings(self, options: Dict[str, Any] = None):
def update_episode_settings(self, options: Dict[str, Any] = None):
for k, v in options.items():
if k in self.variable_env_settings:
if k in self.episode_settings:
setattr(self, k, v)

self._sanity_check()

# Storing original attributes before sampling random ones
original_settings_values = {}
for k in self.variable_env_settings:
for k in self.episode_settings:
original_settings_values[k] = getattr(self, k)

request = self.get_pb_request()
Expand Down Expand Up @@ -188,7 +193,7 @@ class EnvironmentSettings1P(EnvironmentSettings):
role: str = "Random"
characters: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]] = ("Random", "Random", "Random")
outfits: int = 1
action_space: str = "multi_discrete"
action_space: int = SpaceType.MULTI_DISCRETE
super_art: Union[int, str] = "Random" # SFIII Specific
fighting_style: Union[int, str] = "Random" # KOF Specific
ultimate_style: Union[Tuple[str, str, str], Tuple[int, int, int]] = ("Random", "Random", "Random") # KOF Specific
Expand All @@ -203,7 +208,7 @@ def _sanity_check(self):
self.characters += ("Random", )

check_num_in_range("n_players", self.n_players, [1, 1])
check_val_in_list("action_space", self.action_space, ["discrete", "multi_discrete"])
check_space_type("action_space", self.action_space, [SpaceType.DISCRETE, SpaceType.MULTI_DISCRETE])
check_val_in_list("role", self.role, ["P1", "P2", "Random"])
# Check for characters
char_list = list(self.env_info.characters_info.char_list)
Expand All @@ -217,10 +222,7 @@ def _sanity_check(self):
check_val_in_list("ultimate_style[{}]".format(idx), self.ultimate_style[idx], ["Random", 1, 2])

def _get_action_spaces(self):
action_space = model.EnvSettings.ActionSpace.ACTION_SPACE_DISCRETE if self.action_space == "discrete" else \
model.EnvSettings.ActionSpace.ACTION_SPACE_MULTI_DISCRETE

return [action_space]
return [self.action_space]

def _process_random_values(self):
super()._process_random_values()
Expand All @@ -243,7 +245,7 @@ def _process_random_values(self):
self.ultimate_style = tuple([random.choice(list(range(1, 3))) if self.ultimate_style[idx] == "Random" else self.ultimate_style[idx] for idx in range(3)])

def _get_player_specific_values(self):
player_env_settings = model.EnvSettings.VariableEnvSettings.PlayerEnvSettings(
player_settings = model.EnvSettings.EpisodeSettings.PlayerSettings(
role=self.role,
characters=[self.characters[idx] for idx in range(self.env_info.characters_info.chars_to_select)],
outfits=self.outfits,
Expand All @@ -252,7 +254,7 @@ def _get_player_specific_values(self):
ultimate_style={"dash": self.ultimate_style[0], "evade": self.ultimate_style[1], "bar": self.ultimate_style[2]}
)

return [player_env_settings]
return [player_settings]

@dataclass
class EnvironmentSettings2P(EnvironmentSettings):
Expand All @@ -263,7 +265,7 @@ class EnvironmentSettings2P(EnvironmentSettings):
Tuple[Tuple[str, str, str], Tuple[str, str, str]]] =\
(("Random", "Random", "Random"), ("Random", "Random", "Random"))
outfits: Tuple[int, int] = (1, 1)
action_space: Tuple[str, str] = ("multi_discrete", "multi_discrete")
action_space: Tuple[int, int] = (SpaceType.MULTI_DISCRETE, SpaceType.MULTI_DISCRETE)
super_art: Union[Tuple[str, str], Tuple[int, int], Tuple[str, int], Tuple[int, str]] = ("Random", "Random") # SFIII Specific
fighting_style: Union[Tuple[str, str], Tuple[int, int], Tuple[str, int], Tuple[int, str]] = ("Random", "Random") # KOF Specific
ultimate_style: Union[Tuple[Tuple[str, str, str], Tuple[str, str, str]], Tuple[Tuple[int, int, int], Tuple[int, int, int]]] =\
Expand All @@ -286,7 +288,7 @@ def _sanity_check(self):
char_list = list(self.env_info.characters_info.char_list)
char_list.append("Random")
for idx in range(2):
check_val_in_list("action_space[{}]".format(idx), self.action_space[idx], ["discrete", "multi_discrete"])
check_space_type("action_space[{}]".format(idx), self.action_space[idx], [SpaceType.DISCRETE, SpaceType.MULTI_DISCRETE])
check_val_in_list("role[{}]".format(idx), self.role[idx], ["P1", "P2", "Random"])
for jdx in range(3):
check_val_in_list("characters[{}][{}]".format(idx, jdx), self.characters[idx][jdx], char_list)
Expand Down Expand Up @@ -325,17 +327,14 @@ def _process_random_values(self):
self.ultimate_style = tuple([[random.choice(list(range(1, 3))) if self.ultimate_style[idx][jdx] == "Random" else self.ultimate_style[idx][jdx] for jdx in range(3)] for idx in range(2)])

def _get_action_spaces(self):
action_spaces = [model.EnvSettings.ActionSpace.ACTION_SPACE_DISCRETE if action_space == "discrete" else \
model.EnvSettings.ActionSpace.ACTION_SPACE_MULTI_DISCRETE for action_space in self.action_space]

return action_spaces
return [action_space for action_space in self.action_space]

def _get_player_specific_values(self):
players_env_settings = []

for idx in range(2):

player_env_settings = model.EnvSettings.VariableEnvSettings.PlayerEnvSettings(
player_settings = model.EnvSettings.EpisodeSettings.PlayerSettings(
role=self.role[idx],
characters=[self.characters[idx][jdx] for jdx in range(self.env_info.characters_info.chars_to_select)],
outfits=self.outfits[idx],
Expand All @@ -344,7 +343,7 @@ def _get_player_specific_values(self):
ultimate_style={"dash": self.ultimate_style[idx][0], "evade": self.ultimate_style[idx][1], "bar": self.ultimate_style[idx][2]}
)

players_env_settings.append(player_env_settings)
players_env_settings.append(player_settings)

return players_env_settings

Expand All @@ -366,7 +365,7 @@ class WrappersSettings:
frame_shape: Tuple[int, int, int] = (0, 0, 0)
flatten: bool = False
filter_keys: List[str] = None
additional_wrappers_list: List[List[Any]] = None
wrappers: List[List[Any]] = None

def sanity_check(self):
check_num_in_range("no_op_max", self.no_op_max, [0, 12])
Expand Down
13 changes: 6 additions & 7 deletions diambra/arena/stable_baselines/make_sb_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def make_sb_env(game_id: str, env_settings: dict={}, wrappers_settings: dict={},

# Add the conversion from gymnasium to gym
old_gym_wrapper = [OldGymWrapper, {}]
if 'additional_wrappers_list' in wrappers_settings:
wrappers_settings['additional_wrappers_list'].insert(0, old_gym_wrapper)
if 'wrappers' in wrappers_settings:
wrappers_settings['wrappers'].insert(0, old_gym_wrapper)
else:
# If it's not present, add the key with a new list containing your custom element
wrappers_settings['additional_wrappers_list'] = [old_gym_wrapper]
wrappers_settings['wrappers'] = [old_gym_wrapper]

def _make_sb_env(rank):
def _init():
Expand Down Expand Up @@ -78,12 +78,11 @@ def __init__(self, env):
:param env: (Gym<=0.21 Environment) the resulting environment
"""
gym.Wrapper.__init__(self, env)
if self.env_settings.action_space == "multi_discrete":
if self.env_settings.action_space == diambra.arena.SpaceType.MULTI_DISCRETE:
self.action_space = gym.spaces.MultiDiscrete(self.n_actions)
self.logger.debug("Using MultiDiscrete action space")
elif self.env_settings.action_space == "discrete":
elif self.env_settings.action_space == diambra.arena.SpaceType.DISCRETE:
self.action_space = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)
self.logger.debug("Using Discrete action space")
self.logger.debug("Using {} action space".format(diambra.arena.SpaceType.Name(self.env_settings.action_space)))

def reset(self, **kwargs):
obs, _ = self.env.reset(**kwargs)
Expand Down
Loading

0 comments on commit 0770bff

Please sign in to comment.