Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continous CQL loss logging and aligning with discrete logging #317

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ docs/d3rlpy*.rst
docs/modules.rst
docs/references/generated
coverage.xml
.coverage
.coverage*
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running tests, I seem to get .coverage... files with references to my local system. Maybe it's the way I'm running the test?

.mypy_cache
.ipynb_checkpoints
build
dist
/.idea/
*.egg-info
*.DS_Store
53 changes: 29 additions & 24 deletions d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import dataclasses
from typing import Dict

import torch

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_continuous_q_function,
Expand All @@ -11,9 +12,10 @@
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...torch_utility import TorchMiniBatch
from ...models.torch import Parameter
from .base import QLearningAlgoBase
from .torch.awac_impl import AWACImpl
from .torch.sac_impl import SACModules

__all__ = ["AWACConfig", "AWAC"]

Expand Down Expand Up @@ -68,7 +70,6 @@ class AWACConfig(LearnableConfig):
n_action_samples (int): Number of sampled actions to calculate
:math:`A^\pi(s_t, a_t)`.
n_critics (int): Number of Q functions for ensemble.
update_actor_interval (int): Interval to update policy function.
"""
actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
Expand All @@ -83,7 +84,6 @@ class AWACConfig(LearnableConfig):
lam: float = 1.0
n_action_samples: int = 1
n_critics: int = 2
update_actor_interval: int = 1

def create(self, device: DeviceArg = False) -> "AWAC":
return AWAC(self, device)
Expand All @@ -106,7 +106,15 @@ def inner_create_impl(
use_std_parameter=True,
device=self._device,
)
q_func = create_continuous_q_function(
q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
Expand All @@ -119,36 +127,33 @@ def inner_create_impl(
policy.parameters(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_func.parameters(), lr=self._config.critic_learning_rate
q_funcs.parameters(), lr=self._config.critic_learning_rate
)

self._impl = AWACImpl(
observation_shape=observation_shape,
action_size=action_size,
q_func=q_func,
dummy_log_temp = Parameter(torch.zeros(1, 1))
modules = SACModules(
policy=policy,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
log_temp=dummy_log_temp,
actor_optim=actor_optim,
critic_optim=critic_optim,
temp_optim=None,
)

self._impl = AWACImpl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
lam=self._config.lam,
n_action_samples=self._config.n_action_samples,
device=self._device,
)

def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]:
assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR

metrics = {}

metrics.update(self._impl.update_critic(batch))

# delayed policy update
if self._grad_step % self._config.update_actor_interval == 0:
metrics.update(self._impl.update_actor(batch))

return metrics

def get_action_type(self) -> ActionSpace:
return ActionSpace.CONTINUOUS

Expand Down
48 changes: 23 additions & 25 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import torch
from torch import nn
from tqdm.auto import tqdm, trange
from typing_extensions import Self

Expand All @@ -36,17 +37,15 @@
LoggerAdapterFactory,
)
from ...metrics import EvaluatorProtocol, evaluate_qlearning_with_environment
from ...models.torch import EnsembleQFunction, Policy
from ...models.torch import Policy
from ...torch_utility import (
TorchMiniBatch,
convert_to_torch,
convert_to_torch_recursively,
eval_api,
freeze,
hard_sync,
reset_optimizer_states,
sync_optimizer_state,
unfreeze,
train_api,
)
from ..utility import (
assert_action_space_with_dataset,
Expand All @@ -65,6 +64,16 @@


class QLearningAlgoImplBase(ImplBase):
@train_api
def update(self, batch: TorchMiniBatch, grad_step: int) -> Dict[str, float]:
return self.inner_update(batch, grad_step)

Check warning on line 69 in d3rlpy/algos/qlearning/base.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/base.py#L69

Added line #L69 was not covered by tests

@abstractmethod
def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
pass

Check warning on line 75 in d3rlpy/algos/qlearning/base.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/base.py#L75

Added line #L75 was not covered by tests

@eval_api
def predict_best_action(self, x: torch.Tensor) -> torch.Tensor:
return self.inner_predict_best_action(x)
Expand Down Expand Up @@ -119,15 +128,15 @@
sync_optimizer_state(self.policy_optim, impl.policy_optim)

@property
def q_function(self) -> EnsembleQFunction:
def q_function(self) -> nn.ModuleList:
raise NotImplementedError

def copy_q_function_from(self, impl: "QLearningAlgoImplBase") -> None:
q_func = self.q_function.q_funcs[0]
if not isinstance(impl.q_function.q_funcs[0], type(q_func)):
q_func = self.q_function[0]
if not isinstance(impl.q_function[0], type(q_func)):
raise ValueError(
f"Invalid Q-function type: expected={type(q_func)},"
f"actual={type(impl.q_function.q_funcs[0])}"
f"actual={type(impl.q_function[0])}"
)
hard_sync(self.q_function, impl.q_function)

Expand All @@ -145,7 +154,7 @@
sync_optimizer_state(self.q_function_optim, impl.q_function_optim)

def reset_optimizer_states(self) -> None:
reset_optimizer_states(self)
self.modules.reset_optimizer_states()

Check warning on line 157 in d3rlpy/algos/qlearning/base.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/base.py#L157

Added line #L157 was not covered by tests


TQLearningImpl = TypeVar("TQLearningImpl", bound=QLearningAlgoImplBase)
Expand Down Expand Up @@ -195,9 +204,9 @@
)

# workaround until version 1.6
freeze(self._impl)
self._impl.modules.freeze()

# dummy function to select best actions
# local function to select best actions
def _func(x: torch.Tensor) -> torch.Tensor:
assert self._impl

Expand Down Expand Up @@ -233,7 +242,7 @@
)

# workaround until version 1.6
unfreeze(self._impl)
self._impl.modules.unfreeze()

def predict(self, x: Observation) -> np.ndarray:
"""Returns greedy actions.
Expand Down Expand Up @@ -811,29 +820,18 @@
Returns:
Dictionary of metrics.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
torch_batch = TorchMiniBatch.from_batch(
batch=batch,
device=self._device,
observation_scaler=self._config.observation_scaler,
action_scaler=self._config.action_scaler,
reward_scaler=self._config.reward_scaler,
)
loss = self.inner_update(torch_batch)
loss = self._impl.inner_update(torch_batch, self._grad_step)
self._grad_step += 1
return loss

