From 449bd7829bc1ecbfb70b82ae1bb2042f1569d0e1 Mon Sep 17 00:00:00 2001 From: takuseno Date: Fri, 11 Aug 2023 16:44:23 +0900 Subject: [PATCH] Refactor imitator models --- d3rlpy/algos/qlearning/bc.py | 14 +- d3rlpy/algos/qlearning/bcq.py | 10 +- d3rlpy/algos/qlearning/torch/bc_impl.py | 91 +++++-------- d3rlpy/algos/qlearning/torch/bcq_impl.py | 18 ++- d3rlpy/algos/qlearning/torch/sac_impl.py | 3 +- d3rlpy/models/builders.py | 63 --------- d3rlpy/models/torch/distributions.py | 9 +- d3rlpy/models/torch/imitators.py | 162 +++++------------------ tests/models/test_builders.py | 75 +---------- tests/models/torch/test_imitators.py | 84 ++++++------ 10 files changed, 129 insertions(+), 400 deletions(-) diff --git a/d3rlpy/algos/qlearning/bc.py b/d3rlpy/algos/qlearning/bc.py index ff711488..b77bda48 100644 --- a/d3rlpy/algos/qlearning/bc.py +++ b/d3rlpy/algos/qlearning/bc.py @@ -5,9 +5,9 @@ from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace from ...dataset import Shape from ...models.builders import ( - create_deterministic_regressor, - create_discrete_imitator, - create_probablistic_regressor, + create_categorical_policy, + create_deterministic_policy, + create_normal_policy, ) from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field @@ -75,14 +75,14 @@ def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: if self._config.policy_type == "deterministic": - imitator = create_deterministic_regressor( + imitator = create_deterministic_policy( observation_shape, action_size, self._config.encoder_factory, device=self._device, ) elif self._config.policy_type == "stochastic": - imitator = create_probablistic_regressor( + imitator = create_normal_policy( observation_shape, action_size, self._config.encoder_factory, @@ -102,7 +102,6 @@ def inner_create_impl( action_size=action_size, imitator=imitator, optim=optim, - policy_type=self._config.policy_type, device=self._device, ) @@ -156,10 +155,9 @@ class DiscreteBC(_BCBase[DiscreteBCConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - imitator = create_discrete_imitator( + imitator = create_categorical_policy( observation_shape, action_size, - self._config.beta, self._config.encoder_factory, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index 3fb46292..8a859dfa 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -5,16 +5,16 @@ from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace from ...dataset import Shape from ...models.builders import ( + create_categorical_policy, create_conditional_vae, create_continuous_q_function, create_deterministic_residual_policy, - create_discrete_imitator, create_discrete_q_function, ) from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...models.torch import DiscreteImitator, PixelEncoder, compute_output_size +from ...models.torch import CategoricalPolicy, PixelEncoder, compute_output_size from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.bcq_impl import BCQImpl, DiscreteBCQImpl @@ -342,18 +342,16 @@ def inner_create_impl( q_func.q_funcs[0].encoder, device=self._device, ) - imitator = DiscreteImitator( + imitator = CategoricalPolicy( encoder=q_func.q_funcs[0].encoder, hidden_size=hidden_size, action_size=action_size, - beta=self._config.beta, ) imitator.to(self._device) else: - imitator = create_discrete_imitator( + imitator = create_categorical_policy( observation_shape, action_size, - self._config.beta, self._config.encoder_factory, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/torch/bc_impl.py b/d3rlpy/algos/qlearning/torch/bc_impl.py index 087466ea..b5f77814 100644 --- a/d3rlpy/algos/qlearning/torch/bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/bc_impl.py @@ -1,4 +1,4 @@ -from abc import ABCMeta +from abc import ABCMeta, abstractmethod from typing import Union import torch @@ -6,16 +6,15 @@ from ....dataset import Shape from ....models.torch import ( + CategoricalPolicy, DeterministicPolicy, - DeterministicRegressor, - DiscreteImitator, - Imitator, NormalPolicy, Policy, - ProbablisticRegressor, - compute_output_size, + compute_deterministic_imitation_loss, + compute_discrete_imitation_loss, + compute_stochastic_imitation_loss, ) -from ....torch_utility import TorchMiniBatch, hard_sync, train_api +from ....torch_utility import TorchMiniBatch, train_api from ..base import QLearningAlgoImplBase __all__ = ["BCImpl", "DiscreteBCImpl"] @@ -23,14 +22,12 @@ class BCBaseImpl(QLearningAlgoImplBase, metaclass=ABCMeta): _learning_rate: float - _imitator: Imitator _optim: Optimizer def __init__( self, observation_shape: Shape, action_size: int, - imitator: Imitator, optim: Optimizer, device: str, ): @@ -39,7 +36,6 @@ def __init__( action_size=action_size, device=device, ) - self._imitator = imitator self._optim = optim @train_api @@ -53,13 +49,11 @@ def update_imitator(self, batch: TorchMiniBatch) -> float: return float(loss.cpu().detach().numpy()) + @abstractmethod def compute_loss( self, obs_t: torch.Tensor, act_t: torch.Tensor ) -> torch.Tensor: - return self._imitator.compute_error(obs_t, act_t) - - def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - return self._imitator(x) + pass def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: return self.inner_predict_best_action(x) @@ -71,63 +65,42 @@ def inner_predict_value( class BCImpl(BCBaseImpl): - _policy_type: str - _imitator: Union[DeterministicRegressor, ProbablisticRegressor] + _imitator: Union[DeterministicPolicy, NormalPolicy] def __init__( self, observation_shape: Shape, action_size: int, - imitator: Union[DeterministicRegressor, ProbablisticRegressor], + imitator: Union[DeterministicPolicy, NormalPolicy], optim: Optimizer, - policy_type: str, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - imitator=imitator, optim=optim, device=device, ) - self._policy_type = policy_type + self._imitator = imitator - @property - def policy(self) -> Policy: - policy: Policy - hidden_size = compute_output_size( - [self._observation_shape, (self._action_size,)], - self._imitator.encoder, - device=self._device, - ) - if self._policy_type == "deterministic": - hidden_size = compute_output_size( - [self._observation_shape, (self._action_size,)], - self._imitator.encoder, - device=self._device, - ) - policy = DeterministicPolicy( - encoder=self._imitator.encoder, - hidden_size=hidden_size, - action_size=self._action_size, - ) - elif self._policy_type == "stochastic": - return NormalPolicy( - encoder=self._imitator.encoder, - hidden_size=hidden_size, - action_size=self._action_size, - min_logstd=-4.0, - max_logstd=15.0, - use_std_parameter=False, + def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: + return self._imitator(x).squashed_mu + + def compute_loss( + self, obs_t: torch.Tensor, act_t: torch.Tensor + ) -> torch.Tensor: + if isinstance(self._imitator, DeterministicPolicy): + return compute_deterministic_imitation_loss( + self._imitator, obs_t, act_t ) else: - raise ValueError(f"invalid policy_type: {self._policy_type}") - policy.to(self._device) - - # copy parameters - hard_sync(policy, self._imitator) + return compute_stochastic_imitation_loss( + self._imitator, obs_t, act_t + ) - return policy + @property + def policy(self) -> Policy: + return self._imitator @property def policy_optim(self) -> Optimizer: @@ -136,13 +109,13 @@ def policy_optim(self) -> Optimizer: class DiscreteBCImpl(BCBaseImpl): _beta: float - _imitator: DiscreteImitator + _imitator: CategoricalPolicy def __init__( self, observation_shape: Shape, action_size: int, - imitator: DiscreteImitator, + imitator: CategoricalPolicy, optim: Optimizer, beta: float, device: str, @@ -150,16 +123,18 @@ def __init__( super().__init__( observation_shape=observation_shape, action_size=action_size, - imitator=imitator, optim=optim, device=device, ) + self._imitator = imitator self._beta = beta def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - return self._imitator(x).argmax(dim=1) + return self._imitator(x).logits.argmax(dim=1) def compute_loss( self, obs_t: torch.Tensor, act_t: torch.Tensor ) -> torch.Tensor: - return self._imitator.compute_error(obs_t, act_t.long()) + return compute_discrete_imitation_loss( + policy=self._imitator, x=obs_t, action=act_t.long(), beta=self._beta + ) diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index 9eaebca2..76e8469d 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -2,15 +2,17 @@ from typing import cast import torch +import torch.nn.functional as F from torch.optim import Optimizer from ....dataset import Shape from ....models.torch import ( + CategoricalPolicy, ConditionalVAE, DeterministicResidualPolicy, - DiscreteImitator, EnsembleContinuousQFunction, EnsembleDiscreteQFunction, + compute_discrete_imitation_loss, compute_max_with_n_actions, compute_vae_error, forward_vae_decode, @@ -171,14 +173,14 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: class DiscreteBCQImpl(DoubleDQNImpl): _action_flexibility: float _beta: float - _imitator: DiscreteImitator + _imitator: CategoricalPolicy def __init__( self, observation_shape: Shape, action_size: int, q_func: EnsembleDiscreteQFunction, - imitator: DiscreteImitator, + imitator: CategoricalPolicy, optim: Optimizer, gamma: float, action_flexibility: float, @@ -201,13 +203,17 @@ def compute_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor ) -> torch.Tensor: loss = super().compute_loss(batch, q_tpn) - imitator_loss = self._imitator.compute_error( - batch.observations, batch.actions.long() + imitator_loss = compute_discrete_imitation_loss( + policy=self._imitator, + x=batch.observations, + action=batch.actions.long(), + beta=self._beta, ) return loss + imitator_loss def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - log_probs = self._imitator(x) + dist = self._imitator(x) + log_probs = F.log_softmax(dist.logits, dim=1) ratio = log_probs - log_probs.max(dim=1, keepdim=True).values mask = (ratio > math.log(self._action_flexibility)).float() value = self._q_func(x) diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index bfa298fb..c0806dca 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -3,6 +3,7 @@ from typing import Tuple import torch +import torch.nn.functional as F from torch.optim import Optimizer from ....dataset import Shape @@ -209,7 +210,7 @@ def update_temp(self, batch: TorchMiniBatch) -> Tuple[float, float]: with torch.no_grad(): dist = self._policy(batch.observations) - log_probs = dist.logits + log_probs = F.log_softmax(dist.logits, dim=1) probs = dist.probs expct_log_probs = (probs * log_probs).sum(dim=1, keepdim=True) entropy_target = 0.98 * (-math.log(1 / self.action_size)) diff --git a/d3rlpy/models/builders.py b/d3rlpy/models/builders.py index 4f7ab087..02e82ccd 100644 --- a/d3rlpy/models/builders.py +++ b/d3rlpy/models/builders.py @@ -11,16 +11,13 @@ ConditionalVAE, ContinuousDecisionTransformer, DeterministicPolicy, - DeterministicRegressor, DeterministicResidualPolicy, DiscreteDecisionTransformer, - DiscreteImitator, EnsembleContinuousQFunction, EnsembleDiscreteQFunction, GlobalPositionEncoding, NormalPolicy, Parameter, - ProbablisticRegressor, SimplePositionEncoding, VAEDecoder, VAEEncoder, @@ -37,9 +34,6 @@ "create_categorical_policy", "create_normal_policy", "create_conditional_vae", - "create_discrete_imitator", - "create_deterministic_regressor", - "create_probablistic_regressor", "create_value_function", "create_parameter", "create_continuous_decision_transformer", @@ -224,63 +218,6 @@ def create_conditional_vae( return policy -def create_discrete_imitator( - observation_shape: Shape, - action_size: int, - beta: float, - encoder_factory: EncoderFactory, - device: str, -) -> DiscreteImitator: - encoder = encoder_factory.create(observation_shape) - hidden_size = compute_output_size([observation_shape], encoder, device) - imitator = DiscreteImitator( - encoder=encoder, - hidden_size=hidden_size, - action_size=action_size, - beta=beta, - ) - imitator.to(device) - return imitator - - -def create_deterministic_regressor( - observation_shape: Shape, - action_size: int, - encoder_factory: EncoderFactory, - device: str, -) -> DeterministicRegressor: - encoder = encoder_factory.create(observation_shape) - hidden_size = compute_output_size([observation_shape], encoder, device) - regressor = DeterministicRegressor( - encoder=encoder, - hidden_size=hidden_size, - action_size=action_size, - ) - regressor.to(device) - return regressor - - -def create_probablistic_regressor( - observation_shape: Shape, - action_size: int, - encoder_factory: EncoderFactory, - device: str, - min_logstd: float = -20.0, - max_logstd: float = 2.0, -) -> ProbablisticRegressor: - encoder = encoder_factory.create(observation_shape) - hidden_size = compute_output_size([observation_shape], encoder, device) - regressor = ProbablisticRegressor( - encoder=encoder, - hidden_size=hidden_size, - action_size=action_size, - min_logstd=min_logstd, - max_logstd=max_logstd, - ) - regressor.to(device) - return regressor - - def create_value_function( observation_shape: Shape, encoder_factory: EncoderFactory, device: str ) -> ValueFunction: diff --git a/d3rlpy/models/torch/distributions.py b/d3rlpy/models/torch/distributions.py index 4278124c..e930cb42 100644 --- a/d3rlpy/models/torch/distributions.py +++ b/d3rlpy/models/torch/distributions.py @@ -1,6 +1,6 @@ import math from abc import ABCMeta, abstractmethod -from typing import Optional, Tuple +from typing import Tuple import torch import torch.nn.functional as F @@ -49,14 +49,13 @@ class GaussianDistribution(Distribution): def __init__( self, - loc: torch.Tensor, + loc: torch.Tensor, # squashed mean std: torch.Tensor, - raw_loc: Optional[torch.Tensor] = None, + raw_loc: torch.Tensor, ): self._mean = loc self._std = std - if raw_loc is not None: - self._raw_loc = raw_loc + self._raw_loc = raw_loc self._dist = Normal(self._mean, self._std) def sample(self) -> torch.Tensor: diff --git a/d3rlpy/models/torch/imitators.py b/d3rlpy/models/torch/imitators.py index 116e1e2e..ade1add5 100644 --- a/d3rlpy/models/torch/imitators.py +++ b/d3rlpy/models/torch/imitators.py @@ -1,5 +1,4 @@ -from abc import ABCMeta, abstractmethod -from typing import Tuple, cast +from typing import cast import torch import torch.nn.functional as F @@ -7,7 +6,13 @@ from torch.distributions import Normal from torch.distributions.kl import kl_divergence -from .encoders import Encoder, EncoderWithAction +from .encoders import EncoderWithAction +from .policies import ( + CategoricalPolicy, + DeterministicPolicy, + NormalPolicy, + build_gaussian_distribution, +) __all__ = [ "VAEEncoder", @@ -18,10 +23,9 @@ "forward_vae_sample", "forward_vae_sample_n", "compute_vae_error", - "Imitator", - "DiscreteImitator", - "DeterministicRegressor", - "ProbablisticRegressor", + "compute_discrete_imitation_loss", + "compute_deterministic_imitation_loss", + "compute_stochastic_imitation_loss", ] @@ -172,132 +176,26 @@ def compute_vae_error( return F.mse_loss(y, action) + cast(torch.Tensor, beta * kl_loss) -class Imitator(nn.Module, metaclass=ABCMeta): # type: ignore - @abstractmethod - def forward(self, x: torch.Tensor) -> torch.Tensor: - pass - - def __call__(self, x: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, super().__call__(x)) - - @abstractmethod - def compute_error( - self, x: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: - pass - - @property - @abstractmethod - def encoder(self) -> Encoder: - pass - - -class DiscreteImitator(Imitator): - _encoder: Encoder - _beta: float - _fc: nn.Linear - - def __init__( - self, encoder: Encoder, hidden_size: int, action_size: int, beta: float - ): - super().__init__() - self._encoder = encoder - self._beta = beta - self._fc = nn.Linear(hidden_size, action_size) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.compute_log_probs_with_logits(x)[0] - - def compute_log_probs_with_logits( - self, x: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - h = self._encoder(x) - logits = self._fc(h) - log_probs = F.log_softmax(logits, dim=1) - return log_probs, logits - - def compute_error( - self, x: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: - log_probs, logits = self.compute_log_probs_with_logits(x) - penalty = (logits**2).mean() - return F.nll_loss(log_probs, action.view(-1)) + self._beta * penalty - - @property - def encoder(self) -> Encoder: - return self._encoder - - -class DeterministicRegressor(Imitator): - _encoder: Encoder - _fc: nn.Linear - - def __init__(self, encoder: Encoder, hidden_size: int, action_size: int): - super().__init__() - self._encoder = encoder - self._fc = nn.Linear(hidden_size, action_size) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - h = self._encoder(x) - h = self._fc(h) - return torch.tanh(h) - - def compute_error( - self, x: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: - return F.mse_loss(self.forward(x), action) - - @property - def encoder(self) -> Encoder: - return self._encoder - - -class ProbablisticRegressor(Imitator): - _min_logstd: float - _max_logstd: float - _encoder: Encoder - _mu: nn.Linear - _logstd: nn.Linear - - def __init__( - self, - encoder: Encoder, - hidden_size: int, - action_size: int, - min_logstd: float, - max_logstd: float, - ): - super().__init__() - self._min_logstd = min_logstd - self._max_logstd = max_logstd - self._encoder = encoder - self._mu = nn.Linear(hidden_size, action_size) - self._logstd = nn.Linear(hidden_size, action_size) - - def dist(self, x: torch.Tensor) -> Normal: - h = self._encoder(x) - mu = self._mu(h) - logstd = self._logstd(h) - clipped_logstd = logstd.clamp(self._min_logstd, self._max_logstd) - return Normal(mu, clipped_logstd.exp()) +def compute_discrete_imitation_loss( + policy: CategoricalPolicy, + x: torch.Tensor, + action: torch.Tensor, + beta: float, +) -> torch.Tensor: + dist = policy(x) + penalty = (dist.logits**2).mean() + log_probs = F.log_softmax(dist.logits, dim=1) + return F.nll_loss(log_probs, action.view(-1)) + beta * penalty - def forward(self, x: torch.Tensor) -> torch.Tensor: - h = self._encoder(x) - mu = self._mu(h) - return torch.tanh(mu) - def sample_n(self, x: torch.Tensor, n: int) -> torch.Tensor: - dist = self.dist(x) - actions = cast(torch.Tensor, dist.rsample((n,))) - # (n, batch, action) -> (batch, n, action) - return actions.transpose(0, 1) +def compute_deterministic_imitation_loss( + policy: DeterministicPolicy, x: torch.Tensor, action: torch.Tensor +) -> torch.Tensor: + return F.mse_loss(policy(x).squashed_mu, action) - def compute_error( - self, x: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: - dist = self.dist(x) - return F.mse_loss(torch.tanh(dist.rsample()), action) - @property - def encoder(self) -> Encoder: - return self._encoder +def compute_stochastic_imitation_loss( + policy: NormalPolicy, x: torch.Tensor, action: torch.Tensor +) -> torch.Tensor: + dist = build_gaussian_distribution(policy(x)) + return F.mse_loss(dist.sample(), action) diff --git a/tests/models/test_builders.py b/tests/models/test_builders.py index 2b98c859..4a3e3f93 100644 --- a/tests/models/test_builders.py +++ b/tests/models/test_builders.py @@ -10,14 +10,11 @@ create_continuous_decision_transformer, create_continuous_q_function, create_deterministic_policy, - create_deterministic_regressor, create_deterministic_residual_policy, create_discrete_decision_transformer, - create_discrete_imitator, create_discrete_q_function, create_normal_policy, create_parameter, - create_probablistic_regressor, create_value_function, ) from d3rlpy.models.encoders import DefaultEncoderFactory, EncoderFactory @@ -26,12 +23,7 @@ EnsembleContinuousQFunction, EnsembleDiscreteQFunction, ) -from d3rlpy.models.torch.imitators import ( - ConditionalVAE, - DeterministicRegressor, - DiscreteImitator, - ProbablisticRegressor, -) +from d3rlpy.models.torch.imitators import ConditionalVAE from d3rlpy.models.torch.policies import ( CategoricalPolicy, DeterministicPolicy, @@ -241,71 +233,6 @@ def test_create_conditional_vae( assert y.shape == (batch_size, action_size) -@pytest.mark.parametrize("observation_shape", [(4, 84, 84), (100,)]) -@pytest.mark.parametrize("action_size", [2]) -@pytest.mark.parametrize("beta", [1e-2]) -@pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("encoder_factory", [DefaultEncoderFactory()]) -def test_create_discrete_imitator( - observation_shape: Sequence[int], - action_size: int, - beta: float, - batch_size: int, - encoder_factory: EncoderFactory, -) -> None: - imitator = create_discrete_imitator( - observation_shape, action_size, beta, encoder_factory, device="cpu:0" - ) - - assert isinstance(imitator, DiscreteImitator) - - x = torch.rand((batch_size, *observation_shape)) - y = imitator(x) - assert y.shape == (batch_size, action_size) - - -@pytest.mark.parametrize("observation_shape", [(4, 84, 84), (100,)]) -@pytest.mark.parametrize("action_size", [2]) -@pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("encoder_factory", [DefaultEncoderFactory()]) -def test_create_deterministic_regressor( - observation_shape: Sequence[int], - action_size: int, - batch_size: int, - encoder_factory: EncoderFactory, -) -> None: - imitator = create_deterministic_regressor( - observation_shape, action_size, encoder_factory, device="cpu:0" - ) - - assert isinstance(imitator, DeterministicRegressor) - - x = torch.rand((batch_size, *observation_shape)) - y = imitator(x) - assert y.shape == (batch_size, action_size) - - -@pytest.mark.parametrize("observation_shape", [(4, 84, 84), (100,)]) -@pytest.mark.parametrize("action_size", [2]) -@pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("encoder_factory", [DefaultEncoderFactory()]) -def test_create_probablistic_regressor( - observation_shape: Sequence[int], - action_size: int, - batch_size: int, - encoder_factory: EncoderFactory, -) -> None: - imitator = create_probablistic_regressor( - observation_shape, action_size, encoder_factory, device="cpu:0" - ) - - assert isinstance(imitator, ProbablisticRegressor) - - x = torch.rand((batch_size, *observation_shape)) - y = imitator(x) - assert y.shape == (batch_size, action_size) - - @pytest.mark.parametrize("observation_shape", [(100,), (4, 84, 84)]) @pytest.mark.parametrize("encoder_factory", [DefaultEncoderFactory()]) @pytest.mark.parametrize("batch_size", [32]) diff --git a/tests/models/torch/test_imitators.py b/tests/models/torch/test_imitators.py index 1802efde..533f5265 100644 --- a/tests/models/torch/test_imitators.py +++ b/tests/models/torch/test_imitators.py @@ -1,20 +1,24 @@ import pytest import torch -import torch.nn.functional as F from d3rlpy.models.torch.imitators import ( ConditionalVAE, - DeterministicRegressor, - DiscreteImitator, - ProbablisticRegressor, VAEDecoder, VAEEncoder, + compute_deterministic_imitation_loss, + compute_discrete_imitation_loss, + compute_stochastic_imitation_loss, compute_vae_error, forward_vae_decode, forward_vae_encode, forward_vae_sample, forward_vae_sample_n, ) +from d3rlpy.models.torch.policies import ( + CategoricalPolicy, + DeterministicPolicy, + NormalPolicy, +) from .model_test import ( DummyEncoder, @@ -137,86 +141,72 @@ def test_conditional_vae( @pytest.mark.parametrize("feature_size", [100]) @pytest.mark.parametrize("action_size", [2]) -@pytest.mark.parametrize("beta", [1e-2]) @pytest.mark.parametrize("batch_size", [32]) -def test_discrete_imitator( - feature_size: int, action_size: int, beta: float, batch_size: int +@pytest.mark.parametrize("beta", [0.5]) +def test_compute_discrete_imitation_loss( + feature_size: int, action_size: int, batch_size: int, beta: float ) -> None: encoder = DummyEncoder(feature_size) - imitator = DiscreteImitator( + policy = CategoricalPolicy( encoder=encoder, hidden_size=feature_size, action_size=action_size, - beta=beta, ) # check output shape x = torch.rand(batch_size, feature_size) - y = imitator(x) - assert torch.allclose(y.exp().sum(dim=1), torch.ones(batch_size)) - y, logits = imitator.compute_log_probs_with_logits(x) - assert torch.allclose(y, F.log_softmax(logits, dim=1)) - - action = torch.randint(low=0, high=action_size - 1, size=(batch_size,)) - loss = imitator.compute_error(x, action) - penalty = (logits**2).mean() - assert torch.allclose(loss, F.nll_loss(y, action) + beta * penalty) - - # check layer connections - check_parameter_updates(imitator, (x, action)) + action = torch.randint(low=0, high=action_size, size=(batch_size,)) + loss = compute_discrete_imitation_loss(policy, x, action, beta) + assert loss.ndim == 0 @pytest.mark.parametrize("feature_size", [100]) @pytest.mark.parametrize("action_size", [2]) @pytest.mark.parametrize("batch_size", [32]) -def test_deterministic_regressor( +def test_compute_deterministic_imitation_loss( feature_size: int, action_size: int, batch_size: int ) -> None: encoder = DummyEncoder(feature_size) - imitator = DeterministicRegressor( + policy = DeterministicPolicy( encoder=encoder, hidden_size=feature_size, action_size=action_size, ) + # check output shape x = torch.rand(batch_size, feature_size) - y = imitator(x) - assert y.shape == (batch_size, action_size) - action = torch.rand(batch_size, action_size) - loss = imitator.compute_error(x, action) - assert torch.allclose(F.mse_loss(y, action), loss) - - # check layer connections - check_parameter_updates(imitator, (x, action)) + loss = compute_deterministic_imitation_loss(policy, x, action) + assert loss.ndim == 0 + assert loss == ((policy(x).squashed_mu - action) ** 2).mean() @pytest.mark.parametrize("feature_size", [100]) @pytest.mark.parametrize("action_size", [2]) @pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("n", [10]) -def test_probablistic_regressor( - feature_size: int, action_size: int, batch_size: int, n: int +@pytest.mark.parametrize("min_logstd", [-20.0]) +@pytest.mark.parametrize("max_logstd", [2.0]) +@pytest.mark.parametrize("use_std_parameter", [True, False]) +def test_compute_stochastic_imitation_loss( + feature_size: int, + action_size: int, + batch_size: int, + min_logstd: float, + max_logstd: float, + use_std_parameter: bool, ) -> None: encoder = DummyEncoder(feature_size) - imitator = ProbablisticRegressor( + policy = NormalPolicy( encoder=encoder, hidden_size=feature_size, action_size=action_size, - min_logstd=-20, - max_logstd=2, + min_logstd=min_logstd, + max_logstd=max_logstd, + use_std_parameter=use_std_parameter, ) + # check output shape x = torch.rand(batch_size, feature_size) - y = imitator(x) - assert y.shape == (batch_size, action_size) - action = torch.rand(batch_size, action_size) - loss = imitator.compute_error(x, action) + loss = compute_stochastic_imitation_loss(policy, x, action) assert loss.ndim == 0 - - y = imitator.sample_n(x, n) - assert y.shape == (batch_size, n, action_size) - - # check layer connections - check_parameter_updates(imitator, (x, action))