Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Discrete IQL #404

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 100 additions & 2 deletions d3rlpy/algos/qlearning/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
from ...constants import ActionSpace
from ...models.builders import (
create_continuous_q_function,
create_discrete_q_function,
create_categorical_policy,
create_normal_policy,
create_value_function,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import MeanQFunctionFactory
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.iql_impl import IQLImpl, IQLModules
from .torch.iql_impl import IQLImpl, IQLModules, DiscreteIQLImpl, DiscreteIQLModules

__all__ = ["IQLConfig", "IQL"]
__all__ = ["IQLConfig", "IQL", "DiscreteIQLConfig", "DiscreteIQL"]


@dataclasses.dataclass()
Expand Down Expand Up @@ -175,5 +178,100 @@ def inner_create_impl(
def get_action_type(self) -> ActionSpace:
return ActionSpace.CONTINUOUS

@dataclasses.dataclass()
class DiscreteIQLConfig(LearnableConfig):
actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4

q_func_factory: QFunctionFactory = make_q_func_field()
encoder_factory: EncoderFactory = make_encoder_field()
value_encoder_factory: EncoderFactory = make_encoder_field()
critic_optim_factory: OptimizerFactory = make_optimizer_field()

actor_encoder_factory: EncoderFactory = make_encoder_field()
actor_optim_factory: OptimizerFactory = make_optimizer_field()

batch_size: int = 256
gamma: float = 0.99
tau: float = 0.005
n_critics: int = 2
expectile: float = 0.7
weight_temp: float = 3.0
max_weight: float = 100.0

def create(self, device: DeviceArg = False) -> "DiscreteIQL":
return DiscreteIQL(self, device)

@staticmethod
def get_type() -> str:
return "discrete_iql"

class DiscreteIQL(QLearningAlgoBase[DiscreteIQLImpl, DiscreteIQLConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
policy = create_categorical_policy(
observation_shape,
action_size,
self._config.actor_encoder_factory,
self._device,
)
q_funcs, q_func_forwarder = create_discrete_q_function(
observation_shape,
action_size,
self._config.encoder_factory,
self._config.q_func_factory,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove q_func_factory from config? Instead, please use MeanQFunctionFactory just like the continuous IQL?

n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function(
observation_shape,
action_size,
self._config.encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
value_func = create_value_function(
observation_shape,
self._config.value_encoder_factory,
device=self._device,
)

q_func_params = list(q_funcs.named_modules())
v_func_params = list(value_func.named_modules())
critic_optim = self._config.critic_optim_factory.create(
q_func_params + v_func_params, lr=self._config.critic_learning_rate
)
actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
)

modules = DiscreteIQLModules(
policy=policy,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
value_func=value_func,
actor_optim=actor_optim,
critic_optim=critic_optim,
)

self._impl = DiscreteIQLImpl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
expectile=self._config.expectile,
weight_temp=self._config.weight_temp,
max_weight=self._config.max_weight,
device=self._device,
)

def get_action_type(self) -> ActionSpace:
return ActionSpace.DISCRETE

register_learnable(IQLConfig)
register_learnable(DiscreteIQLConfig)
110 changes: 110 additions & 0 deletions d3rlpy/algos/qlearning/torch/ddpg_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,116 @@ def q_function_optim(self) -> Optimizer:
return self._modules.critic_optim


class DiscreteDDPGBaseImpl(
ContinuousQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta
Mamba413 marked this conversation as resolved.
Show resolved Hide resolved
):
_modules: DDPGBaseModules
_gamma: float
_tau: float
_q_func_forwarder: ContinuousEnsembleQFunctionForwarder
_targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These need to be DiscreteEnsembnleQFunctionForwarder.

Copy link
Author

@Mamba413 Mamba413 Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you means DiscreteEnsembleQFunctionForwarder, I think this is solved now.


def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: DDPGBaseModules,
q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
gamma: float,
tau: float,
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
device=device,
)
self._gamma = gamma
self._tau = tau
self._q_func_forwarder = q_func_forwarder
self._targ_q_func_forwarder = targ_q_func_forwarder
hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs)

def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]:
self._modules.critic_optim.zero_grad()
q_tpn = self.compute_target(batch)
loss = self.compute_critic_loss(batch, q_tpn)
loss.critic_loss.backward()
self._modules.critic_optim.step()
return asdict_as_float(loss)

def compute_critic_loss(
self, batch: TorchMiniBatch, q_tpn: torch.Tensor
) -> DDPGBaseCriticLoss:
loss = self._q_func_forwarder.compute_error(
observations=batch.observations,
actions=batch.actions,
rewards=batch.rewards,
target=q_tpn,
terminals=batch.terminals,
gamma=self._gamma**batch.intervals,
)
return DDPGBaseCriticLoss(loss)

def update_actor(
self, batch: TorchMiniBatch, action: ActionOutput
) -> Dict[str, float]:
# Q function should be inference mode for stability
self._modules.q_funcs.eval()
self._modules.actor_optim.zero_grad()
loss = self.compute_actor_loss(batch, action)
loss.actor_loss.backward()
self._modules.actor_optim.step()
return asdict_as_float(loss)

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
metrics = {}
action = self._modules.policy(batch.observations)
metrics.update(self.update_critic(batch))
metrics.update(self.update_actor(batch, action))
self.update_critic_target()
return metrics

@abstractmethod
def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
) -> DDPGBaseActorLoss:
pass

