Skip to content

Commit

Permalink
Simpify Q functions with forwarders
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Aug 14, 2023
1 parent b8263d4 commit 3ec25d7
Show file tree
Hide file tree
Showing 44 changed files with 1,614 additions and 682 deletions.
17 changes: 14 additions & 3 deletions d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,15 @@ def inner_create_impl(
use_std_parameter=True,
device=self._device,
)
q_func = create_continuous_q_function(
q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
Expand All @@ -119,13 +127,16 @@ def inner_create_impl(
policy.parameters(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_func.parameters(), lr=self._config.critic_learning_rate
q_funcs.parameters(), lr=self._config.critic_learning_rate
)

self._impl = AWACImpl(
observation_shape=observation_shape,
action_size=action_size,
q_func=q_func,
q_funcs=q_funcs,
q_func_forwarder=q_func_forwarder,
targ_q_funcs=targ_q_funcs,
targ_q_func_forwarder=targ_q_func_forwarder,
policy=policy,
actor_optim=actor_optim,
critic_optim=critic_optim,
Expand Down
11 changes: 6 additions & 5 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import torch
from torch import nn
from tqdm.auto import tqdm, trange
from typing_extensions import Self

Expand All @@ -36,7 +37,7 @@
LoggerAdapterFactory,
)
from ...metrics import EvaluatorProtocol, evaluate_qlearning_with_environment
from ...models.torch import EnsembleQFunction, Policy
from ...models.torch import Policy
from ...torch_utility import (
TorchMiniBatch,
convert_to_torch,
Expand Down Expand Up @@ -119,15 +120,15 @@ def copy_policy_optim_from(self, impl: "QLearningAlgoImplBase") -> None:
sync_optimizer_state(self.policy_optim, impl.policy_optim)

@property
def q_function(self) -> EnsembleQFunction:
def q_function(self) -> nn.ModuleList:
raise NotImplementedError

def copy_q_function_from(self, impl: "QLearningAlgoImplBase") -> None:
q_func = self.q_function.q_funcs[0]
if not isinstance(impl.q_function.q_funcs[0], type(q_func)):
q_func = self.q_function[0]
if not isinstance(impl.q_function[0], type(q_func)):
raise ValueError(
f"Invalid Q-function type: expected={type(q_func)},"
f"actual={type(impl.q_function.q_funcs[0])}"
f"actual={type(impl.q_function[0])}"
)
hard_sync(self.q_function, impl.q_function)

Expand Down
40 changes: 31 additions & 9 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,15 @@ def inner_create_impl(
self._config.actor_encoder_factory,
device=self._device,
)
q_func = create_continuous_q_function(
q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
Expand All @@ -196,7 +204,7 @@ def inner_create_impl(
policy.parameters(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_func.parameters(), lr=self._config.critic_learning_rate
q_funcs.parameters(), lr=self._config.critic_learning_rate
)
imitator_optim = self._config.imitator_optim_factory.create(
imitator.parameters(), lr=self._config.imitator_learning_rate
Expand All @@ -206,7 +214,10 @@ def inner_create_impl(
observation_shape=observation_shape,
action_size=action_size,
policy=policy,
q_func=q_func,
q_funcs=q_funcs,
q_func_forwarder=q_func_forwarder,
targ_q_funcs=targ_q_funcs,
targ_q_func_forwarder=targ_q_func_forwarder,
imitator=imitator,
actor_optim=actor_optim,
critic_optim=critic_optim,
Expand Down Expand Up @@ -323,7 +334,15 @@ class DiscreteBCQ(QLearningAlgoBase[DiscreteBCQImpl, DiscreteBCQConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
q_func = create_discrete_q_function(
q_funcs, q_func_forwarder = create_discrete_q_function(
observation_shape,
action_size,
self._config.encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function(
observation_shape,
action_size,
self._config.encoder_factory,
Expand All @@ -333,14 +352,14 @@ def inner_create_impl(
)

# share convolutional layers if observation is pixel
if isinstance(q_func.q_funcs[0].encoder, PixelEncoder):
if isinstance(q_funcs[0].encoder, PixelEncoder):
hidden_size = compute_output_size(
[observation_shape],
q_func.q_funcs[0].encoder,
q_funcs[0].encoder,
device=self._device,
)
imitator = CategoricalPolicy(
encoder=q_func.q_funcs[0].encoder,
encoder=q_funcs[0].encoder,
hidden_size=hidden_size,
action_size=action_size,
)
Expand All @@ -355,7 +374,7 @@ def inner_create_impl(

# TODO: replace this with a cleaner way
# retrieve unique elements
q_func_params = list(q_func.parameters())
q_func_params = list(q_funcs.parameters())
imitator_params = list(imitator.parameters())
unique_dict = {}
for param in q_func_params + imitator_params:
Expand All @@ -368,7 +387,10 @@ def inner_create_impl(
self._impl = DiscreteBCQImpl(
observation_shape=observation_shape,
action_size=action_size,
q_func=q_func,
q_funcs=q_funcs,
q_func_forwarder=q_func_forwarder,
targ_q_funcs=targ_q_funcs,
targ_q_func_forwarder=targ_q_func_forwarder,
imitator=imitator,
optim=optim,
gamma=self._config.gamma,
Expand Down
17 changes: 14 additions & 3 deletions d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,15 @@ def inner_create_impl(
self._config.actor_encoder_factory,
device=self._device,
)
q_func = create_continuous_q_function(
q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
Expand Down Expand Up @@ -194,7 +202,7 @@ def inner_create_impl(
policy.parameters(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_func.parameters(), lr=self._config.critic_learning_rate
q_funcs.parameters(), lr=self._config.critic_learning_rate
)
imitator_optim = self._config.imitator_optim_factory.create(
imitator.parameters(), lr=self._config.imitator_learning_rate
Expand All @@ -210,7 +218,10 @@ def inner_create_impl(
observation_shape=observation_shape,
action_size=action_size,
policy=policy,
q_func=q_func,
q_funcs=q_funcs,
q_func_forwarder=q_func_forwarder,
targ_q_funcs=targ_q_funcs,
targ_q_func_forwarder=targ_q_func_forwarder,
imitator=imitator,
log_temp=log_temp,
log_alpha=log_alpha,
Expand Down
34 changes: 28 additions & 6 deletions d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,15 @@ def inner_create_impl(
self._config.actor_encoder_factory,
device=self._device,
)
q_func = create_continuous_q_function(
q_funcs, q_func_fowarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
Expand All @@ -162,7 +170,7 @@ def inner_create_impl(
policy.parameters(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_func.parameters(), lr=self._config.critic_learning_rate
q_funcs.parameters(), lr=self._config.critic_learning_rate
)
temp_optim = self._config.temp_optim_factory.create(
log_temp.parameters(), lr=self._config.temp_learning_rate
Expand All @@ -175,7 +183,10 @@ def inner_create_impl(
observation_shape=observation_shape,
action_size=action_size,
policy=policy,
q_func=q_func,
q_funcs=q_funcs,
q_func_forwarder=q_func_fowarder,
targ_q_funcs=targ_q_funcs,
targ_q_func_forwarder=targ_q_func_forwarder,
log_temp=log_temp,
log_alpha=log_alpha,
actor_optim=actor_optim,
Expand Down Expand Up @@ -279,7 +290,15 @@ class DiscreteCQL(QLearningAlgoBase[DiscreteCQLImpl, DiscreteCQLConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
q_func = create_discrete_q_function(
q_funcs, q_func_forwarder = create_discrete_q_function(
observation_shape,
action_size,
self._config.encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function(
observation_shape,
action_size,
self._config.encoder_factory,
Expand All @@ -289,13 +308,16 @@ def inner_create_impl(
)

optim = self._config.optim_factory.create(
q_func.parameters(), lr=self._config.learning_rate
q_funcs.parameters(), lr=self._config.learning_rate
)

self._impl = DiscreteCQLImpl(
observation_shape=observation_shape,
action_size=action_size,
q_func=q_func,
q_funcs=q_funcs,
q_func_forwarder=q_func_forwarder,
targ_q_funcs=targ_q_funcs,
targ_q_func_forwarder=targ_q_func_forwarder,
optim=optim,
gamma=self._config.gamma,
alpha=self._config.alpha,
Expand Down
17 changes: 14 additions & 3 deletions d3rlpy/algos/qlearning/crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,15 @@ def inner_create_impl(
self._config.actor_encoder_factory,
device=self._device,
)
q_func = create_continuous_q_function(
q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
Expand All @@ -153,14 +161,17 @@ def inner_create_impl(
policy.parameters(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_func.parameters(), lr=self._config.critic_learning_rate
q_funcs.parameters(), lr=self._config.critic_learning_rate
)

self._impl = CRRImpl(
observation_shape=observation_shape,
action_size=action_size,
policy=policy,
q_func=q_func,
q_funcs=q_funcs,
q_func_forwarder=q_func_forwarder,
targ_q_funcs=targ_q_funcs,
targ_q_func_forwarder=targ_q_func_forwarder,
actor_optim=actor_optim,
critic_optim=critic_optim,
gamma=self._config.gamma,
Expand Down
17 changes: 14 additions & 3 deletions d3rlpy/algos/qlearning/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,15 @@ def inner_create_impl(
self._config.actor_encoder_factory,
device=self._device,
)
q_func = create_continuous_q_function(
q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
Expand All @@ -114,14 +122,17 @@ def inner_create_impl(
policy.parameters(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_func.parameters(), lr=self._config.critic_learning_rate
q_funcs.parameters(), lr=self._config.critic_learning_rate
)

self._impl = DDPGImpl(
observation_shape=observation_shape,
action_size=action_size,
policy=policy,
q_func=q_func,
q_funcs=q_funcs,
q_func_forwarder=q_func_forwarder,
targ_q_funcs=targ_q_funcs,
targ_q_func_forwarder=targ_q_func_forwarder,
actor_optim=actor_optim,
critic_optim=critic_optim,
gamma=self._config.gamma,
Expand Down
Loading

0 comments on commit 3ec25d7

Please sign in to comment.