Skip to content

Commit

Permalink
Add SimBaEncoderFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 4, 2024
1 parent 3e57c75 commit 1d28d0b
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 0 deletions.
53 changes: 53 additions & 0 deletions d3rlpy/models/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
VectorEncoder,
VectorEncoderWithAction,
)
from .torch.encoders import SimBaEncoder, SimBaEncoderWithAction
from .utility import create_activation

__all__ = [
"EncoderFactory",
"PixelEncoderFactory",
"VectorEncoderFactory",
"DefaultEncoderFactory",
"SimBaEncoderFactory",
"register_encoder_factory",
"make_encoder_field",
]
Expand Down Expand Up @@ -263,6 +265,56 @@ def get_type() -> str:
return "default"


@dataclass()
class SimBaEncoderFactory(EncoderFactory):
"""SimBa encoder factory class.
This class implements SimBa encoder architecture.
References:
* `Lee et al., SimBa: Simplicity Bias for Scaling Up Parameters in Deep
Reinforcement Learning, <https://arxiv.org/abs/2410.09754>`_
Args:
feature_size (int): Feature unit size.
hidden_size (int): HIdden expansion layer unit size.
n_blocks (int): Number of SimBa blocks.
"""

feature_size: int = 256
hidden_size: int = 1024
n_blocks: int = 1

def create(self, observation_shape: Shape) -> SimBaEncoder:
assert len(observation_shape) == 1
return SimBaEncoder(
observation_shape=cast_flat_shape(observation_shape),
hidden_size=self.hidden_size,
output_size=self.feature_size,
n_blocks=self.n_blocks,
)

def create_with_action(
self,
observation_shape: Shape,
action_size: int,
discrete_action: bool = False,
) -> SimBaEncoderWithAction:
assert len(observation_shape) == 1
return SimBaEncoderWithAction(
observation_shape=cast_flat_shape(observation_shape),
action_size=action_size,
hidden_size=self.hidden_size,
output_size=self.feature_size,
n_blocks=self.n_blocks,
discrete_action=discrete_action,
)

@staticmethod
def get_type() -> str:
return "simba"


register_encoder_factory, make_encoder_field = generate_config_registration(
EncoderFactory, lambda: DefaultEncoderFactory()
)
Expand All @@ -271,3 +323,4 @@ def get_type() -> str:
register_encoder_factory(VectorEncoderFactory)
register_encoder_factory(PixelEncoderFactory)
register_encoder_factory(DefaultEncoderFactory)
register_encoder_factory(SimBaEncoderFactory)
68 changes: 68 additions & 0 deletions d3rlpy/models/torch/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"PixelEncoderWithAction",
"VectorEncoder",
"VectorEncoderWithAction",
"SimBaEncoder",
"SimBaEncoderWithAction",
"compute_output_size",
]

Expand Down Expand Up @@ -290,6 +292,72 @@ def forward(
return self._layers(x)


class SimBaBlock(nn.Module): # type: ignore
def __init__(self, input_size: int, hidden_size: int, out_size: int):
super().__init__()
layers = [
nn.LayerNorm(input_size),
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, out_size)
]
self._layers = nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self._layers(x)


class SimBaEncoder(Encoder):
def __init__(
self,
observation_shape: Sequence[int],
hidden_size: int,
output_size: int,
n_blocks: int,
):
super().__init__()
layers = [
nn.Linear(observation_shape[0], output_size),
*[SimBaBlock(output_size, hidden_size, output_size) for _ in range(n_blocks)],
nn.LayerNorm(output_size),
]
self._layers = nn.Sequential(*layers)

def forward(self, x: TorchObservation) -> torch.Tensor:
assert isinstance(x, torch.Tensor)
return self._layers(x)


class SimBaEncoderWithAction(EncoderWithAction):
def __init__(
self,
observation_shape: Sequence[int],
action_size: int,
hidden_size: int,
output_size: int,
n_blocks: int,
discrete_action: bool,
):
super().__init__()
layers = [
nn.Linear(observation_shape[0] + action_size, output_size),
*[SimBaBlock(output_size, hidden_size, output_size) for _ in range(n_blocks)],
nn.LayerNorm(output_size),
]
self._layers = nn.Sequential(*layers)
self._action_size = action_size
self._discrete_action = discrete_action

