diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index d8a9fb5e..37c064ee 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -14,7 +14,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...models.torch import DiscreteImitator, PixelEncoder +from ...models.torch import DiscreteImitator, PixelEncoder, compute_output_size from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.bcq_impl import BCQImpl, DiscreteBCQImpl @@ -338,8 +338,16 @@ def inner_create_impl( # share convolutional layers if observation is pixel if isinstance(q_func.q_funcs[0].encoder, PixelEncoder): + hidden_size = compute_output_size( + [observation_shape], + q_func.q_funcs[0].encoder, + device=self._device, + ) imitator = DiscreteImitator( - q_func.q_funcs[0].encoder, action_size, self._config.beta + encoder=q_func.q_funcs[0].encoder, + hidden_size=hidden_size, + action_size=action_size, + beta=self._config.beta, ) imitator.to(self._device) else: diff --git a/d3rlpy/algos/qlearning/torch/bc_impl.py b/d3rlpy/algos/qlearning/torch/bc_impl.py index 5f497279..60467ac7 100644 --- a/d3rlpy/algos/qlearning/torch/bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/bc_impl.py @@ -13,6 +13,7 @@ Policy, ProbablisticRegressor, SquashedNormalPolicy, + compute_output_size, ) from ....torch_utility import TorchMiniBatch, hard_sync, train_api from ..base import QLearningAlgoImplBase @@ -94,14 +95,27 @@ def __init__( @property def policy(self) -> Policy: policy: Policy + hidden_size = compute_output_size( + [self._observation_shape, (self._action_size,)], + self._imitator.encoder, + device=self._device, + ) if self._policy_type == "deterministic": + hidden_size = compute_output_size( + [self._observation_shape, (self._action_size,)], + self._imitator.encoder, + device=self._device, + ) policy = DeterministicPolicy( - self._imitator.encoder, self._action_size + encoder=self._imitator.encoder, + hidden_size=hidden_size, + action_size=self._action_size, ) elif self._policy_type == "stochastic": return SquashedNormalPolicy( - self._imitator.encoder, - self._action_size, + encoder=self._imitator.encoder, + hidden_size=hidden_size, + action_size=self._action_size, min_logstd=-4.0, max_logstd=15.0, use_std_parameter=False, diff --git a/d3rlpy/models/builders.py b/d3rlpy/models/builders.py index af674d80..f4486752 100644 --- a/d3rlpy/models/builders.py +++ b/d3rlpy/models/builders.py @@ -24,6 +24,7 @@ SimplePositionEncoding, SquashedNormalPolicy, ValueFunction, + compute_output_size, ) from .utility import create_activation @@ -56,6 +57,7 @@ def create_discrete_q_function( ) -> EnsembleDiscreteQFunction: if q_func_factory.share_encoder: encoder = encoder_factory.create(observation_shape) + hidden_size = compute_output_size([observation_shape], encoder, device) # normalize gradient scale by ensemble size for p in cast(nn.Module, encoder).parameters(): p.register_hook(lambda grad: grad / n_ensembles) @@ -64,7 +66,12 @@ def create_discrete_q_function( for _ in range(n_ensembles): if not q_func_factory.share_encoder: encoder = encoder_factory.create(observation_shape) - q_funcs.append(q_func_factory.create_discrete(encoder, action_size)) + hidden_size = compute_output_size( + [observation_shape], encoder, device + ) + q_funcs.append( + q_func_factory.create_discrete(encoder, hidden_size, action_size) + ) q_func = EnsembleDiscreteQFunction(q_funcs) q_func.to(device) return q_func @@ -82,6 +89,9 @@ def create_continuous_q_function( encoder = encoder_factory.create_with_action( observation_shape, action_size ) + hidden_size = compute_output_size( + [observation_shape, (action_size,)], encoder, device + ) # normalize gradient scale by ensemble size for p in cast(nn.Module, encoder).parameters(): p.register_hook(lambda grad: grad / n_ensembles) @@ -92,7 +102,12 @@ def create_continuous_q_function( encoder = encoder_factory.create_with_action( observation_shape, action_size ) - q_funcs.append(q_func_factory.create_continuous(encoder)) + hidden_size = compute_output_size( + [observation_shape, (action_size,)], encoder, device + ) + q_funcs.append( + q_func_factory.create_continuous(encoder, hidden_size, action_size) + ) q_func = EnsembleContinuousQFunction(q_funcs) q_func.to(device) return q_func @@ -105,7 +120,12 @@ def create_deterministic_policy( device: str, ) -> DeterministicPolicy: encoder = encoder_factory.create(observation_shape) - policy = DeterministicPolicy(encoder, action_size) + hidden_size = compute_output_size([observation_shape], encoder, device) + policy = DeterministicPolicy( + encoder=encoder, + hidden_size=hidden_size, + action_size=action_size, + ) policy.to(device) return policy @@ -118,7 +138,15 @@ def create_deterministic_residual_policy( device: str, ) -> DeterministicResidualPolicy: encoder = encoder_factory.create_with_action(observation_shape, action_size) - policy = DeterministicResidualPolicy(encoder, scale) + hidden_size = compute_output_size( + [observation_shape, (action_size,)], encoder, device + ) + policy = DeterministicResidualPolicy( + encoder=encoder, + hidden_size=hidden_size, + action_size=action_size, + scale=scale, + ) policy.to(device) return policy @@ -133,9 +161,11 @@ def create_squashed_normal_policy( use_std_parameter: bool = False, ) -> SquashedNormalPolicy: encoder = encoder_factory.create(observation_shape) + hidden_size = compute_output_size([observation_shape], encoder, device) policy = SquashedNormalPolicy( - encoder, - action_size, + encoder=encoder, + hidden_size=hidden_size, + action_size=action_size, min_logstd=min_logstd, max_logstd=max_logstd, use_std_parameter=use_std_parameter, @@ -154,9 +184,11 @@ def create_non_squashed_normal_policy( use_std_parameter: bool = False, ) -> NonSquashedNormalPolicy: encoder = encoder_factory.create(observation_shape) + hidden_size = compute_output_size([observation_shape], encoder, device) policy = NonSquashedNormalPolicy( - encoder, - action_size, + encoder=encoder, + hidden_size=hidden_size, + action_size=action_size, min_logstd=min_logstd, max_logstd=max_logstd, use_std_parameter=use_std_parameter, @@ -172,7 +204,10 @@ def create_categorical_policy( device: str, ) -> CategoricalPolicy: encoder = encoder_factory.create(observation_shape) - policy = CategoricalPolicy(encoder, action_size) + hidden_size = compute_output_size([observation_shape], encoder, device) + policy = CategoricalPolicy( + encoder=encoder, hidden_size=hidden_size, action_size=action_size + ) policy.to(device) return policy @@ -193,10 +228,16 @@ def create_conditional_vae( decoder_encoder = encoder_factory.create_with_action( observation_shape, latent_size ) + hidden_size = compute_output_size( + [observation_shape, (action_size,)], encoder_encoder, device + ) policy = ConditionalVAE( - encoder_encoder, - decoder_encoder, - beta, + encoder_encoder=encoder_encoder, + decoder_encoder=decoder_encoder, + hidden_size=hidden_size, + latent_size=latent_size, + action_size=action_size, + beta=beta, min_logstd=min_logstd, max_logstd=max_logstd, ) @@ -212,7 +253,13 @@ def create_discrete_imitator( device: str, ) -> DiscreteImitator: encoder = encoder_factory.create(observation_shape) - imitator = DiscreteImitator(encoder, action_size, beta) + hidden_size = compute_output_size([observation_shape], encoder, device) + imitator = DiscreteImitator( + encoder=encoder, + hidden_size=hidden_size, + action_size=action_size, + beta=beta, + ) imitator.to(device) return imitator @@ -224,7 +271,12 @@ def create_deterministic_regressor( device: str, ) -> DeterministicRegressor: encoder = encoder_factory.create(observation_shape) - regressor = DeterministicRegressor(encoder, action_size) + hidden_size = compute_output_size([observation_shape], encoder, device) + regressor = DeterministicRegressor( + encoder=encoder, + hidden_size=hidden_size, + action_size=action_size, + ) regressor.to(device) return regressor @@ -238,8 +290,13 @@ def create_probablistic_regressor( max_logstd: float = 2.0, ) -> ProbablisticRegressor: encoder = encoder_factory.create(observation_shape) + hidden_size = compute_output_size([observation_shape], encoder, device) regressor = ProbablisticRegressor( - encoder, action_size, min_logstd=min_logstd, max_logstd=max_logstd + encoder=encoder, + hidden_size=hidden_size, + action_size=action_size, + min_logstd=min_logstd, + max_logstd=max_logstd, ) regressor.to(device) return regressor @@ -249,7 +306,8 @@ def create_value_function( observation_shape: Shape, encoder_factory: EncoderFactory, device: str ) -> ValueFunction: encoder = encoder_factory.create(observation_shape) - value_func = ValueFunction(encoder) + hidden_size = compute_output_size([observation_shape], encoder, device) + value_func = ValueFunction(encoder, hidden_size) value_func.to(device) return value_func @@ -279,7 +337,7 @@ def create_continuous_decision_transformer( device: str, ) -> ContinuousDecisionTransformer: encoder = encoder_factory.create(observation_shape) - hidden_size = encoder.get_feature_size() + hidden_size = compute_output_size([observation_shape], encoder, device) if position_encoding_type == "simple": position_encoding = SimplePositionEncoding(hidden_size, max_timestep) @@ -294,6 +352,7 @@ def create_continuous_decision_transformer( transformer = ContinuousDecisionTransformer( encoder=encoder, + feature_size=hidden_size, position_encoding=position_encoding, action_size=action_size, num_heads=num_heads, @@ -324,7 +383,7 @@ def create_discrete_decision_transformer( device: str, ) -> DiscreteDecisionTransformer: encoder = encoder_factory.create(observation_shape) - hidden_size = encoder.get_feature_size() + hidden_size = compute_output_size([observation_shape], encoder, device) if position_encoding_type == "simple": position_encoding = SimplePositionEncoding(hidden_size, max_timestep) @@ -339,6 +398,7 @@ def create_discrete_decision_transformer( transformer = DiscreteDecisionTransformer( encoder=encoder, + feature_size=hidden_size, position_encoding=position_encoding, action_size=action_size, num_heads=num_heads, diff --git a/d3rlpy/models/encoders.py b/d3rlpy/models/encoders.py index e15c5f81..f516c066 100644 --- a/d3rlpy/models/encoders.py +++ b/d3rlpy/models/encoders.py @@ -18,7 +18,6 @@ "PixelEncoderFactory", "VectorEncoderFactory", "DefaultEncoderFactory", - "DenseEncoderFactory", "register_encoder_factory", "make_encoder_field", ] @@ -130,7 +129,6 @@ class VectorEncoderFactory(EncoderFactory): standard architecture with ``[256, 256]`` is used. activation (str): activation function name. use_batch_norm (bool): Flag to insert batch normalization layers. - use_dense (bool): Flag to use DenseNet architecture. dropout_rate (float): Dropout probability. exclude_last_activation (bool): Flag to exclude activation function at the last layer. @@ -140,7 +138,6 @@ class VectorEncoderFactory(EncoderFactory): activation: str = "relu" use_batch_norm: bool = False dropout_rate: Optional[float] = None - use_dense: bool = False exclude_last_activation: bool = False def create(self, observation_shape: Shape) -> VectorEncoder: @@ -150,7 +147,6 @@ def create(self, observation_shape: Shape) -> VectorEncoder: hidden_units=self.hidden_units, use_batch_norm=self.use_batch_norm, dropout_rate=self.dropout_rate, - use_dense=self.use_dense, activation=create_activation(self.activation), exclude_last_activation=self.exclude_last_activation, ) @@ -168,7 +164,6 @@ def create_with_action( hidden_units=self.hidden_units, use_batch_norm=self.use_batch_norm, dropout_rate=self.dropout_rate, - use_dense=self.use_dense, discrete_action=discrete_action, activation=create_activation(self.activation), exclude_last_activation=self.exclude_last_activation, @@ -239,72 +234,6 @@ def get_type() -> str: return "default" -@dataclass() -class DenseEncoderFactory(EncoderFactory): - """DenseNet encoder factory class. - - This is an alias for DenseNet architecture proposed in D2RL. - This class does exactly same as follows. - - .. code-block:: python - - from d3rlpy.encoders import VectorEncoderFactory - - factory = VectorEncoderFactory(hidden_units=[256, 256, 256, 256], - use_dense=True) - - For now, this only supports vector observations. - - References: - * `Sinha et al., D2RL: Deep Dense Architectures in Reinforcement - Learning. `_ - - Args: - activation (str): activation function name. - use_batch_norm (bool): flag to insert batch normalization layers. - dropout_rate (float): dropout probability. - """ - - activation: str = "relu" - use_batch_norm: bool = False - dropout_rate: Optional[float] = None - - def create(self, observation_shape: Shape) -> VectorEncoder: - if len(observation_shape) == 3: - raise NotImplementedError("pixel observation is not supported.") - factory = VectorEncoderFactory( - hidden_units=[256, 256, 256, 256], - activation=self.activation, - use_dense=True, - use_batch_norm=self.use_batch_norm, - dropout_rate=self.dropout_rate, - ) - return factory.create(observation_shape) - - def create_with_action( - self, - observation_shape: Shape, - action_size: int, - discrete_action: bool = False, - ) -> VectorEncoderWithAction: - if len(observation_shape) == 3: - raise NotImplementedError("pixel observation is not supported.") - factory = VectorEncoderFactory( - hidden_units=[256, 256, 256, 256], - activation=self.activation, - use_dense=True, - use_batch_norm=self.use_batch_norm, - dropout_rate=self.dropout_rate, - ) - return factory.create_with_action( - observation_shape, action_size, discrete_action - ) - - @staticmethod - def get_type() -> str: - return "dense" - - register_encoder_factory, make_encoder_field = generate_config_registration( EncoderFactory, lambda: DefaultEncoderFactory() ) @@ -313,4 +242,3 @@ def get_type() -> str: register_encoder_factory(VectorEncoderFactory) register_encoder_factory(PixelEncoderFactory) register_encoder_factory(DefaultEncoderFactory) -register_encoder_factory(DenseEncoderFactory) diff --git a/d3rlpy/models/q_functions.py b/d3rlpy/models/q_functions.py index 3b904781..1a31628b 100644 --- a/d3rlpy/models/q_functions.py +++ b/d3rlpy/models/q_functions.py @@ -2,12 +2,10 @@ from ..serializable_config import DynamicConfig, generate_config_registration from .torch import ( - ContinuousFQFQFunction, ContinuousIQNQFunction, ContinuousMeanQFunction, ContinuousQFunction, ContinuousQRQFunction, - DiscreteFQFQFunction, DiscreteIQNQFunction, DiscreteMeanQFunction, DiscreteQFunction, @@ -30,14 +28,15 @@ class QFunctionFactory(DynamicConfig): share_encoder: bool = False def create_discrete( - self, encoder: Encoder, action_size: int + self, encoder: Encoder, hidden_size: int, action_size: int ) -> DiscreteQFunction: """Returns PyTorch's Q function module. Args: - encoder: an encoder module that processes the observation to + encoder: Encoder that processes the observation to obtain feature representations. - action_size: dimension of discrete action-space. + hidden_size: Dimension of encoder output. + action_size: Dimension of discrete action-space. Returns: discrete Q function object. @@ -45,13 +44,15 @@ def create_discrete( raise NotImplementedError def create_continuous( - self, encoder: EncoderWithAction + self, encoder: EncoderWithAction, hidden_size: int, action_size: int ) -> ContinuousQFunction: """Returns PyTorch's Q function module. Args: - encoder: an encoder module that processes the observation and + encoder: Encoder module that processes the observation and action to obtain feature representations. + hidden_size: Dimension of encoder output. + action_size: Dimension of continuous actions. Returns: continuous Q function object. @@ -87,15 +88,18 @@ class MeanQFunctionFactory(QFunctionFactory): def create_discrete( self, encoder: Encoder, + hidden_size: int, action_size: int, ) -> DiscreteMeanQFunction: - return DiscreteMeanQFunction(encoder, action_size) + return DiscreteMeanQFunction(encoder, hidden_size, action_size) def create_continuous( self, encoder: EncoderWithAction, + hidden_size: int, + action_size: int, ) -> ContinuousMeanQFunction: - return ContinuousMeanQFunction(encoder) + return ContinuousMeanQFunction(encoder, hidden_size, action_size) @staticmethod def get_type() -> str: @@ -118,15 +122,27 @@ class QRQFunctionFactory(QFunctionFactory): n_quantiles: int = 32 def create_discrete( - self, encoder: Encoder, action_size: int + self, encoder: Encoder, hidden_size: int, action_size: int ) -> DiscreteQRQFunction: - return DiscreteQRQFunction(encoder, action_size, self.n_quantiles) + return DiscreteQRQFunction( + encoder=encoder, + hidden_size=hidden_size, + action_size=action_size, + n_quantiles=self.n_quantiles, + ) def create_continuous( self, encoder: EncoderWithAction, + hidden_size: int, + action_size: int, ) -> ContinuousQRQFunction: - return ContinuousQRQFunction(encoder, self.n_quantiles) + return ContinuousQRQFunction( + encoder=encoder, + hidden_size=hidden_size, + action_size=action_size, + n_quantiles=self.n_quantiles, + ) @staticmethod def get_type() -> str: @@ -155,10 +171,12 @@ class IQNQFunctionFactory(QFunctionFactory): def create_discrete( self, encoder: Encoder, + hidden_size: int, action_size: int, ) -> DiscreteIQNQFunction: return DiscreteIQNQFunction( encoder=encoder, + hidden_size=hidden_size, action_size=action_size, n_quantiles=self.n_quantiles, n_greedy_quantiles=self.n_greedy_quantiles, @@ -166,11 +184,12 @@ def create_discrete( ) def create_continuous( - self, - encoder: EncoderWithAction, + self, encoder: EncoderWithAction, hidden_size: int, action_size: int ) -> ContinuousIQNQFunction: return ContinuousIQNQFunction( encoder=encoder, + hidden_size=hidden_size, + action_size=action_size, n_quantiles=self.n_quantiles, n_greedy_quantiles=self.n_greedy_quantiles, embed_size=self.embed_size, @@ -181,55 +200,6 @@ def get_type() -> str: return "iqn" -@dataclasses.dataclass() -class FQFQFunctionFactory(QFunctionFactory): - """Fully parameterized Quantile Function Q function factory. - - References: - * `Yang et al., Fully parameterized quantile function for - distributional reinforcement learning. - `_ - - Args: - share_encoder (bool): flag to share encoder over multiple Q functions. - n_quantiles: the number of quantiles. - embed_size: the embedding size. - entropy_coeff: the coefficiency of entropy penalty term. - """ - - n_quantiles: int = 32 - embed_size: int = 64 - entropy_coeff: float = 0.0 - - def create_discrete( - self, - encoder: Encoder, - action_size: int, - ) -> DiscreteFQFQFunction: - return DiscreteFQFQFunction( - encoder=encoder, - action_size=action_size, - n_quantiles=self.n_quantiles, - embed_size=self.embed_size, - entropy_coeff=self.entropy_coeff, - ) - - def create_continuous( - self, - encoder: EncoderWithAction, - ) -> ContinuousFQFQFunction: - return ContinuousFQFQFunction( - encoder=encoder, - n_quantiles=self.n_quantiles, - embed_size=self.embed_size, - entropy_coeff=self.entropy_coeff, - ) - - @staticmethod - def get_type() -> str: - return "fqf" - - register_q_func_factory, make_q_func_field = generate_config_registration( QFunctionFactory, lambda: MeanQFunctionFactory() ) @@ -238,4 +208,3 @@ def get_type() -> str: register_q_func_factory(MeanQFunctionFactory) register_q_func_factory(QRQFunctionFactory) register_q_func_factory(IQNQFunctionFactory) -register_q_func_factory(FQFQFunctionFactory) diff --git a/d3rlpy/models/torch/encoders.py b/d3rlpy/models/torch/encoders.py index 0ceb5fe3..612c9769 100644 --- a/d3rlpy/models/torch/encoders.py +++ b/d3rlpy/models/torch/encoders.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from torch import nn +from ...dataset import Shape from ...itertools import last_flag __all__ = [ @@ -14,64 +15,31 @@ "PixelEncoderWithAction", "VectorEncoder", "VectorEncoderWithAction", + "compute_output_size", ] -class Encoder(metaclass=ABCMeta): +class Encoder(nn.Module, metaclass=ABCMeta): # type: ignore @abstractmethod def forward(self, x: torch.Tensor) -> torch.Tensor: pass - @abstractmethod - def get_feature_size(self) -> int: - pass - - @property - @abstractmethod - def observation_shape(self) -> Sequence[int]: - pass - - @abstractmethod def __call__(self, x: torch.Tensor) -> torch.Tensor: - pass + return super().__call__(x) -class EncoderWithAction(metaclass=ABCMeta): +class EncoderWithAction(nn.Module, metaclass=ABCMeta): # type: ignore @abstractmethod def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: pass - @abstractmethod - def get_feature_size(self) -> int: - pass - - @property - @abstractmethod - def action_size(self) -> int: - pass - - @property - @abstractmethod - def observation_shape(self) -> Sequence[int]: - pass - - @abstractmethod def __call__(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: - pass + return super().__call__(x, action) -class _PixelEncoder(nn.Module): # type: ignore - _observation_shape: Sequence[int] - _feature_size: int - _use_batch_norm: bool - _dropout_rate: Optional[float] - _activation: nn.Module - _convs: nn.ModuleList - _conv_bns: nn.ModuleList - _fc: nn.Linear - _fc_bn: nn.BatchNorm1d - _dropouts: nn.ModuleList - _exclude_last_activation: bool +class PixelEncoder(Encoder): + _cnn_layers: nn.Module + _last_layers: nn.Module def __init__( self, @@ -91,86 +59,53 @@ def __init__( if feature_size is None: feature_size = 512 - self._observation_shape = observation_shape - self._use_batch_norm = use_batch_norm - self._dropout_rate = dropout_rate - self._activation = activation - self._feature_size = feature_size - self._exclude_last_activation = exclude_last_activation - # convolutional layers + cnn_layers = [] in_channels = [observation_shape[0]] + [f[0] for f in filters[:-1]] - self._convs = nn.ModuleList() - self._conv_bns = nn.ModuleList() - self._dropouts = nn.ModuleList() for in_channel, f in zip(in_channels, filters): out_channel, kernel_size, stride = f conv = nn.Conv2d( in_channel, out_channel, kernel_size=kernel_size, stride=stride ) - self._convs.append(conv) + cnn_layers.append(conv) + cnn_layers.append(activation) # use batch normalization layer if use_batch_norm: - self._conv_bns.append(nn.BatchNorm2d(out_channel)) + cnn_layers.append(nn.BatchNorm2d(out_channel)) # use dropout layer if dropout_rate is not None: - self._dropouts.append(nn.Dropout2d(dropout_rate)) + cnn_layers.append(nn.Dropout2d(dropout_rate)) + self._cnn_layers = nn.Sequential(*cnn_layers) + + # compute output shape of CNN layers + x = torch.rand((1,) + tuple(observation_shape)) + with torch.no_grad(): + cnn_output_size = self._cnn_layers(x).view(1, -1).shape[1] # last dense layer - self._fc = nn.Linear(self._get_linear_input_size(), feature_size) + layers: List[nn.Module] = [] + layers.append(nn.Linear(cnn_output_size, feature_size)) + if not exclude_last_activation: + layers.append(activation) if use_batch_norm: - self._fc_bn = nn.BatchNorm1d(feature_size) + layers.append(nn.BatchNorm1d(feature_size)) if dropout_rate is not None: - self._dropouts.append(nn.Dropout(dropout_rate)) - - def _get_linear_input_size(self) -> int: - x = torch.rand((1,) + tuple(self._observation_shape)) - with torch.no_grad(): - return self._conv_encode(x).view(1, -1).shape[1] # type: ignore - - def _get_last_conv_shape(self) -> Sequence[int]: - x = torch.rand((1,) + tuple(self._observation_shape)) - with torch.no_grad(): - return self._conv_encode(x).shape # type: ignore - - def _conv_encode(self, x: torch.Tensor) -> torch.Tensor: - h = x - for i, conv in enumerate(self._convs): - h = self._activation(conv(h)) - if self._use_batch_norm: - h = self._conv_bns[i](h) - if self._dropout_rate is not None: - h = self._dropouts[i](h) - return h - - def get_feature_size(self) -> int: - return self._feature_size + layers.append(nn.Dropout(dropout_rate)) - @property - def observation_shape(self) -> Sequence[int]: - return self._observation_shape + self._last_layers = nn.Sequential(*layers) - -class PixelEncoder(_PixelEncoder, Encoder): def forward(self, x: torch.Tensor) -> torch.Tensor: - h = self._conv_encode(x) - - h = self._fc(h.view(h.shape[0], -1)) - if not self._exclude_last_activation: - h = self._activation(h) - if self._use_batch_norm: - h = self._fc_bn(h) - if self._dropout_rate is not None: - h = self._dropouts[-1](h) + h = self._cnn_layers(x) + return self._last_layers(h.reshape(x.shape[0], -1)) - return h - -class PixelEncoderWithAction(_PixelEncoder, EncoderWithAction): - _action_size: int +class PixelEncoderWithAction(EncoderWithAction): + _cnn_layers: nn.Module + _last_layers: nn.Module _discrete_action: bool + _action_size: int def __init__( self, @@ -179,29 +114,59 @@ def __init__( filters: Optional[List[List[int]]] = None, feature_size: int = 512, use_batch_norm: bool = False, - dropout_rate: Optional[float] = None, + dropout_rate: Optional[float] = False, discrete_action: bool = False, activation: nn.Module = nn.ReLU(), exclude_last_activation: bool = False, ): - self._action_size = action_size + super().__init__() self._discrete_action = discrete_action - super().__init__( - observation_shape=observation_shape, - filters=filters, - feature_size=feature_size, - use_batch_norm=use_batch_norm, - dropout_rate=dropout_rate, - activation=activation, - exclude_last_activation=exclude_last_activation, - ) + self._action_size = action_size - def _get_linear_input_size(self) -> int: - size = super()._get_linear_input_size() - return size + self._action_size + # default architecture is based on Nature DQN paper. + if filters is None: + filters = [[32, 8, 4], [64, 4, 2], [64, 3, 1]] + if feature_size is None: + feature_size = 512 + + # convolutional layers + cnn_layers = [] + in_channels = [observation_shape[0]] + [f[0] for f in filters[:-1]] + for in_channel, f in zip(in_channels, filters): + out_channel, kernel_size, stride = f + conv = nn.Conv2d( + in_channel, out_channel, kernel_size=kernel_size, stride=stride + ) + cnn_layers.append(conv) + cnn_layers.append(activation) + + # use batch normalization layer + if use_batch_norm: + cnn_layers.append(nn.BatchNorm2d(out_channel)) + + # use dropout layer + if dropout_rate is not None: + cnn_layers.append(nn.Dropout2d(dropout_rate)) + self._cnn_layers = nn.Sequential(*cnn_layers) + + # compute output shape of CNN layers + x = torch.rand((1,) + tuple(observation_shape)) + with torch.no_grad(): + cnn_output_size = self._cnn_layers(x).view(1, -1).shape[1] + + # last dense layer + layers: List[nn.Module] = [] + layers.append(nn.Linear(cnn_output_size + action_size, feature_size)) + if not exclude_last_activation: + layers.append(activation) + if use_batch_norm: + layers.append(nn.BatchNorm1d(feature_size)) + if dropout_rate is not None: + layers.append(nn.Dropout(dropout_rate)) + self._last_layers = nn.Sequential(*layers) def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: - h = self._conv_encode(x) + h = self._cnn_layers(x) if self._discrete_action: action = F.one_hot( @@ -209,32 +174,13 @@ def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: ).float() # cocat feature and action - h = self._fc(torch.cat([h.view(h.shape[0], -1), action], dim=1)) - if not self._exclude_last_activation: - h = self._activation(h) - if self._use_batch_norm: - h = self._fc_bn(h) - if self._dropout_rate is not None: - h = self._dropouts[-1](h) - - return h - - @property - def action_size(self) -> int: - return self._action_size - - -class _VectorEncoder(nn.Module): # type: ignore - _observation_shape: Sequence[int] - _use_batch_norm: bool - _dropout_rate: Optional[float] - _use_dense: bool - _activation: nn.Module - _feature_size: int - _fcs: nn.ModuleList - _bns: nn.ModuleList - _dropouts: nn.ModuleList - _exclude_last_activation: bool + h = torch.cat([h.reshape(h.shape[0], -1), action], dim=1) + + return self._last_layers(h) + + +class VectorEncoder(Encoder): + _layers: nn.Module def __init__( self, @@ -242,69 +188,34 @@ def __init__( hidden_units: Optional[Sequence[int]] = None, use_batch_norm: bool = False, dropout_rate: Optional[float] = None, - use_dense: bool = False, activation: nn.Module = nn.ReLU(), exclude_last_activation: bool = False, ): super().__init__() - self._observation_shape = observation_shape if hidden_units is None: hidden_units = [256, 256] - self._use_batch_norm = use_batch_norm - self._dropout_rate = dropout_rate - self._feature_size = hidden_units[-1] - self._activation = activation - self._use_dense = use_dense - self._exclude_last_activation = exclude_last_activation - + layers = [] in_units = [observation_shape[0]] + list(hidden_units[:-1]) - self._fcs = nn.ModuleList() - self._bns = nn.ModuleList() - self._dropouts = nn.ModuleList() - for i, (in_unit, out_unit) in enumerate(zip(in_units, hidden_units)): - if use_dense and i > 0: - in_unit += observation_shape[0] - self._fcs.append(nn.Linear(in_unit, out_unit)) + for is_last, (in_unit, out_unit) in last_flag( + zip(in_units, hidden_units) + ): + layers.append(nn.Linear(in_unit, out_unit)) + if not is_last or not exclude_last_activation: + layers.append(activation) if use_batch_norm: - self._bns.append(nn.BatchNorm1d(out_unit)) + layers.append(nn.BatchNorm1d(out_unit)) if dropout_rate is not None: - self._dropouts.append(nn.Dropout(dropout_rate)) - - def _fc_encode(self, x: torch.Tensor) -> torch.Tensor: - h = x - for is_last, (i, fc) in last_flag(enumerate(self._fcs)): - if self._use_dense and i > 0: - h = torch.cat([h, x], dim=1) - h = fc(h) - if not is_last or not self._exclude_last_activation: - h = self._activation(h) - if self._use_batch_norm: - h = self._bns[i](h) - if self._dropout_rate is not None: - h = self._dropouts[i](h) - return h - - def get_feature_size(self) -> int: - return self._feature_size - - @property - def observation_shape(self) -> Sequence[int]: - return self._observation_shape - - -class VectorEncoder(_VectorEncoder, Encoder): + layers.append(nn.Dropout(dropout_rate)) + self._layers = nn.Sequential(*layers) + def forward(self, x: torch.Tensor) -> torch.Tensor: - h = self._fc_encode(x) - if self._use_batch_norm: - h = self._bns[-1](h) - if self._dropout_rate is not None: - h = self._dropouts[-1](h) - return h + return self._layers(x) -class VectorEncoderWithAction(_VectorEncoder, EncoderWithAction): +class VectorEncoderWithAction(EncoderWithAction): + _layers: nn.Module _action_size: int _discrete_action: bool @@ -315,38 +226,51 @@ def __init__( hidden_units: Optional[Sequence[int]] = None, use_batch_norm: bool = False, dropout_rate: Optional[float] = None, - use_dense: bool = False, discrete_action: bool = False, activation: nn.Module = nn.ReLU(), exclude_last_activation: bool = False, ): + super().__init__() self._action_size = action_size self._discrete_action = discrete_action - concat_shape = (observation_shape[0] + action_size,) - super().__init__( - observation_shape=concat_shape, - hidden_units=hidden_units, - use_batch_norm=use_batch_norm, - use_dense=use_dense, - dropout_rate=dropout_rate, - activation=activation, - exclude_last_activation=exclude_last_activation, + + if hidden_units is None: + hidden_units = [256, 256] + + layers = [] + in_units = [observation_shape[0] + action_size] + list( + hidden_units[:-1] ) - self._observation_shape = observation_shape + for is_last, (in_unit, out_unit) in last_flag( + zip(in_units, hidden_units) + ): + layers.append(nn.Linear(in_unit, out_unit)) + if not is_last or not exclude_last_activation: + layers.append(activation) + if use_batch_norm: + layers.append(nn.BatchNorm1d(out_unit)) + if dropout_rate is not None: + layers.append(nn.Dropout(dropout_rate)) + self._layers = nn.Sequential(*layers) def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: if self._discrete_action: action = F.one_hot( - action.view(-1).long(), num_classes=self.action_size + action.view(-1).long(), num_classes=self._action_size ).float() x = torch.cat([x, action], dim=1) - h = self._fc_encode(x) - if self._use_batch_norm: - h = self._bns[-1](h) - if self._dropout_rate is not None: - h = self._dropouts[-1](h) - return h - - @property - def action_size(self) -> int: - return self._action_size + return self._layers(x) + + +def compute_output_size( + input_shapes: Sequence[Shape], encoder: nn.Module, device: str +) -> int: + with torch.no_grad(): + inputs = [] + for shape in input_shapes: + if isinstance(shape[0], (list, tuple)): + inputs.append([torch.rand(1, *s, device=device) for s in shape]) + else: + inputs.append(torch.rand(1, *shape, device=device)) + y = encoder(*inputs) + return int(y.shape[1]) diff --git a/d3rlpy/models/torch/imitators.py b/d3rlpy/models/torch/imitators.py index 225cc4d9..374775c2 100644 --- a/d3rlpy/models/torch/imitators.py +++ b/d3rlpy/models/torch/imitators.py @@ -35,6 +35,9 @@ def __init__( self, encoder_encoder: EncoderWithAction, decoder_encoder: EncoderWithAction, + hidden_size: int, + latent_size: int, + action_size: int, beta: float, min_logstd: float = -20.0, max_logstd: float = 2.0, @@ -46,20 +49,14 @@ def __init__( self._min_logstd = min_logstd self._max_logstd = max_logstd - self._action_size = encoder_encoder.action_size - self._latent_size = decoder_encoder.action_size + self._action_size = action_size + self._latent_size = latent_size # encoder - self._mu = nn.Linear( - encoder_encoder.get_feature_size(), self._latent_size - ) - self._logstd = nn.Linear( - encoder_encoder.get_feature_size(), self._latent_size - ) + self._mu = nn.Linear(hidden_size, self._latent_size) + self._logstd = nn.Linear(hidden_size, self._latent_size) # decoder - self._fc = nn.Linear( - decoder_encoder.get_feature_size(), self._action_size - ) + self._fc = nn.Linear(hidden_size, self._action_size) def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: dist = self.encode(x, action) @@ -151,11 +148,13 @@ class DiscreteImitator(Imitator): _beta: float _fc: nn.Linear - def __init__(self, encoder: Encoder, action_size: int, beta: float): + def __init__( + self, encoder: Encoder, hidden_size: int, action_size: int, beta: float + ): super().__init__() self._encoder = encoder self._beta = beta - self._fc = nn.Linear(encoder.get_feature_size(), action_size) + self._fc = nn.Linear(hidden_size, action_size) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.compute_log_probs_with_logits(x)[0] @@ -184,10 +183,10 @@ class DeterministicRegressor(Imitator): _encoder: Encoder _fc: nn.Linear - def __init__(self, encoder: Encoder, action_size: int): + def __init__(self, encoder: Encoder, hidden_size: int, action_size: int): super().__init__() self._encoder = encoder - self._fc = nn.Linear(encoder.get_feature_size(), action_size) + self._fc = nn.Linear(hidden_size, action_size) def forward(self, x: torch.Tensor) -> torch.Tensor: h = self._encoder(x) @@ -214,6 +213,7 @@ class ProbablisticRegressor(Imitator): def __init__( self, encoder: Encoder, + hidden_size: int, action_size: int, min_logstd: float, max_logstd: float, @@ -222,8 +222,8 @@ def __init__( self._min_logstd = min_logstd self._max_logstd = max_logstd self._encoder = encoder - self._mu = nn.Linear(encoder.get_feature_size(), action_size) - self._logstd = nn.Linear(encoder.get_feature_size(), action_size) + self._mu = nn.Linear(hidden_size, action_size) + self._logstd = nn.Linear(hidden_size, action_size) def dist(self, x: torch.Tensor) -> Normal: h = self._encoder(x) diff --git a/d3rlpy/models/torch/policies.py b/d3rlpy/models/torch/policies.py index a2611bd1..3cfa10a2 100644 --- a/d3rlpy/models/torch/policies.py +++ b/d3rlpy/models/torch/policies.py @@ -59,10 +59,10 @@ class DeterministicPolicy(Policy): _encoder: Encoder _fc: nn.Linear - def __init__(self, encoder: Encoder, action_size: int): + def __init__(self, encoder: Encoder, hidden_size: int, action_size: int): super().__init__() self._encoder = encoder - self._fc = nn.Linear(encoder.get_feature_size(), action_size) + self._fc = nn.Linear(hidden_size, action_size) def forward(self, x: torch.Tensor) -> torch.Tensor: h = self._encoder(x) @@ -94,11 +94,17 @@ class DeterministicResidualPolicy(Policy): _scale: float _fc: nn.Linear - def __init__(self, encoder: EncoderWithAction, scale: float): + def __init__( + self, + encoder: EncoderWithAction, + hidden_size: int, + action_size: int, + scale: float, + ): super().__init__() self._scale = scale self._encoder = encoder - self._fc = nn.Linear(encoder.get_feature_size(), encoder.action_size) + self._fc = nn.Linear(hidden_size, action_size) def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: h = self._encoder(x, action) @@ -145,6 +151,7 @@ class NormalPolicy(Policy): def __init__( self, encoder: Encoder, + hidden_size: int, action_size: int, min_logstd: float, max_logstd: float, @@ -158,12 +165,12 @@ def __init__( self._max_logstd = max_logstd self._use_std_parameter = use_std_parameter self._squash_distribution = squash_distribution - self._mu = nn.Linear(encoder.get_feature_size(), action_size) + self._mu = nn.Linear(hidden_size, action_size) if use_std_parameter: initial_logstd = torch.zeros(1, action_size, dtype=torch.float32) self._logstd = nn.Parameter(initial_logstd) else: - self._logstd = nn.Linear(encoder.get_feature_size(), action_size) + self._logstd = nn.Linear(hidden_size, action_size) def _compute_logstd(self, h: torch.Tensor) -> torch.Tensor: if self._use_std_parameter: @@ -264,6 +271,7 @@ class SquashedNormalPolicy(NormalPolicy): def __init__( self, encoder: Encoder, + hidden_size: int, action_size: int, min_logstd: float, max_logstd: float, @@ -271,6 +279,7 @@ def __init__( ): super().__init__( encoder=encoder, + hidden_size=hidden_size, action_size=action_size, min_logstd=min_logstd, max_logstd=max_logstd, @@ -283,6 +292,7 @@ class NonSquashedNormalPolicy(NormalPolicy): def __init__( self, encoder: Encoder, + hidden_size: int, action_size: int, min_logstd: float, max_logstd: float, @@ -290,6 +300,7 @@ def __init__( ): super().__init__( encoder=encoder, + hidden_size=hidden_size, action_size=action_size, min_logstd=min_logstd, max_logstd=max_logstd, @@ -302,10 +313,10 @@ class CategoricalPolicy(Policy): _encoder: Encoder _fc: nn.Linear - def __init__(self, encoder: Encoder, action_size: int): + def __init__(self, encoder: Encoder, hidden_size: int, action_size: int): super().__init__() self._encoder = encoder - self._fc = nn.Linear(encoder.get_feature_size(), action_size) + self._fc = nn.Linear(hidden_size, action_size) def dist(self, x: torch.Tensor) -> Categorical: h = self._encoder(x) diff --git a/d3rlpy/models/torch/q_functions/__init__.py b/d3rlpy/models/torch/q_functions/__init__.py index da366c48..8f43f950 100644 --- a/d3rlpy/models/torch/q_functions/__init__.py +++ b/d3rlpy/models/torch/q_functions/__init__.py @@ -1,6 +1,5 @@ from .base import * from .ensemble_q_function import * -from .fqf_q_function import * from .iqn_q_function import * from .mean_q_function import * from .qr_q_function import * diff --git a/d3rlpy/models/torch/q_functions/fqf_q_function.py b/d3rlpy/models/torch/q_functions/fqf_q_function.py deleted file mode 100644 index 17611a74..00000000 --- a/d3rlpy/models/torch/q_functions/fqf_q_function.py +++ /dev/null @@ -1,278 +0,0 @@ -from typing import Optional, Tuple, cast - -import torch -from torch import nn - -from ..encoders import Encoder, EncoderWithAction -from .base import ContinuousQFunction, DiscreteQFunction -from .iqn_q_function import compute_iqn_feature -from .utility import ( - compute_quantile_loss, - compute_reduce, - pick_quantile_value_by_action, -) - -__all__ = ["DiscreteFQFQFunction", "ContinuousFQFQFunction"] - - -def _make_taus( - h: torch.Tensor, - proposal: nn.Linear, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - proposals = proposal(h.detach()) - - # tau_i+1 - log_probs = torch.log_softmax(proposals, dim=1) - probs = log_probs.exp() - taus = torch.cumsum(probs, dim=1) - # tau_i - pads = torch.zeros(h.shape[0], 1, device=h.device) - taus_minus = torch.cat([pads, taus[:, :-1]], dim=1) - # tau^ - taus_prime = (taus + taus_minus) / 2 - # entropy for penalty - entropies = -(log_probs * probs).sum(dim=1) - - return taus, taus_minus, taus_prime, entropies - - -class DiscreteFQFQFunction(DiscreteQFunction, nn.Module): # type: ignore - _action_size: int - _entropy_coeff: float - _encoder: Encoder - _fc: nn.Linear - _n_quantiles: int - _embed_size: int - _embed: nn.Linear - _proposal: nn.Linear - - def __init__( - self, - encoder: Encoder, - action_size: int, - n_quantiles: int, - embed_size: int, - entropy_coeff: float = 0.0, - ): - super().__init__() - self._encoder = encoder - self._action_size = action_size - self._fc = nn.Linear(encoder.get_feature_size(), self._action_size) - self._entropy_coeff = entropy_coeff - self._n_quantiles = n_quantiles - self._embed_size = embed_size - self._embed = nn.Linear(embed_size, encoder.get_feature_size()) - self._proposal = nn.Linear(encoder.get_feature_size(), n_quantiles) - - def _compute_quantiles( - self, h: torch.Tensor, taus: torch.Tensor - ) -> torch.Tensor: - # element-wise product on feature and phi (batch, quantile, feature) - prod = compute_iqn_feature(h, taus, self._embed, self._embed_size) - # (batch, quantile, feature) -> (batch, action, quantile) - return cast(torch.Tensor, self._fc(prod)).transpose(1, 2) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - h = self._encoder(x) - taus, taus_minus, taus_prime, _ = _make_taus(h, self._proposal) - quantiles = self._compute_quantiles(h, taus_prime.detach()) - weight = (taus - taus_minus).view(-1, 1, self._n_quantiles).detach() - return (weight * quantiles).sum(dim=2) - - def compute_error( - self, - observations: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - target: torch.Tensor, - terminals: torch.Tensor, - gamma: float = 0.99, - reduction: str = "mean", - ) -> torch.Tensor: - assert target.shape == (observations.shape[0], self._n_quantiles) - - # compute quantiles - h = self._encoder(observations) - taus, _, taus_prime, entropies = _make_taus(h, self._proposal) - all_quantiles = self._compute_quantiles(h, taus_prime.detach()) - quantiles = pick_quantile_value_by_action(all_quantiles, actions) - - quantile_loss = compute_quantile_loss( - quantiles=quantiles, - rewards=rewards, - target=target, - terminals=terminals, - taus=taus_prime.detach(), - gamma=gamma, - ) - - # compute proposal network loss - # original paper explicitly separates the optimization process - # but, it's combined here - proposal_loss = self._compute_proposal_loss( - h, actions, taus, taus_prime - ) - proposal_params = list(self._proposal.parameters()) - proposal_grads = torch.autograd.grad( - outputs=proposal_loss.mean(), - inputs=proposal_params, - retain_graph=True, - ) - # directly apply gradients - for param, grad in zip(list(proposal_params), proposal_grads): - param.grad = 1e-4 * grad - - loss = quantile_loss - self._entropy_coeff * entropies - - return compute_reduce(loss, reduction) - - def _compute_proposal_loss( - self, - h: torch.Tensor, - actions: torch.Tensor, - taus: torch.Tensor, - taus_prime: torch.Tensor, - ) -> torch.Tensor: - q_taus = self._compute_quantiles(h.detach(), taus) - q_taus_prime = self._compute_quantiles(h.detach(), taus_prime) - batch_steps = torch.arange(h.shape[0]) - # (batch, n_quantiles - 1) - q_taus = q_taus[batch_steps, actions.view(-1)][:, :-1] - # (batch, n_quantiles) - q_taus_prime = q_taus_prime[batch_steps, actions.view(-1)] - - # compute gradients - proposal_grad = 2 * q_taus - q_taus_prime[:, :-1] - q_taus_prime[:, 1:] - - return proposal_grad.sum(dim=1) - - def compute_target( - self, x: torch.Tensor, action: Optional[torch.Tensor] = None - ) -> torch.Tensor: - h = self._encoder(x) - _, _, taus_prime, _ = _make_taus(h, self._proposal) - quantiles = self._compute_quantiles(h, taus_prime.detach()) - if action is None: - return quantiles - return pick_quantile_value_by_action(quantiles, action) - - @property - def action_size(self) -> int: - return self._action_size - - @property - def encoder(self) -> Encoder: - return self._encoder - - -class ContinuousFQFQFunction(ContinuousQFunction, nn.Module): # type: ignore - _action_size: int - _entropy_coeff: float - _encoder: EncoderWithAction - _fc: nn.Linear - _n_quantiles: int - _embed_size: int - _embed: nn.Linear - _proposal: nn.Linear - - def __init__( - self, - encoder: EncoderWithAction, - n_quantiles: int, - embed_size: int, - entropy_coeff: float = 0.0, - ): - super().__init__() - self._encoder = encoder - self._action_size = encoder.action_size - self._fc = nn.Linear(encoder.get_feature_size(), 1) - self._entropy_coeff = entropy_coeff - self._n_quantiles = n_quantiles - self._embed_size = embed_size - self._embed = nn.Linear(embed_size, encoder.get_feature_size()) - self._proposal = nn.Linear(encoder.get_feature_size(), n_quantiles) - - def _compute_quantiles( - self, h: torch.Tensor, taus: torch.Tensor - ) -> torch.Tensor: - # element-wise product on feature and phi (batch, quantile, feature) - prod = compute_iqn_feature(h, taus, self._embed, self._embed_size) - # (batch, quantile, feature) -> (batch, quantile) - return cast(torch.Tensor, self._fc(prod)).view(h.shape[0], -1) - - def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: - h = self._encoder(x, action) - taus, taus_minus, taus_prime, _ = _make_taus(h, self._proposal) - quantiles = self._compute_quantiles(h, taus_prime.detach()) - weight = (taus - taus_minus).detach() - return (weight * quantiles).sum(dim=1, keepdim=True) - - def compute_error( - self, - observations: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - target: torch.Tensor, - terminals: torch.Tensor, - gamma: float = 0.99, - reduction: str = "mean", - ) -> torch.Tensor: - assert target.shape == (observations.shape[0], self._n_quantiles) - - h = self._encoder(observations, actions) - taus, _, taus_prime, entropies = _make_taus(h, self._proposal) - quantiles = self._compute_quantiles(h, taus_prime.detach()) - - quantile_loss = compute_quantile_loss( - quantiles=quantiles, - rewards=rewards, - target=target, - terminals=terminals, - taus=taus_prime.detach(), - gamma=gamma, - ) - - # compute proposal network loss - # original paper explicitly separates the optimization process - # but, it's combined here - proposal_loss = self._compute_proposal_loss(h, taus, taus_prime) - proposal_params = list(self._proposal.parameters()) - proposal_grads = torch.autograd.grad( - outputs=proposal_loss.mean(), - inputs=proposal_params, - retain_graph=True, - ) - # directly apply gradients - for param, grad in zip(list(proposal_params), proposal_grads): - param.grad = 1e-4 * grad - - loss = quantile_loss - self._entropy_coeff * entropies - - return compute_reduce(loss, reduction) - - def _compute_proposal_loss( - self, h: torch.Tensor, taus: torch.Tensor, taus_prime: torch.Tensor - ) -> torch.Tensor: - # (batch, n_quantiles - 1) - q_taus = self._compute_quantiles(h.detach(), taus)[:, :-1] - # (batch, n_quantiles) - q_taus_prime = self._compute_quantiles(h.detach(), taus_prime) - - # compute gradients - proposal_grad = 2 * q_taus - q_taus_prime[:, :-1] - q_taus_prime[:, 1:] - return proposal_grad.sum(dim=1) - - def compute_target( - self, x: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: - h = self._encoder(x, action) - _, _, taus_prime, _ = _make_taus(h, self._proposal) - return self._compute_quantiles(h, taus_prime.detach()) - - @property - def action_size(self) -> int: - return self._action_size - - @property - def encoder(self) -> EncoderWithAction: - return self._encoder diff --git a/d3rlpy/models/torch/q_functions/iqn_q_function.py b/d3rlpy/models/torch/q_functions/iqn_q_function.py index be2700e5..9e64eb37 100644 --- a/d3rlpy/models/torch/q_functions/iqn_q_function.py +++ b/d3rlpy/models/torch/q_functions/iqn_q_function.py @@ -61,6 +61,7 @@ class DiscreteIQNQFunction(DiscreteQFunction, nn.Module): # type: ignore def __init__( self, encoder: Encoder, + hidden_size: int, action_size: int, n_quantiles: int, n_greedy_quantiles: int, @@ -69,11 +70,11 @@ def __init__( super().__init__() self._encoder = encoder self._action_size = action_size - self._fc = nn.Linear(encoder.get_feature_size(), self._action_size) + self._fc = nn.Linear(hidden_size, self._action_size) self._n_quantiles = n_quantiles self._n_greedy_quantiles = n_greedy_quantiles self._embed_size = embed_size - self._embed = nn.Linear(embed_size, encoder.get_feature_size()) + self._embed = nn.Linear(embed_size, hidden_size) def _make_taus(self, h: torch.Tensor) -> torch.Tensor: if self.training: @@ -156,18 +157,20 @@ class ContinuousIQNQFunction(ContinuousQFunction, nn.Module): # type: ignore def __init__( self, encoder: EncoderWithAction, + hidden_size: int, + action_size: int, n_quantiles: int, n_greedy_quantiles: int, embed_size: int, ): super().__init__() self._encoder = encoder - self._action_size = encoder.action_size - self._fc = nn.Linear(encoder.get_feature_size(), 1) + self._action_size = action_size + self._fc = nn.Linear(hidden_size, 1) self._n_quantiles = n_quantiles self._n_greedy_quantiles = n_greedy_quantiles self._embed_size = embed_size - self._embed = nn.Linear(embed_size, encoder.get_feature_size()) + self._embed = nn.Linear(embed_size, hidden_size) def _make_taus(self, h: torch.Tensor) -> torch.Tensor: if self.training: diff --git a/d3rlpy/models/torch/q_functions/mean_q_function.py b/d3rlpy/models/torch/q_functions/mean_q_function.py index 60fbaa16..d3d5e7aa 100644 --- a/d3rlpy/models/torch/q_functions/mean_q_function.py +++ b/d3rlpy/models/torch/q_functions/mean_q_function.py @@ -16,11 +16,11 @@ class DiscreteMeanQFunction(DiscreteQFunction, nn.Module): # type: ignore _encoder: Encoder _fc: nn.Linear - def __init__(self, encoder: Encoder, action_size: int): + def __init__(self, encoder: Encoder, hidden_size: int, action_size: int): super().__init__() self._action_size = action_size self._encoder = encoder - self._fc = nn.Linear(encoder.get_feature_size(), action_size) + self._fc = nn.Linear(hidden_size, action_size) def forward(self, x: torch.Tensor) -> torch.Tensor: return cast(torch.Tensor, self._fc(self._encoder(x))) @@ -64,11 +64,13 @@ class ContinuousMeanQFunction(ContinuousQFunction, nn.Module): # type: ignore _action_size: int _fc: nn.Linear - def __init__(self, encoder: EncoderWithAction): + def __init__( + self, encoder: EncoderWithAction, hidden_size: int, action_size: int + ): super().__init__() self._encoder = encoder - self._action_size = encoder.action_size - self._fc = nn.Linear(encoder.get_feature_size(), 1) + self._action_size = action_size + self._fc = nn.Linear(hidden_size, 1) def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: return cast(torch.Tensor, self._fc(self._encoder(x, action))) diff --git a/d3rlpy/models/torch/q_functions/qr_q_function.py b/d3rlpy/models/torch/q_functions/qr_q_function.py index 7c5337e8..448de90c 100644 --- a/d3rlpy/models/torch/q_functions/qr_q_function.py +++ b/d3rlpy/models/torch/q_functions/qr_q_function.py @@ -27,14 +27,18 @@ class DiscreteQRQFunction(DiscreteQFunction, nn.Module): # type: ignore _n_quantiles: int _fc: nn.Linear - def __init__(self, encoder: Encoder, action_size: int, n_quantiles: int): + def __init__( + self, + encoder: Encoder, + hidden_size: int, + action_size: int, + n_quantiles: int, + ): super().__init__() self._encoder = encoder self._action_size = action_size self._n_quantiles = n_quantiles - self._fc = nn.Linear( - encoder.get_feature_size(), action_size * n_quantiles - ) + self._fc = nn.Linear(hidden_size, action_size * n_quantiles) def _compute_quantiles( self, h: torch.Tensor, taus: torch.Tensor @@ -102,12 +106,18 @@ class ContinuousQRQFunction(ContinuousQFunction, nn.Module): # type: ignore _n_quantiles: int _fc: nn.Linear - def __init__(self, encoder: EncoderWithAction, n_quantiles: int): + def __init__( + self, + encoder: EncoderWithAction, + hidden_size: int, + action_size: int, + n_quantiles: int, + ): super().__init__() self._encoder = encoder - self._action_size = encoder.action_size + self._action_size = action_size self._n_quantiles = n_quantiles - self._fc = nn.Linear(encoder.get_feature_size(), n_quantiles) + self._fc = nn.Linear(hidden_size, n_quantiles) def _compute_quantiles( self, h: torch.Tensor, taus: torch.Tensor diff --git a/d3rlpy/models/torch/transformers.py b/d3rlpy/models/torch/transformers.py index 32a36404..81eb3824 100644 --- a/d3rlpy/models/torch/transformers.py +++ b/d3rlpy/models/torch/transformers.py @@ -248,6 +248,7 @@ class ContinuousDecisionTransformer(nn.Module): # type: ignore def __init__( self, encoder: Encoder, + feature_size: int, position_encoding: PositionEncoding, action_size: int, num_heads: int, @@ -261,11 +262,11 @@ def __init__( super().__init__() self._encoder = encoder self._position_encoding = position_encoding - self._action_embed = nn.Linear(action_size, encoder.get_feature_size()) - self._rtg_embed = nn.Linear(1, encoder.get_feature_size()) - self._embed_ln = nn.LayerNorm(encoder.get_feature_size()) + self._action_embed = nn.Linear(action_size, feature_size) + self._rtg_embed = nn.Linear(1, feature_size) + self._embed_ln = nn.LayerNorm(feature_size) self._gpt2 = GPT2( - hidden_size=encoder.get_feature_size(), + hidden_size=feature_size, num_heads=num_heads, context_size=3 * context_size, num_layers=num_layers, @@ -274,7 +275,7 @@ def __init__( embed_dropout=embed_dropout, activation=activation, ) - self._output = nn.Linear(encoder.get_feature_size(), action_size) + self._output = nn.Linear(feature_size, action_size) def forward( self, @@ -322,6 +323,7 @@ class DiscreteDecisionTransformer(nn.Module): # type: ignore def __init__( self, encoder: Encoder, + feature_size: int, position_encoding: PositionEncoding, action_size: int, num_heads: int, @@ -335,12 +337,10 @@ def __init__( super().__init__() self._encoder = encoder self._position_encoding = position_encoding - self._action_embed = nn.Embedding( - action_size, encoder.get_feature_size() - ) - self._rtg_embed = nn.Linear(1, encoder.get_feature_size()) + self._action_embed = nn.Embedding(action_size, feature_size) + self._rtg_embed = nn.Linear(1, feature_size) self._gpt2 = GPT2( - hidden_size=encoder.get_feature_size(), + hidden_size=feature_size, num_heads=num_heads, context_size=3 * context_size, num_layers=num_layers, @@ -349,9 +349,7 @@ def __init__( embed_dropout=embed_dropout, activation=activation, ) - self._output = nn.Linear( - encoder.get_feature_size(), action_size, bias=False - ) + self._output = nn.Linear(feature_size, action_size, bias=False) def forward( self, diff --git a/d3rlpy/models/torch/v_functions.py b/d3rlpy/models/torch/v_functions.py index 58e57eaf..6c0d0cef 100644 --- a/d3rlpy/models/torch/v_functions.py +++ b/d3rlpy/models/torch/v_functions.py @@ -13,10 +13,10 @@ class ValueFunction(nn.Module): # type: ignore _encoder: Encoder _fc: nn.Linear - def __init__(self, encoder: Encoder): + def __init__(self, encoder: Encoder, hidden_size: int): super().__init__() self._encoder = encoder - self._fc = nn.Linear(encoder.get_feature_size(), 1) + self._fc = nn.Linear(hidden_size, 1) def forward(self, x: torch.Tensor) -> torch.Tensor: h = self._encoder(x) diff --git a/docs/references/network_architectures.rst b/docs/references/network_architectures.rst index faa74a17..728f519f 100644 --- a/docs/references/network_architectures.rst +++ b/docs/references/network_architectures.rst @@ -47,10 +47,6 @@ You can also build your own encoder factory. h = torch.relu(self.fc2(h)) return h - # THIS IS IMPORTANT! - def get_feature_size(self): - return self.feature_size - # your own encoder factory class CustomEncoderFactory(EncoderFactory): TYPE = 'custom' # this is necessary @@ -87,9 +83,6 @@ controls. h = torch.relu(self.fc2(h)) return h - def get_feature_size(self): - return self.feature_size - class CustomEncoderFactory(EncoderFactory): TYPE = 'custom' # this is necessary @@ -133,4 +126,3 @@ your encoder configuration, you need to register your encoder factory. d3rlpy.models.DefaultEncoderFactory d3rlpy.models.PixelEncoderFactory d3rlpy.models.VectorEncoderFactory - d3rlpy.models.DenseEncoderFactory diff --git a/docs/tutorials/customize_neural_network.rst b/docs/tutorials/customize_neural_network.rst index 0110d270..020e38b0 100644 --- a/docs/tutorials/customize_neural_network.rst +++ b/docs/tutorials/customize_neural_network.rst @@ -10,8 +10,6 @@ Prepare PyTorch Model --------------------- If you're familiar with PyTorch, this step should be easy for you. -Please note that your model must have ``get_feature_size`` method to tell the -feature size to the final layer. .. code-block:: python @@ -31,9 +29,6 @@ feature size to the final layer. h = torch.relu(self.fc2(h)) return h - # THIS IS IMPORTANT! - def get_feature_size(self): - return self.feature_size Setup EncoderFactory -------------------- @@ -90,9 +85,6 @@ you need to prepare an action-conditioned model. h = torch.relu(self.fc2(h)) return h - def get_feature_size(self): - return self.feature_size - Finally, you can update your ``CustomEncoderFactory`` as follows. .. code-block:: python diff --git a/tests/models/test_encoders.py b/tests/models/test_encoders.py index 0635739c..e931b56e 100644 --- a/tests/models/test_encoders.py +++ b/tests/models/test_encoders.py @@ -5,7 +5,6 @@ from d3rlpy.models.encoders import ( DefaultEncoderFactory, - DenseEncoderFactory, PixelEncoderFactory, VectorEncoderFactory, ) @@ -105,32 +104,3 @@ def test_default_encoder_factory( # check serization and deserialization DefaultEncoderFactory.deserialize(factory.serialize()) - - -@pytest.mark.parametrize("observation_shape", [(100,)]) -@pytest.mark.parametrize("action_size", [2]) -@pytest.mark.parametrize("discrete_action", [False, True]) -def test_dense_encoder_factory( - observation_shape: Sequence[int], - action_size: int, - discrete_action: bool, -) -> None: - factory = DenseEncoderFactory() - - # test state encoder - encoder = factory.create(observation_shape) - assert isinstance(encoder, VectorEncoder) - assert encoder._use_dense - - # test state-action encoder - encoder = factory.create_with_action( - observation_shape, action_size, discrete_action - ) - assert isinstance(encoder, VectorEncoderWithAction) - assert encoder._discrete_action == discrete_action - assert encoder._use_dense - - assert factory.get_type() == "dense" - - # check serization and deserialization - DenseEncoderFactory.deserialize(factory.serialize()) diff --git a/tests/models/test_q_functions.py b/tests/models/test_q_functions.py index a5c6a339..3f1098c4 100644 --- a/tests/models/test_q_functions.py +++ b/tests/models/test_q_functions.py @@ -4,20 +4,18 @@ from d3rlpy.models.encoders import VectorEncoderFactory from d3rlpy.models.q_functions import ( - FQFQFunctionFactory, IQNQFunctionFactory, MeanQFunctionFactory, QRQFunctionFactory, ) from d3rlpy.models.torch import ( - ContinuousFQFQFunction, ContinuousIQNQFunction, ContinuousMeanQFunction, ContinuousQRQFunction, - DiscreteFQFQFunction, DiscreteIQNQFunction, DiscreteMeanQFunction, DiscreteQRQFunction, + compute_output_size, ) from d3rlpy.models.torch.encoders import Encoder, EncoderWithAction @@ -45,11 +43,17 @@ def test_mean_q_function_factory( encoder_with_action = _create_encoder_with_action( observation_shape, action_size ) - q_func = factory.create_continuous(encoder_with_action) + hidden_size = compute_output_size( + [observation_shape, (action_size,)], encoder_with_action, "cpu:0" + ) + q_func = factory.create_continuous( + encoder_with_action, hidden_size, action_size + ) assert isinstance(q_func, ContinuousMeanQFunction) encoder = _create_encoder(observation_shape) - discrete_q_func = factory.create_discrete(encoder, action_size) + hidden_size = compute_output_size([observation_shape], encoder, "cpu:0") + discrete_q_func = factory.create_discrete(encoder, hidden_size, action_size) assert isinstance(discrete_q_func, DiscreteMeanQFunction) # check serization and deserialization @@ -67,11 +71,17 @@ def test_qr_q_function_factory( encoder_with_action = _create_encoder_with_action( observation_shape, action_size ) - q_func = factory.create_continuous(encoder_with_action) + hidden_size = compute_output_size( + [observation_shape, (action_size,)], encoder_with_action, "cpu:0" + ) + q_func = factory.create_continuous( + encoder_with_action, hidden_size, action_size + ) assert isinstance(q_func, ContinuousQRQFunction) encoder = _create_encoder(observation_shape) - discrete_q_func = factory.create_discrete(encoder, action_size) + hidden_size = compute_output_size([observation_shape], encoder, "cpu:0") + discrete_q_func = factory.create_discrete(encoder, hidden_size, action_size) assert isinstance(discrete_q_func, DiscreteQRQFunction) # check serization and deserialization @@ -89,34 +99,18 @@ def test_iqn_q_function_factory( encoder_with_action = _create_encoder_with_action( observation_shape, action_size ) - q_func = factory.create_continuous(encoder_with_action) + hidden_size = compute_output_size( + [observation_shape, (action_size,)], encoder_with_action, "cpu:0" + ) + q_func = factory.create_continuous( + encoder_with_action, hidden_size, action_size + ) assert isinstance(q_func, ContinuousIQNQFunction) encoder = _create_encoder(observation_shape) - discrete_q_func = factory.create_discrete(encoder, action_size) + hidden_size = compute_output_size([observation_shape], encoder, "cpu:0") + discrete_q_func = factory.create_discrete(encoder, hidden_size, action_size) assert isinstance(discrete_q_func, DiscreteIQNQFunction) # check serization and deserialization IQNQFunctionFactory.deserialize(factory.serialize()) - - -@pytest.mark.parametrize("observation_shape", [(100,)]) -@pytest.mark.parametrize("action_size", [2]) -def test_fqf_q_function_factory( - observation_shape: Sequence[int], action_size: int -) -> None: - factory = FQFQFunctionFactory() - assert factory.get_type() == "fqf" - - encoder_with_action = _create_encoder_with_action( - observation_shape, action_size - ) - q_func = factory.create_continuous(encoder_with_action) - assert isinstance(q_func, ContinuousFQFQFunction) - - encoder = _create_encoder(observation_shape) - discrete_q_func = factory.create_discrete(encoder, action_size) - assert isinstance(discrete_q_func, DiscreteFQFQFunction) - - # check serization and deserialization - FQFQFunctionFactory.deserialize(factory.serialize()) diff --git a/tests/models/torch/model_test.py b/tests/models/torch/model_test.py index 18b34f9c..1568fda2 100644 --- a/tests/models/torch/model_test.py +++ b/tests/models/torch/model_test.py @@ -56,14 +56,14 @@ def ref_quantile_huber_loss( return element_wise_loss.sum(axis=2).mean(axis=1) -class DummyEncoder(torch.nn.Module, Encoder): # type: ignore +class DummyEncoder(Encoder): def __init__(self, feature_size: int): super().__init__() self.feature_size = feature_size self._observation_shape = (feature_size,) - def __call__(self, *args: Any) -> torch.Tensor: - return args[0] + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x @property def observation_shape(self) -> Sequence[int]: @@ -73,15 +73,15 @@ def get_feature_size(self) -> int: return self.feature_size -class DummyEncoderWithAction(torch.nn.Module, EncoderWithAction): # type: ignore +class DummyEncoderWithAction(EncoderWithAction): def __init__(self, feature_size: int, action_size: int): super().__init__() self.feature_size = feature_size self._observation_shape = (feature_size,) self._action_size = action_size - def __call__(self, *args: Any) -> torch.Tensor: - return torch.cat([args[0][:, : -args[1].shape[1]], args[1]], dim=1) + def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: + return torch.cat([x[:, : -action.shape[1]], action], dim=1) def get_feature_size(self) -> int: return self.feature_size diff --git a/tests/models/torch/q_functions/test_ensemble_q_function.py b/tests/models/torch/q_functions/test_ensemble_q_function.py index 242439a7..8db2c969 100644 --- a/tests/models/torch/q_functions/test_ensemble_q_function.py +++ b/tests/models/torch/q_functions/test_ensemble_q_function.py @@ -4,12 +4,10 @@ import torch from d3rlpy.models.torch import ( - ContinuousFQFQFunction, ContinuousIQNQFunction, ContinuousMeanQFunction, ContinuousQFunction, ContinuousQRQFunction, - DiscreteFQFQFunction, DiscreteIQNQFunction, DiscreteMeanQFunction, DiscreteQFunction, @@ -79,7 +77,7 @@ def test_reduce_quantile_ensemble( @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("gamma", [0.99]) @pytest.mark.parametrize("ensemble_size", [5]) -@pytest.mark.parametrize("q_func_factory", ["mean", "qr", "iqn", "fqf"]) +@pytest.mark.parametrize("q_func_factory", ["mean", "qr", "iqn"]) @pytest.mark.parametrize("n_quantiles", [200]) @pytest.mark.parametrize("embed_size", [64]) def test_ensemble_discrete_q_function( @@ -96,16 +94,19 @@ def test_ensemble_discrete_q_function( for _ in range(ensemble_size): encoder = DummyEncoder(feature_size) if q_func_factory == "mean": - q_func = DiscreteMeanQFunction(encoder, action_size) + q_func = DiscreteMeanQFunction(encoder, feature_size, action_size) elif q_func_factory == "qr": - q_func = DiscreteQRQFunction(encoder, action_size, n_quantiles) + q_func = DiscreteQRQFunction( + encoder, feature_size, action_size, n_quantiles + ) elif q_func_factory == "iqn": q_func = DiscreteIQNQFunction( - encoder, action_size, n_quantiles, n_quantiles, embed_size - ) - elif q_func_factory == "fqf": - q_func = DiscreteFQFQFunction( - encoder, action_size, n_quantiles, embed_size + encoder, + feature_size, + action_size, + n_quantiles, + n_quantiles, + embed_size, ) q_funcs.append(q_func) q_func = EnsembleDiscreteQFunction(q_funcs) @@ -174,7 +175,7 @@ def test_ensemble_discrete_q_function( @pytest.mark.parametrize("gamma", [0.99]) @pytest.mark.parametrize("ensemble_size", [5]) @pytest.mark.parametrize("n_quantiles", [200]) -@pytest.mark.parametrize("q_func_factory", ["mean", "qr", "iqn", "fqf"]) +@pytest.mark.parametrize("q_func_factory", ["mean", "qr", "iqn"]) @pytest.mark.parametrize("embed_size", [64]) def test_ensemble_continuous_q_function( feature_size: int, @@ -190,15 +191,20 @@ def test_ensemble_continuous_q_function( for _ in range(ensemble_size): encoder = DummyEncoderWithAction(feature_size, action_size) if q_func_factory == "mean": - q_func = ContinuousMeanQFunction(encoder) + q_func = ContinuousMeanQFunction(encoder, feature_size, action_size) elif q_func_factory == "qr": - q_func = ContinuousQRQFunction(encoder, n_quantiles) + q_func = ContinuousQRQFunction( + encoder, feature_size, action_size, n_quantiles + ) elif q_func_factory == "iqn": q_func = ContinuousIQNQFunction( - encoder, n_quantiles, n_quantiles, embed_size + encoder, + feature_size, + action_size, + n_quantiles, + n_quantiles, + embed_size, ) - elif q_func_factory == "fqf": - q_func = ContinuousFQFQFunction(encoder, n_quantiles, embed_size) q_funcs.append(q_func) q_func = EnsembleContinuousQFunction(q_funcs) diff --git a/tests/models/torch/q_functions/test_fqf_q_function.py b/tests/models/torch/q_functions/test_fqf_q_function.py deleted file mode 100644 index 7ec5f82b..00000000 --- a/tests/models/torch/q_functions/test_fqf_q_function.py +++ /dev/null @@ -1,99 +0,0 @@ -import pytest -import torch - -from d3rlpy.models.torch import ContinuousFQFQFunction, DiscreteFQFQFunction - -from ..model_test import ( - DummyEncoder, - DummyEncoderWithAction, - check_parameter_updates, -) - - -@pytest.mark.parametrize("feature_size", [100]) -@pytest.mark.parametrize("action_size", [2]) -@pytest.mark.parametrize("n_quantiles", [200]) -@pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("embed_size", [64]) -def test_discrete_fqf_q_function( - feature_size: int, - action_size: int, - n_quantiles: int, - batch_size: int, - embed_size: int, -) -> None: - encoder = DummyEncoder(feature_size) - q_func = DiscreteFQFQFunction(encoder, action_size, n_quantiles, embed_size) - - # check output shape - x = torch.rand(batch_size, feature_size) - y = q_func(x) - assert y.shape == (batch_size, action_size) - - # check compute_target - action = torch.randint(high=action_size, size=(batch_size,)) - target = q_func.compute_target(x, action) - assert target.shape == (batch_size, n_quantiles) - - # check compute_target - targets = q_func.compute_target(x) - assert targets.shape == (batch_size, action_size, n_quantiles) - - # TODO: check quantile huber loss - obs_t = torch.rand(batch_size, feature_size) - act_t = torch.randint(action_size, size=(batch_size,)) - rew_tp1 = torch.rand(batch_size, 1) - q_tp1 = torch.rand(batch_size, n_quantiles) - ter_tp1 = torch.randint(2, size=(batch_size, 1)) - # check shape - loss = q_func.compute_error( - obs_t, act_t, rew_tp1, q_tp1, ter_tp1, reduction="none" - ) - assert loss.shape == (batch_size, 1) - # mean loss - loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1) - - # check layer connection - check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1, ter_tp1)) - - -@pytest.mark.parametrize("feature_size", [100]) -@pytest.mark.parametrize("action_size", [2]) -@pytest.mark.parametrize("n_quantiles", [200]) -@pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("embed_size", [64]) -def test_continuous_fqf_q_function( - feature_size: int, - action_size: int, - n_quantiles: int, - batch_size: int, - embed_size: int, -) -> None: - encoder = DummyEncoderWithAction(feature_size, action_size) - q_func = ContinuousFQFQFunction(encoder, n_quantiles, embed_size) - - # check output shape - x = torch.rand(batch_size, feature_size) - action = torch.rand(batch_size, action_size) - y = q_func(x, action) - assert y.shape == (batch_size, 1) - - target = q_func.compute_target(x, action) - assert target.shape == (batch_size, n_quantiles) - - # TODO: check quantile huber loss - obs_t = torch.rand(batch_size, feature_size) - act_t = torch.rand(batch_size, action_size) - rew_tp1 = torch.rand(batch_size, 1) - q_tp1 = torch.rand(batch_size, n_quantiles) - ter_tp1 = torch.randint(2, size=(batch_size, 1)) - # check shape - loss = q_func.compute_error( - obs_t, act_t, rew_tp1, q_tp1, ter_tp1, reduction="none" - ) - assert loss.shape == (batch_size, 1) - # mean loss - loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1) - - # check layer connection - check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1, ter_tp1)) diff --git a/tests/models/torch/q_functions/test_iqn_q_function.py b/tests/models/torch/q_functions/test_iqn_q_function.py index ab4f4cab..553809d8 100644 --- a/tests/models/torch/q_functions/test_iqn_q_function.py +++ b/tests/models/torch/q_functions/test_iqn_q_function.py @@ -26,7 +26,12 @@ def test_discrete_iqn_q_function( ) -> None: encoder = DummyEncoder(feature_size) q_func = DiscreteIQNQFunction( - encoder, action_size, n_quantiles, n_greedy_quantiles, embed_size + encoder, + feature_size, + action_size, + n_quantiles, + n_greedy_quantiles, + embed_size, ) # check output shape @@ -84,7 +89,12 @@ def test_continuous_iqn_q_function( ) -> None: encoder = DummyEncoderWithAction(feature_size, action_size) q_func = ContinuousIQNQFunction( - encoder, n_quantiles, n_greedy_quantiles, embed_size + encoder, + feature_size, + action_size, + n_quantiles, + n_greedy_quantiles, + embed_size, ) # check output shape diff --git a/tests/models/torch/q_functions/test_mean_q_function.py b/tests/models/torch/q_functions/test_mean_q_function.py index 3583df3f..7c54e9a4 100644 --- a/tests/models/torch/q_functions/test_mean_q_function.py +++ b/tests/models/torch/q_functions/test_mean_q_function.py @@ -27,7 +27,7 @@ def test_discrete_mean_q_function( feature_size: int, action_size: int, batch_size: int, gamma: float ) -> None: encoder = DummyEncoder(feature_size) - q_func = DiscreteMeanQFunction(encoder, action_size) + q_func = DiscreteMeanQFunction(encoder, feature_size, action_size) # check output shape x = torch.rand(batch_size, feature_size) @@ -80,7 +80,7 @@ def test_continuous_mean_q_function( gamma: float, ) -> None: encoder = DummyEncoderWithAction(feature_size, action_size) - q_func = ContinuousMeanQFunction(encoder) + q_func = ContinuousMeanQFunction(encoder, feature_size, action_size) # check output shape x = torch.rand(batch_size, feature_size) diff --git a/tests/models/torch/q_functions/test_qr_q_function.py b/tests/models/torch/q_functions/test_qr_q_function.py index 6d854ae9..586634b5 100644 --- a/tests/models/torch/q_functions/test_qr_q_function.py +++ b/tests/models/torch/q_functions/test_qr_q_function.py @@ -30,7 +30,9 @@ def test_discrete_qr_q_function( gamma: float, ) -> None: encoder = DummyEncoder(feature_size) - q_func = DiscreteQRQFunction(encoder, action_size, n_quantiles) + q_func = DiscreteQRQFunction( + encoder, feature_size, action_size, n_quantiles + ) # check output shape x = torch.rand(batch_size, feature_size) @@ -97,7 +99,9 @@ def test_continuous_qr_q_function( gamma: float, ) -> None: encoder = DummyEncoderWithAction(feature_size, action_size) - q_func = ContinuousQRQFunction(encoder, n_quantiles) + q_func = ContinuousQRQFunction( + encoder, feature_size, action_size, n_quantiles + ) # check output shape x = torch.rand(batch_size, feature_size) diff --git a/tests/models/torch/test_encoders.py b/tests/models/torch/test_encoders.py index ab9f6ee9..a6b12f13 100644 --- a/tests/models/torch/test_encoders.py +++ b/tests/models/torch/test_encoders.py @@ -20,7 +20,7 @@ @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("use_batch_norm", [False, True]) @pytest.mark.parametrize("dropout_rate", [None, 0.2]) -@pytest.mark.parametrize("activation", [torch.relu]) +@pytest.mark.parametrize("activation", [torch.nn.ReLU()]) def test_pixel_encoder( shapes: Tuple[Sequence[int], int], filters: List[List[int]], @@ -30,7 +30,7 @@ def test_pixel_encoder( dropout_rate: Optional[float], activation: torch.nn.Module, ) -> None: - observation_shape, linear_input_size = shapes + observation_shape, _ = shapes encoder = PixelEncoder( observation_shape=observation_shape, @@ -44,7 +44,6 @@ def test_pixel_encoder( y = encoder(x) # check output shape - assert encoder._get_linear_input_size() == linear_input_size assert y.shape == (batch_size, feature_size) # check use of batch norm @@ -67,7 +66,7 @@ def test_pixel_encoder( @pytest.mark.parametrize("use_batch_norm", [False, True]) @pytest.mark.parametrize("dropout_rate", [None, 0.2]) @pytest.mark.parametrize("discrete_action", [False, True]) -@pytest.mark.parametrize("activation", [torch.relu]) +@pytest.mark.parametrize("activation", [torch.nn.ReLU()]) def test_pixel_encoder_with_action( shapes: Tuple[Sequence[int], int], action_size: int, @@ -79,7 +78,7 @@ def test_pixel_encoder_with_action( discrete_action: bool, activation: torch.nn.Module, ) -> None: - observation_shape, linear_input_size = shapes + observation_shape, _ = shapes encoder = PixelEncoderWithAction( observation_shape=observation_shape, @@ -99,7 +98,6 @@ def test_pixel_encoder_with_action( y = encoder(x, action) # check output shape - assert encoder._get_linear_input_size() == linear_input_size + action_size assert y.shape == (batch_size, feature_size) # check use of batch norm @@ -119,15 +117,13 @@ def test_pixel_encoder_with_action( @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("use_batch_norm", [False, True]) @pytest.mark.parametrize("dropout_rate", [None, 0.2]) -@pytest.mark.parametrize("use_dense", [False, True]) -@pytest.mark.parametrize("activation", [torch.relu]) +@pytest.mark.parametrize("activation", [torch.nn.ReLU()]) def test_vector_encoder( observation_shape: Sequence[int], hidden_units: Sequence[int], batch_size: int, use_batch_norm: bool, dropout_rate: Optional[float], - use_dense: bool, activation: torch.nn.Module, ) -> None: encoder = VectorEncoder( @@ -135,7 +131,6 @@ def test_vector_encoder( hidden_units=hidden_units, use_batch_norm=use_batch_norm, dropout_rate=dropout_rate, - use_dense=use_dense, activation=activation, ) @@ -143,7 +138,6 @@ def test_vector_encoder( y = encoder(x) # check output shape - assert encoder.get_feature_size() == hidden_units[-1] assert y.shape == (batch_size, hidden_units[-1]) # check use of batch norm @@ -164,9 +158,8 @@ def test_vector_encoder( @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("use_batch_norm", [False, True]) @pytest.mark.parametrize("dropout_rate", [None, 0.2]) -@pytest.mark.parametrize("use_dense", [False, True]) @pytest.mark.parametrize("discrete_action", [False, True]) -@pytest.mark.parametrize("activation", [torch.relu]) +@pytest.mark.parametrize("activation", [torch.nn.ReLU()]) def test_vector_encoder_with_action( observation_shape: Sequence[int], action_size: int, @@ -174,7 +167,6 @@ def test_vector_encoder_with_action( batch_size: int, use_batch_norm: bool, dropout_rate: Optional[float], - use_dense: bool, discrete_action: bool, activation: torch.nn.Module, ) -> None: @@ -184,7 +176,6 @@ def test_vector_encoder_with_action( hidden_units=hidden_units, use_batch_norm=use_batch_norm, dropout_rate=dropout_rate, - use_dense=use_dense, discrete_action=discrete_action, activation=activation, ) @@ -197,7 +188,6 @@ def test_vector_encoder_with_action( y = encoder(x, action) # check output shape - assert encoder.get_feature_size() == hidden_units[-1] assert y.shape == (batch_size, hidden_units[-1]) # check use of batch norm diff --git a/tests/models/torch/test_imitators.py b/tests/models/torch/test_imitators.py index 3e3f678a..44b72b26 100644 --- a/tests/models/torch/test_imitators.py +++ b/tests/models/torch/test_imitators.py @@ -32,7 +32,14 @@ def test_conditional_vae( ) -> None: encoder_encoder = DummyEncoderWithAction(feature_size, action_size) decoder_encoder = DummyEncoderWithAction(feature_size, latent_size) - vae = ConditionalVAE(encoder_encoder, decoder_encoder, beta) + vae = ConditionalVAE( + encoder_encoder=encoder_encoder, + decoder_encoder=decoder_encoder, + hidden_size=feature_size, + latent_size=latent_size, + action_size=action_size, + beta=beta, + ) # check output shape x = torch.rand(batch_size, feature_size) @@ -76,7 +83,12 @@ def test_discrete_imitator( feature_size: int, action_size: int, beta: float, batch_size: int ) -> None: encoder = DummyEncoder(feature_size) - imitator = DiscreteImitator(encoder, action_size, beta) + imitator = DiscreteImitator( + encoder=encoder, + hidden_size=feature_size, + action_size=action_size, + beta=beta, + ) # check output shape x = torch.rand(batch_size, feature_size) @@ -101,7 +113,11 @@ def test_deterministic_regressor( feature_size: int, action_size: int, batch_size: int ) -> None: encoder = DummyEncoder(feature_size) - imitator = DeterministicRegressor(encoder, action_size) + imitator = DeterministicRegressor( + encoder=encoder, + hidden_size=feature_size, + action_size=action_size, + ) x = torch.rand(batch_size, feature_size) y = imitator(x) @@ -124,7 +140,11 @@ def test_probablistic_regressor( ) -> None: encoder = DummyEncoder(feature_size) imitator = ProbablisticRegressor( - encoder, action_size, min_logstd=-20, max_logstd=2 + encoder=encoder, + hidden_size=feature_size, + action_size=action_size, + min_logstd=-20, + max_logstd=2, ) x = torch.rand(batch_size, feature_size) diff --git a/tests/models/torch/test_policies.py b/tests/models/torch/test_policies.py index f8cb82cc..f2afccd2 100644 --- a/tests/models/torch/test_policies.py +++ b/tests/models/torch/test_policies.py @@ -27,7 +27,11 @@ def test_deterministic_policy( feature_size: int, action_size: int, batch_size: int ) -> None: encoder = DummyEncoder(feature_size) - policy = DeterministicPolicy(encoder, action_size) + policy = DeterministicPolicy( + encoder=encoder, + hidden_size=feature_size, + action_size=action_size, + ) # check output shape x = torch.rand(batch_size, feature_size) @@ -50,7 +54,12 @@ def test_deterministic_residual_policy( feature_size: int, action_size: int, scale: float, batch_size: int ) -> None: encoder = DummyEncoderWithAction(feature_size, action_size) - policy = DeterministicResidualPolicy(encoder, scale) + policy = DeterministicResidualPolicy( + encoder=encoder, + hidden_size=feature_size, + action_size=action_size, + scale=scale, + ) # check output shape x = torch.rand(batch_size, feature_size) @@ -88,7 +97,12 @@ def test_squashed_normal_policy( ) -> None: encoder = DummyEncoder(feature_size) policy = SquashedNormalPolicy( - encoder, action_size, min_logstd, max_logstd, use_std_parameter + encoder=encoder, + hidden_size=feature_size, + action_size=action_size, + min_logstd=min_logstd, + max_logstd=max_logstd, + use_std_parameter=use_std_parameter, ) # check output shape @@ -137,7 +151,12 @@ def test_non_squashed_normal_policy( ) -> None: encoder = DummyEncoder(feature_size) policy = NonSquashedNormalPolicy( - encoder, action_size, min_logstd, max_logstd, use_std_parameter + encoder=encoder, + hidden_size=feature_size, + action_size=action_size, + min_logstd=min_logstd, + max_logstd=max_logstd, + use_std_parameter=use_std_parameter, ) # check output shape @@ -176,7 +195,11 @@ def test_categorical_policy( feature_size: int, action_size: int, batch_size: int, n: int ) -> None: encoder = DummyEncoder(feature_size) - policy = CategoricalPolicy(encoder, action_size) + policy = CategoricalPolicy( + encoder=encoder, + hidden_size=feature_size, + action_size=action_size, + ) # check output shape x = torch.rand(batch_size, feature_size) diff --git a/tests/models/torch/test_transformers.py b/tests/models/torch/test_transformers.py index bc82f88c..e2592be1 100644 --- a/tests/models/torch/test_transformers.py +++ b/tests/models/torch/test_transformers.py @@ -211,6 +211,7 @@ def test_continuous_decision_transformer( model = ContinuousDecisionTransformer( encoder=encoder, + feature_size=hidden_size, position_encoding=SimplePositionEncoding(hidden_size, max_timestep), action_size=action_size, num_heads=num_heads, @@ -257,6 +258,7 @@ def test_discrete_decision_transformer( model = DiscreteDecisionTransformer( encoder=encoder, + feature_size=hidden_size, position_encoding=SimplePositionEncoding(hidden_size, max_timestep), action_size=action_size, num_heads=num_heads, diff --git a/tests/models/torch/test_v_functions.py b/tests/models/torch/test_v_functions.py index 0ad582ff..04d252e8 100644 --- a/tests/models/torch/test_v_functions.py +++ b/tests/models/torch/test_v_functions.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize("batch_size", [32]) def test_value_function(feature_size: int, batch_size: int) -> None: encoder = DummyEncoder(feature_size) - v_func = ValueFunction(encoder) + v_func = ValueFunction(encoder, feature_size) # check output shape x = torch.rand(batch_size, feature_size)