Skip to content

Commit

Permalink
get_dataset and cli download ability (#322)
Browse files Browse the repository at this point in the history
* minari integration with test

* fix import order

* run formatter
  • Loading branch information
grahamannett authored Nov 3, 2023
1 parent f9bde31 commit 8e139fc
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 21 deletions.
38 changes: 18 additions & 20 deletions d3rlpy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,28 +350,26 @@ def play(

@cli.command(short_help="Install additional packages.")
@click.argument("name")
def install(name: str) -> None:
def _install_module(
name: list[str], upgrade: bool = False, check: bool = True
) -> None:
if name == "atari":
subprocess.run(
["pip3", "install", "-U", "gym[atari,accept-rom-license]"],
check=True,
)
_install_module(["gym[atari,accept-rom-license]"], upgrade=True)
elif name == "d4rl_atari":
subprocess.run(["d3rlpy", "install", "atari"], check=True)
subprocess.run(
["pip3", "install", "git+https://github.com/takuseno/d4rl-atari"],
check=True,
)
install("atari")
_install_module(["git+https://github.com/takuseno/d4rl-atari"])
elif name == "d4rl":
subprocess.run(
[
"pip3",
"install",
"git+https://github.com/Farama-Foundation/D4RL",
],
check=True,
)
subprocess.run(["pip3", "install", "-U", "gym"], check=True)
subprocess.run(["pip3", "uninstall", "-y", "pybullet"], check=True)
_install_module(["git+https://github.com/Farama-Foundation/D4RL"])
_install_module(["gym"], upgrade=True)
_install_module(["-y", "pybullet"], upgrade=True)
elif name == "minari":
_install_module(["minari==0.4.2"], upgrade=True)
else:
raise ValueError(f"Unsupported command: {name}")


def _install_module(
name: list[str], upgrade: bool = False, check: bool = True
) -> None:
name = ["-U", *name] if upgrade else name
subprocess.run(["pip3", "install", *name], check=check)
64 changes: 64 additions & 0 deletions d3rlpy/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from urllib import request

import gym
import gymnasium
import numpy as np
from gym.wrappers.time_limit import TimeLimit

Expand Down Expand Up @@ -442,6 +443,69 @@ def get_d4rl(
) from e


def get_minari(
env_name: str,
transition_picker: Optional[TransitionPickerProtocol] = None,
trajectory_slicer: Optional[TrajectorySlicerProtocol] = None,
render_mode: Optional[str] = None,
) -> Tuple[ReplayBuffer, gymnasium.Env[np.ndarray, np.ndarray]]:
"""Returns minari dataset and envrironment.
The dataset is provided through minari.
.. code-block:: python
from d3rlpy.datasets import get_minari
dataset, env = get_minari('door-cloned-v1')
Args:
env_name: environment id of minari dataset.
transition_picker: TransitionPickerProtocol object.
trajectory_slicer: TrajectorySlicerProtocol object.
render_mode: Mode of rendering (``human``, ``rgb_array``).
Returns:
tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment.
"""
try:
import minari

_dataset: minari.MinariDataset = minari.load_dataset(
env_name, download=True
)

data = {
"observations": [],
"actions": [],
"rewards": [],
"terminations": [],
"truncations": [],
}

for ep in _dataset:
for key in data.keys():
data[key].append(getattr(ep, key))

dataset = MDPDataset(
observations=np.concatenate(data["observations"]),
actions=np.concatenate(data["actions"]),
rewards=np.concatenate(data["rewards"]),
terminals=np.concatenate(data["terminations"]),
timeouts=np.concatenate(data["truncations"]),
transition_picker=transition_picker,
trajectory_slicer=trajectory_slicer,
)

env = _dataset.recover_environment()
unwrapped_env = env.unwrapped

unwrapped_env.render_mode = render_mode
return dataset, TimeLimit(
unwrapped_env, max_episode_steps=env.spec.max_episode_steps
)

except ImportError as e:
raise ImportError(
"minari is not installed.\n" "$ d3rlpy install minari"
) from e


ATARI_GAMES = [
"adventure",
"air-raid",
Expand Down
1 change: 1 addition & 0 deletions docs/references/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ learning algorithms.
d3rlpy.datasets.get_atari_transitions
d3rlpy.datasets.get_d4rl
d3rlpy.datasets.get_dataset
d3rlpy.datasets.get_minari
14 changes: 13 additions & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from d3rlpy.datasets import get_cartpole, get_dataset, get_pendulum
from d3rlpy.datasets import get_cartpole, get_dataset, get_minari, get_pendulum


@pytest.mark.parametrize("dataset_type", ["replay", "random"])
Expand All @@ -23,3 +23,15 @@ def test_get_dataset(env_name: str) -> None:
assert env.unwrapped.spec.id == "CartPole-v1"
elif env_name == "pendulum-random":
assert env.unwrapped.spec.id == "Pendulum-v1"


@pytest.mark.parametrize(
"dataset_name, env_name",
[
("door-cloned-v1", "AdroitHandDoor-v1"),
("relocate-expert-v1", "AdroitHandRelocate-v1"),
],
)
def test_get_minari(dataset_name: str, env_name: str) -> None:
_, env = get_minari(dataset_name)
assert env.unwrapped.spec.id == env_name

0 comments on commit 8e139fc

Please sign in to comment.