Skip to content

Commit

Permalink
Implement CalQL
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Mar 24, 2024
1 parent beabb9f commit 48407b0
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 11 deletions.
163 changes: 163 additions & 0 deletions d3rlpy/algos/qlearning/cal_ql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import dataclasses
import math

from ...base import DeviceArg, register_learnable
from ...models.builders import (
create_continuous_q_function,
create_normal_policy,
create_parameter,
)
from ...types import Shape
from .cql import CQL, CQLConfig
from .torch.cal_ql_impl import CalQLImpl
from .torch.cql_impl import CQLModules

__all__ = ["CalQLConfig", "CalQL"]


@dataclasses.dataclass()
class CalQLConfig(CQLConfig):
r"""Config of Calibrated Q-Learning algorithm.
Cal-QL is an extension to CQL to mitigate issues in offline-to-online
fine-tuning.
The CQL regularizer is modified as follows:
.. math::
\mathbb{E}_{s \sim D, a \sim \pi} [\max{(Q(s, a), V(s))}]
- \mathbb{E}_{s, a \sim D} [Q(s, a)]
References:
* `Mitsuhiko et al., Cal-QL: Calibrated Offline RL Pre-Training for
Efficient Online Fine-Tuning. <https://arxiv.org/abs/2303.05479>`_
Args:
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor.
actor_learning_rate (float): Learning rate for policy function.
critic_learning_rate (float): Learning rate for Q functions.
temp_learning_rate (float):
Learning rate for temperature parameter of SAC.
alpha_learning_rate (float): Learning rate for :math:`\alpha`.
actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
Optimizer factory for the actor.
critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
Optimizer factory for the critic.
temp_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
Optimizer factory for the temperature.
alpha_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
Optimizer factory for :math:`\alpha`.
actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the actor.
critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the critic.
q_func_factory (d3rlpy.models.q_functions.QFunctionFactory):
Q function factory.
batch_size (int): Mini-batch size.
gamma (float): Discount factor.
tau (float): Target network synchronization coefficiency.
n_critics (int): Number of Q functions for ensemble.
initial_temperature (float): Initial temperature value.
initial_alpha (float): Initial :math:`\alpha` value.
alpha_threshold (float): Threshold value described as :math:`\tau`.
conservative_weight (float): Constant weight to scale conservative loss.
n_action_samples (int): Number of sampled actions to compute
:math:`\log{\sum_a \exp{Q(s, a)}}`.
soft_q_backup (bool): Flag to use SAC-style backup.
"""

def create(self, device: DeviceArg = False) -> "CalQL":
return CalQL(self, device)

@staticmethod
def get_type() -> str:
return "cal_ql"


class CalQL(CQL):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
policy = create_normal_policy(
observation_shape,
action_size,
self._config.actor_encoder_factory,
device=self._device,
)
q_funcs, q_func_fowarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
log_temp = create_parameter(
(1, 1),
math.log(self._config.initial_temperature),
device=self._device,
)
log_alpha = create_parameter(
(1, 1), math.log(self._config.initial_alpha), device=self._device
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)
if self._config.temp_learning_rate > 0:
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(), lr=self._config.temp_learning_rate
)
else:
temp_optim = None
if self._config.alpha_learning_rate > 0:
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.named_modules(), lr=self._config.alpha_learning_rate
)
else:
alpha_optim = None

modules = CQLModules(
policy=policy,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
log_temp=log_temp,
log_alpha=log_alpha,
actor_optim=actor_optim,
critic_optim=critic_optim,
temp_optim=temp_optim,
alpha_optim=alpha_optim,
)

self._impl = CalQLImpl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_fowarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
alpha_threshold=self._config.alpha_threshold,
conservative_weight=self._config.conservative_weight,
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
device=self._device,
)


register_learnable(CalQLConfig)
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .bc_impl import *
from .bcq_impl import *
from .bear_impl import *
from .cal_ql_impl import *
from .cql_impl import *
from .crr_impl import *
from .ddpg_impl import *
Expand Down
23 changes: 23 additions & 0 deletions d3rlpy/algos/qlearning/torch/cal_ql_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Tuple

import torch

from ....types import TorchObservation
from .cql_impl import CQLImpl

__all__ = ["CalQLImpl"]


