Skip to content

Commit

Permalink
Simplify ConditionalVAE
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Aug 6, 2023
1 parent 39a3312 commit 4abdbd3
Show file tree
Hide file tree
Showing 13 changed files with 280 additions and 155 deletions.
1 change: 0 additions & 1 deletion d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def inner_create_impl(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
beta=self._config.beta,
min_logstd=-4.0,
max_logstd=15.0,
encoder_factory=self._config.imitator_encoder_factory,
Expand Down
1 change: 0 additions & 1 deletion d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def inner_create_impl(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
beta=self._config.vae_kl_weight,
min_logstd=-4.0,
max_logstd=15.0,
encoder_factory=self._config.imitator_encoder_factory,
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/plas.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def inner_create_impl(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
beta=self._config.beta,
min_logstd=-4.0,
max_logstd=15.0,
encoder_factory=self._config.imitator_encoder_factory,
Expand Down Expand Up @@ -153,6 +152,7 @@ def inner_create_impl(
gamma=self._config.gamma,
tau=self._config.tau,
lam=self._config.lam,
beta=self._config.beta,
device=self._device,
)

Expand Down Expand Up @@ -257,7 +257,6 @@ def inner_create_impl(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
beta=self._config.beta,
min_logstd=-4.0,
max_logstd=15.0,
encoder_factory=self._config.imitator_encoder_factory,
Expand Down Expand Up @@ -296,6 +295,7 @@ def inner_create_impl(
gamma=self._config.gamma,
tau=self._config.tau,
lam=self._config.lam,
beta=self._config.beta,
device=self._device,
)

Expand Down
21 changes: 17 additions & 4 deletions d3rlpy/algos/qlearning/torch/bcq_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
EnsembleContinuousQFunction,
EnsembleDiscreteQFunction,
compute_max_with_n_actions,
compute_vae_error,
forward_vae_decode,
)
from ....torch_utility import TorchMiniBatch, train_api
from .ddpg_impl import DDPGBaseImpl
Expand Down Expand Up @@ -73,8 +75,10 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
device=self._device,
)
clipped_latent = latent.clamp(-0.5, 0.5)
sampled_action = self._imitator.decode(
batch.observations, clipped_latent
sampled_action = forward_vae_decode(
vae=self._imitator,
x=batch.observations,
latent=clipped_latent,
)
action = self._policy(batch.observations, sampled_action)
return -self._q_func(batch.observations, action.squashed_mu, "none")[
Expand All @@ -85,7 +89,12 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
def update_imitator(self, batch: TorchMiniBatch) -> float:
self._imitator_optim.zero_grad()

loss = self._imitator.compute_error(batch.observations, batch.actions)
loss = compute_vae_error(
vae=self._imitator,
x=batch.observations,
action=batch.actions,
beta=self._beta,
)

loss.backward()
self._imitator_optim.step()
Expand All @@ -109,7 +118,11 @@ def _sample_repeated_action(
)
clipped_latent = latent.clamp(-0.5, 0.5)
# sample action
sampled_action = self._imitator.decode(flattened_x, clipped_latent)
sampled_action = forward_vae_decode(
vae=self._imitator,
x=flattened_x,
latent=clipped_latent,
)
# add residual action
policy = self._targ_policy if target else self._policy
action = policy(flattened_x, sampled_action)
Expand Down
13 changes: 10 additions & 3 deletions d3rlpy/algos/qlearning/torch/bear_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
Parameter,
build_squashed_gaussian_distribution,
compute_max_with_n_actions_and_indices,
compute_vae_error,
forward_vae_sample_n,
)
from ....torch_utility import TorchMiniBatch, train_api
from .sac_impl import SACImpl
Expand Down Expand Up @@ -133,7 +135,12 @@ def update_imitator(self, batch: TorchMiniBatch) -> float:
return float(loss.cpu().detach().numpy())

def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
return self._imitator.compute_error(batch.observations, batch.actions)
return compute_vae_error(
vae=self._imitator,
x=batch.observations,
action=batch.actions,
beta=self._vae_kl_weight,
)

@train_api
def update_alpha(self, batch: TorchMiniBatch) -> Tuple[float, float]:
Expand All @@ -152,8 +159,8 @@ def update_alpha(self, batch: TorchMiniBatch) -> Tuple[float, float]:

def _compute_mmd(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
behavior_actions = self._imitator.sample_n_without_squash(
x, self._n_mmd_action_samples
behavior_actions = forward_vae_sample_n(
self._imitator, x, self._n_mmd_action_samples, with_squash=False
)
dist = build_squashed_gaussian_distribution(self._policy(x))
policy_actions = dist.sample_n_without_squash(
Expand Down
37 changes: 27 additions & 10 deletions d3rlpy/algos/qlearning/torch/plas_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
DeterministicPolicy,
DeterministicResidualPolicy,
EnsembleContinuousQFunction,
compute_vae_error,
forward_vae_decode,
)
from ....torch_utility import TorchMiniBatch, soft_sync, train_api
from .ddpg_impl import DDPGBaseImpl
Expand Down Expand Up @@ -37,6 +39,7 @@ def __init__(
gamma: float,
tau: float,
lam: float,
beta: float,
device: str,
):
super().__init__(
Expand All @@ -51,14 +54,20 @@ def __init__(
device=device,
)
self._lam = lam
self._beta = beta
self._imitator = imitator
self._imitator_optim = imitator_optim

@train_api
def update_imitator(self, batch: TorchMiniBatch) -> float:
self._imitator_optim.zero_grad()

loss = self._imitator.compute_error(batch.observations, batch.actions)
loss = compute_vae_error(
vae=self._imitator,
x=batch.observations,
action=batch.actions,
beta=self._beta,
)

loss.backward()
self._imitator_optim.step()
Expand All @@ -67,11 +76,14 @@ def update_imitator(self, batch: TorchMiniBatch) -> float:

def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
latent_actions = 2.0 * self._policy(batch.observations).squashed_mu
actions = self._imitator.decode(batch.observations, latent_actions)
actions = forward_vae_decode(
self._imitator, batch.observations, latent_actions
)
return -self._q_func(batch.observations, actions, "none")[0].mean()

def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor:
return self._imitator.decode(x, 2.0 * self._policy(x).squashed_mu)
latent_actions = 2.0 * self._policy(x).squashed_mu
return forward_vae_decode(self._imitator, x, latent_actions)

def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor:
return self.inner_predict_best_action(x)
Expand All @@ -81,8 +93,8 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
latent_actions = (
2.0 * self._targ_policy(batch.next_observations).squashed_mu
)
actions = self._imitator.decode(
batch.next_observations, latent_actions
actions = forward_vae_decode(
self._imitator, batch.next_observations, latent_actions
)
return self._targ_q_func.compute_target(
batch.next_observations,
Expand Down Expand Up @@ -110,6 +122,7 @@ def __init__(
gamma: float,
tau: float,
lam: float,
beta: float,
device: str,
):
super().__init__(
Expand All @@ -124,23 +137,27 @@ def __init__(
gamma=gamma,
tau=tau,
lam=lam,
beta=beta,
device=device,
)
self._perturbation = perturbation
self._targ_perturbation = copy.deepcopy(perturbation)

def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
latent_actions = 2.0 * self._policy(batch.observations).squashed_mu
actions = self._imitator.decode(batch.observations, latent_actions)
actions = forward_vae_decode(
self._imitator, batch.observations, latent_actions
)
residual_actions = self._perturbation(
batch.observations, actions
).squashed_mu
q_value = self._q_func(batch.observations, residual_actions, "none")
return -q_value[0].mean()

def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor:
action = self._imitator.decode(x, 2.0 * self._policy(x).squashed_mu)
return self._perturbation(x, action).squashed_mu
latent_actions = 2.0 * self._policy(x).squashed_mu
actions = forward_vae_decode(self._imitator, x, latent_actions)
return self._perturbation(x, actions).squashed_mu

def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor:
return self.inner_predict_best_action(x)
Expand All @@ -150,8 +167,8 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
latent_actions = (
2.0 * self._targ_policy(batch.next_observations).squashed_mu
)
actions = self._imitator.decode(
batch.next_observations, latent_actions
actions = forward_vae_decode(
self._imitator, batch.next_observations, latent_actions
)
residual_actions = self._targ_perturbation(
batch.next_observations, actions
Expand Down
15 changes: 5 additions & 10 deletions d3rlpy/algos/qlearning/torch/sac_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
EnsembleQFunction,
Parameter,
Policy,
build_categorical_distribution,
build_squashed_gaussian_distribution,
)
from ....torch_utility import TorchMiniBatch, hard_sync, train_api
Expand Down Expand Up @@ -155,9 +154,7 @@ def update_critic(self, batch: TorchMiniBatch) -> float:

def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
with torch.no_grad():
dist = build_categorical_distribution(
self._policy(batch.next_observations)
)
dist = self._policy(batch.next_observations)
log_probs = dist.logits
probs = dist.probs
entropy = self._log_temp().exp() * log_probs
Expand Down Expand Up @@ -200,7 +197,7 @@ def update_actor(self, batch: TorchMiniBatch) -> float:
def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
with torch.no_grad():
q_t = self._q_func(batch.observations, reduction="min")
dist = build_categorical_distribution(self._policy(batch.observations))
dist = self._policy(batch.observations)
log_probs = dist.logits
probs = dist.probs
entropy = self._log_temp().exp() * log_probs
Expand All @@ -211,9 +208,7 @@ def update_temp(self, batch: TorchMiniBatch) -> Tuple[float, float]:
self._temp_optim.zero_grad()

with torch.no_grad():
dist = build_categorical_distribution(
self._policy(batch.observations)
)
dist = self._policy(batch.observations)
log_probs = dist.logits
probs = dist.probs
expct_log_probs = (probs * log_probs).sum(dim=1, keepdim=True)
Expand All @@ -231,11 +226,11 @@ def update_temp(self, batch: TorchMiniBatch) -> Tuple[float, float]:
return float(loss.cpu().detach().numpy()), float(cur_temp)

def inner_predict_best_action(self, x: torch.Tensor) -> torch.Tensor:
dist = build_categorical_distribution(self._policy(x))
dist = self._policy(x)
return dist.probs.argmax(dim=1)

def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor:
dist = build_categorical_distribution(self._policy(x))
dist = self._policy(x)
return dist.sample()

def update_target(self) -> None:
Expand Down
16 changes: 10 additions & 6 deletions d3rlpy/models/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
Parameter,
ProbablisticRegressor,
SimplePositionEncoding,
VAEDecoder,
VAEEncoder,
ValueFunction,
compute_output_size,
)
Expand Down Expand Up @@ -191,7 +193,6 @@ def create_conditional_vae(
observation_shape: Shape,
action_size: int,
latent_size: int,
beta: float,
encoder_factory: EncoderFactory,
device: str,
min_logstd: float = -20.0,
Expand All @@ -206,16 +207,19 @@ def create_conditional_vae(
hidden_size = compute_output_size(
[observation_shape, (action_size,)], encoder_encoder, device
)
policy = ConditionalVAE(
encoder_encoder=encoder_encoder,
decoder_encoder=decoder_encoder,
encoder = VAEEncoder(
encoder=encoder_encoder,
hidden_size=hidden_size,
latent_size=latent_size,
action_size=action_size,
beta=beta,
min_logstd=min_logstd,
max_logstd=max_logstd,
)
decoder = VAEDecoder(
encoder=decoder_encoder,
hidden_size=hidden_size,
action_size=action_size,
)
policy = ConditionalVAE(encoder=encoder, decoder=decoder)
policy.to(device)
return policy

Expand Down
Loading

0 comments on commit 4abdbd3

Please sign in to comment.