Skip to content

Commit

Permalink
Add more parameters to transformer models
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Dec 27, 2023
1 parent b55c72d commit dcfd0e6
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 58 deletions.
4 changes: 2 additions & 2 deletions d3rlpy/models/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def create_continuous_decision_transformer(

transformer = ContinuousDecisionTransformer(
encoder=encoder,
feature_size=hidden_size,
embed_size=hidden_size,
position_encoding=position_encoding,
action_size=action_size,
num_heads=num_heads,
Expand Down Expand Up @@ -341,7 +341,7 @@ def create_discrete_decision_transformer(

transformer = DiscreteDecisionTransformer(
encoder=encoder,
feature_size=hidden_size,
embed_size=hidden_size,
position_encoding=position_encoding,
action_size=action_size,
num_heads=num_heads,
Expand Down
78 changes: 46 additions & 32 deletions d3rlpy/models/torch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ class CausalSelfAttention(nn.Module): # type: ignore

def __init__(
self,
in_size: int,
out_size: int,
embed_size: int,
num_heads: int,
context_size: int,
attn_dropout: float,
Expand All @@ -47,10 +46,10 @@ def __init__(
super().__init__()
self._num_heads = num_heads
self._context_size = context_size
self._k = nn.Linear(in_size, out_size)
self._q = nn.Linear(in_size, out_size)
self._v = nn.Linear(in_size, out_size)
self._proj = nn.Linear(out_size, out_size)
self._k = nn.Linear(embed_size, embed_size)
self._q = nn.Linear(embed_size, embed_size)
self._v = nn.Linear(embed_size, embed_size)
self._proj = nn.Linear(embed_size, embed_size)
self._attn_dropout = nn.Dropout(attn_dropout)
self._proj_dropout = nn.Dropout(resid_dropout)
mask = create_attention_mask(context_size)
Expand Down Expand Up @@ -91,11 +90,17 @@ class MLP(nn.Module): # type: ignore
_activation: nn.Module

def __init__(
self, in_size: int, out_size: int, dropout: float, activation: nn.Module
self,
in_size: int,
out_size: int,
pre_activation_hidden_size: int,
post_activation_hidden_size: int,
dropout: float,
activation: nn.Module,
):
super().__init__()
self._l1 = nn.Linear(in_size, 4 * out_size)
self._l2 = nn.Linear(4 * out_size, out_size)
self._l1 = nn.Linear(in_size, pre_activation_hidden_size)
self._l2 = nn.Linear(post_activation_hidden_size, out_size)
self._dropout = nn.Dropout(dropout)
self._activation = activation

Expand All @@ -113,8 +118,9 @@ class Block(nn.Module): # type: ignore

def __init__(
self,
in_size: int,
out_size: int,
layer_width: int,
pre_activation_ff_hidden_size: int,
post_activation_ff_hidden_size: int,
num_heads: int,
context_size: int,
attn_dropout: float,
Expand All @@ -123,21 +129,22 @@ def __init__(
):
super().__init__()
self._attention = CausalSelfAttention(
in_size=in_size,
out_size=out_size,
embed_size=layer_width,
num_heads=num_heads,
context_size=context_size,
attn_dropout=attn_dropout,
resid_dropout=resid_dropout,
)
self._mlp = MLP(
in_size=out_size,
out_size=out_size,
in_size=layer_width,
out_size=layer_width,
pre_activation_hidden_size=pre_activation_ff_hidden_size,
post_activation_hidden_size=post_activation_ff_hidden_size,
dropout=resid_dropout,
activation=activation,
)
self._layer_norm1 = nn.LayerNorm(out_size, eps=0.003)
self._layer_norm2 = nn.LayerNorm(out_size, eps=0.003)
self._layer_norm1 = nn.LayerNorm(layer_width, eps=0.003)
self._layer_norm2 = nn.LayerNorm(layer_width, eps=0.003)

def forward(self, x: torch.Tensor) -> torch.Tensor:
norm_x = self._layer_norm1(x)
Expand Down Expand Up @@ -206,7 +213,9 @@ class GPT2(nn.Module): # type: ignore

def __init__(
self,
hidden_size: int,
layer_width: int,
pre_activation_ff_hidden_size: int,
post_activation_ff_hidden_size: int,
num_heads: int,
context_size: int,
num_layers: int,
Expand All @@ -218,8 +227,9 @@ def __init__(
super().__init__()
blocks = [
Block(
in_size=hidden_size,
out_size=hidden_size,
layer_width=layer_width,
pre_activation_ff_hidden_size=pre_activation_ff_hidden_size,
post_activation_ff_hidden_size=post_activation_ff_hidden_size,
num_heads=num_heads,
context_size=context_size,
attn_dropout=attn_dropout,
Expand All @@ -229,7 +239,7 @@ def __init__(
for _ in range(num_layers)
]
self._transformer = nn.Sequential(*blocks)
self._layer_norm = nn.LayerNorm(hidden_size, eps=0.003)
self._layer_norm = nn.LayerNorm(layer_width, eps=0.003)
self._dropout = nn.Dropout(embed_dropout)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -260,7 +270,7 @@ class ContinuousDecisionTransformer(nn.Module): # type: ignore
def __init__(
self,
encoder: Encoder,
feature_size: int,
embed_size: int,
position_encoding: PositionEncoding,
action_size: int,
num_heads: int,
Expand All @@ -273,9 +283,11 @@ def __init__(
):
super().__init__()
self._position_encoding = position_encoding
self._embed_ln = nn.LayerNorm(feature_size)
self._embed_ln = nn.LayerNorm(embed_size)
self._gpt2 = GPT2(
hidden_size=feature_size,
layer_width=embed_size,
pre_activation_ff_hidden_size=4 * embed_size,
post_activation_ff_hidden_size=4 * embed_size,
num_heads=num_heads,
context_size=3 * context_size,
num_layers=num_layers,
Expand All @@ -287,9 +299,9 @@ def __init__(
self.apply(_init_weights)

self._encoder = encoder
self._rtg_embed = nn.Linear(1, feature_size)
self._action_embed = nn.Linear(action_size, feature_size)
self._output = nn.Linear(feature_size, action_size)
self._rtg_embed = nn.Linear(1, embed_size)
self._action_embed = nn.Linear(action_size, embed_size)
self._output = nn.Linear(embed_size, action_size)

def forward(
self,
Expand Down Expand Up @@ -342,7 +354,7 @@ class DiscreteDecisionTransformer(nn.Module): # type: ignore
def __init__(
self,
encoder: Encoder,
feature_size: int,
embed_size: int,
position_encoding: PositionEncoding,
action_size: int,
num_heads: int,
Expand All @@ -357,7 +369,9 @@ def __init__(
super().__init__()
self._position_encoding = position_encoding
self._gpt2 = GPT2(
hidden_size=feature_size,
layer_width=embed_size,
pre_activation_ff_hidden_size=4 * embed_size,
post_activation_ff_hidden_size=4 * embed_size,
num_heads=num_heads,
context_size=3 * context_size,
num_layers=num_layers,
Expand All @@ -366,12 +380,12 @@ def __init__(
embed_dropout=embed_dropout,
activation=activation,
)
self._output = nn.Linear(feature_size, action_size, bias=False)
self._action_embed = nn.Embedding(action_size, feature_size)
self._output = nn.Linear(embed_size, action_size, bias=False)
self._action_embed = nn.Embedding(action_size, embed_size)
self.apply(_init_weights)

self._encoder = encoder
self._rtg_embed = nn.Linear(1, feature_size)
self._rtg_embed = nn.Linear(1, embed_size)
self._embed_activation = embed_activation

def forward(
Expand Down
4 changes: 3 additions & 1 deletion d3rlpy/models/utility.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torch import nn

from ..torch_utility import Swish
from ..torch_utility import GEGLU, Swish

__all__ = ["create_activation"]

Expand All @@ -16,4 +16,6 @@ def create_activation(activation_type: str) -> nn.Module:
return Swish()
elif activation_type == "none":
return nn.Identity()
elif activation_type == "geglu":
return GEGLU()
raise ValueError("invalid activation_type.")
8 changes: 8 additions & 0 deletions d3rlpy/torch_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
Expand Down Expand Up @@ -380,3 +381,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class Swish(nn.Module): # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)


class GEGLU(nn.Module): # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.shape[-1] % 2 == 0
a, b = x.chunk(2, dim=-1)
return a * F.gelu(b)
Loading

0 comments on commit dcfd0e6

Please sign in to comment.