-
Notifications
You must be signed in to change notification settings - Fork 867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Port HIL-SERL] Add HF vision encoder option in SAC #651
base: user/adil-zouitine/2025-1-7-port-hil-serl-new
Are you sure you want to change the base?
Changes from 7 commits
3bc3fe3
6836b3a
b4d18e7
8f2678b
c62052a
f3d6f97
ff79a27
3d56189
933b374
3893069
f2f77ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
# TODO: (1) better device management | ||
|
||
from collections import deque | ||
from typing import Callable, Optional, Sequence, Tuple, Union | ||
from typing import Callable, Optional, Tuple | ||
|
||
import einops | ||
import numpy as np | ||
|
@@ -27,6 +27,7 @@ | |
import torch.nn.functional as F # noqa: N812 | ||
from huggingface_hub import PyTorchModelHubMixin | ||
from torch import Tensor | ||
from transformers import AutoModel | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to move it under the class definition as with the |
||
|
||
from lerobot.common.policies.normalize import Normalize, Unnormalize | ||
from lerobot.common.policies.sac.configuration_sac import SACConfig | ||
|
@@ -430,29 +431,41 @@ def __init__(self, config: SACConfig): | |
self.config = config | ||
|
||
if "observation.image" in config.input_shapes: | ||
self.image_enc_layers = nn.Sequential( | ||
nn.Conv2d( | ||
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2 | ||
), | ||
nn.ReLU(), | ||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), | ||
nn.ReLU(), | ||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), | ||
nn.ReLU(), | ||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), | ||
nn.ReLU(), | ||
) | ||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) | ||
with torch.inference_mode(): | ||
out_shape = self.image_enc_layers(dummy_batch).shape[1:] | ||
self.image_enc_layers.extend( | ||
nn.Sequential( | ||
nn.Flatten(), | ||
nn.Linear(np.prod(out_shape), config.latent_dim), | ||
if self.config.vision_encoder_name is not None: | ||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder() | ||
self.freeze_encoder() | ||
self.image_enc_proj = nn.Sequential( | ||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), | ||
nn.LayerNorm(config.latent_dim), | ||
nn.Tanh(), | ||
) | ||
) | ||
else: | ||
self.image_enc_layers = nn.Sequential( | ||
nn.Conv2d( | ||
config.input_shapes["observation.image"][0], | ||
config.image_encoder_hidden_dim, | ||
7, | ||
stride=2, | ||
), | ||
nn.ReLU(), | ||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), | ||
nn.ReLU(), | ||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), | ||
nn.ReLU(), | ||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), | ||
nn.ReLU(), | ||
) | ||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) | ||
with torch.inference_mode(): | ||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:] | ||
self.image_enc_layers.extend( | ||
nn.Sequential( | ||
nn.Flatten(), | ||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), | ||
nn.LayerNorm(config.latent_dim), | ||
nn.Tanh(), | ||
) | ||
) | ||
if "observation.state" in config.input_shapes: | ||
self.state_enc_layers = nn.Sequential( | ||
nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim), | ||
|
@@ -466,6 +479,22 @@ def __init__(self, config: SACConfig): | |
nn.Tanh(), | ||
) | ||
|
||
def _load_pretrained_vision_encoder(self): | ||
"""Set up CNN encoder""" | ||
self.image_enc_layers = AutoModel.from_pretrained(self.config.vision_encoder_name) | ||
if hasattr(self.image_enc_layers.config, "hidden_sizes"): | ||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension | ||
elif hasattr(self.image_enc_layers, "fc"): | ||
self.image_enc_out_shape = self.image_enc_layers.fc.in_features | ||
else: | ||
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN") | ||
return self.image_enc_layers, self.image_enc_out_shape | ||
|
||
def freeze_encoder(self): | ||
"""Freeze all parameters in the encoder""" | ||
for param in self.image_enc_layers.parameters(): | ||
param.requires_grad = False | ||
|
||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: | ||
"""Encode the image and/or state vector. | ||
|
||
|
@@ -476,7 +505,13 @@ def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: | |
# Concatenate all images along the channel dimension. | ||
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")] | ||
for image_key in image_keys: | ||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])) | ||
if self.config.vision_encoder_name is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a small refactoring - let's extract to a separate method |
||
enc_feat = self.image_enc_layers(obs_dict[image_key]).pooler_output | ||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1)) | ||
else: | ||
enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]) | ||
|
||
feat.append(enc_feat) | ||
if "observation.environment_state" in self.config.input_shapes: | ||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) | ||
if "observation.state" in self.config.input_shapes: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this to have more flexibility when creating transforms. Lmk if you don't think we need it, or if you have a better idea