Skip to content

Commit

Permalink
Fix decision transfomers initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 8, 2023
1 parent 154f3c0 commit a78d734
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions d3rlpy/models/torch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return h


def _init_weights(module: nn.Module) -> None:
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)


class ContinuousDecisionTransformer(nn.Module): # type: ignore
_encoder: Encoder
_position_encoding: PositionEncoding
Expand All @@ -260,10 +270,7 @@ def __init__(
activation: nn.Module,
):
super().__init__()
self._encoder = encoder
self._position_encoding = position_encoding
self._action_embed = nn.Linear(action_size, feature_size)
self._rtg_embed = nn.Linear(1, feature_size)
self._embed_ln = nn.LayerNorm(feature_size)
self._gpt2 = GPT2(
hidden_size=feature_size,
Expand All @@ -275,6 +282,11 @@ def __init__(
embed_dropout=embed_dropout,
activation=activation,
)
self.apply(_init_weights)

self._encoder = encoder
self._rtg_embed = nn.Linear(1, feature_size)
self._action_embed = nn.Linear(action_size, feature_size)
self._output = nn.Linear(feature_size, action_size)

def forward(
Expand Down Expand Up @@ -337,11 +349,7 @@ def __init__(
embed_activation: nn.Module,
):
super().__init__()
self._encoder = encoder
self._position_encoding = position_encoding
self._action_embed = nn.Embedding(action_size, feature_size)
nn.init.normal_(self._action_embed.weight, mean=0.0, std=0.02)
self._rtg_embed = nn.Linear(1, feature_size)
self._gpt2 = GPT2(
hidden_size=feature_size,
num_heads=num_heads,
Expand All @@ -353,6 +361,11 @@ def __init__(
activation=activation,
)
self._output = nn.Linear(feature_size, action_size, bias=False)
self._action_embed = nn.Embedding(action_size, feature_size)
self.apply(_init_weights)

self._encoder = encoder
self._rtg_embed = nn.Linear(1, feature_size)
self._embed_activation = embed_activation

def forward(
Expand Down

0 comments on commit a78d734

Please sign in to comment.