From 4a6edc9530ffa8c33c0ec29db4057698b1449b09 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Fri, 4 Aug 2023 06:51:15 +0100 Subject: [PATCH 01/20] tracking of cql regularisation for continuous cql --- d3rlpy/algos/qlearning/cql.py | 3 ++- d3rlpy/algos/qlearning/torch/cql_impl.py | 18 +++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index 1825e1f3..4146b1de 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -206,8 +206,9 @@ def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: alpha_loss, alpha = self._impl.update_alpha(batch) metrics.update({"alpha_loss": alpha_loss, "alpha": alpha}) - critic_loss = self._impl.update_critic(batch) + critic_loss, cql_loss = self._impl.update_critic(batch) metrics.update({"critic_loss": critic_loss}) + metrics.update({"cql_loss": cql_loss}) actor_loss = self._impl.update_actor(batch) metrics.update({"actor_loss": actor_loss}) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index cf81b715..ed249a62 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -75,7 +75,23 @@ def compute_critic_loss( conservative_loss = self._compute_conservative_loss( batch.observations, batch.actions, batch.next_observations ) - return loss + conservative_loss + return loss + conservative_loss, conservative_loss + + @train_api + def update_critic(self, batch: TorchMiniBatch) -> float: + self._critic_optim.zero_grad() + + q_tpn = self.compute_target(batch) + + loss, cql_loss = self.compute_critic_loss(batch, q_tpn) + + loss.backward() + self._critic_optim.step() + + res = np.array( + [loss.cpu().detach().numpy(), cql_loss.cpu().detach().numpy()] + ) + return res @train_api def update_alpha(self, batch: TorchMiniBatch) -> Tuple[float, float]: From 5b7185cf3b8565f5a55f5aa96d83efc48cf16bbe Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Fri, 4 Aug 2023 11:32:29 +0100 Subject: [PATCH 02/20] updated for linting and formatting --- d3rlpy/algos/qlearning/torch/cql_impl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index ed249a62..491c2caf 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -76,9 +76,9 @@ def compute_critic_loss( batch.observations, batch.actions, batch.next_observations ) return loss + conservative_loss, conservative_loss - + @train_api - def update_critic(self, batch: TorchMiniBatch) -> float: + def update_critic(self, batch: TorchMiniBatch) -> np.array: self._critic_optim.zero_grad() q_tpn = self.compute_target(batch) @@ -88,9 +88,9 @@ def update_critic(self, batch: TorchMiniBatch) -> float: loss.backward() self._critic_optim.step() - res = np.array( - [loss.cpu().detach().numpy(), cql_loss.cpu().detach().numpy()] - ) + critic_loss = float(loss.cpu().detach().numpy()) + cql_loss = float(cql_loss.cpu().detach().numpy()) + res = np.array([critic_loss, cql_loss]) return res @train_api From 7fd0a37805895d23eecd472a58fe458c09add657 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Thu, 10 Aug 2023 18:02:33 +0100 Subject: [PATCH 03/20] overwriting dr3 pull and aligning cql logging --- d3rlpy/algos/qlearning/cql.py | 4 ++-- d3rlpy/algos/qlearning/torch/cql_impl.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index 4146b1de..696eaf43 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -206,9 +206,9 @@ def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: alpha_loss, alpha = self._impl.update_alpha(batch) metrics.update({"alpha_loss": alpha_loss, "alpha": alpha}) - critic_loss, cql_loss = self._impl.update_critic(batch) + critic_loss, conservative_loss = self._impl.update_critic(batch) metrics.update({"critic_loss": critic_loss}) - metrics.update({"cql_loss": cql_loss}) + metrics.update({"conservative_loss": conservative_loss}) actor_loss = self._impl.update_actor(batch) metrics.update({"actor_loss": actor_loss}) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 491c2caf..4096b3e7 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -237,7 +237,8 @@ def compute_loss( conservative_loss = self._compute_conservative_loss( batch.observations, batch.actions.long() ) - return loss + self._alpha * conservative_loss, conservative_loss + cql_loss = self._alpha * conservative_loss + return loss + cql_loss, cql_loss def _compute_conservative_loss( self, obs_t: torch.Tensor, act_t: torch.Tensor From a52651eb6a1c6caaba086fe2be2523a4ddeb3eb0 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Thu, 10 Aug 2023 18:03:58 +0100 Subject: [PATCH 04/20] updated formatting --- d3rlpy/algos/qlearning/torch/cql_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 4096b3e7..19686191 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -238,7 +238,7 @@ def compute_loss( batch.observations, batch.actions.long() ) cql_loss = self._alpha * conservative_loss - return loss + cql_loss, cql_loss + return loss + cql_loss, cql_loss def _compute_conservative_loss( self, obs_t: torch.Tensor, act_t: torch.Tensor From a46d73f60d907a6c11594e43bb72d455971d3f51 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Thu, 10 Aug 2023 18:15:42 +0100 Subject: [PATCH 05/20] update gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 348562a3..3b61fe79 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ docs/d3rlpy*.rst docs/modules.rst docs/references/generated coverage.xml -.coverage +.coverage* .mypy_cache .ipynb_checkpoints build From 72551595633ddd40481246dcc9ad32e41c47d791 Mon Sep 17 00:00:00 2001 From: takuseno Date: Fri, 11 Aug 2023 22:58:32 +0900 Subject: [PATCH 06/20] Fix custom network docs --- docs/references/network_architectures.rst | 24 +++++++------- docs/tutorials/customize_neural_network.rst | 36 ++++++++++----------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/docs/references/network_architectures.rst b/docs/references/network_architectures.rst index 728f519f..2ac807e7 100644 --- a/docs/references/network_architectures.rst +++ b/docs/references/network_architectures.rst @@ -30,6 +30,7 @@ You can also build your own encoder factory. .. code-block:: python + import dataclasses import torch import torch.nn as nn @@ -48,17 +49,16 @@ You can also build your own encoder factory. return h # your own encoder factory + @dataclasses.dataclass() class CustomEncoderFactory(EncoderFactory): - TYPE = 'custom' # this is necessary - - def __init__(self, feature_size): - self.feature_size = feature_size + feature_size: int def create(self, observation_shape): return CustomEncoder(observation_shape, self.feature_size) - def get_params(self, deep=False): - return {'feature_size': self.feature_size} + @staticmethod + def get_type() -> str: + return "custom" dqn = d3rlpy.algos.DQNConfig( encoder_factory=CustomEncoderFactory(feature_size=64), @@ -83,11 +83,9 @@ controls. h = torch.relu(self.fc2(h)) return h + @dataclasses.dataclass() class CustomEncoderFactory(EncoderFactory): - TYPE = 'custom' # this is necessary - - def __init__(self, feature_size): - self.feature_size = feature_size + feature_size: int def create(self, observation_shape): return CustomEncoder(observation_shape, self.feature_size) @@ -95,8 +93,10 @@ controls. def create_with_action(observation_shape, action_size, discrete_action): return CustomEncoderWithAction(observation_shape, action_size, self.feature_size) - def get_params(self, deep=False): - return {'feature_size': self.feature_size} + @staticmethod + def get_type() -> str: + return "custom" + factory = CustomEncoderFactory(feature_size=64) diff --git a/docs/tutorials/customize_neural_network.rst b/docs/tutorials/customize_neural_network.rst index 020e38b0..f676f47a 100644 --- a/docs/tutorials/customize_neural_network.rst +++ b/docs/tutorials/customize_neural_network.rst @@ -33,25 +33,26 @@ If you're familiar with PyTorch, this step should be easy for you. Setup EncoderFactory -------------------- -Once you setup your PyTorch model, you need to setup ``EncoderFactory``. -In your ``EncoderFactory`` class, you need to define ``create`` and -``get_params`` methods as well as ``TYPE`` attribute. -``TYPE`` attribute and ``get_params`` method are used to serialize your -customized neural network configuration. +Once you setup your PyTorch model, you need to setup ``EncoderFactory`` as a +dataclass class. In your ``EncoderFactory`` class, you need to define +``create`` and ``get_type``. +``get_type`` method is used to serialize your customized neural network +configuration. .. code-block:: python - class CustomEncoderFactory(d3rlpy.models.encoders.EncoderFactory): - TYPE = "custom" # this is necessary + import dataclasses - def __init__(self, feature_size): - self.feature_size = feature_size + @dataclasses.dataclass() + class CustomEncoderFactory(d3rlpy.models.EncoderFactory): + feature_size: int def create(self, observation_shape): return CustomEncoder(observation_shape, self.feature_size) - def get_params(self, deep=False): - return {"feature_size": self.feature_size} + @staticmethod + def get_type() -> str: + return "custom" Now, you can use your model with d3rlpy. @@ -89,11 +90,9 @@ Finally, you can update your ``CustomEncoderFactory`` as follows. .. code-block:: python - class CustomEncoderFactory(EncoderFactory): - TYPE = "custom" - - def __init__(self, feature_size): - self.feature_size = feature_size + @dataclasses.dataclass() + class CustomEncoderFactory(d3rlpy.models.EncoderFactory): + feature_size: int def create(self, observation_shape): return CustomEncoder(observation_shape, self.feature_size) @@ -101,8 +100,9 @@ Finally, you can update your ``CustomEncoderFactory`` as follows. def create_with_action(self, observation_shape, action_size, discrete_action): return CustomEncoderWithAction(observation_shape, action_size, self.feature_size) - def get_params(self, deep=False): - return {"feature_size": self.feature_size} + @staticmethod + def get_type() -> str: + return "custom" Now, you can customize actor-critic algorithms. From 294544db27750686363f1946c2711e87fe6a60d5 Mon Sep 17 00:00:00 2001 From: asmith26 Date: Fri, 11 Aug 2023 23:43:33 +0100 Subject: [PATCH 07/20] Fix typo (#321) Fixed a small typo. Many thanks again! --- docs/references/network_architectures.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/references/network_architectures.rst b/docs/references/network_architectures.rst index 2ac807e7..8deb6633 100644 --- a/docs/references/network_architectures.rst +++ b/docs/references/network_architectures.rst @@ -11,7 +11,7 @@ Otherwise, the standard MLP architecture that consists with two linear layers with ``256`` hidden units. Furthermore, d3rlpy provides ``EncoderFactory`` that gives you flexible control -over this neural netowrk architectures. +over the neural network architectures. .. code-block:: python From b8263d4d2d392b1727a6fe30c24d3676f41cf213 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sat, 12 Aug 2023 10:25:48 +0900 Subject: [PATCH 08/20] Add TPU example notebook --- docs/notebooks.rst | 1 + tutorials/tpu.ipynb | 218 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 219 insertions(+) create mode 100644 tutorials/tpu.ipynb diff --git a/docs/notebooks.rst b/docs/notebooks.rst index d420e506..01a926b4 100644 --- a/docs/notebooks.rst +++ b/docs/notebooks.rst @@ -4,3 +4,4 @@ Jupyter Notebooks * `CartPole `_ * `CartPole (online) `_ * `Discrete Control with Atari `_ +* `TPU Example `_ diff --git a/tutorials/tpu.ipynb b/tutorials/tpu.ipynb new file mode 100644 index 00000000..7f1f4e7a --- /dev/null +++ b/tutorials/tpu.ipynb @@ -0,0 +1,218 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "nWVdtW83Rz8y" + }, + "source": [ + "Setup rendering dependencies for Google Colaboratory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "h56pgW5PRz81" + }, + "outputs": [], + "source": [ + "!apt-get install -y xvfb ffmpeg > /dev/null 2>&1\n", + "!pip install pyvirtualdisplay pygame moviepy > /dev/null 2>&1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PjIeDwAnRz82" + }, + "source": [ + "Install d3rlpy and PyTorch with TPU support! It likely fails to install the XLA dependency for the first time. If it's the case, simply restart the runtime and retry this." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "meFEBhUkRz83" + }, + "outputs": [], + "source": [ + "!pip install d3rlpy torch~=2.0.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wZjK4rLiRz83" + }, + "source": [ + "Setup cartpole dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DqxnnKMNRz83" + }, + "outputs": [], + "source": [ + "import d3rlpy\n", + "\n", + "# get CartPole dataset\n", + "dataset, env = d3rlpy.datasets.get_cartpole()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UEkMtKIHRz84" + }, + "source": [ + "Setup data-driven deep reinforcement learning algorithm with TPU. Currently, it's super slow to train the small architecture. But, you can see it works at least." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7kyhhMjcRz84" + }, + "outputs": [], + "source": [ + "# get TPU device\n", + "import torch_xla.core.xla_model as xm\n", + "device = xm.xla_device()\n", + "\n", + "# setup CQL algorithm\n", + "cql = d3rlpy.algos.DiscreteCQLConfig().create(device=str(device))\n", + "\n", + "# start training\n", + "cql.fit(\n", + " dataset,\n", + " n_steps=10000,\n", + " n_steps_per_epoch=1000,\n", + " evaluators={\n", + " 'environment': d3rlpy.metrics.EnvironmentEvaluator(env), # evaluate with CartPole-v1 environment\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RWNyt9WARz84" + }, + "source": [ + "Setup rendering utilities for Google Colaboratory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nYKQAPhzRz85" + }, + "outputs": [], + "source": [ + "import glob\n", + "import io\n", + "import base64\n", + "\n", + "from gym.wrappers import RecordVideo\n", + "from IPython.display import HTML\n", + "from IPython import display as ipythondisplay\n", + "from pyvirtualdisplay import Display\n", + "\n", + "# start virtual display\n", + "display = Display()\n", + "display.start()\n", + "\n", + "# play recorded video\n", + "def show_video():\n", + " mp4list = glob.glob('video/*.mp4')\n", + " if len(mp4list) > 0:\n", + " mp4 = mp4list[0]\n", + " video = io.open(mp4, 'r+b').read()\n", + " encoded = base64.b64encode(video)\n", + " ipythondisplay.display(HTML(data='''\n", + " '''.format(encoded.decode('ascii'))))\n", + " else:\n", + " print(\"Could not find video\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Wat4TgcYRz85" + }, + "source": [ + "Record video!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Pvf4zX5LRz85" + }, + "outputs": [], + "source": [ + "import gym\n", + "\n", + "# wrap RecordVideo wrapper\n", + "env = RecordVideo(gym.make(\"CartPole-v1\", render_mode=\"rgb_array\"), './video')\n", + "\n", + "# evaluate\n", + "d3rlpy.metrics.evaluate_qlearning_with_environment(cql, env)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NYKphVWYRz85" + }, + "source": [ + "Let's see how it works!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wcRJe7roRz86" + }, + "outputs": [], + "source": [ + "show_video()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + }, + "colab": { + "provenance": [] + }, + "accelerator": "TPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From 3ec25d777a7b37d29b7df1f308f0317bc3b7877a Mon Sep 17 00:00:00 2001 From: takuseno Date: Mon, 14 Aug 2023 09:28:50 +0900 Subject: [PATCH 09/20] Simpify Q functions with forwarders --- d3rlpy/algos/qlearning/awac.py | 17 +- d3rlpy/algos/qlearning/base.py | 11 +- d3rlpy/algos/qlearning/bcq.py | 40 ++- d3rlpy/algos/qlearning/bear.py | 17 +- d3rlpy/algos/qlearning/cql.py | 34 ++- d3rlpy/algos/qlearning/crr.py | 17 +- d3rlpy/algos/qlearning/ddpg.py | 17 +- d3rlpy/algos/qlearning/dqn.py | 34 ++- d3rlpy/algos/qlearning/iql.py | 17 +- d3rlpy/algos/qlearning/nfq.py | 17 +- d3rlpy/algos/qlearning/plas.py | 34 ++- d3rlpy/algos/qlearning/sac.py | 34 ++- d3rlpy/algos/qlearning/td3.py | 17 +- d3rlpy/algos/qlearning/td3_plus_bc.py | 17 +- d3rlpy/algos/qlearning/torch/awac_impl.py | 21 +- d3rlpy/algos/qlearning/torch/bcq_impl.py | 43 +++- d3rlpy/algos/qlearning/torch/bear_impl.py | 22 +- d3rlpy/algos/qlearning/torch/cql_impl.py | 47 +++- d3rlpy/algos/qlearning/torch/crr_impl.py | 29 ++- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 43 ++-- d3rlpy/algos/qlearning/torch/dqn_impl.py | 40 +-- d3rlpy/algos/qlearning/torch/iql_impl.py | 23 +- d3rlpy/algos/qlearning/torch/plas_impl.py | 35 ++- d3rlpy/algos/qlearning/torch/sac_impl.py | 59 +++-- d3rlpy/algos/qlearning/torch/td3_impl.py | 18 +- .../algos/qlearning/torch/td3_plus_bc_impl.py | 20 +- d3rlpy/algos/qlearning/torch/utility.py | 18 +- d3rlpy/models/builders.py | 42 +-- d3rlpy/models/q_functions.py | 64 +++-- d3rlpy/models/torch/q_functions/base.py | 95 ++++--- .../torch/q_functions/ensemble_q_function.py | 239 +++++++++++------- .../torch/q_functions/iqn_q_function.py | 161 +++++++----- .../torch/q_functions/mean_q_function.py | 107 +++++--- .../models/torch/q_functions/qr_q_function.py | 139 +++++----- d3rlpy/ope/fqe.py | 34 ++- d3rlpy/ope/torch/fqe_impl.py | 54 ++-- tests/models/test_builders.py | 24 +- tests/models/test_q_functions.py | 36 ++- tests/models/torch/model_test.py | 4 +- .../q_functions/test_ensemble_q_function.py | 150 +++++++---- .../torch/q_functions/test_iqn_q_function.py | 153 +++++++++-- .../torch/q_functions/test_mean_q_function.py | 96 +++++-- .../torch/q_functions/test_qr_q_function.py | 153 ++++++++--- tests/models/torch/test_q_functions.py | 4 +- 44 files changed, 1614 insertions(+), 682 deletions(-) diff --git a/d3rlpy/algos/qlearning/awac.py b/d3rlpy/algos/qlearning/awac.py index 9547e324..17cb6fc2 100644 --- a/d3rlpy/algos/qlearning/awac.py +++ b/d3rlpy/algos/qlearning/awac.py @@ -106,7 +106,15 @@ def inner_create_impl( use_std_parameter=True, device=self._device, ) - q_func = create_continuous_q_function( + q_funcs, q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, @@ -119,13 +127,16 @@ def inner_create_impl( policy.parameters(), lr=self._config.actor_learning_rate ) critic_optim = self._config.critic_optim_factory.create( - q_func.parameters(), lr=self._config.critic_learning_rate + q_funcs.parameters(), lr=self._config.critic_learning_rate ) self._impl = AWACImpl( observation_shape=observation_shape, action_size=action_size, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, policy=policy, actor_optim=actor_optim, critic_optim=critic_optim, diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index c5129bd8..64ec02c7 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -14,6 +14,7 @@ import numpy as np import torch +from torch import nn from tqdm.auto import tqdm, trange from typing_extensions import Self @@ -36,7 +37,7 @@ LoggerAdapterFactory, ) from ...metrics import EvaluatorProtocol, evaluate_qlearning_with_environment -from ...models.torch import EnsembleQFunction, Policy +from ...models.torch import Policy from ...torch_utility import ( TorchMiniBatch, convert_to_torch, @@ -119,15 +120,15 @@ def copy_policy_optim_from(self, impl: "QLearningAlgoImplBase") -> None: sync_optimizer_state(self.policy_optim, impl.policy_optim) @property - def q_function(self) -> EnsembleQFunction: + def q_function(self) -> nn.ModuleList: raise NotImplementedError def copy_q_function_from(self, impl: "QLearningAlgoImplBase") -> None: - q_func = self.q_function.q_funcs[0] - if not isinstance(impl.q_function.q_funcs[0], type(q_func)): + q_func = self.q_function[0] + if not isinstance(impl.q_function[0], type(q_func)): raise ValueError( f"Invalid Q-function type: expected={type(q_func)}," - f"actual={type(impl.q_function.q_funcs[0])}" + f"actual={type(impl.q_function[0])}" ) hard_sync(self.q_function, impl.q_function) diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index 83139293..a4710379 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -174,7 +174,15 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) - q_func = create_continuous_q_function( + q_funcs, q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, @@ -196,7 +204,7 @@ def inner_create_impl( policy.parameters(), lr=self._config.actor_learning_rate ) critic_optim = self._config.critic_optim_factory.create( - q_func.parameters(), lr=self._config.critic_learning_rate + q_funcs.parameters(), lr=self._config.critic_learning_rate ) imitator_optim = self._config.imitator_optim_factory.create( imitator.parameters(), lr=self._config.imitator_learning_rate @@ -206,7 +214,10 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, imitator=imitator, actor_optim=actor_optim, critic_optim=critic_optim, @@ -323,7 +334,15 @@ class DiscreteBCQ(QLearningAlgoBase[DiscreteBCQImpl, DiscreteBCQConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - q_func = create_discrete_q_function( + q_funcs, q_func_forwarder = create_discrete_q_function( + observation_shape, + action_size, + self._config.encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function( observation_shape, action_size, self._config.encoder_factory, @@ -333,14 +352,14 @@ def inner_create_impl( ) # share convolutional layers if observation is pixel - if isinstance(q_func.q_funcs[0].encoder, PixelEncoder): + if isinstance(q_funcs[0].encoder, PixelEncoder): hidden_size = compute_output_size( [observation_shape], - q_func.q_funcs[0].encoder, + q_funcs[0].encoder, device=self._device, ) imitator = CategoricalPolicy( - encoder=q_func.q_funcs[0].encoder, + encoder=q_funcs[0].encoder, hidden_size=hidden_size, action_size=action_size, ) @@ -355,7 +374,7 @@ def inner_create_impl( # TODO: replace this with a cleaner way # retrieve unique elements - q_func_params = list(q_func.parameters()) + q_func_params = list(q_funcs.parameters()) imitator_params = list(imitator.parameters()) unique_dict = {} for param in q_func_params + imitator_params: @@ -368,7 +387,10 @@ def inner_create_impl( self._impl = DiscreteBCQImpl( observation_shape=observation_shape, action_size=action_size, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, imitator=imitator, optim=optim, gamma=self._config.gamma, diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index 1f8456c9..65ad4d8e 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -164,7 +164,15 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) - q_func = create_continuous_q_function( + q_funcs, q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, @@ -194,7 +202,7 @@ def inner_create_impl( policy.parameters(), lr=self._config.actor_learning_rate ) critic_optim = self._config.critic_optim_factory.create( - q_func.parameters(), lr=self._config.critic_learning_rate + q_funcs.parameters(), lr=self._config.critic_learning_rate ) imitator_optim = self._config.imitator_optim_factory.create( imitator.parameters(), lr=self._config.imitator_learning_rate @@ -210,7 +218,10 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, imitator=imitator, log_temp=log_temp, log_alpha=log_alpha, diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index a15c5ba6..1410e6a5 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -141,7 +141,15 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) - q_func = create_continuous_q_function( + q_funcs, q_func_fowarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, @@ -162,7 +170,7 @@ def inner_create_impl( policy.parameters(), lr=self._config.actor_learning_rate ) critic_optim = self._config.critic_optim_factory.create( - q_func.parameters(), lr=self._config.critic_learning_rate + q_funcs.parameters(), lr=self._config.critic_learning_rate ) temp_optim = self._config.temp_optim_factory.create( log_temp.parameters(), lr=self._config.temp_learning_rate @@ -175,7 +183,10 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_fowarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, log_temp=log_temp, log_alpha=log_alpha, actor_optim=actor_optim, @@ -279,7 +290,15 @@ class DiscreteCQL(QLearningAlgoBase[DiscreteCQLImpl, DiscreteCQLConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - q_func = create_discrete_q_function( + q_funcs, q_func_forwarder = create_discrete_q_function( + observation_shape, + action_size, + self._config.encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function( observation_shape, action_size, self._config.encoder_factory, @@ -289,13 +308,16 @@ def inner_create_impl( ) optim = self._config.optim_factory.create( - q_func.parameters(), lr=self._config.learning_rate + q_funcs.parameters(), lr=self._config.learning_rate ) self._impl = DiscreteCQLImpl( observation_shape=observation_shape, action_size=action_size, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, optim=optim, gamma=self._config.gamma, alpha=self._config.alpha, diff --git a/d3rlpy/algos/qlearning/crr.py b/d3rlpy/algos/qlearning/crr.py index 250607d7..a0aaf2ac 100644 --- a/d3rlpy/algos/qlearning/crr.py +++ b/d3rlpy/algos/qlearning/crr.py @@ -140,7 +140,15 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) - q_func = create_continuous_q_function( + q_funcs, q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, @@ -153,14 +161,17 @@ def inner_create_impl( policy.parameters(), lr=self._config.actor_learning_rate ) critic_optim = self._config.critic_optim_factory.create( - q_func.parameters(), lr=self._config.critic_learning_rate + q_funcs.parameters(), lr=self._config.critic_learning_rate ) self._impl = CRRImpl( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, actor_optim=actor_optim, critic_optim=critic_optim, gamma=self._config.gamma, diff --git a/d3rlpy/algos/qlearning/ddpg.py b/d3rlpy/algos/qlearning/ddpg.py index 87919a50..92b8a905 100644 --- a/d3rlpy/algos/qlearning/ddpg.py +++ b/d3rlpy/algos/qlearning/ddpg.py @@ -101,7 +101,15 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) - q_func = create_continuous_q_function( + q_funcs, q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, @@ -114,14 +122,17 @@ def inner_create_impl( policy.parameters(), lr=self._config.actor_learning_rate ) critic_optim = self._config.critic_optim_factory.create( - q_func.parameters(), lr=self._config.critic_learning_rate + q_funcs.parameters(), lr=self._config.critic_learning_rate ) self._impl = DDPGImpl( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, actor_optim=actor_optim, critic_optim=critic_optim, gamma=self._config.gamma, diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index 5a15510a..31b8945f 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -68,7 +68,15 @@ class DQN(QLearningAlgoBase[DQNImpl, DQNConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - q_func = create_discrete_q_function( + q_funcs, forwarder = create_discrete_q_function( + observation_shape, + action_size, + self._config.encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_forwarder = create_discrete_q_function( observation_shape, action_size, self._config.encoder_factory, @@ -78,13 +86,16 @@ def inner_create_impl( ) optim = self._config.optim_factory.create( - q_func.parameters(), lr=self._config.learning_rate + q_funcs.parameters(), lr=self._config.learning_rate ) self._impl = DQNImpl( observation_shape=observation_shape, action_size=action_size, - q_func=q_func, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + q_func_forwarder=forwarder, + targ_q_func_forwarder=targ_forwarder, optim=optim, gamma=self._config.gamma, device=self._device, @@ -161,7 +172,15 @@ class DoubleDQN(DQN): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - q_func = create_discrete_q_function( + q_funcs, forwarder = create_discrete_q_function( + observation_shape, + action_size, + self._config.encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_forwarder = create_discrete_q_function( observation_shape, action_size, self._config.encoder_factory, @@ -171,13 +190,16 @@ def inner_create_impl( ) optim = self._config.optim_factory.create( - q_func.parameters(), lr=self._config.learning_rate + q_funcs.parameters(), lr=self._config.learning_rate ) self._impl = DoubleDQNImpl( observation_shape=observation_shape, action_size=action_size, - q_func=q_func, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + q_func_forwarder=forwarder, + targ_q_func_forwarder=targ_forwarder, optim=optim, gamma=self._config.gamma, device=self._device, diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index feb39639..a346718e 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -119,7 +119,15 @@ def inner_create_impl( use_std_parameter=True, device=self._device, ) - q_func = create_continuous_q_function( + q_funcs, q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + MeanQFunctionFactory(), + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, @@ -136,7 +144,7 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.parameters(), lr=self._config.actor_learning_rate ) - q_func_params = list(q_func.parameters()) + q_func_params = list(q_funcs.parameters()) v_func_params = list(value_func.parameters()) critic_optim = self._config.critic_optim_factory.create( q_func_params + v_func_params, lr=self._config.critic_learning_rate @@ -146,7 +154,10 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, value_func=value_func, actor_optim=actor_optim, critic_optim=critic_optim, diff --git a/d3rlpy/algos/qlearning/nfq.py b/d3rlpy/algos/qlearning/nfq.py index 6ff0a5b5..e6423cb1 100644 --- a/d3rlpy/algos/qlearning/nfq.py +++ b/d3rlpy/algos/qlearning/nfq.py @@ -70,7 +70,15 @@ class NFQ(QLearningAlgoBase[DQNImpl, NFQConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - q_func = create_discrete_q_function( + q_funcs, q_func_forwarder = create_discrete_q_function( + observation_shape, + action_size, + self._config.encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function( observation_shape, action_size, self._config.encoder_factory, @@ -80,13 +88,16 @@ def inner_create_impl( ) optim = self._config.optim_factory.create( - q_func.parameters(), lr=self._config.learning_rate + q_funcs.parameters(), lr=self._config.learning_rate ) self._impl = DQNImpl( observation_shape=observation_shape, action_size=action_size, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, optim=optim, gamma=self._config.gamma, device=self._device, diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index 713e65fc..74a242cd 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -112,7 +112,15 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) - q_func = create_continuous_q_function( + q_funcs, q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, @@ -134,7 +142,7 @@ def inner_create_impl( policy.parameters(), lr=self._config.actor_learning_rate ) critic_optim = self._config.critic_optim_factory.create( - q_func.parameters(), lr=self._config.critic_learning_rate + q_funcs.parameters(), lr=self._config.critic_learning_rate ) imitator_optim = self._config.critic_optim_factory.create( imitator.parameters(), lr=self._config.imitator_learning_rate @@ -144,7 +152,10 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, imitator=imitator, actor_optim=actor_optim, critic_optim=critic_optim, @@ -242,7 +253,15 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) - q_func = create_continuous_q_function( + q_funcs, q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, @@ -273,7 +292,7 @@ def inner_create_impl( params=parameters, lr=self._config.actor_learning_rate ) critic_optim = self._config.critic_optim_factory.create( - q_func.parameters(), lr=self._config.critic_learning_rate + q_funcs.parameters(), lr=self._config.critic_learning_rate ) imitator_optim = self._config.critic_optim_factory.create( imitator.parameters(), lr=self._config.imitator_learning_rate @@ -283,7 +302,10 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, imitator=imitator, perturbation=perturbation, actor_optim=actor_optim, diff --git a/d3rlpy/algos/qlearning/sac.py b/d3rlpy/algos/qlearning/sac.py index e6384736..dc4c0af7 100644 --- a/d3rlpy/algos/qlearning/sac.py +++ b/d3rlpy/algos/qlearning/sac.py @@ -125,7 +125,15 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) - q_func = create_continuous_q_function( + q_funcs, q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, @@ -143,7 +151,7 @@ def inner_create_impl( policy.parameters(), lr=self._config.actor_learning_rate ) critic_optim = self._config.critic_optim_factory.create( - q_func.parameters(), lr=self._config.critic_learning_rate + q_funcs.parameters(), lr=self._config.critic_learning_rate ) temp_optim = self._config.temp_optim_factory.create( log_temp.parameters(), lr=self._config.temp_learning_rate @@ -153,7 +161,10 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, log_temp=log_temp, actor_optim=actor_optim, critic_optim=critic_optim, @@ -265,7 +276,15 @@ class DiscreteSAC(QLearningAlgoBase[DiscreteSACImpl, DiscreteSACConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - q_func = create_discrete_q_function( + q_funcs, q_func_forwarder = create_discrete_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function( observation_shape, action_size, self._config.critic_encoder_factory, @@ -286,7 +305,7 @@ def inner_create_impl( ) critic_optim = self._config.critic_optim_factory.create( - q_func.parameters(), lr=self._config.critic_learning_rate + q_funcs.parameters(), lr=self._config.critic_learning_rate ) actor_optim = self._config.actor_optim_factory.create( policy.parameters(), lr=self._config.actor_learning_rate @@ -298,7 +317,10 @@ def inner_create_impl( self._impl = DiscreteSACImpl( observation_shape=observation_shape, action_size=action_size, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, policy=policy, log_temp=log_temp, actor_optim=actor_optim, diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index 00a5ac66..cfc88747 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -109,7 +109,15 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) - q_func = create_continuous_q_function( + q_funcs, q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, @@ -122,14 +130,17 @@ def inner_create_impl( policy.parameters(), lr=self._config.actor_learning_rate ) critic_optim = self._config.critic_optim_factory.create( - q_func.parameters(), lr=self._config.critic_learning_rate + q_funcs.parameters(), lr=self._config.critic_learning_rate ) self._impl = TD3Impl( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, actor_optim=actor_optim, critic_optim=critic_optim, gamma=self._config.gamma, diff --git a/d3rlpy/algos/qlearning/td3_plus_bc.py b/d3rlpy/algos/qlearning/td3_plus_bc.py index 231d99f6..25ac1bb0 100644 --- a/d3rlpy/algos/qlearning/td3_plus_bc.py +++ b/d3rlpy/algos/qlearning/td3_plus_bc.py @@ -101,7 +101,15 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) - q_func = create_continuous_q_function( + q_funcs, q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, @@ -114,14 +122,17 @@ def inner_create_impl( policy.parameters(), lr=self._config.actor_learning_rate ) critic_optim = self._config.critic_optim_factory.create( - q_func.parameters(), lr=self._config.critic_learning_rate + q_funcs.parameters(), lr=self._config.critic_learning_rate ) self._impl = TD3PlusBCImpl( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, actor_optim=actor_optim, critic_optim=critic_optim, gamma=self._config.gamma, diff --git a/d3rlpy/algos/qlearning/torch/awac_impl.py b/d3rlpy/algos/qlearning/torch/awac_impl.py index 9f44f81e..c0726b98 100644 --- a/d3rlpy/algos/qlearning/torch/awac_impl.py +++ b/d3rlpy/algos/qlearning/torch/awac_impl.py @@ -1,10 +1,11 @@ import torch import torch.nn.functional as F +from torch import nn from torch.optim import Adam, Optimizer from ....dataset import Shape from ....models.torch import ( - EnsembleContinuousQFunction, + ContinuousEnsembleQFunctionForwarder, NormalPolicy, Parameter, Policy, @@ -25,7 +26,10 @@ def __init__( self, observation_shape: Shape, action_size: int, - q_func: EnsembleContinuousQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, policy: Policy, actor_optim: Optimizer, critic_optim: Optimizer, @@ -40,7 +44,10 @@ def __init__( super().__init__( observation_shape=observation_shape, action_size=action_size, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, policy=policy, actor_optim=actor_optim, critic_optim=critic_optim, @@ -70,7 +77,9 @@ def _compute_weights( batch_size = obs_t.shape[0] # compute action-value - q_values = self._q_func(obs_t, act_t, "min") + q_values = self._q_func_forwarder.compute_expected_q( + obs_t, act_t, "min" + ) # sample actions # (batch_size * N, action_size) @@ -89,7 +98,9 @@ def _compute_weights( flat_obs_t = repeated_obs_t.reshape(-1, *obs_t.shape[1:]) # compute state-value - flat_v_values = self._q_func(flat_obs_t, flat_actions, "min") + flat_v_values = self._q_func_forwarder.compute_expected_q( + flat_obs_t, flat_actions, "min" + ) reshaped_v_values = flat_v_values.view(obs_t.shape[0], -1, 1) v_values = reshaped_v_values.mean(dim=1) diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index 4ea28d8b..338af1ee 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -3,15 +3,16 @@ import torch import torch.nn.functional as F +from torch import nn from torch.optim import Optimizer from ....dataset import Shape from ....models.torch import ( CategoricalPolicy, ConditionalVAE, + ContinuousEnsembleQFunctionForwarder, DeterministicResidualPolicy, - EnsembleContinuousQFunction, - EnsembleDiscreteQFunction, + DiscreteEnsembleQFunctionForwarder, compute_discrete_imitation_loss, compute_max_with_n_actions, compute_vae_error, @@ -39,7 +40,10 @@ def __init__( observation_shape: Shape, action_size: int, policy: DeterministicResidualPolicy, - q_func: EnsembleContinuousQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, imitator: ConditionalVAE, actor_optim: Optimizer, critic_optim: Optimizer, @@ -56,7 +60,10 @@ def __init__( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, actor_optim=actor_optim, critic_optim=critic_optim, gamma=gamma, @@ -83,9 +90,10 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: latent=clipped_latent, ) action = self._policy(batch.observations, sampled_action) - return -self._q_func(batch.observations, action.squashed_mu, "none")[ - 0 - ].mean() + value = self._q_func_forwarder.compute_expected_q( + batch.observations, action.squashed_mu, "none" + ) + return -value[0].mean() @train_api def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: @@ -143,7 +151,9 @@ def _predict_value( # (batch_size, n, action_size) -> (batch_size * n, action_size) flattend_action = action.view(-1, self.action_size) # estimate values - return self._q_func(flattened_x, flattend_action, "none") + return self._q_func_forwarder.compute_expected_q( + flattened_x, flattend_action, "none" + ) def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: # TODO: this seems to be slow with image observation @@ -164,7 +174,10 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: actions = self._sample_repeated_action(repeated_x, True) values = compute_max_with_n_actions( - batch.next_observations, actions, self._targ_q_func, self._lam + batch.next_observations, + actions, + self._targ_q_func_forwarder, + self._lam, ) return values @@ -179,7 +192,10 @@ def __init__( self, observation_shape: Shape, action_size: int, - q_func: EnsembleDiscreteQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, imitator: CategoricalPolicy, optim: Optimizer, gamma: float, @@ -190,7 +206,10 @@ def __init__( super().__init__( observation_shape=observation_shape, action_size=action_size, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, optim=optim, gamma=gamma, device=device, @@ -216,7 +235,7 @@ def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: log_probs = F.log_softmax(dist.logits, dim=1) ratio = log_probs - log_probs.max(dim=1, keepdim=True).values mask = (ratio > math.log(self._action_flexibility)).float() - value = self._q_func(x) + value = self._q_func_forwarder.compute_expected_q(x) # add a small constant value to deal with the case where the all # actions except the min value are masked normalized_value = value - value.min(dim=1, keepdim=True).values + 1e-5 diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index 0b79a110..2f6b5f4a 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -1,12 +1,13 @@ from typing import Dict import torch +from torch import nn from torch.optim import Optimizer from ....dataset import Shape from ....models.torch import ( ConditionalVAE, - EnsembleContinuousQFunction, + ContinuousEnsembleQFunctionForwarder, NormalPolicy, Parameter, build_squashed_gaussian_distribution, @@ -54,7 +55,10 @@ def __init__( observation_shape: Shape, action_size: int, policy: NormalPolicy, - q_func: EnsembleContinuousQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, imitator: ConditionalVAE, log_temp: Parameter, log_alpha: Parameter, @@ -79,7 +83,10 @@ def __init__( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, log_temp=log_temp, actor_optim=actor_optim, critic_optim=critic_optim, @@ -218,7 +225,10 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: self._n_target_samples ) values, indices = compute_max_with_n_actions_and_indices( - batch.next_observations, actions, self._targ_q_func, self._lam + batch.next_observations, + actions, + self._targ_q_func_forwarder, + self._lam, ) # (batch, n, 1) -> (batch, 1) @@ -245,7 +255,9 @@ def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: flat_x = repeated_x.reshape(-1, *x.shape[1:]) # (batch * n, 1) - flat_values = self._q_func(flat_x, flat_actions, "none")[0] + flat_values = self._q_func_forwarder.compute_expected_q( + flat_x, flat_actions, "none" + )[0] # (batch, n) values = flat_values.view(x.shape[0], self._n_action_samples) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 6675d978..8f4e83ff 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -3,12 +3,13 @@ import torch import torch.nn.functional as F +from torch import nn from torch.optim import Optimizer from ....dataset import Shape from ....models.torch import ( - EnsembleContinuousQFunction, - EnsembleDiscreteQFunction, + ContinuousEnsembleQFunctionForwarder, + DiscreteEnsembleQFunctionForwarder, NormalPolicy, Parameter, build_squashed_gaussian_distribution, @@ -33,7 +34,10 @@ def __init__( observation_shape: Shape, action_size: int, policy: NormalPolicy, - q_func: EnsembleContinuousQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, log_temp: Parameter, log_alpha: Parameter, actor_optim: Optimizer, @@ -52,7 +56,10 @@ def __init__( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, log_temp=log_temp, actor_optim=actor_optim, critic_optim=critic_optim, @@ -80,7 +87,7 @@ def compute_critic_loss( @train_api def update_alpha(self, batch: TorchMiniBatch) -> Dict[str, float]: # Q function should be inference mode for stability - self._q_func.eval() + self._q_funcs.eval() self._alpha_optim.zero_grad() @@ -121,7 +128,9 @@ def _compute_policy_is_values( flat_policy_acts = policy_actions.reshape(-1, self.action_size) # estimate action-values for policy actions - policy_values = self._q_func(flat_obs, flat_policy_acts, "none") + policy_values = self._q_func_forwarder.compute_expected_q( + flat_obs, flat_policy_acts, "none" + ) policy_values = policy_values.view( -1, obs_shape[0], self._n_action_samples ) @@ -142,7 +151,9 @@ def _compute_random_is_values(self, obs: torch.Tensor) -> torch.Tensor: flat_shape = (obs.shape[0] * self._n_action_samples, self._action_size) zero_tensor = torch.zeros(flat_shape, device=self._device) random_actions = zero_tensor.uniform_(-1.0, 1.0) - random_values = self._q_func(flat_obs, random_actions, "none") + random_values = self._q_func_forwarder.compute_expected_q( + flat_obs, random_actions, "none" + ) random_values = random_values.view( -1, obs.shape[0], self._n_action_samples ) @@ -166,7 +177,9 @@ def _compute_conservative_loss( logsumexp = torch.logsumexp(target_values, dim=2, keepdim=True) # estimate action-values for data actions - data_values = self._q_func(obs_t, act_t, "none") + data_values = self._q_func_forwarder.compute_expected_q( + obs_t, act_t, "none" + ) loss = logsumexp.mean(dim=0).mean() - data_values.mean(dim=0).mean() scaled_loss = self._conservative_weight * loss @@ -188,7 +201,7 @@ def _compute_deterministic_target( ) -> torch.Tensor: with torch.no_grad(): action = self._policy(batch.next_observations).squashed_mu - return self._targ_q_func.compute_target( + return self._targ_q_func_forwarder.compute_target( batch.next_observations, action, reduction="min", @@ -202,7 +215,10 @@ def __init__( self, observation_shape: Shape, action_size: int, - q_func: EnsembleDiscreteQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, optim: Optimizer, gamma: float, alpha: float, @@ -211,7 +227,10 @@ def __init__( super().__init__( observation_shape=observation_shape, action_size=action_size, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, optim=optim, gamma=gamma, device=device, @@ -222,12 +241,12 @@ def _compute_conservative_loss( self, obs_t: torch.Tensor, act_t: torch.Tensor ) -> torch.Tensor: # compute logsumexp - policy_values = self._q_func(obs_t) - logsumexp = torch.logsumexp(policy_values, dim=1, keepdim=True) + values = self._q_func_forwarder.compute_expected_q(obs_t) + logsumexp = torch.logsumexp(values, dim=1, keepdim=True) # estimate action-values under data distribution one_hot = F.one_hot(act_t.view(-1), num_classes=self.action_size) - data_values = (self._q_func(obs_t) * one_hot).sum(dim=1, keepdim=True) + data_values = (values * one_hot).sum(dim=1, keepdim=True) return (logsumexp - data_values).mean() diff --git a/d3rlpy/algos/qlearning/torch/crr_impl.py b/d3rlpy/algos/qlearning/torch/crr_impl.py index 14015bf7..82e5e603 100644 --- a/d3rlpy/algos/qlearning/torch/crr_impl.py +++ b/d3rlpy/algos/qlearning/torch/crr_impl.py @@ -1,10 +1,11 @@ import torch import torch.nn.functional as F +from torch import nn from torch.optim import Optimizer from ....dataset import Shape from ....models.torch import ( - EnsembleContinuousQFunction, + ContinuousEnsembleQFunctionForwarder, NormalPolicy, build_gaussian_distribution, ) @@ -28,7 +29,10 @@ def __init__( observation_shape: Shape, action_size: int, policy: NormalPolicy, - q_func: EnsembleContinuousQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, actor_optim: Optimizer, critic_optim: Optimizer, gamma: float, @@ -44,7 +48,10 @@ def __init__( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, actor_optim=actor_optim, critic_optim=critic_optim, gamma=gamma, @@ -97,7 +104,9 @@ def _compute_advantage( # (batch_size, N, obs_size) -> (batch_size * N, obs_size) flat_obs_t = repeated_obs_t.reshape(-1, *obs_t.shape[1:]) - flat_values = self._q_func(flat_obs_t, flat_actions) + flat_values = self._q_func_forwarder.compute_expected_q( + flat_obs_t, flat_actions + ) reshaped_values = flat_values.view(obs_t.shape[0], -1, 1) if self._advantage_type == "mean": @@ -109,14 +118,16 @@ def _compute_advantage( f"invalid advantage type: {self._advantage_type}." ) - return self._q_func(obs_t, act_t) - values + return ( + self._q_func_forwarder.compute_expected_q(obs_t, act_t) - values + ) def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): action = build_gaussian_distribution( self._targ_policy(batch.next_observations) ).sample() - return self._targ_q_func.compute_target( + return self._targ_q_func_forwarder.compute_target( batch.next_observations, action.clamp(-1.0, 1.0), reduction="min", @@ -141,7 +152,9 @@ def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: flat_obs_t = repeated_obs_t.reshape(-1, *x.shape[1:]) # (batch_size * N, 1) - flat_values = self._q_func(flat_obs_t, flat_actions) + flat_values = self._q_func_forwarder.compute_expected_q( + flat_obs_t, flat_actions + ) # (batch_size * N, 1) -> (batch_size, N) reshaped_values = flat_values.view(x.shape[0], -1) @@ -156,7 +169,7 @@ def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: return dist.sample() def sync_critic_target(self) -> None: - hard_sync(self._targ_q_func, self._q_func) + hard_sync(self._targ_q_funcs, self._q_funcs) def sync_actor_target(self) -> None: hard_sync(self._targ_policy, self._policy) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index a4b9b6f3..e6a2d2fa 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -3,15 +3,12 @@ from typing import Dict import torch +from torch import nn from torch.optim import Optimizer from ....dataset import Shape -from ....models.torch import ( - EnsembleContinuousQFunction, - EnsembleQFunction, - Policy, -) -from ....torch_utility import TorchMiniBatch, soft_sync, train_api +from ....models.torch import ContinuousEnsembleQFunctionForwarder, Policy +from ....torch_utility import TorchMiniBatch, hard_sync, soft_sync, train_api from ..base import QLearningAlgoImplBase from .utility import ContinuousQFunctionMixin @@ -23,9 +20,11 @@ class DDPGBaseImpl( ): _gamma: float _tau: float - _q_func: EnsembleContinuousQFunction + _q_funcs: nn.ModuleList + _q_func_forwarder: ContinuousEnsembleQFunctionForwarder _policy: Policy - _targ_q_func: EnsembleContinuousQFunction + _targ_q_funcs: nn.ModuleList + _targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder _targ_policy: Policy _actor_optim: Optimizer _critic_optim: Optimizer @@ -35,7 +34,10 @@ def __init__( observation_shape: Shape, action_size: int, policy: Policy, - q_func: EnsembleContinuousQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, actor_optim: Optimizer, critic_optim: Optimizer, gamma: float, @@ -50,11 +52,14 @@ def __init__( self._gamma = gamma self._tau = tau self._policy = policy - self._q_func = q_func + self._q_funcs = q_funcs + self._q_func_forwarder = q_func_forwarder self._actor_optim = actor_optim self._critic_optim = critic_optim - self._targ_q_func = copy.deepcopy(q_func) + self._targ_q_funcs = targ_q_funcs + self._targ_q_func_forwarder = targ_q_func_forwarder self._targ_policy = copy.deepcopy(policy) + hard_sync(self._targ_q_funcs, self._q_funcs) @train_api def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: @@ -72,7 +77,7 @@ def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor ) -> torch.Tensor: - return self._q_func.compute_error( + return self._q_func_forwarder.compute_error( observations=batch.observations, actions=batch.actions, rewards=batch.rewards, @@ -84,7 +89,7 @@ def compute_critic_loss( @train_api def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: # Q function should be inference mode for stability - self._q_func.eval() + self._q_funcs.eval() self._actor_optim.zero_grad() @@ -111,7 +116,7 @@ def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: pass def update_critic_target(self) -> None: - soft_sync(self._targ_q_func, self._q_func, self._tau) + soft_sync(self._targ_q_funcs, self._q_funcs, self._tau) def update_actor_target(self) -> None: soft_sync(self._targ_policy, self._policy, self._tau) @@ -125,8 +130,8 @@ def policy_optim(self) -> Optimizer: return self._actor_optim @property - def q_function(self) -> EnsembleQFunction: - return self._q_func + def q_function(self) -> nn.ModuleList: + return self._q_funcs @property def q_function_optim(self) -> Optimizer: @@ -136,13 +141,15 @@ def q_function_optim(self) -> Optimizer: class DDPGImpl(DDPGBaseImpl): def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: action = self._policy(batch.observations) - q_t = self._q_func(batch.observations, action.squashed_mu, "none")[0] + q_t = self._q_func_forwarder.compute_expected_q( + batch.observations, action.squashed_mu, "none" + )[0] return -q_t.mean() def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): action = self._targ_policy(batch.next_observations) - return self._targ_q_func.compute_target( + return self._targ_q_func_forwarder.compute_target( batch.next_observations, action.squashed_mu.clamp(-1.0, 1.0), reduction="min", diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index afc5d30e..54d4cbf5 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -1,11 +1,11 @@ -import copy from typing import Dict import torch +from torch import nn from torch.optim import Optimizer from ....dataset import Shape -from ....models.torch import EnsembleDiscreteQFunction, EnsembleQFunction +from ....models.torch import DiscreteEnsembleQFunctionForwarder from ....torch_utility import TorchMiniBatch, hard_sync, train_api from ..base import QLearningAlgoImplBase from .utility import DiscreteQFunctionMixin @@ -15,15 +15,20 @@ class DQNImpl(DiscreteQFunctionMixin, QLearningAlgoImplBase): _gamma: float - _q_func: EnsembleDiscreteQFunction - _targ_q_func: EnsembleDiscreteQFunction + _q_funcs: nn.ModuleList + _targ_q_funcs: nn.ModuleList + _q_func_forwarder: DiscreteEnsembleQFunctionForwarder + _targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder _optim: Optimizer def __init__( self, observation_shape: Shape, action_size: int, - q_func: EnsembleDiscreteQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, optim: Optimizer, gamma: float, device: str, @@ -34,9 +39,12 @@ def __init__( device=device, ) self._gamma = gamma - self._q_func = q_func + self._q_funcs = q_funcs + self._q_func_forwarder = q_func_forwarder + self._targ_q_funcs = targ_q_funcs + self._targ_q_func_forwarder = targ_q_func_forwarder self._optim = optim - self._targ_q_func = copy.deepcopy(q_func) + hard_sync(targ_q_funcs, q_funcs) @train_api def update(self, batch: TorchMiniBatch) -> Dict[str, float]: @@ -56,7 +64,7 @@ def compute_loss( batch: TorchMiniBatch, q_tpn: torch.Tensor, ) -> torch.Tensor: - return self._q_func.compute_error( + return self._q_func_forwarder.compute_error( observations=batch.observations, actions=batch.actions.long(), rewards=batch.rewards, @@ -67,26 +75,28 @@ def compute_loss( def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): - next_actions = self._targ_q_func(batch.next_observations) + next_actions = self._targ_q_func_forwarder.compute_expected_q( + batch.next_observations + ) max_action = next_actions.argmax(dim=1) - return self._targ_q_func.compute_target( + return self._targ_q_func_forwarder.compute_target( batch.next_observations, max_action, reduction="min", ) def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - return self._q_func(x).argmax(dim=1) + return self._q_func_forwarder.compute_expected_q(x).argmax(dim=1) def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: return self.inner_predict_best_action(x) def update_target(self) -> None: - hard_sync(self._targ_q_func, self._q_func) + hard_sync(self._targ_q_funcs, self._q_funcs) @property - def q_function(self) -> EnsembleQFunction: - return self._q_func + def q_function(self) -> nn.ModuleList: + return self._q_funcs @property def q_function_optim(self) -> Optimizer: @@ -97,7 +107,7 @@ class DoubleDQNImpl(DQNImpl): def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): action = self.inner_predict_best_action(batch.next_observations) - return self._targ_q_func.compute_target( + return self._targ_q_func_forwarder.compute_target( batch.next_observations, action, reduction="min", diff --git a/d3rlpy/algos/qlearning/torch/iql_impl.py b/d3rlpy/algos/qlearning/torch/iql_impl.py index 12dcb7e7..80d8d806 100644 --- a/d3rlpy/algos/qlearning/torch/iql_impl.py +++ b/d3rlpy/algos/qlearning/torch/iql_impl.py @@ -1,11 +1,12 @@ from typing import Dict import torch +from torch import nn from torch.optim import Optimizer from ....dataset import Shape from ....models.torch import ( - EnsembleContinuousQFunction, + ContinuousEnsembleQFunctionForwarder, NormalPolicy, ValueFunction, build_gaussian_distribution, @@ -28,7 +29,10 @@ def __init__( observation_shape: Shape, action_size: int, policy: NormalPolicy, - q_func: EnsembleContinuousQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, value_func: ValueFunction, actor_optim: Optimizer, critic_optim: Optimizer, @@ -43,7 +47,10 @@ def __init__( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, actor_optim=actor_optim, critic_optim=critic_optim, gamma=gamma, @@ -58,7 +65,7 @@ def __init__( def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor ) -> torch.Tensor: - return self._q_func.compute_error( + return self._q_func_forwarder.compute_error( observations=batch.observations, actions=batch.actions, rewards=batch.rewards, @@ -83,13 +90,17 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: return -(weight * log_probs).mean() def _compute_weight(self, batch: TorchMiniBatch) -> torch.Tensor: - q_t = self._targ_q_func(batch.observations, batch.actions, "min") + q_t = self._targ_q_func_forwarder.compute_expected_q( + batch.observations, batch.actions, "min" + ) v_t = self._value_func(batch.observations) adv = q_t - v_t return (self._weight_temp * adv).exp().clamp(max=self._max_weight) def compute_value_loss(self, batch: TorchMiniBatch) -> torch.Tensor: - q_t = self._targ_q_func(batch.observations, batch.actions, "min") + q_t = self._targ_q_func_forwarder.compute_expected_q( + batch.observations, batch.actions, "min" + ) v_t = self._value_func(batch.observations) diff = q_t.detach() - v_t weight = (self._expectile - (diff < 0.0).float()).abs().detach() diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index cdd11ae6..817754b2 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -2,14 +2,15 @@ from typing import Dict import torch +from torch import nn from torch.optim import Optimizer from ....dataset import Shape from ....models.torch import ( ConditionalVAE, + ContinuousEnsembleQFunctionForwarder, DeterministicPolicy, DeterministicResidualPolicy, - EnsembleContinuousQFunction, compute_vae_error, forward_vae_decode, ) @@ -32,7 +33,10 @@ def __init__( observation_shape: Shape, action_size: int, policy: DeterministicPolicy, - q_func: EnsembleContinuousQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, imitator: ConditionalVAE, actor_optim: Optimizer, critic_optim: Optimizer, @@ -47,7 +51,10 @@ def __init__( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, actor_optim=actor_optim, critic_optim=critic_optim, gamma=gamma, @@ -80,7 +87,9 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: actions = forward_vae_decode( self._imitator, batch.observations, latent_actions ) - return -self._q_func(batch.observations, actions, "none")[0].mean() + return -self._q_func_forwarder.compute_expected_q( + batch.observations, actions, "none" + )[0].mean() def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: latent_actions = 2.0 * self._policy(x).squashed_mu @@ -97,7 +106,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: actions = forward_vae_decode( self._imitator, batch.next_observations, latent_actions ) - return self._targ_q_func.compute_target( + return self._targ_q_func_forwarder.compute_target( batch.next_observations, actions, "mix", @@ -114,7 +123,10 @@ def __init__( observation_shape: Shape, action_size: int, policy: DeterministicPolicy, - q_func: EnsembleContinuousQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, imitator: ConditionalVAE, perturbation: DeterministicResidualPolicy, actor_optim: Optimizer, @@ -130,7 +142,10 @@ def __init__( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, imitator=imitator, actor_optim=actor_optim, critic_optim=critic_optim, @@ -152,7 +167,9 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: residual_actions = self._perturbation( batch.observations, actions ).squashed_mu - q_value = self._q_func(batch.observations, residual_actions, "none") + q_value = self._q_func_forwarder.compute_expected_q( + batch.observations, residual_actions, "none" + ) return -q_value[0].mean() def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: @@ -174,7 +191,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: residual_actions = self._targ_perturbation( batch.next_observations, actions ) - return self._targ_q_func.compute_target( + return self._targ_q_func_forwarder.compute_target( batch.next_observations, residual_actions.squashed_mu, reduction="mix", diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index 38a3262b..ed1e4391 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -1,17 +1,16 @@ -import copy import math from typing import Dict import torch import torch.nn.functional as F +from torch import nn from torch.optim import Optimizer from ....dataset import Shape from ....models.torch import ( CategoricalPolicy, - EnsembleContinuousQFunction, - EnsembleDiscreteQFunction, - EnsembleQFunction, + ContinuousEnsembleQFunctionForwarder, + DiscreteEnsembleQFunctionForwarder, Parameter, Policy, build_squashed_gaussian_distribution, @@ -33,7 +32,10 @@ def __init__( observation_shape: Shape, action_size: int, policy: Policy, - q_func: EnsembleContinuousQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, log_temp: Parameter, actor_optim: Optimizer, critic_optim: Optimizer, @@ -46,7 +48,10 @@ def __init__( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, actor_optim=actor_optim, critic_optim=critic_optim, gamma=gamma, @@ -62,7 +67,9 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: ) action, log_prob = dist.sample_with_log_prob() entropy = self._log_temp().exp() * log_prob - q_t = self._q_func(batch.observations, action, "min") + q_t = self._q_func_forwarder.compute_expected_q( + batch.observations, action, "min" + ) return (entropy - q_t).mean() @train_api @@ -96,7 +103,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: ) action, log_prob = dist.sample_with_log_prob() entropy = self._log_temp().exp() * log_prob - target = self._targ_q_func.compute_target( + target = self._targ_q_func_forwarder.compute_target( batch.next_observations, action, reduction="min", @@ -110,8 +117,10 @@ def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: class DiscreteSACImpl(DiscreteQFunctionMixin, QLearningAlgoImplBase): _policy: CategoricalPolicy - _q_func: EnsembleDiscreteQFunction - _targ_q_func: EnsembleDiscreteQFunction + _q_funcss: nn.ModuleList + _q_func_forwarder: DiscreteEnsembleQFunctionForwarder + _targ_q_funcs: nn.ModuleList + _targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder _log_temp: Parameter _actor_optim: Optimizer _critic_optim: Optimizer @@ -121,7 +130,10 @@ def __init__( self, observation_shape: Shape, action_size: int, - q_func: EnsembleDiscreteQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, policy: CategoricalPolicy, log_temp: Parameter, actor_optim: Optimizer, @@ -136,13 +148,16 @@ def __init__( device=device, ) self._gamma = gamma - self._q_func = q_func + self._q_funcs = q_funcs + self._q_func_forwarder = q_func_forwarder + self._targ_q_funcs = targ_q_funcs + self._targ_q_func_forwarder = targ_q_func_forwarder self._policy = policy self._log_temp = log_temp self._actor_optim = actor_optim self._critic_optim = critic_optim self._temp_optim = temp_optim - self._targ_q_func = copy.deepcopy(q_func) + hard_sync(targ_q_funcs, q_funcs) @train_api def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: @@ -162,7 +177,9 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: log_probs = dist.logits probs = dist.probs entropy = self._log_temp().exp() * log_probs - target = self._targ_q_func.compute_target(batch.next_observations) + target = self._targ_q_func_forwarder.compute_target( + batch.next_observations + ) keepdims = True if target.dim() == 3: entropy = entropy.unsqueeze(-1) @@ -175,7 +192,7 @@ def compute_critic_loss( batch: TorchMiniBatch, q_tpn: torch.Tensor, ) -> torch.Tensor: - return self._q_func.compute_error( + return self._q_func_forwarder.compute_error( observations=batch.observations, actions=batch.actions.long(), rewards=batch.rewards, @@ -187,7 +204,7 @@ def compute_critic_loss( @train_api def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: # Q function should be inference mode for stability - self._q_func.eval() + self._q_funcs.eval() self._actor_optim.zero_grad() @@ -200,7 +217,9 @@ def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): - q_t = self._q_func(batch.observations, reduction="min") + q_t = self._q_func_forwarder.compute_expected_q( + batch.observations, reduction="min" + ) dist = self._policy(batch.observations) log_probs = dist.logits probs = dist.probs @@ -241,7 +260,7 @@ def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: return dist.sample() def update_target(self) -> None: - hard_sync(self._targ_q_func, self._q_func) + hard_sync(self._targ_q_funcs, self._q_funcs) @property def policy(self) -> Policy: @@ -252,8 +271,8 @@ def policy_optim(self) -> Optimizer: return self._actor_optim @property - def q_function(self) -> EnsembleQFunction: - return self._q_func + def q_function(self) -> nn.ModuleList: + return self._q_funcs @property def q_function_optim(self) -> Optimizer: diff --git a/d3rlpy/algos/qlearning/torch/td3_impl.py b/d3rlpy/algos/qlearning/torch/td3_impl.py index 5cda68f4..dbd944a7 100644 --- a/d3rlpy/algos/qlearning/torch/td3_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_impl.py @@ -1,8 +1,12 @@ import torch +from torch import nn from torch.optim import Optimizer from ....dataset import Shape -from ....models.torch import DeterministicPolicy, EnsembleContinuousQFunction +from ....models.torch import ( + ContinuousEnsembleQFunctionForwarder, + DeterministicPolicy, +) from ....torch_utility import TorchMiniBatch from .ddpg_impl import DDPGImpl @@ -18,7 +22,10 @@ def __init__( observation_shape: Shape, action_size: int, policy: DeterministicPolicy, - q_func: EnsembleContinuousQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, actor_optim: Optimizer, critic_optim: Optimizer, gamma: float, @@ -31,7 +38,10 @@ def __init__( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, actor_optim=actor_optim, critic_optim=critic_optim, gamma=gamma, @@ -52,7 +62,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: ) smoothed_action = action.squashed_mu + clipped_noise clipped_action = smoothed_action.clamp(-1.0, 1.0) - return self._targ_q_func.compute_target( + return self._targ_q_func_forwarder.compute_target( batch.next_observations, clipped_action, reduction="min", diff --git a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py index cca9a1a4..9cb75608 100644 --- a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py @@ -1,10 +1,14 @@ # pylint: disable=too-many-ancestors import torch +from torch import nn from torch.optim import Optimizer from ....dataset import Shape -from ....models.torch import DeterministicPolicy, EnsembleContinuousQFunction +from ....models.torch import ( + ContinuousEnsembleQFunctionForwarder, + DeterministicPolicy, +) from ....torch_utility import TorchMiniBatch from .td3_impl import TD3Impl @@ -19,7 +23,10 @@ def __init__( observation_shape: Shape, action_size: int, policy: DeterministicPolicy, - q_func: EnsembleContinuousQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, actor_optim: Optimizer, critic_optim: Optimizer, gamma: float, @@ -33,7 +40,10 @@ def __init__( observation_shape=observation_shape, action_size=action_size, policy=policy, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, actor_optim=actor_optim, critic_optim=critic_optim, gamma=gamma, @@ -46,6 +56,8 @@ def __init__( def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: action = self._policy(batch.observations).squashed_mu - q_t = self._q_func(batch.observations, action, "none")[0] + q_t = self._q_func_forwarder.compute_expected_q( + batch.observations, action, "none" + )[0] lam = self._alpha / (q_t.abs().mean()).detach() return lam * -q_t.mean() + ((batch.actions - action) ** 2).mean() diff --git a/d3rlpy/algos/qlearning/torch/utility.py b/d3rlpy/algos/qlearning/torch/utility.py index 9c43cb37..15ed5551 100644 --- a/d3rlpy/algos/qlearning/torch/utility.py +++ b/d3rlpy/algos/qlearning/torch/utility.py @@ -1,30 +1,27 @@ -from typing import Optional - import torch from typing_extensions import Protocol from ....models.torch import ( - EnsembleContinuousQFunction, - EnsembleDiscreteQFunction, + ContinuousEnsembleQFunctionForwarder, + DiscreteEnsembleQFunctionForwarder, ) __all__ = ["DiscreteQFunctionMixin", "ContinuousQFunctionMixin"] class _DiscreteQFunctionProtocol(Protocol): - _q_func: Optional[EnsembleDiscreteQFunction] + _q_func_forwarder: DiscreteEnsembleQFunctionForwarder class _ContinuousQFunctionProtocol(Protocol): - _q_func: Optional[EnsembleContinuousQFunction] + _q_func_forwarder: ContinuousEnsembleQFunctionForwarder class DiscreteQFunctionMixin: def inner_predict_value( self: _DiscreteQFunctionProtocol, x: torch.Tensor, action: torch.Tensor ) -> torch.Tensor: - assert self._q_func is not None - values = self._q_func(x, reduction="mean") + values = self._q_func_forwarder.compute_expected_q(x, reduction="mean") flat_action = action.reshape(-1) return values[torch.arange(0, x.size(0)), flat_action].reshape(-1) @@ -35,5 +32,6 @@ def inner_predict_value( x: torch.Tensor, action: torch.Tensor, ) -> torch.Tensor: - assert self._q_func is not None - return self._q_func(x, action, reduction="mean").reshape(-1) + return self._q_func_forwarder.compute_expected_q( + x, action, reduction="mean" + ).reshape(-1) diff --git a/d3rlpy/models/builders.py b/d3rlpy/models/builders.py index 02e82ccd..6aa93d23 100644 --- a/d3rlpy/models/builders.py +++ b/d3rlpy/models/builders.py @@ -1,4 +1,4 @@ -from typing import Sequence, cast +from typing import Sequence, Tuple, cast import torch from torch import nn @@ -10,11 +10,11 @@ CategoricalPolicy, ConditionalVAE, ContinuousDecisionTransformer, + ContinuousEnsembleQFunctionForwarder, DeterministicPolicy, DeterministicResidualPolicy, DiscreteDecisionTransformer, - EnsembleContinuousQFunction, - EnsembleDiscreteQFunction, + DiscreteEnsembleQFunctionForwarder, GlobalPositionEncoding, NormalPolicy, Parameter, @@ -48,7 +48,7 @@ def create_discrete_q_function( q_func_factory: QFunctionFactory, device: str, n_ensembles: int = 1, -) -> EnsembleDiscreteQFunction: +) -> Tuple[nn.ModuleList, DiscreteEnsembleQFunctionForwarder]: if q_func_factory.share_encoder: encoder = encoder_factory.create(observation_shape) hidden_size = compute_output_size([observation_shape], encoder, device) @@ -57,18 +57,24 @@ def create_discrete_q_function( p.register_hook(lambda grad: grad / n_ensembles) q_funcs = [] + forwarders = [] for _ in range(n_ensembles): if not q_func_factory.share_encoder: encoder = encoder_factory.create(observation_shape) hidden_size = compute_output_size( [observation_shape], encoder, device ) - q_funcs.append( - q_func_factory.create_discrete(encoder, hidden_size, action_size) + q_func, forwarder = q_func_factory.create_discrete( + encoder, hidden_size, action_size ) - q_func = EnsembleDiscreteQFunction(q_funcs) - q_func.to(device) - return q_func + q_funcs.append(q_func) + forwarders.append(forwarder) + q_func_modules = nn.ModuleList(q_funcs) + q_func_modules.to(device) + ensemble_forwarder = DiscreteEnsembleQFunctionForwarder( + forwarders, action_size + ) + return q_func_modules, ensemble_forwarder def create_continuous_q_function( @@ -78,7 +84,7 @@ def create_continuous_q_function( q_func_factory: QFunctionFactory, device: str, n_ensembles: int = 1, -) -> EnsembleContinuousQFunction: +) -> Tuple[nn.ModuleList, ContinuousEnsembleQFunctionForwarder]: if q_func_factory.share_encoder: encoder = encoder_factory.create_with_action( observation_shape, action_size @@ -91,6 +97,7 @@ def create_continuous_q_function( p.register_hook(lambda grad: grad / n_ensembles) q_funcs = [] + forwarders = [] for _ in range(n_ensembles): if not q_func_factory.share_encoder: encoder = encoder_factory.create_with_action( @@ -99,12 +106,17 @@ def create_continuous_q_function( hidden_size = compute_output_size( [observation_shape, (action_size,)], encoder, device ) - q_funcs.append( - q_func_factory.create_continuous(encoder, hidden_size, action_size) + q_func, forwarder = q_func_factory.create_continuous( + encoder, hidden_size ) - q_func = EnsembleContinuousQFunction(q_funcs) - q_func.to(device) - return q_func + q_funcs.append(q_func) + forwarders.append(forwarder) + q_func_modules = nn.ModuleList(q_funcs) + q_func_modules.to(device) + ensemble_forwarder = ContinuousEnsembleQFunctionForwarder( + forwarders, action_size + ) + return q_func_modules, ensemble_forwarder def create_deterministic_policy( diff --git a/d3rlpy/models/q_functions.py b/d3rlpy/models/q_functions.py index 1a31628b..55681627 100644 --- a/d3rlpy/models/q_functions.py +++ b/d3rlpy/models/q_functions.py @@ -1,15 +1,24 @@ import dataclasses +from typing import Tuple from ..serializable_config import DynamicConfig, generate_config_registration from .torch import ( ContinuousIQNQFunction, + ContinuousIQNQFunctionForwarder, ContinuousMeanQFunction, + ContinuousMeanQFunctionForwarder, ContinuousQFunction, + ContinuousQFunctionForwarder, ContinuousQRQFunction, + ContinuousQRQFunctionForwarder, DiscreteIQNQFunction, + DiscreteIQNQFunctionForwarder, DiscreteMeanQFunction, + DiscreteMeanQFunctionForwarder, DiscreteQFunction, + DiscreteQFunctionForwarder, DiscreteQRQFunction, + DiscreteQRQFunctionForwarder, Encoder, EncoderWithAction, ) @@ -29,7 +38,7 @@ class QFunctionFactory(DynamicConfig): def create_discrete( self, encoder: Encoder, hidden_size: int, action_size: int - ) -> DiscreteQFunction: + ) -> Tuple[DiscreteQFunction, DiscreteQFunctionForwarder]: """Returns PyTorch's Q function module. Args: @@ -39,23 +48,22 @@ def create_discrete( action_size: Dimension of discrete action-space. Returns: - discrete Q function object. + Tuple of discrete Q function and its forwarder. """ raise NotImplementedError def create_continuous( - self, encoder: EncoderWithAction, hidden_size: int, action_size: int - ) -> ContinuousQFunction: + self, encoder: EncoderWithAction, hidden_size: int + ) -> Tuple[ContinuousQFunction, ContinuousQFunctionForwarder]: """Returns PyTorch's Q function module. Args: encoder: Encoder module that processes the observation and action to obtain feature representations. hidden_size: Dimension of encoder output. - action_size: Dimension of continuous actions. Returns: - continuous Q function object. + Tuple of continuous Q function and its forwarder. """ raise NotImplementedError @@ -90,16 +98,19 @@ def create_discrete( encoder: Encoder, hidden_size: int, action_size: int, - ) -> DiscreteMeanQFunction: - return DiscreteMeanQFunction(encoder, hidden_size, action_size) + ) -> Tuple[DiscreteMeanQFunction, DiscreteMeanQFunctionForwarder]: + q_func = DiscreteMeanQFunction(encoder, hidden_size, action_size) + forwarder = DiscreteMeanQFunctionForwarder(q_func, action_size) + return q_func, forwarder def create_continuous( self, encoder: EncoderWithAction, hidden_size: int, - action_size: int, - ) -> ContinuousMeanQFunction: - return ContinuousMeanQFunction(encoder, hidden_size, action_size) + ) -> Tuple[ContinuousMeanQFunction, ContinuousMeanQFunctionForwarder]: + q_func = ContinuousMeanQFunction(encoder, hidden_size) + forwarder = ContinuousMeanQFunctionForwarder(q_func) + return q_func, forwarder @staticmethod def get_type() -> str: @@ -123,26 +134,28 @@ class QRQFunctionFactory(QFunctionFactory): def create_discrete( self, encoder: Encoder, hidden_size: int, action_size: int - ) -> DiscreteQRQFunction: - return DiscreteQRQFunction( + ) -> Tuple[DiscreteQRQFunction, DiscreteQRQFunctionForwarder]: + q_func = DiscreteQRQFunction( encoder=encoder, hidden_size=hidden_size, action_size=action_size, n_quantiles=self.n_quantiles, ) + forwarder = DiscreteQRQFunctionForwarder(q_func, self.n_quantiles) + return q_func, forwarder def create_continuous( self, encoder: EncoderWithAction, hidden_size: int, - action_size: int, - ) -> ContinuousQRQFunction: - return ContinuousQRQFunction( + ) -> Tuple[ContinuousQRQFunction, ContinuousQRQFunctionForwarder]: + q_func = ContinuousQRQFunction( encoder=encoder, hidden_size=hidden_size, - action_size=action_size, n_quantiles=self.n_quantiles, ) + forwarder = ContinuousQRQFunctionForwarder(q_func, self.n_quantiles) + return q_func, forwarder @staticmethod def get_type() -> str: @@ -173,8 +186,8 @@ def create_discrete( encoder: Encoder, hidden_size: int, action_size: int, - ) -> DiscreteIQNQFunction: - return DiscreteIQNQFunction( + ) -> Tuple[DiscreteIQNQFunction, DiscreteIQNQFunctionForwarder]: + q_func = DiscreteIQNQFunction( encoder=encoder, hidden_size=hidden_size, action_size=action_size, @@ -182,18 +195,23 @@ def create_discrete( n_greedy_quantiles=self.n_greedy_quantiles, embed_size=self.embed_size, ) + forwarder = DiscreteIQNQFunctionForwarder(q_func, self.n_quantiles) + return q_func, forwarder def create_continuous( - self, encoder: EncoderWithAction, hidden_size: int, action_size: int - ) -> ContinuousIQNQFunction: - return ContinuousIQNQFunction( + self, encoder: EncoderWithAction, hidden_size: int + ) -> Tuple[ContinuousIQNQFunction, ContinuousIQNQFunctionForwarder]: + q_func = ContinuousIQNQFunction( encoder=encoder, hidden_size=hidden_size, - action_size=action_size, n_quantiles=self.n_quantiles, n_greedy_quantiles=self.n_greedy_quantiles, embed_size=self.embed_size, ) + forwarder = ContinuousIQNQFunctionForwarder( + q_func, self.n_greedy_quantiles + ) + return q_func, forwarder @staticmethod def get_type() -> str: diff --git a/d3rlpy/models/torch/q_functions/base.py b/d3rlpy/models/torch/q_functions/base.py index c3451234..f809350d 100644 --- a/d3rlpy/models/torch/q_functions/base.py +++ b/d3rlpy/models/torch/q_functions/base.py @@ -1,46 +1,49 @@ from abc import ABCMeta, abstractmethod -from typing import Optional +from typing import NamedTuple, Optional import torch +from torch import nn from ..encoders import Encoder, EncoderWithAction -__all__ = ["QFunction", "DiscreteQFunction", "ContinuousQFunction"] +__all__ = [ + "DiscreteQFunction", + "ContinuousQFunction", + "ContinuousQFunctionForwarder", + "DiscreteQFunctionForwarder", + "QFunctionOutput", +] -class QFunction(metaclass=ABCMeta): - @abstractmethod - def compute_error( - self, - observations: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - target: torch.Tensor, - terminals: torch.Tensor, - gamma: float = 0.99, - reduction: str = "mean", - ) -> torch.Tensor: - pass +class QFunctionOutput(NamedTuple): + q_value: torch.Tensor + quantiles: Optional[torch.Tensor] + taus: Optional[torch.Tensor] - @property + +class ContinuousQFunction(nn.Module, metaclass=ABCMeta): # type: ignore @abstractmethod - def action_size(self) -> int: + def forward(self, x: torch.Tensor, action: torch.Tensor) -> QFunctionOutput: pass + def __call__( + self, x: torch.Tensor, action: torch.Tensor + ) -> QFunctionOutput: + return super().__call__(x, action) # type: ignore -class DiscreteQFunction(QFunction): + @property @abstractmethod - def forward(self, x: torch.Tensor) -> torch.Tensor: + def encoder(self) -> EncoderWithAction: pass + +class DiscreteQFunction(nn.Module, metaclass=ABCMeta): # type: ignore @abstractmethod - def compute_target( - self, x: torch.Tensor, action: Optional[torch.Tensor] - ) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> QFunctionOutput: pass - def __call__(self, x: torch.Tensor) -> torch.Tensor: - return self.forward(x) + def __call__(self, x: torch.Tensor) -> QFunctionOutput: + return super().__call__(x) # type: ignore @property @abstractmethod @@ -48,9 +51,24 @@ def encoder(self) -> Encoder: pass -class ContinuousQFunction(QFunction): +class ContinuousQFunctionForwarder(metaclass=ABCMeta): @abstractmethod - def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: + def compute_expected_q( + self, x: torch.Tensor, action: torch.Tensor + ) -> torch.Tensor: + pass + + @abstractmethod + def compute_error( + self, + observations: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + target: torch.Tensor, + terminals: torch.Tensor, + gamma: float = 0.99, + reduction: str = "mean", + ) -> torch.Tensor: pass @abstractmethod @@ -59,10 +77,27 @@ def compute_target( ) -> torch.Tensor: pass - def __call__(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: - return self.forward(x, action) - @property +class DiscreteQFunctionForwarder(metaclass=ABCMeta): @abstractmethod - def encoder(self) -> EncoderWithAction: + def compute_expected_q(self, x: torch.Tensor) -> torch.Tensor: + pass + + @abstractmethod + def compute_error( + self, + observations: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + target: torch.Tensor, + terminals: torch.Tensor, + gamma: float = 0.99, + reduction: str = "mean", + ) -> torch.Tensor: + pass + + @abstractmethod + def compute_target( + self, x: torch.Tensor, action: Optional[torch.Tensor] = None + ) -> torch.Tensor: pass diff --git a/d3rlpy/models/torch/q_functions/ensemble_q_function.py b/d3rlpy/models/torch/q_functions/ensemble_q_function.py index cbff1b76..ce775b04 100644 --- a/d3rlpy/models/torch/q_functions/ensemble_q_function.py +++ b/d3rlpy/models/torch/q_functions/ensemble_q_function.py @@ -1,14 +1,12 @@ -from typing import List, Optional, Tuple, Union, cast +from typing import List, Optional, Sequence, Tuple, Union import torch -from torch import nn -from .base import ContinuousQFunction, DiscreteQFunction +from .base import ContinuousQFunctionForwarder, DiscreteQFunctionForwarder __all__ = [ - "EnsembleQFunction", - "EnsembleDiscreteQFunction", - "EnsembleContinuousQFunction", + "DiscreteEnsembleQFunctionForwarder", + "ContinuousEnsembleQFunctionForwarder", "compute_max_with_n_actions", "compute_max_with_n_actions_and_indices", ] @@ -74,17 +72,93 @@ def _reduce_quantile_ensemble( raise ValueError -class EnsembleQFunction(nn.Module): # type: ignore +def compute_ensemble_q_function_error( + forwarders: Union[ + Sequence[DiscreteQFunctionForwarder], + Sequence[ContinuousQFunctionForwarder], + ], + observations: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + target: torch.Tensor, + terminals: torch.Tensor, + gamma: float = 0.99, +) -> torch.Tensor: + assert target.ndim == 2 + td_sum = torch.tensor( + 0.0, + dtype=torch.float32, + device=observations.device, + ) + for forwarder in forwarders: + loss = forwarder.compute_error( + observations=observations, + actions=actions, + rewards=rewards, + target=target, + terminals=terminals, + gamma=gamma, + reduction="none", + ) + td_sum += loss.mean() + return td_sum + + +def compute_ensemble_q_function_target( + forwarders: Union[ + Sequence[DiscreteQFunctionForwarder], + Sequence[ContinuousQFunctionForwarder], + ], + action_size: int, + x: torch.Tensor, + action: Optional[torch.Tensor] = None, + reduction: str = "min", + lam: float = 0.75, +) -> torch.Tensor: + values_list: List[torch.Tensor] = [] + for forwarder in forwarders: + if isinstance(forwarder, ContinuousQFunctionForwarder): + assert action is not None + target = forwarder.compute_target(x, action) + else: + target = forwarder.compute_target(x, action) + values_list.append(target.reshape(1, x.shape[0], -1)) + + values = torch.cat(values_list, dim=0) + + if action is None: + # mean Q function + if values.shape[2] == action_size: + return _reduce_ensemble(values, reduction) + # distributional Q function + n_q_funcs = values.shape[0] + values = values.view(n_q_funcs, x.shape[0], action_size, -1) + return _reduce_quantile_ensemble(values, reduction) + + if values.shape[2] == 1: + return _reduce_ensemble(values, reduction, lam=lam) + + return _reduce_quantile_ensemble(values, reduction, lam=lam) + + +class DiscreteEnsembleQFunctionForwarder: + _forwarders: Sequence[DiscreteQFunctionForwarder] _action_size: int - _q_funcs: nn.ModuleList def __init__( - self, - q_funcs: Union[List[DiscreteQFunction], List[ContinuousQFunction]], + self, forwarders: Sequence[DiscreteQFunctionForwarder], action_size: int ): - super().__init__() - self._action_size = q_funcs[0].action_size - self._q_funcs = nn.ModuleList(q_funcs) + self._forwarders = forwarders + self._action_size = action_size + + def compute_expected_q( + self, x: torch.Tensor, reduction: str = "mean" + ) -> torch.Tensor: + values = [] + for forwarder in self._forwarders: + value = forwarder.compute_expected_q(x) + values.append(value.view(1, x.shape[0], self._action_size)) + return _reduce_ensemble(torch.cat(values, dim=0), reduction) def compute_error( self, @@ -95,92 +169,76 @@ def compute_error( terminals: torch.Tensor, gamma: float = 0.99, ) -> torch.Tensor: - assert target.ndim == 2 - - td_sum = torch.tensor( - 0.0, dtype=torch.float32, device=observations.device + return compute_ensemble_q_function_error( + forwarders=self._forwarders, + observations=observations, + actions=actions, + rewards=rewards, + target=target, + terminals=terminals, + gamma=gamma, ) - for q_func in self._q_funcs: - loss = q_func.compute_error( - observations=observations, - actions=actions, - rewards=rewards, - target=target, - terminals=terminals, - gamma=gamma, - reduction="none", - ) - td_sum += loss.mean() - return td_sum - - def _compute_target( + + def compute_target( self, x: torch.Tensor, action: Optional[torch.Tensor] = None, reduction: str = "min", lam: float = 0.75, ) -> torch.Tensor: - values_list: List[torch.Tensor] = [] - for q_func in self._q_funcs: - target = q_func.compute_target(x, action) - values_list.append(target.reshape(1, x.shape[0], -1)) - - values = torch.cat(values_list, dim=0) - - if action is None: - # mean Q function - if values.shape[2] == self._action_size: - return _reduce_ensemble(values, reduction) - # distributional Q function - n_q_funcs = values.shape[0] - values = values.view(n_q_funcs, x.shape[0], self._action_size, -1) - return _reduce_quantile_ensemble(values, reduction) - - if values.shape[2] == 1: - return _reduce_ensemble(values, reduction, lam=lam) - - return _reduce_quantile_ensemble(values, reduction, lam=lam) + return compute_ensemble_q_function_target( + forwarders=self._forwarders, + action_size=self._action_size, + x=x, + action=action, + reduction=reduction, + lam=lam, + ) @property - def q_funcs(self) -> nn.ModuleList: - return self._q_funcs - + def forwarders(self) -> Sequence[DiscreteQFunctionForwarder]: + return self._forwarders -class EnsembleDiscreteQFunction(EnsembleQFunction): - def forward(self, x: torch.Tensor, reduction: str = "mean") -> torch.Tensor: - values = [] - for q_func in self._q_funcs: - values.append(q_func(x).view(1, x.shape[0], self._action_size)) - return _reduce_ensemble(torch.cat(values, dim=0), reduction) - def __call__( - self, x: torch.Tensor, reduction: str = "mean" - ) -> torch.Tensor: - return cast(torch.Tensor, super().__call__(x, reduction)) +class ContinuousEnsembleQFunctionForwarder: + _forwarders: Sequence[ContinuousQFunctionForwarder] + _action_size: int - def compute_target( + def __init__( self, - x: torch.Tensor, - action: Optional[torch.Tensor] = None, - reduction: str = "min", - lam: float = 0.75, - ) -> torch.Tensor: - return self._compute_target(x, action, reduction, lam) - + forwarders: Sequence[ContinuousQFunctionForwarder], + action_size: int, + ): + self._forwarders = forwarders + self._action_size = action_size -class EnsembleContinuousQFunction(EnsembleQFunction): - def forward( + def compute_expected_q( self, x: torch.Tensor, action: torch.Tensor, reduction: str = "mean" ) -> torch.Tensor: values = [] - for q_func in self._q_funcs: - values.append(q_func(x, action).view(1, x.shape[0], 1)) + for forwarder in self._forwarders: + value = forwarder.compute_expected_q(x, action) + values.append(value.view(1, x.shape[0], 1)) return _reduce_ensemble(torch.cat(values, dim=0), reduction) - def __call__( - self, x: torch.Tensor, action: torch.Tensor, reduction: str = "mean" + def compute_error( + self, + observations: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + target: torch.Tensor, + terminals: torch.Tensor, + gamma: float = 0.99, ) -> torch.Tensor: - return cast(torch.Tensor, super().__call__(x, action, reduction)) + return compute_ensemble_q_function_error( + forwarders=self._forwarders, + observations=observations, + actions=actions, + rewards=rewards, + target=target, + terminals=terminals, + gamma=gamma, + ) def compute_target( self, @@ -189,13 +247,24 @@ def compute_target( reduction: str = "min", lam: float = 0.75, ) -> torch.Tensor: - return self._compute_target(x, action, reduction, lam) + return compute_ensemble_q_function_target( + forwarders=self._forwarders, + action_size=self._action_size, + x=x, + action=action, + reduction=reduction, + lam=lam, + ) + + @property + def forwarders(self) -> Sequence[ContinuousQFunctionForwarder]: + return self._forwarders def compute_max_with_n_actions_and_indices( x: torch.Tensor, actions: torch.Tensor, - q_func: EnsembleContinuousQFunction, + forwarder: ContinuousEnsembleQFunctionForwarder, lam: float, ) -> Tuple[torch.Tensor, torch.Tensor]: """Returns weighted target value from sampled actions. @@ -205,7 +274,7 @@ def compute_max_with_n_actions_and_indices( `actions` should be shaped with `(batch, N, dim_action)`. """ batch_size = actions.shape[0] - n_critics = len(q_func.q_funcs) + n_critics = len(forwarder.forwarders) n_actions = actions.shape[1] # (batch, observation) -> (batch, n, observation) @@ -216,7 +285,7 @@ def compute_max_with_n_actions_and_indices( flat_actions = actions.reshape(batch_size * n_actions, -1) # estimate values while taking care of quantiles - flat_values = q_func.compute_target(flat_x, flat_actions, "none") + flat_values = forwarder.compute_target(flat_x, flat_actions, "none") # reshape to (n_ensembles, batch_size, n, -1) transposed_values = flat_values.view(n_critics, batch_size, n_actions, -1) # (n_ensembles, batch_size, n, -1) -> (batch_size, n_ensembles, n, -1) @@ -254,7 +323,7 @@ def compute_max_with_n_actions_and_indices( def compute_max_with_n_actions( x: torch.Tensor, actions: torch.Tensor, - q_func: EnsembleContinuousQFunction, + forwarder: ContinuousEnsembleQFunctionForwarder, lam: float, ) -> torch.Tensor: - return compute_max_with_n_actions_and_indices(x, actions, q_func, lam)[0] + return compute_max_with_n_actions_and_indices(x, actions, forwarder, lam)[0] diff --git a/d3rlpy/models/torch/q_functions/iqn_q_function.py b/d3rlpy/models/torch/q_functions/iqn_q_function.py index 9e64eb37..8dbed24a 100644 --- a/d3rlpy/models/torch/q_functions/iqn_q_function.py +++ b/d3rlpy/models/torch/q_functions/iqn_q_function.py @@ -1,34 +1,45 @@ import math -from typing import Optional, cast +from typing import Optional import torch from torch import nn from ..encoders import Encoder, EncoderWithAction -from .base import ContinuousQFunction, DiscreteQFunction +from .base import ( + ContinuousQFunction, + ContinuousQFunctionForwarder, + DiscreteQFunction, + DiscreteQFunctionForwarder, + QFunctionOutput, +) from .utility import ( compute_quantile_loss, compute_reduce, pick_quantile_value_by_action, ) -__all__ = ["DiscreteIQNQFunction", "ContinuousIQNQFunction"] +__all__ = [ + "DiscreteIQNQFunction", + "ContinuousIQNQFunction", + "DiscreteIQNQFunctionForwarder", + "ContinuousIQNQFunctionForwarder", +] def _make_taus( - h: torch.Tensor, n_quantiles: int, training: bool + batch_size: int, n_quantiles: int, training: bool, device: torch.device ) -> torch.Tensor: if training: - taus = torch.rand(h.shape[0], n_quantiles, device=h.device) + taus = torch.rand(batch_size, n_quantiles, device=device) else: taus = torch.linspace( start=0, end=1, steps=n_quantiles, - device=h.device, + device=device, dtype=torch.float32, ) - taus = taus.view(1, -1).repeat(h.shape[0], 1) + taus = taus.view(1, -1).repeat(batch_size, 1) return taus @@ -49,7 +60,7 @@ def compute_iqn_feature( return h.view(h.shape[0], 1, -1) * phi -class DiscreteIQNQFunction(DiscreteQFunction, nn.Module): # type: ignore +class DiscreteIQNQFunction(DiscreteQFunction): _action_size: int _encoder: Encoder _fc: nn.Linear @@ -76,26 +87,46 @@ def __init__( self._embed_size = embed_size self._embed = nn.Linear(embed_size, hidden_size) - def _make_taus(self, h: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> QFunctionOutput: + h = self._encoder(x) + if self.training: n_quantiles = self._n_quantiles else: n_quantiles = self._n_greedy_quantiles - return _make_taus(h, n_quantiles, self.training) + taus = _make_taus( + batch_size=x.shape[0], + n_quantiles=n_quantiles, + training=self.training, + device=x.device, + ) - def _compute_quantiles( - self, h: torch.Tensor, taus: torch.Tensor - ) -> torch.Tensor: - # element-wise product on feature and phi (batch, quantile, feature) + # (batch, quantile, feature) prod = compute_iqn_feature(h, taus, self._embed, self._embed_size) - # (batch, quantile, feature) -> (batch, action, quantile) - return cast(torch.Tensor, self._fc(prod)).transpose(1, 2) + # (batch, quantile, action) -> (batch, action, quantile) + quantiles = self._fc(prod).transpose(1, 2) + + return QFunctionOutput( + q_value=quantiles.mean(dim=2), + quantiles=quantiles, + taus=taus, + ) + + @property + def encoder(self) -> Encoder: + return self._encoder - def forward(self, x: torch.Tensor) -> torch.Tensor: - h = self._encoder(x) - taus = self._make_taus(h) - quantiles = self._compute_quantiles(h, taus) - return quantiles.mean(dim=2) + +class DiscreteIQNQFunctionForwarder(DiscreteQFunctionForwarder): + _q_func: DiscreteIQNQFunction + _n_quantiles: int + + def __init__(self, q_func: DiscreteIQNQFunction, n_quantiles: int): + self._q_func = q_func + self._n_quantiles = n_quantiles + + def compute_expected_q(self, x: torch.Tensor) -> torch.Tensor: + return self._q_func(x).q_value def compute_error( self, @@ -110,9 +141,10 @@ def compute_error( assert target.shape == (observations.shape[0], self._n_quantiles) # extraect quantiles corresponding to act_t - h = self._encoder(observations) - taus = self._make_taus(h) - all_quantiles = self._compute_quantiles(h, taus) + output = self._q_func(observations) + taus = output.taus + all_quantiles = output.quantiles + assert taus is not None and all_quantiles is not None quantiles = pick_quantile_value_by_action(all_quantiles, actions) loss = compute_quantile_loss( @@ -129,24 +161,14 @@ def compute_error( def compute_target( self, x: torch.Tensor, action: Optional[torch.Tensor] = None ) -> torch.Tensor: - h = self._encoder(x) - taus = self._make_taus(h) - quantiles = self._compute_quantiles(h, taus) + quantiles = self._q_func(x).quantiles + assert quantiles is not None if action is None: return quantiles return pick_quantile_value_by_action(quantiles, action) - @property - def action_size(self) -> int: - return self._action_size - - @property - def encoder(self) -> Encoder: - return self._encoder - class ContinuousIQNQFunction(ContinuousQFunction, nn.Module): # type: ignore - _action_size: int _encoder: EncoderWithAction _fc: nn.Linear _n_quantiles: int @@ -158,40 +180,60 @@ def __init__( self, encoder: EncoderWithAction, hidden_size: int, - action_size: int, n_quantiles: int, n_greedy_quantiles: int, embed_size: int, ): super().__init__() self._encoder = encoder - self._action_size = action_size self._fc = nn.Linear(hidden_size, 1) self._n_quantiles = n_quantiles self._n_greedy_quantiles = n_greedy_quantiles self._embed_size = embed_size self._embed = nn.Linear(embed_size, hidden_size) - def _make_taus(self, h: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, action: torch.Tensor) -> QFunctionOutput: + h = self._encoder(x, action) + if self.training: n_quantiles = self._n_quantiles else: n_quantiles = self._n_greedy_quantiles - return _make_taus(h, n_quantiles, self.training) + taus = _make_taus( + batch_size=x.shape[0], + n_quantiles=n_quantiles, + training=self.training, + device=x.device, + ) - def _compute_quantiles( - self, h: torch.Tensor, taus: torch.Tensor - ) -> torch.Tensor: # element-wise product on feature and phi (batch, quantile, feature) prod = compute_iqn_feature(h, taus, self._embed, self._embed_size) # (batch, quantile, feature) -> (batch, quantile) - return cast(torch.Tensor, self._fc(prod)).view(h.shape[0], -1) + quantiles = self._fc(prod).view(h.shape[0], -1) + + return QFunctionOutput( + q_value=quantiles.mean(dim=1, keepdim=True), + quantiles=quantiles, + taus=taus, + ) + + @property + def encoder(self) -> EncoderWithAction: + return self._encoder - def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: - h = self._encoder(x, action) - taus = self._make_taus(h) - quantiles = self._compute_quantiles(h, taus) - return quantiles.mean(dim=1, keepdim=True) + +class ContinuousIQNQFunctionForwarder(ContinuousQFunctionForwarder): + _q_func: ContinuousIQNQFunction + _n_quantiles: int + + def __init__(self, q_func: ContinuousIQNQFunction, n_quantiles: int): + self._q_func = q_func + self._n_quantiles = n_quantiles + + def compute_expected_q( + self, x: torch.Tensor, action: torch.Tensor + ) -> torch.Tensor: + return self._q_func(x, action).q_value def compute_error( self, @@ -205,9 +247,10 @@ def compute_error( ) -> torch.Tensor: assert target.shape == (observations.shape[0], self._n_quantiles) - h = self._encoder(observations, actions) - taus = self._make_taus(h) - quantiles = self._compute_quantiles(h, taus) + output = self._q_func(observations, actions) + taus = output.taus + quantiles = output.quantiles + assert taus is not None and quantiles is not None loss = compute_quantile_loss( quantiles=quantiles, @@ -223,14 +266,6 @@ def compute_error( def compute_target( self, x: torch.Tensor, action: torch.Tensor ) -> torch.Tensor: - h = self._encoder(x, action) - taus = self._make_taus(h) - return self._compute_quantiles(h, taus) - - @property - def action_size(self) -> int: - return self._action_size - - @property - def encoder(self) -> EncoderWithAction: - return self._encoder + quantiles = self._q_func(x, action).quantiles + assert quantiles is not None + return quantiles diff --git a/d3rlpy/models/torch/q_functions/mean_q_function.py b/d3rlpy/models/torch/q_functions/mean_q_function.py index d3d5e7aa..43593889 100644 --- a/d3rlpy/models/torch/q_functions/mean_q_function.py +++ b/d3rlpy/models/torch/q_functions/mean_q_function.py @@ -1,29 +1,58 @@ -from typing import Optional, cast +from typing import Optional import torch import torch.nn.functional as F from torch import nn from ..encoders import Encoder, EncoderWithAction -from .base import ContinuousQFunction, DiscreteQFunction +from .base import ( + ContinuousQFunction, + ContinuousQFunctionForwarder, + DiscreteQFunction, + DiscreteQFunctionForwarder, + QFunctionOutput, +) from .utility import compute_huber_loss, compute_reduce, pick_value_by_action -__all__ = ["DiscreteMeanQFunction", "ContinuousMeanQFunction"] +__all__ = [ + "DiscreteMeanQFunction", + "ContinuousMeanQFunction", + "DiscreteMeanQFunctionForwarder", + "ContinuousMeanQFunctionForwarder", +] -class DiscreteMeanQFunction(DiscreteQFunction, nn.Module): # type: ignore - _action_size: int +class DiscreteMeanQFunction(DiscreteQFunction): _encoder: Encoder _fc: nn.Linear def __init__(self, encoder: Encoder, hidden_size: int, action_size: int): super().__init__() - self._action_size = action_size self._encoder = encoder self._fc = nn.Linear(hidden_size, action_size) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, self._fc(self._encoder(x))) + def forward(self, x: torch.Tensor) -> QFunctionOutput: + return QFunctionOutput( + q_value=self._fc(self._encoder(x)), + quantiles=None, + taus=None, + ) + + @property + def encoder(self) -> Encoder: + return self._encoder + + +class DiscreteMeanQFunctionForwarder(DiscreteQFunctionForwarder): + _q_func: DiscreteMeanQFunction + _action_size: int + + def __init__(self, q_func: DiscreteMeanQFunction, action_size: int): + self._q_func = q_func + self._action_size = action_size + + def compute_expected_q(self, x: torch.Tensor) -> torch.Tensor: + return self._q_func(x).q_value def compute_error( self, @@ -35,8 +64,8 @@ def compute_error( gamma: float = 0.99, reduction: str = "mean", ) -> torch.Tensor: - one_hot = F.one_hot(actions.view(-1), num_classes=self.action_size) - value = (self.forward(observations) * one_hot.float()).sum( + one_hot = F.one_hot(actions.view(-1), num_classes=self._action_size) + value = (self._q_func(observations).q_value * one_hot.float()).sum( dim=1, keepdim=True ) y = rewards + gamma * target * (1 - terminals) @@ -47,33 +76,43 @@ def compute_target( self, x: torch.Tensor, action: Optional[torch.Tensor] = None ) -> torch.Tensor: if action is None: - return self.forward(x) - return pick_value_by_action(self.forward(x), action, keepdim=True) - - @property - def action_size(self) -> int: - return self._action_size - - @property - def encoder(self) -> Encoder: - return self._encoder + return self._q_func(x).q_value + return pick_value_by_action( + self._q_func(x).q_value, action, keepdim=True + ) -class ContinuousMeanQFunction(ContinuousQFunction, nn.Module): # type: ignore +class ContinuousMeanQFunction(ContinuousQFunction): _encoder: EncoderWithAction - _action_size: int _fc: nn.Linear - def __init__( - self, encoder: EncoderWithAction, hidden_size: int, action_size: int - ): + def __init__(self, encoder: EncoderWithAction, hidden_size: int): super().__init__() self._encoder = encoder - self._action_size = action_size self._fc = nn.Linear(hidden_size, 1) - def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, self._fc(self._encoder(x, action))) + def forward(self, x: torch.Tensor, action: torch.Tensor) -> QFunctionOutput: + return QFunctionOutput( + q_value=self._fc(self._encoder(x, action)), + quantiles=None, + taus=None, + ) + + @property + def encoder(self) -> EncoderWithAction: + return self._encoder + + +class ContinuousMeanQFunctionForwarder(ContinuousQFunctionForwarder): + _q_func: ContinuousMeanQFunction + + def __init__(self, q_func: ContinuousMeanQFunction): + self._q_func = q_func + + def compute_expected_q( + self, x: torch.Tensor, action: torch.Tensor + ) -> torch.Tensor: + return self._q_func(x, action).q_value def compute_error( self, @@ -85,7 +124,7 @@ def compute_error( gamma: float = 0.99, reduction: str = "mean", ) -> torch.Tensor: - value = self.forward(observations, actions) + value = self._q_func(observations, actions).q_value y = rewards + gamma * target * (1 - terminals) loss = F.mse_loss(value, y, reduction="none") return compute_reduce(loss, reduction) @@ -93,12 +132,4 @@ def compute_error( def compute_target( self, x: torch.Tensor, action: torch.Tensor ) -> torch.Tensor: - return self.forward(x, action) - - @property - def action_size(self) -> int: - return self._action_size - - @property - def encoder(self) -> EncoderWithAction: - return self._encoder + return self._q_func(x, action).q_value diff --git a/d3rlpy/models/torch/q_functions/qr_q_function.py b/d3rlpy/models/torch/q_functions/qr_q_function.py index 448de90c..601d0567 100644 --- a/d3rlpy/models/torch/q_functions/qr_q_function.py +++ b/d3rlpy/models/torch/q_functions/qr_q_function.py @@ -1,27 +1,38 @@ -from typing import Optional, cast +from typing import Optional import torch from torch import nn from ..encoders import Encoder, EncoderWithAction -from .base import ContinuousQFunction, DiscreteQFunction +from .base import ( + ContinuousQFunction, + ContinuousQFunctionForwarder, + DiscreteQFunction, + DiscreteQFunctionForwarder, + QFunctionOutput, +) from .utility import ( compute_quantile_loss, compute_reduce, pick_quantile_value_by_action, ) -__all__ = ["DiscreteQRQFunction", "ContinuousQRQFunction"] +__all__ = [ + "DiscreteQRQFunction", + "ContinuousQRQFunction", + "ContinuousQRQFunctionForwarder", + "DiscreteQRQFunctionForwarder", +] -def _make_taus(h: torch.Tensor, n_quantiles: int) -> torch.Tensor: - steps = torch.arange(n_quantiles, dtype=torch.float32, device=h.device) +def _make_taus(n_quantiles: int, device: torch.device) -> torch.Tensor: + steps = torch.arange(n_quantiles, dtype=torch.float32, device=device) taus = ((steps + 1).float() / n_quantiles).view(1, -1) taus_dot = (steps.float() / n_quantiles).view(1, -1) return (taus + taus_dot) / 2.0 -class DiscreteQRQFunction(DiscreteQFunction, nn.Module): # type: ignore +class DiscreteQRQFunction(DiscreteQFunction): _action_size: int _encoder: Encoder _n_quantiles: int @@ -40,17 +51,30 @@ def __init__( self._n_quantiles = n_quantiles self._fc = nn.Linear(hidden_size, action_size * n_quantiles) - def _compute_quantiles( - self, h: torch.Tensor, taus: torch.Tensor - ) -> torch.Tensor: - h = cast(torch.Tensor, self._fc(h)) - return h.view(-1, self._action_size, self._n_quantiles) + def forward(self, x: torch.Tensor) -> QFunctionOutput: + quantiles = self._fc(self._encoder(x)) + quantiles = quantiles.view(-1, self._action_size, self._n_quantiles) + return QFunctionOutput( + q_value=quantiles.mean(dim=2), + quantiles=quantiles, + taus=_make_taus(self._n_quantiles, device=x.device), + ) + + @property + def encoder(self) -> Encoder: + return self._encoder + - def forward(self, x: torch.Tensor) -> torch.Tensor: - h = self._encoder(x) - taus = _make_taus(h, self._n_quantiles) - quantiles = self._compute_quantiles(h, taus) - return quantiles.mean(dim=2) +class DiscreteQRQFunctionForwarder(DiscreteQFunctionForwarder): + _q_func: DiscreteQRQFunction + _n_quantiles: int + + def __init__(self, q_func: DiscreteQRQFunction, n_quantiles: int): + self._q_func = q_func + self._n_quantiles = n_quantiles + + def compute_expected_q(self, x: torch.Tensor) -> torch.Tensor: + return self._q_func(x).q_value def compute_error( self, @@ -65,9 +89,10 @@ def compute_error( assert target.shape == (observations.shape[0], self._n_quantiles) # extraect quantiles corresponding to act_t - h = self._encoder(observations) - taus = _make_taus(h, self._n_quantiles) - all_quantiles = self._compute_quantiles(h, taus) + output = self._q_func(observations) + all_quantiles = output.quantiles + taus = output.taus + assert all_quantiles is not None and taus is not None quantiles = pick_quantile_value_by_action(all_quantiles, actions) loss = compute_quantile_loss( @@ -84,51 +109,54 @@ def compute_error( def compute_target( self, x: torch.Tensor, action: Optional[torch.Tensor] = None ) -> torch.Tensor: - h = self._encoder(x) - taus = _make_taus(h, self._n_quantiles) - quantiles = self._compute_quantiles(h, taus) + quantiles = self._q_func(x).quantiles + assert quantiles is not None if action is None: return quantiles return pick_quantile_value_by_action(quantiles, action) - @property - def action_size(self) -> int: - return self._action_size - - @property - def encoder(self) -> Encoder: - return self._encoder - -class ContinuousQRQFunction(ContinuousQFunction, nn.Module): # type: ignore - _action_size: int +class ContinuousQRQFunction(ContinuousQFunction): _encoder: EncoderWithAction - _n_quantiles: int _fc: nn.Linear + _n_quantiles: int def __init__( self, encoder: EncoderWithAction, hidden_size: int, - action_size: int, n_quantiles: int, ): super().__init__() self._encoder = encoder - self._action_size = action_size - self._n_quantiles = n_quantiles self._fc = nn.Linear(hidden_size, n_quantiles) + self._n_quantiles = n_quantiles - def _compute_quantiles( - self, h: torch.Tensor, taus: torch.Tensor - ) -> torch.Tensor: - return cast(torch.Tensor, self._fc(h)) + def forward(self, x: torch.Tensor, action: torch.Tensor) -> QFunctionOutput: + quantiles = self._fc(self._encoder(x, action)) + return QFunctionOutput( + q_value=quantiles.mean(dim=1, keepdim=True), + quantiles=quantiles, + taus=_make_taus(self._n_quantiles, device=x.device), + ) - def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: - h = self._encoder(x, action) - taus = _make_taus(h, self._n_quantiles) - quantiles = self._compute_quantiles(h, taus) - return quantiles.mean(dim=1, keepdim=True) + @property + def encoder(self) -> EncoderWithAction: + return self._encoder + + +class ContinuousQRQFunctionForwarder(ContinuousQFunctionForwarder): + _q_func: ContinuousQRQFunction + _n_quantiles: int + + def __init__(self, q_func: ContinuousQRQFunction, n_quantiles: int): + self._q_func = q_func + self._n_quantiles = n_quantiles + + def compute_expected_q( + self, x: torch.Tensor, action: torch.Tensor + ) -> torch.Tensor: + return self._q_func(x, action).q_value def compute_error( self, @@ -142,9 +170,10 @@ def compute_error( ) -> torch.Tensor: assert target.shape == (observations.shape[0], self._n_quantiles) - h = self._encoder(observations, actions) - taus = _make_taus(h, self._n_quantiles) - quantiles = self._compute_quantiles(h, taus) + output = self._q_func(observations, actions) + quantiles = output.quantiles + taus = output.taus + assert quantiles is not None and taus is not None loss = compute_quantile_loss( quantiles=quantiles, @@ -160,14 +189,6 @@ def compute_error( def compute_target( self, x: torch.Tensor, action: torch.Tensor ) -> torch.Tensor: - h = self._encoder(x, action) - taus = _make_taus(h, self._n_quantiles) - return self._compute_quantiles(h, taus) - - @property - def action_size(self) -> int: - return self._action_size - - @property - def encoder(self) -> EncoderWithAction: - return self._encoder + quantiles = self._q_func(x, action).quantiles + assert quantiles is not None + return quantiles diff --git a/d3rlpy/ope/fqe.py b/d3rlpy/ope/fqe.py index 4b2a413d..d6a196a2 100644 --- a/d3rlpy/ope/fqe.py +++ b/d3rlpy/ope/fqe.py @@ -156,7 +156,15 @@ class FQE(_FQEBase): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - q_func = create_continuous_q_function( + q_funcs, q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.encoder_factory, @@ -165,12 +173,15 @@ def inner_create_impl( device=self._device, ) optim = self._config.optim_factory.create( - q_func.parameters(), lr=self._config.learning_rate + q_funcs.parameters(), lr=self._config.learning_rate ) self._impl = FQEImpl( observation_shape=observation_shape, action_size=action_size, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, optim=optim, gamma=self._config.gamma, device=self._device, @@ -212,7 +223,15 @@ class DiscreteFQE(_FQEBase): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - q_func = create_discrete_q_function( + q_funcs, q_func_forwarder = create_discrete_q_function( + observation_shape, + action_size, + self._config.encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function( observation_shape, action_size, self._config.encoder_factory, @@ -221,12 +240,15 @@ def inner_create_impl( device=self._device, ) optim = self._config.optim_factory.create( - q_func.parameters(), lr=self._config.learning_rate + q_funcs.parameters(), lr=self._config.learning_rate ) self._impl = DiscreteFQEImpl( observation_shape=observation_shape, action_size=action_size, - q_func=q_func, + q_funcs=q_funcs, + q_func_forwarder=q_func_forwarder, + targ_q_funcs=targ_q_funcs, + targ_q_func_forwarder=targ_q_func_forwarder, optim=optim, gamma=self._config.gamma, device=self._device, diff --git a/d3rlpy/ope/torch/fqe_impl.py b/d3rlpy/ope/torch/fqe_impl.py index 49ac273c..cab2f4f0 100644 --- a/d3rlpy/ope/torch/fqe_impl.py +++ b/d3rlpy/ope/torch/fqe_impl.py @@ -1,6 +1,7 @@ -import copy +from typing import Union import torch +from torch import nn from torch.optim import Optimizer from ...algos.qlearning.base import QLearningAlgoImplBase @@ -10,9 +11,8 @@ ) from ...dataset import Shape from ...models.torch import ( - EnsembleContinuousQFunction, - EnsembleDiscreteQFunction, - EnsembleQFunction, + ContinuousEnsembleQFunctionForwarder, + DiscreteEnsembleQFunctionForwarder, ) from ...torch_utility import TorchMiniBatch, hard_sync, train_api @@ -21,15 +21,30 @@ class FQEBaseImpl(QLearningAlgoImplBase): _gamma: float - _q_func: EnsembleQFunction - _targ_q_func: EnsembleQFunction + _q_funcs: nn.ModuleList + _q_func_forwarder: Union[ + DiscreteEnsembleQFunctionForwarder, ContinuousEnsembleQFunctionForwarder + ] + _targ_q_funcs: nn.ModuleList + _targ_q_func_forwarder: Union[ + DiscreteEnsembleQFunctionForwarder, ContinuousEnsembleQFunctionForwarder + ] _optim: Optimizer def __init__( self, observation_shape: Shape, action_size: int, - q_func: EnsembleQFunction, + q_funcs: nn.ModuleList, + q_func_forwarder: Union[ + DiscreteEnsembleQFunctionForwarder, + ContinuousEnsembleQFunctionForwarder, + ], + targ_q_funcs: nn.ModuleList, + targ_q_func_forwarder: Union[ + DiscreteEnsembleQFunctionForwarder, + ContinuousEnsembleQFunctionForwarder, + ], optim: Optimizer, gamma: float, device: str, @@ -40,9 +55,12 @@ def __init__( device=device, ) self._gamma = gamma - self._q_func = q_func - self._targ_q_func = copy.deepcopy(q_func) + self._q_funcs = q_funcs + self._q_func_forwarder = q_func_forwarder + self._targ_q_funcs = targ_q_funcs + self._targ_q_func_forwarder = targ_q_func_forwarder self._optim = optim + hard_sync(targ_q_funcs, q_funcs) @train_api def update( @@ -62,7 +80,7 @@ def compute_loss( batch: TorchMiniBatch, q_tpn: torch.Tensor, ) -> torch.Tensor: - return self._q_func.compute_error( + return self._q_func_forwarder.compute_error( observations=batch.observations, actions=batch.actions, rewards=batch.rewards, @@ -75,12 +93,12 @@ def compute_target( self, batch: TorchMiniBatch, next_actions: torch.Tensor ) -> torch.Tensor: with torch.no_grad(): - return self._targ_q_func.compute_target( + return self._targ_q_func_forwarder.compute_target( batch.next_observations, next_actions ) def update_target(self) -> None: - hard_sync(self._targ_q_func, self._q_func) + hard_sync(self._targ_q_funcs, self._q_funcs) def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -90,20 +108,20 @@ def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: class FQEImpl(ContinuousQFunctionMixin, FQEBaseImpl): - _q_func: EnsembleContinuousQFunction - _targ_q_func: EnsembleContinuousQFunction + _q_func_forwarder: DiscreteEnsembleQFunctionForwarder + _targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder class DiscreteFQEImpl(DiscreteQFunctionMixin, FQEBaseImpl): - _q_func: EnsembleDiscreteQFunction - _targ_q_func: EnsembleDiscreteQFunction + _q_func_forwarder: ContinuousEnsembleQFunctionForwarder + _targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder def compute_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor, ) -> torch.Tensor: - return self._q_func.compute_error( + return self._q_func_forwarder.compute_error( observations=batch.observations, actions=batch.actions.long(), rewards=batch.rewards, @@ -116,7 +134,7 @@ def compute_target( self, batch: TorchMiniBatch, next_actions: torch.Tensor ) -> torch.Tensor: with torch.no_grad(): - return self._targ_q_func.compute_target( + return self._targ_q_func_forwarder.compute_target( batch.next_observations, next_actions.long(), ) diff --git a/tests/models/test_builders.py b/tests/models/test_builders.py index 4a3e3f93..a59a2b80 100644 --- a/tests/models/test_builders.py +++ b/tests/models/test_builders.py @@ -20,8 +20,8 @@ from d3rlpy.models.encoders import DefaultEncoderFactory, EncoderFactory from d3rlpy.models.q_functions import MeanQFunctionFactory from d3rlpy.models.torch import ( - EnsembleContinuousQFunction, - EnsembleDiscreteQFunction, + ContinuousEnsembleQFunctionForwarder, + DiscreteEnsembleQFunctionForwarder, ) from d3rlpy.models.torch.imitators import ConditionalVAE from d3rlpy.models.torch.policies import ( @@ -140,7 +140,7 @@ def test_create_discrete_q_function( ) -> None: q_func_factory = MeanQFunctionFactory(share_encoder=share_encoder) - q_func = create_discrete_q_function( + q_funcs, forwarder = create_discrete_q_function( observation_shape, action_size, encoder_factory, @@ -149,18 +149,18 @@ def test_create_discrete_q_function( n_ensembles=n_ensembles, ) - assert isinstance(q_func, EnsembleDiscreteQFunction) + assert isinstance(forwarder, DiscreteEnsembleQFunctionForwarder) # check share_encoder - encoder = q_func.q_funcs[0].encoder - for q_func in q_func.q_funcs[1:]: + encoder = q_funcs[0].encoder + for q_func in q_funcs[1:]: if share_encoder: assert encoder is q_func.encoder else: assert encoder is not q_func.encoder x = torch.rand((batch_size, *observation_shape)) - y = q_func(x) + y = forwarder.compute_expected_q(x) assert y.shape == (batch_size, action_size) @@ -180,7 +180,7 @@ def test_create_continuous_q_function( ) -> None: q_func_factory = MeanQFunctionFactory(share_encoder=share_encoder) - q_func = create_continuous_q_function( + q_funcs, forwarder = create_continuous_q_function( observation_shape, action_size, encoder_factory, @@ -189,11 +189,11 @@ def test_create_continuous_q_function( n_ensembles=n_ensembles, ) - assert isinstance(q_func, EnsembleContinuousQFunction) + assert isinstance(forwarder, ContinuousEnsembleQFunctionForwarder) # check share_encoder - encoder = q_func.q_funcs[0].encoder - for q_func in q_func.q_funcs[1:]: + encoder = q_funcs[0].encoder + for q_func in q_funcs[1:]: if share_encoder: assert encoder is q_func.encoder else: @@ -201,7 +201,7 @@ def test_create_continuous_q_function( x = torch.rand((batch_size, *observation_shape)) action = torch.rand(batch_size, action_size) - y = q_func(x, action) + y = forwarder.compute_expected_q(x, action) assert y.shape == (batch_size, 1) diff --git a/tests/models/test_q_functions.py b/tests/models/test_q_functions.py index 3f1098c4..7bc7e00a 100644 --- a/tests/models/test_q_functions.py +++ b/tests/models/test_q_functions.py @@ -10,10 +10,16 @@ ) from d3rlpy.models.torch import ( ContinuousIQNQFunction, + ContinuousIQNQFunctionForwarder, ContinuousMeanQFunction, + ContinuousMeanQFunctionForwarder, ContinuousQRQFunction, + ContinuousQRQFunctionForwarder, DiscreteIQNQFunction, + DiscreteIQNQFunctionForwarder, DiscreteMeanQFunction, + DiscreteMeanQFunctionForwarder, + DiscreteQFunctionForwarder, DiscreteQRQFunction, compute_output_size, ) @@ -46,15 +52,19 @@ def test_mean_q_function_factory( hidden_size = compute_output_size( [observation_shape, (action_size,)], encoder_with_action, "cpu:0" ) - q_func = factory.create_continuous( - encoder_with_action, hidden_size, action_size + q_func, forwarder = factory.create_continuous( + encoder_with_action, hidden_size ) assert isinstance(q_func, ContinuousMeanQFunction) + assert isinstance(forwarder, ContinuousMeanQFunctionForwarder) encoder = _create_encoder(observation_shape) hidden_size = compute_output_size([observation_shape], encoder, "cpu:0") - discrete_q_func = factory.create_discrete(encoder, hidden_size, action_size) + discrete_q_func, discrete_forwarder = factory.create_discrete( + encoder, hidden_size, action_size + ) assert isinstance(discrete_q_func, DiscreteMeanQFunction) + assert isinstance(discrete_forwarder, DiscreteMeanQFunctionForwarder) # check serization and deserialization MeanQFunctionFactory.deserialize(factory.serialize()) @@ -74,15 +84,19 @@ def test_qr_q_function_factory( hidden_size = compute_output_size( [observation_shape, (action_size,)], encoder_with_action, "cpu:0" ) - q_func = factory.create_continuous( - encoder_with_action, hidden_size, action_size + q_func, forwarder = factory.create_continuous( + encoder_with_action, hidden_size ) assert isinstance(q_func, ContinuousQRQFunction) + assert isinstance(forwarder, ContinuousQRQFunctionForwarder) encoder = _create_encoder(observation_shape) hidden_size = compute_output_size([observation_shape], encoder, "cpu:0") - discrete_q_func = factory.create_discrete(encoder, hidden_size, action_size) + discrete_q_func, discrete_forwarder = factory.create_discrete( + encoder, hidden_size, action_size + ) assert isinstance(discrete_q_func, DiscreteQRQFunction) + assert isinstance(discrete_forwarder, DiscreteQFunctionForwarder) # check serization and deserialization QRQFunctionFactory.deserialize(factory.serialize()) @@ -102,15 +116,19 @@ def test_iqn_q_function_factory( hidden_size = compute_output_size( [observation_shape, (action_size,)], encoder_with_action, "cpu:0" ) - q_func = factory.create_continuous( - encoder_with_action, hidden_size, action_size + q_func, forwarder = factory.create_continuous( + encoder_with_action, hidden_size ) assert isinstance(q_func, ContinuousIQNQFunction) + assert isinstance(forwarder, ContinuousIQNQFunctionForwarder) encoder = _create_encoder(observation_shape) hidden_size = compute_output_size([observation_shape], encoder, "cpu:0") - discrete_q_func = factory.create_discrete(encoder, hidden_size, action_size) + discrete_q_func, discrete_forwarder = factory.create_discrete( + encoder, hidden_size, action_size + ) assert isinstance(discrete_q_func, DiscreteIQNQFunction) + assert isinstance(discrete_forwarder, DiscreteIQNQFunctionForwarder) # check serization and deserialization IQNQFunctionFactory.deserialize(factory.serialize()) diff --git a/tests/models/torch/model_test.py b/tests/models/torch/model_test.py index 01559411..c5f10a30 100644 --- a/tests/models/torch/model_test.py +++ b/tests/models/torch/model_test.py @@ -5,7 +5,7 @@ import torch from torch.optim import SGD -from d3rlpy.models.torch import ActionOutput +from d3rlpy.models.torch import ActionOutput, QFunctionOutput from d3rlpy.models.torch.encoders import Encoder, EncoderWithAction @@ -26,6 +26,8 @@ def check_parameter_updates( output = mu if logstd is not None: output = output + logstd + elif isinstance(output, QFunctionOutput): + output = output.q_value if isinstance(output, (list, tuple)): loss = 0.0 for y in output: diff --git a/tests/models/torch/q_functions/test_ensemble_q_function.py b/tests/models/torch/q_functions/test_ensemble_q_function.py index 8db2c969..4a1f59ce 100644 --- a/tests/models/torch/q_functions/test_ensemble_q_function.py +++ b/tests/models/torch/q_functions/test_ensemble_q_function.py @@ -4,27 +4,29 @@ import torch from d3rlpy.models.torch import ( + ContinuousEnsembleQFunctionForwarder, ContinuousIQNQFunction, + ContinuousIQNQFunctionForwarder, ContinuousMeanQFunction, - ContinuousQFunction, + ContinuousMeanQFunctionForwarder, + ContinuousQFunctionForwarder, ContinuousQRQFunction, + ContinuousQRQFunctionForwarder, + DiscreteEnsembleQFunctionForwarder, DiscreteIQNQFunction, + DiscreteIQNQFunctionForwarder, DiscreteMeanQFunction, - DiscreteQFunction, + DiscreteMeanQFunctionForwarder, + DiscreteQFunctionForwarder, DiscreteQRQFunction, - EnsembleContinuousQFunction, - EnsembleDiscreteQFunction, + DiscreteQRQFunctionForwarder, ) from d3rlpy.models.torch.q_functions.ensemble_q_function import ( _reduce_ensemble, _reduce_quantile_ensemble, ) -from ..model_test import ( - DummyEncoder, - DummyEncoderWithAction, - check_parameter_updates, -) +from ..model_test import DummyEncoder, DummyEncoderWithAction @pytest.mark.parametrize("n_ensembles", [2]) @@ -80,7 +82,7 @@ def test_reduce_quantile_ensemble( @pytest.mark.parametrize("q_func_factory", ["mean", "qr", "iqn"]) @pytest.mark.parametrize("n_quantiles", [200]) @pytest.mark.parametrize("embed_size", [64]) -def test_ensemble_discrete_q_function( +def test_discrete_ensemble_q_function_forwarder( feature_size: int, action_size: int, batch_size: int, @@ -90,15 +92,18 @@ def test_ensemble_discrete_q_function( n_quantiles: int, embed_size: int, ) -> None: - q_funcs: List[DiscreteQFunction] = [] + forwarders: List[DiscreteQFunctionForwarder] = [] for _ in range(ensemble_size): encoder = DummyEncoder(feature_size) + forwarder: DiscreteQFunctionForwarder if q_func_factory == "mean": q_func = DiscreteMeanQFunction(encoder, feature_size, action_size) + forwarder = DiscreteMeanQFunctionForwarder(q_func, action_size) elif q_func_factory == "qr": q_func = DiscreteQRQFunction( encoder, feature_size, action_size, n_quantiles ) + forwarder = DiscreteQRQFunctionForwarder(q_func, n_quantiles) elif q_func_factory == "iqn": q_func = DiscreteIQNQFunction( encoder, @@ -108,17 +113,22 @@ def test_ensemble_discrete_q_function( n_quantiles, embed_size, ) - q_funcs.append(q_func) - q_func = EnsembleDiscreteQFunction(q_funcs) + forwarder = DiscreteIQNQFunctionForwarder(q_func, n_quantiles) + else: + raise ValueError + forwarders.append(forwarder) + ensemble_forwarder = DiscreteEnsembleQFunctionForwarder( + forwarders, action_size + ) # check output shape x = torch.rand(batch_size, feature_size) - values = q_func(x, "none") + values = ensemble_forwarder.compute_expected_q(x, "none") assert values.shape == (ensemble_size, batch_size, action_size) # check compute_target action = torch.randint(high=action_size, size=(batch_size,)) - target = q_func.compute_target(x, action) + target = ensemble_forwarder.compute_target(x, action) if q_func_factory == "mean": assert target.shape == (batch_size, 1) min_values = values.min(dim=0).values @@ -129,7 +139,7 @@ def test_ensemble_discrete_q_function( assert target.shape == (batch_size, n_quantiles) # check compute_target with action=None - targets = q_func.compute_target(x) + targets = ensemble_forwarder.compute_target(x) if q_func_factory == "mean": assert targets.shape == (batch_size, action_size) else: @@ -137,9 +147,17 @@ def test_ensemble_discrete_q_function( # check reductions if q_func_factory != "iqn": - assert torch.allclose(values.min(dim=0).values, q_func(x, "min")) - assert torch.allclose(values.max(dim=0).values, q_func(x, "max")) - assert torch.allclose(values.mean(dim=0), q_func(x, "mean")) + assert torch.allclose( + values.min(dim=0).values, + ensemble_forwarder.compute_expected_q(x, "min"), + ) + assert torch.allclose( + values.max(dim=0).values, + ensemble_forwarder.compute_expected_q(x, "max"), + ) + assert torch.allclose( + values.mean(dim=0), ensemble_forwarder.compute_expected_q(x, "mean") + ) # check td computation obs_t = torch.rand(batch_size, feature_size) @@ -153,21 +171,26 @@ def test_ensemble_discrete_q_function( else: q_tp1 = torch.rand(batch_size, n_quantiles) ref_td_sum = 0.0 - for i in range(ensemble_size): - f = q_func.q_funcs[i] - ref_td_sum += f.compute_error( - obs_t, act_t, rew_tp1, q_tp1, ter_tp1, gamma + for forwarder in forwarders: + ref_td_sum += forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + gamma=gamma, ) - loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1, gamma) + loss = ensemble_forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + gamma=gamma, + ) if q_func_factory != "iqn": assert torch.allclose(ref_td_sum, loss) - # check layer connection - check_parameter_updates( - q_func, - (obs_t, act_t, rew_tp1, q_tp1, ter_tp1, gamma), - ) - @pytest.mark.parametrize("feature_size", [100]) @pytest.mark.parametrize("action_size", [2]) @@ -187,36 +210,41 @@ def test_ensemble_continuous_q_function( n_quantiles: int, embed_size: int, ) -> None: - q_funcs: List[ContinuousQFunction] = [] + forwarders: List[ContinuousQFunctionForwarder] = [] for _ in range(ensemble_size): + forwarder: ContinuousQFunctionForwarder encoder = DummyEncoderWithAction(feature_size, action_size) if q_func_factory == "mean": - q_func = ContinuousMeanQFunction(encoder, feature_size, action_size) + q_func = ContinuousMeanQFunction(encoder, feature_size) + forwarder = ContinuousMeanQFunctionForwarder(q_func) elif q_func_factory == "qr": - q_func = ContinuousQRQFunction( - encoder, feature_size, action_size, n_quantiles - ) + q_func = ContinuousQRQFunction(encoder, feature_size, n_quantiles) + forwarder = ContinuousQRQFunctionForwarder(q_func, n_quantiles) elif q_func_factory == "iqn": q_func = ContinuousIQNQFunction( encoder, feature_size, - action_size, n_quantiles, n_quantiles, embed_size, ) - q_funcs.append(q_func) + forwarder = ContinuousIQNQFunctionForwarder(q_func, n_quantiles) + else: + raise ValueError + forwarders.append(forwarder) - q_func = EnsembleContinuousQFunction(q_funcs) + ensemble_forwarder = ContinuousEnsembleQFunctionForwarder( + forwarders, action_size + ) # check output shape x = torch.rand(batch_size, feature_size) action = torch.rand(batch_size, action_size) - values = q_func(x, action, "none") + values = ensemble_forwarder.compute_expected_q(x, action, "none") assert values.shape == (ensemble_size, batch_size, 1) # check compute_target - target = q_func.compute_target(x, action) + target = ensemble_forwarder.compute_target(x, action) if q_func_factory == "mean": assert target.shape == (batch_size, 1) min_values = values.min(dim=0).values @@ -226,9 +254,18 @@ def test_ensemble_continuous_q_function( # check reductions if q_func_factory != "iqn": - assert torch.allclose(values.min(dim=0)[0], q_func(x, action, "min")) - assert torch.allclose(values.max(dim=0)[0], q_func(x, action, "max")) - assert torch.allclose(values.mean(dim=0), q_func(x, action, "mean")) + assert torch.allclose( + values.min(dim=0)[0], + ensemble_forwarder.compute_expected_q(x, action, "min"), + ) + assert torch.allclose( + values.max(dim=0)[0], + ensemble_forwarder.compute_expected_q(x, action, "max"), + ) + assert torch.allclose( + values.mean(dim=0), + ensemble_forwarder.compute_expected_q(x, action, "mean"), + ) # check td computation obs_t = torch.rand(batch_size, feature_size) @@ -240,17 +277,22 @@ def test_ensemble_continuous_q_function( else: q_tp1 = torch.rand(batch_size, n_quantiles) ref_td_sum = 0.0 - for i in range(ensemble_size): - f = q_func.q_funcs[i] - ref_td_sum += f.compute_error( - obs_t, act_t, rew_tp1, q_tp1, ter_tp1, gamma + for forwarder in forwarders: + ref_td_sum += forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + gamma=gamma, ) - loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1, gamma) + loss = ensemble_forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + gamma=gamma, + ) if q_func_factory != "iqn": assert torch.allclose(ref_td_sum, loss) - - # check layer connection - check_parameter_updates( - q_func, - (obs_t, act_t, rew_tp1, q_tp1, ter_tp1, gamma), - ) diff --git a/tests/models/torch/q_functions/test_iqn_q_function.py b/tests/models/torch/q_functions/test_iqn_q_function.py index 553809d8..07147165 100644 --- a/tests/models/torch/q_functions/test_iqn_q_function.py +++ b/tests/models/torch/q_functions/test_iqn_q_function.py @@ -1,7 +1,12 @@ import pytest import torch -from d3rlpy.models.torch import ContinuousIQNQFunction, DiscreteIQNQFunction +from d3rlpy.models.torch import ( + ContinuousIQNQFunction, + ContinuousIQNQFunctionForwarder, + DiscreteIQNQFunction, + DiscreteIQNQFunctionForwarder, +) from ..model_test import ( DummyEncoder, @@ -37,22 +42,72 @@ def test_discrete_iqn_q_function( # check output shape x = torch.rand(batch_size, feature_size) y = q_func(x) - assert y.shape == (batch_size, action_size) + assert y.q_value.shape == (batch_size, action_size) + assert y.quantiles is not None + assert y.taus is not None + assert y.quantiles.shape == (batch_size, action_size, n_quantiles) + assert y.taus.shape == (batch_size, n_quantiles) + assert (y.q_value == y.quantiles.mean(dim=2)).all() # check eval mode q_func.eval() x = torch.rand(batch_size, feature_size) y = q_func(x) + assert y.q_value.shape == (batch_size, action_size) + assert y.quantiles is not None + assert y.taus is not None + assert y.quantiles.shape == (batch_size, action_size, n_greedy_quantiles) + assert y.taus.shape == (batch_size, n_greedy_quantiles) + q_func.train() + + # check layer connection + check_parameter_updates(q_func, (x,)) + + +@pytest.mark.parametrize("feature_size", [100]) +@pytest.mark.parametrize("action_size", [2]) +@pytest.mark.parametrize("n_quantiles", [200]) +@pytest.mark.parametrize("n_greedy_quantiles", [32]) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("embed_size", [64]) +def test_discrete_iqn_q_function_forwarder( + feature_size: int, + action_size: int, + n_quantiles: int, + n_greedy_quantiles: int, + batch_size: int, + embed_size: int, +) -> None: + encoder = DummyEncoder(feature_size) + q_func = DiscreteIQNQFunction( + encoder, + feature_size, + action_size, + n_quantiles, + n_greedy_quantiles, + embed_size, + ) + forwarder = DiscreteIQNQFunctionForwarder(q_func, n_quantiles) + + # check output shape + x = torch.rand(batch_size, feature_size) + y = forwarder.compute_expected_q(x) + assert y.shape == (batch_size, action_size) + + # check eval mode + q_func.eval() + x = torch.rand(batch_size, feature_size) + y = forwarder.compute_expected_q(x) assert y.shape == (batch_size, action_size) q_func.train() # check compute_target action = torch.randint(high=action_size, size=(batch_size,)) - target = q_func.compute_target(x, action) + target = forwarder.compute_target(x, action) assert target.shape == (batch_size, n_quantiles) # check compute_target with action=None - targets = q_func.compute_target(x) + targets = forwarder.compute_target(x) assert targets.shape == (batch_size, action_size, n_quantiles) # TODO: check quantile huber loss @@ -62,15 +117,23 @@ def test_discrete_iqn_q_function( q_tp1 = torch.rand(batch_size, n_quantiles) ter_tp1 = torch.randint(2, size=(batch_size, 1)) # check shape - loss = q_func.compute_error( - obs_t, act_t, rew_tp1, q_tp1, ter_tp1, reduction="none" + loss = forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + reduction="none", ) assert loss.shape == (batch_size, 1) # mean loss - loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1) - - # check layer connection - check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1, ter_tp1)) + loss = forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + ) @pytest.mark.parametrize("feature_size", [100]) @@ -91,7 +154,6 @@ def test_continuous_iqn_q_function( q_func = ContinuousIQNQFunction( encoder, feature_size, - action_size, n_quantiles, n_greedy_quantiles, embed_size, @@ -101,17 +163,60 @@ def test_continuous_iqn_q_function( x = torch.rand(batch_size, feature_size) action = torch.rand(batch_size, action_size) y = q_func(x, action) - assert y.shape == (batch_size, 1) + assert y.q_value.shape == (batch_size, 1) + assert y.quantiles is not None + assert y.taus is not None + assert y.quantiles.shape == (batch_size, n_quantiles) + assert y.taus.shape == (batch_size, n_quantiles) + assert (y.q_value == y.quantiles.mean(dim=1, keepdim=True)).all() # check eval mode q_func.eval() x = torch.rand(batch_size, feature_size) action = torch.rand(batch_size, action_size) y = q_func(x, action) - assert y.shape == (batch_size, 1) + assert y.q_value.shape == (batch_size, 1) + assert y.quantiles is not None + assert y.taus is not None + assert y.quantiles.shape == (batch_size, n_greedy_quantiles) + assert y.taus.shape == (batch_size, n_greedy_quantiles) q_func.train() - target = q_func.compute_target(x, action) + # check layer connection + check_parameter_updates(q_func, (x, action)) + + +@pytest.mark.parametrize("feature_size", [100]) +@pytest.mark.parametrize("action_size", [2]) +@pytest.mark.parametrize("n_quantiles", [200]) +@pytest.mark.parametrize("n_greedy_quantiles", [32]) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("embed_size", [64]) +def test_continuous_iqn_q_function_forwarder( + feature_size: int, + action_size: int, + n_quantiles: int, + n_greedy_quantiles: int, + batch_size: int, + embed_size: int, +) -> None: + encoder = DummyEncoderWithAction(feature_size, action_size) + q_func = ContinuousIQNQFunction( + encoder, + feature_size, + n_quantiles, + n_greedy_quantiles, + embed_size, + ) + forwarder = ContinuousIQNQFunctionForwarder(q_func, n_quantiles) + + # check output shape + x = torch.rand(batch_size, feature_size) + action = torch.rand(batch_size, action_size) + y = forwarder.compute_expected_q(x, action) + assert y.shape == (batch_size, 1) + + target = forwarder.compute_target(x, action) assert target.shape == (batch_size, n_quantiles) # TODO: check quantile huber loss @@ -121,12 +226,20 @@ def test_continuous_iqn_q_function( q_tp1 = torch.rand(batch_size, n_quantiles) ter_tp1 = torch.randint(2, size=(batch_size, 1)) # check shape - loss = q_func.compute_error( - obs_t, act_t, rew_tp1, q_tp1, ter_tp1, reduction="none" + loss = forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + reduction="none", ) assert loss.shape == (batch_size, 1) # mean loss - loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1) - - # check layer connection - check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1, ter_tp1)) + loss = forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + ) diff --git a/tests/models/torch/q_functions/test_mean_q_function.py b/tests/models/torch/q_functions/test_mean_q_function.py index 7c54e9a4..07700c41 100644 --- a/tests/models/torch/q_functions/test_mean_q_function.py +++ b/tests/models/torch/q_functions/test_mean_q_function.py @@ -2,7 +2,12 @@ import pytest import torch -from d3rlpy.models.torch import ContinuousMeanQFunction, DiscreteMeanQFunction +from d3rlpy.models.torch import ( + ContinuousMeanQFunction, + ContinuousMeanQFunctionForwarder, + DiscreteMeanQFunction, + DiscreteMeanQFunctionForwarder, +) from ..model_test import ( DummyEncoder, @@ -22,9 +27,8 @@ def filter_by_action( @pytest.mark.parametrize("feature_size", [100]) @pytest.mark.parametrize("action_size", [2]) @pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("gamma", [0.99]) def test_discrete_mean_q_function( - feature_size: int, action_size: int, batch_size: int, gamma: float + feature_size: int, action_size: int, batch_size: int ) -> None: encoder = DummyEncoder(feature_size) q_func = DiscreteMeanQFunction(encoder, feature_size, action_size) @@ -32,17 +36,40 @@ def test_discrete_mean_q_function( # check output shape x = torch.rand(batch_size, feature_size) y = q_func(x) + assert y.q_value.shape == (batch_size, action_size) + assert y.quantiles is None + assert y.taus is None + + # check layer connection + check_parameter_updates(q_func, (x,)) + + +@pytest.mark.parametrize("feature_size", [100]) +@pytest.mark.parametrize("action_size", [2]) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("gamma", [0.99]) +def test_discrete_mean_q_function_forwarder( + feature_size: int, action_size: int, batch_size: int, gamma: float +) -> None: + encoder = DummyEncoder(feature_size) + q_func = DiscreteMeanQFunction(encoder, feature_size, action_size) + forwarder = DiscreteMeanQFunctionForwarder(q_func, action_size) + + # check output shape + x = torch.rand(batch_size, feature_size) + y = forwarder.compute_expected_q(x) assert y.shape == (batch_size, action_size) # check compute_target action = torch.randint(high=action_size, size=(batch_size,)) - target = q_func.compute_target(x, action) + target = forwarder.compute_target(x, action) assert target.shape == (batch_size, 1) assert torch.allclose(y[torch.arange(batch_size), action], target.view(-1)) # check compute_target with action=None - targets = q_func.compute_target(x) + targets = forwarder.compute_target(x) assert targets.shape == (batch_size, action_size) + assert (y == targets).all() # check td calculation q_tp1 = np.random.random((batch_size, 1)) @@ -52,44 +79,71 @@ def test_discrete_mean_q_function( obs_t = torch.rand(batch_size, feature_size) act_t = np.random.randint(action_size, size=(batch_size, 1)) - q_t = filter_by_action(q_func(obs_t).detach().numpy(), act_t, action_size) + q_t = filter_by_action( + q_func(obs_t).q_value.detach().numpy(), act_t, action_size + ) ref_loss = ref_huber_loss(q_t.reshape((-1, 1)), target) act_t = torch.tensor(act_t, dtype=torch.int64) rew_tp1 = torch.tensor(rew_tp1, dtype=torch.float32) q_tp1 = torch.tensor(q_tp1, dtype=torch.float32) ter_tp1 = torch.tensor(ter_tp1, dtype=torch.float32) - loss = q_func.compute_error( - obs_t, act_t, rew_tp1, q_tp1, ter_tp1, gamma=gamma + loss = forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + gamma=gamma, ) - assert np.allclose(loss.detach().numpy(), ref_loss) + +@pytest.mark.parametrize("feature_size", [100]) +@pytest.mark.parametrize("action_size", [2]) +@pytest.mark.parametrize("batch_size", [32]) +def test_continuous_mean_q_function( + feature_size: int, + action_size: int, + batch_size: int, +) -> None: + encoder = DummyEncoderWithAction(feature_size, action_size) + q_func = ContinuousMeanQFunction(encoder, feature_size) + + # check output shape + x = torch.rand(batch_size, feature_size) + action = torch.rand(batch_size, action_size) + y = q_func(x, action) + assert y.q_value.shape == (batch_size, 1) + assert y.quantiles is None + assert y.taus is None + # check layer connection - check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1, ter_tp1)) + check_parameter_updates(q_func, (x, action)) @pytest.mark.parametrize("feature_size", [100]) @pytest.mark.parametrize("action_size", [2]) @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("gamma", [0.99]) -def test_continuous_mean_q_function( +def test_continuous_mean_q_function_forwarder( feature_size: int, action_size: int, batch_size: int, gamma: float, ) -> None: encoder = DummyEncoderWithAction(feature_size, action_size) - q_func = ContinuousMeanQFunction(encoder, feature_size, action_size) + q_func = ContinuousMeanQFunction(encoder, feature_size) + forwarder = ContinuousMeanQFunctionForwarder(q_func) # check output shape x = torch.rand(batch_size, feature_size) action = torch.rand(batch_size, action_size) - y = q_func(x, action) + y = forwarder.compute_expected_q(x, action) assert y.shape == (batch_size, 1) # check compute_target - target = q_func.compute_target(x, action) + target = forwarder.compute_target(x, action) assert target.shape == (batch_size, 1) assert (target == y).all() @@ -101,15 +155,19 @@ def test_continuous_mean_q_function( obs_t = torch.rand(batch_size, feature_size) act_t = torch.rand(batch_size, action_size) - q_t = q_func(obs_t, act_t).detach().numpy() + q_t = q_func(obs_t, act_t).q_value.detach().numpy() ref_loss = ((q_t - target) ** 2).mean() rew_tp1 = torch.tensor(rew_tp1, dtype=torch.float32) q_tp1 = torch.tensor(q_tp1, dtype=torch.float32) ter_tp1 = torch.tensor(ter_tp1, dtype=torch.float32) - loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1, gamma) + loss = forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + gamma=gamma, + ) assert np.allclose(loss.detach().numpy(), ref_loss) - - # check layer connection - check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1, ter_tp1)) diff --git a/tests/models/torch/q_functions/test_qr_q_function.py b/tests/models/torch/q_functions/test_qr_q_function.py index 586634b5..af4e4e63 100644 --- a/tests/models/torch/q_functions/test_qr_q_function.py +++ b/tests/models/torch/q_functions/test_qr_q_function.py @@ -3,8 +3,12 @@ import pytest import torch -from d3rlpy.models.torch import ContinuousQRQFunction, DiscreteQRQFunction -from d3rlpy.models.torch.q_functions.qr_q_function import _make_taus +from d3rlpy.models.torch import ( + ContinuousQRQFunction, + ContinuousQRQFunctionForwarder, + DiscreteQRQFunction, + DiscreteQRQFunctionForwarder, +) from d3rlpy.models.torch.q_functions.utility import ( pick_quantile_value_by_action, ) @@ -21,13 +25,11 @@ @pytest.mark.parametrize("action_size", [2]) @pytest.mark.parametrize("n_quantiles", [200]) @pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("gamma", [0.99]) def test_discrete_qr_q_function( feature_size: int, action_size: int, n_quantiles: int, batch_size: int, - gamma: float, ) -> None: encoder = DummyEncoder(feature_size) q_func = DiscreteQRQFunction( @@ -37,21 +39,51 @@ def test_discrete_qr_q_function( # check output shape x = torch.rand(batch_size, feature_size) y = q_func(x) - assert y.shape == (batch_size, action_size) + assert y.q_value.shape == (batch_size, action_size) + assert y.quantiles is not None and y.taus is not None + assert y.quantiles.shape == (batch_size, action_size, n_quantiles) + assert y.taus.shape == (1, n_quantiles) + assert torch.allclose(y.q_value, y.quantiles.mean(dim=2)) # check taus - taus = _make_taus(encoder(x), n_quantiles) step = 1 / n_quantiles for i in range(n_quantiles): - assert np.allclose(taus[0][i].numpy(), i * step + step / 2.0) + assert np.allclose(y.taus[0][i].numpy(), i * step + step / 2.0) + + # check layer connection + check_parameter_updates(q_func, (x,)) + + +@pytest.mark.parametrize("feature_size", [100]) +@pytest.mark.parametrize("action_size", [2]) +@pytest.mark.parametrize("n_quantiles", [200]) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("gamma", [0.99]) +def test_discrete_qr_q_function_forwarder( + feature_size: int, + action_size: int, + n_quantiles: int, + batch_size: int, + gamma: float, +) -> None: + encoder = DummyEncoder(feature_size) + q_func = DiscreteQRQFunction( + encoder, feature_size, action_size, n_quantiles + ) + forwarder = DiscreteQRQFunctionForwarder(q_func, n_quantiles) + + # check output shape + x = torch.rand(batch_size, feature_size) + y = forwarder.compute_expected_q(x) + assert y.shape == (batch_size, action_size) # check compute_target action = torch.randint(high=action_size, size=(batch_size,)) - target = q_func.compute_target(x, action) + target = forwarder.compute_target(x, action) assert target.shape == (batch_size, n_quantiles) # check compute_target with action=None - targets = q_func.compute_target(x) + targets = forwarder.compute_target(x) assert targets.shape == (batch_size, action_size, n_quantiles) # check quantile huber loss @@ -61,61 +93,98 @@ def test_discrete_qr_q_function( q_tp1 = torch.rand(batch_size, n_quantiles) ter_tp1 = torch.randint(2, size=(batch_size, 1)) # shape check - loss = q_func.compute_error( - obs_t, act_t, rew_tp1, q_tp1, ter_tp1, reduction="none" + loss = forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + reduction="none", ) assert loss.shape == (batch_size, 1) # mean loss - loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1) + loss = forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + ) target = rew_tp1.numpy() + gamma * q_tp1.numpy() * (1 - ter_tp1.numpy()) - y = pick_quantile_value_by_action( - q_func._compute_quantiles(encoder(obs_t), taus), act_t - ) + y = q_func(obs_t) + quantiles = y.quantiles + taus = y.taus + assert quantiles is not None + assert taus is not None + y = pick_quantile_value_by_action(quantiles, act_t) reshaped_target = np.reshape(target, (batch_size, -1, 1)) reshaped_y = np.reshape(y.detach().numpy(), (batch_size, 1, -1)) - reshaped_taus = np.reshape(taus, (1, 1, -1)) + reshaped_taus = np.reshape(taus.detach().numpy(), (1, 1, -1)) ref_loss = ref_quantile_huber_loss( reshaped_y, reshaped_target, reshaped_taus, n_quantiles ) assert np.allclose(loss.cpu().detach(), ref_loss.mean()) - # check layer connection - check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1, ter_tp1)) - @pytest.mark.parametrize("feature_size", [100]) @pytest.mark.parametrize("action_size", [2]) @pytest.mark.parametrize("n_quantiles", [200]) @pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("gamma", [0.99]) def test_continuous_qr_q_function( feature_size: int, action_size: int, n_quantiles: int, batch_size: int, - gamma: float, ) -> None: encoder = DummyEncoderWithAction(feature_size, action_size) - q_func = ContinuousQRQFunction( - encoder, feature_size, action_size, n_quantiles - ) + q_func = ContinuousQRQFunction(encoder, feature_size, n_quantiles) # check output shape x = torch.rand(batch_size, feature_size) action = torch.rand(batch_size, action_size) y = q_func(x, action) - assert y.shape == (batch_size, 1) + assert y.q_value.shape == (batch_size, 1) + assert y.quantiles is not None + assert y.quantiles.shape == (batch_size, n_quantiles) + assert torch.allclose(y.q_value, y.quantiles.mean(dim=1, keepdim=True)) + assert y.taus is not None + assert y.taus.shape == (1, n_quantiles) # check taus - taus = _make_taus(encoder(x, action), n_quantiles) step = 1 / n_quantiles for i in range(n_quantiles): - assert np.allclose(taus[0][i].numpy(), i * step + step / 2.0) + assert np.allclose(y.taus[0][i].numpy(), i * step + step / 2.0) + + # check layer connection + check_parameter_updates(q_func, (x, action)) + + +@pytest.mark.parametrize("feature_size", [100]) +@pytest.mark.parametrize("action_size", [2]) +@pytest.mark.parametrize("n_quantiles", [200]) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("gamma", [0.99]) +def test_continuous_qr_q_function_forwarder( + feature_size: int, + action_size: int, + n_quantiles: int, + batch_size: int, + gamma: float, +) -> None: + encoder = DummyEncoderWithAction(feature_size, action_size) + q_func = ContinuousQRQFunction(encoder, feature_size, n_quantiles) + forwarder = ContinuousQRQFunctionForwarder(q_func, n_quantiles) - target = q_func.compute_target(x, action) + # check output shape + x = torch.rand(batch_size, feature_size) + action = torch.rand(batch_size, action_size) + y = forwarder.compute_expected_q(x, action) + assert y.shape == (batch_size, 1) + + target = forwarder.compute_target(x, action) assert target.shape == (batch_size, n_quantiles) # check quantile huber loss @@ -125,24 +194,36 @@ def test_continuous_qr_q_function( q_tp1 = torch.rand(batch_size, n_quantiles) ter_tp1 = torch.randint(2, size=(batch_size, 1)) # check shape - loss = q_func.compute_error( - obs_t, act_t, rew_tp1, q_tp1, ter_tp1, reduction="none" + loss = forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + reduction="none", ) assert loss.shape == (batch_size, 1) # mean loss - loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1) + loss = forwarder.compute_error( + observations=obs_t, + actions=act_t, + rewards=rew_tp1, + target=q_tp1, + terminals=ter_tp1, + ) target = rew_tp1.numpy() + gamma * q_tp1.numpy() * (1 - ter_tp1.numpy()) - y = q_func._compute_quantiles(encoder(obs_t, act_t), taus).detach().numpy() + y = q_func(obs_t, act_t) + assert y.quantiles is not None + assert y.taus is not None + quantiles = y.quantiles.detach().numpy() + taus = y.taus.detach().numpy() reshaped_target = target.reshape((batch_size, -1, 1)) - reshaped_y = y.reshape((batch_size, 1, -1)) + reshaped_y = quantiles.reshape((batch_size, 1, -1)) reshaped_taus = taus.reshape((1, 1, -1)) ref_loss = ref_quantile_huber_loss( reshaped_y, reshaped_target, reshaped_taus, n_quantiles ) assert np.allclose(loss.cpu().detach(), ref_loss.mean()) - - # check layer connection - check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1, ter_tp1)) diff --git a/tests/models/torch/test_q_functions.py b/tests/models/torch/test_q_functions.py index 76560bfd..e13c0a23 100644 --- a/tests/models/torch/test_q_functions.py +++ b/tests/models/torch/test_q_functions.py @@ -33,7 +33,7 @@ def test_compute_max_with_n_actions( n_actions: int, lam: float, ) -> None: - q_func = create_continuous_q_function( + _, forwarder = create_continuous_q_function( observation_shape, action_size, encoder_factory, @@ -44,7 +44,7 @@ def test_compute_max_with_n_actions( x = torch.rand(batch_size, *observation_shape) actions = torch.rand(batch_size, n_actions, action_size) - y = compute_max_with_n_actions(x, actions, q_func, lam) + y = compute_max_with_n_actions(x, actions, forwarder, lam) if isinstance(q_func_factory, MeanQFunctionFactory): assert y.shape == (batch_size, 1) From bae6777603549901da165eba7a5615cb73f14822 Mon Sep 17 00:00:00 2001 From: takuseno Date: Mon, 14 Aug 2023 14:35:33 +0900 Subject: [PATCH 10/20] Add Checkpointer --- d3rlpy/algos/qlearning/awac.py | 14 +++- d3rlpy/algos/qlearning/bc.py | 12 ++- d3rlpy/algos/qlearning/bcq.py | 27 ++++++- d3rlpy/algos/qlearning/bear.py | 20 ++++- d3rlpy/algos/qlearning/cql.py | 28 ++++++- d3rlpy/algos/qlearning/crr.py | 14 +++- d3rlpy/algos/qlearning/ddpg.py | 14 +++- d3rlpy/algos/qlearning/dqn.py | 22 +++++- d3rlpy/algos/qlearning/iql.py | 15 +++- d3rlpy/algos/qlearning/nfq.py | 12 ++- d3rlpy/algos/qlearning/plas.py | 31 +++++++- d3rlpy/algos/qlearning/sac.py | 30 +++++++- d3rlpy/algos/qlearning/td3.py | 14 +++- d3rlpy/algos/qlearning/td3_plus_bc.py | 14 +++- d3rlpy/algos/qlearning/torch/awac_impl.py | 4 +- d3rlpy/algos/qlearning/torch/bc_impl.py | 8 +- d3rlpy/algos/qlearning/torch/bcq_impl.py | 6 +- d3rlpy/algos/qlearning/torch/bear_impl.py | 4 +- d3rlpy/algos/qlearning/torch/cql_impl.py | 6 +- d3rlpy/algos/qlearning/torch/crr_impl.py | 4 +- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 10 ++- d3rlpy/algos/qlearning/torch/dqn_impl.py | 4 +- d3rlpy/algos/qlearning/torch/iql_impl.py | 4 +- d3rlpy/algos/qlearning/torch/plas_impl.py | 6 +- d3rlpy/algos/qlearning/torch/sac_impl.py | 6 +- d3rlpy/algos/qlearning/torch/td3_impl.py | 4 +- .../algos/qlearning/torch/td3_plus_bc_impl.py | 4 +- .../algos/transformer/decision_transformer.py | 8 +- .../torch/decision_transformer_impl.py | 9 ++- d3rlpy/base.py | 11 +-- d3rlpy/ope/fqe.py | 22 +++++- d3rlpy/ope/torch/fqe_impl.py | 4 +- d3rlpy/torch_utility.py | 41 +++++----- tests/algos/qlearning/algo_test.py | 24 +++++- tests/algos/transformer/algo_test.py | 14 +++- tests/test_torch_utility.py | 77 +++++++++++-------- 36 files changed, 456 insertions(+), 91 deletions(-) diff --git a/d3rlpy/algos/qlearning/awac.py b/d3rlpy/algos/qlearning/awac.py index 17cb6fc2..9b6f4eef 100644 --- a/d3rlpy/algos/qlearning/awac.py +++ b/d3rlpy/algos/qlearning/awac.py @@ -11,7 +11,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.awac_impl import AWACImpl @@ -130,6 +130,17 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.critic_learning_rate ) + checkpointer = Checkpointer( + modules={ + "policy": policy, + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "actor_optim": actor_optim, + "critic_optim": critic_optim, + }, + device=self._device, + ) + self._impl = AWACImpl( observation_shape=observation_shape, action_size=action_size, @@ -144,6 +155,7 @@ def inner_create_impl( tau=self._config.tau, lam=self._config.lam, n_action_samples=self._config.n_action_samples, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bc.py b/d3rlpy/algos/qlearning/bc.py index b77bda48..a37781aa 100644 --- a/d3rlpy/algos/qlearning/bc.py +++ b/d3rlpy/algos/qlearning/bc.py @@ -11,7 +11,7 @@ ) from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.bc_impl import BCBaseImpl, BCImpl, DiscreteBCImpl @@ -97,11 +97,16 @@ def inner_create_impl( imitator.parameters(), lr=self._config.learning_rate ) + checkpointer = Checkpointer( + modules={"imitator": imitator, "optim": optim}, device=self._device + ) + self._impl = BCImpl( observation_shape=observation_shape, action_size=action_size, imitator=imitator, optim=optim, + checkpointer=checkpointer, device=self._device, ) @@ -166,12 +171,17 @@ def inner_create_impl( imitator.parameters(), lr=self._config.learning_rate ) + checkpointer = Checkpointer( + modules={"imitator": imitator, "optim": optim}, device=self._device + ) + self._impl = DiscreteBCImpl( observation_shape=observation_shape, action_size=action_size, imitator=imitator, optim=optim, beta=self._config.beta, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index a4710379..2ff3b3ef 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -15,7 +15,7 @@ from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field from ...models.torch import CategoricalPolicy, PixelEncoder, compute_output_size -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.bcq_impl import BCQImpl, DiscreteBCQImpl @@ -210,6 +210,19 @@ def inner_create_impl( imitator.parameters(), lr=self._config.imitator_learning_rate ) + checkpointer = Checkpointer( + modules={ + "policy": policy, + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "imitator": imitator, + "actor_optim": actor_optim, + "critic_optim": critic_optim, + "imitator_optim": imitator_optim, + }, + device=self._device, + ) + self._impl = BCQImpl( observation_shape=observation_shape, action_size=action_size, @@ -228,6 +241,7 @@ def inner_create_impl( n_action_samples=self._config.n_action_samples, action_flexibility=self._config.action_flexibility, beta=self._config.beta, + checkpointer=checkpointer, device=self._device, ) @@ -384,6 +398,16 @@ def inner_create_impl( unique_params, lr=self._config.learning_rate ) + checkpointer = Checkpointer( + modules={ + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "imitator": imitator, + "optim": optim, + }, + device=self._device, + ) + self._impl = DiscreteBCQImpl( observation_shape=observation_shape, action_size=action_size, @@ -396,6 +420,7 @@ def inner_create_impl( gamma=self._config.gamma, action_flexibility=self._config.action_flexibility, beta=self._config.beta, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index 65ad4d8e..fb1a0b08 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -14,7 +14,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.bear_impl import BEARImpl @@ -214,6 +214,23 @@ def inner_create_impl( log_alpha.parameters(), lr=self._config.actor_learning_rate ) + checkpointer = Checkpointer( + modules={ + "policy": policy, + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "imitator": imitator, + "log_temp": log_temp, + "log_alpha": log_alpha, + "actor_optim": actor_optim, + "critic_optim": critic_optim, + "imitator_optim": imitator_optim, + "temp_optim": temp_optim, + "alpha_optim": alpha_optim, + }, + device=self._device, + ) + self._impl = BEARImpl( observation_shape=observation_shape, action_size=action_size, @@ -240,6 +257,7 @@ def inner_create_impl( mmd_kernel=self._config.mmd_kernel, mmd_sigma=self._config.mmd_sigma, vae_kl_weight=self._config.vae_kl_weight, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index 1410e6a5..aa36b668 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -14,7 +14,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.cql_impl import CQLImpl, DiscreteCQLImpl @@ -179,6 +179,21 @@ def inner_create_impl( log_alpha.parameters(), lr=self._config.alpha_learning_rate ) + checkpointer = Checkpointer( + modules={ + "policy": policy, + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "log_temp": log_temp, + "log_alpha": log_alpha, + "actor_optim": actor_optim, + "critic_optim": critic_optim, + "temp_optim": temp_optim, + "alpha_optim": alpha_optim, + }, + device=self._device, + ) + self._impl = CQLImpl( observation_shape=observation_shape, action_size=action_size, @@ -199,6 +214,7 @@ def inner_create_impl( conservative_weight=self._config.conservative_weight, n_action_samples=self._config.n_action_samples, soft_q_backup=self._config.soft_q_backup, + checkpointer=checkpointer, device=self._device, ) @@ -311,6 +327,15 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.learning_rate ) + checkpointer = Checkpointer( + modules={ + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "optim": optim, + }, + device=self._device, + ) + self._impl = DiscreteCQLImpl( observation_shape=observation_shape, action_size=action_size, @@ -321,6 +346,7 @@ def inner_create_impl( optim=optim, gamma=self._config.gamma, alpha=self._config.alpha, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/crr.py b/d3rlpy/algos/qlearning/crr.py index a0aaf2ac..ae671ea0 100644 --- a/d3rlpy/algos/qlearning/crr.py +++ b/d3rlpy/algos/qlearning/crr.py @@ -11,7 +11,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.crr_impl import CRRImpl @@ -164,6 +164,17 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.critic_learning_rate ) + checkpointer = Checkpointer( + modules={ + "policy": policy, + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "actor_optim": actor_optim, + "critic_optim": critic_optim, + }, + device=self._device, + ) + self._impl = CRRImpl( observation_shape=observation_shape, action_size=action_size, @@ -181,6 +192,7 @@ def inner_create_impl( weight_type=self._config.weight_type, max_weight=self._config.max_weight, tau=self._config.tau, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/ddpg.py b/d3rlpy/algos/qlearning/ddpg.py index 92b8a905..01b73c96 100644 --- a/d3rlpy/algos/qlearning/ddpg.py +++ b/d3rlpy/algos/qlearning/ddpg.py @@ -11,7 +11,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.ddpg_impl import DDPGImpl @@ -125,6 +125,17 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.critic_learning_rate ) + checkpointer = Checkpointer( + modules={ + "policy": policy, + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "actor_optim": actor_optim, + "critic_optim": critic_optim, + }, + device=self._device, + ) + self._impl = DDPGImpl( observation_shape=observation_shape, action_size=action_size, @@ -137,6 +148,7 @@ def inner_create_impl( critic_optim=critic_optim, gamma=self._config.gamma, tau=self._config.tau, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index 31b8945f..21832152 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -8,7 +8,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.dqn_impl import DoubleDQNImpl, DQNImpl @@ -89,6 +89,15 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.learning_rate ) + checkpointer = Checkpointer( + modules={ + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "optim": optim, + }, + device=self._device, + ) + self._impl = DQNImpl( observation_shape=observation_shape, action_size=action_size, @@ -98,6 +107,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_forwarder, optim=optim, gamma=self._config.gamma, + checkpointer=checkpointer, device=self._device, ) @@ -193,6 +203,15 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.learning_rate ) + checkpointer = Checkpointer( + modules={ + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "optim": optim, + }, + device=self._device, + ) + self._impl = DoubleDQNImpl( observation_shape=observation_shape, action_size=action_size, @@ -202,6 +221,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_forwarder, optim=optim, gamma=self._config.gamma, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index a346718e..fa7ad2b9 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -12,7 +12,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import MeanQFunctionFactory -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.iql_impl import IQLImpl @@ -150,6 +150,18 @@ def inner_create_impl( q_func_params + v_func_params, lr=self._config.critic_learning_rate ) + checkpointer = Checkpointer( + modules={ + "policy": policy, + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "value_func": value_func, + "actor_optim": actor_optim, + "critic_optim": critic_optim, + }, + device=self._device, + ) + self._impl = IQLImpl( observation_shape=observation_shape, action_size=action_size, @@ -166,6 +178,7 @@ def inner_create_impl( expectile=self._config.expectile, weight_temp=self._config.weight_temp, max_weight=self._config.max_weight, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/nfq.py b/d3rlpy/algos/qlearning/nfq.py index e6423cb1..2f00d4d1 100644 --- a/d3rlpy/algos/qlearning/nfq.py +++ b/d3rlpy/algos/qlearning/nfq.py @@ -8,7 +8,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.dqn_impl import DQNImpl @@ -91,6 +91,15 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.learning_rate ) + checkpointer = Checkpointer( + modules={ + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "optim": optim, + }, + device=self._device, + ) + self._impl = DQNImpl( observation_shape=observation_shape, action_size=action_size, @@ -100,6 +109,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, optim=optim, gamma=self._config.gamma, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index 74a242cd..539ee4bb 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -13,7 +13,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.plas_impl import PLASImpl, PLASWithPerturbationImpl @@ -148,6 +148,19 @@ def inner_create_impl( imitator.parameters(), lr=self._config.imitator_learning_rate ) + checkpointer = Checkpointer( + modules={ + "policy": policy, + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "imitator": imitator, + "actor_optim": actor_optim, + "critic_optim": critic_optim, + "imitator_optim": imitator_optim, + }, + device=self._device, + ) + self._impl = PLASImpl( observation_shape=observation_shape, action_size=action_size, @@ -164,6 +177,7 @@ def inner_create_impl( tau=self._config.tau, lam=self._config.lam, beta=self._config.beta, + checkpointer=checkpointer, device=self._device, ) @@ -298,6 +312,20 @@ def inner_create_impl( imitator.parameters(), lr=self._config.imitator_learning_rate ) + checkpointer = Checkpointer( + modules={ + "policy": policy, + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "imitator": imitator, + "perturbation": perturbation, + "actor_optim": actor_optim, + "critic_optim": critic_optim, + "imitator_optim": imitator_optim, + }, + device=self._device, + ) + self._impl = PLASWithPerturbationImpl( observation_shape=observation_shape, action_size=action_size, @@ -315,6 +343,7 @@ def inner_create_impl( tau=self._config.tau, lam=self._config.lam, beta=self._config.beta, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/sac.py b/d3rlpy/algos/qlearning/sac.py index dc4c0af7..797eb7ef 100644 --- a/d3rlpy/algos/qlearning/sac.py +++ b/d3rlpy/algos/qlearning/sac.py @@ -15,7 +15,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.sac_impl import DiscreteSACImpl, SACImpl @@ -157,6 +157,19 @@ def inner_create_impl( log_temp.parameters(), lr=self._config.temp_learning_rate ) + checkpointer = Checkpointer( + modules={ + "policy": policy, + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "log_temp": log_temp, + "actor_optim": actor_optim, + "critic_optim": critic_optim, + "temp_optim": temp_optim, + }, + device=self._device, + ) + self._impl = SACImpl( observation_shape=observation_shape, action_size=action_size, @@ -171,6 +184,7 @@ def inner_create_impl( temp_optim=temp_optim, gamma=self._config.gamma, tau=self._config.tau, + checkpointer=checkpointer, device=self._device, ) @@ -314,6 +328,19 @@ def inner_create_impl( log_temp.parameters(), lr=self._config.temp_learning_rate ) + checkpointer = Checkpointer( + modules={ + "policy": policy, + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "log_temp": log_temp, + "critic_optim": critic_optim, + "actor_optim": actor_optim, + "temp_optim": temp_optim, + }, + device=self._device, + ) + self._impl = DiscreteSACImpl( observation_shape=observation_shape, action_size=action_size, @@ -327,6 +354,7 @@ def inner_create_impl( critic_optim=critic_optim, temp_optim=temp_optim, gamma=self._config.gamma, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index cfc88747..81d8522b 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -11,7 +11,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.td3_impl import TD3Impl @@ -133,6 +133,17 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.critic_learning_rate ) + checkpointer = Checkpointer( + modules={ + "policy": policy, + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "actor_optim": actor_optim, + "critic_optim": critic_optim, + }, + device=self._device, + ) + self._impl = TD3Impl( observation_shape=observation_shape, action_size=action_size, @@ -147,6 +158,7 @@ def inner_create_impl( tau=self._config.tau, target_smoothing_sigma=self._config.target_smoothing_sigma, target_smoothing_clip=self._config.target_smoothing_clip, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3_plus_bc.py b/d3rlpy/algos/qlearning/td3_plus_bc.py index 25ac1bb0..49b3ecb2 100644 --- a/d3rlpy/algos/qlearning/td3_plus_bc.py +++ b/d3rlpy/algos/qlearning/td3_plus_bc.py @@ -11,7 +11,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch +from ...torch_utility import Checkpointer, TorchMiniBatch from .base import QLearningAlgoBase from .torch.td3_plus_bc_impl import TD3PlusBCImpl @@ -125,6 +125,17 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.critic_learning_rate ) + checkpointer = Checkpointer( + modules={ + "policy": policy, + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "actor_optim": actor_optim, + "critic_optim": critic_optim, + }, + device=self._device, + ) + self._impl = TD3PlusBCImpl( observation_shape=observation_shape, action_size=action_size, @@ -140,6 +151,7 @@ def inner_create_impl( target_smoothing_sigma=self._config.target_smoothing_sigma, target_smoothing_clip=self._config.target_smoothing_clip, alpha=self._config.alpha, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/torch/awac_impl.py b/d3rlpy/algos/qlearning/torch/awac_impl.py index c0726b98..65e2565f 100644 --- a/d3rlpy/algos/qlearning/torch/awac_impl.py +++ b/d3rlpy/algos/qlearning/torch/awac_impl.py @@ -11,7 +11,7 @@ Policy, build_gaussian_distribution, ) -from ....torch_utility import TorchMiniBatch +from ....torch_utility import Checkpointer, TorchMiniBatch from .sac_impl import SACImpl __all__ = ["AWACImpl"] @@ -37,6 +37,7 @@ def __init__( tau: float, lam: float, n_action_samples: int, + checkpointer: Checkpointer, device: str, ): assert isinstance(policy, NormalPolicy) @@ -55,6 +56,7 @@ def __init__( temp_optim=Adam(dummy_log_temp.parameters(), lr=0.0), gamma=gamma, tau=tau, + checkpointer=checkpointer, device=device, ) self._lam = lam diff --git a/d3rlpy/algos/qlearning/torch/bc_impl.py b/d3rlpy/algos/qlearning/torch/bc_impl.py index b5f77814..ff72fe2e 100644 --- a/d3rlpy/algos/qlearning/torch/bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/bc_impl.py @@ -14,7 +14,7 @@ compute_discrete_imitation_loss, compute_stochastic_imitation_loss, ) -from ....torch_utility import TorchMiniBatch, train_api +from ....torch_utility import Checkpointer, TorchMiniBatch, train_api from ..base import QLearningAlgoImplBase __all__ = ["BCImpl", "DiscreteBCImpl"] @@ -29,11 +29,13 @@ def __init__( observation_shape: Shape, action_size: int, optim: Optimizer, + checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, + checkpointer=checkpointer, device=device, ) self._optim = optim @@ -73,12 +75,14 @@ def __init__( action_size: int, imitator: Union[DeterministicPolicy, NormalPolicy], optim: Optimizer, + checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, optim=optim, + checkpointer=checkpointer, device=device, ) self._imitator = imitator @@ -118,12 +122,14 @@ def __init__( imitator: CategoricalPolicy, optim: Optimizer, beta: float, + checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, optim=optim, + checkpointer=checkpointer, device=device, ) self._imitator = imitator diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index 338af1ee..1b5b5a5e 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -18,7 +18,7 @@ compute_vae_error, forward_vae_decode, ) -from ....torch_utility import TorchMiniBatch, train_api +from ....torch_utility import Checkpointer, TorchMiniBatch, train_api from .ddpg_impl import DDPGBaseImpl from .dqn_impl import DoubleDQNImpl @@ -54,6 +54,7 @@ def __init__( n_action_samples: int, action_flexibility: float, beta: float, + checkpointer: Checkpointer, device: str, ): super().__init__( @@ -68,6 +69,7 @@ def __init__( critic_optim=critic_optim, gamma=gamma, tau=tau, + checkpointer=checkpointer, device=device, ) self._lam = lam @@ -201,6 +203,7 @@ def __init__( gamma: float, action_flexibility: float, beta: float, + checkpointer: Checkpointer, device: str, ): super().__init__( @@ -212,6 +215,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, optim=optim, gamma=gamma, + checkpointer=checkpointer, device=device, ) self._action_flexibility = action_flexibility diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index 2f6b5f4a..8a10aee6 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -15,7 +15,7 @@ compute_vae_error, forward_vae_sample_n, ) -from ....torch_utility import TorchMiniBatch, train_api +from ....torch_utility import Checkpointer, TorchMiniBatch, train_api from .sac_impl import SACImpl __all__ = ["BEARImpl"] @@ -77,6 +77,7 @@ def __init__( mmd_kernel: str, mmd_sigma: float, vae_kl_weight: float, + checkpointer: Checkpointer, device: str, ): super().__init__( @@ -93,6 +94,7 @@ def __init__( temp_optim=temp_optim, gamma=gamma, tau=tau, + checkpointer=checkpointer, device=device, ) self._alpha_threshold = alpha_threshold diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 8f4e83ff..a232c5aa 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -14,7 +14,7 @@ Parameter, build_squashed_gaussian_distribution, ) -from ....torch_utility import TorchMiniBatch, train_api +from ....torch_utility import Checkpointer, TorchMiniBatch, train_api from .dqn_impl import DoubleDQNImpl from .sac_impl import SACImpl @@ -50,6 +50,7 @@ def __init__( conservative_weight: float, n_action_samples: int, soft_q_backup: bool, + checkpointer: Checkpointer, device: str, ): super().__init__( @@ -66,6 +67,7 @@ def __init__( temp_optim=temp_optim, gamma=gamma, tau=tau, + checkpointer=checkpointer, device=device, ) self._alpha_threshold = alpha_threshold @@ -222,6 +224,7 @@ def __init__( optim: Optimizer, gamma: float, alpha: float, + checkpointer: Checkpointer, device: str, ): super().__init__( @@ -233,6 +236,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, optim=optim, gamma=gamma, + checkpointer=checkpointer, device=device, ) self._alpha = alpha diff --git a/d3rlpy/algos/qlearning/torch/crr_impl.py b/d3rlpy/algos/qlearning/torch/crr_impl.py index 82e5e603..3adcb922 100644 --- a/d3rlpy/algos/qlearning/torch/crr_impl.py +++ b/d3rlpy/algos/qlearning/torch/crr_impl.py @@ -9,7 +9,7 @@ NormalPolicy, build_gaussian_distribution, ) -from ....torch_utility import TorchMiniBatch, hard_sync +from ....torch_utility import Checkpointer, TorchMiniBatch, hard_sync from .ddpg_impl import DDPGBaseImpl __all__ = ["CRRImpl"] @@ -42,6 +42,7 @@ def __init__( weight_type: str, max_weight: float, tau: float, + checkpointer: Checkpointer, device: str, ): super().__init__( @@ -56,6 +57,7 @@ def __init__( critic_optim=critic_optim, gamma=gamma, tau=tau, + checkpointer=checkpointer, device=device, ) self._beta = beta diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index e6a2d2fa..2add695b 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -8,7 +8,13 @@ from ....dataset import Shape from ....models.torch import ContinuousEnsembleQFunctionForwarder, Policy -from ....torch_utility import TorchMiniBatch, hard_sync, soft_sync, train_api +from ....torch_utility import ( + Checkpointer, + TorchMiniBatch, + hard_sync, + soft_sync, + train_api, +) from ..base import QLearningAlgoImplBase from .utility import ContinuousQFunctionMixin @@ -42,11 +48,13 @@ def __init__( critic_optim: Optimizer, gamma: float, tau: float, + checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, + checkpointer=checkpointer, device=device, ) self._gamma = gamma diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index 54d4cbf5..e894418d 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -6,7 +6,7 @@ from ....dataset import Shape from ....models.torch import DiscreteEnsembleQFunctionForwarder -from ....torch_utility import TorchMiniBatch, hard_sync, train_api +from ....torch_utility import Checkpointer, TorchMiniBatch, hard_sync, train_api from ..base import QLearningAlgoImplBase from .utility import DiscreteQFunctionMixin @@ -31,11 +31,13 @@ def __init__( targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, optim: Optimizer, gamma: float, + checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, + checkpointer=checkpointer, device=device, ) self._gamma = gamma diff --git a/d3rlpy/algos/qlearning/torch/iql_impl.py b/d3rlpy/algos/qlearning/torch/iql_impl.py index 80d8d806..aa39c95f 100644 --- a/d3rlpy/algos/qlearning/torch/iql_impl.py +++ b/d3rlpy/algos/qlearning/torch/iql_impl.py @@ -11,7 +11,7 @@ ValueFunction, build_gaussian_distribution, ) -from ....torch_utility import TorchMiniBatch, train_api +from ....torch_utility import Checkpointer, TorchMiniBatch, train_api from .ddpg_impl import DDPGBaseImpl __all__ = ["IQLImpl"] @@ -41,6 +41,7 @@ def __init__( expectile: float, weight_temp: float, max_weight: float, + checkpointer: Checkpointer, device: str, ): super().__init__( @@ -55,6 +56,7 @@ def __init__( critic_optim=critic_optim, gamma=gamma, tau=tau, + checkpointer=checkpointer, device=device, ) self._expectile = expectile diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index 817754b2..400d51ac 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -14,7 +14,7 @@ compute_vae_error, forward_vae_decode, ) -from ....torch_utility import TorchMiniBatch, soft_sync, train_api +from ....torch_utility import Checkpointer, TorchMiniBatch, soft_sync, train_api from .ddpg_impl import DDPGBaseImpl __all__ = ["PLASImpl", "PLASWithPerturbationImpl"] @@ -45,6 +45,7 @@ def __init__( tau: float, lam: float, beta: float, + checkpointer: Checkpointer, device: str, ): super().__init__( @@ -59,6 +60,7 @@ def __init__( critic_optim=critic_optim, gamma=gamma, tau=tau, + checkpointer=checkpointer, device=device, ) self._lam = lam @@ -136,6 +138,7 @@ def __init__( tau: float, lam: float, beta: float, + checkpointer: Checkpointer, device: str, ): super().__init__( @@ -154,6 +157,7 @@ def __init__( tau=tau, lam=lam, beta=beta, + checkpointer=checkpointer, device=device, ) self._perturbation = perturbation diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index ed1e4391..6a87b7f1 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -15,7 +15,7 @@ Policy, build_squashed_gaussian_distribution, ) -from ....torch_utility import TorchMiniBatch, hard_sync, train_api +from ....torch_utility import Checkpointer, TorchMiniBatch, hard_sync, train_api from ..base import QLearningAlgoImplBase from .ddpg_impl import DDPGBaseImpl from .utility import DiscreteQFunctionMixin @@ -42,6 +42,7 @@ def __init__( temp_optim: Optimizer, gamma: float, tau: float, + checkpointer: Checkpointer, device: str, ): super().__init__( @@ -56,6 +57,7 @@ def __init__( critic_optim=critic_optim, gamma=gamma, tau=tau, + checkpointer=checkpointer, device=device, ) self._log_temp = log_temp @@ -140,11 +142,13 @@ def __init__( critic_optim: Optimizer, temp_optim: Optimizer, gamma: float, + checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, + checkpointer=checkpointer, device=device, ) self._gamma = gamma diff --git a/d3rlpy/algos/qlearning/torch/td3_impl.py b/d3rlpy/algos/qlearning/torch/td3_impl.py index dbd944a7..8868ecee 100644 --- a/d3rlpy/algos/qlearning/torch/td3_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_impl.py @@ -7,7 +7,7 @@ ContinuousEnsembleQFunctionForwarder, DeterministicPolicy, ) -from ....torch_utility import TorchMiniBatch +from ....torch_utility import Checkpointer, TorchMiniBatch from .ddpg_impl import DDPGImpl __all__ = ["TD3Impl"] @@ -32,6 +32,7 @@ def __init__( tau: float, target_smoothing_sigma: float, target_smoothing_clip: float, + checkpointer: Checkpointer, device: str, ): super().__init__( @@ -46,6 +47,7 @@ def __init__( critic_optim=critic_optim, gamma=gamma, tau=tau, + checkpointer=checkpointer, device=device, ) self._target_smoothing_sigma = target_smoothing_sigma diff --git a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py index 9cb75608..4ffdc5c0 100644 --- a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py @@ -9,7 +9,7 @@ ContinuousEnsembleQFunctionForwarder, DeterministicPolicy, ) -from ....torch_utility import TorchMiniBatch +from ....torch_utility import Checkpointer, TorchMiniBatch from .td3_impl import TD3Impl __all__ = ["TD3PlusBCImpl"] @@ -34,6 +34,7 @@ def __init__( target_smoothing_sigma: float, target_smoothing_clip: float, alpha: float, + checkpointer: Checkpointer, device: str, ): super().__init__( @@ -50,6 +51,7 @@ def __init__( tau=tau, target_smoothing_sigma=target_smoothing_sigma, target_smoothing_clip=target_smoothing_clip, + checkpointer=checkpointer, device=device, ) self._alpha = alpha diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index 2564b43b..ad34f362 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -13,7 +13,7 @@ make_optimizer_field, ) from ...models.builders import create_continuous_decision_transformer -from ...torch_utility import TorchTrajectoryMiniBatch +from ...torch_utility import Checkpointer, TorchTrajectoryMiniBatch from .base import TransformerAlgoBase, TransformerConfig from .torch.decision_transformer_impl import DecisionTransformerImpl @@ -113,6 +113,11 @@ def inner_create_impl( if self._config.compile: transformer = torch.compile(transformer, fullgraph=True) + checkpointer = Checkpointer( + modules={"transformer": transformer, "optim": optim}, + device=self._device, + ) + self._impl = DecisionTransformerImpl( observation_shape=observation_shape, action_size=action_size, @@ -120,6 +125,7 @@ def inner_create_impl( optim=optim, scheduler=scheduler, clip_grad_norm=self._config.clip_grad_norm, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py index b1e4871a..2ea74ef5 100644 --- a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py +++ b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py @@ -5,7 +5,12 @@ from ....dataset import Shape from ....models.torch import ContinuousDecisionTransformer -from ....torch_utility import TorchTrajectoryMiniBatch, eval_api, train_api +from ....torch_utility import ( + Checkpointer, + TorchTrajectoryMiniBatch, + eval_api, + train_api, +) from ..base import TransformerAlgoImplBase from ..inputs import TorchTransformerInput @@ -26,11 +31,13 @@ def __init__( optim: Optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR, clip_grad_norm: float, + checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, + checkpointer=checkpointer, device=device, ) self._transformer = transformer diff --git a/d3rlpy/base.py b/d3rlpy/base.py index c06a1e27..e72f9c35 100644 --- a/d3rlpy/base.py +++ b/d3rlpy/base.py @@ -4,7 +4,6 @@ from abc import ABCMeta, abstractmethod from typing import BinaryIO, Generic, Optional, Type, TypeVar, Union -import torch from gym.spaces import Box, Discrete from gymnasium.spaces import Box as GymnasiumBox from gymnasium.spaces import Discrete as GymnasiumDiscrete @@ -24,7 +23,7 @@ make_reward_scaler_field, ) from .serializable_config import DynamicConfig, generate_config_registration -from .torch_utility import get_state_dict, map_location, set_state_dict +from .torch_utility import Checkpointer __all__ = [ "DeviceArg", @@ -49,24 +48,26 @@ class ImplBase(metaclass=ABCMeta): _observation_shape: Shape _action_size: int + _checkpointer: Checkpointer _device: str def __init__( self, observation_shape: Shape, action_size: int, + checkpointer: Checkpointer, device: str, ): self._observation_shape = observation_shape self._action_size = action_size + self._checkpointer = checkpointer self._device = device def save_model(self, f: BinaryIO) -> None: - torch.save(get_state_dict(self), f) + self._checkpointer.save(f) def load_model(self, f: BinaryIO) -> None: - chkpt = torch.load(f, map_location=map_location(self._device)) - set_state_dict(self, chkpt) + self._checkpointer.load(f) @property def observation_shape(self) -> Shape: diff --git a/d3rlpy/ope/fqe.py b/d3rlpy/ope/fqe.py index d6a196a2..a08b6c65 100644 --- a/d3rlpy/ope/fqe.py +++ b/d3rlpy/ope/fqe.py @@ -18,7 +18,7 @@ from ..models.encoders import EncoderFactory, make_encoder_field from ..models.optimizers import OptimizerFactory, make_optimizer_field from ..models.q_functions import QFunctionFactory, make_q_func_field -from ..torch_utility import TorchMiniBatch, convert_to_torch +from ..torch_utility import Checkpointer, TorchMiniBatch, convert_to_torch from .torch.fqe_impl import DiscreteFQEImpl, FQEBaseImpl, FQEImpl __all__ = ["FQEConfig", "FQE", "DiscreteFQE"] @@ -175,6 +175,16 @@ def inner_create_impl( optim = self._config.optim_factory.create( q_funcs.parameters(), lr=self._config.learning_rate ) + + checkpointer = Checkpointer( + modules={ + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "optim": optim, + }, + device=self._device, + ) + self._impl = FQEImpl( observation_shape=observation_shape, action_size=action_size, @@ -184,6 +194,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, optim=optim, gamma=self._config.gamma, + checkpointer=checkpointer, device=self._device, ) @@ -242,6 +253,14 @@ def inner_create_impl( optim = self._config.optim_factory.create( q_funcs.parameters(), lr=self._config.learning_rate ) + checkpointer = Checkpointer( + modules={ + "q_func": q_funcs, + "targ_q_func": targ_q_funcs, + "optim": optim, + }, + device=self._device, + ) self._impl = DiscreteFQEImpl( observation_shape=observation_shape, action_size=action_size, @@ -251,6 +270,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, optim=optim, gamma=self._config.gamma, + checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/ope/torch/fqe_impl.py b/d3rlpy/ope/torch/fqe_impl.py index cab2f4f0..f1361371 100644 --- a/d3rlpy/ope/torch/fqe_impl.py +++ b/d3rlpy/ope/torch/fqe_impl.py @@ -14,7 +14,7 @@ ContinuousEnsembleQFunctionForwarder, DiscreteEnsembleQFunctionForwarder, ) -from ...torch_utility import TorchMiniBatch, hard_sync, train_api +from ...torch_utility import Checkpointer, TorchMiniBatch, hard_sync, train_api __all__ = ["FQEBaseImpl", "FQEImpl", "DiscreteFQEImpl"] @@ -47,11 +47,13 @@ def __init__( ], optim: Optimizer, gamma: float, + checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, + checkpointer=checkpointer, device=device, ) self._gamma = gamma diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index 11f280e1..6b58656b 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -1,6 +1,6 @@ import collections import dataclasses -from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union +from typing import Any, BinaryIO, Dict, List, Optional, Sequence, TypeVar, Union import numpy as np import torch @@ -21,12 +21,11 @@ "to_device", "freeze", "unfreeze", - "get_state_dict", - "set_state_dict", "reset_optimizer_states", "map_location", "TorchMiniBatch", "TorchTrajectoryMiniBatch", + "Checkpointer", "convert_to_torch", "convert_to_torch_recursively", "eval_api", @@ -124,22 +123,6 @@ def unfreeze(impl: Any) -> None: p.requires_grad = True -def get_state_dict(impl: Any) -> Dict[str, Any]: - rets = {} - for key in _get_attributes(impl): - obj = getattr(impl, key) - if isinstance(obj, (torch.nn.Module, torch.optim.Optimizer)): - rets[key] = obj.state_dict() - return rets - - -def set_state_dict(impl: Any, chkpt: Dict[str, Any]) -> None: - for key in _get_attributes(impl): - obj = getattr(impl, key) - if isinstance(obj, (torch.nn.Module, torch.optim.Optimizer)): - obj.load_state_dict(chkpt[key]) - - def reset_optimizer_states(impl: Any) -> None: for key in _get_attributes(impl): obj = getattr(impl, key) @@ -283,6 +266,26 @@ def from_batch( ) +class Checkpointer: + _modules: Dict[str, Union[nn.Module, Optimizer]] + _device: str + + def __init__( + self, modules: Dict[str, Union[nn.Module, Optimizer]], device: str + ): + self._modules = modules + self._device = device + + def save(self, f: BinaryIO) -> None: + states = {k: v.state_dict() for k, v in self._modules.items()} + torch.save(states, f) + + def load(self, f: BinaryIO) -> None: + chkpt = torch.load(f, map_location=map_location(self._device)) + for k, v in self._modules.items(): + v.load_state_dict(chkpt[k]) + + TCallable = TypeVar("TCallable") diff --git a/tests/algos/qlearning/algo_test.py b/tests/algos/qlearning/algo_test.py index 6d259301..86a2018e 100644 --- a/tests/algos/qlearning/algo_test.py +++ b/tests/algos/qlearning/algo_test.py @@ -37,7 +37,12 @@ def algo_tester( load_learnable_tester(algo, observation_shape, action_size) predict_tester(algo, observation_shape, action_size) sample_action_tester(algo, observation_shape, action_size) - save_and_load_tester(algo, observation_shape, action_size) + save_and_load_tester( + algo, + observation_shape, + action_size, + deterministic_best_action=deterministic_best_action, + ) update_tester( algo, observation_shape, @@ -227,10 +232,25 @@ def save_and_load_tester( algo: QLearningAlgoBase[QLearningAlgoImplBase, LearnableConfig], observation_shape: Sequence[int], action_size: int, + deterministic_best_action: bool = True, ) -> None: algo.create_impl(observation_shape, action_size) algo.save_model(os.path.join("test_data", "model.pt")) - algo.load_model(os.path.join("test_data", "model.pt")) + + try: + algo2 = algo.config.create() + algo2.create_impl(observation_shape, action_size) + algo2.load_model(os.path.join("test_data", "model.pt")) + assert isinstance(algo2, QLearningAlgoBase) + + if deterministic_best_action: + observations = np.random.random((100, *observation_shape)) + action1 = algo.predict(observations) + action2 = algo2.predict(observations) + assert np.all(action1 == action2) + except NotImplementedError: + # check interface at least + algo.load_model(os.path.join("test_data", "model.pt")) def update_tester( diff --git a/tests/algos/transformer/algo_test.py b/tests/algos/transformer/algo_test.py index 5faa6e6f..f00805d1 100644 --- a/tests/algos/transformer/algo_test.py +++ b/tests/algos/transformer/algo_test.py @@ -120,7 +120,19 @@ def save_and_load_tester( ) -> None: algo.create_impl(observation_shape, action_size) algo.save_model(os.path.join("test_data", "model.pt")) - algo.load_model(os.path.join("test_data", "model.pt")) + + algo2 = algo.config.create() + algo2.create_impl(observation_shape, action_size) + algo2.load_model(os.path.join("test_data", "model.pt")) + assert isinstance(algo2, TransformerAlgoBase) + + actor1 = algo.as_stateful_wrapper(0) + actor2 = algo2.as_stateful_wrapper(0) + + observation = np.random.random(observation_shape) + action1 = actor1.predict(observation, 0) + action2 = actor2.predict(observation, 0) + assert np.all(action1 == action2) def update_tester( diff --git a/tests/test_torch_utility.py b/tests/test_torch_utility.py index 43f03bce..fb710851 100644 --- a/tests/test_torch_utility.py +++ b/tests/test_torch_utility.py @@ -1,4 +1,5 @@ import copy +from io import BytesIO from typing import Any, Dict, Sequence from unittest.mock import Mock @@ -8,18 +9,17 @@ from d3rlpy.dataset import TrajectoryMiniBatch, Transition, TransitionMiniBatch from d3rlpy.torch_utility import ( + Checkpointer, Swish, TorchMiniBatch, TorchTrajectoryMiniBatch, View, eval_api, freeze, - get_state_dict, hard_sync, map_location, reset_optimizer_states, set_eval_mode, - set_state_dict, set_train_mode, soft_sync, sync_optimizer_state, @@ -136,37 +136,6 @@ def check_if_same_dict(a: Dict[str, Any], b: Dict[str, Any]) -> None: assert b[k] == v -def test_get_state_dict() -> None: - impl = DummyImpl() - - state_dict = get_state_dict(impl) - - check_if_same_dict(state_dict["fc1"], impl.fc1.state_dict()) - check_if_same_dict(state_dict["fc2"], impl.fc2.state_dict()) - check_if_same_dict(state_dict["optim"], impl.optim.state_dict()) - - -def test_set_state_dict() -> None: - impl1 = DummyImpl() - impl2 = DummyImpl() - - impl1.optim.step() - - assert not (impl1.fc1.weight == impl2.fc1.weight).all() - assert not (impl1.fc1.bias == impl2.fc1.bias).all() - assert not (impl1.fc2.weight == impl2.fc2.weight).all() - assert not (impl1.fc2.bias == impl2.fc2.bias).all() - - chkpt = get_state_dict(impl1) - - set_state_dict(impl2, chkpt) - - assert (impl1.fc1.weight == impl2.fc1.weight).all() - assert (impl1.fc1.bias == impl2.fc1.bias).all() - assert (impl1.fc2.weight == impl2.fc2.weight).all() - assert (impl1.fc2.bias == impl2.fc2.bias).all() - - def test_reset_optimizer_states() -> None: impl = DummyImpl() @@ -396,6 +365,48 @@ def test_torch_trajectory_mini_batch( assert np.all(torch_batch.terminals.numpy() == batch.terminals) +def test_checkpointer() -> None: + fc1 = torch.nn.Linear(100, 100) + fc2 = torch.nn.Linear(100, 100) + optim = torch.optim.Adam(fc1.parameters()) + checkpointer = Checkpointer( + modules={"fc1": fc1, "fc2": fc2, "optim": optim}, device="cpu:0" + ) + + # prepare reference bytes + ref_bytes = BytesIO() + states = { + "fc1": fc1.state_dict(), + "fc2": fc2.state_dict(), + "optim": optim.state_dict(), + } + torch.save(states, ref_bytes) + + # check saved bytes + saved_bytes = BytesIO() + checkpointer.save(saved_bytes) + assert ref_bytes.getvalue() == saved_bytes.getvalue() + + fc1_2 = torch.nn.Linear(100, 100) + fc2_2 = torch.nn.Linear(100, 100) + optim_2 = torch.optim.Adam(fc1_2.parameters()) + checkpointer = Checkpointer( + modules={"fc1": fc1_2, "fc2": fc2_2, "optim": optim_2}, device="cpu:0" + ) + + # check load + checkpointer.load(BytesIO(saved_bytes.getvalue())) + + # check output + x = torch.rand(32, 100) + y1_ref = fc1(x) + y2_ref = fc2(x) + y1 = fc1_2(x) + y2 = fc2_2(x) + assert torch.all(y1_ref == y1) + assert torch.all(y2_ref == y2) + + def test_train_api() -> None: impl = DummyImpl() impl.fc1.eval() From 817810d5308ac194a16f23cc2010ceb7124ba548 Mon Sep 17 00:00:00 2001 From: takuseno Date: Mon, 14 Aug 2023 20:49:11 +0900 Subject: [PATCH 11/20] Refactor models with Modules --- d3rlpy/algos/qlearning/awac.py | 31 ++-- d3rlpy/algos/qlearning/base.py | 6 +- d3rlpy/algos/qlearning/bc.py | 26 ++-- d3rlpy/algos/qlearning/bcq.py | 64 ++++---- d3rlpy/algos/qlearning/bear.py | 37 ++--- d3rlpy/algos/qlearning/cql.py | 52 +++---- d3rlpy/algos/qlearning/crr.py | 33 ++--- d3rlpy/algos/qlearning/ddpg.py | 33 ++--- d3rlpy/algos/qlearning/dqn.py | 36 ++--- d3rlpy/algos/qlearning/iql.py | 29 ++-- d3rlpy/algos/qlearning/nfq.py | 20 +-- d3rlpy/algos/qlearning/plas.py | 88 +++++------ d3rlpy/algos/qlearning/sac.py | 66 ++++----- d3rlpy/algos/qlearning/td3.py | 32 ++-- d3rlpy/algos/qlearning/td3_plus_bc.py | 32 ++-- d3rlpy/algos/qlearning/torch/awac_impl.py | 36 ++--- d3rlpy/algos/qlearning/torch/bc_impl.py | 72 +++++---- d3rlpy/algos/qlearning/torch/bcq_impl.py | 83 +++++------ d3rlpy/algos/qlearning/torch/bear_impl.py | 82 +++++------ d3rlpy/algos/qlearning/torch/cql_impl.py | 72 ++++----- d3rlpy/algos/qlearning/torch/crr_impl.py | 52 +++---- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 101 ++++++++----- d3rlpy/algos/qlearning/torch/dqn_impl.py | 38 ++--- d3rlpy/algos/qlearning/torch/iql_impl.py | 50 +++---- d3rlpy/algos/qlearning/torch/plas_impl.py | 126 ++++++++-------- d3rlpy/algos/qlearning/torch/sac_impl.py | 137 ++++++++---------- d3rlpy/algos/qlearning/torch/td3_impl.py | 27 +--- .../algos/qlearning/torch/td3_plus_bc_impl.py | 26 +--- .../algos/transformer/decision_transformer.py | 17 ++- .../torch/decision_transformer_impl.py | 32 ++-- d3rlpy/base.py | 12 +- d3rlpy/dataclass_utils.py | 10 ++ d3rlpy/ope/fqe.py | 41 +++--- d3rlpy/ope/torch/fqe_impl.py | 34 ++--- d3rlpy/torch_utility.py | 76 +++++----- reproductions/finetuning/iql_finetune.py | 7 +- reproductions/offline/iql.py | 3 +- tests/test_dataclass_utils.py | 24 +++ tests/test_torch_utility.py | 69 ++++----- 39 files changed, 835 insertions(+), 977 deletions(-) create mode 100644 d3rlpy/dataclass_utils.py create mode 100644 tests/test_dataclass_utils.py diff --git a/d3rlpy/algos/qlearning/awac.py b/d3rlpy/algos/qlearning/awac.py index 9b6f4eef..f49f04a9 100644 --- a/d3rlpy/algos/qlearning/awac.py +++ b/d3rlpy/algos/qlearning/awac.py @@ -1,6 +1,8 @@ import dataclasses from typing import Dict +import torch + from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace from ...dataset import Shape @@ -11,9 +13,11 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...models.torch import Parameter +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.awac_impl import AWACImpl +from .torch.sac_impl import SACModules __all__ = ["AWACConfig", "AWAC"] @@ -130,32 +134,27 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.critic_learning_rate ) - checkpointer = Checkpointer( - modules={ - "policy": policy, - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "actor_optim": actor_optim, - "critic_optim": critic_optim, - }, - device=self._device, + dummy_log_temp = Parameter(torch.zeros(1)) + modules = SACModules( + policy=policy, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + log_temp=dummy_log_temp, + actor_optim=actor_optim, + critic_optim=critic_optim, + temp_optim=torch.optim.Adam(dummy_log_temp.parameters(), lr=0.0), ) self._impl = AWACImpl( observation_shape=observation_shape, action_size=action_size, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - policy=policy, - actor_optim=actor_optim, - critic_optim=critic_optim, gamma=self._config.gamma, tau=self._config.tau, lam=self._config.lam, n_action_samples=self._config.n_action_samples, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index 64ec02c7..4033890c 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -43,11 +43,9 @@ convert_to_torch, convert_to_torch_recursively, eval_api, - freeze, hard_sync, reset_optimizer_states, sync_optimizer_state, - unfreeze, ) from ..utility import ( assert_action_space_with_dataset, @@ -196,7 +194,7 @@ def save_policy(self, fname: str) -> None: ) # workaround until version 1.6 - freeze(self._impl) + self._impl.modules.freeze() # dummy function to select best actions def _func(x: torch.Tensor) -> torch.Tensor: @@ -234,7 +232,7 @@ def _func(x: torch.Tensor) -> torch.Tensor: ) # workaround until version 1.6 - unfreeze(self._impl) + self._impl.modules.unfreeze() def predict(self, x: Observation) -> np.ndarray: """Returns greedy actions. diff --git a/d3rlpy/algos/qlearning/bc.py b/d3rlpy/algos/qlearning/bc.py index a37781aa..84c0321e 100644 --- a/d3rlpy/algos/qlearning/bc.py +++ b/d3rlpy/algos/qlearning/bc.py @@ -11,9 +11,15 @@ ) from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase -from .torch.bc_impl import BCBaseImpl, BCImpl, DiscreteBCImpl +from .torch.bc_impl import ( + BCBaseImpl, + BCImpl, + BCModules, + DiscreteBCImpl, + DiscreteBCModules, +) __all__ = ["BCConfig", "BC", "DiscreteBCConfig", "DiscreteBC"] @@ -97,16 +103,12 @@ def inner_create_impl( imitator.parameters(), lr=self._config.learning_rate ) - checkpointer = Checkpointer( - modules={"imitator": imitator, "optim": optim}, device=self._device - ) + modules = BCModules(optim=optim, imitator=imitator) self._impl = BCImpl( observation_shape=observation_shape, action_size=action_size, - imitator=imitator, - optim=optim, - checkpointer=checkpointer, + modules=modules, device=self._device, ) @@ -171,17 +173,13 @@ def inner_create_impl( imitator.parameters(), lr=self._config.learning_rate ) - checkpointer = Checkpointer( - modules={"imitator": imitator, "optim": optim}, device=self._device - ) + modules = DiscreteBCModules(optim=optim, imitator=imitator) self._impl = DiscreteBCImpl( observation_shape=observation_shape, action_size=action_size, - imitator=imitator, - optim=optim, + modules=modules, beta=self._config.beta, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index 2ff3b3ef..c4b50a27 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -15,9 +15,14 @@ from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field from ...models.torch import CategoricalPolicy, PixelEncoder, compute_output_size -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase -from .torch.bcq_impl import BCQImpl, DiscreteBCQImpl +from .torch.bcq_impl import ( + BCQImpl, + BCQModules, + DiscreteBCQImpl, + DiscreteBCQModules, +) __all__ = ["BCQConfig", "BCQ", "DiscreteBCQConfig", "DiscreteBCQ"] @@ -174,6 +179,13 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) + targ_policy = create_deterministic_residual_policy( + observation_shape, + action_size, + self._config.action_flexibility, + self._config.actor_encoder_factory, + device=self._device, + ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, action_size, @@ -210,38 +222,29 @@ def inner_create_impl( imitator.parameters(), lr=self._config.imitator_learning_rate ) - checkpointer = Checkpointer( - modules={ - "policy": policy, - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "imitator": imitator, - "actor_optim": actor_optim, - "critic_optim": critic_optim, - "imitator_optim": imitator_optim, - }, - device=self._device, - ) - - self._impl = BCQImpl( - observation_shape=observation_shape, - action_size=action_size, + modules = BCQModules( policy=policy, + targ_policy=targ_policy, q_funcs=q_funcs, - q_func_forwarder=q_func_forwarder, targ_q_funcs=targ_q_funcs, - targ_q_func_forwarder=targ_q_func_forwarder, imitator=imitator, actor_optim=actor_optim, critic_optim=critic_optim, imitator_optim=imitator_optim, + ) + + self._impl = BCQImpl( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, lam=self._config.lam, n_action_samples=self._config.n_action_samples, action_flexibility=self._config.action_flexibility, beta=self._config.beta, - checkpointer=checkpointer, device=self._device, ) @@ -398,29 +401,22 @@ def inner_create_impl( unique_params, lr=self._config.learning_rate ) - checkpointer = Checkpointer( - modules={ - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "imitator": imitator, - "optim": optim, - }, - device=self._device, + modules = DiscreteBCQModules( + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + imitator=imitator, + optim=optim, ) self._impl = DiscreteBCQImpl( observation_shape=observation_shape, action_size=action_size, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - imitator=imitator, - optim=optim, gamma=self._config.gamma, action_flexibility=self._config.action_flexibility, beta=self._config.beta, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index fb1a0b08..5abe7e15 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -14,9 +14,9 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase -from .torch.bear_impl import BEARImpl +from .torch.bear_impl import BEARImpl, BEARModules __all__ = ["BEARConfig", "BEAR"] @@ -214,31 +214,10 @@ def inner_create_impl( log_alpha.parameters(), lr=self._config.actor_learning_rate ) - checkpointer = Checkpointer( - modules={ - "policy": policy, - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "imitator": imitator, - "log_temp": log_temp, - "log_alpha": log_alpha, - "actor_optim": actor_optim, - "critic_optim": critic_optim, - "imitator_optim": imitator_optim, - "temp_optim": temp_optim, - "alpha_optim": alpha_optim, - }, - device=self._device, - ) - - self._impl = BEARImpl( - observation_shape=observation_shape, - action_size=action_size, + modules = BEARModules( policy=policy, q_funcs=q_funcs, - q_func_forwarder=q_func_forwarder, targ_q_funcs=targ_q_funcs, - targ_q_func_forwarder=targ_q_func_forwarder, imitator=imitator, log_temp=log_temp, log_alpha=log_alpha, @@ -247,6 +226,14 @@ def inner_create_impl( imitator_optim=imitator_optim, temp_optim=temp_optim, alpha_optim=alpha_optim, + ) + + self._impl = BEARImpl( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, alpha_threshold=self._config.alpha_threshold, @@ -257,7 +244,6 @@ def inner_create_impl( mmd_kernel=self._config.mmd_kernel, mmd_sigma=self._config.mmd_sigma, vae_kl_weight=self._config.vae_kl_weight, - checkpointer=checkpointer, device=self._device, ) @@ -284,7 +270,6 @@ def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: actor_loss = self._impl.update_actor(batch) metrics.update(actor_loss) - self._impl.update_actor_target() self._impl.update_critic_target() return metrics diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index aa36b668..de04b672 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -14,9 +14,10 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase -from .torch.cql_impl import CQLImpl, DiscreteCQLImpl +from .torch.cql_impl import CQLImpl, CQLModules, DiscreteCQLImpl +from .torch.dqn_impl import DQNModules __all__ = ["CQLConfig", "CQL", "DiscreteCQLConfig", "DiscreteCQL"] @@ -179,42 +180,30 @@ def inner_create_impl( log_alpha.parameters(), lr=self._config.alpha_learning_rate ) - checkpointer = Checkpointer( - modules={ - "policy": policy, - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "log_temp": log_temp, - "log_alpha": log_alpha, - "actor_optim": actor_optim, - "critic_optim": critic_optim, - "temp_optim": temp_optim, - "alpha_optim": alpha_optim, - }, - device=self._device, - ) - - self._impl = CQLImpl( - observation_shape=observation_shape, - action_size=action_size, + modules = CQLModules( policy=policy, q_funcs=q_funcs, - q_func_forwarder=q_func_fowarder, targ_q_funcs=targ_q_funcs, - targ_q_func_forwarder=targ_q_func_forwarder, log_temp=log_temp, log_alpha=log_alpha, actor_optim=actor_optim, critic_optim=critic_optim, temp_optim=temp_optim, alpha_optim=alpha_optim, + ) + + self._impl = CQLImpl( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_fowarder, + targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, alpha_threshold=self._config.alpha_threshold, conservative_weight=self._config.conservative_weight, n_action_samples=self._config.n_action_samples, soft_q_backup=self._config.soft_q_backup, - checkpointer=checkpointer, device=self._device, ) @@ -235,7 +224,6 @@ def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: metrics.update(self._impl.update_actor(batch)) self._impl.update_critic_target() - self._impl.update_actor_target() return metrics @@ -327,26 +315,20 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.learning_rate ) - checkpointer = Checkpointer( - modules={ - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "optim": optim, - }, - device=self._device, + modules = DQNModules( + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + optim=optim, ) self._impl = DiscreteCQLImpl( observation_shape=observation_shape, action_size=action_size, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - optim=optim, gamma=self._config.gamma, alpha=self._config.alpha, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/crr.py b/d3rlpy/algos/qlearning/crr.py index ae671ea0..840eded2 100644 --- a/d3rlpy/algos/qlearning/crr.py +++ b/d3rlpy/algos/qlearning/crr.py @@ -11,9 +11,9 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase -from .torch.crr_impl import CRRImpl +from .torch.crr_impl import CRRImpl, CRRModules __all__ = ["CRRConfig", "CRR"] @@ -140,6 +140,12 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) + targ_policy = create_normal_policy( + observation_shape, + action_size, + self._config.actor_encoder_factory, + device=self._device, + ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, action_size, @@ -164,27 +170,21 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.critic_learning_rate ) - checkpointer = Checkpointer( - modules={ - "policy": policy, - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "actor_optim": actor_optim, - "critic_optim": critic_optim, - }, - device=self._device, + modules = CRRModules( + policy=policy, + targ_policy=targ_policy, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + actor_optim=actor_optim, + critic_optim=critic_optim, ) self._impl = CRRImpl( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - actor_optim=actor_optim, - critic_optim=critic_optim, gamma=self._config.gamma, beta=self._config.beta, n_action_samples=self._config.n_action_samples, @@ -192,7 +192,6 @@ def inner_create_impl( weight_type=self._config.weight_type, max_weight=self._config.max_weight, tau=self._config.tau, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/ddpg.py b/d3rlpy/algos/qlearning/ddpg.py index 01b73c96..0fe94cf0 100644 --- a/d3rlpy/algos/qlearning/ddpg.py +++ b/d3rlpy/algos/qlearning/ddpg.py @@ -11,9 +11,9 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase -from .torch.ddpg_impl import DDPGImpl +from .torch.ddpg_impl import DDPGImpl, DDPGModules __all__ = ["DDPGConfig", "DDPG"] @@ -101,6 +101,12 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) + targ_policy = create_deterministic_policy( + observation_shape, + action_size, + self._config.actor_encoder_factory, + device=self._device, + ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, action_size, @@ -125,30 +131,23 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.critic_learning_rate ) - checkpointer = Checkpointer( - modules={ - "policy": policy, - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "actor_optim": actor_optim, - "critic_optim": critic_optim, - }, - device=self._device, + modules = DDPGModules( + policy=policy, + targ_policy=targ_policy, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + actor_optim=actor_optim, + critic_optim=critic_optim, ) self._impl = DDPGImpl( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - actor_optim=actor_optim, - critic_optim=critic_optim, gamma=self._config.gamma, tau=self._config.tau, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index 21832152..95894d92 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -8,9 +8,9 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase -from .torch.dqn_impl import DoubleDQNImpl, DQNImpl +from .torch.dqn_impl import DoubleDQNImpl, DQNImpl, DQNModules __all__ = ["DQNConfig", "DQN", "DoubleDQNConfig", "DoubleDQN"] @@ -89,25 +89,19 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.learning_rate ) - checkpointer = Checkpointer( - modules={ - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "optim": optim, - }, - device=self._device, + modules = DQNModules( + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + optim=optim, ) self._impl = DQNImpl( observation_shape=observation_shape, action_size=action_size, - q_funcs=q_funcs, - targ_q_funcs=targ_q_funcs, q_func_forwarder=forwarder, targ_q_func_forwarder=targ_forwarder, - optim=optim, + modules=modules, gamma=self._config.gamma, - checkpointer=checkpointer, device=self._device, ) @@ -203,25 +197,19 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.learning_rate ) - checkpointer = Checkpointer( - modules={ - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "optim": optim, - }, - device=self._device, + modules = DQNModules( + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + optim=optim, ) self._impl = DoubleDQNImpl( observation_shape=observation_shape, action_size=action_size, - q_funcs=q_funcs, - targ_q_funcs=targ_q_funcs, + modules=modules, q_func_forwarder=forwarder, targ_q_func_forwarder=targ_forwarder, - optim=optim, gamma=self._config.gamma, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index fa7ad2b9..63f037ab 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -12,9 +12,9 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import MeanQFunctionFactory -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase -from .torch.iql_impl import IQLImpl +from .torch.iql_impl import IQLImpl, IQLModules __all__ = ["IQLConfig", "IQL"] @@ -150,35 +150,26 @@ def inner_create_impl( q_func_params + v_func_params, lr=self._config.critic_learning_rate ) - checkpointer = Checkpointer( - modules={ - "policy": policy, - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "value_func": value_func, - "actor_optim": actor_optim, - "critic_optim": critic_optim, - }, - device=self._device, + modules = IQLModules( + policy=policy, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + value_func=value_func, + actor_optim=actor_optim, + critic_optim=critic_optim, ) self._impl = IQLImpl( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - value_func=value_func, - actor_optim=actor_optim, - critic_optim=critic_optim, gamma=self._config.gamma, tau=self._config.tau, expectile=self._config.expectile, weight_temp=self._config.weight_temp, max_weight=self._config.max_weight, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/nfq.py b/d3rlpy/algos/qlearning/nfq.py index 2f00d4d1..b5c76a68 100644 --- a/d3rlpy/algos/qlearning/nfq.py +++ b/d3rlpy/algos/qlearning/nfq.py @@ -8,9 +8,9 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase -from .torch.dqn_impl import DQNImpl +from .torch.dqn_impl import DQNImpl, DQNModules __all__ = ["NFQConfig", "NFQ"] @@ -91,25 +91,19 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.learning_rate ) - checkpointer = Checkpointer( - modules={ - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "optim": optim, - }, - device=self._device, + modules = DQNModules( + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + optim=optim, ) self._impl = DQNImpl( observation_shape=observation_shape, action_size=action_size, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - optim=optim, gamma=self._config.gamma, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index 539ee4bb..629f546a 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -13,9 +13,14 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase -from .torch.plas_impl import PLASImpl, PLASWithPerturbationImpl +from .torch.plas_impl import ( + PLASImpl, + PLASModules, + PLASWithPerturbationImpl, + PLASWithPerturbationModules, +) __all__ = [ "PLASConfig", @@ -112,6 +117,12 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) + targ_policy = create_deterministic_policy( + observation_shape, + 2 * action_size, + self._config.actor_encoder_factory, + device=self._device, + ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, action_size, @@ -148,36 +159,27 @@ def inner_create_impl( imitator.parameters(), lr=self._config.imitator_learning_rate ) - checkpointer = Checkpointer( - modules={ - "policy": policy, - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "imitator": imitator, - "actor_optim": actor_optim, - "critic_optim": critic_optim, - "imitator_optim": imitator_optim, - }, - device=self._device, - ) - - self._impl = PLASImpl( - observation_shape=observation_shape, - action_size=action_size, + modules = PLASModules( policy=policy, + targ_policy=targ_policy, q_funcs=q_funcs, - q_func_forwarder=q_func_forwarder, targ_q_funcs=targ_q_funcs, - targ_q_func_forwarder=targ_q_func_forwarder, imitator=imitator, actor_optim=actor_optim, critic_optim=critic_optim, imitator_optim=imitator_optim, + ) + + self._impl = PLASImpl( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, lam=self._config.lam, beta=self._config.beta, - checkpointer=checkpointer, device=self._device, ) @@ -267,6 +269,12 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) + targ_policy = create_deterministic_policy( + observation_shape, + 2 * action_size, + self._config.actor_encoder_factory, + device=self._device, + ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, action_size, @@ -299,6 +307,13 @@ def inner_create_impl( encoder_factory=self._config.actor_encoder_factory, device=self._device, ) + targ_perturbation = create_deterministic_residual_policy( + observation_shape=observation_shape, + action_size=action_size, + scale=self._config.action_flexibility, + encoder_factory=self._config.actor_encoder_factory, + device=self._device, + ) parameters = list(policy.parameters()) parameters += list(perturbation.parameters()) @@ -312,38 +327,29 @@ def inner_create_impl( imitator.parameters(), lr=self._config.imitator_learning_rate ) - checkpointer = Checkpointer( - modules={ - "policy": policy, - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "imitator": imitator, - "perturbation": perturbation, - "actor_optim": actor_optim, - "critic_optim": critic_optim, - "imitator_optim": imitator_optim, - }, - device=self._device, - ) - - self._impl = PLASWithPerturbationImpl( - observation_shape=observation_shape, - action_size=action_size, + modules = PLASWithPerturbationModules( policy=policy, + targ_policy=targ_policy, q_funcs=q_funcs, - q_func_forwarder=q_func_forwarder, targ_q_funcs=targ_q_funcs, - targ_q_func_forwarder=targ_q_func_forwarder, imitator=imitator, perturbation=perturbation, + targ_perturbation=targ_perturbation, actor_optim=actor_optim, critic_optim=critic_optim, imitator_optim=imitator_optim, + ) + + self._impl = PLASWithPerturbationImpl( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, lam=self._config.lam, beta=self._config.beta, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/sac.py b/d3rlpy/algos/qlearning/sac.py index 797eb7ef..0748d3c0 100644 --- a/d3rlpy/algos/qlearning/sac.py +++ b/d3rlpy/algos/qlearning/sac.py @@ -15,9 +15,14 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase -from .torch.sac_impl import DiscreteSACImpl, SACImpl +from .torch.sac_impl import ( + DiscreteSACImpl, + DiscreteSACModules, + SACImpl, + SACModules, +) __all__ = ["SACConfig", "SAC", "DiscreteSACConfig", "DiscreteSAC"] @@ -157,34 +162,24 @@ def inner_create_impl( log_temp.parameters(), lr=self._config.temp_learning_rate ) - checkpointer = Checkpointer( - modules={ - "policy": policy, - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "log_temp": log_temp, - "actor_optim": actor_optim, - "critic_optim": critic_optim, - "temp_optim": temp_optim, - }, - device=self._device, - ) - - self._impl = SACImpl( - observation_shape=observation_shape, - action_size=action_size, + modules = SACModules( policy=policy, q_funcs=q_funcs, - q_func_forwarder=q_func_forwarder, targ_q_funcs=targ_q_funcs, - targ_q_func_forwarder=targ_q_func_forwarder, log_temp=log_temp, actor_optim=actor_optim, critic_optim=critic_optim, temp_optim=temp_optim, + ) + + self._impl = SACImpl( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, - checkpointer=checkpointer, device=self._device, ) @@ -200,7 +195,6 @@ def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: metrics.update(self._impl.update_critic(batch)) metrics.update(self._impl.update_actor(batch)) self._impl.update_critic_target() - self._impl.update_actor_target() return metrics @@ -328,33 +322,23 @@ def inner_create_impl( log_temp.parameters(), lr=self._config.temp_learning_rate ) - checkpointer = Checkpointer( - modules={ - "policy": policy, - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "log_temp": log_temp, - "critic_optim": critic_optim, - "actor_optim": actor_optim, - "temp_optim": temp_optim, - }, - device=self._device, + modules = DiscreteSACModules( + policy=policy, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + log_temp=log_temp, + actor_optim=actor_optim, + critic_optim=critic_optim, + temp_optim=temp_optim, ) self._impl = DiscreteSACImpl( observation_shape=observation_shape, action_size=action_size, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - policy=policy, - log_temp=log_temp, - actor_optim=actor_optim, - critic_optim=critic_optim, - temp_optim=temp_optim, gamma=self._config.gamma, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index 81d8522b..eb92b0bb 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -11,8 +11,9 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase +from .torch.ddpg_impl import DDPGModules from .torch.td3_impl import TD3Impl __all__ = ["TD3Config", "TD3"] @@ -109,6 +110,12 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) + targ_policy = create_deterministic_policy( + observation_shape, + action_size, + self._config.actor_encoder_factory, + device=self._device, + ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, action_size, @@ -133,32 +140,25 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.critic_learning_rate ) - checkpointer = Checkpointer( - modules={ - "policy": policy, - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "actor_optim": actor_optim, - "critic_optim": critic_optim, - }, - device=self._device, + modules = DDPGModules( + policy=policy, + targ_policy=targ_policy, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + actor_optim=actor_optim, + critic_optim=critic_optim, ) self._impl = TD3Impl( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - actor_optim=actor_optim, - critic_optim=critic_optim, gamma=self._config.gamma, tau=self._config.tau, target_smoothing_sigma=self._config.target_smoothing_sigma, target_smoothing_clip=self._config.target_smoothing_clip, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3_plus_bc.py b/d3rlpy/algos/qlearning/td3_plus_bc.py index 49b3ecb2..590fda2a 100644 --- a/d3rlpy/algos/qlearning/td3_plus_bc.py +++ b/d3rlpy/algos/qlearning/td3_plus_bc.py @@ -11,8 +11,9 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import Checkpointer, TorchMiniBatch +from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase +from .torch.ddpg_impl import DDPGModules from .torch.td3_plus_bc_impl import TD3PlusBCImpl __all__ = ["TD3PlusBCConfig", "TD3PlusBC"] @@ -101,6 +102,12 @@ def inner_create_impl( self._config.actor_encoder_factory, device=self._device, ) + targ_policy = create_deterministic_policy( + observation_shape, + action_size, + self._config.actor_encoder_factory, + device=self._device, + ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, action_size, @@ -125,33 +132,26 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.critic_learning_rate ) - checkpointer = Checkpointer( - modules={ - "policy": policy, - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "actor_optim": actor_optim, - "critic_optim": critic_optim, - }, - device=self._device, + modules = DDPGModules( + policy=policy, + targ_policy=targ_policy, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + actor_optim=actor_optim, + critic_optim=critic_optim, ) self._impl = TD3PlusBCImpl( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - actor_optim=actor_optim, - critic_optim=critic_optim, gamma=self._config.gamma, tau=self._config.tau, target_smoothing_sigma=self._config.target_smoothing_sigma, target_smoothing_clip=self._config.target_smoothing_clip, alpha=self._config.alpha, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/torch/awac_impl.py b/d3rlpy/algos/qlearning/torch/awac_impl.py index 65e2565f..81646adc 100644 --- a/d3rlpy/algos/qlearning/torch/awac_impl.py +++ b/d3rlpy/algos/qlearning/torch/awac_impl.py @@ -1,24 +1,18 @@ import torch import torch.nn.functional as F -from torch import nn -from torch.optim import Adam, Optimizer from ....dataset import Shape from ....models.torch import ( ContinuousEnsembleQFunctionForwarder, - NormalPolicy, - Parameter, - Policy, build_gaussian_distribution, ) -from ....torch_utility import Checkpointer, TorchMiniBatch -from .sac_impl import SACImpl +from ....torch_utility import TorchMiniBatch +from .sac_impl import SACImpl, SACModules __all__ = ["AWACImpl"] class AWACImpl(SACImpl): - _policy: NormalPolicy _lam: float _n_action_samples: int @@ -26,37 +20,23 @@ def __init__( self, observation_shape: Shape, action_size: int, - q_funcs: nn.ModuleList, + modules: SACModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - policy: Policy, - actor_optim: Optimizer, - critic_optim: Optimizer, gamma: float, tau: float, lam: float, n_action_samples: int, - checkpointer: Checkpointer, device: str, ): - assert isinstance(policy, NormalPolicy) - dummy_log_temp = Parameter(torch.zeros(1)) super().__init__( observation_shape=observation_shape, action_size=action_size, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - policy=policy, - actor_optim=actor_optim, - critic_optim=critic_optim, - log_temp=dummy_log_temp, - temp_optim=Adam(dummy_log_temp.parameters(), lr=0.0), gamma=gamma, tau=tau, - checkpointer=checkpointer, device=device, ) self._lam = lam @@ -64,7 +44,9 @@ def __init__( def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: # compute log probability - dist = build_gaussian_distribution(self._policy(batch.observations)) + dist = build_gaussian_distribution( + self._modules.policy(batch.observations) + ) log_probs = dist.log_prob(batch.actions) # compute exponential weight @@ -85,7 +67,7 @@ def _compute_weights( # sample actions # (batch_size * N, action_size) - dist = build_gaussian_distribution(self._policy(obs_t)) + dist = build_gaussian_distribution(self._modules.policy(obs_t)) policy_actions = dist.sample_n(self._n_action_samples) flat_actions = policy_actions.reshape(-1, self.action_size) @@ -113,5 +95,5 @@ def _compute_weights( return weights * adv_values.numel() def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: - dist = build_gaussian_distribution(self._policy(x)) + dist = build_gaussian_distribution(self._modules.policy(x)) return dist.sample() diff --git a/d3rlpy/algos/qlearning/torch/bc_impl.py b/d3rlpy/algos/qlearning/torch/bc_impl.py index ff72fe2e..3661acbd 100644 --- a/d3rlpy/algos/qlearning/torch/bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/bc_impl.py @@ -1,3 +1,4 @@ +import dataclasses from abc import ABCMeta, abstractmethod from typing import Union @@ -14,40 +15,42 @@ compute_discrete_imitation_loss, compute_stochastic_imitation_loss, ) -from ....torch_utility import Checkpointer, TorchMiniBatch, train_api +from ....torch_utility import Modules, TorchMiniBatch, train_api from ..base import QLearningAlgoImplBase -__all__ = ["BCImpl", "DiscreteBCImpl"] +__all__ = ["BCImpl", "DiscreteBCImpl", "BCModules", "DiscreteBCModules"] + + +@dataclasses.dataclass(frozen=True) +class BCBaseModules(Modules): + optim: Optimizer class BCBaseImpl(QLearningAlgoImplBase, metaclass=ABCMeta): - _learning_rate: float - _optim: Optimizer + _modules: BCBaseModules def __init__( self, observation_shape: Shape, action_size: int, - optim: Optimizer, - checkpointer: Checkpointer, + modules: BCBaseModules, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - checkpointer=checkpointer, + modules=modules, device=device, ) - self._optim = optim @train_api def update_imitator(self, batch: TorchMiniBatch) -> float: - self._optim.zero_grad() + self._modules.optim.zero_grad() loss = self.compute_loss(batch.observations, batch.actions) loss.backward() - self._optim.step() + self._modules.optim.step() return float(loss.cpu().detach().numpy()) @@ -66,81 +69,86 @@ def inner_predict_value( raise NotImplementedError("BC does not support value estimation") +@dataclasses.dataclass(frozen=True) +class BCModules(BCBaseModules): + imitator: Union[DeterministicPolicy, NormalPolicy] + + class BCImpl(BCBaseImpl): - _imitator: Union[DeterministicPolicy, NormalPolicy] + _modules: BCModules def __init__( self, observation_shape: Shape, action_size: int, - imitator: Union[DeterministicPolicy, NormalPolicy], - optim: Optimizer, - checkpointer: Checkpointer, + modules: BCModules, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - optim=optim, - checkpointer=checkpointer, + modules=modules, device=device, ) - self._imitator = imitator def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - return self._imitator(x).squashed_mu + return self._modules.imitator(x).squashed_mu def compute_loss( self, obs_t: torch.Tensor, act_t: torch.Tensor ) -> torch.Tensor: - if isinstance(self._imitator, DeterministicPolicy): + if isinstance(self._modules.imitator, DeterministicPolicy): return compute_deterministic_imitation_loss( - self._imitator, obs_t, act_t + self._modules.imitator, obs_t, act_t ) else: return compute_stochastic_imitation_loss( - self._imitator, obs_t, act_t + self._modules.imitator, obs_t, act_t ) @property def policy(self) -> Policy: - return self._imitator + return self._modules.imitator @property def policy_optim(self) -> Optimizer: - return self._optim + return self._modules.optim + + +@dataclasses.dataclass(frozen=True) +class DiscreteBCModules(BCBaseModules): + imitator: CategoricalPolicy class DiscreteBCImpl(BCBaseImpl): + _modules: DiscreteBCModules _beta: float - _imitator: CategoricalPolicy def __init__( self, observation_shape: Shape, action_size: int, - imitator: CategoricalPolicy, - optim: Optimizer, + modules: DiscreteBCModules, beta: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - optim=optim, - checkpointer=checkpointer, + modules=modules, device=device, ) - self._imitator = imitator self._beta = beta def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - return self._imitator(x).logits.argmax(dim=1) + return self._modules.imitator(x).logits.argmax(dim=1) def compute_loss( self, obs_t: torch.Tensor, act_t: torch.Tensor ) -> torch.Tensor: return compute_discrete_imitation_loss( - policy=self._imitator, x=obs_t, action=act_t.long(), beta=self._beta + policy=self._modules.imitator, + x=obs_t, + action=act_t.long(), + beta=self._beta, ) diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index 1b5b5a5e..997257e8 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -1,9 +1,9 @@ +import dataclasses import math from typing import Dict, cast import torch import torch.nn.functional as F -from torch import nn from torch.optim import Optimizer from ....dataset import Shape @@ -18,66 +18,57 @@ compute_vae_error, forward_vae_decode, ) -from ....torch_utility import Checkpointer, TorchMiniBatch, train_api -from .ddpg_impl import DDPGBaseImpl -from .dqn_impl import DoubleDQNImpl +from ....torch_utility import TorchMiniBatch, soft_sync, train_api +from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules +from .dqn_impl import DoubleDQNImpl, DQNModules -__all__ = ["BCQImpl", "DiscreteBCQImpl"] +__all__ = ["BCQImpl", "DiscreteBCQImpl", "BCQModules", "DiscreteBCQModules"] + + +@dataclasses.dataclass(frozen=True) +class BCQModules(DDPGBaseModules): + policy: DeterministicResidualPolicy + targ_policy: DeterministicResidualPolicy + imitator: ConditionalVAE + imitator_optim: Optimizer class BCQImpl(DDPGBaseImpl): + _modules: BCQModules _lam: float _n_action_samples: int _action_flexibility: float _beta: float - _policy: DeterministicResidualPolicy - _targ_policy: DeterministicResidualPolicy - _imitator: ConditionalVAE - _imitator_optim: Optimizer def __init__( self, observation_shape: Shape, action_size: int, - policy: DeterministicResidualPolicy, - q_funcs: nn.ModuleList, + modules: BCQModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - imitator: ConditionalVAE, - actor_optim: Optimizer, - critic_optim: Optimizer, - imitator_optim: Optimizer, gamma: float, tau: float, lam: float, n_action_samples: int, action_flexibility: float, beta: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - actor_optim=actor_optim, - critic_optim=critic_optim, gamma=gamma, tau=tau, - checkpointer=checkpointer, device=device, ) self._lam = lam self._n_action_samples = n_action_samples self._action_flexibility = action_flexibility self._beta = beta - self._imitator = imitator - self._imitator_optim = imitator_optim def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: latent = torch.randn( @@ -87,11 +78,11 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: ) clipped_latent = latent.clamp(-0.5, 0.5) sampled_action = forward_vae_decode( - vae=self._imitator, + vae=self._modules.imitator, x=batch.observations, latent=clipped_latent, ) - action = self._policy(batch.observations, sampled_action) + action = self._modules.policy(batch.observations, sampled_action) value = self._q_func_forwarder.compute_expected_q( batch.observations, action.squashed_mu, "none" ) @@ -99,17 +90,17 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: @train_api def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: - self._imitator_optim.zero_grad() + self._modules.imitator_optim.zero_grad() loss = compute_vae_error( - vae=self._imitator, + vae=self._modules.imitator, x=batch.observations, action=batch.actions, beta=self._beta, ) loss.backward() - self._imitator_optim.step() + self._modules.imitator_optim.step() return {"imitator_loss": float(loss.cpu().detach().numpy())} @@ -131,12 +122,12 @@ def _sample_repeated_action( clipped_latent = latent.clamp(-0.5, 0.5) # sample action sampled_action = forward_vae_decode( - vae=self._imitator, + vae=self._modules.imitator, x=flattened_x, latent=clipped_latent, ) # add residual action - policy = self._targ_policy if target else self._policy + policy = self._modules.targ_policy if target else self._modules.policy action = policy(flattened_x, sampled_action) return action.squashed_mu.view( -1, self._n_action_samples, self._action_size @@ -174,60 +165,58 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): repeated_x = self._repeat_observation(batch.next_observations) actions = self._sample_repeated_action(repeated_x, True) - values = compute_max_with_n_actions( batch.next_observations, actions, self._targ_q_func_forwarder, self._lam, ) - return values + def update_actor_target(self) -> None: + soft_sync(self._modules.targ_policy, self._modules.policy, self._tau) + + +@dataclasses.dataclass(frozen=True) +class DiscreteBCQModules(DQNModules): + imitator: CategoricalPolicy + class DiscreteBCQImpl(DoubleDQNImpl): + _modules: DiscreteBCQModules _action_flexibility: float _beta: float - _imitator: CategoricalPolicy def __init__( self, observation_shape: Shape, action_size: int, - q_funcs: nn.ModuleList, + modules: DiscreteBCQModules, q_func_forwarder: DiscreteEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, - imitator: CategoricalPolicy, - optim: Optimizer, gamma: float, action_flexibility: float, beta: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - optim=optim, gamma=gamma, - checkpointer=checkpointer, device=device, ) self._action_flexibility = action_flexibility self._beta = beta - self._imitator = imitator def compute_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor ) -> torch.Tensor: loss = super().compute_loss(batch, q_tpn) imitator_loss = compute_discrete_imitation_loss( - policy=self._imitator, + policy=self._modules.imitator, x=batch.observations, action=batch.actions.long(), beta=self._beta, @@ -235,7 +224,7 @@ def compute_loss( return loss + imitator_loss def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - dist = self._imitator(x) + dist = self._modules.imitator(x) log_probs = F.log_softmax(dist.logits, dim=1) ratio = log_probs - log_probs.max(dim=1, keepdim=True).values mask = (ratio > math.log(self._action_flexibility)).float() diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index 8a10aee6..10b83d99 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -1,24 +1,23 @@ +import dataclasses from typing import Dict import torch -from torch import nn from torch.optim import Optimizer from ....dataset import Shape from ....models.torch import ( ConditionalVAE, ContinuousEnsembleQFunctionForwarder, - NormalPolicy, Parameter, build_squashed_gaussian_distribution, compute_max_with_n_actions_and_indices, compute_vae_error, forward_vae_sample_n, ) -from ....torch_utility import Checkpointer, TorchMiniBatch, train_api -from .sac_impl import SACImpl +from ....torch_utility import TorchMiniBatch, train_api +from .sac_impl import SACImpl, SACModules -__all__ = ["BEARImpl"] +__all__ = ["BEARImpl", "BEARModules"] def _gaussian_kernel( @@ -35,8 +34,16 @@ def _laplacian_kernel( return (-(x - y).abs().sum(dim=3) / (2 * sigma)).exp() +@dataclasses.dataclass(frozen=True) +class BEARModules(SACModules): + imitator: ConditionalVAE + log_alpha: Parameter + imitator_optim: Optimizer + alpha_optim: Optimizer + + class BEARImpl(SACImpl): - _policy: NormalPolicy + _modules: BEARModules _alpha_threshold: float _lam: float _n_action_samples: int @@ -45,28 +52,14 @@ class BEARImpl(SACImpl): _mmd_kernel: str _mmd_sigma: float _vae_kl_weight: float - _imitator: ConditionalVAE - _imitator_optim: Optimizer - _log_alpha: Parameter - _alpha_optim: Optimizer def __init__( self, observation_shape: Shape, action_size: int, - policy: NormalPolicy, - q_funcs: nn.ModuleList, + modules: BEARModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - imitator: ConditionalVAE, - log_temp: Parameter, - log_alpha: Parameter, - actor_optim: Optimizer, - critic_optim: Optimizer, - imitator_optim: Optimizer, - temp_optim: Optimizer, - alpha_optim: Optimizer, gamma: float, tau: float, alpha_threshold: float, @@ -77,24 +70,16 @@ def __init__( mmd_kernel: str, mmd_sigma: float, vae_kl_weight: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - log_temp=log_temp, - actor_optim=actor_optim, - critic_optim=critic_optim, - temp_optim=temp_optim, gamma=gamma, tau=tau, - checkpointer=checkpointer, device=device, ) self._alpha_threshold = alpha_threshold @@ -105,10 +90,6 @@ def __init__( self._mmd_kernel = mmd_kernel self._mmd_sigma = mmd_sigma self._vae_kl_weight = vae_kl_weight - self._imitator = imitator - self._log_alpha = log_alpha - self._imitator_optim = imitator_optim - self._alpha_optim = alpha_optim def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: loss = super().compute_actor_loss(batch) @@ -117,35 +98,35 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: @train_api def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: - self._actor_optim.zero_grad() + self._modules.actor_optim.zero_grad() loss = self._compute_mmd_loss(batch.observations) loss.backward() - self._actor_optim.step() + self._modules.actor_optim.step() return {"actor_loss": float(loss.cpu().detach().numpy())} def _compute_mmd_loss(self, obs_t: torch.Tensor) -> torch.Tensor: mmd = self._compute_mmd(obs_t) - alpha = self._log_alpha().exp() + alpha = self._modules.log_alpha().exp() return (alpha * (mmd - self._alpha_threshold)).mean() @train_api def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: - self._imitator_optim.zero_grad() + self._modules.imitator_optim.zero_grad() loss = self.compute_imitator_loss(batch) loss.backward() - self._imitator_optim.step() + self._modules.imitator_optim.step() return {"imitator_loss": float(loss.cpu().detach().numpy())} def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor: return compute_vae_error( - vae=self._imitator, + vae=self._modules.imitator, x=batch.observations, action=batch.actions, beta=self._vae_kl_weight, @@ -155,14 +136,14 @@ def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor: def update_alpha(self, batch: TorchMiniBatch) -> Dict[str, float]: loss = -self._compute_mmd_loss(batch.observations) - self._alpha_optim.zero_grad() + self._modules.alpha_optim.zero_grad() loss.backward() - self._alpha_optim.step() + self._modules.alpha_optim.step() # clip for stability - self._log_alpha.data.clamp_(-5.0, 10.0) + self._modules.log_alpha.data.clamp_(-5.0, 10.0) - cur_alpha = self._log_alpha().exp().cpu().detach().numpy()[0][0] + cur_alpha = self._modules.log_alpha().exp().cpu().detach().numpy()[0][0] return { "alpha_loss": float(loss.cpu().detach().numpy()), @@ -172,9 +153,12 @@ def update_alpha(self, batch: TorchMiniBatch) -> Dict[str, float]: def _compute_mmd(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): behavior_actions = forward_vae_sample_n( - self._imitator, x, self._n_mmd_action_samples, with_squash=False + self._modules.imitator, + x, + self._n_mmd_action_samples, + with_squash=False, ) - dist = build_squashed_gaussian_distribution(self._policy(x)) + dist = build_squashed_gaussian_distribution(self._modules.policy(x)) policy_actions = dist.sample_n_without_squash( self._n_mmd_action_samples ) @@ -221,7 +205,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): # BCQ-like target computation dist = build_squashed_gaussian_distribution( - self._policy(batch.next_observations) + self._modules.policy(batch.next_observations) ) actions, log_probs = dist.sample_n_with_log_prob( self._n_target_samples @@ -237,12 +221,12 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: batch_size = batch.observations.shape[0] max_log_prob = log_probs[torch.arange(batch_size), indices] - return values - self._log_temp().exp() * max_log_prob + return values - self._modules.log_temp().exp() * max_log_prob def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): # (batch, n, action) - dist = build_squashed_gaussian_distribution(self._policy(x)) + dist = build_squashed_gaussian_distribution(self._modules.policy(x)) actions = dist.onnx_safe_sample_n(self._n_action_samples) # (batch, n, action) -> (batch * n, action) flat_actions = actions.reshape(-1, self._action_size) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index a232c5aa..39ab82a1 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -1,81 +1,67 @@ +import dataclasses import math from typing import Dict import torch import torch.nn.functional as F -from torch import nn from torch.optim import Optimizer from ....dataset import Shape from ....models.torch import ( ContinuousEnsembleQFunctionForwarder, DiscreteEnsembleQFunctionForwarder, - NormalPolicy, Parameter, build_squashed_gaussian_distribution, ) -from ....torch_utility import Checkpointer, TorchMiniBatch, train_api -from .dqn_impl import DoubleDQNImpl -from .sac_impl import SACImpl +from ....torch_utility import TorchMiniBatch, train_api +from .dqn_impl import DoubleDQNImpl, DQNModules +from .sac_impl import SACImpl, SACModules -__all__ = ["CQLImpl", "DiscreteCQLImpl"] +__all__ = ["CQLImpl", "DiscreteCQLImpl", "CQLModules"] + + +@dataclasses.dataclass(frozen=True) +class CQLModules(SACModules): + log_alpha: Parameter + alpha_optim: Optimizer class CQLImpl(SACImpl): + _modules: CQLModules _alpha_threshold: float _conservative_weight: float _n_action_samples: int _soft_q_backup: bool - _log_alpha: Parameter - _alpha_optim: Optimizer def __init__( self, observation_shape: Shape, action_size: int, - policy: NormalPolicy, - q_funcs: nn.ModuleList, + modules: CQLModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - log_temp: Parameter, - log_alpha: Parameter, - actor_optim: Optimizer, - critic_optim: Optimizer, - temp_optim: Optimizer, - alpha_optim: Optimizer, gamma: float, tau: float, alpha_threshold: float, conservative_weight: float, n_action_samples: int, soft_q_backup: bool, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - log_temp=log_temp, - actor_optim=actor_optim, - critic_optim=critic_optim, - temp_optim=temp_optim, gamma=gamma, tau=tau, - checkpointer=checkpointer, device=device, ) self._alpha_threshold = alpha_threshold self._conservative_weight = conservative_weight self._n_action_samples = n_action_samples self._soft_q_backup = soft_q_backup - self._log_alpha = log_alpha - self._alpha_optim = alpha_optim def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor @@ -89,9 +75,9 @@ def compute_critic_loss( @train_api def update_alpha(self, batch: TorchMiniBatch) -> Dict[str, float]: # Q function should be inference mode for stability - self._q_funcs.eval() + self._modules.q_funcs.eval() - self._alpha_optim.zero_grad() + self._modules.alpha_optim.zero_grad() # the original implementation does scale the loss value loss = -self._compute_conservative_loss( @@ -99,9 +85,9 @@ def update_alpha(self, batch: TorchMiniBatch) -> Dict[str, float]: ) loss.backward() - self._alpha_optim.step() + self._modules.alpha_optim.step() - cur_alpha = self._log_alpha().exp().cpu().detach().numpy()[0][0] + cur_alpha = self._modules.log_alpha().exp().cpu().detach().numpy()[0][0] return { "alpha_loss": float(loss.cpu().detach().numpy()), @@ -113,7 +99,7 @@ def _compute_policy_is_values( ) -> torch.Tensor: with torch.no_grad(): dist = build_squashed_gaussian_distribution( - self._policy(policy_obs) + self._modules.policy(policy_obs) ) policy_actions, n_log_probs = dist.sample_n_with_log_prob( self._n_action_samples @@ -187,7 +173,7 @@ def _compute_conservative_loss( scaled_loss = self._conservative_weight * loss # clip for stability - clipped_alpha = self._log_alpha().exp().clamp(0, 1e6)[0][0] + clipped_alpha = self._modules.log_alpha().exp().clamp(0, 1e6)[0][0] return clipped_alpha * (scaled_loss - self._alpha_threshold) @@ -202,7 +188,7 @@ def _compute_deterministic_target( self, batch: TorchMiniBatch ) -> torch.Tensor: with torch.no_grad(): - action = self._policy(batch.next_observations).squashed_mu + action = self._modules.policy(batch.next_observations).squashed_mu return self._targ_q_func_forwarder.compute_target( batch.next_observations, action, @@ -217,26 +203,20 @@ def __init__( self, observation_shape: Shape, action_size: int, - q_funcs: nn.ModuleList, + modules: DQNModules, q_func_forwarder: DiscreteEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, - optim: Optimizer, gamma: float, alpha: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - optim=optim, gamma=gamma, - checkpointer=checkpointer, device=device, ) self._alpha = alpha @@ -256,9 +236,9 @@ def _compute_conservative_loss( @train_api def update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._optim is not None + assert self._modules.optim is not None - self._optim.zero_grad() + self._modules.optim.zero_grad() q_tpn = self.compute_target(batch) @@ -269,7 +249,7 @@ def update(self, batch: TorchMiniBatch) -> Dict[str, float]: loss = td_loss + self._alpha * conservative_loss loss.backward() - self._optim.step() + self._modules.optim.step() return { "loss": float(loss.cpu().detach().numpy()), diff --git a/d3rlpy/algos/qlearning/torch/crr_impl.py b/d3rlpy/algos/qlearning/torch/crr_impl.py index 3adcb922..59183666 100644 --- a/d3rlpy/algos/qlearning/torch/crr_impl.py +++ b/d3rlpy/algos/qlearning/torch/crr_impl.py @@ -1,7 +1,7 @@ +import dataclasses + import torch import torch.nn.functional as F -from torch import nn -from torch.optim import Optimizer from ....dataset import Shape from ....models.torch import ( @@ -9,32 +9,33 @@ NormalPolicy, build_gaussian_distribution, ) -from ....torch_utility import Checkpointer, TorchMiniBatch, hard_sync -from .ddpg_impl import DDPGBaseImpl +from ....torch_utility import TorchMiniBatch, hard_sync, soft_sync +from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules + +__all__ = ["CRRImpl", "CRRModules"] + -__all__ = ["CRRImpl"] +@dataclasses.dataclass(frozen=True) +class CRRModules(DDPGBaseModules): + policy: NormalPolicy + targ_policy: NormalPolicy class CRRImpl(DDPGBaseImpl): + _modules: CRRModules _beta: float _n_action_samples: int _advantage_type: str _weight_type: str _max_weight: float - _policy: NormalPolicy - _targ_policy: NormalPolicy def __init__( self, observation_shape: Shape, action_size: int, - policy: NormalPolicy, - q_funcs: nn.ModuleList, + modules: CRRModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - actor_optim: Optimizer, - critic_optim: Optimizer, gamma: float, beta: float, n_action_samples: int, @@ -42,22 +43,16 @@ def __init__( weight_type: str, max_weight: float, tau: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - actor_optim=actor_optim, - critic_optim=critic_optim, gamma=gamma, tau=tau, - checkpointer=checkpointer, device=device, ) self._beta = beta @@ -68,7 +63,9 @@ def __init__( def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: # compute log probability - dist = build_gaussian_distribution(self._policy(batch.observations)) + dist = build_gaussian_distribution( + self._modules.policy(batch.observations) + ) log_probs = dist.log_prob(batch.actions) weight = self._compute_weight(batch.observations, batch.actions) @@ -92,7 +89,7 @@ def _compute_advantage( batch_size = obs_t.shape[0] # (batch_size, N, action) - dist = build_gaussian_distribution(self._policy(obs_t)) + dist = build_gaussian_distribution(self._modules.policy(obs_t)) policy_actions = dist.sample_n(self._n_action_samples) flat_actions = policy_actions.reshape(-1, self._action_size) @@ -127,7 +124,7 @@ def _compute_advantage( def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): action = build_gaussian_distribution( - self._targ_policy(batch.next_observations) + self._modules.targ_policy(batch.next_observations) ).sample() return self._targ_q_func_forwarder.compute_target( batch.next_observations, @@ -138,7 +135,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: # compute CWP - dist = build_gaussian_distribution(self._policy(x)) + dist = build_gaussian_distribution(self._modules.policy(x)) actions = dist.onnx_safe_sample_n(self._n_action_samples) # (batch_size, N, action_size) -> (batch_size * N, action_size) flat_actions = actions.reshape(-1, self._action_size) @@ -167,11 +164,14 @@ def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: return actions[torch.arange(x.shape[0]), indices.view(-1)] def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: - dist = build_gaussian_distribution(self._policy(x)) + dist = build_gaussian_distribution(self._modules.policy(x)) return dist.sample() def sync_critic_target(self) -> None: - hard_sync(self._targ_q_funcs, self._q_funcs) + hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) def sync_actor_target(self) -> None: - hard_sync(self._targ_policy, self._policy) + hard_sync(self._modules.targ_policy, self._modules.policy) + + def update_actor_target(self) -> None: + soft_sync(self._modules.targ_policy, self._modules.policy, self._tau) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 2add695b..7fd9f521 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -1,4 +1,4 @@ -import copy +import dataclasses from abc import ABCMeta, abstractmethod from typing import Dict @@ -9,7 +9,7 @@ from ....dataset import Shape from ....models.torch import ContinuousEnsembleQFunctionForwarder, Policy from ....torch_utility import ( - Checkpointer, + Modules, TorchMiniBatch, hard_sync, soft_sync, @@ -18,67 +18,60 @@ from ..base import QLearningAlgoImplBase from .utility import ContinuousQFunctionMixin -__all__ = ["DDPGImpl"] +__all__ = ["DDPGImpl", "DDPGBaseImpl", "DDPGBaseModules", "DDPGModules"] + + +@dataclasses.dataclass(frozen=True) +class DDPGBaseModules(Modules): + policy: Policy + q_funcs: nn.ModuleList + targ_q_funcs: nn.ModuleList + actor_optim: Optimizer + critic_optim: Optimizer class DDPGBaseImpl( ContinuousQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta ): + _modules: DDPGBaseModules _gamma: float _tau: float - _q_funcs: nn.ModuleList _q_func_forwarder: ContinuousEnsembleQFunctionForwarder - _policy: Policy - _targ_q_funcs: nn.ModuleList _targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder - _targ_policy: Policy - _actor_optim: Optimizer - _critic_optim: Optimizer def __init__( self, observation_shape: Shape, action_size: int, - policy: Policy, - q_funcs: nn.ModuleList, + modules: DDPGBaseModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - actor_optim: Optimizer, - critic_optim: Optimizer, gamma: float, tau: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - checkpointer=checkpointer, + modules=modules, device=device, ) self._gamma = gamma self._tau = tau - self._policy = policy - self._q_funcs = q_funcs self._q_func_forwarder = q_func_forwarder - self._actor_optim = actor_optim - self._critic_optim = critic_optim - self._targ_q_funcs = targ_q_funcs self._targ_q_func_forwarder = targ_q_func_forwarder - self._targ_policy = copy.deepcopy(policy) - hard_sync(self._targ_q_funcs, self._q_funcs) + hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) @train_api def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: - self._critic_optim.zero_grad() + self._modules.critic_optim.zero_grad() q_tpn = self.compute_target(batch) loss = self.compute_critic_loss(batch, q_tpn) loss.backward() - self._critic_optim.step() + self._modules.critic_optim.step() return {"critic_loss": float(loss.cpu().detach().numpy())} @@ -97,14 +90,14 @@ def compute_critic_loss( @train_api def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: # Q function should be inference mode for stability - self._q_funcs.eval() + self._modules.q_funcs.eval() - self._actor_optim.zero_grad() + self._modules.actor_optim.zero_grad() loss = self.compute_actor_loss(batch) loss.backward() - self._actor_optim.step() + self._modules.actor_optim.step() return {"actor_loss": float(loss.cpu().detach().numpy())} @@ -117,38 +110,65 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: pass def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - return self._policy(x).squashed_mu + return self._modules.policy(x).squashed_mu @abstractmethod def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: pass def update_critic_target(self) -> None: - soft_sync(self._targ_q_funcs, self._q_funcs, self._tau) - - def update_actor_target(self) -> None: - soft_sync(self._targ_policy, self._policy, self._tau) + soft_sync(self._modules.targ_q_funcs, self._modules.q_funcs, self._tau) @property def policy(self) -> Policy: - return self._policy + return self._modules.policy @property def policy_optim(self) -> Optimizer: - return self._actor_optim + return self._modules.actor_optim @property def q_function(self) -> nn.ModuleList: - return self._q_funcs + return self._modules.q_funcs @property def q_function_optim(self) -> Optimizer: - return self._critic_optim + return self._modules.critic_optim + + +@dataclasses.dataclass(frozen=True) +class DDPGModules(DDPGBaseModules): + targ_policy: Policy class DDPGImpl(DDPGBaseImpl): + _modules: DDPGModules + + def __init__( + self, + observation_shape: Shape, + action_size: int, + modules: DDPGModules, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + gamma: float, + tau: float, + device: str, + ): + super().__init__( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, + gamma=gamma, + tau=tau, + device=device, + ) + hard_sync(self._modules.targ_policy, self._modules.policy) + def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: - action = self._policy(batch.observations) + action = self._modules.policy(batch.observations) q_t = self._q_func_forwarder.compute_expected_q( batch.observations, action.squashed_mu, "none" )[0] @@ -156,7 +176,7 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): - action = self._targ_policy(batch.next_observations) + action = self._modules.targ_policy(batch.next_observations) return self._targ_q_func_forwarder.compute_target( batch.next_observations, action.squashed_mu.clamp(-1.0, 1.0), @@ -165,3 +185,6 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: return self.inner_predict_best_action(x) + + def update_actor_target(self) -> None: + soft_sync(self._modules.targ_policy, self._modules.policy, self._tau) diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index e894418d..4d1fe0f0 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -1,3 +1,4 @@ +import dataclasses from typing import Dict import torch @@ -6,58 +7,57 @@ from ....dataset import Shape from ....models.torch import DiscreteEnsembleQFunctionForwarder -from ....torch_utility import Checkpointer, TorchMiniBatch, hard_sync, train_api +from ....torch_utility import Modules, TorchMiniBatch, hard_sync, train_api from ..base import QLearningAlgoImplBase from .utility import DiscreteQFunctionMixin -__all__ = ["DQNImpl", "DoubleDQNImpl"] +__all__ = ["DQNImpl", "DQNModules", "DoubleDQNImpl"] + + +@dataclasses.dataclass(frozen=True) +class DQNModules(Modules): + q_funcs: nn.ModuleList + targ_q_funcs: nn.ModuleList + optim: Optimizer class DQNImpl(DiscreteQFunctionMixin, QLearningAlgoImplBase): + _modules: DQNModules _gamma: float - _q_funcs: nn.ModuleList - _targ_q_funcs: nn.ModuleList _q_func_forwarder: DiscreteEnsembleQFunctionForwarder _targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder - _optim: Optimizer def __init__( self, observation_shape: Shape, action_size: int, - q_funcs: nn.ModuleList, + modules: DQNModules, q_func_forwarder: DiscreteEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, - optim: Optimizer, gamma: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - checkpointer=checkpointer, + modules=modules, device=device, ) self._gamma = gamma - self._q_funcs = q_funcs self._q_func_forwarder = q_func_forwarder - self._targ_q_funcs = targ_q_funcs self._targ_q_func_forwarder = targ_q_func_forwarder - self._optim = optim - hard_sync(targ_q_funcs, q_funcs) + hard_sync(modules.targ_q_funcs, modules.q_funcs) @train_api def update(self, batch: TorchMiniBatch) -> Dict[str, float]: - self._optim.zero_grad() + self._modules.optim.zero_grad() q_tpn = self.compute_target(batch) loss = self.compute_loss(batch, q_tpn) loss.backward() - self._optim.step() + self._modules.optim.step() return {"loss": float(loss.cpu().detach().numpy())} @@ -94,15 +94,15 @@ def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: return self.inner_predict_best_action(x) def update_target(self) -> None: - hard_sync(self._targ_q_funcs, self._q_funcs) + hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) @property def q_function(self) -> nn.ModuleList: - return self._q_funcs + return self._modules.q_funcs @property def q_function_optim(self) -> Optimizer: - return self._optim + return self._modules.optim class DoubleDQNImpl(DQNImpl): diff --git a/d3rlpy/algos/qlearning/torch/iql_impl.py b/d3rlpy/algos/qlearning/torch/iql_impl.py index aa39c95f..0bb068c3 100644 --- a/d3rlpy/algos/qlearning/torch/iql_impl.py +++ b/d3rlpy/algos/qlearning/torch/iql_impl.py @@ -1,8 +1,7 @@ +import dataclasses from typing import Dict import torch -from torch import nn -from torch.optim import Optimizer from ....dataset import Shape from ....models.torch import ( @@ -11,15 +10,20 @@ ValueFunction, build_gaussian_distribution, ) -from ....torch_utility import Checkpointer, TorchMiniBatch, train_api -from .ddpg_impl import DDPGBaseImpl +from ....torch_utility import TorchMiniBatch, train_api +from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules -__all__ = ["IQLImpl"] +__all__ = ["IQLImpl", "IQLModules"] + + +@dataclasses.dataclass(frozen=True) +class IQLModules(DDPGBaseModules): + policy: NormalPolicy + value_func: ValueFunction class IQLImpl(DDPGBaseImpl): - _policy: NormalPolicy - _value_func: ValueFunction + _modules: IQLModules _expectile: float _weight_temp: float _max_weight: float @@ -28,41 +32,29 @@ def __init__( self, observation_shape: Shape, action_size: int, - policy: NormalPolicy, - q_funcs: nn.ModuleList, + modules: IQLModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - value_func: ValueFunction, - actor_optim: Optimizer, - critic_optim: Optimizer, gamma: float, tau: float, expectile: float, weight_temp: float, max_weight: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - actor_optim=actor_optim, - critic_optim=critic_optim, gamma=gamma, tau=tau, - checkpointer=checkpointer, device=device, ) self._expectile = expectile self._weight_temp = weight_temp self._max_weight = max_weight - self._value_func = value_func def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor @@ -78,11 +70,13 @@ def compute_critic_loss( def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): - return self._value_func(batch.next_observations) + return self._modules.value_func(batch.next_observations) def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: # compute log probability - dist = build_gaussian_distribution(self._policy(batch.observations)) + dist = build_gaussian_distribution( + self._modules.policy(batch.observations) + ) log_probs = dist.log_prob(batch.actions) # compute weight @@ -95,7 +89,7 @@ def _compute_weight(self, batch: TorchMiniBatch) -> torch.Tensor: q_t = self._targ_q_func_forwarder.compute_expected_q( batch.observations, batch.actions, "min" ) - v_t = self._value_func(batch.observations) + v_t = self._modules.value_func(batch.observations) adv = q_t - v_t return (self._weight_temp * adv).exp().clamp(max=self._max_weight) @@ -103,7 +97,7 @@ def compute_value_loss(self, batch: TorchMiniBatch) -> torch.Tensor: q_t = self._targ_q_func_forwarder.compute_expected_q( batch.observations, batch.actions, "min" ) - v_t = self._value_func(batch.observations) + v_t = self._modules.value_func(batch.observations) diff = q_t.detach() - v_t weight = (self._expectile - (diff < 0.0).float()).abs().detach() return (weight * (diff**2)).mean() @@ -112,7 +106,7 @@ def compute_value_loss(self, batch: TorchMiniBatch) -> torch.Tensor: def update_critic_and_state_value( self, batch: TorchMiniBatch ) -> Dict[str, float]: - self._critic_optim.zero_grad() + self._modules.critic_optim.zero_grad() # compute Q-function loss q_tpn = self.compute_target(batch) @@ -124,7 +118,7 @@ def update_critic_and_state_value( loss = q_loss + v_loss loss.backward() - self._critic_optim.step() + self._modules.critic_optim.step() return { "critic_loss": float(q_loss.cpu().detach().numpy()), @@ -132,5 +126,5 @@ def update_critic_and_state_value( } def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: - dist = build_gaussian_distribution(self._policy(x)) + dist = build_gaussian_distribution(self._modules.policy(x)) return dist.sample() diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index 400d51ac..5d54ac49 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -1,8 +1,7 @@ -import copy +import dataclasses from typing import Dict import torch -from torch import nn from torch.optim import Optimizer from ....dataset import Shape @@ -14,88 +13,86 @@ compute_vae_error, forward_vae_decode, ) -from ....torch_utility import Checkpointer, TorchMiniBatch, soft_sync, train_api -from .ddpg_impl import DDPGBaseImpl +from ....torch_utility import TorchMiniBatch, soft_sync, train_api +from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules -__all__ = ["PLASImpl", "PLASWithPerturbationImpl"] +__all__ = [ + "PLASImpl", + "PLASWithPerturbationImpl", + "PLASModules", + "PLASWithPerturbationModules", +] + + +@dataclasses.dataclass(frozen=True) +class PLASModules(DDPGBaseModules): + policy: DeterministicPolicy + targ_policy: DeterministicPolicy + imitator: ConditionalVAE + imitator_optim: Optimizer class PLASImpl(DDPGBaseImpl): + _modules: PLASModules _lam: float _beta: float - _policy: DeterministicPolicy - _targ_policy: DeterministicPolicy - _imitator: ConditionalVAE - _imitator_optim: Optimizer def __init__( self, observation_shape: Shape, action_size: int, - policy: DeterministicPolicy, - q_funcs: nn.ModuleList, + modules: PLASModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - imitator: ConditionalVAE, - actor_optim: Optimizer, - critic_optim: Optimizer, - imitator_optim: Optimizer, gamma: float, tau: float, lam: float, beta: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - actor_optim=actor_optim, - critic_optim=critic_optim, gamma=gamma, tau=tau, - checkpointer=checkpointer, device=device, ) self._lam = lam self._beta = beta - self._imitator = imitator - self._imitator_optim = imitator_optim @train_api def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: - self._imitator_optim.zero_grad() + self._modules.imitator_optim.zero_grad() loss = compute_vae_error( - vae=self._imitator, + vae=self._modules.imitator, x=batch.observations, action=batch.actions, beta=self._beta, ) loss.backward() - self._imitator_optim.step() + self._modules.imitator_optim.step() return {"imitator_loss": float(loss.cpu().detach().numpy())} def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: - latent_actions = 2.0 * self._policy(batch.observations).squashed_mu + latent_actions = ( + 2.0 * self._modules.policy(batch.observations).squashed_mu + ) actions = forward_vae_decode( - self._imitator, batch.observations, latent_actions + self._modules.imitator, batch.observations, latent_actions ) return -self._q_func_forwarder.compute_expected_q( batch.observations, actions, "none" )[0].mean() def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - latent_actions = 2.0 * self._policy(x).squashed_mu - return forward_vae_decode(self._imitator, x, latent_actions) + latent_actions = 2.0 * self._modules.policy(x).squashed_mu + return forward_vae_decode(self._modules.imitator, x, latent_actions) def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: return self.inner_predict_best_action(x) @@ -103,10 +100,11 @@ def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): latent_actions = ( - 2.0 * self._targ_policy(batch.next_observations).squashed_mu + 2.0 + * self._modules.targ_policy(batch.next_observations).squashed_mu ) actions = forward_vae_decode( - self._imitator, batch.next_observations, latent_actions + self._modules.imitator, batch.next_observations, latent_actions ) return self._targ_q_func_forwarder.compute_target( batch.next_observations, @@ -115,60 +113,53 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: self._lam, ) + def update_actor_target(self) -> None: + soft_sync(self._modules.targ_policy, self._modules.policy, self._tau) + + +@dataclasses.dataclass(frozen=True) +class PLASWithPerturbationModules(PLASModules): + perturbation: DeterministicResidualPolicy + targ_perturbation: DeterministicResidualPolicy + class PLASWithPerturbationImpl(PLASImpl): - _perturbation: DeterministicResidualPolicy - _targ_perturbation: DeterministicResidualPolicy + _modules: PLASWithPerturbationModules def __init__( self, observation_shape: Shape, action_size: int, - policy: DeterministicPolicy, - q_funcs: nn.ModuleList, + modules: PLASWithPerturbationModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - imitator: ConditionalVAE, - perturbation: DeterministicResidualPolicy, - actor_optim: Optimizer, - critic_optim: Optimizer, - imitator_optim: Optimizer, gamma: float, tau: float, lam: float, beta: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - imitator=imitator, - actor_optim=actor_optim, - critic_optim=critic_optim, - imitator_optim=imitator_optim, gamma=gamma, tau=tau, lam=lam, beta=beta, - checkpointer=checkpointer, device=device, ) - self._perturbation = perturbation - self._targ_perturbation = copy.deepcopy(perturbation) def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: - latent_actions = 2.0 * self._policy(batch.observations).squashed_mu + latent_actions = ( + 2.0 * self._modules.policy(batch.observations).squashed_mu + ) actions = forward_vae_decode( - self._imitator, batch.observations, latent_actions + self._modules.imitator, batch.observations, latent_actions ) - residual_actions = self._perturbation( + residual_actions = self._modules.perturbation( batch.observations, actions ).squashed_mu q_value = self._q_func_forwarder.compute_expected_q( @@ -177,9 +168,9 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: return -q_value[0].mean() def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - latent_actions = 2.0 * self._policy(x).squashed_mu - actions = forward_vae_decode(self._imitator, x, latent_actions) - return self._perturbation(x, actions).squashed_mu + latent_actions = 2.0 * self._modules.policy(x).squashed_mu + actions = forward_vae_decode(self._modules.imitator, x, latent_actions) + return self._modules.perturbation(x, actions).squashed_mu def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: return self.inner_predict_best_action(x) @@ -187,12 +178,13 @@ def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): latent_actions = ( - 2.0 * self._targ_policy(batch.next_observations).squashed_mu + 2.0 + * self._modules.targ_policy(batch.next_observations).squashed_mu ) actions = forward_vae_decode( - self._imitator, batch.next_observations, latent_actions + self._modules.imitator, batch.next_observations, latent_actions ) - residual_actions = self._targ_perturbation( + residual_actions = self._modules.targ_perturbation( batch.next_observations, actions ) return self._targ_q_func_forwarder.compute_target( @@ -204,4 +196,8 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: def update_actor_target(self) -> None: super().update_actor_target() - soft_sync(self._targ_perturbation, self._perturbation, self._tau) + soft_sync( + self._modules.targ_perturbation, + self._modules.perturbation, + self._tau, + ) diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index 6a87b7f1..37bc7997 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -1,3 +1,4 @@ +import dataclasses import math from typing import Dict @@ -11,64 +12,57 @@ CategoricalPolicy, ContinuousEnsembleQFunctionForwarder, DiscreteEnsembleQFunctionForwarder, + NormalPolicy, Parameter, Policy, build_squashed_gaussian_distribution, ) -from ....torch_utility import Checkpointer, TorchMiniBatch, hard_sync, train_api +from ....torch_utility import Modules, TorchMiniBatch, hard_sync, train_api from ..base import QLearningAlgoImplBase -from .ddpg_impl import DDPGBaseImpl +from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules from .utility import DiscreteQFunctionMixin -__all__ = ["SACImpl", "DiscreteSACImpl"] +__all__ = ["SACImpl", "DiscreteSACImpl", "SACModules", "DiscreteSACModules"] + + +@dataclasses.dataclass(frozen=True) +class SACModules(DDPGBaseModules): + policy: NormalPolicy + log_temp: Parameter + temp_optim: Optimizer class SACImpl(DDPGBaseImpl): - _log_temp: Parameter - _temp_optim: Optimizer + _modules: SACModules def __init__( self, observation_shape: Shape, action_size: int, - policy: Policy, - q_funcs: nn.ModuleList, + modules: SACModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - log_temp: Parameter, - actor_optim: Optimizer, - critic_optim: Optimizer, - temp_optim: Optimizer, gamma: float, tau: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - actor_optim=actor_optim, - critic_optim=critic_optim, gamma=gamma, tau=tau, - checkpointer=checkpointer, device=device, ) - self._log_temp = log_temp - self._temp_optim = temp_optim def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: dist = build_squashed_gaussian_distribution( - self._policy(batch.observations) + self._modules.policy(batch.observations) ) action, log_prob = dist.sample_with_log_prob() - entropy = self._log_temp().exp() * log_prob + entropy = self._modules.log_temp().exp() * log_prob q_t = self._q_func_forwarder.compute_expected_q( batch.observations, action, "min" ) @@ -76,22 +70,22 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: @train_api def update_temp(self, batch: TorchMiniBatch) -> Dict[str, float]: - self._temp_optim.zero_grad() + self._modules.temp_optim.zero_grad() with torch.no_grad(): dist = build_squashed_gaussian_distribution( - self._policy(batch.observations) + self._modules.policy(batch.observations) ) _, log_prob = dist.sample_with_log_prob() targ_temp = log_prob - self._action_size - loss = -(self._log_temp().exp() * targ_temp).mean() + loss = -(self._modules.log_temp().exp() * targ_temp).mean() loss.backward() - self._temp_optim.step() + self._modules.temp_optim.step() # current temperature value - cur_temp = self._log_temp().exp().cpu().detach().numpy()[0][0] + cur_temp = self._modules.log_temp().exp().cpu().detach().numpy()[0][0] return { "temp_loss": float(loss.cpu().detach().numpy()), @@ -101,10 +95,10 @@ def update_temp(self, batch: TorchMiniBatch) -> Dict[str, float]: def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): dist = build_squashed_gaussian_distribution( - self._policy(batch.next_observations) + self._modules.policy(batch.next_observations) ) action, log_prob = dist.sample_with_log_prob() - entropy = self._log_temp().exp() * log_prob + entropy = self._modules.log_temp().exp() * log_prob target = self._targ_q_func_forwarder.compute_target( batch.next_observations, action, @@ -113,74 +107,65 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: return target - entropy def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: - dist = build_squashed_gaussian_distribution(self._policy(x)) + dist = build_squashed_gaussian_distribution(self._modules.policy(x)) return dist.sample() +@dataclasses.dataclass(frozen=True) +class DiscreteSACModules(Modules): + policy: CategoricalPolicy + q_funcs: nn.ModuleList + targ_q_funcs: nn.ModuleList + log_temp: Parameter + actor_optim: Optimizer + critic_optim: Optimizer + temp_optim: Optimizer + + class DiscreteSACImpl(DiscreteQFunctionMixin, QLearningAlgoImplBase): - _policy: CategoricalPolicy - _q_funcss: nn.ModuleList + _modules: DiscreteSACModules _q_func_forwarder: DiscreteEnsembleQFunctionForwarder - _targ_q_funcs: nn.ModuleList _targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder - _log_temp: Parameter - _actor_optim: Optimizer - _critic_optim: Optimizer - _temp_optim: Optimizer def __init__( self, observation_shape: Shape, action_size: int, - q_funcs: nn.ModuleList, + modules: DiscreteSACModules, q_func_forwarder: DiscreteEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, - policy: CategoricalPolicy, - log_temp: Parameter, - actor_optim: Optimizer, - critic_optim: Optimizer, - temp_optim: Optimizer, gamma: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - checkpointer=checkpointer, + modules=modules, device=device, ) self._gamma = gamma - self._q_funcs = q_funcs self._q_func_forwarder = q_func_forwarder - self._targ_q_funcs = targ_q_funcs self._targ_q_func_forwarder = targ_q_func_forwarder - self._policy = policy - self._log_temp = log_temp - self._actor_optim = actor_optim - self._critic_optim = critic_optim - self._temp_optim = temp_optim - hard_sync(targ_q_funcs, q_funcs) + hard_sync(modules.targ_q_funcs, modules.q_funcs) @train_api def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: - self._critic_optim.zero_grad() + self._modules.critic_optim.zero_grad() q_tpn = self.compute_target(batch) loss = self.compute_critic_loss(batch, q_tpn) loss.backward() - self._critic_optim.step() + self._modules.critic_optim.step() return {"critic_loss": float(loss.cpu().detach().numpy())} def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): - dist = self._policy(batch.next_observations) + dist = self._modules.policy(batch.next_observations) log_probs = dist.logits probs = dist.probs - entropy = self._log_temp().exp() * log_probs + entropy = self._modules.log_temp().exp() * log_probs target = self._targ_q_func_forwarder.compute_target( batch.next_observations ) @@ -208,14 +193,14 @@ def compute_critic_loss( @train_api def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: # Q function should be inference mode for stability - self._q_funcs.eval() + self._modules.q_funcs.eval() - self._actor_optim.zero_grad() + self._modules.actor_optim.zero_grad() loss = self.compute_actor_loss(batch) loss.backward() - self._actor_optim.step() + self._modules.actor_optim.step() return {"actor_loss": float(loss.cpu().detach().numpy())} @@ -224,31 +209,31 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: q_t = self._q_func_forwarder.compute_expected_q( batch.observations, reduction="min" ) - dist = self._policy(batch.observations) + dist = self._modules.policy(batch.observations) log_probs = dist.logits probs = dist.probs - entropy = self._log_temp().exp() * log_probs + entropy = self._modules.log_temp().exp() * log_probs return (probs * (entropy - q_t)).sum(dim=1).mean() @train_api def update_temp(self, batch: TorchMiniBatch) -> Dict[str, float]: - self._temp_optim.zero_grad() + self._modules.temp_optim.zero_grad() with torch.no_grad(): - dist = self._policy(batch.observations) + dist = self._modules.policy(batch.observations) log_probs = F.log_softmax(dist.logits, dim=1) probs = dist.probs expct_log_probs = (probs * log_probs).sum(dim=1, keepdim=True) entropy_target = 0.98 * (-math.log(1 / self.action_size)) targ_temp = expct_log_probs + entropy_target - loss = -(self._log_temp().exp() * targ_temp).mean() + loss = -(self._modules.log_temp().exp() * targ_temp).mean() loss.backward() - self._temp_optim.step() + self._modules.temp_optim.step() # current temperature value - cur_temp = self._log_temp().exp().cpu().detach().numpy()[0][0] + cur_temp = self._modules.log_temp().exp().cpu().detach().numpy()[0][0] return { "temp_loss": float(loss.cpu().detach().numpy()), @@ -256,28 +241,28 @@ def update_temp(self, batch: TorchMiniBatch) -> Dict[str, float]: } def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: - dist = self._policy(x) + dist = self._modules.policy(x) return dist.probs.argmax(dim=1) def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: - dist = self._policy(x) + dist = self._modules.policy(x) return dist.sample() def update_target(self) -> None: - hard_sync(self._targ_q_funcs, self._q_funcs) + hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) @property def policy(self) -> Policy: - return self._policy + return self._modules.policy @property def policy_optim(self) -> Optimizer: - return self._actor_optim + return self._modules.actor_optim @property def q_function(self) -> nn.ModuleList: - return self._q_funcs + return self._modules.q_funcs @property def q_function_optim(self) -> Optimizer: - return self._critic_optim + return self._modules.critic_optim diff --git a/d3rlpy/algos/qlearning/torch/td3_impl.py b/d3rlpy/algos/qlearning/torch/td3_impl.py index 8868ecee..896f3221 100644 --- a/d3rlpy/algos/qlearning/torch/td3_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_impl.py @@ -1,14 +1,9 @@ import torch -from torch import nn -from torch.optim import Optimizer from ....dataset import Shape -from ....models.torch import ( - ContinuousEnsembleQFunctionForwarder, - DeterministicPolicy, -) -from ....torch_utility import Checkpointer, TorchMiniBatch -from .ddpg_impl import DDPGImpl +from ....models.torch import ContinuousEnsembleQFunctionForwarder +from ....torch_utility import TorchMiniBatch +from .ddpg_impl import DDPGImpl, DDPGModules __all__ = ["TD3Impl"] @@ -21,33 +16,23 @@ def __init__( self, observation_shape: Shape, action_size: int, - policy: DeterministicPolicy, - q_funcs: nn.ModuleList, + modules: DDPGModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - actor_optim: Optimizer, - critic_optim: Optimizer, gamma: float, tau: float, target_smoothing_sigma: float, target_smoothing_clip: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - actor_optim=actor_optim, - critic_optim=critic_optim, gamma=gamma, tau=tau, - checkpointer=checkpointer, device=device, ) self._target_smoothing_sigma = target_smoothing_sigma @@ -55,7 +40,7 @@ def __init__( def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): - action = self._targ_policy(batch.next_observations) + action = self._modules.targ_policy(batch.next_observations) # smoothing target noise = torch.randn(action.mu.shape, device=batch.device) scaled_noise = self._target_smoothing_sigma * noise diff --git a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py index 4ffdc5c0..c3a93f8a 100644 --- a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py @@ -1,15 +1,11 @@ # pylint: disable=too-many-ancestors import torch -from torch import nn -from torch.optim import Optimizer from ....dataset import Shape -from ....models.torch import ( - ContinuousEnsembleQFunctionForwarder, - DeterministicPolicy, -) -from ....torch_utility import Checkpointer, TorchMiniBatch +from ....models.torch import ContinuousEnsembleQFunctionForwarder +from ....torch_utility import TorchMiniBatch +from .ddpg_impl import DDPGModules from .td3_impl import TD3Impl __all__ = ["TD3PlusBCImpl"] @@ -22,42 +18,32 @@ def __init__( self, observation_shape: Shape, action_size: int, - policy: DeterministicPolicy, - q_funcs: nn.ModuleList, + modules: DDPGModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - actor_optim: Optimizer, - critic_optim: Optimizer, gamma: float, tau: float, target_smoothing_sigma: float, target_smoothing_clip: float, alpha: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - policy=policy, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - actor_optim=actor_optim, - critic_optim=critic_optim, gamma=gamma, tau=tau, target_smoothing_sigma=target_smoothing_sigma, target_smoothing_clip=target_smoothing_clip, - checkpointer=checkpointer, device=device, ) self._alpha = alpha def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: - action = self._policy(batch.observations).squashed_mu + action = self._modules.policy(batch.observations).squashed_mu q_t = self._q_func_forwarder.compute_expected_q( batch.observations, action, "none" )[0] diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index ad34f362..73603cd5 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -13,9 +13,12 @@ make_optimizer_field, ) from ...models.builders import create_continuous_decision_transformer -from ...torch_utility import Checkpointer, TorchTrajectoryMiniBatch +from ...torch_utility import TorchTrajectoryMiniBatch from .base import TransformerAlgoBase, TransformerConfig -from .torch.decision_transformer_impl import DecisionTransformerImpl +from .torch.decision_transformer_impl import ( + DecisionTransformerImpl, + DecisionTransformerModules, +) __all__ = ["DecisionTransformerConfig", "DecisionTransformer"] @@ -113,19 +116,17 @@ def inner_create_impl( if self._config.compile: transformer = torch.compile(transformer, fullgraph=True) - checkpointer = Checkpointer( - modules={"transformer": transformer, "optim": optim}, - device=self._device, + modules = DecisionTransformerModules( + transformer=transformer, + optim=optim, ) self._impl = DecisionTransformerImpl( observation_shape=observation_shape, action_size=action_size, - transformer=transformer, - optim=optim, + modules=modules, scheduler=scheduler, clip_grad_norm=self._config.clip_grad_norm, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py index 2ea74ef5..43345c5b 100644 --- a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py +++ b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py @@ -1,3 +1,4 @@ +import dataclasses from typing import Dict import torch @@ -6,7 +7,7 @@ from ....dataset import Shape from ....models.torch import ContinuousDecisionTransformer from ....torch_utility import ( - Checkpointer, + Modules, TorchTrajectoryMiniBatch, eval_api, train_api, @@ -14,12 +15,17 @@ from ..base import TransformerAlgoImplBase from ..inputs import TorchTransformerInput -__all__ = ["DecisionTransformerImpl"] +__all__ = ["DecisionTransformerImpl", "DecisionTransformerModules"] + + +@dataclasses.dataclass(frozen=True) +class DecisionTransformerModules(Modules): + transformer: ContinuousDecisionTransformer + optim: Optimizer class DecisionTransformerImpl(TransformerAlgoImplBase): - _transformer: ContinuousDecisionTransformer - _optim: torch.optim.Optimizer + _modules: DecisionTransformerModules _scheduler: torch.optim.lr_scheduler.LambdaLR _clip_grad_norm: float @@ -27,28 +33,24 @@ def __init__( self, observation_shape: Shape, action_size: int, - transformer: ContinuousDecisionTransformer, - optim: Optimizer, + modules: DecisionTransformerModules, scheduler: torch.optim.lr_scheduler.LambdaLR, clip_grad_norm: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - checkpointer=checkpointer, + modules=modules, device=device, ) - self._transformer = transformer - self._optim = optim self._scheduler = scheduler self._clip_grad_norm = clip_grad_norm @eval_api def predict(self, inpt: TorchTransformerInput) -> torch.Tensor: # (1, T, A) - action = self._transformer( + action = self._modules.transformer( inpt.observations, inpt.actions, inpt.returns_to_go, inpt.timesteps ) # (1, T, A) -> (A,) @@ -56,21 +58,21 @@ def predict(self, inpt: TorchTransformerInput) -> torch.Tensor: @train_api def update(self, batch: TorchTrajectoryMiniBatch) -> Dict[str, float]: - self._optim.zero_grad() + self._modules.optim.zero_grad() loss = self.compute_loss(batch) loss.backward() torch.nn.utils.clip_grad_norm_( - self._transformer.parameters(), self._clip_grad_norm + self._modules.transformer.parameters(), self._clip_grad_norm ) - self._optim.step() + self._modules.optim.step() self._scheduler.step() return {"loss": float(loss.cpu().detach().numpy())} def compute_loss(self, batch: TorchTrajectoryMiniBatch) -> torch.Tensor: - action = self._transformer( + action = self._modules.transformer( batch.observations, batch.actions, batch.returns_to_go, diff --git a/d3rlpy/base.py b/d3rlpy/base.py index e72f9c35..30143363 100644 --- a/d3rlpy/base.py +++ b/d3rlpy/base.py @@ -23,7 +23,7 @@ make_reward_scaler_field, ) from .serializable_config import DynamicConfig, generate_config_registration -from .torch_utility import Checkpointer +from .torch_utility import Checkpointer, Modules __all__ = [ "DeviceArg", @@ -48,6 +48,7 @@ class ImplBase(metaclass=ABCMeta): _observation_shape: Shape _action_size: int + _modules: Modules _checkpointer: Checkpointer _device: str @@ -55,12 +56,13 @@ def __init__( self, observation_shape: Shape, action_size: int, - checkpointer: Checkpointer, + modules: Modules, device: str, ): self._observation_shape = observation_shape self._action_size = action_size - self._checkpointer = checkpointer + self._modules = modules + self._checkpointer = modules.create_checkpointer(device) self._device = device def save_model(self, f: BinaryIO) -> None: @@ -81,6 +83,10 @@ def action_size(self) -> int: def device(self) -> str: return self._device + @property + def modules(self) -> Modules: + return self._modules + @dataclasses.dataclass() class LearnableConfig(DynamicConfig): diff --git a/d3rlpy/dataclass_utils.py b/d3rlpy/dataclass_utils.py new file mode 100644 index 00000000..a0b269c1 --- /dev/null +++ b/d3rlpy/dataclass_utils.py @@ -0,0 +1,10 @@ +import dataclasses +from typing import Any, Dict + +__all__ = ["asdict_without_copy"] + + +def asdict_without_copy(obj: Any) -> Dict[str, Any]: + assert dataclasses.is_dataclass(obj) + fields = dataclasses.fields(obj) + return {field.name: getattr(obj, field.name) for field in fields} diff --git a/d3rlpy/ope/fqe.py b/d3rlpy/ope/fqe.py index a08b6c65..d465e822 100644 --- a/d3rlpy/ope/fqe.py +++ b/d3rlpy/ope/fqe.py @@ -18,8 +18,13 @@ from ..models.encoders import EncoderFactory, make_encoder_field from ..models.optimizers import OptimizerFactory, make_optimizer_field from ..models.q_functions import QFunctionFactory, make_q_func_field -from ..torch_utility import Checkpointer, TorchMiniBatch, convert_to_torch -from .torch.fqe_impl import DiscreteFQEImpl, FQEBaseImpl, FQEImpl +from ..torch_utility import TorchMiniBatch, convert_to_torch +from .torch.fqe_impl import ( + DiscreteFQEImpl, + FQEBaseImpl, + FQEBaseModules, + FQEImpl, +) __all__ = ["FQEConfig", "FQE", "DiscreteFQE"] @@ -176,25 +181,19 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.learning_rate ) - checkpointer = Checkpointer( - modules={ - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "optim": optim, - }, - device=self._device, + modules = FQEBaseModules( + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + optim=optim, ) self._impl = FQEImpl( observation_shape=observation_shape, action_size=action_size, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - optim=optim, gamma=self._config.gamma, - checkpointer=checkpointer, device=self._device, ) @@ -253,24 +252,18 @@ def inner_create_impl( optim = self._config.optim_factory.create( q_funcs.parameters(), lr=self._config.learning_rate ) - checkpointer = Checkpointer( - modules={ - "q_func": q_funcs, - "targ_q_func": targ_q_funcs, - "optim": optim, - }, - device=self._device, + modules = FQEBaseModules( + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + optim=optim, ) self._impl = DiscreteFQEImpl( observation_shape=observation_shape, action_size=action_size, - q_funcs=q_funcs, + modules=modules, q_func_forwarder=q_func_forwarder, - targ_q_funcs=targ_q_funcs, targ_q_func_forwarder=targ_q_func_forwarder, - optim=optim, gamma=self._config.gamma, - checkpointer=checkpointer, device=self._device, ) diff --git a/d3rlpy/ope/torch/fqe_impl.py b/d3rlpy/ope/torch/fqe_impl.py index f1361371..ab0b94bd 100644 --- a/d3rlpy/ope/torch/fqe_impl.py +++ b/d3rlpy/ope/torch/fqe_impl.py @@ -1,3 +1,4 @@ +import dataclasses from typing import Union import torch @@ -14,55 +15,54 @@ ContinuousEnsembleQFunctionForwarder, DiscreteEnsembleQFunctionForwarder, ) -from ...torch_utility import Checkpointer, TorchMiniBatch, hard_sync, train_api +from ...torch_utility import Modules, TorchMiniBatch, hard_sync, train_api -__all__ = ["FQEBaseImpl", "FQEImpl", "DiscreteFQEImpl"] +__all__ = ["FQEBaseImpl", "FQEImpl", "DiscreteFQEImpl", "FQEBaseModules"] + + +@dataclasses.dataclass(frozen=True) +class FQEBaseModules(Modules): + q_funcs: nn.ModuleList + targ_q_funcs: nn.ModuleList + optim: Optimizer class FQEBaseImpl(QLearningAlgoImplBase): + _modules: FQEBaseModules _gamma: float - _q_funcs: nn.ModuleList _q_func_forwarder: Union[ DiscreteEnsembleQFunctionForwarder, ContinuousEnsembleQFunctionForwarder ] - _targ_q_funcs: nn.ModuleList _targ_q_func_forwarder: Union[ DiscreteEnsembleQFunctionForwarder, ContinuousEnsembleQFunctionForwarder ] - _optim: Optimizer def __init__( self, observation_shape: Shape, action_size: int, - q_funcs: nn.ModuleList, + modules: FQEBaseModules, q_func_forwarder: Union[ DiscreteEnsembleQFunctionForwarder, ContinuousEnsembleQFunctionForwarder, ], - targ_q_funcs: nn.ModuleList, targ_q_func_forwarder: Union[ DiscreteEnsembleQFunctionForwarder, ContinuousEnsembleQFunctionForwarder, ], - optim: Optimizer, gamma: float, - checkpointer: Checkpointer, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, - checkpointer=checkpointer, + modules=modules, device=device, ) self._gamma = gamma - self._q_funcs = q_funcs self._q_func_forwarder = q_func_forwarder - self._targ_q_funcs = targ_q_funcs self._targ_q_func_forwarder = targ_q_func_forwarder - self._optim = optim - hard_sync(targ_q_funcs, q_funcs) + hard_sync(modules.targ_q_funcs, modules.q_funcs) @train_api def update( @@ -71,9 +71,9 @@ def update( q_tpn = self.compute_target(batch, next_actions) loss = self.compute_loss(batch, q_tpn) - self._optim.zero_grad() + self._modules.optim.zero_grad() loss.backward() - self._optim.step() + self._modules.optim.step() return float(loss.cpu().detach().numpy()) @@ -100,7 +100,7 @@ def compute_target( ) def update_target(self) -> None: - hard_sync(self._targ_q_funcs, self._q_funcs) + hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index 6b58656b..a592709c 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -7,6 +7,7 @@ from torch import nn from torch.optim import Optimizer +from .dataclass_utils import asdict_without_copy from .dataset import TrajectoryMiniBatch, TransitionMiniBatch from .preprocessing import ActionScaler, ObservationScaler, RewardScaler @@ -14,18 +15,15 @@ "soft_sync", "hard_sync", "sync_optimizer_state", - "set_eval_mode", - "set_train_mode", "to_cuda", "to_cpu", "to_device", - "freeze", - "unfreeze", "reset_optimizer_states", "map_location", "TorchMiniBatch", "TorchTrajectoryMiniBatch", "Checkpointer", + "Modules", "convert_to_torch", "convert_to_torch_recursively", "eval_api", @@ -72,20 +70,6 @@ def sync_optimizer_state(targ_optim: Optimizer, optim: Optimizer) -> None: targ_optim.load_state_dict({"state": state, "param_groups": param_groups}) -def set_eval_mode(impl: Any) -> None: - for key in _get_attributes(impl): - module = getattr(impl, key) - if isinstance(module, torch.nn.Module): - module.eval() - - -def set_train_mode(impl: Any) -> None: - for key in _get_attributes(impl): - module = getattr(impl, key) - if isinstance(module, torch.nn.Module): - module.train() - - def to_cuda(impl: Any, device: str) -> None: for key in _get_attributes(impl): module = getattr(impl, key) @@ -107,22 +91,6 @@ def to_device(impl: Any, device: str) -> None: to_cpu(impl) -def freeze(impl: Any) -> None: - for key in _get_attributes(impl): - module = getattr(impl, key) - if isinstance(module, torch.nn.Module): - for p in module.parameters(): - p.requires_grad = False - - -def unfreeze(impl: Any) -> None: - for key in _get_attributes(impl): - module = getattr(impl, key) - if isinstance(module, torch.nn.Module): - for p in module.parameters(): - p.requires_grad = True - - def reset_optimizer_states(impl: Any) -> None: for key in _get_attributes(impl): obj = getattr(impl, key) @@ -285,13 +253,47 @@ def load(self, f: BinaryIO) -> None: for k, v in self._modules.items(): v.load_state_dict(chkpt[k]) + @property + def modules(self) -> Dict[str, Union[nn.Module, Optimizer]]: + return self._modules + + +@dataclasses.dataclass(frozen=True) +class Modules: + def create_checkpointer(self, device: str) -> Checkpointer: + return Checkpointer(modules=asdict_without_copy(self), device=device) + + def freeze(self) -> None: + for v in asdict_without_copy(self).values(): + if isinstance(v, nn.Module): + for p in v.parameters(): + p.requires_grad = False + + def unfreeze(self) -> None: + for v in asdict_without_copy(self).values(): + if isinstance(v, nn.Module): + for p in v.parameters(): + p.requires_grad = True + + def set_eval(self) -> None: + for v in asdict_without_copy(self).values(): + if isinstance(v, nn.Module): + v.eval() + + def set_train(self) -> None: + for v in asdict_without_copy(self).values(): + if isinstance(v, nn.Module): + v.train() + TCallable = TypeVar("TCallable") def eval_api(f: TCallable) -> TCallable: def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - set_eval_mode(self) + assert hasattr(self, "modules") + assert isinstance(self.modules, Modules) + self.modules.set_eval() return f(self, *args, **kwargs) # type: ignore return wrapper # type: ignore @@ -299,7 +301,9 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: def train_api(f: TCallable) -> TCallable: def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - set_train_mode(self) + assert hasattr(self, "modules") + assert isinstance(self.modules, Modules) + self.modules.set_train() return f(self, *args, **kwargs) # type: ignore return wrapper # type: ignore diff --git a/reproductions/finetuning/iql_finetune.py b/reproductions/finetuning/iql_finetune.py index 131626b6..6dead46e 100644 --- a/reproductions/finetuning/iql_finetune.py +++ b/reproductions/finetuning/iql_finetune.py @@ -36,7 +36,10 @@ def main() -> None: # workaround for learning scheduler iql.build_with_dataset(dataset) assert iql.impl - scheduler = CosineAnnealingLR(iql.impl._actor_optim, 1000000) + scheduler = CosineAnnealingLR( + iql.impl._modules.actor_optim, # pylint: disable=protected-access + 1000000, + ) def callback(algo: d3rlpy.algos.IQL, epoch: int, total_step: int) -> None: scheduler.step() @@ -53,7 +56,7 @@ def callback(algo: d3rlpy.algos.IQL, epoch: int, total_step: int) -> None: ) # reset learning rate - for g in iql.impl._actor_optim.param_groups: + for g in iql.impl._modules.actor_optim.param_groups: g["lr"] = iql.config.actor_learning_rate # prepare FIFO buffer filled with dataset episodes diff --git a/reproductions/offline/iql.py b/reproductions/offline/iql.py index 592371e3..e95d048e 100644 --- a/reproductions/offline/iql.py +++ b/reproductions/offline/iql.py @@ -36,7 +36,8 @@ def main() -> None: iql.build_with_dataset(dataset) assert iql.impl scheduler = CosineAnnealingLR( - iql.impl._actor_optim, 500000 # pylint: disable=protected-access + iql.impl._modules.actor_optim, # pylint: disable=protected-access + 500000, ) def callback(algo: d3rlpy.algos.IQL, epoch: int, total_step: int) -> None: diff --git a/tests/test_dataclass_utils.py b/tests/test_dataclass_utils.py new file mode 100644 index 00000000..22094b76 --- /dev/null +++ b/tests/test_dataclass_utils.py @@ -0,0 +1,24 @@ +import dataclasses + +from d3rlpy.dataclass_utils import asdict_without_copy + + +@dataclasses.dataclass(frozen=True) +class A: + a: int + + +@dataclasses.dataclass(frozen=True) +class D: + a: A + b: float + c: str + + +def test_asdict_without_any() -> None: + a = A(1) + d = D(a, 2.0, "3") + dict_d = asdict_without_copy(d) + assert dict_d["a"] is a + assert dict_d["b"] == 2.0 + assert dict_d["c"] == "3" diff --git a/tests/test_torch_utility.py b/tests/test_torch_utility.py index fb710851..e13d0589 100644 --- a/tests/test_torch_utility.py +++ b/tests/test_torch_utility.py @@ -1,4 +1,5 @@ import copy +import dataclasses from io import BytesIO from typing import Any, Dict, Sequence from unittest.mock import Mock @@ -10,21 +11,18 @@ from d3rlpy.dataset import TrajectoryMiniBatch, Transition, TransitionMiniBatch from d3rlpy.torch_utility import ( Checkpointer, + Modules, Swish, TorchMiniBatch, TorchTrajectoryMiniBatch, View, eval_api, - freeze, hard_sync, map_location, reset_optimizer_states, - set_eval_mode, - set_train_mode, soft_sync, sync_optimizer_state, train_api, - unfreeze, ) from .dummy_scalers import ( @@ -115,17 +113,18 @@ def __init__(self) -> None: self.fc1 = torch.nn.Linear(100, 100) self.fc2 = torch.nn.Linear(100, 100) self.optim = torch.optim.Adam(self.fc1.parameters()) + self.modules = TestModules(self.fc1, self.optim) self.device = "cpu:0" @train_api def train_api_func(self) -> None: assert self.fc1.training - assert self.fc2.training + assert not self.fc2.training @eval_api def eval_api_func(self) -> None: assert not self.fc1.training - assert not self.fc2.training + assert self.fc2.training def check_if_same_dict(a: Dict[str, Any], b: Dict[str, Any]) -> None: @@ -155,28 +154,6 @@ def test_reset_optimizer_states() -> None: assert not reset_state -def test_eval_mode() -> None: - impl = DummyImpl() - impl.fc1.train() - impl.fc2.train() - - set_eval_mode(impl) - - assert not impl.fc1.training - assert not impl.fc2.training - - -def test_train_mode() -> None: - impl = DummyImpl() - impl.fc1.eval() - impl.fc2.eval() - - set_train_mode(impl) - - assert impl.fc1.training - assert impl.fc2.training - - @pytest.mark.skip(reason="no way to test this") def test_to_cuda() -> None: pass @@ -187,26 +164,32 @@ def test_to_cpu() -> None: pass -def test_freeze() -> None: - impl = DummyImpl() - - freeze(impl) +@dataclasses.dataclass(frozen=True) +class TestModules(Modules): + fc: torch.nn.Linear + optim: torch.optim.Adam - for p in impl.fc1.parameters(): - assert not p.requires_grad - for p in impl.fc2.parameters(): - assert not p.requires_grad +def test_modules() -> None: + fc = torch.nn.Linear(100, 200) + optim = torch.optim.Adam(fc.parameters()) + modules = TestModules(fc, optim) -def test_unfreeze() -> None: - impl = DummyImpl() + # check checkpointer + checkpointer = modules.create_checkpointer("cpu:0") + assert "fc" in checkpointer.modules + assert "optim" in checkpointer.modules + assert checkpointer.modules["fc"] is fc + assert checkpointer.modules["optim"] is optim - freeze(impl) - unfreeze(impl) + # check freeze + modules.freeze() + for p in fc.parameters(): + assert not p.requires_grad - for p in impl.fc1.parameters(): - assert p.requires_grad - for p in impl.fc2.parameters(): + # check unfreeze + modules.unfreeze() + for p in fc.parameters(): assert p.requires_grad From 4ba297fc6cd62201f7cd4edb7759138182e4ce04 Mon Sep 17 00:00:00 2001 From: takuseno Date: Mon, 14 Aug 2023 21:07:03 +0900 Subject: [PATCH 12/20] Refactor reset_optimizer_states --- d3rlpy/algos/qlearning/base.py | 3 +- d3rlpy/torch_utility.py | 51 ++++------------------------------ tests/test_torch_utility.py | 3 +- 3 files changed, 8 insertions(+), 49 deletions(-) diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index 4033890c..fd4bca14 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -44,7 +44,6 @@ convert_to_torch_recursively, eval_api, hard_sync, - reset_optimizer_states, sync_optimizer_state, ) from ..utility import ( @@ -144,7 +143,7 @@ def copy_q_function_optim_from(self, impl: "QLearningAlgoImplBase") -> None: sync_optimizer_state(self.q_function_optim, impl.q_function_optim) def reset_optimizer_states(self) -> None: - reset_optimizer_states(self) + self.modules.reset_optimizer_states() TQLearningImpl = TypeVar("TQLearningImpl", bound=QLearningAlgoImplBase) diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index a592709c..14ea9493 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -1,6 +1,6 @@ import collections import dataclasses -from typing import Any, BinaryIO, Dict, List, Optional, Sequence, TypeVar, Union +from typing import Any, BinaryIO, Dict, Optional, Sequence, TypeVar, Union import numpy as np import torch @@ -15,10 +15,6 @@ "soft_sync", "hard_sync", "sync_optimizer_state", - "to_cuda", - "to_cpu", - "to_device", - "reset_optimizer_states", "map_location", "TorchMiniBatch", "TorchTrajectoryMiniBatch", @@ -32,18 +28,6 @@ ] -IGNORE_LIST = [ - "policy", - "q_function", - "policy_optim", - "q_function_optim", -] # special properties - - -def _get_attributes(obj: Any) -> List[str]: - return [key for key in dir(obj) if key not in IGNORE_LIST] - - def soft_sync(targ_model: nn.Module, model: nn.Module, tau: float) -> None: with torch.no_grad(): params = model.parameters() @@ -70,34 +54,6 @@ def sync_optimizer_state(targ_optim: Optimizer, optim: Optimizer) -> None: targ_optim.load_state_dict({"state": state, "param_groups": param_groups}) -def to_cuda(impl: Any, device: str) -> None: - for key in _get_attributes(impl): - module = getattr(impl, key) - if isinstance(module, (torch.nn.Module, torch.nn.Parameter)): - module.cuda(device) - - -def to_cpu(impl: Any) -> None: - for key in _get_attributes(impl): - module = getattr(impl, key) - if isinstance(module, (torch.nn.Module, torch.nn.Parameter)): - module.cpu() - - -def to_device(impl: Any, device: str) -> None: - if device.startswith("cuda"): - to_cuda(impl, device) - else: - to_cpu(impl) - - -def reset_optimizer_states(impl: Any) -> None: - for key in _get_attributes(impl): - obj = getattr(impl, key) - if isinstance(obj, torch.optim.Optimizer): - obj.state = collections.defaultdict(dict) - - def map_location(device: str) -> Any: if "cuda" in device: return lambda storage, loc: storage.cuda(device) @@ -285,6 +241,11 @@ def set_train(self) -> None: if isinstance(v, nn.Module): v.train() + def reset_optimizer_states(self) -> None: + for v in asdict_without_copy(self).values(): + if isinstance(v, torch.optim.Optimizer): + v.state = collections.defaultdict(dict) + TCallable = TypeVar("TCallable") diff --git a/tests/test_torch_utility.py b/tests/test_torch_utility.py index e13d0589..33c080a4 100644 --- a/tests/test_torch_utility.py +++ b/tests/test_torch_utility.py @@ -19,7 +19,6 @@ eval_api, hard_sync, map_location, - reset_optimizer_states, soft_sync, sync_optimizer_state, train_api, @@ -147,7 +146,7 @@ def test_reset_optimizer_states() -> None: state = copy.deepcopy(impl.optim.state) assert state - reset_optimizer_states(impl) + impl.modules.reset_optimizer_states() # check if state is empty reset_state = impl.optim.state From 2d730efeff9624d9a67a826b82b44bc4974aa95b Mon Sep 17 00:00:00 2001 From: takuseno Date: Wed, 16 Aug 2023 00:09:19 +0900 Subject: [PATCH 13/20] Refactor to support torch.compile --- d3rlpy/algos/qlearning/base.py | 2 +- d3rlpy/algos/qlearning/bc.py | 1 + d3rlpy/algos/qlearning/bcq.py | 7 +++++-- d3rlpy/algos/qlearning/torch/bc_impl.py | 9 +++++++-- d3rlpy/models/torch/policies.py | 10 ++++++---- 5 files changed, 20 insertions(+), 9 deletions(-) diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index fd4bca14..76ac3d02 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -195,7 +195,7 @@ def save_policy(self, fname: str) -> None: # workaround until version 1.6 self._impl.modules.freeze() - # dummy function to select best actions + # local function to select best actions def _func(x: torch.Tensor) -> torch.Tensor: assert self._impl diff --git a/d3rlpy/algos/qlearning/bc.py b/d3rlpy/algos/qlearning/bc.py index 84c0321e..90762272 100644 --- a/d3rlpy/algos/qlearning/bc.py +++ b/d3rlpy/algos/qlearning/bc.py @@ -109,6 +109,7 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, modules=modules, + policy_type=self._config.policy_type, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index c4b50a27..ac81ce7d 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -14,7 +14,7 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...models.torch import CategoricalPolicy, PixelEncoder, compute_output_size +from ...models.torch import CategoricalPolicy, compute_output_size from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.bcq_impl import ( @@ -327,6 +327,8 @@ class DiscreteBCQConfig(LearnableConfig): :math:`\tau`. beta (float): Reguralization term for imitation function. target_update_interval (int): Interval to update the target network. + share_encoder (bool): Flag to share encoder between Q-function and + imitation models. """ learning_rate: float = 6.25e-5 optim_factory: OptimizerFactory = make_optimizer_field() @@ -338,6 +340,7 @@ class DiscreteBCQConfig(LearnableConfig): action_flexibility: float = 0.3 beta: float = 0.5 target_update_interval: int = 8000 + share_encoder: bool = True def create(self, device: DeviceArg = False) -> "DiscreteBCQ": return DiscreteBCQ(self, device) @@ -369,7 +372,7 @@ def inner_create_impl( ) # share convolutional layers if observation is pixel - if isinstance(q_funcs[0].encoder, PixelEncoder): + if self._config.share_encoder: hidden_size = compute_output_size( [observation_shape], q_funcs[0].encoder, diff --git a/d3rlpy/algos/qlearning/torch/bc_impl.py b/d3rlpy/algos/qlearning/torch/bc_impl.py index 3661acbd..309bf3cb 100644 --- a/d3rlpy/algos/qlearning/torch/bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/bc_impl.py @@ -76,12 +76,14 @@ class BCModules(BCBaseModules): class BCImpl(BCBaseImpl): _modules: BCModules + _policy_type: str def __init__( self, observation_shape: Shape, action_size: int, modules: BCModules, + policy_type: str, device: str, ): super().__init__( @@ -90,6 +92,7 @@ def __init__( modules=modules, device=device, ) + self._policy_type = policy_type def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: return self._modules.imitator(x).squashed_mu @@ -97,14 +100,16 @@ def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: def compute_loss( self, obs_t: torch.Tensor, act_t: torch.Tensor ) -> torch.Tensor: - if isinstance(self._modules.imitator, DeterministicPolicy): + if self._policy_type == "deterministic": return compute_deterministic_imitation_loss( self._modules.imitator, obs_t, act_t ) - else: + elif self._policy_type == "stochastic": return compute_stochastic_imitation_loss( self._modules.imitator, obs_t, act_t ) + else: + raise ValueError(f"invalid policy_type: {self._policy_type}") @property def policy(self) -> Policy: diff --git a/d3rlpy/models/torch/policies.py b/d3rlpy/models/torch/policies.py index de8baee2..e0cdf190 100644 --- a/d3rlpy/models/torch/policies.py +++ b/d3rlpy/models/torch/policies.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import Any, NamedTuple, Optional, Union, cast +from typing import Any, NamedTuple, Optional, Union import torch from torch import nn @@ -85,7 +85,7 @@ def forward(self, x: torch.Tensor, *args: Any) -> ActionOutput: action = args[0] h = self._encoder(x, action) residual_action = self._scale * torch.tanh(self._fc(h)) - action = (action + cast(torch.Tensor, residual_action)).clamp(-1.0, 1.0) + action = (action + residual_action).clamp(-1.0, 1.0) return ActionOutput(mu=action, squashed_mu=action, logstd=None) @@ -125,11 +125,13 @@ def forward(self, x: torch.Tensor, *args: Any) -> ActionOutput: mu = self._mu(h) if self._use_std_parameter: - logstd = torch.sigmoid(cast(nn.Parameter, self._logstd)) + assert isinstance(self._logstd, nn.Parameter) + logstd = torch.sigmoid(self._logstd) base_logstd = self._max_logstd - self._min_logstd clipped_logstd = self._min_logstd + logstd * base_logstd else: - logstd = cast(nn.Linear, self._logstd)(h) + assert isinstance(self._logstd, nn.Linear) + logstd = self._logstd(h) clipped_logstd = logstd.clamp(self._min_logstd, self._max_logstd) return ActionOutput(mu, torch.tanh(mu), clipped_logstd) From 9d4f928fab1311a0f568e6422f36ec50f32e7932 Mon Sep 17 00:00:00 2001 From: takuseno Date: Wed, 23 Aug 2023 18:38:24 +0900 Subject: [PATCH 14/20] Move update logics to impl --- d3rlpy/algos/qlearning/awac.py | 23 +----- d3rlpy/algos/qlearning/base.py | 26 +++---- d3rlpy/algos/qlearning/bc.py | 18 +---- d3rlpy/algos/qlearning/bcq.py | 30 +------- d3rlpy/algos/qlearning/bear.py | 32 +-------- d3rlpy/algos/qlearning/cql.py | 50 ++++--------- d3rlpy/algos/qlearning/crr.py | 27 +------ d3rlpy/algos/qlearning/ddpg.py | 13 +--- d3rlpy/algos/qlearning/dqn.py | 13 +--- d3rlpy/algos/qlearning/iql.py | 16 +---- d3rlpy/algos/qlearning/nfq.py | 11 +-- d3rlpy/algos/qlearning/plas.py | 24 +------ d3rlpy/algos/qlearning/sac.py | 55 ++++----------- d3rlpy/algos/qlearning/td3.py | 20 +----- d3rlpy/algos/qlearning/td3_plus_bc.py | 20 +----- d3rlpy/algos/qlearning/torch/bc_impl.py | 10 ++- d3rlpy/algos/qlearning/torch/bcq_impl.py | 43 ++++++++++-- d3rlpy/algos/qlearning/torch/bear_impl.py | 43 ++++++++++-- d3rlpy/algos/qlearning/torch/cql_impl.py | 70 ++++++++++++------- d3rlpy/algos/qlearning/torch/crr_impl.py | 28 ++++++++ d3rlpy/algos/qlearning/torch/ddpg_impl.py | 26 ++++--- d3rlpy/algos/qlearning/torch/dqn_impl.py | 30 +++++--- d3rlpy/algos/qlearning/torch/iql_impl.py | 12 +++- d3rlpy/algos/qlearning/torch/plas_impl.py | 23 +++++- d3rlpy/algos/qlearning/torch/sac_impl.py | 42 ++++++++--- d3rlpy/algos/qlearning/torch/td3_impl.py | 20 ++++++ .../algos/qlearning/torch/td3_plus_bc_impl.py | 2 + d3rlpy/algos/transformer/base.py | 29 ++++---- .../algos/transformer/decision_transformer.py | 7 -- .../torch/decision_transformer_impl.py | 12 ++-- d3rlpy/dataclass_utils.py | 17 ++++- d3rlpy/ope/fqe.py | 29 +++----- d3rlpy/ope/torch/fqe_impl.py | 40 +++++++---- d3rlpy/torch_utility.py | 7 +- tests/test_dataclass_utils.py | 18 ++++- tests/test_torch_utility.py | 6 +- 36 files changed, 442 insertions(+), 450 deletions(-) diff --git a/d3rlpy/algos/qlearning/awac.py b/d3rlpy/algos/qlearning/awac.py index f49f04a9..c4879db5 100644 --- a/d3rlpy/algos/qlearning/awac.py +++ b/d3rlpy/algos/qlearning/awac.py @@ -1,10 +1,9 @@ import dataclasses -from typing import Dict import torch from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, @@ -14,7 +13,6 @@ from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field from ...models.torch import Parameter -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.awac_impl import AWACImpl from .torch.sac_impl import SACModules @@ -72,7 +70,6 @@ class AWACConfig(LearnableConfig): n_action_samples (int): Number of sampled actions to calculate :math:`A^\pi(s_t, a_t)`. n_critics (int): Number of Q functions for ensemble. - update_actor_interval (int): Interval to update policy function. """ actor_learning_rate: float = 3e-4 critic_learning_rate: float = 3e-4 @@ -87,7 +84,6 @@ class AWACConfig(LearnableConfig): lam: float = 1.0 n_action_samples: int = 1 n_critics: int = 2 - update_actor_interval: int = 1 def create(self, device: DeviceArg = False) -> "AWAC": return AWAC(self, device) @@ -134,7 +130,7 @@ def inner_create_impl( q_funcs.parameters(), lr=self._config.critic_learning_rate ) - dummy_log_temp = Parameter(torch.zeros(1)) + dummy_log_temp = Parameter(torch.zeros(1, 1)) modules = SACModules( policy=policy, q_funcs=q_funcs, @@ -142,7 +138,7 @@ def inner_create_impl( log_temp=dummy_log_temp, actor_optim=actor_optim, critic_optim=critic_optim, - temp_optim=torch.optim.Adam(dummy_log_temp.parameters(), lr=0.0), + temp_optim=None, ) self._impl = AWACImpl( @@ -158,19 +154,6 @@ def inner_create_impl( device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - - metrics = {} - - metrics.update(self._impl.update_critic(batch)) - - # delayed policy update - if self._grad_step % self._config.update_actor_interval == 0: - metrics.update(self._impl.update_actor(batch)) - - return metrics - def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index 76ac3d02..ad0073dc 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -45,6 +45,7 @@ eval_api, hard_sync, sync_optimizer_state, + train_api, ) from ..utility import ( assert_action_space_with_dataset, @@ -63,6 +64,16 @@ class QLearningAlgoImplBase(ImplBase): + @train_api + def update(self, batch: TorchMiniBatch, grad_step: int) -> Dict[str, float]: + return self.inner_update(batch, grad_step) + + @abstractmethod + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + pass + @eval_api def predict_best_action(self, x: torch.Tensor) -> torch.Tensor: return self.inner_predict_best_action(x) @@ -809,6 +820,7 @@ def update(self, batch: TransitionMiniBatch) -> Dict[str, float]: Returns: Dictionary of metrics. """ + assert self._impl, IMPL_NOT_INITIALIZED_ERROR torch_batch = TorchMiniBatch.from_batch( batch=batch, device=self._device, @@ -816,22 +828,10 @@ def update(self, batch: TransitionMiniBatch) -> Dict[str, float]: action_scaler=self._config.action_scaler, reward_scaler=self._config.reward_scaler, ) - loss = self.inner_update(torch_batch) + loss = self._impl.inner_update(torch_batch, self._grad_step) self._grad_step += 1 return loss - @abstractmethod - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - """Update parameters with PyTorch mini-batch. - - Args: - batch: PyTorch mini-batch data. - - Returns: - Dictionary of metrics. - """ - raise NotImplementedError - def copy_policy_from( self, algo: "QLearningAlgoBase[QLearningAlgoImplBase, LearnableConfig]" ) -> None: diff --git a/d3rlpy/algos/qlearning/bc.py b/d3rlpy/algos/qlearning/bc.py index 90762272..7dbfffbf 100644 --- a/d3rlpy/algos/qlearning/bc.py +++ b/d3rlpy/algos/qlearning/bc.py @@ -1,8 +1,7 @@ import dataclasses -from typing import Dict, Generic, TypeVar from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import ( create_categorical_policy, @@ -11,7 +10,6 @@ ) from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.bc_impl import ( BCBaseImpl, @@ -24,16 +22,6 @@ __all__ = ["BCConfig", "BC", "DiscreteBCConfig", "DiscreteBC"] -TBCConfig = TypeVar("TBCConfig", bound="LearnableConfig") - - -class _BCBase(Generic[TBCConfig], QLearningAlgoBase[BCBaseImpl, TBCConfig]): - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - loss = self._impl.update_imitator(batch) - return {"loss": loss} - - @dataclasses.dataclass() class BCConfig(LearnableConfig): r"""Config of Behavior Cloning algorithm. @@ -76,7 +64,7 @@ def get_type() -> str: return "bc" -class BC(_BCBase[BCConfig]): +class BC(QLearningAlgoBase[BCBaseImpl, BCConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: @@ -159,7 +147,7 @@ def get_type() -> str: return "discrete_bc" -class DiscreteBC(_BCBase[DiscreteBCConfig]): +class DiscreteBC(QLearningAlgoBase[BCBaseImpl, DiscreteBCConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index ac81ce7d..df154087 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -1,8 +1,7 @@ import dataclasses -from typing import Dict from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import ( create_categorical_policy, @@ -15,7 +14,6 @@ from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field from ...models.torch import CategoricalPolicy, compute_output_size -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.bcq_impl import ( BCQImpl, @@ -245,26 +243,10 @@ def inner_create_impl( n_action_samples=self._config.n_action_samples, action_flexibility=self._config.action_flexibility, beta=self._config.beta, + rl_start_step=self._config.rl_start_step, device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - - metrics = {} - - metrics.update(self._impl.update_imitator(batch)) - - if self._grad_step >= self._config.rl_start_step: - metrics.update(self._impl.update_critic(batch)) - - if self._grad_step % self._config.update_actor_interval == 0: - metrics.update(self._impl.update_actor(batch)) - self._impl.update_actor_target() - self._impl.update_critic_target() - - return metrics - def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS @@ -417,19 +399,13 @@ def inner_create_impl( modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, + target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, action_flexibility=self._config.action_flexibility, beta=self._config.beta, device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - loss = self._impl.update(batch) - if self._grad_step % self._config.target_update_interval == 0: - self._impl.update_target() - return loss - def get_action_type(self) -> ActionSpace: return ActionSpace.DISCRETE diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index 5abe7e15..fd890882 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -1,9 +1,8 @@ import dataclasses import math -from typing import Dict from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import ( create_conditional_vae, @@ -14,7 +13,6 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.bear_impl import BEARImpl, BEARModules @@ -244,36 +242,10 @@ def inner_create_impl( mmd_kernel=self._config.mmd_kernel, mmd_sigma=self._config.mmd_sigma, vae_kl_weight=self._config.vae_kl_weight, + warmup_steps=self._config.warmup_steps, device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - - metrics = {} - - metrics.update(self._impl.update_imitator(batch)) - - # lagrangian parameter update for SAC temperature - if self._config.temp_learning_rate > 0: - metrics.update(self._impl.update_temp(batch)) - - # lagrangian parameter update for MMD loss weight - if self._config.alpha_learning_rate > 0: - metrics.update(self._impl.update_alpha(batch)) - - metrics.update(self._impl.update_critic(batch)) - - if self._grad_step < self._config.warmup_steps: - actor_loss = self._impl.warmup_actor(batch) - else: - actor_loss = self._impl.update_actor(batch) - metrics.update(actor_loss) - - self._impl.update_critic_target() - - return metrics - def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index de04b672..6dc27a48 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -1,9 +1,8 @@ import dataclasses import math -from typing import Dict from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, @@ -14,7 +13,6 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.cql_impl import CQLImpl, CQLModules, DiscreteCQLImpl from .torch.dqn_impl import DQNModules @@ -173,12 +171,18 @@ def inner_create_impl( critic_optim = self._config.critic_optim_factory.create( q_funcs.parameters(), lr=self._config.critic_learning_rate ) - temp_optim = self._config.temp_optim_factory.create( - log_temp.parameters(), lr=self._config.temp_learning_rate - ) - alpha_optim = self._config.alpha_optim_factory.create( - log_alpha.parameters(), lr=self._config.alpha_learning_rate - ) + if self._config.temp_learning_rate > 0: + temp_optim = self._config.temp_optim_factory.create( + log_temp.parameters(), lr=self._config.temp_learning_rate + ) + else: + temp_optim = None + if self._config.alpha_learning_rate > 0: + alpha_optim = self._config.alpha_optim_factory.create( + log_alpha.parameters(), lr=self._config.alpha_learning_rate + ) + else: + alpha_optim = None modules = CQLModules( policy=policy, @@ -207,26 +211,6 @@ def inner_create_impl( device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - - metrics = {} - - # lagrangian parameter update for SAC temperature - if self._config.temp_learning_rate > 0: - metrics.update(self._impl.update_temp(batch)) - - # lagrangian parameter update for conservative loss weight - if self._config.alpha_learning_rate > 0: - metrics.update(self._impl.update_alpha(batch)) - - metrics.update(self._impl.update_critic(batch)) - metrics.update(self._impl.update_actor(batch)) - - self._impl.update_critic_target() - - return metrics - def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS @@ -327,18 +311,12 @@ def inner_create_impl( modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, + target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, alpha=self._config.alpha, device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - loss = self._impl.update(batch) - if self._grad_step % self._config.target_update_interval == 0: - self._impl.update_target() - return loss - def get_action_type(self) -> ActionSpace: return ActionSpace.DISCRETE diff --git a/d3rlpy/algos/qlearning/crr.py b/d3rlpy/algos/qlearning/crr.py index 840eded2..b94bf139 100644 --- a/d3rlpy/algos/qlearning/crr.py +++ b/d3rlpy/algos/qlearning/crr.py @@ -1,8 +1,7 @@ import dataclasses -from typing import Dict from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, @@ -11,7 +10,6 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.crr_impl import CRRImpl, CRRModules @@ -192,30 +190,11 @@ def inner_create_impl( weight_type=self._config.weight_type, max_weight=self._config.max_weight, tau=self._config.tau, + target_update_type=self._config.target_update_type, + target_update_interval=self._config.target_update_interval, device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - - metrics = {} - metrics.update(self._impl.update_critic(batch)) - metrics.update(self._impl.update_actor(batch)) - - if self._config.target_update_type == "hard": - if self._grad_step % self._config.target_update_interval == 0: - self._impl.sync_critic_target() - self._impl.sync_actor_target() - elif self._config.target_update_type == "soft": - self._impl.update_critic_target() - self._impl.update_actor_target() - else: - raise ValueError( - f"invalid target_update_type: {self._config.target_update_type}" - ) - - return metrics - def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS diff --git a/d3rlpy/algos/qlearning/ddpg.py b/d3rlpy/algos/qlearning/ddpg.py index 0fe94cf0..16e2bfc5 100644 --- a/d3rlpy/algos/qlearning/ddpg.py +++ b/d3rlpy/algos/qlearning/ddpg.py @@ -1,8 +1,7 @@ import dataclasses -from typing import Dict from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, @@ -11,7 +10,6 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.ddpg_impl import DDPGImpl, DDPGModules @@ -151,15 +149,6 @@ def inner_create_impl( device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - metrics = {} - metrics.update(self._impl.update_critic(batch)) - metrics.update(self._impl.update_actor(batch)) - self._impl.update_critic_target() - self._impl.update_actor_target() - return metrics - def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index 95894d92..7866930c 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -1,14 +1,12 @@ import dataclasses -from typing import Dict from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import create_discrete_q_function from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.dqn_impl import DoubleDQNImpl, DQNImpl, DQNModules @@ -100,18 +98,12 @@ def inner_create_impl( action_size=action_size, q_func_forwarder=forwarder, targ_q_func_forwarder=targ_forwarder, + target_update_interval=self._config.target_update_interval, modules=modules, gamma=self._config.gamma, device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - loss = self._impl.update(batch) - if self._grad_step % self._config.target_update_interval == 0: - self._impl.update_target() - return loss - def get_action_type(self) -> ActionSpace: return ActionSpace.DISCRETE @@ -209,6 +201,7 @@ def inner_create_impl( modules=modules, q_func_forwarder=forwarder, targ_q_func_forwarder=targ_forwarder, + target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index 63f037ab..8ce2fb2f 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -1,8 +1,7 @@ import dataclasses -from typing import Dict from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, @@ -12,7 +11,6 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import MeanQFunctionFactory -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.iql_impl import IQLImpl, IQLModules @@ -173,18 +171,6 @@ def inner_create_impl( device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - - metrics = {} - - metrics.update(self._impl.update_critic_and_state_value(batch)) - metrics.update(self._impl.update_actor(batch)) - - self._impl.update_critic_target() - - return metrics - def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS diff --git a/d3rlpy/algos/qlearning/nfq.py b/d3rlpy/algos/qlearning/nfq.py index b5c76a68..bd3bc5e3 100644 --- a/d3rlpy/algos/qlearning/nfq.py +++ b/d3rlpy/algos/qlearning/nfq.py @@ -1,14 +1,12 @@ import dataclasses -from typing import Dict from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import create_discrete_q_function from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.dqn_impl import DQNImpl, DQNModules @@ -103,16 +101,11 @@ def inner_create_impl( modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, + target_update_interval=1, gamma=self._config.gamma, device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - loss = self._impl.update(batch) - self._impl.update_target() - return loss - def get_action_type(self) -> ActionSpace: return ActionSpace.DISCRETE diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index 629f546a..e346d35c 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -1,8 +1,7 @@ import dataclasses -from typing import Dict from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import ( create_conditional_vae, @@ -13,7 +12,6 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.plas_impl import ( PLASImpl, @@ -75,7 +73,6 @@ class PLASConfig(LearnableConfig): gamma (float): Discount factor. tau (float): Target network synchronization coefficiency. n_critics (int): Number of Q functions for ensemble. - update_actor_interval (int): Interval to update policy function. lam (float): Weight factor for critic ensemble. warmup_steps (int): Number of steps to warmup the VAE. beta (float): KL reguralization term for Conditional VAE. @@ -94,7 +91,6 @@ class PLASConfig(LearnableConfig): gamma: float = 0.99 tau: float = 0.005 n_critics: int = 2 - update_actor_interval: int = 1 lam: float = 0.75 warmup_steps: int = 500000 beta: float = 0.5 @@ -180,25 +176,10 @@ def inner_create_impl( tau=self._config.tau, lam=self._config.lam, beta=self._config.beta, + warmup_steps=self._config.warmup_steps, device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - - metrics = {} - - if self._grad_step < self._config.warmup_steps: - metrics.update(self._impl.update_imitator(batch)) - else: - metrics.update(self._impl.update_critic(batch)) - if self._grad_step % self._config.update_actor_interval == 0: - metrics.update(self._impl.update_actor(batch)) - self._impl.update_actor_target() - self._impl.update_critic_target() - - return metrics - def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS @@ -350,6 +331,7 @@ def inner_create_impl( tau=self._config.tau, lam=self._config.lam, beta=self._config.beta, + warmup_steps=self._config.warmup_steps, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/sac.py b/d3rlpy/algos/qlearning/sac.py index 0748d3c0..8ad40728 100644 --- a/d3rlpy/algos/qlearning/sac.py +++ b/d3rlpy/algos/qlearning/sac.py @@ -1,9 +1,8 @@ import dataclasses import math -from typing import Dict from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import ( create_categorical_policy, @@ -15,7 +14,6 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.sac_impl import ( DiscreteSACImpl, @@ -158,9 +156,12 @@ def inner_create_impl( critic_optim = self._config.critic_optim_factory.create( q_funcs.parameters(), lr=self._config.critic_learning_rate ) - temp_optim = self._config.temp_optim_factory.create( - log_temp.parameters(), lr=self._config.temp_learning_rate - ) + if self._config.temp_learning_rate > 0: + temp_optim = self._config.temp_optim_factory.create( + log_temp.parameters(), lr=self._config.temp_learning_rate + ) + else: + temp_optim = None modules = SACModules( policy=policy, @@ -183,21 +184,6 @@ def inner_create_impl( device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - - metrics = {} - - # lagrangian parameter update for SAC temperature - if self._config.temp_learning_rate > 0: - metrics.update(self._impl.update_temp(batch)) - - metrics.update(self._impl.update_critic(batch)) - metrics.update(self._impl.update_actor(batch)) - self._impl.update_critic_target() - - return metrics - def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS @@ -318,9 +304,12 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.parameters(), lr=self._config.actor_learning_rate ) - temp_optim = self._config.temp_optim_factory.create( - log_temp.parameters(), lr=self._config.temp_learning_rate - ) + if self._config.temp_learning_rate > 0: + temp_optim = self._config.temp_optim_factory.create( + log_temp.parameters(), lr=self._config.temp_learning_rate + ) + else: + temp_optim = None modules = DiscreteSACModules( policy=policy, @@ -338,27 +327,11 @@ def inner_create_impl( modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, + target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - - metrics = {} - - # lagrangian parameter update for SAC temeprature - if self._config.temp_learning_rate > 0: - metrics.update(self._impl.update_temp(batch)) - - metrics.update(self._impl.update_critic(batch)) - metrics.update(self._impl.update_actor(batch)) - - if self._grad_step % self._config.target_update_interval == 0: - self._impl.update_target() - - return metrics - def get_action_type(self) -> ActionSpace: return ActionSpace.DISCRETE diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index eb92b0bb..50f7d46a 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -1,8 +1,7 @@ import dataclasses -from typing import Dict from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, @@ -11,7 +10,6 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.ddpg_impl import DDPGModules from .torch.td3_impl import TD3Impl @@ -159,24 +157,10 @@ def inner_create_impl( tau=self._config.tau, target_smoothing_sigma=self._config.target_smoothing_sigma, target_smoothing_clip=self._config.target_smoothing_clip, + update_actor_interval=self._config.update_actor_interval, device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - - metrics = {} - - metrics.update(self._impl.update_critic(batch)) - - # delayed policy update - if self._grad_step % self._config.update_actor_interval == 0: - metrics.update(self._impl.update_actor(batch)) - self._impl.update_critic_target() - self._impl.update_actor_target() - - return metrics - def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS diff --git a/d3rlpy/algos/qlearning/td3_plus_bc.py b/d3rlpy/algos/qlearning/td3_plus_bc.py index 590fda2a..8b89b322 100644 --- a/d3rlpy/algos/qlearning/td3_plus_bc.py +++ b/d3rlpy/algos/qlearning/td3_plus_bc.py @@ -1,8 +1,7 @@ import dataclasses -from typing import Dict from ...base import DeviceArg, LearnableConfig, register_learnable -from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace +from ...constants import ActionSpace from ...dataset import Shape from ...models.builders import ( create_continuous_q_function, @@ -11,7 +10,6 @@ from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import QFunctionFactory, make_q_func_field -from ...torch_utility import TorchMiniBatch from .base import QLearningAlgoBase from .torch.ddpg_impl import DDPGModules from .torch.td3_plus_bc_impl import TD3PlusBCImpl @@ -152,24 +150,10 @@ def inner_create_impl( target_smoothing_sigma=self._config.target_smoothing_sigma, target_smoothing_clip=self._config.target_smoothing_clip, alpha=self._config.alpha, + update_actor_interval=self._config.update_actor_interval, device=self._device, ) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - - metrics = {} - - metrics.update(self._impl.update_critic(batch)) - - # delayed policy update - if self._grad_step % self._config.update_actor_interval == 0: - metrics.update(self._impl.update_actor(batch)) - self._impl.update_critic_target() - self._impl.update_actor_target() - - return metrics - def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS diff --git a/d3rlpy/algos/qlearning/torch/bc_impl.py b/d3rlpy/algos/qlearning/torch/bc_impl.py index 309bf3cb..dbefeb55 100644 --- a/d3rlpy/algos/qlearning/torch/bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/bc_impl.py @@ -1,6 +1,6 @@ import dataclasses from abc import ABCMeta, abstractmethod -from typing import Union +from typing import Dict, Union import torch from torch.optim import Optimizer @@ -15,7 +15,7 @@ compute_discrete_imitation_loss, compute_stochastic_imitation_loss, ) -from ....torch_utility import Modules, TorchMiniBatch, train_api +from ....torch_utility import Modules, TorchMiniBatch from ..base import QLearningAlgoImplBase __all__ = ["BCImpl", "DiscreteBCImpl", "BCModules", "DiscreteBCModules"] @@ -43,7 +43,6 @@ def __init__( device=device, ) - @train_api def update_imitator(self, batch: TorchMiniBatch) -> float: self._modules.optim.zero_grad() @@ -68,6 +67,11 @@ def inner_predict_value( ) -> torch.Tensor: raise NotImplementedError("BC does not support value estimation") + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + return {"loss": self.update_imitator(batch)} + @dataclasses.dataclass(frozen=True) class BCModules(BCBaseModules): diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index 997257e8..fcedd24c 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -18,11 +18,17 @@ compute_vae_error, forward_vae_decode, ) -from ....torch_utility import TorchMiniBatch, soft_sync, train_api +from ....torch_utility import TorchMiniBatch, soft_sync from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules -from .dqn_impl import DoubleDQNImpl, DQNModules +from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules -__all__ = ["BCQImpl", "DiscreteBCQImpl", "BCQModules", "DiscreteBCQModules"] +__all__ = [ + "BCQImpl", + "DiscreteBCQImpl", + "BCQModules", + "DiscreteBCQModules", + "DiscreteBCQLoss", +] @dataclasses.dataclass(frozen=True) @@ -39,6 +45,7 @@ class BCQImpl(DDPGBaseImpl): _n_action_samples: int _action_flexibility: float _beta: float + _rl_start_step: float def __init__( self, @@ -53,6 +60,7 @@ def __init__( n_action_samples: int, action_flexibility: float, beta: float, + rl_start_step: int, device: str, ): super().__init__( @@ -69,6 +77,7 @@ def __init__( self._n_action_samples = n_action_samples self._action_flexibility = action_flexibility self._beta = beta + self._rl_start_step = rl_start_step def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: latent = torch.randn( @@ -88,7 +97,6 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: ) return -value[0].mean() - @train_api def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.imitator_optim.zero_grad() @@ -176,12 +184,28 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: def update_actor_target(self) -> None: soft_sync(self._modules.targ_policy, self._modules.policy, self._tau) + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + metrics = {} + metrics.update(self.update_imitator(batch)) + if grad_step >= self._rl_start_step: + metrics.update(super().inner_update(batch, grad_step)) + self.update_actor_target() + return metrics + @dataclasses.dataclass(frozen=True) class DiscreteBCQModules(DQNModules): imitator: CategoricalPolicy +@dataclasses.dataclass(frozen=True) +class DiscreteBCQLoss(DQNLoss): + td_loss: torch.Tensor + imitator_loss: torch.Tensor + + class DiscreteBCQImpl(DoubleDQNImpl): _modules: DiscreteBCQModules _action_flexibility: float @@ -194,6 +218,7 @@ def __init__( modules: DiscreteBCQModules, q_func_forwarder: DiscreteEnsembleQFunctionForwarder, targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + target_update_interval: int, gamma: float, action_flexibility: float, beta: float, @@ -205,6 +230,7 @@ def __init__( modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, + target_update_interval=target_update_interval, gamma=gamma, device=device, ) @@ -213,15 +239,18 @@ def __init__( def compute_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor - ) -> torch.Tensor: - loss = super().compute_loss(batch, q_tpn) + ) -> DiscreteBCQLoss: + td_loss = super().compute_loss(batch, q_tpn).loss imitator_loss = compute_discrete_imitation_loss( policy=self._modules.imitator, x=batch.observations, action=batch.actions.long(), beta=self._beta, ) - return loss + imitator_loss + loss = td_loss + imitator_loss + return DiscreteBCQLoss( + loss=loss, td_loss=td_loss, imitator_loss=imitator_loss + ) def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: dist = self._modules.imitator(x) diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index 10b83d99..aa2cc173 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Dict +from typing import Dict, Optional import torch from torch.optim import Optimizer @@ -14,7 +14,7 @@ compute_vae_error, forward_vae_sample_n, ) -from ....torch_utility import TorchMiniBatch, train_api +from ....torch_utility import TorchMiniBatch from .sac_impl import SACImpl, SACModules __all__ = ["BEARImpl", "BEARModules"] @@ -39,7 +39,7 @@ class BEARModules(SACModules): imitator: ConditionalVAE log_alpha: Parameter imitator_optim: Optimizer - alpha_optim: Optimizer + alpha_optim: Optional[Optimizer] class BEARImpl(SACImpl): @@ -52,6 +52,7 @@ class BEARImpl(SACImpl): _mmd_kernel: str _mmd_sigma: float _vae_kl_weight: float + _warmup_steps: int def __init__( self, @@ -70,6 +71,7 @@ def __init__( mmd_kernel: str, mmd_sigma: float, vae_kl_weight: float, + warmup_steps: int, device: str, ): super().__init__( @@ -90,13 +92,13 @@ def __init__( self._mmd_kernel = mmd_kernel self._mmd_sigma = mmd_sigma self._vae_kl_weight = vae_kl_weight + self._warmup_steps = warmup_steps def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: loss = super().compute_actor_loss(batch) mmd_loss = self._compute_mmd_loss(batch.observations) return loss + mmd_loss - @train_api def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.actor_optim.zero_grad() @@ -112,7 +114,6 @@ def _compute_mmd_loss(self, obs_t: torch.Tensor) -> torch.Tensor: alpha = self._modules.log_alpha().exp() return (alpha * (mmd - self._alpha_threshold)).mean() - @train_api def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.imitator_optim.zero_grad() @@ -132,11 +133,12 @@ def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor: beta=self._vae_kl_weight, ) - @train_api def update_alpha(self, batch: TorchMiniBatch) -> Dict[str, float]: + assert self._modules.alpha_optim + self._modules.alpha_optim.zero_grad() + loss = -self._compute_mmd_loss(batch.observations) - self._modules.alpha_optim.zero_grad() loss.backward() self._modules.alpha_optim.step() @@ -252,3 +254,30 @@ def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: max_indices = torch.argmax(values, dim=1) return actions[torch.arange(x.shape[0]), max_indices] + + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + metrics = {} + + metrics.update(self.update_imitator(batch)) + + # lagrangian parameter update for SAC temperature + if self._modules.temp_optim: + metrics.update(self.update_temp(batch)) + + # lagrangian parameter update for MMD loss weight + if self._modules.alpha_optim: + metrics.update(self.update_alpha(batch)) + + metrics.update(self.update_critic(batch)) + + if grad_step < self._warmup_steps: + actor_loss = self.warmup_actor(batch) + else: + actor_loss = self.update_actor(batch) + metrics.update(actor_loss) + + self.update_critic_target() + + return metrics diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 39ab82a1..a531abd6 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -1,6 +1,6 @@ import dataclasses import math -from typing import Dict +from typing import Dict, Optional import torch import torch.nn.functional as F @@ -13,17 +13,17 @@ Parameter, build_squashed_gaussian_distribution, ) -from ....torch_utility import TorchMiniBatch, train_api -from .dqn_impl import DoubleDQNImpl, DQNModules +from ....torch_utility import TorchMiniBatch +from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules from .sac_impl import SACImpl, SACModules -__all__ = ["CQLImpl", "DiscreteCQLImpl", "CQLModules"] +__all__ = ["CQLImpl", "DiscreteCQLImpl", "CQLModules", "DiscreteCQLLoss"] @dataclasses.dataclass(frozen=True) class CQLModules(SACModules): log_alpha: Parameter - alpha_optim: Optimizer + alpha_optim: Optional[Optimizer] class CQLImpl(SACImpl): @@ -72,8 +72,9 @@ def compute_critic_loss( ) return loss + conservative_loss - @train_api def update_alpha(self, batch: TorchMiniBatch) -> Dict[str, float]: + assert self._modules.alpha_optim + # Q function should be inference mode for stability self._modules.q_funcs.eval() @@ -195,6 +196,32 @@ def _compute_deterministic_target( reduction="min", ) + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + metrics = {} + + # lagrangian parameter update for SAC temperature + if self._modules.temp_optim: + metrics.update(self.update_temp(batch)) + + # lagrangian parameter update for conservative loss weight + if self._modules.alpha_optim: + metrics.update(self.update_alpha(batch)) + + metrics.update(self.update_critic(batch)) + metrics.update(self.update_actor(batch)) + + self.update_critic_target() + + return metrics + + +@dataclasses.dataclass(frozen=True) +class DiscreteCQLLoss(DQNLoss): + td_loss: torch.Tensor + conservative_loss: torch.Tensor + class DiscreteCQLImpl(DoubleDQNImpl): _alpha: float @@ -206,6 +233,7 @@ def __init__( modules: DQNModules, q_func_forwarder: DiscreteEnsembleQFunctionForwarder, targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + target_update_interval: int, gamma: float, alpha: float, device: str, @@ -216,6 +244,7 @@ def __init__( modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, + target_update_interval=target_update_interval, gamma=gamma, device=device, ) @@ -234,27 +263,16 @@ def _compute_conservative_loss( return (logsumexp - data_values).mean() - @train_api - def update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._modules.optim is not None - - self._modules.optim.zero_grad() - - q_tpn = self.compute_target(batch) - - td_loss = self.compute_loss(batch, q_tpn) + def compute_loss( + self, + batch: TorchMiniBatch, + q_tpn: torch.Tensor, + ) -> DiscreteCQLLoss: + td_loss = super().compute_loss(batch, q_tpn).loss conservative_loss = self._compute_conservative_loss( batch.observations, batch.actions.long() ) loss = td_loss + self._alpha * conservative_loss - - loss.backward() - self._modules.optim.step() - - return { - "loss": float(loss.cpu().detach().numpy()), - "td_loss": float(td_loss.cpu().detach().numpy()), - "conservative_loss": float( - conservative_loss.cpu().detach().numpy() - ), - } + return DiscreteCQLLoss( + loss=loss, td_loss=td_loss, conservative_loss=conservative_loss + ) diff --git a/d3rlpy/algos/qlearning/torch/crr_impl.py b/d3rlpy/algos/qlearning/torch/crr_impl.py index 59183666..38f2168d 100644 --- a/d3rlpy/algos/qlearning/torch/crr_impl.py +++ b/d3rlpy/algos/qlearning/torch/crr_impl.py @@ -1,4 +1,5 @@ import dataclasses +from typing import Dict import torch import torch.nn.functional as F @@ -28,6 +29,8 @@ class CRRImpl(DDPGBaseImpl): _advantage_type: str _weight_type: str _max_weight: float + _target_update_type: str + _target_update_interval: int def __init__( self, @@ -43,6 +46,8 @@ def __init__( weight_type: str, max_weight: float, tau: float, + target_update_type: str, + target_update_interval: int, device: str, ): super().__init__( @@ -60,6 +65,8 @@ def __init__( self._advantage_type = advantage_type self._weight_type = weight_type self._max_weight = max_weight + self._target_update_type = target_update_type + self._target_update_interval = target_update_interval def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: # compute log probability @@ -175,3 +182,24 @@ def sync_actor_target(self) -> None: def update_actor_target(self) -> None: soft_sync(self._modules.targ_policy, self._modules.policy, self._tau) + + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + metrics = {} + metrics.update(self.update_critic(batch)) + metrics.update(self.update_actor(batch)) + + if self._target_update_type == "hard": + if grad_step % self._target_update_interval == 0: + self.sync_critic_target() + self.sync_actor_target() + elif self._target_update_type == "soft": + self.update_critic_target() + self.update_actor_target() + else: + raise ValueError( + f"invalid target_update_type: {self._target_update_type}" + ) + + return metrics diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 7fd9f521..a1535589 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -8,13 +8,7 @@ from ....dataset import Shape from ....models.torch import ContinuousEnsembleQFunctionForwarder, Policy -from ....torch_utility import ( - Modules, - TorchMiniBatch, - hard_sync, - soft_sync, - train_api, -) +from ....torch_utility import Modules, TorchMiniBatch, hard_sync, soft_sync from ..base import QLearningAlgoImplBase from .utility import ContinuousQFunctionMixin @@ -62,7 +56,6 @@ def __init__( self._targ_q_func_forwarder = targ_q_func_forwarder hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) - @train_api def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.critic_optim.zero_grad() @@ -87,7 +80,6 @@ def compute_critic_loss( gamma=self._gamma**batch.intervals, ) - @train_api def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: # Q function should be inference mode for stability self._modules.q_funcs.eval() @@ -101,6 +93,15 @@ def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: return {"actor_loss": float(loss.cpu().detach().numpy())} + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + metrics = {} + metrics.update(self.update_critic(batch)) + metrics.update(self.update_actor(batch)) + self.update_critic_target() + return metrics + @abstractmethod def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: pass @@ -188,3 +189,10 @@ def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: def update_actor_target(self) -> None: soft_sync(self._modules.targ_policy, self._modules.policy, self._tau) + + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + metrics = super().inner_update(batch, grad_step) + self.update_actor_target() + return metrics diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index 4d1fe0f0..e5caee13 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -5,13 +5,14 @@ from torch import nn from torch.optim import Optimizer +from ....dataclass_utils import asdict_as_float from ....dataset import Shape from ....models.torch import DiscreteEnsembleQFunctionForwarder -from ....torch_utility import Modules, TorchMiniBatch, hard_sync, train_api +from ....torch_utility import Modules, TorchMiniBatch, hard_sync from ..base import QLearningAlgoImplBase from .utility import DiscreteQFunctionMixin -__all__ = ["DQNImpl", "DQNModules", "DoubleDQNImpl"] +__all__ = ["DQNImpl", "DQNModules", "DQNLoss", "DoubleDQNImpl"] @dataclasses.dataclass(frozen=True) @@ -21,11 +22,17 @@ class DQNModules(Modules): optim: Optimizer +@dataclasses.dataclass(frozen=True) +class DQNLoss: + loss: torch.Tensor + + class DQNImpl(DiscreteQFunctionMixin, QLearningAlgoImplBase): _modules: DQNModules _gamma: float _q_func_forwarder: DiscreteEnsembleQFunctionForwarder _targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder + _target_update_interval: int def __init__( self, @@ -34,6 +41,7 @@ def __init__( modules: DQNModules, q_func_forwarder: DiscreteEnsembleQFunctionForwarder, targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + target_update_interval: int, gamma: float, device: str, ): @@ -46,27 +54,32 @@ def __init__( self._gamma = gamma self._q_func_forwarder = q_func_forwarder self._targ_q_func_forwarder = targ_q_func_forwarder + self._target_update_interval = target_update_interval hard_sync(modules.targ_q_funcs, modules.q_funcs) - @train_api - def update(self, batch: TorchMiniBatch) -> Dict[str, float]: + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: self._modules.optim.zero_grad() q_tpn = self.compute_target(batch) loss = self.compute_loss(batch, q_tpn) - loss.backward() + loss.loss.backward() self._modules.optim.step() - return {"loss": float(loss.cpu().detach().numpy())} + if grad_step % self._target_update_interval == 0: + self.update_target() + + return asdict_as_float(loss) def compute_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor, - ) -> torch.Tensor: - return self._q_func_forwarder.compute_error( + ) -> DQNLoss: + loss = self._q_func_forwarder.compute_error( observations=batch.observations, actions=batch.actions.long(), rewards=batch.rewards, @@ -74,6 +87,7 @@ def compute_loss( terminals=batch.terminals, gamma=self._gamma**batch.intervals, ) + return DQNLoss(loss=loss) def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): diff --git a/d3rlpy/algos/qlearning/torch/iql_impl.py b/d3rlpy/algos/qlearning/torch/iql_impl.py index 0bb068c3..880a4c09 100644 --- a/d3rlpy/algos/qlearning/torch/iql_impl.py +++ b/d3rlpy/algos/qlearning/torch/iql_impl.py @@ -10,7 +10,7 @@ ValueFunction, build_gaussian_distribution, ) -from ....torch_utility import TorchMiniBatch, train_api +from ....torch_utility import TorchMiniBatch from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules __all__ = ["IQLImpl", "IQLModules"] @@ -102,7 +102,6 @@ def compute_value_loss(self, batch: TorchMiniBatch) -> torch.Tensor: weight = (self._expectile - (diff < 0.0).float()).abs().detach() return (weight * (diff**2)).mean() - @train_api def update_critic_and_state_value( self, batch: TorchMiniBatch ) -> Dict[str, float]: @@ -128,3 +127,12 @@ def update_critic_and_state_value( def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: dist = build_gaussian_distribution(self._modules.policy(x)) return dist.sample() + + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + metrics = {} + metrics.update(self.update_critic_and_state_value(batch)) + metrics.update(self.update_actor(batch)) + self.update_critic_target() + return metrics diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index 5d54ac49..df3a081b 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -13,7 +13,7 @@ compute_vae_error, forward_vae_decode, ) -from ....torch_utility import TorchMiniBatch, soft_sync, train_api +from ....torch_utility import TorchMiniBatch, soft_sync from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules __all__ = [ @@ -36,6 +36,7 @@ class PLASImpl(DDPGBaseImpl): _modules: PLASModules _lam: float _beta: float + _warmup_steps: int def __init__( self, @@ -48,6 +49,7 @@ def __init__( tau: float, lam: float, beta: float, + warmup_steps: int, device: str, ): super().__init__( @@ -62,8 +64,8 @@ def __init__( ) self._lam = lam self._beta = beta + self._warmup_steps = warmup_steps - @train_api def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.imitator_optim.zero_grad() @@ -116,6 +118,21 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: def update_actor_target(self) -> None: soft_sync(self._modules.targ_policy, self._modules.policy, self._tau) + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + metrics = {} + + if grad_step < self._warmup_steps: + metrics.update(self.update_imitator(batch)) + else: + metrics.update(self.update_critic(batch)) + metrics.update(self.update_actor(batch)) + self.update_actor_target() + self.update_critic_target() + + return metrics + @dataclasses.dataclass(frozen=True) class PLASWithPerturbationModules(PLASModules): @@ -137,6 +154,7 @@ def __init__( tau: float, lam: float, beta: float, + warmup_steps: int, device: str, ): super().__init__( @@ -149,6 +167,7 @@ def __init__( tau=tau, lam=lam, beta=beta, + warmup_steps=warmup_steps, device=device, ) diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index 37bc7997..b6612295 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -1,6 +1,6 @@ import dataclasses import math -from typing import Dict +from typing import Dict, Optional import torch import torch.nn.functional as F @@ -17,7 +17,7 @@ Policy, build_squashed_gaussian_distribution, ) -from ....torch_utility import Modules, TorchMiniBatch, hard_sync, train_api +from ....torch_utility import Modules, TorchMiniBatch, hard_sync from ..base import QLearningAlgoImplBase from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules from .utility import DiscreteQFunctionMixin @@ -29,7 +29,7 @@ class SACModules(DDPGBaseModules): policy: NormalPolicy log_temp: Parameter - temp_optim: Optimizer + temp_optim: Optional[Optimizer] class SACImpl(DDPGBaseImpl): @@ -68,8 +68,8 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: ) return (entropy - q_t).mean() - @train_api def update_temp(self, batch: TorchMiniBatch) -> Dict[str, float]: + assert self._modules.temp_optim self._modules.temp_optim.zero_grad() with torch.no_grad(): @@ -106,6 +106,15 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: ) return target - entropy + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + metrics = {} + if self._modules.temp_optim: + metrics.update(self.update_temp(batch)) + metrics.update(super().inner_update(batch, grad_step)) + return metrics + def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: dist = build_squashed_gaussian_distribution(self._modules.policy(x)) return dist.sample() @@ -119,13 +128,14 @@ class DiscreteSACModules(Modules): log_temp: Parameter actor_optim: Optimizer critic_optim: Optimizer - temp_optim: Optimizer + temp_optim: Optional[Optimizer] class DiscreteSACImpl(DiscreteQFunctionMixin, QLearningAlgoImplBase): _modules: DiscreteSACModules _q_func_forwarder: DiscreteEnsembleQFunctionForwarder _targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder + _target_update_interval: int def __init__( self, @@ -134,6 +144,7 @@ def __init__( modules: DiscreteSACModules, q_func_forwarder: DiscreteEnsembleQFunctionForwarder, targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + target_update_interval: int, gamma: float, device: str, ): @@ -146,9 +157,9 @@ def __init__( self._gamma = gamma self._q_func_forwarder = q_func_forwarder self._targ_q_func_forwarder = targ_q_func_forwarder + self._target_update_interval = target_update_interval hard_sync(modules.targ_q_funcs, modules.q_funcs) - @train_api def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.critic_optim.zero_grad() @@ -190,7 +201,6 @@ def compute_critic_loss( gamma=self._gamma**batch.intervals, ) - @train_api def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: # Q function should be inference mode for stability self._modules.q_funcs.eval() @@ -215,8 +225,8 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: entropy = self._modules.log_temp().exp() * log_probs return (probs * (entropy - q_t)).sum(dim=1).mean() - @train_api def update_temp(self, batch: TorchMiniBatch) -> Dict[str, float]: + assert self._modules.temp_optim self._modules.temp_optim.zero_grad() with torch.no_grad(): @@ -240,6 +250,22 @@ def update_temp(self, batch: TorchMiniBatch) -> Dict[str, float]: "temp": float(cur_temp), } + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + metrics = {} + + # lagrangian parameter update for SAC temeprature + if self._modules.temp_optim: + metrics.update(self.update_temp(batch)) + metrics.update(self.update_critic(batch)) + metrics.update(self.update_actor(batch)) + + if grad_step % self._target_update_interval == 0: + self.update_target() + + return metrics + def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: dist = self._modules.policy(x) return dist.probs.argmax(dim=1) diff --git a/d3rlpy/algos/qlearning/torch/td3_impl.py b/d3rlpy/algos/qlearning/torch/td3_impl.py index 896f3221..34cb127c 100644 --- a/d3rlpy/algos/qlearning/torch/td3_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_impl.py @@ -1,3 +1,5 @@ +from typing import Dict + import torch from ....dataset import Shape @@ -11,6 +13,7 @@ class TD3Impl(DDPGImpl): _target_smoothing_sigma: float _target_smoothing_clip: float + _update_actor_interval: int def __init__( self, @@ -23,6 +26,7 @@ def __init__( tau: float, target_smoothing_sigma: float, target_smoothing_clip: float, + update_actor_interval: int, device: str, ): super().__init__( @@ -37,6 +41,7 @@ def __init__( ) self._target_smoothing_sigma = target_smoothing_sigma self._target_smoothing_clip = target_smoothing_clip + self._update_actor_interval = update_actor_interval def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): @@ -54,3 +59,18 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: clipped_action, reduction="min", ) + + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + metrics = {} + + metrics.update(self.update_critic(batch)) + + # delayed policy update + if grad_step % self._update_actor_interval == 0: + metrics.update(self.update_actor(batch)) + self.update_critic_target() + self.update_actor_target() + + return metrics diff --git a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py index c3a93f8a..1edd9149 100644 --- a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py @@ -26,6 +26,7 @@ def __init__( target_smoothing_sigma: float, target_smoothing_clip: float, alpha: float, + update_actor_interval: int, device: str, ): super().__init__( @@ -38,6 +39,7 @@ def __init__( tau=tau, target_smoothing_sigma=target_smoothing_sigma, target_smoothing_clip=target_smoothing_clip, + update_actor_interval=update_actor_interval, device=device, ) self._alpha = alpha diff --git a/d3rlpy/algos/transformer/base.py b/d3rlpy/algos/transformer/base.py index 9da483ca..3f262617 100644 --- a/d3rlpy/algos/transformer/base.py +++ b/d3rlpy/algos/transformer/base.py @@ -24,7 +24,7 @@ LoggerAdapterFactory, ) from ...metrics import evaluate_transformer_with_environment -from ...torch_utility import TorchTrajectoryMiniBatch +from ...torch_utility import TorchTrajectoryMiniBatch, train_api from ..utility import ( assert_action_space_with_dataset, build_scalers_with_trajectory_slicer, @@ -44,6 +44,18 @@ class TransformerAlgoImplBase(ImplBase): def predict(self, inpt: TorchTransformerInput) -> torch.Tensor: ... + @train_api + def update( + self, batch: TorchTrajectoryMiniBatch, grad_step: int + ) -> Dict[str, float]: + return self.inner_update(batch, grad_step) + + @abstractmethod + def inner_update( + self, batch: TorchTrajectoryMiniBatch, grad_step: int + ) -> Dict[str, float]: + pass + @dataclasses.dataclass() class TransformerConfig(LearnableConfig): @@ -334,6 +346,7 @@ def update(self, batch: TrajectoryMiniBatch) -> Dict[str, float]: Returns: Dictionary of metrics. """ + assert self._impl, IMPL_NOT_INITIALIZED_ERROR torch_batch = TorchTrajectoryMiniBatch.from_batch( batch=batch, device=self._device, @@ -341,22 +354,10 @@ def update(self, batch: TrajectoryMiniBatch) -> Dict[str, float]: action_scaler=self._config.action_scaler, reward_scaler=self._config.reward_scaler, ) - loss = self.inner_update(torch_batch) + loss = self._impl.inner_update(torch_batch, self._grad_step) self._grad_step += 1 return loss - @abstractmethod - def inner_update(self, batch: TorchTrajectoryMiniBatch) -> Dict[str, float]: - """Update parameters with PyTorch mini-batch. - - Args: - batch: PyTorch mini-batch data. - - Returns: - Dictionary of metrics. - """ - raise NotImplementedError - def as_stateful_wrapper( self, target_return: float ) -> StatefulTransformerWrapper[TTransformerImpl, TTransformerConfig]: diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index 73603cd5..0c9d3be0 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -1,5 +1,4 @@ import dataclasses -from typing import Dict import torch @@ -13,7 +12,6 @@ make_optimizer_field, ) from ...models.builders import create_continuous_decision_transformer -from ...torch_utility import TorchTrajectoryMiniBatch from .base import TransformerAlgoBase, TransformerConfig from .torch.decision_transformer_impl import ( DecisionTransformerImpl, @@ -130,11 +128,6 @@ def inner_create_impl( device=self._device, ) - def inner_update(self, batch: TorchTrajectoryMiniBatch) -> Dict[str, float]: - assert self._impl - loss = self._impl.update(batch) - return loss - def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS diff --git a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py index 43345c5b..28ab9504 100644 --- a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py +++ b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py @@ -6,12 +6,7 @@ from ....dataset import Shape from ....models.torch import ContinuousDecisionTransformer -from ....torch_utility import ( - Modules, - TorchTrajectoryMiniBatch, - eval_api, - train_api, -) +from ....torch_utility import Modules, TorchTrajectoryMiniBatch, eval_api from ..base import TransformerAlgoImplBase from ..inputs import TorchTransformerInput @@ -56,8 +51,9 @@ def predict(self, inpt: TorchTransformerInput) -> torch.Tensor: # (1, T, A) -> (A,) return action[0][-1] - @train_api - def update(self, batch: TorchTrajectoryMiniBatch) -> Dict[str, float]: + def inner_update( + self, batch: TorchTrajectoryMiniBatch, grad_step: int + ) -> Dict[str, float]: self._modules.optim.zero_grad() loss = self.compute_loss(batch) diff --git a/d3rlpy/dataclass_utils.py b/d3rlpy/dataclass_utils.py index a0b269c1..e3422279 100644 --- a/d3rlpy/dataclass_utils.py +++ b/d3rlpy/dataclass_utils.py @@ -1,10 +1,25 @@ import dataclasses from typing import Any, Dict -__all__ = ["asdict_without_copy"] +import torch + +__all__ = ["asdict_without_copy", "asdict_as_float"] def asdict_without_copy(obj: Any) -> Dict[str, Any]: assert dataclasses.is_dataclass(obj) fields = dataclasses.fields(obj) return {field.name: getattr(obj, field.name) for field in fields} + + +def asdict_as_float(obj: Any) -> Dict[str, float]: + assert dataclasses.is_dataclass(obj) + fields = dataclasses.fields(obj) + ret: Dict[str, float] = {} + for field in fields: + value = getattr(obj, field.name) + if isinstance(value, torch.Tensor): + ret[field.name] = float(value.cpu().detach().numpy()) + else: + ret[field.name] = float(value) + return ret diff --git a/d3rlpy/ope/fqe.py b/d3rlpy/ope/fqe.py index d465e822..d8281017 100644 --- a/d3rlpy/ope/fqe.py +++ b/d3rlpy/ope/fqe.py @@ -1,15 +1,11 @@ import dataclasses -from typing import Dict, Optional +from typing import Optional import numpy as np from ..algos.qlearning import QLearningAlgoBase, QLearningAlgoImplBase from ..base import DeviceArg, LearnableConfig, register_learnable -from ..constants import ( - ALGO_NOT_GIVEN_ERROR, - IMPL_NOT_INITIALIZED_ERROR, - ActionSpace, -) +from ..constants import ALGO_NOT_GIVEN_ERROR, ActionSpace from ..dataset import Observation, Shape from ..models.builders import ( create_continuous_q_function, @@ -18,7 +14,6 @@ from ..models.encoders import EncoderFactory, make_encoder_field from ..models.optimizers import OptimizerFactory, make_optimizer_field from ..models.q_functions import QFunctionFactory, make_q_func_field -from ..torch_utility import TorchMiniBatch, convert_to_torch from .torch.fqe_impl import ( DiscreteFQEImpl, FQEBaseImpl, @@ -114,18 +109,6 @@ def sample_action(self, x: Observation) -> np.ndarray: assert self._algo is not None, ALGO_NOT_GIVEN_ERROR return self._algo.sample_action(x) - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: - assert self._algo is not None, ALGO_NOT_GIVEN_ERROR - assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR - assert batch.numpy_batch - next_actions = self._algo.predict(batch.numpy_batch.next_observations) - loss = self._impl.update( - batch, convert_to_torch(next_actions, self._device) - ) - if self._grad_step % self._config.target_update_interval == 0: - self._impl.update_target() - return {"loss": loss} - @property def algo(self) -> QLearningAlgoBase[QLearningAlgoImplBase, LearnableConfig]: return self._algo @@ -161,6 +144,8 @@ class FQE(_FQEBase): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + assert self._algo.impl, "The target algorithm is not initialized." + q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, action_size, @@ -190,10 +175,12 @@ def inner_create_impl( self._impl = FQEImpl( observation_shape=observation_shape, action_size=action_size, + algo=self._algo.impl, modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, + target_update_interval=self._config.target_update_interval, device=self._device, ) @@ -233,6 +220,8 @@ class DiscreteFQE(_FQEBase): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + assert self._algo.impl, "The target algorithm is not initialized." + q_funcs, q_func_forwarder = create_discrete_q_function( observation_shape, action_size, @@ -260,10 +249,12 @@ def inner_create_impl( self._impl = DiscreteFQEImpl( observation_shape=observation_shape, action_size=action_size, + algo=self._algo.impl, modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, + target_update_interval=self._config.target_update_interval, device=self._device, ) diff --git a/d3rlpy/ope/torch/fqe_impl.py b/d3rlpy/ope/torch/fqe_impl.py index ab0b94bd..c98e3925 100644 --- a/d3rlpy/ope/torch/fqe_impl.py +++ b/d3rlpy/ope/torch/fqe_impl.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Union +from typing import Dict, Union import torch from torch import nn @@ -15,7 +15,7 @@ ContinuousEnsembleQFunctionForwarder, DiscreteEnsembleQFunctionForwarder, ) -from ...torch_utility import Modules, TorchMiniBatch, hard_sync, train_api +from ...torch_utility import Modules, TorchMiniBatch, hard_sync __all__ = ["FQEBaseImpl", "FQEImpl", "DiscreteFQEImpl", "FQEBaseModules"] @@ -28,6 +28,7 @@ class FQEBaseModules(Modules): class FQEBaseImpl(QLearningAlgoImplBase): + _algo: QLearningAlgoImplBase _modules: FQEBaseModules _gamma: float _q_func_forwarder: Union[ @@ -36,11 +37,13 @@ class FQEBaseImpl(QLearningAlgoImplBase): _targ_q_func_forwarder: Union[ DiscreteEnsembleQFunctionForwarder, ContinuousEnsembleQFunctionForwarder ] + _target_update_interval: int def __init__( self, observation_shape: Shape, action_size: int, + algo: QLearningAlgoImplBase, modules: FQEBaseModules, q_func_forwarder: Union[ DiscreteEnsembleQFunctionForwarder, @@ -51,6 +54,7 @@ def __init__( ContinuousEnsembleQFunctionForwarder, ], gamma: float, + target_update_interval: int, device: str, ): super().__init__( @@ -59,24 +63,13 @@ def __init__( modules=modules, device=device, ) + self._algo = algo self._gamma = gamma self._q_func_forwarder = q_func_forwarder self._targ_q_func_forwarder = targ_q_func_forwarder + self._target_update_interval = target_update_interval hard_sync(modules.targ_q_funcs, modules.q_funcs) - @train_api - def update( - self, batch: TorchMiniBatch, next_actions: torch.Tensor - ) -> float: - q_tpn = self.compute_target(batch, next_actions) - loss = self.compute_loss(batch, q_tpn) - - self._modules.optim.zero_grad() - loss.backward() - self._modules.optim.step() - - return float(loss.cpu().detach().numpy()) - def compute_loss( self, batch: TorchMiniBatch, @@ -108,6 +101,23 @@ def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + next_actions = self._algo.predict_best_action(batch.next_observations) + + q_tpn = self.compute_target(batch, next_actions) + loss = self.compute_loss(batch, q_tpn) + + self._modules.optim.zero_grad() + loss.backward() + self._modules.optim.step() + + if grad_step % self._target_update_interval == 0: + self.update_target() + + return {"loss": float(loss.cpu().detach().numpy())} + class FQEImpl(ContinuousQFunctionMixin, FQEBaseImpl): _q_func_forwarder: DiscreteEnsembleQFunctionForwarder diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index 14ea9493..a4651234 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -217,7 +217,12 @@ def modules(self) -> Dict[str, Union[nn.Module, Optimizer]]: @dataclasses.dataclass(frozen=True) class Modules: def create_checkpointer(self, device: str) -> Checkpointer: - return Checkpointer(modules=asdict_without_copy(self), device=device) + modules = { + k: v + for k, v in asdict_without_copy(self).items() + if isinstance(v, (nn.Module, torch.optim.Optimizer)) + } + return Checkpointer(modules=modules, device=device) def freeze(self) -> None: for v in asdict_without_copy(self).values(): diff --git a/tests/test_dataclass_utils.py b/tests/test_dataclass_utils.py index 22094b76..d710a4cf 100644 --- a/tests/test_dataclass_utils.py +++ b/tests/test_dataclass_utils.py @@ -1,6 +1,8 @@ import dataclasses -from d3rlpy.dataclass_utils import asdict_without_copy +import torch + +from d3rlpy.dataclass_utils import asdict_as_float, asdict_without_copy @dataclasses.dataclass(frozen=True) @@ -22,3 +24,17 @@ def test_asdict_without_any() -> None: assert dict_d["a"] is a assert dict_d["b"] == 2.0 assert dict_d["c"] == "3" + + +@dataclasses.dataclass(frozen=True) +class D2: + a: float + b: torch.Tensor + + +def test_asdict_as_float() -> None: + b = torch.rand([], dtype=torch.float32) + d = D2(a=1.0, b=b) + dict_d = asdict_as_float(d) + assert dict_d["a"] == 1.0 + assert dict_d["b"] == b.numpy() diff --git a/tests/test_torch_utility.py b/tests/test_torch_utility.py index 33c080a4..087c5507 100644 --- a/tests/test_torch_utility.py +++ b/tests/test_torch_utility.py @@ -112,7 +112,7 @@ def __init__(self) -> None: self.fc1 = torch.nn.Linear(100, 100) self.fc2 = torch.nn.Linear(100, 100) self.optim = torch.optim.Adam(self.fc1.parameters()) - self.modules = TestModules(self.fc1, self.optim) + self.modules = DummyModules(self.fc1, self.optim) self.device = "cpu:0" @train_api @@ -164,7 +164,7 @@ def test_to_cpu() -> None: @dataclasses.dataclass(frozen=True) -class TestModules(Modules): +class DummyModules(Modules): fc: torch.nn.Linear optim: torch.optim.Adam @@ -172,7 +172,7 @@ class TestModules(Modules): def test_modules() -> None: fc = torch.nn.Linear(100, 200) optim = torch.optim.Adam(fc.parameters()) - modules = TestModules(fc, optim) + modules = DummyModules(fc, optim) # check checkpointer checkpointer = modules.create_checkpointer("cpu:0") From 2846d2040931c000f624ec95bc80ed7a68eb20e4 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Tue, 29 Aug 2023 11:57:27 +0100 Subject: [PATCH 15/20] corrected formatting --- d3rlpy/algos/qlearning/torch/cql_impl.py | 23 ++++++++++++++--------- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 12 ++++++++---- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 31ba9606..15525bed 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -15,14 +15,17 @@ build_squashed_gaussian_distribution, ) from ....torch_utility import TorchMiniBatch +from .ddpg_impl import DDPGCriticLoss from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules from .sac_impl import SACImpl, SACModules -from .ddpg_impl import DDPGCriticLoss __all__ = [ - "CQLImpl", "DiscreteCQLImpl", "CQLModules", "DiscreteCQLLoss", - "CQLLoss" - ] + "CQLImpl", + "DiscreteCQLImpl", + "CQLModules", + "DiscreteCQLLoss", + "CQLLoss", +] @dataclasses.dataclass(frozen=True) @@ -82,9 +85,10 @@ def compute_critic_loss( batch.observations, batch.actions, batch.next_observations ) return CQLLoss( - loss=loss+conservative_loss, td_loss=loss, - conservative_loss=conservative_loss - ) + loss=loss + conservative_loss, + td_loss=loss, + conservative_loss=conservative_loss, + ) def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.critic_optim.zero_grad() @@ -300,6 +304,7 @@ def compute_loss( ) loss = td_loss + self._alpha * conservative_loss return DiscreteCQLLoss( - loss=loss, td_loss=td_loss, - conservative_loss=self._alpha * conservative_loss + loss=loss, + td_loss=td_loss, + conservative_loss=self._alpha * conservative_loss, ) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index ce442908..26a09d7e 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -14,9 +14,12 @@ from .utility import ContinuousQFunctionMixin __all__ = [ - "DDPGImpl", "DDPGBaseImpl", "DDPGBaseModules", "DDPGModules", - "DDPGCriticLoss" - ] + "DDPGImpl", + "DDPGBaseImpl", + "DDPGBaseModules", + "DDPGModules", + "DDPGCriticLoss", +] @dataclasses.dataclass(frozen=True) @@ -27,6 +30,7 @@ class DDPGBaseModules(Modules): actor_optim: Optimizer critic_optim: Optimizer + @dataclasses.dataclass(frozen=True) class DDPGCriticLoss: loss: torch.Tensor @@ -75,7 +79,7 @@ def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.critic_optim.step() return asdict_as_float(loss) - + def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor ) -> DDPGCriticLoss: From 8e5aec857f7e8983c896fe5ef646d349d39d716c Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Tue, 29 Aug 2023 16:37:35 +0100 Subject: [PATCH 16/20] generalised loss dataclass --- d3rlpy/algos/qlearning/torch/cql_impl.py | 26 ++++++++--------------- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 12 +++-------- d3rlpy/algos/qlearning/torch/utility.py | 16 +++++++++++++- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 15525bed..2955de97 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -15,15 +15,14 @@ build_squashed_gaussian_distribution, ) from ....torch_utility import TorchMiniBatch -from .ddpg_impl import DDPGCriticLoss from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules from .sac_impl import SACImpl, SACModules +from .utility import CriticLoss __all__ = [ "CQLImpl", "DiscreteCQLImpl", "CQLModules", - "DiscreteCQLLoss", "CQLLoss", ] @@ -35,10 +34,12 @@ class CQLModules(SACModules): @dataclasses.dataclass(frozen=True) -class CQLLoss(DDPGCriticLoss): - td_loss: torch.Tensor +class CQLLoss(CriticLoss): conservative_loss: torch.Tensor + def get_loss(self): + return super().get_loss() + self.conservative_loss + class CQLImpl(SACImpl): _modules: CQLModules @@ -80,13 +81,12 @@ def __init__( def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor ) -> CQLLoss: - loss = super().compute_critic_loss(batch, q_tpn).loss + td_loss = super().compute_critic_loss(batch, q_tpn).loss conservative_loss = self._compute_conservative_loss( batch.observations, batch.actions, batch.next_observations ) return CQLLoss( - loss=loss + conservative_loss, - td_loss=loss, + td_loss=td_loss, conservative_loss=conservative_loss, ) @@ -247,12 +247,6 @@ def inner_update( return metrics -@dataclasses.dataclass(frozen=True) -class DiscreteCQLLoss(DQNLoss): - td_loss: torch.Tensor - conservative_loss: torch.Tensor - - class DiscreteCQLImpl(DoubleDQNImpl): _alpha: float @@ -297,14 +291,12 @@ def compute_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor, - ) -> DiscreteCQLLoss: + ) -> CQLLoss: td_loss = super().compute_loss(batch, q_tpn).loss conservative_loss = self._compute_conservative_loss( batch.observations, batch.actions.long() ) - loss = td_loss + self._alpha * conservative_loss - return DiscreteCQLLoss( - loss=loss, + return CQLLoss( td_loss=td_loss, conservative_loss=self._alpha * conservative_loss, ) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 26a09d7e..5732524f 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -11,14 +11,13 @@ from ....models.torch import ContinuousEnsembleQFunctionForwarder, Policy from ....torch_utility import Modules, TorchMiniBatch, hard_sync, soft_sync from ..base import QLearningAlgoImplBase -from .utility import ContinuousQFunctionMixin +from .utility import ContinuousQFunctionMixin, CriticLoss __all__ = [ "DDPGImpl", "DDPGBaseImpl", "DDPGBaseModules", "DDPGModules", - "DDPGCriticLoss", ] @@ -31,11 +30,6 @@ class DDPGBaseModules(Modules): critic_optim: Optimizer -@dataclasses.dataclass(frozen=True) -class DDPGCriticLoss: - loss: torch.Tensor - - class DDPGBaseImpl( ContinuousQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta ): @@ -82,7 +76,7 @@ def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor - ) -> DDPGCriticLoss: + ) -> CriticLoss: loss = self._q_func_forwarder.compute_error( observations=batch.observations, actions=batch.actions, @@ -91,7 +85,7 @@ def compute_critic_loss( terminals=batch.terminals, gamma=self._gamma**batch.intervals, ) - return DDPGCriticLoss(loss=loss) + return CriticLoss(td_loss=loss) def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: # Q function should be inference mode for stability diff --git a/d3rlpy/algos/qlearning/torch/utility.py b/d3rlpy/algos/qlearning/torch/utility.py index 15ed5551..056de6fc 100644 --- a/d3rlpy/algos/qlearning/torch/utility.py +++ b/d3rlpy/algos/qlearning/torch/utility.py @@ -1,3 +1,5 @@ +import dataclasses + import torch from typing_extensions import Protocol @@ -6,7 +8,7 @@ DiscreteEnsembleQFunctionForwarder, ) -__all__ = ["DiscreteQFunctionMixin", "ContinuousQFunctionMixin"] +__all__ = ["DiscreteQFunctionMixin", "ContinuousQFunctionMixin", "CriticLoss"] class _DiscreteQFunctionProtocol(Protocol): @@ -35,3 +37,15 @@ def inner_predict_value( return self._q_func_forwarder.compute_expected_q( x, action, reduction="mean" ).reshape(-1) + + +@dataclasses.dataclass(frozen=True) +class CriticLoss: + td_loss: torch.Tensor + loss: torch.Tensor = dataclasses.field(init=False) + + def __post_init__(self): + object.__setattr__(self, "loss", self.get_loss()) + + def get_loss(self): + return self.td_loss From fa90b6c7562be9219d9653472fff8a1fa90eea10 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Tue, 29 Aug 2023 16:39:26 +0100 Subject: [PATCH 17/20] added method returns --- d3rlpy/algos/qlearning/torch/cql_impl.py | 2 +- d3rlpy/algos/qlearning/torch/utility.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 2955de97..7b5d6ca4 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -37,7 +37,7 @@ class CQLModules(SACModules): class CQLLoss(CriticLoss): conservative_loss: torch.Tensor - def get_loss(self): + def get_loss(self)->torch.Tensor: return super().get_loss() + self.conservative_loss diff --git a/d3rlpy/algos/qlearning/torch/utility.py b/d3rlpy/algos/qlearning/torch/utility.py index 056de6fc..96571fb5 100644 --- a/d3rlpy/algos/qlearning/torch/utility.py +++ b/d3rlpy/algos/qlearning/torch/utility.py @@ -44,8 +44,8 @@ class CriticLoss: td_loss: torch.Tensor loss: torch.Tensor = dataclasses.field(init=False) - def __post_init__(self): + def __post_init__(self)->None: object.__setattr__(self, "loss", self.get_loss()) - def get_loss(self): + def get_loss(self)->torch.Tensor: return self.td_loss From b1c1e62f5899e9ee001d7097c5792d9b47541373 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Tue, 29 Aug 2023 16:59:31 +0100 Subject: [PATCH 18/20] updated formatting --- d3rlpy/algos/qlearning/torch/cql_impl.py | 2 +- d3rlpy/algos/qlearning/torch/dqn_impl.py | 13 ++++--------- d3rlpy/algos/qlearning/torch/utility.py | 4 ++-- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 7b5d6ca4..c8b23371 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -37,7 +37,7 @@ class CQLModules(SACModules): class CQLLoss(CriticLoss): conservative_loss: torch.Tensor - def get_loss(self)->torch.Tensor: + def get_loss(self) -> torch.Tensor: return super().get_loss() + self.conservative_loss diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index e5caee13..e3e2af0c 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -10,9 +10,9 @@ from ....models.torch import DiscreteEnsembleQFunctionForwarder from ....torch_utility import Modules, TorchMiniBatch, hard_sync from ..base import QLearningAlgoImplBase -from .utility import DiscreteQFunctionMixin +from .utility import CriticLoss, DiscreteQFunctionMixin -__all__ = ["DQNImpl", "DQNModules", "DQNLoss", "DoubleDQNImpl"] +__all__ = ["DQNImpl", "DQNModules", "DoubleDQNImpl"] @dataclasses.dataclass(frozen=True) @@ -22,11 +22,6 @@ class DQNModules(Modules): optim: Optimizer -@dataclasses.dataclass(frozen=True) -class DQNLoss: - loss: torch.Tensor - - class DQNImpl(DiscreteQFunctionMixin, QLearningAlgoImplBase): _modules: DQNModules _gamma: float @@ -78,7 +73,7 @@ def compute_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor, - ) -> DQNLoss: + ) -> CriticLoss: loss = self._q_func_forwarder.compute_error( observations=batch.observations, actions=batch.actions.long(), @@ -87,7 +82,7 @@ def compute_loss( terminals=batch.terminals, gamma=self._gamma**batch.intervals, ) - return DQNLoss(loss=loss) + return CriticLoss(td_loss=loss) def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): diff --git a/d3rlpy/algos/qlearning/torch/utility.py b/d3rlpy/algos/qlearning/torch/utility.py index 96571fb5..11909ad1 100644 --- a/d3rlpy/algos/qlearning/torch/utility.py +++ b/d3rlpy/algos/qlearning/torch/utility.py @@ -44,8 +44,8 @@ class CriticLoss: td_loss: torch.Tensor loss: torch.Tensor = dataclasses.field(init=False) - def __post_init__(self)->None: + def __post_init__(self) -> None: object.__setattr__(self, "loss", self.get_loss()) - def get_loss(self)->torch.Tensor: + def get_loss(self) -> torch.Tensor: return self.td_loss From 5bd59a1e9e1b6c687ec1fe8c6b3b6301e19f709b Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Tue, 29 Aug 2023 17:02:56 +0100 Subject: [PATCH 19/20] updated bcq loss --- d3rlpy/algos/qlearning/torch/bcq_impl.py | 13 ++++++++----- d3rlpy/algos/qlearning/torch/cql_impl.py | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index fcedd24c..0e92254b 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -20,7 +20,8 @@ ) from ....torch_utility import TorchMiniBatch, soft_sync from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules -from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules +from .dqn_impl import DoubleDQNImpl, DQNModules +from .utility import CriticLoss __all__ = [ "BCQImpl", @@ -201,10 +202,13 @@ class DiscreteBCQModules(DQNModules): @dataclasses.dataclass(frozen=True) -class DiscreteBCQLoss(DQNLoss): - td_loss: torch.Tensor +class DiscreteBCQLoss(CriticLoss): imitator_loss: torch.Tensor + def get_loss(self) -> torch.Tensor: + return super().get_loss() + self.imitator_loss + + class DiscreteBCQImpl(DoubleDQNImpl): _modules: DiscreteBCQModules @@ -247,9 +251,8 @@ def compute_loss( action=batch.actions.long(), beta=self._beta, ) - loss = td_loss + imitator_loss return DiscreteBCQLoss( - loss=loss, td_loss=td_loss, imitator_loss=imitator_loss + td_loss=td_loss, imitator_loss=imitator_loss ) def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index c8b23371..590c2e06 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -15,7 +15,7 @@ build_squashed_gaussian_distribution, ) from ....torch_utility import TorchMiniBatch -from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules +from .dqn_impl import DoubleDQNImpl, DQNModules from .sac_impl import SACImpl, SACModules from .utility import CriticLoss From 80b9579b7b5435cb6d9bc3555b444de59f24713f Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Tue, 29 Aug 2023 17:03:18 +0100 Subject: [PATCH 20/20] updated formatting --- d3rlpy/algos/qlearning/torch/bcq_impl.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index 0e92254b..7827c33e 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -209,7 +209,6 @@ def get_loss(self) -> torch.Tensor: return super().get_loss() + self.imitator_loss - class DiscreteBCQImpl(DoubleDQNImpl): _modules: DiscreteBCQModules _action_flexibility: float @@ -251,9 +250,7 @@ def compute_loss( action=batch.actions.long(), beta=self._beta, ) - return DiscreteBCQLoss( - td_loss=td_loss, imitator_loss=imitator_loss - ) + return DiscreteBCQLoss(td_loss=td_loss, imitator_loss=imitator_loss) def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor: dist = self._modules.imitator(x)