From 4dee692fa02412fe00be4049f1ae6c504dde251f Mon Sep 17 00:00:00 2001 From: takuseno Date: Mon, 6 May 2024 18:07:48 +0900 Subject: [PATCH] Fix action scaling for d4rl --- d3rlpy/datasets.py | 20 ++++++++++++++++---- mypy.ini | 5 +++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/d3rlpy/datasets.py b/d3rlpy/datasets.py index dc11a4ae..bf383437 100644 --- a/d3rlpy/datasets.py +++ b/d3rlpy/datasets.py @@ -389,6 +389,7 @@ def get_d4rl( transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, render_mode: Optional[str] = None, + max_episode_steps: int = 1000, ) -> Tuple[ReplayBuffer, gym.Env[NDArray, NDArray]]: """Returns d4rl dataset and envrironment. @@ -410,12 +411,17 @@ def get_d4rl( transition_picker: TransitionPickerProtocol object. trajectory_slicer: TrajectorySlicerProtocol object. render_mode: Mode of rendering (``human``, ``rgb_array``). + max_episode_steps: Maximum episode environmental steps. Returns: tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. """ try: - import d4rl # type: ignore + import d4rl + from d4rl.locomotion.wrappers import NormalizedBoxEnv + from d4rl.utils.wrappers import ( + NormalizedBoxEnv as NormalizedBoxEnvFromUtils, + ) env = gym.make(env_name) raw_dataset: Dict[str, NDArray] = env.get_dataset() # type: ignore @@ -436,11 +442,17 @@ def get_d4rl( trajectory_slicer=trajectory_slicer, ) - # wrapped by NormalizedBoxEnv that is incompatible with newer Gym - unwrapped_env: gym.Env[Any, Any] = env.env.env.env.wrapped_env # type: ignore + # remove incompatible wrappers + normalized_env = env.env.env.env # type: ignore + assert isinstance( + normalized_env, (NormalizedBoxEnv, NormalizedBoxEnvFromUtils) + ) + unwrapped_env: gym.Env[Any, Any] = normalized_env.wrapped_env unwrapped_env.render_mode = render_mode # overwrite - return dataset, TimeLimit(unwrapped_env, max_episode_steps=1000) + return dataset, TimeLimit( + normalized_env, max_episode_steps=max_episode_steps + ) except ImportError as e: raise ImportError( "d4rl is not installed.\n" "$ d3rlpy install d4rl" diff --git a/mypy.ini b/mypy.ini index b6f45301..0d910287 100644 --- a/mypy.ini +++ b/mypy.ini @@ -61,3 +61,8 @@ follow_imports_for_stubs = True [mypy-minari.*] ignore_missing_imports = True + +[mypy-d4rl.*] +ignore_missing_imports = True +follow_imports = skip +follow_imports_for_stubs = True