Skip to content

Commit

Permalink
GarageEnv wraps BulletEnv (#1684)
Browse files Browse the repository at this point in the history
* GarageEnv wraps BulletEnv

* Add tests
  • Loading branch information
AiRuiChen authored Jul 3, 2020
1 parent d9aa7e1 commit c9e9817
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 43 deletions.
33 changes: 33 additions & 0 deletions src/garage/envs/bullet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,36 @@
from garage.envs.bullet.bullet_env import BulletEnv

__all__ = ['BulletEnv']


def _get_bullet_env_list():
"""Return a complete list of Bullet Gym environments.
Returns:
list: a list of bullet environment id (str)
"""
envs = [env.replace('- ', '') for env in pybullet_envs.getList()]
# Hardcoded missing environment names from pybullet_envs.getList()
envs.extend([
'MinitaurExtendedEnv-v0', 'MinitaurReactiveEnv-v0',
'MinitaurBallGymEnv-v0', 'MinitaurTrottingEnv-v0',
'MinitaurStandGymEnv-v0', 'MinitaurAlternatingLegsEnv-v0',
'MinitaurFourLegStandEnv-v0', 'KukaDiverseObjectGrasping-v0'
])
return envs


def _get_unsupported_env_list():
"""Return a list of unsupported Bullet Gym environments.
See https://github.com/rlworkgroup/garage/issues/1668
Returns:
list: a list of bullet environment id (str)
"""
return [
'MinitaurExtendedEnv-v0', 'MinitaurReactiveEnv-v0',
'MinitaurBallGymEnv-v0', 'MinitaurTrottingEnv-v0',
'MinitaurStandGymEnv-v0', 'MinitaurAlternatingLegsEnv-v0',
'MinitaurFourLegStandEnv-v0', 'KukaDiverseObjectGrasping-v0'
]
142 changes: 131 additions & 11 deletions src/garage/envs/bullet/bullet_env.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,126 @@
"""Wrappers for py_bullet environments."""
import inspect

import akro
import gym
from pybullet_envs.bullet.minitaur_duck_gym_env import MinitaurBulletDuckEnv
from pybullet_envs.bullet.minitaur_gym_env import MinitaurBulletEnv
from pybullet_envs.env_bases import MJCFBaseBulletEnv

from garage.envs import GarageEnv
from garage.envs.env_spec import EnvSpec


class BulletEnv(GarageEnv):
class BulletEnv(gym.Wrapper):
"""Binding for py_bullet environments."""

def __init__(self, env=None, env_name='', is_image=False):
"""Returns a Garage wrapper class for bullet-based gym.Env.
Args:
env (gym.wrappers.time_limit): A gym.wrappers.time_limit.TimeLimit
object wrapping a gym.Env created via gym.make().
env_name (str): If the env_name is speficied, a gym environment
with that name will be created. If such an environment does not
exist, a `gym.error` is thrown.
is_image (bool): True if observations contain pixel values,
false otherwise. Setting this to true converts a gym.Spaces.Box
obs space to an akro.Image and normalizes pixel values.
"""
if not env:
# 'RacecarZedBulletEnv-v0' environment enables rendering by
# default, while pybullet allows only one GUI connection at a time.
# Setting renders to False avoids potential error when multiple
# of these envs are tested at the same time.
if env_name == 'RacecarZedBulletEnv-v0':
env = gym.make(env_name, renders=False)
else:
env = gym.make(env_name)

# Needed for deserialization
self._env = env
self._env_name = env_name

super().__init__(env)
self.action_space = akro.from_gym(self.env.action_space)
self.observation_space = akro.from_gym(self.env.observation_space,
is_image=is_image)
self._spec = EnvSpec(action_space=self.action_space,
observation_space=self.observation_space)

@property
def spec(self):
"""Return the environment specification.
This property needs to exist, since it's defined as a property in
gym.Wrapper in a way that makes it difficult to overwrite.
Returns:
garage.envs.env_spec.EnvSpec: The envionrment specification.
"""
return self._spec

def close(self):
"""Close the wrapped env."""
# RacecarZedBulletEnv-v0 environment doesn't disconnect from bullet
# server in its close() method.
# Note that disconnect() disconnects the environment from the physics
# server, whereas the GUI window will not be destroyed.
# The expected behavior
if self.env.env.spec.id == 'RacecarZedBulletEnv-v0':
# pylint: disable=protected-access
if self.env.env._p.isConnected():
self.env.env._p.disconnect()
self.env.close()

def reset(self, **kwargs):
"""Call reset on wrapped env.
This method is necessary to suppress a deprecated warning
thrown by gym.Wrapper.
Args:
kwargs: Keyword args
Returns:
object: The initial observation.
"""
return self.env.reset(**kwargs)

def step(self, action):
"""Call step on wrapped env.
This method is necessary to suppress a deprecated warning
thrown by gym.Wrapper.
Args:
action (np.ndarray): An action provided by the agent.
Returns:
np.ndarray: Agent's observation of the current environment
float: Amount of reward returned after previous action
bool: Whether the episode has ended, in which case further step()
calls will return undefined results
dict: Contains auxiliary diagnostic information (helpful for
debugging, and sometimes learning)
"""
observation, reward, done, info = self.env.step(action)
# gym envs that are wrapped in TimeLimit wrapper modify
# the done/termination signal to be true whenever a time
# limit expiration occurs. The following statement sets
# the done signal to be True only if caused by an
# environment termination, and not a time limit
# termination. The time limit termination signal
# will be saved inside env_infos as
# 'BulletEnv.TimeLimitTerminated'
if 'TimeLimit.truncated' in info:
info['BulletEnv.TimeLimitTerminated'] = done # done = True always
done = not info['TimeLimit.truncated']
return observation, reward, done, info

def __getstate__(self):
"""See `Object.__getstate__.
Expand All @@ -37,10 +147,20 @@ def __getstate__(self):
args['robot'] = env.robot
param_names.remove('robot')

# Create param name -> param value mapping
# Create param name -> param value mapping for the wrapped environment
args = {key: env.__dict__['_' + key] for key in param_names}
args['class_type'] = type(env)
args['_env_name'] = self._env_name

# Only one local in-process GUI connection is allowed. Thus pickled
# BulletEnv shouldn't enable rendering. New BulletEnv will connect in
# DIRECT mode.
for key in args.keys():
if 'render' in key:
args[key] = False

# Add BulletEnv class specific params
# env id is saved to help gym.make() in __setstate__
args['id'] = env.spec.id
args['env_name'] = self._env_name

return args

Expand All @@ -53,11 +173,11 @@ def __setstate__(self, state):
state (dict): The instance’s __init__() arguments.
"""
class_type = state['class_type']
env_name = state['_env_name']
# Create a new class instance via constructor arguments
del state['class_type']
del state['_env_name']
env = class_type(**state)
env_id = state['id']
env_name = state['env_name']
# Create a environment via constructor arguments
del state['id']
del state['env_name']
env = gym.make(env_id, **state)

self.__init__(env, env_name)
78 changes: 62 additions & 16 deletions src/garage/envs/garage_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import akro
import gym
from gym.wrappers.time_limit import TimeLimit

from garage.envs.bullet import _get_bullet_env_list, BulletEnv
from garage.envs.env_spec import EnvSpec

# The gym environments using one of the packages in the following lists as
Expand Down Expand Up @@ -35,18 +37,62 @@ class GarageEnv(gym.Wrapper):
convert action_space and observation_space from gym.Spaces to
akro.spaces.
Args:
env (gym.Env): An env that will be wrapped
env_name (str): If the env_name is speficied, a gym environment
with that name will be created. If such an environment does not
exist, a `gym.error` is thrown.
is_image (bool): True if observations contain pixel values,
false otherwise. Setting this to true converts a gym.Spaces.Box
obs space to an akro.Image and normalizes pixel values.
GarageEnv handles all environments created by gym.make().
It returns a different wrapper class instance if the input environment
requires special handling.
Current supported wrapper classes are:
garage.envs.bullet.BulletEnv for Bullet-based gym environments.
See __new__() for details.
"""

def __new__(cls, *args, **kwargs):
"""Returns environment specific wrapper based on input environment type.
Args:
args: positional arguments
kwargs: keyword arguments
Returns:
garage.envs.bullet.BulletEnv: if the environment is a bullet-based
environment. Else returns a garage.envs.GarageEnv
"""
# Determine if the input env is a bullet-based gym environment
env = None
if 'env' in kwargs: # env passed as a keyword arg
env = kwargs['env']
elif len(args) >= 1 and isinstance(args[0], TimeLimit):
# env passed as a positional arg
# only checks env created by gym.make(), which has type TimeLimit
env = args[0]
if env and any(env.env.spec.id == name
for name in _get_bullet_env_list()):
return BulletEnv(env)

env_name = ''
if 'env_name' in kwargs: # env_name as a keyword arg
env_name = kwargs['env_name']
elif len(args) >= 2:
# env_name as a positional arg
env_name = args[1]
if env_name != '' and any(env_name == name
for name in _get_bullet_env_list()):
return BulletEnv(gym.make(env_name))

return super(GarageEnv, cls).__new__(cls)

def __init__(self, env=None, env_name='', is_image=False):
"""Initializes a GarageEnv.
Args:
env (gym.wrappers.time_limit): A gym.wrappers.time_limit.TimeLimit
object wrapping a gym.Env created via gym.make().
env_name (str): If the env_name is speficied, a gym environment
with that name will be created. If such an environment does not
exist, a `gym.error` is thrown.
is_image (bool): True if observations contain pixel values,
false otherwise. Setting this to true converts a gym.Spaces.Box
obs space to an akro.Image and normalizes pixel values.
"""
# Needed for deserialization
self._env_name = env_name
self._env = env
Expand All @@ -59,8 +105,8 @@ def __init__(self, env=None, env_name='', is_image=False):
self.action_space = akro.from_gym(self.env.action_space)
self.observation_space = akro.from_gym(self.env.observation_space,
is_image=is_image)
self.__spec = EnvSpec(action_space=self.action_space,
observation_space=self.observation_space)
self._spec = EnvSpec(action_space=self.action_space,
observation_space=self.observation_space)

@property
def spec(self):
Expand All @@ -73,7 +119,7 @@ def spec(self):
garage.envs.env_spec.EnvSpec: The envionrment specification.
"""
return self.__spec
return self._spec

def close(self):
"""Close the wrapped env."""
Expand Down Expand Up @@ -140,12 +186,12 @@ def step(self, action):
thrown by gym.Wrapper.
Args:
action (object): An action provided by the agent.
action (np.ndarray): An action provided by the agent.
Returns:
object: Agent's observation of the current environment
float : Amount of reward returned after previous action
bool : Whether the episode has ended, in which case further step()
np.ndarray: Agent's observation of the current environment
float: Amount of reward returned after previous action
bool: Whether the episode has ended, in which case further step()
calls will return undefined results
dict: Contains auxiliary diagnostic information (helpful for
debugging, and sometimes learning)
Expand Down
Loading

0 comments on commit c9e9817

Please sign in to comment.