Skip to content

Commit

Permalink
Merge branch 'minimize_redundant_computes'
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 30, 2023
2 parents b6f17bd + 8d419aa commit f57da4e
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 238 deletions.
22 changes: 14 additions & 8 deletions d3rlpy/algos/qlearning/torch/awac_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import torch.nn.functional as F

from ....models.torch import (
ActionOutput,
ContinuousEnsembleQFunctionForwarder,
build_gaussian_distribution,
)
from ....torch_utility import TorchMiniBatch
from ....types import Shape
from .sac_impl import SACImpl, SACModules
from .sac_impl import SACActorLoss, SACImpl, SACModules

__all__ = ["AWACImpl"]

Expand Down Expand Up @@ -42,17 +43,22 @@ def __init__(
self._lam = lam
self._n_action_samples = n_action_samples

def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
) -> SACActorLoss:
# compute log probability
dist = build_gaussian_distribution(
self._modules.policy(batch.observations)
)
dist = build_gaussian_distribution(action)
log_probs = dist.log_prob(batch.actions)

# compute exponential weight
weights = self._compute_weights(batch.observations, batch.actions)

return -(log_probs * weights).sum()
loss = -(log_probs * weights).sum()
return SACActorLoss(
actor_loss=loss,
temp_loss=torch.tensor(
0.0, dtype=torch.float32, device=loss.device
),
temp=torch.tensor(0.0, dtype=torch.float32, device=loss.device),
)

