-
Notifications
You must be signed in to change notification settings - Fork 867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Port HIL-SERL] Add HF vision encoder option in SAC #651
base: user/adil-zouitine/2025-1-7-port-hil-serl-new
Are you sure you want to change the base?
[Port HIL-SERL] Add HF vision encoder option in SAC #651
Conversation
@@ -74,7 +74,23 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData | |||
|
|||
image_transforms = None | |||
if cfg.training.image_transforms.enable: | |||
cfg_tf = cfg.training.image_transforms | |||
default_tf = OmegaConf.create( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this to have more flexibility when creating transforms. Lmk if you don't think we need it, or if you have a better idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code looks fine in general. I haven't checked it manually as my env is broken a little bit - will try to do it tomorrow.
@@ -27,6 +27,7 @@ | |||
import torch.nn.functional as F # noqa: N812 | |||
from huggingface_hub import PyTorchModelHubMixin | |||
from torch import Tensor | |||
from transformers import AutoModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to move it under the class definition as with the Reward Classifier
- https://github.com/huggingface/lerobot/pull/565/files#diff-160b98695ab8295e4ac586d8b9e50cb8e849b2bc31b2daceb28ded10580ab574R46. The reason is that by default we don't install transformers
, so the library will crash without hil-serl
deps installation.
@@ -476,7 +505,13 @@ def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: | |||
# Concatenate all images along the channel dimension. | |||
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")] | |||
for image_key in image_keys: | |||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])) | |||
if self.config.vision_encoder_name is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a small refactoring - let's extract to a separate method has_pretrained_visual_encoder
What this does
Adds a pretrained vision encoder from HF to the sac modeling + allows for appropriate transforms when training.
How it was tested
Ran a sac training with moss.
My sac training script and associated modeling have slight differences, so it would be good to double check with a training loop incorporating real images and the current sac training script.