Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Port HIL-SERL] Add HF vision encoder option in SAC #651

Open
wants to merge 11 commits into
base: user/adil-zouitine/2025-1-7-port-hil-serl-new
Choose a base branch
from
22 changes: 21 additions & 1 deletion lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,23 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData

image_transforms = None
if cfg.training.image_transforms.enable:
cfg_tf = cfg.training.image_transforms
default_tf = OmegaConf.create(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this to have more flexibility when creating transforms. Lmk if you don't think we need it, or if you have a better idea

{
"brightness": {"weight": 0.0, "min_max": None},
"contrast": {"weight": 0.0, "min_max": None},
"saturation": {"weight": 0.0, "min_max": None},
"hue": {"weight": 0.0, "min_max": None},
"sharpness": {"weight": 0.0, "min_max": None},
"max_num_transforms": None,
"random_order": False,
"image_size": None,
"interpolation": None,
"normalization_means": None,
"normalization_std": None,
}
)
cfg_tf = OmegaConf.merge(OmegaConf.create(default_tf), cfg.training.image_transforms)

image_transforms = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
Expand All @@ -88,6 +104,10 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width) if cfg_tf.image_size else None,
interpolation=cfg_tf.interpolation,
normalization_means=cfg_tf.normalization_means,
normalization_std=cfg_tf.normalization_std,
)

if isinstance(cfg.dataset_repo_id, str):
Expand Down
25 changes: 25 additions & 0 deletions lerobot/common/datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def get_image_transforms(
sharpness_min_max: tuple[float, float] | None = None,
max_num_transforms: int | None = None,
random_order: bool = False,
interpolation: str | None = None,
image_size: tuple[int, int] | None = None,
normalization_means: list[float] | None = None,
normalization_std: list[float] | None = None,
):
def check_value(name, weight, min_max):
if min_max is not None:
Expand All @@ -170,6 +174,18 @@ def check_value(name, weight, min_max):

weights = []
transforms = []
if image_size is not None:
interpolations = [interpolation.value for interpolation in v2.InterpolationMode]
if interpolation is None:
# Use BICUBIC as default interpolation
interpolation_mode = v2.InterpolationMode.BICUBIC
elif interpolation in interpolations:
interpolation_mode = v2.InterpolationMode(interpolation)
else:
raise ValueError("The interpolation passed is not supported")
# Weight for resizing is always 1
weights.append(1.0)
transforms.append(v2.Resize(size=(image_size[0], image_size[1]), interpolation=interpolation_mode))
if brightness_min_max is not None and brightness_weight > 0.0:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
Expand All @@ -185,6 +201,15 @@ def check_value(name, weight, min_max):
if sharpness_min_max is not None and sharpness_weight > 0.0:
weights.append(sharpness_weight)
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
if normalization_means is not None and normalization_std is not None:
# Weight for normalization is always 1
weights.append(1.0)
transforms.append(
v2.Normalize(
mean=normalization_means,
std=normalization_std,
)
)

n_subset = len(transforms)
if max_num_transforms is not None:
Expand Down
33 changes: 2 additions & 31 deletions lerobot/common/policies/sac/configuration_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,6 @@
from typing import Any


@dataclass
class SACConfig:
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
}
)

output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [2],
}
)

# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
"observation.environment_state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
from dataclasses import dataclass, field

@dataclass
class SACConfig:
input_shapes: dict[str, list[int]] = field(
Expand All @@ -67,11 +39,10 @@ class SACConfig:
"observation.environment_state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})

# Add type annotations for these fields:
vision_encoder_name: str = field(default="microsoft/resnet-18")
image_encoder_hidden_dim: int = 32
shared_encoder: bool = False
discount: float = 0.99
Expand Down
79 changes: 57 additions & 22 deletions lerobot/common/policies/sac/modeling_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# TODO: (1) better device management

from collections import deque
from typing import Callable, Optional, Sequence, Tuple, Union
from typing import Callable, Optional, Tuple

import einops
import numpy as np
Expand All @@ -27,6 +27,7 @@
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from transformers import AutoModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to move it under the class definition as with the Reward Classifier - https://github.com/huggingface/lerobot/pull/565/files#diff-160b98695ab8295e4ac586d8b9e50cb8e849b2bc31b2daceb28ded10580ab574R46. The reason is that by default we don't install transformers, so the library will crash without hil-serl deps installation.


from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
Expand Down Expand Up @@ -430,29 +431,41 @@ def __init__(self, config: SACConfig):
self.config = config

if "observation.image" in config.input_shapes:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(out_shape), config.latent_dim),
if self.config.vision_encoder_name is not None:
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder()
self.freeze_encoder()
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
else:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
config.input_shapes["observation.image"][0],
config.image_encoder_hidden_dim,
7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim),
Expand All @@ -466,6 +479,22 @@ def __init__(self, config: SACConfig):
nn.Tanh(),
)

def _load_pretrained_vision_encoder(self):
"""Set up CNN encoder"""
self.image_enc_layers = AutoModel.from_pretrained(self.config.vision_encoder_name)
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
return self.image_enc_layers, self.image_enc_out_shape

def freeze_encoder(self):
"""Freeze all parameters in the encoder"""
for param in self.image_enc_layers.parameters():
param.requires_grad = False

def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.

Expand All @@ -476,7 +505,13 @@ def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
# Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
for image_key in image_keys:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
if self.config.vision_encoder_name is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small refactoring - let's extract to a separate method has_pretrained_visual_encoder

enc_feat = self.image_enc_layers(obs_dict[image_key]).pooler_output
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
else:
enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])

feat.append(enc_feat)
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
Expand Down