diff --git a/d3rlpy/algos/qlearning/cal_ql.py b/d3rlpy/algos/qlearning/cal_ql.py new file mode 100644 index 00000000..316c04a9 --- /dev/null +++ b/d3rlpy/algos/qlearning/cal_ql.py @@ -0,0 +1,163 @@ +import dataclasses +import math + +from ...base import DeviceArg, register_learnable +from ...models.builders import ( + create_continuous_q_function, + create_normal_policy, + create_parameter, +) +from ...types import Shape +from .cql import CQL, CQLConfig +from .torch.cal_ql_impl import CalQLImpl +from .torch.cql_impl import CQLModules + +__all__ = ["CalQLConfig", "CalQL"] + + +@dataclasses.dataclass() +class CalQLConfig(CQLConfig): + r"""Config of Calibrated Q-Learning algorithm. + + Cal-QL is an extension to CQL to mitigate issues in offline-to-online + fine-tuning. + + The CQL regularizer is modified as follows: + + .. math:: + + \mathbb{E}_{s \sim D, a \sim \pi} [\max{(Q(s, a), V(s))}] + - \mathbb{E}_{s, a \sim D} [Q(s, a)] + + References: + * `Mitsuhiko et al., Cal-QL: Calibrated Offline RL Pre-Training for + Efficient Online Fine-Tuning. `_ + + Args: + observation_scaler (d3rlpy.preprocessing.ObservationScaler): + Observation preprocessor. + action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. + reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor. + actor_learning_rate (float): Learning rate for policy function. + critic_learning_rate (float): Learning rate for Q functions. + temp_learning_rate (float): + Learning rate for temperature parameter of SAC. + alpha_learning_rate (float): Learning rate for :math:`\alpha`. + actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): + Optimizer factory for the actor. + critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): + Optimizer factory for the critic. + temp_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): + Optimizer factory for the temperature. + alpha_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): + Optimizer factory for :math:`\alpha`. + actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory): + Encoder factory for the actor. + critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory): + Encoder factory for the critic. + q_func_factory (d3rlpy.models.q_functions.QFunctionFactory): + Q function factory. + batch_size (int): Mini-batch size. + gamma (float): Discount factor. + tau (float): Target network synchronization coefficiency. + n_critics (int): Number of Q functions for ensemble. + initial_temperature (float): Initial temperature value. + initial_alpha (float): Initial :math:`\alpha` value. + alpha_threshold (float): Threshold value described as :math:`\tau`. + conservative_weight (float): Constant weight to scale conservative loss. + n_action_samples (int): Number of sampled actions to compute + :math:`\log{\sum_a \exp{Q(s, a)}}`. + soft_q_backup (bool): Flag to use SAC-style backup. + """ + + def create(self, device: DeviceArg = False) -> "CalQL": + return CalQL(self, device) + + @staticmethod + def get_type() -> str: + return "cal_ql" + + +class CalQL(CQL): + def inner_create_impl( + self, observation_shape: Shape, action_size: int + ) -> None: + policy = create_normal_policy( + observation_shape, + action_size, + self._config.actor_encoder_factory, + device=self._device, + ) + q_funcs, q_func_fowarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + log_temp = create_parameter( + (1, 1), + math.log(self._config.initial_temperature), + device=self._device, + ) + log_alpha = create_parameter( + (1, 1), math.log(self._config.initial_alpha), device=self._device + ) + + actor_optim = self._config.actor_optim_factory.create( + policy.named_modules(), lr=self._config.actor_learning_rate + ) + critic_optim = self._config.critic_optim_factory.create( + q_funcs.named_modules(), lr=self._config.critic_learning_rate + ) + if self._config.temp_learning_rate > 0: + temp_optim = self._config.temp_optim_factory.create( + log_temp.named_modules(), lr=self._config.temp_learning_rate + ) + else: + temp_optim = None + if self._config.alpha_learning_rate > 0: + alpha_optim = self._config.alpha_optim_factory.create( + log_alpha.named_modules(), lr=self._config.alpha_learning_rate + ) + else: + alpha_optim = None + + modules = CQLModules( + policy=policy, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + log_temp=log_temp, + log_alpha=log_alpha, + actor_optim=actor_optim, + critic_optim=critic_optim, + temp_optim=temp_optim, + alpha_optim=alpha_optim, + ) + + self._impl = CalQLImpl( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_fowarder, + targ_q_func_forwarder=targ_q_func_forwarder, + gamma=self._config.gamma, + tau=self._config.tau, + alpha_threshold=self._config.alpha_threshold, + conservative_weight=self._config.conservative_weight, + n_action_samples=self._config.n_action_samples, + soft_q_backup=self._config.soft_q_backup, + device=self._device, + ) + + +register_learnable(CalQLConfig) diff --git a/d3rlpy/algos/qlearning/torch/__init__.py b/d3rlpy/algos/qlearning/torch/__init__.py index 7c5ab2b7..89c7cb13 100644 --- a/d3rlpy/algos/qlearning/torch/__init__.py +++ b/d3rlpy/algos/qlearning/torch/__init__.py @@ -2,6 +2,7 @@ from .bc_impl import * from .bcq_impl import * from .bear_impl import * +from .cal_ql_impl import * from .cql_impl import * from .crr_impl import * from .ddpg_impl import * diff --git a/d3rlpy/algos/qlearning/torch/cal_ql_impl.py b/d3rlpy/algos/qlearning/torch/cal_ql_impl.py new file mode 100644 index 00000000..8079ed25 --- /dev/null +++ b/d3rlpy/algos/qlearning/torch/cal_ql_impl.py @@ -0,0 +1,23 @@ +from typing import Tuple + +import torch + +from ....types import TorchObservation +from .cql_impl import CQLImpl + +__all__ = ["CalQLImpl"] + + +class CalQLImpl(CQLImpl): + def _compute_policy_is_values( + self, + policy_obs: TorchObservation, + value_obs: TorchObservation, + returns_to_go: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + values, log_probs = super()._compute_policy_is_values( + policy_obs=policy_obs, + value_obs=value_obs, + returns_to_go=returns_to_go, + ) + return torch.maximum(values, returns_to_go), log_probs diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 57405cb9..e86432c5 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -1,6 +1,6 @@ import dataclasses import math -from typing import Optional +from typing import Optional, Tuple import torch import torch.nn.functional as F @@ -80,7 +80,10 @@ def compute_critic_loss( ) -> CQLCriticLoss: loss = super().compute_critic_loss(batch, q_tpn) conservative_loss = self._compute_conservative_loss( - batch.observations, batch.actions, batch.next_observations + obs_t=batch.observations, + act_t=batch.actions, + obs_tp1=batch.next_observations, + returns_to_go=batch.returns_to_go, ) if self._modules.alpha_optim: self.update_alpha(conservative_loss) @@ -99,8 +102,11 @@ def update_alpha(self, conservative_loss: torch.Tensor) -> None: self._modules.alpha_optim.step() def _compute_policy_is_values( - self, policy_obs: TorchObservation, value_obs: TorchObservation - ) -> torch.Tensor: + self, + policy_obs: TorchObservation, + value_obs: TorchObservation, + returns_to_go: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: with torch.no_grad(): dist = build_squashed_gaussian_distribution( self._modules.policy(policy_obs) @@ -133,9 +139,11 @@ def _compute_policy_is_values( log_probs = n_log_probs.view(1, -1, self._n_action_samples) # importance sampling - return policy_values - log_probs + return policy_values, log_probs - def _compute_random_is_values(self, obs: TorchObservation) -> torch.Tensor: + def _compute_random_is_values( + self, obs: TorchObservation + ) -> Tuple[torch.Tensor, float]: # (batch, observation) -> (batch, n, observation) repeated_obs = expand_and_repeat_recursively( obs, self._n_action_samples @@ -160,22 +168,36 @@ def _compute_random_is_values(self, obs: TorchObservation) -> torch.Tensor: random_log_probs = math.log(0.5**self._action_size) # importance sampling - return random_values - random_log_probs + return random_values, random_log_probs def _compute_conservative_loss( self, obs_t: TorchObservation, act_t: torch.Tensor, obs_tp1: TorchObservation, + returns_to_go: torch.Tensor, ) -> torch.Tensor: - policy_values_t = self._compute_policy_is_values(obs_t, obs_t) - policy_values_tp1 = self._compute_policy_is_values(obs_tp1, obs_t) - random_values = self._compute_random_is_values(obs_t) + policy_values_t, log_probs_t = self._compute_policy_is_values( + policy_obs=obs_t, + value_obs=obs_t, + returns_to_go=returns_to_go, + ) + policy_values_tp1, log_probs_tp1 = self._compute_policy_is_values( + policy_obs=obs_tp1, + value_obs=obs_t, + returns_to_go=returns_to_go, + ) + random_values, random_log_probs = self._compute_random_is_values(obs_t) # compute logsumexp # (n critics, batch, 3 * n samples) -> (n critics, batch, 1) target_values = torch.cat( - [policy_values_t, policy_values_tp1, random_values], dim=2 + [ + policy_values_t - log_probs_t, + policy_values_tp1 - log_probs_tp1, + random_values - random_log_probs, + ], + dim=2, ) logsumexp = torch.logsumexp(target_values, dim=2, keepdim=True) diff --git a/tests/algos/qlearning/test_cal_ql.py b/tests/algos/qlearning/test_cal_ql.py new file mode 100644 index 00000000..08ae982e --- /dev/null +++ b/tests/algos/qlearning/test_cal_ql.py @@ -0,0 +1,42 @@ +from typing import Optional + +import pytest + +from d3rlpy.algos.qlearning.cal_ql import CalQLConfig +from d3rlpy.models import ( + MeanQFunctionFactory, + QFunctionFactory, + QRQFunctionFactory, +) +from d3rlpy.types import Shape + +from ...models.torch.model_test import DummyEncoderFactory +from ...testing_utils import create_scaler_tuple +from .algo_test import algo_tester + + +@pytest.mark.parametrize( + "observation_shape", [(100,), (4, 84, 84), ((100,), (200,))] +) +@pytest.mark.parametrize( + "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] +) +@pytest.mark.parametrize("scalers", [None, "min_max"]) +def test_cal_ql( + observation_shape: Shape, + q_func_factory: QFunctionFactory, + scalers: Optional[str], +) -> None: + observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( + scalers, observation_shape + ) + config = CalQLConfig( + actor_encoder_factory=DummyEncoderFactory(), + critic_encoder_factory=DummyEncoderFactory(), + q_func_factory=q_func_factory, + observation_scaler=observation_scaler, + action_scaler=action_scaler, + reward_scaler=reward_scaler, + ) + cal_ql = config.create() + algo_tester(cal_ql, observation_shape) # type: ignore