Skip to content

Commit

Permalink
Fix DiscreteDecisionTransformer training
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 7, 2023
1 parent 96e39fb commit 468ce52
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions d3rlpy/algos/transformer/torch/decision_transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ def __init__(
@eval_api
def predict(self, inpt: TorchTransformerInput) -> torch.Tensor:
# (1, T, A)
probs, _ = self._modules.transformer(
_, logits = self._modules.transformer(
inpt.observations, inpt.actions, inpt.returns_to_go, inpt.timesteps
)
# (1, T, A) -> (A,)
return probs[0][-1]
return logits[0][-1]

def inner_update(
self, batch: TorchTrajectoryMiniBatch, grad_step: int
Expand Down
2 changes: 1 addition & 1 deletion reproductions/offline/discrete_decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main() -> None:
num_layers=6,
attn_dropout=0.1,
embed_dropout=0.1,
optim_factory=d3rlpy.models.optimizers.AdamFactory(
optim_factory=d3rlpy.models.optimizers.AdamWFactory(
betas=(0.9, 0.95),
weight_decay=0.1,
),
Expand Down

0 comments on commit 468ce52

Please sign in to comment.