Skip to content

Commit

Permalink
Support goal conditioned minari datasets such as antmaze
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 3, 2023
1 parent 2b4b67b commit b55db62
Showing 1 changed file with 32 additions and 17 deletions.
49 changes: 32 additions & 17 deletions d3rlpy/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,28 +470,43 @@ def get_minari(
try:
import minari

_dataset: minari.MinariDataset = minari.load_dataset(
env_name, download=True
)
_dataset = minari.load_dataset(env_name, download=True)

data: Dict[str, List[NDArray]] = {
"observations": [],
"actions": [],
"rewards": [],
"terminations": [],
"truncations": [],
}
observations = []
actions = []
rewards = []
terminals = []
timeouts = []

for ep in _dataset:
for k, v in data.items():
v.append(getattr(ep, k))
if isinstance(ep.observations, dict):
if (
"desired_goal" in ep.observations
and "observation" in ep.observations
):
_observations = np.concatenate(
[
ep.observations["observation"],
ep.observations["desired_goal"],
],
axis=-1,
)
else:
raise ValueError("Unsupported observation format.")
else:
_observations = ep.observations
observations.append(_observations)
actions.append(ep.actions)
rewards.append(ep.rewards)
terminals.append(ep.terminations)
timeouts.append(ep.truncations)

dataset = MDPDataset(
observations=np.concatenate(data["observations"]),
actions=np.concatenate(data["actions"]),
rewards=np.concatenate(data["rewards"]),
terminals=np.concatenate(data["terminations"]),
timeouts=np.concatenate(data["truncations"]),
observations=np.concatenate(observations),
actions=np.concatenate(actions),
rewards=np.concatenate(rewards),
terminals=np.concatenate(terminals),
timeouts=np.concatenate(timeouts),
transition_picker=transition_picker,
trajectory_slicer=trajectory_slicer,
)
Expand Down

0 comments on commit b55db62

Please sign in to comment.