Skip to content

Commit

Permalink
Fix padded timesteps
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 8, 2023
1 parent 2874bf6 commit b48106f
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 18 deletions.
7 changes: 4 additions & 3 deletions d3rlpy/algos/transformer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def inner_update(
@dataclasses.dataclass()
class TransformerConfig(LearnableConfig):
context_size: int = 20
max_timestep: int = 1000


TTransformerImpl = TypeVar("TTransformerImpl", bound=TransformerAlgoImplBase)
Expand Down Expand Up @@ -125,7 +126,7 @@ def __init__(
self._rewards = deque([], maxlen=context_size)
self._returns_to_go = deque([], maxlen=context_size)
self._timesteps = deque([], maxlen=context_size)
self._timestep = 0
self._timestep = 1

def predict(self, x: Observation, reward: float) -> Union[np.ndarray, int]:
r"""Returns action.
Expand All @@ -151,7 +152,7 @@ def predict(self, x: Observation, reward: float) -> Union[np.ndarray, int]:
action = self._action_sampler(self._algo.predict(inpt))
self._actions[-1] = action
self._actions.append(self._get_pad_action())
self._timestep += 1
self._timestep = min(self._timestep + 1, self._algo.config.max_timestep)
self._return_rest -= reward
return action

Expand All @@ -163,7 +164,7 @@ def reset(self) -> None:
self._returns_to_go.clear()
self._timesteps.clear()
self._actions.append(self._get_pad_action())
self._timestep = 0
self._timestep = 1
self._return_rest = self._target_return

@property
Expand Down
6 changes: 2 additions & 4 deletions d3rlpy/algos/transformer/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ class DecisionTransformerConfig(TransformerConfig):
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor.
context_size (int): Prior sequence length.
max_timestep (int): Maximum environmental timestep.
batch_size (int): Mini-batch size.
learning_rate (float): Learning rate.
encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory.
optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
Optimizer factory.
num_heads (int): Number of attention heads.
max_timestep (int): Maximum environmental timestep.
num_layers (int): Number of attention blocks.
attn_dropout (float): Dropout probability for attentions.
resid_dropout (float): Dropout probability for residual connection.
Expand All @@ -73,7 +73,6 @@ class DecisionTransformerConfig(TransformerConfig):
encoder_factory: EncoderFactory = make_encoder_field()
optim_factory: OptimizerFactory = make_optimizer_field()
num_heads: int = 1
max_timestep: int = 1000
num_layers: int = 3
attn_dropout: float = 0.1
resid_dropout: float = 0.1
Expand Down Expand Up @@ -158,14 +157,14 @@ class DiscreteDecisionTransformerConfig(TransformerConfig):
Observation preprocessor.
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor.
context_size (int): Prior sequence length.
max_timestep (int): Maximum environmental timestep.
batch_size (int): Mini-batch size.
learning_rate (float): Learning rate.
encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory.
optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
Optimizer factory.
num_heads (int): Number of attention heads.
max_timestep (int): Maximum environmental timestep.
num_layers (int): Number of attention blocks.
attn_dropout (float): Dropout probability for attentions.
resid_dropout (float): Dropout probability for residual connection.
Expand All @@ -186,7 +185,6 @@ class DiscreteDecisionTransformerConfig(TransformerConfig):
encoder_factory: EncoderFactory = make_encoder_field()
optim_factory: OptimizerFactory = make_optimizer_field()
num_heads: int = 8
max_timestep: int = 1000
num_layers: int = 6
attn_dropout: float = 0.1
resid_dropout: float = 0.1
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/transformer/torch/decision_transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def compute_loss(self, batch: TorchTrajectoryMiniBatch) -> torch.Tensor:
)
# (B, T, A) -> (B, T)
loss = ((action - batch.actions) ** 2).sum(dim=-1)
return (loss * batch.masks).sum() / batch.masks.sum()
return loss.mean()


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -181,4 +181,4 @@ def compute_loss(self, batch: TorchTrajectoryMiniBatch) -> torch.Tensor:
batch.actions.view(-1).long(),
reduction="none",
)
return (loss * batch.masks.view(-1)).sum() / batch.masks.sum()
return loss.mean()
4 changes: 2 additions & 2 deletions d3rlpy/dataset/trajectory_slicers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __call__(
returns_to_go = all_returns_to_go[:actual_size].reshape((-1, 1))

# prepare metadata
timesteps = np.arange(start, end)
timesteps = np.arange(start, end) + 1
masks = np.ones(end - start, dtype=np.float32)

# compute backward padding size
Expand Down Expand Up @@ -171,7 +171,7 @@ def __call__(
returns_to_go = all_returns_to_go[:actual_size].reshape((-1, 1))

# prepare metadata
timesteps = np.arange(start, end)
timesteps = np.arange(start, end) + 1
masks = np.ones(end - start, dtype=np.float32)

# compute backward padding size
Expand Down
12 changes: 8 additions & 4 deletions d3rlpy/models/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,12 @@ def create_continuous_decision_transformer(
hidden_size = compute_output_size([observation_shape], encoder)

if position_encoding_type == "simple":
position_encoding = SimplePositionEncoding(hidden_size, max_timestep)
position_encoding = SimplePositionEncoding(
hidden_size, max_timestep + 1
)
elif position_encoding_type == "global":
position_encoding = GlobalPositionEncoding(
hidden_size, max_timestep, context_size
hidden_size, max_timestep + 1, context_size
)
else:
raise ValueError(
Expand Down Expand Up @@ -313,10 +315,12 @@ def create_discrete_decision_transformer(
hidden_size = compute_output_size([observation_shape], encoder)

if position_encoding_type == "simple":
position_encoding = SimplePositionEncoding(hidden_size, max_timestep)
position_encoding = SimplePositionEncoding(
hidden_size, max_timestep + 1
)
elif position_encoding_type == "global":
position_encoding = GlobalPositionEncoding(
hidden_size, max_timestep, context_size
hidden_size, max_timestep + 1, context_size
)
else:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/models/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def create(
f"{module_name}.{param_name}" if module_name else param_name
)

if full_name not in decay and full_name not in no_decay:
if full_name not in params_dict:
params_dict[full_name] = param

if param_name.endswith("bias"):
Expand Down
4 changes: 2 additions & 2 deletions tests/dataset/test_trajectory_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_basic_trajectory_slicer(
assert np.all(traj.returns_to_go[pad_size:] == returns_to_go[start:end])
assert np.all(traj.returns_to_go[:pad_size] == 0.0)
assert np.all(traj.terminals == 0.0)
assert np.all(traj.timesteps[pad_size:] == np.arange(start, end))
assert np.all(traj.timesteps[pad_size:] == np.arange(start, end)) + 1
assert np.all(traj.timesteps[:pad_size] == 0.0)
assert np.all(traj.masks[pad_size:] == 1.0)
assert np.all(traj.masks[:pad_size] == 0.0)
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_frame_stack_trajectory_slicer(
assert np.all(traj.returns_to_go[pad_size:] == returns_to_go[start:end])
assert np.all(traj.returns_to_go[:pad_size] == 0.0)
assert np.all(traj.terminals == 0.0)
assert np.all(traj.timesteps[pad_size:] == np.arange(start, end))
assert np.all(traj.timesteps[pad_size:] == np.arange(start, end)) + 1
assert np.all(traj.timesteps[:pad_size] == 0.0)
assert np.all(traj.masks[pad_size:] == 1.0)
assert np.all(traj.masks[:pad_size] == 0.0)
Expand Down

0 comments on commit b48106f

Please sign in to comment.