Skip to content

Commit

Permalink
Refactor DQN variants
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 4, 2024
1 parent 5400c22 commit 9a77f40
Show file tree
Hide file tree
Showing 8 changed files with 335 additions and 164 deletions.
44 changes: 34 additions & 10 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .functional import FunctionalQLearningAlgoImplBase
from .torch.bcq_impl import (
BCQImpl,
BCQModules,
DiscreteBCQImpl,
DiscreteBCQActionSampler,
DiscreteBCQLossFn,
DiscreteBCQModules,
)
from .torch.dqn_impl import DQNUpdater, DQNValuePredictor

__all__ = ["BCQConfig", "BCQ", "DiscreteBCQConfig", "DiscreteBCQ"]

Expand Down Expand Up @@ -363,7 +366,7 @@ def get_type() -> str:
return "discrete_bcq"


class DiscreteBCQ(QLearningAlgoBase[DiscreteBCQImpl, DiscreteBCQConfig]):
class DiscreteBCQ(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, DiscreteBCQConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
Expand Down Expand Up @@ -422,17 +425,38 @@ def inner_create_impl(
optim=optim,
)

self._impl = DiscreteBCQImpl(
observation_shape=observation_shape,
action_size=action_size,
# build functional components
updater = DQNUpdater(
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
dqn_loss_fn=DiscreteBCQLossFn(
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
beta=self._config.beta,
),
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
compiled=self.compiled,
)
action_sampler = DiscreteBCQActionSampler(
modules=modules,
q_func_forwarder=q_func_forwarder,
action_flexibility=self._config.action_flexibility,
)
value_predictor = DQNValuePredictor(q_func_forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
updater=updater,
exploit_action_sampler=action_sampler,
explore_action_sampler=action_sampler,
value_predictor=value_predictor,
q_function=q_funcs,
q_function_optim=optim.optim,
policy=None,
policy_optim=None,
device=self._device,
)

Expand Down
44 changes: 34 additions & 10 deletions d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.cql_impl import CQLImpl, CQLModules, DiscreteCQLImpl
from .torch.dqn_impl import DQNModules
from .functional import FunctionalQLearningAlgoImplBase
from .torch.cql_impl import CQLImpl, CQLModules, DiscreteCQLLossFn
from .torch.dqn_impl import (
DQNActionSampler,
DQNModules,
DQNUpdater,
DQNValuePredictor,
)

__all__ = ["CQLConfig", "CQL", "DiscreteCQLConfig", "DiscreteCQL"]

Expand Down Expand Up @@ -304,7 +310,7 @@ def get_type() -> str:
return "discrete_cql"


class DiscreteCQL(QLearningAlgoBase[DiscreteCQLImpl, DiscreteCQLConfig]):
class DiscreteCQL(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, DiscreteCQLConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
Expand Down Expand Up @@ -339,16 +345,34 @@ def inner_create_impl(
optim=optim,
)

self._impl = DiscreteCQLImpl(
observation_shape=observation_shape,
action_size=action_size,
# build functional components
updater = DQNUpdater(
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
dqn_loss_fn=DiscreteCQLLossFn(
action_size=action_size,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
alpha=self._config.alpha,
),
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
alpha=self._config.alpha,
compiled=self.compiled,
)
action_sampler = DQNActionSampler(q_func_forwarder)
value_predictor = DQNValuePredictor(q_func_forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
updater=updater,
exploit_action_sampler=action_sampler,
explore_action_sampler=action_sampler,
value_predictor=value_predictor,
q_function=q_funcs,
q_function_optim=optim.optim,
policy=None,
policy_optim=None,
device=self._device,
)

Expand Down
70 changes: 56 additions & 14 deletions d3rlpy/algos/qlearning/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.dqn_impl import DoubleDQNImpl, DQNImpl, DQNModules
from .functional import FunctionalQLearningAlgoImplBase
from .torch.dqn_impl import (
DoubleDQNLossFn,
DQNActionSampler,
DQNLossFn,
DQNModules,
DQNUpdater,
DQNValuePredictor,
)

__all__ = ["DQNConfig", "DQN", "DoubleDQNConfig", "DoubleDQN"]

Expand Down Expand Up @@ -66,7 +74,7 @@ def get_type() -> str:
return "dqn"


class DQN(QLearningAlgoBase[DQNImpl, DQNConfig]):
class DQN(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, DQNConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
Expand Down Expand Up @@ -101,15 +109,32 @@ def inner_create_impl(
optim=optim,
)

self._impl = DQNImpl(
# build functional components
updater = DQNUpdater(
modules=modules,
dqn_loss_fn=DQNLossFn(
q_func_forwarder=forwarder,
targ_q_func_forwarder=targ_forwarder,
gamma=self._config.gamma,
),
target_update_interval=self._config.target_update_interval,
compiled=self.compiled,
)
action_sampler = DQNActionSampler(forwarder)
value_predictor = DQNValuePredictor(forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
observation_shape=observation_shape,
action_size=action_size,
q_func_forwarder=forwarder,
targ_q_func_forwarder=targ_forwarder,
target_update_interval=self._config.target_update_interval,
modules=modules,
gamma=self._config.gamma,
compiled=self.compiled,
updater=updater,
exploit_action_sampler=action_sampler,
explore_action_sampler=action_sampler,
value_predictor=value_predictor,
q_function=q_funcs,
q_function_optim=optim.optim,
policy=None,
policy_optim=None,
device=self._device,
)

Expand Down Expand Up @@ -212,15 +237,32 @@ def inner_create_impl(
optim=optim,
)

self._impl = DoubleDQNImpl(
observation_shape=observation_shape,
action_size=action_size,
# build functional components
updater = DQNUpdater(
modules=modules,
q_func_forwarder=forwarder,
targ_q_func_forwarder=targ_forwarder,
dqn_loss_fn=DoubleDQNLossFn(
q_func_forwarder=forwarder,
targ_q_func_forwarder=targ_forwarder,
gamma=self._config.gamma,
),
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
compiled=self.compiled,
)
action_sampler = DQNActionSampler(forwarder)
value_predictor = DQNValuePredictor(forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
updater=updater,
exploit_action_sampler=action_sampler,
explore_action_sampler=action_sampler,
value_predictor=value_predictor,
q_function=q_funcs,
q_function_optim=optim.optim,
policy=None,
policy_optim=None,
device=self._device,
)

Expand Down
92 changes: 92 additions & 0 deletions d3rlpy/algos/qlearning/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Optional, Protocol

import torch
from torch import nn

from ...models.torch.policies import Policy
from ...torch_utility import Modules, TorchMiniBatch
from ...types import Shape, TorchObservation
from .base import QLearningAlgoImplBase

__all__ = ["Updater", "ActionSampler", "ValuePredictor", "FunctionalQLearningAlgoImplBase"]


class Updater(Protocol):
def __call__(self, batch: TorchMiniBatch, grad_step: int) -> dict[str, float]:
...


class ActionSampler(Protocol):
def __call__(self, x: TorchObservation) -> torch.Tensor:
...


class ValuePredictor(Protocol):
def __call__(self, x: TorchObservation, action: torch.Tensor) -> torch.Tensor:
...


class FunctionalQLearningAlgoImplBase(QLearningAlgoImplBase):
def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: Modules,
updater: Updater,
exploit_action_sampler: ActionSampler,
explore_action_sampler: ActionSampler,
value_predictor: ValuePredictor,
q_function: nn.ModuleList,
q_function_optim: torch.optim.Optimizer,
policy: Optional[Policy],
policy_optim: Optional[torch.optim.Optimizer],
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
device=device,
)
self._updater = updater
self._exploit_action_sampler = exploit_action_sampler
self._explore_action_sampler = explore_action_sampler
self._value_predictor = value_predictor
self._q_function = q_function
self._q_function_optim = q_function_optim
self._policy = policy
self._policy_optim = policy_optim

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> dict[str, float]:
return self._updater(batch, grad_step)

def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:
return self._exploit_action_sampler(x)

def inner_sample_action(self, x: TorchObservation) -> torch.Tensor:
return self._explore_action_sampler(x)

def inner_predict_value(
self, x: TorchObservation, action: torch.Tensor
) -> torch.Tensor:
return self._value_predictor(x, action)

@property
def policy(self) -> Policy:
assert self._policy
return self._policy

@property
def policy_optim(self) -> torch.optim.Optimizer:
assert self._policy_optim
return self._policy_optim

@property
def q_function(self) -> nn.ModuleList:
return self._q_function

@property
def q_function_optim(self) -> torch.optim.Optimizer:
return self._q_function_optim
Loading

0 comments on commit 9a77f40

Please sign in to comment.