Skip to content

Commit

Permalink
Merge pull request #64 from Kautenja/4.0.1
Browse files Browse the repository at this point in the history
4.0.1
  • Loading branch information
Kautenja authored Sep 16, 2018
2 parents 0a6e0da + 2879cb8 commit 05efde4
Show file tree
Hide file tree
Showing 20 changed files with 284 additions and 282 deletions.
2 changes: 0 additions & 2 deletions gym_super_mario_bros/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Registration code of Gym environments in this package."""
from .smb_env import SuperMarioBrosEnv
from .smb_stage_env import SuperMarioBrosStageEnv
from ._registration import make


# define the outward facing API of this package
__all__ = [
make.__name__,
SuperMarioBrosEnv.__name__,
SuperMarioBrosStageEnv.__name__,
]
47 changes: 22 additions & 25 deletions gym_super_mario_bros/_registration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Registration code of Gym environments in this package."""
import gym
from ._rom_mode import RomMode


def _register_mario_env(id, **kwargs):
Expand Down Expand Up @@ -28,27 +27,27 @@ def _register_mario_env(id, **kwargs):


# Super Mario Bros. with standard frame skip
_register_mario_env('SuperMarioBros-v0', frameskip=4, rom_mode=RomMode.VANILLA)
_register_mario_env('SuperMarioBros-v1', frameskip=4, rom_mode=RomMode.DOWNSAMPLE)
_register_mario_env('SuperMarioBros-v2', frameskip=4, rom_mode=RomMode.PIXEL)
_register_mario_env('SuperMarioBros-v3', frameskip=4, rom_mode=RomMode.RECTANGLE)
_register_mario_env('SuperMarioBros-v0', frames_per_step=4, rom_mode='vanilla')
_register_mario_env('SuperMarioBros-v1', frames_per_step=4, rom_mode='downsample')
_register_mario_env('SuperMarioBros-v2', frames_per_step=4, rom_mode='pixel')
_register_mario_env('SuperMarioBros-v3', frames_per_step=4, rom_mode='rectangle')


# Super Mario Bros. with no frame skip
_register_mario_env('SuperMarioBrosNoFrameskip-v0', frameskip=1, rom_mode=RomMode.VANILLA)
_register_mario_env('SuperMarioBrosNoFrameskip-v1', frameskip=1, rom_mode=RomMode.DOWNSAMPLE)
_register_mario_env('SuperMarioBrosNoFrameskip-v2', frameskip=1, rom_mode=RomMode.PIXEL)
_register_mario_env('SuperMarioBrosNoFrameskip-v3', frameskip=1, rom_mode=RomMode.RECTANGLE)
_register_mario_env('SuperMarioBrosNoFrameskip-v0', frames_per_step=1, rom_mode='vanilla')
_register_mario_env('SuperMarioBrosNoFrameskip-v1', frames_per_step=1, rom_mode='downsample')
_register_mario_env('SuperMarioBrosNoFrameskip-v2', frames_per_step=1, rom_mode='pixel')
_register_mario_env('SuperMarioBrosNoFrameskip-v3', frames_per_step=1, rom_mode='rectangle')


# Super Mario Bros. 2 (Lost Levels) with standard frame skip
_register_mario_env('SuperMarioBros2-v0', lost_levels=True, frameskip=4, rom_mode=RomMode.VANILLA)
_register_mario_env('SuperMarioBros2-v1', lost_levels=True, frameskip=4, rom_mode=RomMode.DOWNSAMPLE)
_register_mario_env('SuperMarioBros2-v0', lost_levels=True, frames_per_step=4, rom_mode='vanilla')
_register_mario_env('SuperMarioBros2-v1', lost_levels=True, frames_per_step=4, rom_mode='downsample')


# Super Mario Bros. 2 (Lost Levels) with no frame skip
_register_mario_env('SuperMarioBros2NoFrameskip-v0', lost_levels=True, frameskip=1, rom_mode=RomMode.VANILLA)
_register_mario_env('SuperMarioBros2NoFrameskip-v1', lost_levels=True, frameskip=1, rom_mode=RomMode.DOWNSAMPLE)
_register_mario_env('SuperMarioBros2NoFrameskip-v0', lost_levels=True, frames_per_step=1, rom_mode='vanilla')
_register_mario_env('SuperMarioBros2NoFrameskip-v1', lost_levels=True, frames_per_step=1, rom_mode='downsample')


