Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hil SERL fixes clean up #642

Draft
wants to merge 4 commits into
base: user/michel-aractingi/2024-11-27-port-hil-serl
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 110 additions & 58 deletions lerobot/common/policies/sac/modeling_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,20 @@ def __init__(
)
else:
self.normalize_inputs = nn.Identity()
# HACK: we need to pass the dataset_stats to the normalization functions
dataset_stats = dataset_stats or {
"action": {
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
}
}
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)

encoder_critic = SACObservationEncoder(config)
encoder_actor = SACObservationEncoder(config)
# Define networks
Expand All @@ -78,7 +86,8 @@ def __init__(
critic_nets.append(critic_net)

self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = deepcopy(self.critic_ensemble)
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())

self.actor = Policy(
encoder=encoder_actor,
Expand Down Expand Up @@ -136,19 +145,23 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:

Returns a dictionary with loss as a tensor, and other information as native floats.
"""
# We have to actualize the value of the temperature because in the previous
self.temperature = self.log_alpha.exp().item()
temperature = self.temperature

batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for calculating the right td index.
actions = batch["action"][:, 0]
# batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for calculating the right td index.
# actions = batch["action"][:, 0]
actions = batch["action"]
rewards = batch["next.reward"][:, 0]
observations = {}
next_observations = {}
for k in batch:
if k.startswith("observation."):
observations[k] = batch[k][:, 0]
next_observations[k] = batch[k][:, 1]

# perform image augmentation
done = batch["next.done"]

# reward bias from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
Expand All @@ -166,11 +179,11 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
indices = indices[:self.config.num_subsample_critics]
q_targets = q_targets[indices]

# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation

# compute td target
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q -= self.temperature * next_log_probs
td_target = rewards + self.config.discount * min_q * ~done

# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)
Expand All @@ -183,33 +196,10 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
reduction="none"
).sum(0).mean()

# critics_loss = (
# F.mse_loss(
# q_preds,
# einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
# reduction="none",
# ).sum(0) # sum over ensemble
# # `q_preds_ensemble` depends on the first observation and the actions.
# * ~batch["observation.state_is_pad"][0]
# * ~batch["action_is_pad"]
# # q_targets depends on the reward and the next observations.
# * ~batch["next.reward_is_pad"]
# * ~batch["observation.state_is_pad"][1:]
# ).sum(0).mean()

# calculate actors loss
# 1- temperature
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor(observations)
# 3- get q-value predictions
actions_pi, log_probs, _ = self.actor(observations)
with torch.inference_mode():
q_preds = self.critic_forward(observations, actions, use_target=False)
actor_loss = (
-(q_preds - temperature * log_probs).mean()
# * ~batch["observation.state_is_pad"][0]
# * ~batch["action_is_pad"]
).mean()
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]


# calculate temperature loss
Expand All @@ -223,29 +213,91 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
loss = critics_loss + actor_loss + temperature_loss

return {
"critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature.item(),
"entropy": entropy.item(),
"loss": loss,
}

def update(self):
# TODO: implement UTD update
# First update only critics for utd_ratio-1 times
#for critic_step in range(self.config.utd_ratio - 1):
# only update critic and critic target
# Then update critic, critic target, actor and temperature
"critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(),
"mean_q_predicts": min_q_preds.mean().item(),
"min_q_predicts": min_q_preds.min().item(),
"max_q_predicts": min_q_preds.max().item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature,
"mean_log_probs": log_probs.mean().item(),
"min_log_probs": log_probs.min().item(),
"max_log_probs": log_probs.max().item(),
"td_target_mean": td_target.mean().item(),
"td_target_max": td_target.max().item(),
"action_mean": actions.mean().item(),
"entropy": log_probs.mean().item(),
"loss": loss,
}

def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)

def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations)

# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
)

# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]

# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (temperature * next_log_probs)

td_target = rewards + (1 - done) * self.config.discount * min_q

# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)

# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(1)
).sum()
return critics_loss

def compute_loss_temperature(self, observations) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
target_param.data.copy_(
target_param.data * self.config.critic_target_update_weight +
param.data * (1.0 - self.config.critic_target_update_weight)
)

_, log_probs, _ = self.actor(observations)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss

def compute_loss_actor(self, observations) -> Tensor:
temperature = self.log_alpha.exp().item()

actions_pi, log_probs, _ = self.actor(observations)

q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]

actor_loss = ((temperature * log_probs) - min_q_preds).mean()
return actor_loss


class MLP(nn.Module):
def __init__(
self,
Expand Down
Loading