@abstractmethod
def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]:
"""Update parameters with PyTorch mini-batch.

Args:
batch: PyTorch mini-batch data.

Returns:
Dictionary of metrics.
"""
raise NotImplementedError

def copy_policy_from(
self, algo: "QLearningAlgoBase[QLearningAlgoImplBase, LearnableConfig]"
) -> None:
Expand Down
37 changes: 17 additions & 20 deletions d3rlpy/algos/qlearning/bc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import dataclasses
from typing import Dict, Generic, TypeVar

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from ...constants import ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_categorical_policy,
Expand All @@ -11,23 +10,18 @@
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...torch_utility import TorchMiniBatch
from .base import QLearningAlgoBase
from .torch.bc_impl import BCBaseImpl, BCImpl, DiscreteBCImpl
from .torch.bc_impl import (
BCBaseImpl,
BCImpl,
BCModules,
DiscreteBCImpl,
DiscreteBCModules,
)

__all__ = ["BCConfig", "BC", "DiscreteBCConfig", "DiscreteBC"]


TBCConfig = TypeVar("TBCConfig", bound="LearnableConfig")


class _BCBase(Generic[TBCConfig], QLearningAlgoBase[BCBaseImpl, TBCConfig]):
def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]:
assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR
loss = self._impl.update_imitator(batch)
return {"loss": loss}


@dataclasses.dataclass()
class BCConfig(LearnableConfig):
r"""Config of Behavior Cloning algorithm.
Expand Down Expand Up @@ -70,7 +64,7 @@ def get_type() -> str:
return "bc"


class BC(_BCBase[BCConfig]):
class BC(QLearningAlgoBase[BCBaseImpl, BCConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
Expand All @@ -97,11 +91,13 @@ def inner_create_impl(
imitator.parameters(), lr=self._config.learning_rate
)

modules = BCModules(optim=optim, imitator=imitator)

self._impl = BCImpl(
observation_shape=observation_shape,
action_size=action_size,
imitator=imitator,
optim=optim,
modules=modules,
policy_type=self._config.policy_type,
device=self._device,
)

Expand Down Expand Up @@ -151,7 +147,7 @@ def get_type() -> str:
return "discrete_bc"


class DiscreteBC(_BCBase[DiscreteBCConfig]):
class DiscreteBC(QLearningAlgoBase[BCBaseImpl, DiscreteBCConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
Expand All @@ -166,11 +162,12 @@ def inner_create_impl(
imitator.parameters(), lr=self._config.learning_rate
)

modules = DiscreteBCModules(optim=optim, imitator=imitator)

self._impl = DiscreteBCImpl(
observation_shape=observation_shape,
action_size=action_size,
imitator=imitator,
optim=optim,
modules=modules,
beta=self._config.beta,
device=self._device,
)
Expand Down
Loading
Loading