Skip to content

Commit

Permalink
Add use_layer_norm flag to VectorEncoderFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jun 5, 2024
1 parent b769202 commit 7e72150
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
4 changes: 4 additions & 0 deletions d3rlpy/models/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class VectorEncoderFactory(EncoderFactory):
standard architecture with ``[256, 256]`` is used.
activation (str): activation function name.
use_batch_norm (bool): Flag to insert batch normalization layers.
use_layer_norm (bool): Flag to insert layer normalization layers.
dropout_rate (float): Dropout probability.
exclude_last_activation (bool): Flag to exclude activation function at
the last layer.
Expand All @@ -151,6 +152,7 @@ class VectorEncoderFactory(EncoderFactory):
hidden_units: List[int] = field(default_factory=lambda: [256, 256])
activation: str = "relu"
use_batch_norm: bool = False
use_layer_norm: bool = False
dropout_rate: Optional[float] = None
exclude_last_activation: bool = False
last_activation: Optional[str] = None
Expand All @@ -161,6 +163,7 @@ def create(self, observation_shape: Shape) -> VectorEncoder:
observation_shape=cast_flat_shape(observation_shape),
hidden_units=self.hidden_units,
use_batch_norm=self.use_batch_norm,
use_layer_norm=self.use_layer_norm,
dropout_rate=self.dropout_rate,
activation=create_activation(self.activation),
exclude_last_activation=self.exclude_last_activation,
Expand All @@ -183,6 +186,7 @@ def create_with_action(
action_size=action_size,
hidden_units=self.hidden_units,
use_batch_norm=self.use_batch_norm,
use_layer_norm=self.use_layer_norm,
dropout_rate=self.dropout_rate,
discrete_action=discrete_action,
activation=create_activation(self.activation),
Expand Down
6 changes: 6 additions & 0 deletions d3rlpy/models/torch/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(
observation_shape: Sequence[int],
hidden_units: Optional[Sequence[int]] = None,
use_batch_norm: bool = False,
use_layer_norm: bool = False,
dropout_rate: Optional[float] = None,
activation: nn.Module = nn.ReLU(),
exclude_last_activation: bool = False,
Expand All @@ -220,6 +221,8 @@ def __init__(
layers.append(activation)
if use_batch_norm:
layers.append(nn.BatchNorm1d(out_unit))
if use_layer_norm:
layers.append(nn.LayerNorm(out_unit))
if dropout_rate is not None:
layers.append(nn.Dropout(dropout_rate))
self._layers = nn.Sequential(*layers)
Expand All @@ -240,6 +243,7 @@ def __init__(
action_size: int,
hidden_units: Optional[Sequence[int]] = None,
use_batch_norm: bool = False,
use_layer_norm: bool = False,
dropout_rate: Optional[float] = None,
discrete_action: bool = False,
activation: nn.Module = nn.ReLU(),
Expand Down Expand Up @@ -268,6 +272,8 @@ def __init__(
layers.append(activation)
if use_batch_norm:
layers.append(nn.BatchNorm1d(out_unit))
if use_layer_norm:
layers.append(nn.LayerNorm(out_unit))
if dropout_rate is not None:
layers.append(nn.Dropout(dropout_rate))
self._layers = nn.Sequential(*layers)
Expand Down
11 changes: 7 additions & 4 deletions reproductions/offline/rebrac.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def main() -> None:
d3rlpy.envs.seed_env(env, args.seed)

# deeper network
encoder = d3rlpy.models.VectorEncoderFactory([256, 256, 256])
actor_encoder = d3rlpy.models.VectorEncoderFactory([256, 256, 256])
critic_encoder = d3rlpy.models.VectorEncoderFactory(
[256, 256, 256], use_layer_norm=True
)

actor_beta, critic_beta = 0.01, 0.01
for dataset_name, beta_from_paper in BETA_TABLE.items():
Expand All @@ -48,11 +51,11 @@ def main() -> None:

rebrac = d3rlpy.algos.ReBRACConfig(
actor_learning_rate=1e-3,
critic_learning_rate=1e-1,
critic_learning_rate=1e-3,
batch_size=1024,
gamma=0.99,
actor_encoder_factory=encoder,
critic_encoder_factory=encoder,
actor_encoder_factory=actor_encoder,
critic_encoder_factory=critic_encoder,
target_smoothing_sigma=0.2,
target_smoothing_clip=0.5,
update_actor_interval=2,
Expand Down

0 comments on commit 7e72150

Please sign in to comment.