diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index 4f1ce04c..cbdcb9cd 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -3,18 +3,29 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace from ...models.builders import ( + create_categorical_policy, create_continuous_q_function, + create_discrete_q_function, create_normal_policy, create_value_function, ) from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field -from ...models.q_functions import MeanQFunctionFactory +from ...models.q_functions import ( + MeanQFunctionFactory, + QFunctionFactory, + make_q_func_field, +) from ...types import Shape from .base import QLearningAlgoBase -from .torch.iql_impl import IQLImpl, IQLModules +from .torch.iql_impl import ( + DiscreteIQLImpl, + DiscreteIQLModules, + IQLImpl, + IQLModules, +) -__all__ = ["IQLConfig", "IQL"] +__all__ = ["IQLConfig", "IQL", "DiscreteIQLConfig", "DiscreteIQL"] @dataclasses.dataclass() @@ -176,4 +187,165 @@ def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS +@dataclasses.dataclass() +class DiscreteIQLConfig(LearnableConfig): + r"""Implicit Q-Learning algorithm. + + IQL is the offline RL algorithm that avoids ever querying values of unseen + actions while still being able to perform multi-step dynamic programming + updates. + + There are three functions to train in IQL. First the state-value function + is trained via expectile regression. + + .. math:: + + L_V(\psi) = \mathbb{E}_{(s, a) \sim D} + [L_2^\tau (Q_\theta (s, a) - V_\psi (s))] + + where :math:`L_2^\tau (u) = |\tau - \mathbb{1}(u < 0)|u^2`. + + The Q-function is trained with the state-value function to avoid query the + actions. + + .. math:: + + L_Q(\theta) = \mathbb{E}_{(s, a, r, s') \sim D} + [(r + \gamma V_\psi(s') - Q_\theta(s, a))^2] + + Finally, the policy function is trained by using advantage weighted + regression compared with `IQL`, here we use a categorical policy. + + .. math:: + + L_\pi (\phi) = \mathbb{E}_{(s, a) \sim D} + [\exp(\beta (Q_\theta - V_\psi(s))) \log \pi_\phi(a|s)] + + References: + * `Kostrikov et al., Offline Reinforcement Learning with Implicit + Q-Learning. `_ + + Args: + observation_scaler (d3rlpy.preprocessing.ObservationScaler): + Observation preprocessor. + action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. + reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor. + actor_learning_rate (float): Learning rate for policy function. + critic_learning_rate (float): Learning rate for Q functions. + actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): + Optimizer factory for the actor. + critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): + Optimizer factory for the critic. + actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory): + Encoder factory for the actor. + critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory): + Encoder factory for the critic. + value_encoder_factory (d3rlpy.models.encoders.EncoderFactory): + Encoder factory for the value function. + batch_size (int): Mini-batch size. + gamma (float): Discount factor. + tau (float): Target network synchronization coefficiency. + n_critics (int): Number of Q functions for ensemble. + expectile (float): Expectile value for value function training. + weight_temp (float): Inverse temperature value represented as + :math:`\beta`. + max_weight (float): Maximum advantage weight value to clip. + """ + + actor_learning_rate: float = 3e-4 + critic_learning_rate: float = 3e-4 + + q_func_factory: QFunctionFactory = make_q_func_field() + encoder_factory: EncoderFactory = make_encoder_field() + value_encoder_factory: EncoderFactory = make_encoder_field() + critic_optim_factory: OptimizerFactory = make_optimizer_field() + + actor_encoder_factory: EncoderFactory = make_encoder_field() + actor_optim_factory: OptimizerFactory = make_optimizer_field() + + batch_size: int = 256 + gamma: float = 0.99 + tau: float = 0.005 + n_critics: int = 2 + expectile: float = 0.7 + weight_temp: float = 3.0 + max_weight: float = 100.0 + + def create(self, device: DeviceArg = False) -> "DiscreteIQL": + return DiscreteIQL(self, device) + + @staticmethod + def get_type() -> str: + return "discrete_iql" + + +class DiscreteIQL(QLearningAlgoBase[DiscreteIQLImpl, DiscreteIQLConfig]): + def inner_create_impl( + self, observation_shape: Shape, action_size: int + ) -> None: + policy = create_categorical_policy( + observation_shape, + action_size, + self._config.actor_encoder_factory, + self._device, + ) + q_funcs, q_func_forwarder = create_discrete_q_function( + observation_shape, + action_size, + self._config.encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function( + observation_shape, + action_size, + self._config.encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + value_func = create_value_function( + observation_shape, + self._config.value_encoder_factory, + device=self._device, + ) + + q_func_params = list(q_funcs.named_modules()) + v_func_params = list(value_func.named_modules()) + critic_optim = self._config.critic_optim_factory.create( + q_func_params + v_func_params, lr=self._config.critic_learning_rate + ) + actor_optim = self._config.actor_optim_factory.create( + policy.named_modules(), lr=self._config.actor_learning_rate + ) + + modules = DiscreteIQLModules( + policy=policy, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + value_func=value_func, + actor_optim=actor_optim, + critic_optim=critic_optim, + ) + + self._impl = DiscreteIQLImpl( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, + gamma=self._config.gamma, + tau=self._config.tau, + expectile=self._config.expectile, + weight_temp=self._config.weight_temp, + max_weight=self._config.max_weight, + device=self._device, + ) + + def get_action_type(self) -> ActionSpace: + return ActionSpace.DISCRETE + + register_learnable(IQLConfig) +register_learnable(DiscreteIQLConfig) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 1ce7181f..e3587a6a 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -10,12 +10,13 @@ from ....models.torch import ( ActionOutput, ContinuousEnsembleQFunctionForwarder, + DiscreteEnsembleQFunctionForwarder, Policy, ) from ....torch_utility import Modules, TorchMiniBatch, hard_sync, soft_sync from ....types import Shape, TorchObservation from ..base import QLearningAlgoImplBase -from .utility import ContinuousQFunctionMixin +from .utility import ContinuousQFunctionMixin, DiscreteQFunctionMixin __all__ = [ "DDPGImpl", @@ -157,6 +158,117 @@ def q_function_optim(self) -> Optimizer: return self._modules.critic_optim +class DiscreteDDPGBaseImpl( + DiscreteQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta +): + _modules: DDPGBaseModules + _gamma: float + _tau: float + _q_func_forwarder: DiscreteEnsembleQFunctionForwarder + _targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder + + def __init__( + self, + observation_shape: Shape, + action_size: int, + modules: DDPGBaseModules, + q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + gamma: float, + tau: float, + device: str, + ): + super().__init__( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + device=device, + ) + self._gamma = gamma + self._tau = tau + self._q_func_forwarder = q_func_forwarder + self._targ_q_func_forwarder = targ_q_func_forwarder + hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) + + def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: + self._modules.critic_optim.zero_grad() + q_tpn = self.compute_target(batch) + loss = self.compute_critic_loss(batch, q_tpn) + loss.critic_loss.backward() + self._modules.critic_optim.step() + return asdict_as_float(loss) + + def compute_critic_loss( + self, batch: TorchMiniBatch, q_tpn: torch.Tensor + ) -> DDPGBaseCriticLoss: + loss = self._q_func_forwarder.compute_error( + observations=batch.observations, + actions=batch.actions, + rewards=batch.rewards, + target=q_tpn, + terminals=batch.terminals, + gamma=self._gamma**batch.intervals, + ) + return DDPGBaseCriticLoss(loss) + + def update_actor( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> Dict[str, float]: + # Q function should be inference mode for stability + self._modules.q_funcs.eval() + self._modules.actor_optim.zero_grad() + loss = self.compute_actor_loss(batch, None) + loss.actor_loss.backward() + self._modules.actor_optim.step() + return asdict_as_float(loss) + + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + metrics = {} + action = self._modules.policy(batch.observations) + metrics.update(self.update_critic(batch)) + metrics.update(self.update_actor(batch, action)) + self.update_critic_target() + return metrics + + @abstractmethod + def compute_actor_loss( + self, batch: TorchMiniBatch, action: None + ) -> DDPGBaseActorLoss: + pass + + @abstractmethod + def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: + pass + + def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: + return torch.argmax(self._modules.policy(x).probs).unsqueeze(0) + + @abstractmethod + def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: + pass + + def update_critic_target(self) -> None: + soft_sync(self._modules.targ_q_funcs, self._modules.q_funcs, self._tau) + + @property + def policy(self) -> Policy: + return self._modules.policy + + @property + def policy_optim(self) -> Optimizer: + return self._modules.actor_optim + + @property + def q_function(self) -> nn.ModuleList: + return self._modules.q_funcs + + @property + def q_function_optim(self) -> Optimizer: + return self._modules.critic_optim + + @dataclasses.dataclass(frozen=True) class DDPGModules(DDPGBaseModules): targ_policy: Policy diff --git a/d3rlpy/algos/qlearning/torch/iql_impl.py b/d3rlpy/algos/qlearning/torch/iql_impl.py index 1dfd29d0..b2c54cb0 100644 --- a/d3rlpy/algos/qlearning/torch/iql_impl.py +++ b/d3rlpy/algos/qlearning/torch/iql_impl.py @@ -1,10 +1,13 @@ import dataclasses import torch +import torch.nn.functional as F from ....models.torch import ( ActionOutput, + CategoricalPolicy, ContinuousEnsembleQFunctionForwarder, + DiscreteEnsembleQFunctionForwarder, NormalPolicy, ValueFunction, build_gaussian_distribution, @@ -16,9 +19,10 @@ DDPGBaseCriticLoss, DDPGBaseImpl, DDPGBaseModules, + DiscreteDDPGBaseImpl, ) -__all__ = ["IQLImpl", "IQLModules"] +__all__ = ["IQLImpl", "IQLModules", "DiscreteIQLImpl", "DiscreteIQLModules"] @dataclasses.dataclass(frozen=True) @@ -120,3 +124,108 @@ def compute_value_loss(self, batch: TorchMiniBatch) -> torch.Tensor: def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: dist = build_gaussian_distribution(self._modules.policy(x)) return dist.sample() + + +@dataclasses.dataclass(frozen=True) +class DiscreteIQLModules(DDPGBaseModules): + policy: CategoricalPolicy + value_func: ValueFunction + + +@dataclasses.dataclass(frozen=True) +class DiscreteIQLCriticLoss(DDPGBaseCriticLoss): + q_loss: torch.Tensor + v_loss: torch.Tensor + + +class DiscreteIQLImpl(DiscreteDDPGBaseImpl): + _modules: DiscreteIQLModules + _expectile: float + _weight_temp: float + _max_weight: float + + def __init__( + self, + observation_shape: Shape, + action_size: int, + modules: DiscreteIQLModules, + q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + gamma: float, + tau: float, + expectile: float, + weight_temp: float, + max_weight: float, + device: str, + ): + super().__init__( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, + gamma=gamma, + tau=tau, + device=device, + ) + self._expectile = expectile + self._weight_temp = weight_temp + self._max_weight = max_weight + + def compute_critic_loss( + self, batch: TorchMiniBatch, q_tpn: torch.Tensor + ) -> IQLCriticLoss: + q_loss = self._q_func_forwarder.compute_error( + observations=batch.observations, + actions=batch.actions.long(), + rewards=batch.rewards, + target=q_tpn, + terminals=batch.terminals, + gamma=self._gamma**batch.intervals, + ) + v_loss = self.compute_value_loss(batch) + return IQLCriticLoss( + critic_loss=q_loss + v_loss, + q_loss=q_loss, + v_loss=v_loss, + ) + + def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: + with torch.no_grad(): + return self._modules.value_func(batch.next_observations) + + def compute_actor_loss( + self, batch: TorchMiniBatch, action: None + ) -> DDPGBaseActorLoss: + assert self._modules.policy + # compute weight + with torch.no_grad(): + v = self._modules.value_func(batch.observations) + min_Q = self._targ_q_func_forwarder.compute_target( + batch.observations, reduction="min" + ).gather(1, batch.actions.long()) + + exp_a = torch.exp((min_Q - v) * self._weight_temp).clamp( + max=self._max_weight + ) + # compute log probability + dist = self._modules.policy(batch.observations) + log_probs = dist.log_prob(batch.actions.squeeze(-1)).unsqueeze(1) + + return DDPGBaseActorLoss(-(exp_a * log_probs).mean()) + + def compute_value_loss(self, batch: TorchMiniBatch) -> torch.Tensor: + q_t = self._targ_q_func_forwarder.compute_expected_q(batch.observations) + one_hot = F.one_hot( + batch.actions.long().view(-1), num_classes=self.action_size + ) + q_t = (q_t * one_hot).sum(dim=1, keepdim=True) + + v_t = self._modules.value_func(batch.observations) + diff = q_t.detach() - v_t + weight = (self._expectile - (diff < 0.0).float()).abs().detach() + return (weight * (diff**2)).mean() + + def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: + dist = self._modules.policy(x) + return dist.sample() diff --git a/tests/algos/qlearning/test_iql.py b/tests/algos/qlearning/test_iql.py index 32c36864..c7598029 100644 --- a/tests/algos/qlearning/test_iql.py +++ b/tests/algos/qlearning/test_iql.py @@ -2,8 +2,11 @@ import pytest -from d3rlpy.algos.qlearning.iql import IQLConfig +from d3rlpy.algos.qlearning.iql import DiscreteIQLConfig, IQLConfig from d3rlpy.types import Shape +from d3rlpy.models import ( + QFunctionFactory, +) from ...models.torch.model_test import DummyEncoderFactory from ...testing_utils import create_scaler_tuple @@ -28,3 +31,26 @@ def test_iql(observation_shape: Shape, scalers: Optional[str]) -> None: ) iql = config.create() algo_tester(iql, observation_shape) # type: ignore + + +@pytest.mark.parametrize( + "observation_shape", [(100,), (4, 84, 84), ((100,), (200,))] +) +@pytest.mark.parametrize("scalers", [None, "min_max"]) +def test_discrete_iql( + observation_shape: Shape, + q_func_factory: QFunctionFactory, + scalers: Optional[str]) -> None: + observation_scaler, _, reward_scaler = create_scaler_tuple( + scalers, observation_shape + ) + config = DiscreteIQLConfig( + actor_encoder_factory=DummyEncoderFactory(), + encoder_factory=DummyEncoderFactory(), + value_encoder_factory=DummyEncoderFactory(), + q_func_factory=q_func_factory, + observation_scaler=observation_scaler, + reward_scaler=reward_scaler, + ) + iql = config.create() + algo_tester(iql, observation_shape) # type: ignore