def _compute_weights(
self, obs_t: torch.Tensor, act_t: torch.Tensor
Expand Down
50 changes: 29 additions & 21 deletions d3rlpy/algos/qlearning/torch/bcq_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.optim import Optimizer

from ....models.torch import (
ActionOutput,
CategoricalPolicy,
ConditionalVAE,
ContinuousEnsembleQFunctionForwarder,
Expand All @@ -19,7 +20,7 @@
)
from ....torch_utility import TorchMiniBatch, soft_sync
from ....types import Shape
from .ddpg_impl import DDPGBaseImpl, DDPGBaseModules
from .ddpg_impl import DDPGBaseActorLoss, DDPGBaseImpl, DDPGBaseModules
from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules

__all__ = [
Expand Down Expand Up @@ -79,37 +80,24 @@ def __init__(
self._beta = beta
self._rl_start_step = rl_start_step

def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
latent = torch.randn(
batch.observations.shape[0],
2 * self._action_size,
device=self._device,
)
clipped_latent = latent.clamp(-0.5, 0.5)
sampled_action = forward_vae_decode(
vae=self._modules.imitator,
x=batch.observations,
latent=clipped_latent,
)
action = self._modules.policy(batch.observations, sampled_action)
def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
) -> DDPGBaseActorLoss:
value = self._q_func_forwarder.compute_expected_q(
batch.observations, action.squashed_mu, "none"
)
return -value[0].mean()
return DDPGBaseActorLoss(-value[0].mean())

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
self._modules.imitator_optim.zero_grad()

loss = compute_vae_error(
vae=self._modules.imitator,
x=batch.observations,
action=batch.actions,
beta=self._beta,
)

loss.backward()
self._modules.imitator_optim.step()

return {"imitator_loss": float(loss.cpu().detach().numpy())}

def _repeat_observation(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -188,10 +176,30 @@ def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
metrics = {}

metrics.update(self.update_imitator(batch))
if grad_step >= self._rl_start_step:
metrics.update(super().inner_update(batch, grad_step))
self.update_actor_target()
if grad_step < self._rl_start_step:
return metrics

# forward policy
latent = torch.randn(
batch.observations.shape[0],
2 * self._action_size,
device=self._device,
)
clipped_latent = latent.clamp(-0.5, 0.5)
sampled_action = forward_vae_decode(
vae=self._modules.imitator,
x=batch.observations,
latent=clipped_latent,
)
action = self._modules.policy(batch.observations, sampled_action)

# update models
metrics.update(self.update_critic(batch))
metrics.update(self.update_actor(batch, action))
self.update_critic_target()
self.update_actor_target()
return metrics


Expand Down
64 changes: 26 additions & 38 deletions d3rlpy/algos/qlearning/torch/bear_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.optim import Optimizer

from ....models.torch import (
ActionOutput,
ConditionalVAE,
ContinuousEnsembleQFunctionForwarder,
Parameter,
Expand All @@ -15,7 +16,7 @@
)
from ....torch_utility import TorchMiniBatch
from ....types import Shape
from .sac_impl import SACImpl, SACModules
from .sac_impl import SACActorLoss, SACImpl, SACModules

__all__ = ["BEARImpl", "BEARModules"]

Expand All @@ -42,6 +43,12 @@ class BEARModules(SACModules):
alpha_optim: Optional[Optimizer]


@dataclasses.dataclass(frozen=True)
class BEARActorLoss(SACActorLoss):
mmd_loss: torch.Tensor
alpha: torch.Tensor


class BEARImpl(SACImpl):
_modules: BEARModules
_alpha_threshold: float
Expand Down Expand Up @@ -94,19 +101,26 @@ def __init__(
self._vae_kl_weight = vae_kl_weight
self._warmup_steps = warmup_steps

def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
loss = super().compute_actor_loss(batch)
def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
) -> BEARActorLoss:
loss = super().compute_actor_loss(batch, action)
mmd_loss = self._compute_mmd_loss(batch.observations)
return loss + mmd_loss
if self._modules.alpha_optim:
self.update_alpha(mmd_loss)
return BEARActorLoss(
actor_loss=loss.actor_loss + mmd_loss,
temp_loss=loss.temp_loss,
temp=loss.temp,
mmd_loss=mmd_loss,
alpha=self._modules.log_alpha().exp(),
)

def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:
self._modules.actor_optim.zero_grad()

loss = self._compute_mmd_loss(batch.observations)

loss.backward()
self._modules.actor_optim.step()

return {"actor_loss": float(loss.cpu().detach().numpy())}

def _compute_mmd_loss(self, obs_t: torch.Tensor) -> torch.Tensor:
Expand All @@ -116,13 +130,9 @@ def _compute_mmd_loss(self, obs_t: torch.Tensor) -> torch.Tensor:

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
self._modules.imitator_optim.zero_grad()

loss = self.compute_imitator_loss(batch)

loss.backward()

self._modules.imitator_optim.step()

return {"imitator_loss": float(loss.cpu().detach().numpy())}

def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
Expand All @@ -133,25 +143,15 @@ def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
beta=self._vae_kl_weight,
)

def update_alpha(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_alpha(self, mmd_loss: torch.Tensor) -> None:
assert self._modules.alpha_optim
self._modules.alpha_optim.zero_grad()

loss = -self._compute_mmd_loss(batch.observations)

loss.backward()
loss = -mmd_loss
loss.backward(retain_graph=True)
self._modules.alpha_optim.step()

# clip for stability
self._modules.log_alpha.data.clamp_(-5.0, 10.0)

cur_alpha = self._modules.log_alpha().exp().cpu().detach().numpy()[0][0]

return {
"alpha_loss": float(loss.cpu().detach().numpy()),
"alpha": float(cur_alpha),
}

def _compute_mmd(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
behavior_actions = forward_vae_sample_n(
Expand Down Expand Up @@ -259,25 +259,13 @@ def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
metrics = {}

metrics.update(self.update_imitator(batch))

# lagrangian parameter update for SAC temperature
if self._modules.temp_optim:
metrics.update(self.update_temp(batch))

# lagrangian parameter update for MMD loss weight
if self._modules.alpha_optim:
metrics.update(self.update_alpha(batch))

metrics.update(self.update_critic(batch))

if grad_step < self._warmup_steps:
actor_loss = self.warmup_actor(batch)
else:
actor_loss = self.update_actor(batch)
action = self._modules.policy(batch.observations)
actor_loss = self.update_actor(batch, action)
metrics.update(actor_loss)

self.update_critic_target()

return metrics
60 changes: 19 additions & 41 deletions d3rlpy/algos/qlearning/torch/cql_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import math
from typing import Dict, Optional
from typing import Optional

import torch
import torch.nn.functional as F
Expand All @@ -14,6 +14,7 @@
)
from ....torch_utility import TorchMiniBatch
from ....types import Shape
from .ddpg_impl import DDPGBaseCriticLoss
from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules
from .sac_impl import SACImpl, SACModules

Expand All @@ -26,6 +27,12 @@ class CQLModules(SACModules):
alpha_optim: Optional[Optimizer]


@dataclasses.dataclass(frozen=True)
class CQLCriticLoss(DDPGBaseCriticLoss):
conservative_loss: torch.Tensor
alpha: torch.Tensor


class CQLImpl(SACImpl):
_modules: CQLModules
_alpha_threshold: float
Expand Down Expand Up @@ -65,36 +72,27 @@ def __init__(

def compute_critic_loss(
self, batch: TorchMiniBatch, q_tpn: torch.Tensor
) -> torch.Tensor:
) -> CQLCriticLoss:
loss = super().compute_critic_loss(batch, q_tpn)
conservative_loss = self._compute_conservative_loss(
batch.observations, batch.actions, batch.next_observations
)
return loss + conservative_loss
if self._modules.alpha_optim:
self.update_alpha(conservative_loss)
return CQLCriticLoss(
critic_loss=loss.critic_loss + conservative_loss,
conservative_loss=conservative_loss,
alpha=self._modules.log_alpha().exp(),
)

def update_alpha(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_alpha(self, conservative_loss: torch.Tensor) -> None:
assert self._modules.alpha_optim

# Q function should be inference mode for stability
self._modules.q_funcs.eval()

self._modules.alpha_optim.zero_grad()

# the original implementation does scale the loss value
loss = -self._compute_conservative_loss(
batch.observations, batch.actions, batch.next_observations
)

loss.backward()
loss = -conservative_loss
loss.backward(retain_graph=True)
self._modules.alpha_optim.step()

cur_alpha = self._modules.log_alpha().exp().cpu().detach().numpy()[0][0]

return {
"alpha_loss": float(loss.cpu().detach().numpy()),
"alpha": float(cur_alpha),
}

def _compute_policy_is_values(
self, policy_obs: torch.Tensor, value_obs: torch.Tensor
) -> torch.Tensor:
Expand Down Expand Up @@ -196,26 +194,6 @@ def _compute_deterministic_target(
reduction="min",
)

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
metrics = {}

# lagrangian parameter update for SAC temperature
if self._modules.temp_optim:
metrics.update(self.update_temp(batch))

# lagrangian parameter update for conservative loss weight
if self._modules.alpha_optim:
metrics.update(self.update_alpha(batch))

metrics.update(self.update_critic(batch))
metrics.update(self.update_actor(batch))

self.update_critic_target()

return metrics


@dataclasses.dataclass(frozen=True)
class DiscreteCQLLoss(DQNLoss):
Expand Down
Loading

0 comments on commit f57da4e

Please sign in to comment.