-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Action Chunking with Transformers (ACT) to baselines #640
Open
ywchoi02
wants to merge
19
commits into
haosulab:main
Choose a base branch
from
ywchoi02:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 15 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
43d027b
Create train.py
ywchoi02 cc0cb60
Create train_rgb.py
ywchoi02 e154fe3
Create evaluate.py
ywchoi02 2e126e7
Create make_env.py
ywchoi02 7abc4cb
Create utils.py
ywchoi02 f48fdc0
Create backbone.py
ywchoi02 aaae4b7
Create detr_vae.py
ywchoi02 07b5741
Create position_encoding.py
ywchoi02 1f89c29
Create transformer.py
ywchoi02 479841f
Update evaluate.py
ywchoi02 bdd1660
Create README.md
ywchoi02 09d5411
Changed to absolute import
ywchoi02 622a36e
Update absolute import
ywchoi02 4932d9b
change to absolute import
ywchoi02 37eb2af
change to absolute import
ywchoi02 e5ebaa4
fix import
ywchoi02 e5e0aeb
fix import
ywchoi02 35979e9
fix import
ywchoi02 04972f3
Create setup.py
ywchoi02 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Action Chunking with Transformers (ACT) | ||
|
||
Code for running the ACT algorithm based on ["Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware"](https://arxiv.org/pdf/2304.13705). It is adapted from the [original code](https://github.com/tonyzhaozh/act). | ||
|
||
## Installation | ||
|
||
To get started, we recommend using conda/mamba to create a new environment and install the dependencies | ||
|
||
```bash | ||
conda create -n act-ms python=3.9 | ||
conda activate act-ms | ||
pip install -e . | ||
``` | ||
|
||
## Demonstration Download and Preprocessing | ||
|
||
By default for fast downloads and smaller file sizes, ManiSkill demonstrations are stored in a highly reduced/compressed format which includes not keeping any observation data. Run the command to download the demonstration and convert it to a format that includes observation data and the desired action space. | ||
|
||
```bash | ||
python -m mani_skill.utils.download_demo "PickCube-v1" | ||
``` | ||
|
||
```bash | ||
python -m mani_skill.trajectory.replay_trajectory \ | ||
--traj-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.h5 \ | ||
--use-first-env-state -c pd_ee_delta_pos -o state \ | ||
--save-traj --num-procs 10 | ||
``` | ||
|
||
Set -o to rgbd for RGBD observations. Note that the control mode can heavily influence how well Behavior Cloning performs. In the paper, they reported a degraded performance when using delta joint positions as actions instead of target joint positions. By default, we recommend using `pd_joint_delta_pos` for control mode as all tasks can be solved with that control mode, although it is harder to learn with BC than `pd_ee_delta_pos` or `pd_ee_delta_pose` for robots that have those control modes. Finally, the type of demonstration data used can also impact performance, with typically neural network generated demonstrations being easier to learn from than human/motion planning generated demonstrations. | ||
|
||
## Training | ||
|
||
We provide scripts to train ACT on demonstrations. Make sure to use the same sim backend as the backend the demonstrations were collected with. | ||
|
||
|
||
Note that some demonstrations are slow (e.g. motion planning or human teleoperated) and can exceed the default max episode steps which can be an issue as imitation learning algorithms learn to solve the task at the same speed the demonstrations solve it. In this case, you can use the `--max-episode-steps` flag to set a higher value so that the policy can solve the task in time. General recommendation is to set `--max-episode-steps` to about 2x the length of the mean demonstrations length you are using for training. We provide recommended numbers for demonstrations in the examples.sh script. | ||
|
||
Example training, learning from 100 demonstrations generated via motionplanning in the PickCube-v1 task | ||
```bash | ||
python train.py --env-id PickCube-v1 \ | ||
--demo-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.state.pd_ee_delta_pos.cuda.h5 \ | ||
--control-mode "pd_ee_delta_pos" --sim-backend "cpu" --num-demos 100 --max_episode_steps 100 \ | ||
--total_iters 30000 | ||
``` | ||
|
||
|
||
## Train and Evaluate with GPU Simulation | ||
|
||
You can also choose to train on trajectories generated in the GPU simulation and evaluate much faster with the GPU simulation. However as most demonstrations are usually generated in the CPU simulation (via motionplanning or teleoperation), you may observe worse performance when evaluating on the GPU simulation vs the CPU simulation. This can be partially alleviated by using the replay trajectory tool to try and replay trajectories back in the GPU simulation. | ||
|
||
It is also recommended to not save videos if you are using a lot of parallel environments as the video size can get very large. | ||
|
||
To replay trajectories in the GPU simulation, you can use the following command. Note that this can be a bit slow as the replay trajectory tool is currently not optimized for GPU parallelized environments. | ||
|
||
```bash | ||
python -m mani_skill.trajectory.replay_trajectory \ | ||
--traj-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.h5 \ | ||
--use-first-env-state -c pd_ee_delta_pos -o state \ | ||
--save-traj --num-procs 1 -b gpu --count 100 # process only 100 trajectories | ||
``` | ||
|
||
Once our GPU backend demonstration dataset is ready, you can use the following command to train and evaluate on the GPU simulation. | ||
|
||
```bash | ||
python train.py --env-id PickCube-v1 \ | ||
--demo-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.state.pd_ee_delta_pos.cuda.h5 \ | ||
--control-mode "pd_ee_delta_pos" --sim-backend "gpu" --num-demos 100 --max_episode_steps 100 \ | ||
--total_iters 30000 \ | ||
--num-eval-envs 100 --no-capture-video | ||
``` | ||
|
||
## Citation | ||
|
||
If you use this baseline please cite the following | ||
``` | ||
@inproceedings{DBLP:conf/rss/ZhaoKLF23, | ||
author = {Tony Z. Zhao and | ||
Vikash Kumar and | ||
Sergey Levine and | ||
Chelsea Finn}, | ||
editor = {Kostas E. Bekris and | ||
Kris Hauser and | ||
Sylvia L. Herbert and | ||
Jingjin Yu}, | ||
title = {Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware}, | ||
booktitle = {Robotics: Science and Systems XIX, Daegu, Republic of Korea, July | ||
10-14, 2023}, | ||
year = {2023}, | ||
url = {https://doi.org/10.15607/RSS.2023.XIX.016}, | ||
doi = {10.15607/RSS.2023.XIX.016}, | ||
timestamp = {Thu, 20 Jul 2023 15:37:49 +0200}, | ||
biburl = {https://dblp.org/rec/conf/rss/ZhaoKLF23.bib}, | ||
bibsource = {dblp computer science bibliography, https://dblp.org} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
""" | ||
Backbone modules. | ||
""" | ||
from collections import OrderedDict | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import torchvision | ||
from torch import nn | ||
from torchvision.models._utils import IntermediateLayerGetter | ||
from typing import Dict, List | ||
|
||
from examples.baselines.act.act.utils import NestedTensor, is_main_process | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. imports should be absolute and relative to act (which you pip install -e .) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated my code. |
||
from examples.baselines.act.act.detr.position_encoding import build_position_encoding | ||
|
||
import IPython | ||
e = IPython.embed | ||
|
||
class FrozenBatchNorm2d(torch.nn.Module): | ||
""" | ||
BatchNorm2d where the batch statistics and the affine parameters are fixed. | ||
|
||
Copy-paste from torchvision.misc.ops with added eps before rqsrt, | ||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] | ||
produce nans. | ||
""" | ||
|
||
def __init__(self, n): | ||
super(FrozenBatchNorm2d, self).__init__() | ||
self.register_buffer("weight", torch.ones(n)) | ||
self.register_buffer("bias", torch.zeros(n)) | ||
self.register_buffer("running_mean", torch.zeros(n)) | ||
self.register_buffer("running_var", torch.ones(n)) | ||
|
||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | ||
missing_keys, unexpected_keys, error_msgs): | ||
num_batches_tracked_key = prefix + 'num_batches_tracked' | ||
if num_batches_tracked_key in state_dict: | ||
del state_dict[num_batches_tracked_key] | ||
|
||
super(FrozenBatchNorm2d, self)._load_from_state_dict( | ||
state_dict, prefix, local_metadata, strict, | ||
missing_keys, unexpected_keys, error_msgs) | ||
|
||
def forward(self, x): | ||
# move reshapes to the beginning | ||
# to make it fuser-friendly | ||
w = self.weight.reshape(1, -1, 1, 1) | ||
b = self.bias.reshape(1, -1, 1, 1) | ||
rv = self.running_var.reshape(1, -1, 1, 1) | ||
rm = self.running_mean.reshape(1, -1, 1, 1) | ||
eps = 1e-5 | ||
scale = w * (rv + eps).rsqrt() | ||
bias = b - rm * scale | ||
return x * scale + bias | ||
|
||
|
||
class BackboneBase(nn.Module): | ||
|
||
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): | ||
super().__init__() | ||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this? | ||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: | ||
# parameter.requires_grad_(False) | ||
if return_interm_layers: | ||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} | ||
else: | ||
return_layers = {'layer4': "0"} | ||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) | ||
self.num_channels = num_channels | ||
|
||
def forward(self, tensor): | ||
xs = self.body(tensor) | ||
return xs | ||
# out: Dict[str, NestedTensor] = {} | ||
# for name, x in xs.items(): | ||
# m = tensor_list.mask | ||
# assert m is not None | ||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] | ||
# out[name] = NestedTensor(x, mask) | ||
# return out | ||
|
||
|
||
class Backbone(BackboneBase): | ||
"""ResNet backbone with frozen BatchNorm.""" | ||
def __init__(self, name: str, | ||
train_backbone: bool, | ||
return_interm_layers: bool, | ||
dilation: bool): | ||
backbone = getattr(torchvision.models, name)( | ||
replace_stride_with_dilation=[False, False, dilation], | ||
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm?? | ||
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 | ||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers) | ||
|
||
|
||
class Joiner(nn.Sequential): | ||
def __init__(self, backbone, position_embedding): | ||
super().__init__(backbone, position_embedding) | ||
|
||
def forward(self, tensor_list: NestedTensor): | ||
xs = self[0](tensor_list) | ||
out: List[NestedTensor] = [] | ||
pos = [] | ||
for name, x in xs.items(): | ||
out.append(x) | ||
# position encoding | ||
pos.append(self[1](x).to(x.dtype)) | ||
|
||
return out, pos | ||
|
||
|
||
def build_backbone(args): | ||
position_embedding = build_position_encoding(args) | ||
train_backbone = args.lr_backbone > 0 | ||
return_interm_layers = args.masks | ||
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) | ||
model = Joiner(backbone, position_embedding) | ||
model.num_channels = backbone.num_channels | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
""" | ||
DETR model and criterion classes. | ||
""" | ||
import torch | ||
from torch import nn | ||
from torch.autograd import Variable | ||
from examples.baselines.act.act.detr.transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer | ||
|
||
import numpy as np | ||
|
||
import IPython | ||
e = IPython.embed | ||
|
||
|
||
def reparametrize(mu, logvar): | ||
std = logvar.div(2).exp() | ||
eps = Variable(std.data.new(std.size()).normal_()) | ||
return mu + std * eps | ||
|
||
|
||
def get_sinusoid_encoding_table(n_position, d_hid): | ||
def get_position_angle_vec(position): | ||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] | ||
|
||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) | ||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | ||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | ||
|
||
return torch.FloatTensor(sinusoid_table).unsqueeze(0) | ||
|
||
|
||
class DETRVAE(nn.Module): | ||
""" This is the DETR module that performs object detection """ | ||
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries): | ||
super().__init__() | ||
self.num_queries = num_queries | ||
self.transformer = transformer | ||
self.encoder = encoder | ||
hidden_dim = transformer.d_model | ||
self.action_head = nn.Linear(hidden_dim, action_dim) | ||
self.query_embed = nn.Embedding(num_queries, hidden_dim) | ||
if backbones is not None: | ||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) | ||
self.backbones = nn.ModuleList(backbones) | ||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) | ||
else: | ||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) | ||
self.backbones = None | ||
|
||
# encoder extra parameters | ||
self.latent_dim = 32 # size of latent z | ||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding | ||
self.encoder_state_proj = nn.Linear(state_dim, hidden_dim) # project state to embedding | ||
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding | ||
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var | ||
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], state, actions | ||
|
||
# decoder extra parameters | ||
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding | ||
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for state and proprio | ||
|
||
def forward(self, obs, actions=None): | ||
is_training = actions is not None | ||
state = obs['state'] if self.backbones is not None else obs | ||
bs = state.shape[0] | ||
|
||
if is_training: | ||
# project CLS token, state sequence, and action sequence to embedding dim | ||
cls_embed = self.cls_embed.weight # (1, hidden_dim) | ||
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim) | ||
state_embed = self.encoder_state_proj(state) # (bs, hidden_dim) | ||
state_embed = torch.unsqueeze(state_embed, axis=1) # (bs, 1, hidden_dim) | ||
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) | ||
# concat them together to form an input to the CVAE encoder | ||
encoder_input = torch.cat([cls_embed, state_embed, action_embed], axis=1) # (bs, seq+2, hidden_dim) | ||
encoder_input = encoder_input.permute(1, 0, 2) # (seq+2, bs, hidden_dim) | ||
# no masking is applied to all parts of the CVAE encoder input | ||
is_pad = torch.full((bs, encoder_input.shape[0]), False).to(state.device) # False: not a padding | ||
# obtain position embedding | ||
pos_embed = self.pos_table.clone().detach() | ||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+2, 1, hidden_dim) | ||
# query CVAE encoder | ||
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad) | ||
encoder_output = encoder_output[0] # take cls output only | ||
latent_info = self.latent_proj(encoder_output) | ||
mu = latent_info[:, :self.latent_dim] | ||
logvar = latent_info[:, self.latent_dim:] | ||
latent_sample = reparametrize(mu, logvar) | ||
latent_input = self.latent_out_proj(latent_sample) | ||
else: | ||
mu = logvar = None | ||
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(state.device) | ||
latent_input = self.latent_out_proj(latent_sample) | ||
|
||
# CVAE decoder | ||
if self.backbones is not None: | ||
vis_data = obs['rgb'] if "rgb" in obs else obs['rgbd'] | ||
num_cams = vis_data.shape[1] | ||
|
||
# Image observation features and position embeddings | ||
all_cam_features = [] | ||
all_cam_pos = [] | ||
for cam_id in range(num_cams): | ||
features, pos = self.backbones[0](vis_data[:, cam_id]) # HARDCODED | ||
features = features[0] # take the last layer feature # (batch, hidden_dim, H, W) | ||
pos = pos[0] # (1, hidden_dim, H, W) | ||
all_cam_features.append(self.input_proj(features)) | ||
all_cam_pos.append(pos) | ||
|
||
# proprioception features (state) | ||
proprio_input = self.input_proj_robot_state(state) | ||
# fold camera dimension into width dimension | ||
src = torch.cat(all_cam_features, axis=3) # (batch, hidden_dim, 4, 8) | ||
pos = torch.cat(all_cam_pos, axis=3) # (batch, hidden_dim, 4, 8) | ||
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0] # (batch, num_queries, hidden_dim) | ||
else: | ||
state = self.input_proj_robot_state(state) | ||
hs = self.transformer(None, None, self.query_embed.weight, None, latent_input, state, self.additional_pos_embed.weight)[0] | ||
|
||
a_hat = self.action_head(hs) | ||
return a_hat, [mu, logvar] | ||
|
||
|
||
def build_encoder(args): | ||
d_model = args.hidden_dim # 256 | ||
dropout = args.dropout # 0.1 | ||
nhead = args.nheads # 8 | ||
dim_feedforward = args.dim_feedforward # 2048 | ||
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder | ||
normalize_before = args.pre_norm # False | ||
activation = "relu" | ||
|
||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, | ||
dropout, activation, normalize_before) | ||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None | ||
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) | ||
|
||
return encoder |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the conda env act-ms is created and you do a local pip install. However a simple setup.py file is still missing, can you create that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created a simple setup.py.