Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Port HIL-SERL] Add HF vision encoder option in SAC #651

Open
wants to merge 11 commits into
base: user/adil-zouitine/2025-1-7-port-hil-serl-new
Choose a base branch
from

Conversation

ChorntonYoel
Copy link

@ChorntonYoel ChorntonYoel commented Jan 21, 2025

What this does

Adds a pretrained vision encoder from HF to the sac modeling + allows for appropriate transforms when training.

How it was tested

Ran a sac training with moss.
My sac training script and associated modeling have slight differences, so it would be good to double check with a training loop incorporating real images and the current sac training script.

@ChorntonYoel ChorntonYoel marked this pull request as draft January 21, 2025 12:16
@@ -74,7 +74,23 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData

image_transforms = None
if cfg.training.image_transforms.enable:
cfg_tf = cfg.training.image_transforms
default_tf = OmegaConf.create(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this to have more flexibility when creating transforms. Lmk if you don't think we need it, or if you have a better idea

@ChorntonYoel ChorntonYoel marked this pull request as ready for review January 21, 2025 15:29
Copy link
Contributor

@helper2424 helper2424 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks fine in general. I haven't checked it manually as my env is broken a little bit - will try to do it tomorrow.

@@ -27,6 +27,7 @@
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from transformers import AutoModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to move it under the class definition as with the Reward Classifier - https://github.com/huggingface/lerobot/pull/565/files#diff-160b98695ab8295e4ac586d8b9e50cb8e849b2bc31b2daceb28ded10580ab574R46. The reason is that by default we don't install transformers, so the library will crash without hil-serl deps installation.

@@ -476,7 +505,13 @@ def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
# Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
for image_key in image_keys:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
if self.config.vision_encoder_name is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small refactoring - let's extract to a separate method has_pretrained_visual_encoder

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants