From 8d419aa9e0571bdd5d48055541366b8c17c1a91e Mon Sep 17 00:00:00 2001 From: takuseno Date: Sat, 18 Nov 2023 19:57:00 +0900 Subject: [PATCH] Minimize redundant computes --- d3rlpy/algos/qlearning/torch/awac_impl.py | 22 +++--- d3rlpy/algos/qlearning/torch/bcq_impl.py | 50 +++++++------ d3rlpy/algos/qlearning/torch/bear_impl.py | 64 +++++++---------- d3rlpy/algos/qlearning/torch/cql_impl.py | 60 +++++----------- d3rlpy/algos/qlearning/torch/crr_impl.py | 18 ++--- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 67 ++++++++++++------ d3rlpy/algos/qlearning/torch/iql_impl.py | 68 +++++++----------- d3rlpy/algos/qlearning/torch/plas_impl.py | 30 ++++---- d3rlpy/algos/qlearning/torch/sac_impl.py | 70 ++++++++++--------- d3rlpy/algos/qlearning/torch/td3_impl.py | 3 +- .../algos/qlearning/torch/td3_plus_bc_impl.py | 22 ++++-- tests/algos/qlearning/test_bear.py | 1 + 12 files changed, 237 insertions(+), 238 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/awac_impl.py b/d3rlpy/algos/qlearning/torch/awac_impl.py index af73d020..712eeab7 100644 --- a/d3rlpy/algos/qlearning/torch/awac_impl.py +++ b/d3rlpy/algos/qlearning/torch/awac_impl.py @@ -2,12 +2,13 @@ import torch.nn.functional as F from ....models.torch import ( + ActionOutput, ContinuousEnsembleQFunctionForwarder, build_gaussian_distribution, ) from ....torch_utility import TorchMiniBatch from ....types import Shape -from .sac_impl import SACImpl, SACModules +from .sac_impl import SACActorLoss, SACImpl, SACModules __all__ = ["AWACImpl"] @@ -42,17 +43,22 @@ def __init__( self._lam = lam self._n_action_samples = n_action_samples - def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: + def compute_actor_loss( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> SACActorLoss: # compute log probability - dist = build_gaussian_distribution( - self._modules.policy(batch.observations) - ) + dist = build_gaussian_distribution(action) log_probs = dist.log_prob(batch.actions) - # compute exponential weight weights = self._compute_weights(batch.observations, batch.actions) - - return -(log_probs * weights).sum() + loss = -(log_probs * weights).sum() + return SACActorLoss( + actor_loss=loss, + temp_loss=torch.tensor( + 0.0, dtype=torch.float32, device=loss.device + ), + temp=torch.tensor(0.0, dtype=torch.float32, device=loss.device), + ) def _compute_weights( self, obs_t: torch.Tensor, act_t: torch.Tensor diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index 51f9bf14..0528532d 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -7,6 +7,7 @@ from torch.optim import Optimizer from ....models.torch import ( + ActionOutput, CategoricalPolicy, ConditionalVAE, ContinuousEnsembleQFunctionForwarder, @@ -19,7 +20,7 @@ ) from ....torch_utility import TorchMiniBatch, soft_sync from ....types import Shape -from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules +from .ddpg_impl import DDPGBaseActorLoss, DDPGBaseImpl, DDPGBaseModules from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules __all__ = [ @@ -79,37 +80,24 @@ def __init__( self._beta = beta self._rl_start_step = rl_start_step - def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: - latent = torch.randn( - batch.observations.shape[0], - 2 * self._action_size, - device=self._device, - ) - clipped_latent = latent.clamp(-0.5, 0.5) - sampled_action = forward_vae_decode( - vae=self._modules.imitator, - x=batch.observations, - latent=clipped_latent, - ) - action = self._modules.policy(batch.observations, sampled_action) + def compute_actor_loss( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> DDPGBaseActorLoss: value = self._q_func_forwarder.compute_expected_q( batch.observations, action.squashed_mu, "none" ) - return -value[0].mean() + return DDPGBaseActorLoss(-value[0].mean()) def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.imitator_optim.zero_grad() - loss = compute_vae_error( vae=self._modules.imitator, x=batch.observations, action=batch.actions, beta=self._beta, ) - loss.backward() self._modules.imitator_optim.step() - return {"imitator_loss": float(loss.cpu().detach().numpy())} def _repeat_observation(self, x: torch.Tensor) -> torch.Tensor: @@ -188,10 +176,30 @@ def inner_update( self, batch: TorchMiniBatch, grad_step: int ) -> Dict[str, float]: metrics = {} + metrics.update(self.update_imitator(batch)) - if grad_step >= self._rl_start_step: - metrics.update(super().inner_update(batch, grad_step)) - self.update_actor_target() + if grad_step < self._rl_start_step: + return metrics + + # forward policy + latent = torch.randn( + batch.observations.shape[0], + 2 * self._action_size, + device=self._device, + ) + clipped_latent = latent.clamp(-0.5, 0.5) + sampled_action = forward_vae_decode( + vae=self._modules.imitator, + x=batch.observations, + latent=clipped_latent, + ) + action = self._modules.policy(batch.observations, sampled_action) + + # update models + metrics.update(self.update_critic(batch)) + metrics.update(self.update_actor(batch, action)) + self.update_critic_target() + self.update_actor_target() return metrics diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index c16f3501..7cab6d99 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -5,6 +5,7 @@ from torch.optim import Optimizer from ....models.torch import ( + ActionOutput, ConditionalVAE, ContinuousEnsembleQFunctionForwarder, Parameter, @@ -15,7 +16,7 @@ ) from ....torch_utility import TorchMiniBatch from ....types import Shape -from .sac_impl import SACImpl, SACModules +from .sac_impl import SACActorLoss, SACImpl, SACModules __all__ = ["BEARImpl", "BEARModules"] @@ -42,6 +43,12 @@ class BEARModules(SACModules): alpha_optim: Optional[Optimizer] +@dataclasses.dataclass(frozen=True) +class BEARActorLoss(SACActorLoss): + mmd_loss: torch.Tensor + alpha: torch.Tensor + + class BEARImpl(SACImpl): _modules: BEARModules _alpha_threshold: float @@ -94,19 +101,26 @@ def __init__( self._vae_kl_weight = vae_kl_weight self._warmup_steps = warmup_steps - def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: - loss = super().compute_actor_loss(batch) + def compute_actor_loss( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> BEARActorLoss: + loss = super().compute_actor_loss(batch, action) mmd_loss = self._compute_mmd_loss(batch.observations) - return loss + mmd_loss + if self._modules.alpha_optim: + self.update_alpha(mmd_loss) + return BEARActorLoss( + actor_loss=loss.actor_loss + mmd_loss, + temp_loss=loss.temp_loss, + temp=loss.temp, + mmd_loss=mmd_loss, + alpha=self._modules.log_alpha().exp(), + ) def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.actor_optim.zero_grad() - loss = self._compute_mmd_loss(batch.observations) - loss.backward() self._modules.actor_optim.step() - return {"actor_loss": float(loss.cpu().detach().numpy())} def _compute_mmd_loss(self, obs_t: torch.Tensor) -> torch.Tensor: @@ -116,13 +130,9 @@ def _compute_mmd_loss(self, obs_t: torch.Tensor) -> torch.Tensor: def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.imitator_optim.zero_grad() - loss = self.compute_imitator_loss(batch) - loss.backward() - self._modules.imitator_optim.step() - return {"imitator_loss": float(loss.cpu().detach().numpy())} def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor: @@ -133,25 +143,15 @@ def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor: beta=self._vae_kl_weight, ) - def update_alpha(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_alpha(self, mmd_loss: torch.Tensor) -> None: assert self._modules.alpha_optim self._modules.alpha_optim.zero_grad() - - loss = -self._compute_mmd_loss(batch.observations) - - loss.backward() + loss = -mmd_loss + loss.backward(retain_graph=True) self._modules.alpha_optim.step() - # clip for stability self._modules.log_alpha.data.clamp_(-5.0, 10.0) - cur_alpha = self._modules.log_alpha().exp().cpu().detach().numpy()[0][0] - - return { - "alpha_loss": float(loss.cpu().detach().numpy()), - "alpha": float(cur_alpha), - } - def _compute_mmd(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): behavior_actions = forward_vae_sample_n( @@ -259,25 +259,13 @@ def inner_update( self, batch: TorchMiniBatch, grad_step: int ) -> Dict[str, float]: metrics = {} - metrics.update(self.update_imitator(batch)) - - # lagrangian parameter update for SAC temperature - if self._modules.temp_optim: - metrics.update(self.update_temp(batch)) - - # lagrangian parameter update for MMD loss weight - if self._modules.alpha_optim: - metrics.update(self.update_alpha(batch)) - metrics.update(self.update_critic(batch)) - if grad_step < self._warmup_steps: actor_loss = self.warmup_actor(batch) else: - actor_loss = self.update_actor(batch) + action = self._modules.policy(batch.observations) + actor_loss = self.update_actor(batch, action) metrics.update(actor_loss) - self.update_critic_target() - return metrics diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index a63586e3..bd7a4af1 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -1,6 +1,6 @@ import dataclasses import math -from typing import Dict, Optional +from typing import Optional import torch import torch.nn.functional as F @@ -14,6 +14,7 @@ ) from ....torch_utility import TorchMiniBatch from ....types import Shape +from .ddpg_impl import DDPGBaseCriticLoss from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules from .sac_impl import SACImpl, SACModules @@ -26,6 +27,12 @@ class CQLModules(SACModules): alpha_optim: Optional[Optimizer] +@dataclasses.dataclass(frozen=True) +class CQLCriticLoss(DDPGBaseCriticLoss): + conservative_loss: torch.Tensor + alpha: torch.Tensor + + class CQLImpl(SACImpl): _modules: CQLModules _alpha_threshold: float @@ -65,36 +72,27 @@ def __init__( def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor - ) -> torch.Tensor: + ) -> CQLCriticLoss: loss = super().compute_critic_loss(batch, q_tpn) conservative_loss = self._compute_conservative_loss( batch.observations, batch.actions, batch.next_observations ) - return loss + conservative_loss + if self._modules.alpha_optim: + self.update_alpha(conservative_loss) + return CQLCriticLoss( + critic_loss=loss.critic_loss + conservative_loss, + conservative_loss=conservative_loss, + alpha=self._modules.log_alpha().exp(), + ) - def update_alpha(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_alpha(self, conservative_loss: torch.Tensor) -> None: assert self._modules.alpha_optim - - # Q function should be inference mode for stability - self._modules.q_funcs.eval() - self._modules.alpha_optim.zero_grad() - # the original implementation does scale the loss value - loss = -self._compute_conservative_loss( - batch.observations, batch.actions, batch.next_observations - ) - - loss.backward() + loss = -conservative_loss + loss.backward(retain_graph=True) self._modules.alpha_optim.step() - cur_alpha = self._modules.log_alpha().exp().cpu().detach().numpy()[0][0] - - return { - "alpha_loss": float(loss.cpu().detach().numpy()), - "alpha": float(cur_alpha), - } - def _compute_policy_is_values( self, policy_obs: torch.Tensor, value_obs: torch.Tensor ) -> torch.Tensor: @@ -196,26 +194,6 @@ def _compute_deterministic_target( reduction="min", ) - def inner_update( - self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: - metrics = {} - - # lagrangian parameter update for SAC temperature - if self._modules.temp_optim: - metrics.update(self.update_temp(batch)) - - # lagrangian parameter update for conservative loss weight - if self._modules.alpha_optim: - metrics.update(self.update_alpha(batch)) - - metrics.update(self.update_critic(batch)) - metrics.update(self.update_actor(batch)) - - self.update_critic_target() - - return metrics - @dataclasses.dataclass(frozen=True) class DiscreteCQLLoss(DQNLoss): diff --git a/d3rlpy/algos/qlearning/torch/crr_impl.py b/d3rlpy/algos/qlearning/torch/crr_impl.py index 33d00507..923e2c85 100644 --- a/d3rlpy/algos/qlearning/torch/crr_impl.py +++ b/d3rlpy/algos/qlearning/torch/crr_impl.py @@ -5,13 +5,14 @@ import torch.nn.functional as F from ....models.torch import ( + ActionOutput, ContinuousEnsembleQFunctionForwarder, NormalPolicy, build_gaussian_distribution, ) from ....torch_utility import TorchMiniBatch, hard_sync, soft_sync from ....types import Shape -from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules +from .ddpg_impl import DDPGBaseActorLoss, DDPGBaseImpl, DDPGBaseModules __all__ = ["CRRImpl", "CRRModules"] @@ -68,16 +69,14 @@ def __init__( self._target_update_type = target_update_type self._target_update_interval = target_update_interval - def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: + def compute_actor_loss( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> DDPGBaseActorLoss: # compute log probability - dist = build_gaussian_distribution( - self._modules.policy(batch.observations) - ) + dist = build_gaussian_distribution(action) log_probs = dist.log_prob(batch.actions) - weight = self._compute_weight(batch.observations, batch.actions) - - return -(log_probs * weight).mean() + return DDPGBaseActorLoss(-(log_probs * weight).mean()) def _compute_weight( self, obs_t: torch.Tensor, act_t: torch.Tensor @@ -187,8 +186,9 @@ def inner_update( self, batch: TorchMiniBatch, grad_step: int ) -> Dict[str, float]: metrics = {} + action = self._modules.policy(batch.observations) metrics.update(self.update_critic(batch)) - metrics.update(self.update_actor(batch)) + metrics.update(self.update_actor(batch, action)) if self._target_update_type == "hard": if grad_step % self._target_update_interval == 0: diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 58396a65..4ab13a35 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -6,13 +6,25 @@ from torch import nn from torch.optim import Optimizer -from ....models.torch import ContinuousEnsembleQFunctionForwarder, Policy +from ....dataclass_utils import asdict_as_float +from ....models.torch import ( + ActionOutput, + ContinuousEnsembleQFunctionForwarder, + Policy, +) from ....torch_utility import Modules, TorchMiniBatch, hard_sync, soft_sync from ....types import Shape from ..base import QLearningAlgoImplBase from .utility import ContinuousQFunctionMixin -__all__ = ["DDPGImpl", "DDPGBaseImpl", "DDPGBaseModules", "DDPGModules"] +__all__ = [ + "DDPGImpl", + "DDPGBaseImpl", + "DDPGBaseModules", + "DDPGModules", + "DDPGBaseActorLoss", + "DDPGBaseCriticLoss", +] @dataclasses.dataclass(frozen=True) @@ -24,6 +36,16 @@ class DDPGBaseModules(Modules): critic_optim: Optimizer +@dataclasses.dataclass(frozen=True) +class DDPGBaseActorLoss: + actor_loss: torch.Tensor + + +@dataclasses.dataclass(frozen=True) +class DDPGBaseCriticLoss: + critic_loss: torch.Tensor + + class DDPGBaseImpl( ContinuousQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta ): @@ -58,20 +80,16 @@ def __init__( def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.critic_optim.zero_grad() - q_tpn = self.compute_target(batch) - loss = self.compute_critic_loss(batch, q_tpn) - - loss.backward() + loss.critic_loss.backward() self._modules.critic_optim.step() - - return {"critic_loss": float(loss.cpu().detach().numpy())} + return asdict_as_float(loss) def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor - ) -> torch.Tensor: - return self._q_func_forwarder.compute_error( + ) -> DDPGBaseCriticLoss: + loss = self._q_func_forwarder.compute_error( observations=batch.observations, actions=batch.actions, rewards=batch.rewards, @@ -79,31 +97,33 @@ def compute_critic_loss( terminals=batch.terminals, gamma=self._gamma**batch.intervals, ) + return DDPGBaseCriticLoss(loss) - def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_actor( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> Dict[str, float]: # Q function should be inference mode for stability self._modules.q_funcs.eval() - self._modules.actor_optim.zero_grad() - - loss = self.compute_actor_loss(batch) - - loss.backward() + loss = self.compute_actor_loss(batch, action) + loss.actor_loss.backward() self._modules.actor_optim.step() - - return {"actor_loss": float(loss.cpu().detach().numpy())} + return asdict_as_float(loss) def inner_update( self, batch: TorchMiniBatch, grad_step: int ) -> Dict[str, float]: metrics = {} + action = self._modules.policy(batch.observations) metrics.update(self.update_critic(batch)) - metrics.update(self.update_actor(batch)) + metrics.update(self.update_actor(batch, action)) self.update_critic_target() return metrics @abstractmethod - def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: + def compute_actor_loss( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> DDPGBaseActorLoss: pass @abstractmethod @@ -168,12 +188,13 @@ def __init__( ) hard_sync(self._modules.targ_policy, self._modules.policy) - def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: - action = self._modules.policy(batch.observations) + def compute_actor_loss( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> DDPGBaseActorLoss: q_t = self._q_func_forwarder.compute_expected_q( batch.observations, action.squashed_mu, "none" )[0] - return -q_t.mean() + return DDPGBaseActorLoss(-q_t.mean()) def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): diff --git a/d3rlpy/algos/qlearning/torch/iql_impl.py b/d3rlpy/algos/qlearning/torch/iql_impl.py index f2acf2c4..d6155ca6 100644 --- a/d3rlpy/algos/qlearning/torch/iql_impl.py +++ b/d3rlpy/algos/qlearning/torch/iql_impl.py @@ -1,9 +1,9 @@ import dataclasses -from typing import Dict import torch from ....models.torch import ( + ActionOutput, ContinuousEnsembleQFunctionForwarder, NormalPolicy, ValueFunction, @@ -11,7 +11,12 @@ ) from ....torch_utility import TorchMiniBatch from ....types import Shape -from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules +from .ddpg_impl import ( + DDPGBaseActorLoss, + DDPGBaseCriticLoss, + DDPGBaseImpl, + DDPGBaseModules, +) __all__ = ["IQLImpl", "IQLModules"] @@ -22,6 +27,12 @@ class IQLModules(DDPGBaseModules): value_func: ValueFunction +@dataclasses.dataclass(frozen=True) +class IQLCriticLoss(DDPGBaseCriticLoss): + q_loss: torch.Tensor + v_loss: torch.Tensor + + class IQLImpl(DDPGBaseImpl): _modules: IQLModules _expectile: float @@ -58,8 +69,8 @@ def __init__( def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor - ) -> torch.Tensor: - return self._q_func_forwarder.compute_error( + ) -> IQLCriticLoss: + q_loss = self._q_func_forwarder.compute_error( observations=batch.observations, actions=batch.actions, rewards=batch.rewards, @@ -67,23 +78,27 @@ def compute_critic_loss( terminals=batch.terminals, gamma=self._gamma**batch.intervals, ) + v_loss = self.compute_value_loss(batch) + return IQLCriticLoss( + critic_loss=q_loss + v_loss, + q_loss=q_loss, + v_loss=v_loss, + ) def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): return self._modules.value_func(batch.next_observations) - def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: + def compute_actor_loss( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> DDPGBaseActorLoss: # compute log probability - dist = build_gaussian_distribution( - self._modules.policy(batch.observations) - ) + dist = build_gaussian_distribution(action) log_probs = dist.log_prob(batch.actions) - # compute weight with torch.no_grad(): weight = self._compute_weight(batch) - - return -(weight * log_probs).mean() + return DDPGBaseActorLoss(-(weight * log_probs).mean()) def _compute_weight(self, batch: TorchMiniBatch) -> torch.Tensor: q_t = self._targ_q_func_forwarder.compute_expected_q( @@ -102,37 +117,6 @@ def compute_value_loss(self, batch: TorchMiniBatch) -> torch.Tensor: weight = (self._expectile - (diff < 0.0).float()).abs().detach() return (weight * (diff**2)).mean() - def update_critic_and_state_value( - self, batch: TorchMiniBatch - ) -> Dict[str, float]: - self._modules.critic_optim.zero_grad() - - # compute Q-function loss - q_tpn = self.compute_target(batch) - q_loss = self.compute_critic_loss(batch, q_tpn) - - # compute value function loss - v_loss = self.compute_value_loss(batch) - - loss = q_loss + v_loss - - loss.backward() - self._modules.critic_optim.step() - - return { - "critic_loss": float(q_loss.cpu().detach().numpy()), - "v_loss": float(v_loss.cpu().detach().numpy()), - } - def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: dist = build_gaussian_distribution(self._modules.policy(x)) return dist.sample() - - def inner_update( - self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: - metrics = {} - metrics.update(self.update_critic_and_state_value(batch)) - metrics.update(self.update_actor(batch)) - self.update_critic_target() - return metrics diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index 7930540b..23cee209 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -5,6 +5,7 @@ from torch.optim import Optimizer from ....models.torch import ( + ActionOutput, ConditionalVAE, ContinuousEnsembleQFunctionForwarder, DeterministicPolicy, @@ -14,7 +15,7 @@ ) from ....torch_utility import TorchMiniBatch, soft_sync from ....types import Shape -from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules +from .ddpg_impl import DDPGBaseActorLoss, DDPGBaseImpl, DDPGBaseModules __all__ = [ "PLASImpl", @@ -68,29 +69,27 @@ def __init__( def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.imitator_optim.zero_grad() - loss = compute_vae_error( vae=self._modules.imitator, x=batch.observations, action=batch.actions, beta=self._beta, ) - loss.backward() self._modules.imitator_optim.step() - return {"imitator_loss": float(loss.cpu().detach().numpy())} - def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: - latent_actions = ( - 2.0 * self._modules.policy(batch.observations).squashed_mu - ) + def compute_actor_loss( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> DDPGBaseActorLoss: + latent_actions = 2.0 * action.squashed_mu actions = forward_vae_decode( self._modules.imitator, batch.observations, latent_actions ) - return -self._q_func_forwarder.compute_expected_q( + loss = -self._q_func_forwarder.compute_expected_q( batch.observations, actions, "none" )[0].mean() + return DDPGBaseActorLoss(loss) def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: latent_actions = 2.0 * self._modules.policy(x).squashed_mu @@ -126,8 +125,9 @@ def inner_update( if grad_step < self._warmup_steps: metrics.update(self.update_imitator(batch)) else: + action = self._modules.policy(batch.observations) metrics.update(self.update_critic(batch)) - metrics.update(self.update_actor(batch)) + metrics.update(self.update_actor(batch, action)) self.update_actor_target() self.update_critic_target() @@ -171,10 +171,10 @@ def __init__( device=device, ) - def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: - latent_actions = ( - 2.0 * self._modules.policy(batch.observations).squashed_mu - ) + def compute_actor_loss( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> DDPGBaseActorLoss: + latent_actions = 2.0 * action.squashed_mu actions = forward_vae_decode( self._modules.imitator, batch.observations, latent_actions ) @@ -184,7 +184,7 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: q_value = self._q_func_forwarder.compute_expected_q( batch.observations, residual_actions, "none" ) - return -q_value[0].mean() + return DDPGBaseActorLoss(-q_value[0].mean()) def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: latent_actions = 2.0 * self._modules.policy(x).squashed_mu diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index ad141ee5..e0805f50 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -8,6 +8,7 @@ from torch.optim import Optimizer from ....models.torch import ( + ActionOutput, CategoricalPolicy, ContinuousEnsembleQFunctionForwarder, DiscreteEnsembleQFunctionForwarder, @@ -19,10 +20,16 @@ from ....torch_utility import Modules, TorchMiniBatch, hard_sync from ....types import Shape from ..base import QLearningAlgoImplBase -from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules +from .ddpg_impl import DDPGBaseActorLoss, DDPGBaseImpl, DDPGBaseModules from .utility import DiscreteQFunctionMixin -__all__ = ["SACImpl", "DiscreteSACImpl", "SACModules", "DiscreteSACModules"] +__all__ = [ + "SACImpl", + "DiscreteSACImpl", + "SACModules", + "DiscreteSACModules", + "SACActorLoss", +] @dataclasses.dataclass(frozen=True) @@ -32,6 +39,12 @@ class SACModules(DDPGBaseModules): temp_optim: Optional[Optimizer] +@dataclasses.dataclass(frozen=True) +class SACActorLoss(DDPGBaseActorLoss): + temp: torch.Tensor + temp_loss: torch.Tensor + + class SACImpl(DDPGBaseImpl): _modules: SACModules @@ -57,40 +70,38 @@ def __init__( device=device, ) - def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: - dist = build_squashed_gaussian_distribution( - self._modules.policy(batch.observations) - ) - action, log_prob = dist.sample_with_log_prob() + def compute_actor_loss( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> SACActorLoss: + dist = build_squashed_gaussian_distribution(action) + sampled_action, log_prob = dist.sample_with_log_prob() + + if self._modules.temp_optim: + temp_loss = self.update_temp(log_prob) + else: + temp_loss = torch.tensor( + 0.0, dtype=torch.float32, device=sampled_action.device + ) + entropy = self._modules.log_temp().exp() * log_prob q_t = self._q_func_forwarder.compute_expected_q( - batch.observations, action, "min" + batch.observations, sampled_action, "min" + ) + return SACActorLoss( + actor_loss=(entropy - q_t).mean(), + temp_loss=temp_loss, + temp=self._modules.log_temp().exp(), ) - return (entropy - q_t).mean() - def update_temp(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_temp(self, log_prob: torch.Tensor) -> torch.Tensor: assert self._modules.temp_optim self._modules.temp_optim.zero_grad() - with torch.no_grad(): - dist = build_squashed_gaussian_distribution( - self._modules.policy(batch.observations) - ) - _, log_prob = dist.sample_with_log_prob() targ_temp = log_prob - self._action_size - loss = -(self._modules.log_temp().exp() * targ_temp).mean() - loss.backward() self._modules.temp_optim.step() - - # current temperature value - cur_temp = self._modules.log_temp().exp().cpu().detach().numpy()[0][0] - - return { - "temp_loss": float(loss.cpu().detach().numpy()), - "temp": float(cur_temp), - } + return loss def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): @@ -106,15 +117,6 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: ) return target - entropy - def inner_update( - self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: - metrics = {} - if self._modules.temp_optim: - metrics.update(self.update_temp(batch)) - metrics.update(super().inner_update(batch, grad_step)) - return metrics - def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: dist = build_squashed_gaussian_distribution(self._modules.policy(x)) return dist.sample() diff --git a/d3rlpy/algos/qlearning/torch/td3_impl.py b/d3rlpy/algos/qlearning/torch/td3_impl.py index 36e74816..33fdd235 100644 --- a/d3rlpy/algos/qlearning/torch/td3_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_impl.py @@ -69,7 +69,8 @@ def inner_update( # delayed policy update if grad_step % self._update_actor_interval == 0: - metrics.update(self.update_actor(batch)) + action = self._modules.policy(batch.observations) + metrics.update(self.update_actor(batch, action)) self.update_critic_target() self.update_actor_target() diff --git a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py index 85bb92fd..6f73b0de 100644 --- a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py @@ -1,16 +1,22 @@ # pylint: disable=too-many-ancestors +import dataclasses import torch -from ....models.torch import ContinuousEnsembleQFunctionForwarder +from ....models.torch import ActionOutput, ContinuousEnsembleQFunctionForwarder from ....torch_utility import TorchMiniBatch from ....types import Shape -from .ddpg_impl import DDPGModules +from .ddpg_impl import DDPGBaseActorLoss, DDPGModules from .td3_impl import TD3Impl __all__ = ["TD3PlusBCImpl"] +@dataclasses.dataclass(frozen=True) +class TD3PlusBCActorLoss(DDPGBaseActorLoss): + bc_loss: torch.Tensor + + class TD3PlusBCImpl(TD3Impl): _alpha: float @@ -44,10 +50,14 @@ def __init__( ) self._alpha = alpha - def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: - action = self._modules.policy(batch.observations).squashed_mu + def compute_actor_loss( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> TD3PlusBCActorLoss: q_t = self._q_func_forwarder.compute_expected_q( - batch.observations, action, "none" + batch.observations, action.squashed_mu, "none" )[0] lam = self._alpha / (q_t.abs().mean()).detach() - return lam * -q_t.mean() + ((batch.actions - action) ** 2).mean() + bc_loss = ((batch.actions - action.squashed_mu) ** 2).mean() + return TD3PlusBCActorLoss( + actor_loss=lam * -q_t.mean() + bc_loss, bc_loss=bc_loss + ) diff --git a/tests/algos/qlearning/test_bear.py b/tests/algos/qlearning/test_bear.py index 30e884c2..ec6d040f 100644 --- a/tests/algos/qlearning/test_bear.py +++ b/tests/algos/qlearning/test_bear.py @@ -31,6 +31,7 @@ def test_bear( observation_scaler=observation_scaler, action_scaler=action_scaler, reward_scaler=reward_scaler, + warmup_steps=0, ) bear = config.create() algo_tester(