From 4abdbd3ffad773c915829658b065ffeca8e248ca Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 6 Aug 2023 23:30:52 +0900 Subject: [PATCH] Simplify ConditionalVAE --- d3rlpy/algos/qlearning/bcq.py | 1 - d3rlpy/algos/qlearning/bear.py | 1 - d3rlpy/algos/qlearning/plas.py | 4 +- d3rlpy/algos/qlearning/torch/bcq_impl.py | 21 ++- d3rlpy/algos/qlearning/torch/bear_impl.py | 13 +- d3rlpy/algos/qlearning/torch/plas_impl.py | 37 +++-- d3rlpy/algos/qlearning/torch/sac_impl.py | 15 +- d3rlpy/models/builders.py | 16 +- d3rlpy/models/torch/imitators.py | 193 ++++++++++++++-------- d3rlpy/models/torch/policies.py | 9 +- tests/models/test_builders.py | 7 +- tests/models/torch/test_imitators.py | 102 +++++++++--- tests/models/torch/test_policies.py | 16 +- 13 files changed, 280 insertions(+), 155 deletions(-) diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index 37c064ee..3fb46292 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -186,7 +186,6 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, latent_size=2 * action_size, - beta=self._config.beta, min_logstd=-4.0, max_logstd=15.0, encoder_factory=self._config.imitator_encoder_factory, diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index 2481662c..62a88924 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -176,7 +176,6 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, latent_size=2 * action_size, - beta=self._config.vae_kl_weight, min_logstd=-4.0, max_logstd=15.0, encoder_factory=self._config.imitator_encoder_factory, diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index 79ac130f..9f146364 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -124,7 +124,6 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, latent_size=2 * action_size, - beta=self._config.beta, min_logstd=-4.0, max_logstd=15.0, encoder_factory=self._config.imitator_encoder_factory, @@ -153,6 +152,7 @@ def inner_create_impl( gamma=self._config.gamma, tau=self._config.tau, lam=self._config.lam, + beta=self._config.beta, device=self._device, ) @@ -257,7 +257,6 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, latent_size=2 * action_size, - beta=self._config.beta, min_logstd=-4.0, max_logstd=15.0, encoder_factory=self._config.imitator_encoder_factory, @@ -296,6 +295,7 @@ def inner_create_impl( gamma=self._config.gamma, tau=self._config.tau, lam=self._config.lam, + beta=self._config.beta, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index 95ad8865..9eaebca2 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -12,6 +12,8 @@ EnsembleContinuousQFunction, EnsembleDiscreteQFunction, compute_max_with_n_actions, + compute_vae_error, + forward_vae_decode, ) from ....torch_utility import TorchMiniBatch, train_api from .ddpg_impl import DDPGBaseImpl @@ -73,8 +75,10 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: device=self._device, ) clipped_latent = latent.clamp(-0.5, 0.5) - sampled_action = self._imitator.decode( - batch.observations, clipped_latent + sampled_action = forward_vae_decode( + vae=self._imitator, + x=batch.observations, + latent=clipped_latent, ) action = self._policy(batch.observations, sampled_action) return -self._q_func(batch.observations, action.squashed_mu, "none")[ @@ -85,7 +89,12 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: def update_imitator(self, batch: TorchMiniBatch) -> float: self._imitator_optim.zero_grad() - loss = self._imitator.compute_error(batch.observations, batch.actions) + loss = compute_vae_error( + vae=self._imitator, + x=batch.observations, + action=batch.actions, + beta=self._beta, + ) loss.backward() self._imitator_optim.step() @@ -109,7 +118,11 @@ def _sample_repeated_action( ) clipped_latent = latent.clamp(-0.5, 0.5) # sample action - sampled_action = self._imitator.decode(flattened_x, clipped_latent) + sampled_action = forward_vae_decode( + vae=self._imitator, + x=flattened_x, + latent=clipped_latent, + ) # add residual action policy = self._targ_policy if target else self._policy action = policy(flattened_x, sampled_action) diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index 8870f1c5..6a17d3d2 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -11,6 +11,8 @@ Parameter, build_squashed_gaussian_distribution, compute_max_with_n_actions_and_indices, + compute_vae_error, + forward_vae_sample_n, ) from ....torch_utility import TorchMiniBatch, train_api from .sac_impl import SACImpl @@ -133,7 +135,12 @@ def update_imitator(self, batch: TorchMiniBatch) -> float: return float(loss.cpu().detach().numpy()) def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor: - return self._imitator.compute_error(batch.observations, batch.actions) + return compute_vae_error( + vae=self._imitator, + x=batch.observations, + action=batch.actions, + beta=self._vae_kl_weight, + ) @train_api def update_alpha(self, batch: TorchMiniBatch) -> Tuple[float, float]: @@ -152,8 +159,8 @@ def update_alpha(self, batch: TorchMiniBatch) -> Tuple[float, float]: def _compute_mmd(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): - behavior_actions = self._imitator.sample_n_without_squash( - x, self._n_mmd_action_samples + behavior_actions = forward_vae_sample_n( + self._imitator, x, self._n_mmd_action_samples, with_squash=False ) dist = build_squashed_gaussian_distribution(self._policy(x)) policy_actions = dist.sample_n_without_squash( diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index bece9f5b..06760a3b 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -9,6 +9,8 @@ DeterministicPolicy, DeterministicResidualPolicy, EnsembleContinuousQFunction, + compute_vae_error, + forward_vae_decode, ) from ....torch_utility import TorchMiniBatch, soft_sync, train_api from .ddpg_impl import DDPGBaseImpl @@ -37,6 +39,7 @@ def __init__( gamma: float, tau: float, lam: float, + beta: float, device: str, ): super().__init__( @@ -51,6 +54,7 @@ def __init__( device=device, ) self._lam = lam + self._beta = beta self._imitator = imitator self._imitator_optim = imitator_optim @@ -58,7 +62,12 @@ def __init__( def update_imitator(self, batch: TorchMiniBatch) -> float: self._imitator_optim.zero_grad() - loss = self._imitator.compute_error(batch.observations, batch.actions) + loss = compute_vae_error( + vae=self._imitator, + x=batch.observations, + action=batch.actions, + beta=self._beta, + ) loss.backward() self._imitator_optim.step() @@ -67,11 +76,14 @@ def update_imitator(self, batch: TorchMiniBatch) -> float: def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: latent_actions = 2.0 * self._policy(batch.observations).squashed_mu - actions = self._imitator.decode(batch.observations, latent_actions) + actions = forward_vae_decode( + self._imitator, batch.observations, latent_actions + ) return -self._q_func(batch.observations, actions, "none")[0].mean() def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - return self._imitator.decode(x, 2.0 * self._policy(x).squashed_mu) + latent_actions = 2.0 * self._policy(x).squashed_mu + return forward_vae_decode(self._imitator, x, latent_actions) def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: return self.inner_predict_best_action(x) @@ -81,8 +93,8 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: latent_actions = ( 2.0 * self._targ_policy(batch.next_observations).squashed_mu ) - actions = self._imitator.decode( - batch.next_observations, latent_actions + actions = forward_vae_decode( + self._imitator, batch.next_observations, latent_actions ) return self._targ_q_func.compute_target( batch.next_observations, @@ -110,6 +122,7 @@ def __init__( gamma: float, tau: float, lam: float, + beta: float, device: str, ): super().__init__( @@ -124,6 +137,7 @@ def __init__( gamma=gamma, tau=tau, lam=lam, + beta=beta, device=device, ) self._perturbation = perturbation @@ -131,7 +145,9 @@ def __init__( def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: latent_actions = 2.0 * self._policy(batch.observations).squashed_mu - actions = self._imitator.decode(batch.observations, latent_actions) + actions = forward_vae_decode( + self._imitator, batch.observations, latent_actions + ) residual_actions = self._perturbation( batch.observations, actions ).squashed_mu @@ -139,8 +155,9 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: return -q_value[0].mean() def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - action = self._imitator.decode(x, 2.0 * self._policy(x).squashed_mu) - return self._perturbation(x, action).squashed_mu + latent_actions = 2.0 * self._policy(x).squashed_mu + actions = forward_vae_decode(self._imitator, x, latent_actions) + return self._perturbation(x, actions).squashed_mu def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: return self.inner_predict_best_action(x) @@ -150,8 +167,8 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: latent_actions = ( 2.0 * self._targ_policy(batch.next_observations).squashed_mu ) - actions = self._imitator.decode( - batch.next_observations, latent_actions + actions = forward_vae_decode( + self._imitator, batch.next_observations, latent_actions ) residual_actions = self._targ_perturbation( batch.next_observations, actions diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index b5573355..bfa298fb 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -13,7 +13,6 @@ EnsembleQFunction, Parameter, Policy, - build_categorical_distribution, build_squashed_gaussian_distribution, ) from ....torch_utility import TorchMiniBatch, hard_sync, train_api @@ -155,9 +154,7 @@ def update_critic(self, batch: TorchMiniBatch) -> float: def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): - dist = build_categorical_distribution( - self._policy(batch.next_observations) - ) + dist = self._policy(batch.next_observations) log_probs = dist.logits probs = dist.probs entropy = self._log_temp().exp() * log_probs @@ -200,7 +197,7 @@ def update_actor(self, batch: TorchMiniBatch) -> float: def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): q_t = self._q_func(batch.observations, reduction="min") - dist = build_categorical_distribution(self._policy(batch.observations)) + dist = self._policy(batch.observations) log_probs = dist.logits probs = dist.probs entropy = self._log_temp().exp() * log_probs @@ -211,9 +208,7 @@ def update_temp(self, batch: TorchMiniBatch) -> Tuple[float, float]: self._temp_optim.zero_grad() with torch.no_grad(): - dist = build_categorical_distribution( - self._policy(batch.observations) - ) + dist = self._policy(batch.observations) log_probs = dist.logits probs = dist.probs expct_log_probs = (probs * log_probs).sum(dim=1, keepdim=True) @@ -231,11 +226,11 @@ def update_temp(self, batch: TorchMiniBatch) -> Tuple[float, float]: return float(loss.cpu().detach().numpy()), float(cur_temp) def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - dist = build_categorical_distribution(self._policy(x)) + dist = self._policy(x) return dist.probs.argmax(dim=1) def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: - dist = build_categorical_distribution(self._policy(x)) + dist = self._policy(x) return dist.sample() def update_target(self) -> None: diff --git a/d3rlpy/models/builders.py b/d3rlpy/models/builders.py index 65715cd7..4f7ab087 100644 --- a/d3rlpy/models/builders.py +++ b/d3rlpy/models/builders.py @@ -22,6 +22,8 @@ Parameter, ProbablisticRegressor, SimplePositionEncoding, + VAEDecoder, + VAEEncoder, ValueFunction, compute_output_size, ) @@ -191,7 +193,6 @@ def create_conditional_vae( observation_shape: Shape, action_size: int, latent_size: int, - beta: float, encoder_factory: EncoderFactory, device: str, min_logstd: float = -20.0, @@ -206,16 +207,19 @@ def create_conditional_vae( hidden_size = compute_output_size( [observation_shape, (action_size,)], encoder_encoder, device ) - policy = ConditionalVAE( - encoder_encoder=encoder_encoder, - decoder_encoder=decoder_encoder, + encoder = VAEEncoder( + encoder=encoder_encoder, hidden_size=hidden_size, latent_size=latent_size, - action_size=action_size, - beta=beta, min_logstd=min_logstd, max_logstd=max_logstd, ) + decoder = VAEDecoder( + encoder=decoder_encoder, + hidden_size=hidden_size, + action_size=action_size, + ) + policy = ConditionalVAE(encoder=encoder, decoder=decoder) policy.to(device) return policy diff --git a/d3rlpy/models/torch/imitators.py b/d3rlpy/models/torch/imitators.py index 374775c2..116e1e2e 100644 --- a/d3rlpy/models/torch/imitators.py +++ b/d3rlpy/models/torch/imitators.py @@ -10,7 +10,14 @@ from .encoders import Encoder, EncoderWithAction __all__ = [ + "VAEEncoder", + "VAEDecoder", "ConditionalVAE", + "forward_vae_encode", + "forward_vae_decode", + "forward_vae_sample", + "forward_vae_sample_n", + "compute_vae_error", "Imitator", "DiscreteImitator", "DeterministicRegressor", @@ -18,109 +25,151 @@ ] -class ConditionalVAE(nn.Module): # type: ignore - _encoder_encoder: EncoderWithAction - _decoder_encoder: EncoderWithAction - _beta: float +class VAEEncoder(nn.Module): # type: ignore + _encoder: EncoderWithAction + _mu: nn.Module + _logstd: nn.Module _min_logstd: float _max_logstd: float - - _action_size: int _latent_size: int - _mu: nn.Linear - _logstd: nn.Linear - _fc: nn.Linear def __init__( self, - encoder_encoder: EncoderWithAction, - decoder_encoder: EncoderWithAction, + encoder: EncoderWithAction, hidden_size: int, latent_size: int, - action_size: int, - beta: float, min_logstd: float = -20.0, max_logstd: float = 2.0, ): super().__init__() - self._encoder_encoder = encoder_encoder - self._decoder_encoder = decoder_encoder - self._beta = beta + self._encoder = encoder + self._mu = nn.Linear(hidden_size, latent_size) + self._logstd = nn.Linear(hidden_size, latent_size) self._min_logstd = min_logstd self._max_logstd = max_logstd + self._latent_size = latent_size + + def forward(self, x: torch.Tensor, action: torch.Tensor) -> Normal: + h = self._encoder(x, action) + mu = self._mu(h) + logstd = self._logstd(h) + clipped_logstd = logstd.clamp(self._min_logstd, self._max_logstd) + return Normal(mu, clipped_logstd.exp()) + + def __call__(self, x: torch.Tensor, action: torch.Tensor) -> Normal: + return super().__call__(x, action) + @property + def latent_size(self) -> int: + return self._latent_size + + +class VAEDecoder(nn.Module): # type: ignore + _encoder: EncoderWithAction + _fc: nn.Linear + _action_size: int + + def __init__( + self, encoder: EncoderWithAction, hidden_size: int, action_size: int + ): + super().__init__() + self._encoder = encoder + self._fc = nn.Linear(hidden_size, action_size) self._action_size = action_size - self._latent_size = latent_size - # encoder - self._mu = nn.Linear(hidden_size, self._latent_size) - self._logstd = nn.Linear(hidden_size, self._latent_size) - # decoder - self._fc = nn.Linear(hidden_size, self._action_size) + def forward( + self, x: torch.Tensor, latent: torch.Tensor, with_squash: bool + ) -> torch.Tensor: + h = self._encoder(x, latent) + if with_squash: + return self._fc(h) + return torch.tanh(self._fc(h)) + + def __call__( + self, x: torch.Tensor, latent: torch.Tensor, with_squash: bool = True + ) -> torch.Tensor: + return super().__call__(x, latent, with_squash) + + @property + def action_size(self) -> int: + return self._action_size + + +class ConditionalVAE(nn.Module): # type: ignore + _encoder: VAEEncoder + _decoder: VAEDecoder + _beta: float + + def __init__(self, encoder: VAEEncoder, decoder: VAEDecoder): + super().__init__() + self._encoder = encoder + self._decoder = decoder def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: - dist = self.encode(x, action) - return self.decode(x, dist.rsample()) + dist = self._encoder(x, action) + return self._decoder(x, dist.rsample()) def __call__(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: return cast(torch.Tensor, super().__call__(x, action)) - def encode(self, x: torch.Tensor, action: torch.Tensor) -> Normal: - h = self._encoder_encoder(x, action) - mu = self._mu(h) - logstd = self._logstd(h) - clipped_logstd = logstd.clamp(self._min_logstd, self._max_logstd) - return Normal(mu, clipped_logstd.exp()) + @property + def encoder(self) -> VAEEncoder: + return self._encoder - def decode(self, x: torch.Tensor, latent: torch.Tensor) -> torch.Tensor: - h = self._decoder_encoder(x, latent) - return torch.tanh(self._fc(h)) + @property + def decoder(self) -> VAEDecoder: + return self._decoder - def decode_without_squash( - self, x: torch.Tensor, latent: torch.Tensor - ) -> torch.Tensor: - h = self._decoder_encoder(x, latent) - return self._fc(h) - def compute_error( - self, x: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: - dist = self.encode(x, action) - kl_loss = kl_divergence(dist, Normal(0.0, 1.0)).mean() - y = self.decode(x, dist.rsample()) - return F.mse_loss(y, action) + cast(torch.Tensor, self._beta * kl_loss) - - def sample(self, x: torch.Tensor) -> torch.Tensor: - latent = torch.randn((x.shape[0], self._latent_size), device=x.device) - # to prevent extreme numbers - return self.decode(x, latent.clamp(-0.5, 0.5)) - - def sample_n( - self, x: torch.Tensor, n: int, with_squash: bool = True - ) -> torch.Tensor: - flat_latent_shape = (n * x.shape[0], self._latent_size) - flat_latent = torch.randn(flat_latent_shape, device=x.device) - # to prevent extreme numbers - clipped_latent = flat_latent.clamp(-0.5, 0.5) +def forward_vae_encode( + vae: ConditionalVAE, x: torch.Tensor, action: torch.Tensor +) -> Normal: + return vae.encoder(x, action) - # (batch, obs) -> (n, batch, obs) - repeated_x = x.expand((n, *x.shape)) - # (n, batch, obs) -> (n * batch, obs) - flat_x = repeated_x.reshape(-1, *x.shape[1:]) - if with_squash: - flat_actions = self.decode(flat_x, clipped_latent) - else: - flat_actions = self.decode_without_squash(flat_x, clipped_latent) +def forward_vae_decode( + vae: ConditionalVAE, x: torch.Tensor, latent: torch.Tensor +) -> torch.Tensor: + return vae.decoder(x, latent) - # (n * batch, action) -> (n, batch, action) - actions = flat_actions.view(n, x.shape[0], -1) - # (n, batch, action) -> (batch, n, action) - return actions.transpose(0, 1) +def forward_vae_sample( + vae: ConditionalVAE, x: torch.Tensor, with_squash: bool = True +) -> torch.Tensor: + latent = torch.randn((x.shape[0], vae.encoder.latent_size), device=x.device) + # to prevent extreme numbers + return vae.decoder(x, latent.clamp(-0.5, 0.5), with_squash=with_squash) + + +def forward_vae_sample_n( + vae: ConditionalVAE, x: torch.Tensor, n: int, with_squash: bool = True +) -> torch.Tensor: + flat_latent_shape = (n * x.shape[0], vae.encoder.latent_size) + flat_latent = torch.randn(flat_latent_shape, device=x.device) + # to prevent extreme numbers + clipped_latent = flat_latent.clamp(-0.5, 0.5) + + # (batch, obs) -> (n, batch, obs) + repeated_x = x.expand((n, *x.shape)) + # (n, batch, obs) -> (n * batch, obs) + flat_x = repeated_x.reshape(-1, *x.shape[1:]) + + flat_actions = vae.decoder(flat_x, clipped_latent, with_squash=with_squash) + + # (n * batch, action) -> (n, batch, action) + actions = flat_actions.view(n, x.shape[0], -1) + + # (n, batch, action) -> (batch, n, action) + return actions.transpose(0, 1) + - def sample_n_without_squash(self, x: torch.Tensor, n: int) -> torch.Tensor: - return self.sample_n(x, n, with_squash=False) +def compute_vae_error( + vae: ConditionalVAE, x: torch.Tensor, action: torch.Tensor, beta: float +) -> torch.Tensor: + dist = vae.encoder(x, action) + kl_loss = kl_divergence(dist, Normal(0.0, 1.0)).mean() + y = vae.decoder(x, dist.rsample()) + return F.mse_loss(y, action) + cast(torch.Tensor, beta * kl_loss) class Imitator(nn.Module, metaclass=ABCMeta): # type: ignore diff --git a/d3rlpy/models/torch/policies.py b/d3rlpy/models/torch/policies.py index f8df406b..97f61aa8 100644 --- a/d3rlpy/models/torch/policies.py +++ b/d3rlpy/models/torch/policies.py @@ -16,7 +16,6 @@ "CategoricalPolicy", "build_gaussian_distribution", "build_squashed_gaussian_distribution", - "build_categorical_distribution", "ActionOutput", ] @@ -145,9 +144,5 @@ def __init__(self, encoder: Encoder, hidden_size: int, action_size: int): self._encoder = encoder self._fc = nn.Linear(hidden_size, action_size) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self._fc(self._encoder(x)) - - -def build_categorical_distribution(logits: torch.Tensor) -> Categorical: - return Categorical(probs=torch.softmax(logits, dim=1)) + def forward(self, x: torch.Tensor) -> Categorical: + return Categorical(logits=self._fc(self._encoder(x))) diff --git a/tests/models/test_builders.py b/tests/models/test_builders.py index e9e81022..2b98c859 100644 --- a/tests/models/test_builders.py +++ b/tests/models/test_builders.py @@ -128,8 +128,8 @@ def test_create_categorical_policy( assert isinstance(policy, CategoricalPolicy) x = torch.rand((batch_size, *observation_shape)) - y = policy(x) - assert y.shape == (batch_size, action_size) + dist = policy(x) + assert dist.probs.shape == (batch_size, action_size) @pytest.mark.parametrize("observation_shape", [(4, 84, 84), (100,)]) @@ -216,14 +216,12 @@ def test_create_continuous_q_function( @pytest.mark.parametrize("observation_shape", [(4, 84, 84), (100,)]) @pytest.mark.parametrize("action_size", [2]) @pytest.mark.parametrize("latent_size", [32]) -@pytest.mark.parametrize("beta", [1.0]) @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("encoder_factory", [DefaultEncoderFactory()]) def test_create_conditional_vae( observation_shape: Sequence[int], action_size: int, latent_size: int, - beta: float, batch_size: int, encoder_factory: EncoderFactory, ) -> None: @@ -231,7 +229,6 @@ def test_create_conditional_vae( observation_shape, action_size, latent_size, - beta, encoder_factory, device="cpu:0", ) diff --git a/tests/models/torch/test_imitators.py b/tests/models/torch/test_imitators.py index 44b72b26..1802efde 100644 --- a/tests/models/torch/test_imitators.py +++ b/tests/models/torch/test_imitators.py @@ -7,6 +7,13 @@ DeterministicRegressor, DiscreteImitator, ProbablisticRegressor, + VAEDecoder, + VAEEncoder, + compute_vae_error, + forward_vae_decode, + forward_vae_encode, + forward_vae_sample, + forward_vae_sample_n, ) from .model_test import ( @@ -19,26 +26,83 @@ @pytest.mark.parametrize("feature_size", [100]) @pytest.mark.parametrize("action_size", [2]) @pytest.mark.parametrize("latent_size", [32]) -@pytest.mark.parametrize("beta", [0.5]) +@pytest.mark.parametrize("batch_size", [32]) +def test_vae_encoder( + feature_size: int, + action_size: int, + latent_size: int, + batch_size: int, +) -> None: + encoder = DummyEncoderWithAction(feature_size, action_size) + vae_encoder = VAEEncoder( + encoder=encoder, + hidden_size=feature_size, + latent_size=latent_size, + ) + + # check output shape + x = torch.rand(batch_size, feature_size) + action = torch.rand(batch_size, action_size) + dist = vae_encoder(x, action) + assert dist.mean.shape == (batch_size, latent_size) + + +@pytest.mark.parametrize("feature_size", [100]) +@pytest.mark.parametrize("action_size", [2]) +@pytest.mark.parametrize("latent_size", [32]) +@pytest.mark.parametrize("batch_size", [32]) +def test_vae_decoder( + feature_size: int, + action_size: int, + latent_size: int, + batch_size: int, +) -> None: + encoder = DummyEncoderWithAction(feature_size, latent_size) + vae_decoder = VAEDecoder( + encoder=encoder, + hidden_size=feature_size, + action_size=action_size, + ) + + # check output shape + x = torch.rand(batch_size, feature_size) + latent = torch.rand(batch_size, latent_size) + action = vae_decoder(x, latent) + assert action.shape == (batch_size, action_size) + + # check layer connections + check_parameter_updates(vae_decoder, (x, latent)) + + +@pytest.mark.parametrize("feature_size", [100]) +@pytest.mark.parametrize("action_size", [2]) +@pytest.mark.parametrize("latent_size", [32]) @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("n", [100]) +@pytest.mark.parametrize("beta", [0.5]) def test_conditional_vae( feature_size: int, action_size: int, latent_size: int, - beta: float, batch_size: int, n: int, + beta: float, ) -> None: encoder_encoder = DummyEncoderWithAction(feature_size, action_size) decoder_encoder = DummyEncoderWithAction(feature_size, latent_size) - vae = ConditionalVAE( - encoder_encoder=encoder_encoder, - decoder_encoder=decoder_encoder, + vae_encoder = VAEEncoder( + encoder=encoder_encoder, hidden_size=feature_size, latent_size=latent_size, + ) + vae_decoder = VAEDecoder( + encoder=decoder_encoder, + hidden_size=feature_size, action_size=action_size, - beta=beta, + ) + vae = ConditionalVAE( + encoder=vae_encoder, + decoder=vae_decoder, ) # check output shape @@ -47,29 +111,25 @@ def test_conditional_vae( y = vae(x, action) assert y.shape == (batch_size, action_size) - # check encode - dist = vae.encode(x, action) - assert isinstance(dist, torch.distributions.Normal) + # test encode + dist = forward_vae_encode(vae, x, action) assert dist.mean.shape == (batch_size, latent_size) - # check decode - latent = torch.rand(batch_size, latent_size) - y = vae.decode(x, latent) + # test decode + y = forward_vae_decode(vae, x, dist.sample()) assert y.shape == (batch_size, action_size) - # check sample - y = vae.sample(x) + # test decode sample + y = forward_vae_sample(vae, x) assert y.shape == (batch_size, action_size) - # check sample_n - y = vae.sample_n(x, n) - assert y.shape == (batch_size, n, action_size) - - # check sample_n_without_squash - y = vae.sample_n_without_squash(x, n) + # test decode sample n + y = forward_vae_sample_n(vae, x, n) assert y.shape == (batch_size, n, action_size) - # TODO: test vae.compute_likelihood_loss(x, action) + # test compute error + error = compute_vae_error(vae, x, action, beta) + assert error.ndim == 0 # check layer connections check_parameter_updates(vae, (x, action)) diff --git a/tests/models/torch/test_policies.py b/tests/models/torch/test_policies.py index 9ecc2dcf..d7e27f7b 100644 --- a/tests/models/torch/test_policies.py +++ b/tests/models/torch/test_policies.py @@ -11,7 +11,6 @@ DeterministicPolicy, DeterministicResidualPolicy, NormalPolicy, - build_categorical_distribution, build_gaussian_distribution, build_squashed_gaussian_distribution, ) @@ -136,8 +135,9 @@ def test_categorical_policy( # check output shape x = torch.rand(batch_size, feature_size) - y = policy(x) - assert y.shape == (batch_size, action_size) + dist = policy(x) + assert dist.probs.shape == (batch_size, action_size) + assert dist.sample().shape == (batch_size,) @pytest.mark.parametrize("action_size", [2]) @@ -170,13 +170,3 @@ def test_build_squashed_gaussian_distribution( assert torch.all(dist.mean == torch.tanh(mu)) assert torch.all(dist.std == logstd.exp()) - - -@pytest.mark.parametrize("action_size", [2]) -@pytest.mark.parametrize("batch_size", [32]) -def test_build_categorical_distribution( - action_size: int, batch_size: int -) -> None: - logits = torch.rand(batch_size, action_size) - dist = build_categorical_distribution(logits) - assert torch.allclose(dist.probs, torch.softmax(logits, dim=1))