def _register_mario_stage_env(id, **kwargs):
Expand All @@ -57,7 +56,7 @@ def _register_mario_stage_env(id, **kwargs):
Args:
id (str): id for the env to register
kwargs (dict): keyword arguments for the SuperMarioBrosStageEnv initializer
kwargs (dict): keyword arguments for the SuperMarioBrosEnv initializer
Returns:
None
Expand All @@ -67,7 +66,7 @@ def _register_mario_stage_env(id, **kwargs):
# register the environment
gym.envs.registration.register(
id=id,
entry_point='gym_super_mario_bros:SuperMarioBrosStageEnv',
entry_point='gym_super_mario_bros:SuperMarioBrosEnv',
max_episode_steps=9999999,
reward_threshold=32000,
kwargs=kwargs,
Expand All @@ -79,29 +78,27 @@ def _register_mario_stage_env(id, **kwargs):
_ID_TEMPLATE = 'SuperMarioBros{}-{}-{}-v{}'
# iterate over all the rom modes, worlds (1-8), and stages (1-4)
_ROM_MODES = [
RomMode.VANILLA,
RomMode.DOWNSAMPLE,
RomMode.PIXEL,
RomMode.RECTANGLE
'vanilla',
'downsample',
'pixel',
'rectangle'
]
for version, rom_mode in enumerate(_ROM_MODES):
for world in range(1, 9):
for stage in range(1, 5):
# setup the frame-skipping environment
env_id = _ID_TEMPLATE.format('', world, stage, version)
_register_mario_stage_env(env_id,
frameskip=4,
frames_per_step=4,
rom_mode=rom_mode,
target_world=world,
target_stage=stage
target=(world, stage),
)
# setup the no frame-skipping environment
env_id = _ID_TEMPLATE.format('NoFrameskip', world, stage, version)
_register_mario_stage_env(env_id,
frameskip=1,
frames_per_step=1,
rom_mode=rom_mode,
target_world=world,
target_stage=stage
target=(world, stage),
)


Expand Down
17 changes: 0 additions & 17 deletions gym_super_mario_bros/_rom_mode.py

This file was deleted.

10 changes: 10 additions & 0 deletions gym_super_mario_bros/_roms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Methods for ROM file management."""
from .decode_target import decode_target
from .rom_path import rom_path


# explicitly define the outward facing API of this package
__all__ = [
decode_target.__name__,
rom_path.__name__,
]
71 changes: 71 additions & 0 deletions gym_super_mario_bros/_roms/decode_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""A method to decode target values for a ROM stage environment."""


def decode_target(target, lost_levels):
"""
Return the target area for target world and target stage.
Args:
target_world (None, int): the world to target
target_stage (None, int): the stage to target
lost_levels (bool): whether to use lost levels game
Returns (int):
the area to target to load the target world and stage
"""
# Type and value check the lost levels parameter
if not isinstance(lost_levels, bool):
raise TypeError('lost_levels must be of type: bool')
# if there is no target, the world, stage, and area targets are all None
if target is None:
return None, None, None
elif not isinstance(target, tuple):
raise TypeError('target must be of type tuple')
# unwrap the target world and stage
target_world, target_stage = target
# Type and value check the target world parameter
if not isinstance(target_world, int):
raise TypeError('target_world must be of type: int')
else:
if lost_levels:
if not 1 <= target_world <= 12:
raise ValueError('target_world must be in {1, ..., 12}')
elif not 1 <= target_world <= 8:
raise ValueError('target_world must be in {1, ..., 8}')
# Type and value check the target level parameter
if not isinstance(target_stage, int):
raise TypeError('target_stage must be of type: int')
else:
if not 1 <= target_stage <= 4:
raise ValueError('target_stage must be in {1, ..., 4}')

# no target are defined for no target world or stage situations
if target_world is None or target_stage is None:
return None
# setup target area if target world and stage are specified
target_area = target_stage
# setup the target area depending on whether this is SMB 1 or 2
if lost_levels:
# setup the target area depending on the target world and stage
if target_world in {1, 3}:
if target_stage >= 2:
target_area = target_area + 1
elif target_world >= 5:
# TODO: figure out why all worlds greater than 5 fail.
# target_area = target_area + 1
# for now just raise a value error
worlds = set(range(5, 12 + 1))
msg = 'lost levels worlds {} not supported'.format(worlds)
raise ValueError(msg)
else:
# setup the target area depending on the target world and stage
if target_world in {1, 2, 4, 7}:
if target_stage >= 2:
target_area = target_area + 1

return target_world, target_stage, target_area


# explicitly define the outward facing API of this module
__all__ = [decode_target.__name__]
53 changes: 53 additions & 0 deletions gym_super_mario_bros/_roms/rom_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""A method to load a ROM path."""
import os


# a dictionary mapping ROM paths first by lost levels, then by ROM hack mode
_ROM_PATHS = {
# the dictionary of lost level ROM paths
True: {
'vanilla': 'super-mario-bros-2.nes',
'downsample': 'super-mario-bros-2-downsample.nes',
},
# the dictionary of Super Mario Bros. 1 ROM paths
False: {
'vanilla': 'super-mario-bros.nes',
'pixel': 'super-mario-bros-pixel.nes',
'rectangle': 'super-mario-bros-rectangle.nes',
'downsample': 'super-mario-bros-downsample.nes',
}
}


def rom_path(lost_levels, rom_mode):
"""
Return the ROM filename for a game and ROM mode.
Args:
lost_levels (bool): whether to use the lost levels ROM
rom_mode (str): the mode of the ROM hack to use as one of:
- 'vanilla'
- 'pixel'
- 'downsample'
- 'vanilla'
Returns (str):
the ROM path based on the input parameters
"""
# Type and value check the lost levels parameter
if not isinstance(lost_levels, bool):
raise TypeError('lost_levels must be of type: bool')
# try the unwrap the ROM path from the dictionary
try:
rom = _ROM_PATHS[lost_levels][rom_mode]
except KeyError:
raise ValueError('rom_mode ({}) not supported!'.format(rom_mode))
# get the absolute path for the ROM
rom = os.path.join(os.path.dirname(os.path.abspath(__file__)), rom)

return rom


# explicitly define the outward facing API of this module
__all__ = [rom_path.__name__]
43 changes: 0 additions & 43 deletions gym_super_mario_bros/roms/_remove_backgrounds.py

This file was deleted.

Loading

0 comments on commit 05efde4

Please sign in to comment.