Skip to content

Commit

Permalink
Fix interval values of TransitionMiniBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 24, 2023
1 parent 45c9c65 commit eab9e9f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
2 changes: 1 addition & 1 deletion d3rlpy/dataset/mini_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def from_transitions(
[-1, 1],
)
intervals = np.reshape(
np.array([transition.terminal for transition in transitions]),
np.array([transition.interval for transition in transitions]),
[-1, 1],
)
return TransitionMiniBatch(
Expand Down
27 changes: 26 additions & 1 deletion tests/dataset/test_mini_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,50 @@ def test_transition_mini_batch(
) -> None:
transitions = []
for _ in range(batch_size):
transition = create_transition(observation_shape, action_size)
transition = create_transition(
observation_shape,
action_size,
terminated=bool(np.random.randint(2)),
)
transitions.append(transition)

batch = TransitionMiniBatch.from_transitions(transitions)

ref_actions = np.array([t.action for t in transitions])
ref_rewards = np.array([t.reward for t in transitions])
ref_terminals = np.array([[t.terminal] for t in transitions])
ref_intervals = np.array([[t.interval] for t in transitions])

if isinstance(observation_shape[0], tuple):
for i, shape in enumerate(observation_shape):
ref_observations = np.array([t.observation[i] for t in transitions])
ref_next_observations = np.array(
[t.next_observation[i] for t in transitions]
)
assert isinstance(shape, tuple)
assert batch.observations[i].shape == (batch_size, *shape)
assert batch.next_observations[i].shape == (batch_size, *shape)
assert np.all(batch.observations[i] == ref_observations)
assert np.all(batch.next_observations[i] == ref_next_observations)
else:
ref_observations = np.array([t.observation for t in transitions])
ref_next_observations = np.array(
[t.next_observation for t in transitions]
)
assert isinstance(batch.observations, np.ndarray)
assert isinstance(batch.next_observations, np.ndarray)
assert batch.observations.shape == (batch_size, *observation_shape)
assert batch.next_observations.shape == (batch_size, *observation_shape)
assert np.all(batch.observations == ref_observations)
assert np.all(batch.next_observations == ref_next_observations)
assert batch.actions.shape == (batch_size, action_size)
assert batch.rewards.shape == (batch_size, 1)
assert batch.terminals.shape == (batch_size, 1)
assert batch.intervals.shape == (batch_size, 1)
assert np.all(batch.actions == ref_actions)
assert np.all(batch.rewards == ref_rewards)
assert np.all(batch.terminals == ref_terminals)
assert np.all(batch.intervals == ref_intervals)


@pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))])
Expand Down
16 changes: 11 additions & 5 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,25 @@ def create_transition(
observation: Observation
next_observation: Observation
if isinstance(observation_shape[0], (list, tuple)):
observation = [np.random.random(shape) for shape in observation_shape]
observation = [
np.random.random(shape).astype(np.float32)
for shape in observation_shape
]
next_observation = [
np.random.random(shape) for shape in observation_shape
np.random.random(shape).astype(np.float32)
for shape in observation_shape
]
else:
observation = np.random.random(observation_shape)
next_observation = np.random.random(observation_shape)
observation = np.random.random(observation_shape).astype(np.float32)
next_observation = np.random.random(observation_shape).astype(
np.float32
)

action: NDArray
if discrete_action:
action = np.random.randint(action_size, size=(1,))
else:
action = np.random.random(action_size)
action = np.random.random(action_size).astype(np.float32)

return Transition(
observation=observation,
Expand Down

0 comments on commit eab9e9f

Please sign in to comment.