@abstractmethod
def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
pass

def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:
return torch.argmax(self._modules.policy(x).probs).unsqueeze(0)

@abstractmethod
def inner_sample_action(self, x: TorchObservation) -> torch.Tensor:
pass

def update_critic_target(self) -> None:
soft_sync(self._modules.targ_q_funcs, self._modules.q_funcs, self._tau)

@property
def policy(self) -> Policy:
return self._modules.policy

@property
def policy_optim(self) -> Optimizer:
return self._modules.actor_optim

@property
def q_function(self) -> nn.ModuleList:
return self._modules.q_funcs

@property
def q_function_optim(self) -> Optimizer:
return self._modules.critic_optim

@dataclasses.dataclass(frozen=True)
class DDPGModules(DDPGBaseModules):
targ_policy: Policy
Expand Down
106 changes: 104 additions & 2 deletions d3rlpy/algos/qlearning/torch/iql_impl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import dataclasses

import torch
import torch.nn.functional as F

from ....models.torch import (
ActionOutput,
ContinuousEnsembleQFunctionForwarder,
DiscreteEnsembleQFunctionForwarder,
NormalPolicy,
CategoricalPolicy,
ValueFunction,
build_gaussian_distribution,
)
Expand All @@ -15,11 +18,11 @@
DDPGBaseActorLoss,
DDPGBaseCriticLoss,
DDPGBaseImpl,
DiscreteDDPGBaseImpl,
DDPGBaseModules,
)

__all__ = ["IQLImpl", "IQLModules"]

__all__ = ["IQLImpl", "IQLModules", "DiscreteIQLImpl", "DiscreteIQLModules"]

@dataclasses.dataclass(frozen=True)
class IQLModules(DDPGBaseModules):
Expand Down Expand Up @@ -120,3 +123,102 @@ def compute_value_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
def inner_sample_action(self, x: TorchObservation) -> torch.Tensor:
dist = build_gaussian_distribution(self._modules.policy(x))
return dist.sample()


@dataclasses.dataclass(frozen=True)
class DiscreteIQLModules(DDPGBaseModules):
policy: CategoricalPolicy
value_func: ValueFunction


@dataclasses.dataclass(frozen=True)
class DiscreteIQLCriticLoss(DDPGBaseCriticLoss):
q_loss: torch.Tensor
v_loss: torch.Tensor


class DiscreteIQLImpl(DiscreteDDPGBaseImpl):
_modules: DiscreteIQLModules
_expectile: float
_weight_temp: float
_max_weight: float

def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: DiscreteIQLModules,
q_func_forwarder: DiscreteEnsembleQFunctionForwarder,
targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder,
gamma: float,
tau: float,
expectile: float,
weight_temp: float,
max_weight: float,
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=gamma,
tau=tau,
device=device,
)
self._expectile = expectile
self._weight_temp = weight_temp
self._max_weight = max_weight

def compute_critic_loss(
self, batch: TorchMiniBatch, q_tpn: torch.Tensor
) -> IQLCriticLoss:
q_loss = self._q_func_forwarder.compute_error(
observations=batch.observations,
actions=batch.actions.long(),
rewards=batch.rewards,
target=q_tpn,
terminals=batch.terminals,
gamma=self._gamma**batch.intervals,
)
v_loss = self.compute_value_loss(batch)
return IQLCriticLoss(
critic_loss=q_loss + v_loss,
q_loss=q_loss,
v_loss=v_loss,
)

def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
with torch.no_grad():
return self._modules.value_func(batch.next_observations)

def compute_actor_loss(self, batch: TorchMiniBatch, action) -> DDPGBaseActorLoss:
assert self._modules.policy
# compute weight
with torch.no_grad():
v = self._modules.value_func(batch.observations)
min_Q = self._targ_q_func_forwarder.compute_target(batch.observations, reduction="min").gather(
1, batch.actions.long()
)

exp_a = torch.exp((min_Q - v) * self._weight_temp).clamp(max=self._max_weight)
# compute log probability
dist = self._modules.policy(batch.observations)
log_probs = dist.log_prob(batch.actions.squeeze(-1)).unsqueeze(1)

return DDPGBaseActorLoss(-(exp_a * log_probs).mean())

def compute_value_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
q_t = self._targ_q_func_forwarder.compute_expected_q(batch.observations)
one_hot = F.one_hot(batch.actions.long().view(-1), num_classes=self.action_size)
q_t = (q_t * one_hot).sum(dim=1, keepdim=True)

Mamba413 marked this conversation as resolved.
Show resolved Hide resolved
v_t = self._modules.value_func(batch.observations)
diff = q_t.detach() - v_t
weight = (self._expectile - (diff < 0.0).float()).abs().detach()
return (weight * (diff**2)).mean()

def inner_sample_action(self, x: TorchObservation) -> torch.Tensor:
dist = self._modules.policy(x)
return dist.sample()
Loading