class CalQLImpl(CQLImpl):
def _compute_policy_is_values(
self,
policy_obs: TorchObservation,
value_obs: TorchObservation,
returns_to_go: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
values, log_probs = super()._compute_policy_is_values(
policy_obs=policy_obs,
value_obs=value_obs,
returns_to_go=returns_to_go,
)
return torch.maximum(values, returns_to_go), log_probs
44 changes: 33 additions & 11 deletions d3rlpy/algos/qlearning/torch/cql_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import math
from typing import Optional
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -80,7 +80,10 @@ def compute_critic_loss(
) -> CQLCriticLoss:
loss = super().compute_critic_loss(batch, q_tpn)
conservative_loss = self._compute_conservative_loss(
batch.observations, batch.actions, batch.next_observations
obs_t=batch.observations,
act_t=batch.actions,
obs_tp1=batch.next_observations,
returns_to_go=batch.returns_to_go,
)
if self._modules.alpha_optim:
self.update_alpha(conservative_loss)
Expand All @@ -99,8 +102,11 @@ def update_alpha(self, conservative_loss: torch.Tensor) -> None:
self._modules.alpha_optim.step()

def _compute_policy_is_values(
self, policy_obs: TorchObservation, value_obs: TorchObservation
) -> torch.Tensor:
self,
policy_obs: TorchObservation,
value_obs: TorchObservation,
returns_to_go: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
dist = build_squashed_gaussian_distribution(
self._modules.policy(policy_obs)
Expand Down Expand Up @@ -133,9 +139,11 @@ def _compute_policy_is_values(
log_probs = n_log_probs.view(1, -1, self._n_action_samples)

# importance sampling
return policy_values - log_probs
return policy_values, log_probs

def _compute_random_is_values(self, obs: TorchObservation) -> torch.Tensor:
def _compute_random_is_values(
self, obs: TorchObservation
) -> Tuple[torch.Tensor, float]:
# (batch, observation) -> (batch, n, observation)
repeated_obs = expand_and_repeat_recursively(
obs, self._n_action_samples
Expand All @@ -160,22 +168,36 @@ def _compute_random_is_values(self, obs: TorchObservation) -> torch.Tensor:
random_log_probs = math.log(0.5**self._action_size)

# importance sampling
return random_values - random_log_probs
return random_values, random_log_probs

def _compute_conservative_loss(
self,
obs_t: TorchObservation,
act_t: torch.Tensor,
obs_tp1: TorchObservation,
returns_to_go: torch.Tensor,
) -> torch.Tensor:
policy_values_t = self._compute_policy_is_values(obs_t, obs_t)
policy_values_tp1 = self._compute_policy_is_values(obs_tp1, obs_t)
random_values = self._compute_random_is_values(obs_t)
policy_values_t, log_probs_t = self._compute_policy_is_values(
policy_obs=obs_t,
value_obs=obs_t,
returns_to_go=returns_to_go,
)
policy_values_tp1, log_probs_tp1 = self._compute_policy_is_values(
policy_obs=obs_tp1,
value_obs=obs_t,
returns_to_go=returns_to_go,
)
random_values, random_log_probs = self._compute_random_is_values(obs_t)

# compute logsumexp
# (n critics, batch, 3 * n samples) -> (n critics, batch, 1)
target_values = torch.cat(
[policy_values_t, policy_values_tp1, random_values], dim=2
[
policy_values_t - log_probs_t,
policy_values_tp1 - log_probs_tp1,
random_values - random_log_probs,
],
dim=2,
)
logsumexp = torch.logsumexp(target_values, dim=2, keepdim=True)

Expand Down
42 changes: 42 additions & 0 deletions tests/algos/qlearning/test_cal_ql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Optional

import pytest

from d3rlpy.algos.qlearning.cal_ql import CalQLConfig
from d3rlpy.models import (
MeanQFunctionFactory,
QFunctionFactory,
QRQFunctionFactory,
)
from d3rlpy.types import Shape

from ...models.torch.model_test import DummyEncoderFactory
from ...testing_utils import create_scaler_tuple
from .algo_test import algo_tester


@pytest.mark.parametrize(
"observation_shape", [(100,), (4, 84, 84), ((100,), (200,))]
)
@pytest.mark.parametrize(
"q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
)
@pytest.mark.parametrize("scalers", [None, "min_max"])
def test_cal_ql(
observation_shape: Shape,
q_func_factory: QFunctionFactory,
scalers: Optional[str],
) -> None:
observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
scalers, observation_shape
)
config = CalQLConfig(
actor_encoder_factory=DummyEncoderFactory(),
critic_encoder_factory=DummyEncoderFactory(),
q_func_factory=q_func_factory,
observation_scaler=observation_scaler,
action_scaler=action_scaler,
reward_scaler=reward_scaler,
)
cal_ql = config.create()
algo_tester(cal_ql, observation_shape) # type: ignore

0 comments on commit 48407b0

Please sign in to comment.