Skip to content

Commit

Permalink
Refactor imitator models
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Aug 11, 2023
1 parent fc739c7 commit 449bd78
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 400 deletions.
14 changes: 6 additions & 8 deletions d3rlpy/algos/qlearning/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_deterministic_regressor,
create_discrete_imitator,
create_probablistic_regressor,
create_categorical_policy,
create_deterministic_policy,
create_normal_policy,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
Expand Down Expand Up @@ -75,14 +75,14 @@ def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
if self._config.policy_type == "deterministic":
imitator = create_deterministic_regressor(
imitator = create_deterministic_policy(
observation_shape,
action_size,
self._config.encoder_factory,
device=self._device,
)
elif self._config.policy_type == "stochastic":
imitator = create_probablistic_regressor(
imitator = create_normal_policy(
observation_shape,
action_size,
self._config.encoder_factory,
Expand All @@ -102,7 +102,6 @@ def inner_create_impl(
action_size=action_size,
imitator=imitator,
optim=optim,
policy_type=self._config.policy_type,
device=self._device,
)

Expand Down Expand Up @@ -156,10 +155,9 @@ class DiscreteBC(_BCBase[DiscreteBCConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
imitator = create_discrete_imitator(
imitator = create_categorical_policy(
observation_shape,
action_size,
self._config.beta,
self._config.encoder_factory,
device=self._device,
)
Expand Down
10 changes: 4 additions & 6 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from ...dataset import Shape
from ...models.builders import (
create_categorical_policy,
create_conditional_vae,
create_continuous_q_function,
create_deterministic_residual_policy,
create_discrete_imitator,
create_discrete_q_function,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...models.torch import DiscreteImitator, PixelEncoder, compute_output_size
from ...models.torch import CategoricalPolicy, PixelEncoder, compute_output_size
from ...torch_utility import TorchMiniBatch
from .base import QLearningAlgoBase
from .torch.bcq_impl import BCQImpl, DiscreteBCQImpl
Expand Down Expand Up @@ -342,18 +342,16 @@ def inner_create_impl(
q_func.q_funcs[0].encoder,
device=self._device,
)
imitator = DiscreteImitator(
imitator = CategoricalPolicy(
encoder=q_func.q_funcs[0].encoder,
hidden_size=hidden_size,
action_size=action_size,
beta=self._config.beta,
)
imitator.to(self._device)
else:
imitator = create_discrete_imitator(
imitator = create_categorical_policy(
observation_shape,
action_size,
self._config.beta,
self._config.encoder_factory,
device=self._device,
)
Expand Down
91 changes: 33 additions & 58 deletions d3rlpy/algos/qlearning/torch/bc_impl.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,33 @@
from abc import ABCMeta
from abc import ABCMeta, abstractmethod
from typing import Union

import torch
from torch.optim import Optimizer

from ....dataset import Shape
from ....models.torch import (
CategoricalPolicy,
DeterministicPolicy,
DeterministicRegressor,
DiscreteImitator,
Imitator,
NormalPolicy,
Policy,
ProbablisticRegressor,
compute_output_size,
compute_deterministic_imitation_loss,
compute_discrete_imitation_loss,
compute_stochastic_imitation_loss,
)
from ....torch_utility import TorchMiniBatch, hard_sync, train_api
from ....torch_utility import TorchMiniBatch, train_api
from ..base import QLearningAlgoImplBase

__all__ = ["BCImpl", "DiscreteBCImpl"]


class BCBaseImpl(QLearningAlgoImplBase, metaclass=ABCMeta):
_learning_rate: float
_imitator: Imitator
_optim: Optimizer

def __init__(
self,
observation_shape: Shape,
action_size: int,
imitator: Imitator,
optim: Optimizer,
device: str,
):
Expand All @@ -39,7 +36,6 @@ def __init__(
action_size=action_size,
device=device,
)
self._imitator = imitator
self._optim = optim

@train_api
Expand All @@ -53,13 +49,11 @@ def update_imitator(self, batch: TorchMiniBatch) -> float:

return float(loss.cpu().detach().numpy())

@abstractmethod
def compute_loss(
self, obs_t: torch.Tensor, act_t: torch.Tensor
) -> torch.Tensor:
return self._imitator.compute_error(obs_t, act_t)

def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor:
return self._imitator(x)
pass

def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor:
return self.inner_predict_best_action(x)
Expand All @@ -71,63 +65,42 @@ def inner_predict_value(


class BCImpl(BCBaseImpl):
_policy_type: str
_imitator: Union[DeterministicRegressor, ProbablisticRegressor]
_imitator: Union[DeterministicPolicy, NormalPolicy]

def __init__(
self,
observation_shape: Shape,
action_size: int,
imitator: Union[DeterministicRegressor, ProbablisticRegressor],
imitator: Union[DeterministicPolicy, NormalPolicy],
optim: Optimizer,
policy_type: str,
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
imitator=imitator,
optim=optim,
device=device,
)
self._policy_type = policy_type
self._imitator = imitator

@property
def policy(self) -> Policy:
policy: Policy
hidden_size = compute_output_size(
[self._observation_shape, (self._action_size,)],
self._imitator.encoder,
device=self._device,
)
if self._policy_type == "deterministic":
hidden_size = compute_output_size(
[self._observation_shape, (self._action_size,)],
self._imitator.encoder,
device=self._device,
)
policy = DeterministicPolicy(
encoder=self._imitator.encoder,
hidden_size=hidden_size,
action_size=self._action_size,
)
elif self._policy_type == "stochastic":
return NormalPolicy(
encoder=self._imitator.encoder,
hidden_size=hidden_size,
action_size=self._action_size,
min_logstd=-4.0,
max_logstd=15.0,
use_std_parameter=False,
def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor:
return self._imitator(x).squashed_mu

def compute_loss(
self, obs_t: torch.Tensor, act_t: torch.Tensor
) -> torch.Tensor:
if isinstance(self._imitator, DeterministicPolicy):
return compute_deterministic_imitation_loss(
self._imitator, obs_t, act_t
)
else:
raise ValueError(f"invalid policy_type: {self._policy_type}")
policy.to(self._device)

# copy parameters
hard_sync(policy, self._imitator)
return compute_stochastic_imitation_loss(
self._imitator, obs_t, act_t
)

return policy
@property
def policy(self) -> Policy:
return self._imitator

@property
def policy_optim(self) -> Optimizer:
Expand All @@ -136,30 +109,32 @@ def policy_optim(self) -> Optimizer:

class DiscreteBCImpl(BCBaseImpl):
_beta: float
_imitator: DiscreteImitator
_imitator: CategoricalPolicy

def __init__(
self,
observation_shape: Shape,
action_size: int,
imitator: DiscreteImitator,
imitator: CategoricalPolicy,
optim: Optimizer,
beta: float,
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
imitator=imitator,
optim=optim,
device=device,
)
self._imitator = imitator
self._beta = beta

def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor:
return self._imitator(x).argmax(dim=1)
return self._imitator(x).logits.argmax(dim=1)

def compute_loss(
self, obs_t: torch.Tensor, act_t: torch.Tensor
) -> torch.Tensor:
return self._imitator.compute_error(obs_t, act_t.long())
return compute_discrete_imitation_loss(
policy=self._imitator, x=obs_t, action=act_t.long(), beta=self._beta
)
18 changes: 12 additions & 6 deletions d3rlpy/algos/qlearning/torch/bcq_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
from typing import cast

import torch
import torch.nn.functional as F
from torch.optim import Optimizer

from ....dataset import Shape
from ....models.torch import (
CategoricalPolicy,
ConditionalVAE,
DeterministicResidualPolicy,
DiscreteImitator,
EnsembleContinuousQFunction,
EnsembleDiscreteQFunction,
compute_discrete_imitation_loss,
compute_max_with_n_actions,
compute_vae_error,
forward_vae_decode,
Expand Down Expand Up @@ -171,14 +173,14 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
class DiscreteBCQImpl(DoubleDQNImpl):
_action_flexibility: float
_beta: float
_imitator: DiscreteImitator
_imitator: CategoricalPolicy

def __init__(
self,
observation_shape: Shape,
action_size: int,
q_func: EnsembleDiscreteQFunction,
imitator: DiscreteImitator,
imitator: CategoricalPolicy,
optim: Optimizer,
gamma: float,
action_flexibility: float,
Expand All @@ -201,13 +203,17 @@ def compute_loss(
self, batch: TorchMiniBatch, q_tpn: torch.Tensor
) -> torch.Tensor:
loss = super().compute_loss(batch, q_tpn)
imitator_loss = self._imitator.compute_error(
batch.observations, batch.actions.long()
imitator_loss = compute_discrete_imitation_loss(
policy=self._imitator,
x=batch.observations,
action=batch.actions.long(),
beta=self._beta,
)
return loss + imitator_loss

def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor:
log_probs = self._imitator(x)
dist = self._imitator(x)
log_probs = F.log_softmax(dist.logits, dim=1)
ratio = log_probs - log_probs.max(dim=1, keepdim=True).values
mask = (ratio > math.log(self._action_flexibility)).float()
value = self._q_func(x)
Expand Down
3 changes: 2 additions & 1 deletion d3rlpy/algos/qlearning/torch/sac_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Tuple

import torch
import torch.nn.functional as F
from torch.optim import Optimizer

from ....dataset import Shape
Expand Down Expand Up @@ -209,7 +210,7 @@ def update_temp(self, batch: TorchMiniBatch) -> Tuple[float, float]:

with torch.no_grad():
dist = self._policy(batch.observations)
log_probs = dist.logits
log_probs = F.log_softmax(dist.logits, dim=1)
probs = dist.probs
expct_log_probs = (probs * log_probs).sum(dim=1, keepdim=True)
entropy_target = 0.98 * (-math.log(1 / self.action_size))
Expand Down
Loading

0 comments on commit 449bd78

Please sign in to comment.