def forward(self, x: TorchObservation, action: torch.Tensor) -> torch.Tensor:
assert isinstance(x, torch.Tensor)
if self._discrete_action:
action = F.one_hot(
action.view(-1).long(), num_classes=self._action_size
).float()
h = torch.cat([x, action], dim=1)
return self._layers(h)


def compute_output_size(
input_shapes: Sequence[Shape], encoder: nn.Module
) -> int:
Expand Down
33 changes: 33 additions & 0 deletions tests/models/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
from d3rlpy.models.encoders import (
DefaultEncoderFactory,
PixelEncoderFactory,
SimBaEncoderFactory,
VectorEncoderFactory,
)
from d3rlpy.models.torch.encoders import (
PixelEncoder,
PixelEncoderWithAction,
SimBaEncoder,
SimBaEncoderWithAction,
VectorEncoder,
VectorEncoderWithAction,
)
Expand Down Expand Up @@ -104,3 +107,33 @@ def test_default_encoder_factory(

# check serization and deserialization
DefaultEncoderFactory.deserialize(factory.serialize())


@pytest.mark.parametrize("observation_shape", [(100,)])
@pytest.mark.parametrize("action_size", [2])
@pytest.mark.parametrize("discrete_action", [False, True])
def test_simba_encoder_factory(
observation_shape: Sequence[int],
action_size: int,
discrete_action: bool,
) -> None:
factory = SimBaEncoderFactory()

# test state encoder
encoder = factory.create(observation_shape)
assert isinstance(encoder, SimBaEncoder)

# test state-action encoder
encoder = factory.create_with_action(
observation_shape, action_size, discrete_action
)
assert isinstance(encoder, SimBaEncoderWithAction)
assert encoder._discrete_action == discrete_action

assert factory.get_type() == "simba"

# check serization and deserialization
new_factory = SimBaEncoderFactory.deserialize(factory.serialize())
assert new_factory.hidden_size == factory.hidden_size
assert new_factory.feature_size == factory.feature_size
assert new_factory.n_blocks == factory.n_blocks
70 changes: 70 additions & 0 deletions tests/models/torch/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from d3rlpy.models.torch.encoders import (
PixelEncoder,
PixelEncoderWithAction,
SimBaEncoder,
SimBaEncoderWithAction,
VectorEncoder,
VectorEncoderWithAction,
)
Expand Down Expand Up @@ -212,3 +214,71 @@ def test_vector_encoder_with_action(

# check layer connection
check_parameter_updates(encoder, (x, action))


@pytest.mark.parametrize("observation_shape", [(100,)])
@pytest.mark.parametrize("hidden_size", [128])
@pytest.mark.parametrize("output_size", [256])
@pytest.mark.parametrize("n_blocks", [2])
@pytest.mark.parametrize("batch_size", [32])
def test_simba_encoder(
observation_shape: Sequence[int],
hidden_size: int,
output_size: int,
n_blocks: int,
batch_size: int
) -> None:
encoder = SimBaEncoder(
observation_shape=observation_shape,
hidden_size=hidden_size,
output_size=output_size,
n_blocks=n_blocks,
)

x = torch.rand((batch_size, *observation_shape))
y = encoder(x)

# check output shape
assert y.shape == (batch_size, output_size)

# check layer connection
check_parameter_updates(encoder, (x,))


@pytest.mark.parametrize("observation_shape", [(100,)])
@pytest.mark.parametrize("action_size", [2])
@pytest.mark.parametrize("hidden_size", [128])
@pytest.mark.parametrize("output_size", [256])
@pytest.mark.parametrize("n_blocks", [2])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("discrete_action", [False, True])
def test_simba_encoder_with_action(
observation_shape: Sequence[int],
action_size: int,
hidden_size: int,
output_size: int,
n_blocks: int,
batch_size: int,
discrete_action: bool,
) -> None:
encoder = SimBaEncoderWithAction(
observation_shape=observation_shape,
action_size=action_size,
hidden_size=hidden_size,
output_size=output_size,
n_blocks=n_blocks,
discrete_action=discrete_action,
)

x = torch.rand((batch_size, *observation_shape))
if discrete_action:
action = torch.randint(0, action_size, size=(batch_size, 1))
else:
action = torch.rand(batch_size, action_size)
y = encoder(x, action)

# check output shape
assert y.shape == (batch_size, output_size)

# check layer connection
check_parameter_updates(encoder, (x, action))

0 comments on commit 1d28d0b

Please sign in to comment.