Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Action Chunking with Transformers (ACT) to baselines #640

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
122 changes: 122 additions & 0 deletions examples/baselines/act/act/detr/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Backbone modules.
"""
from collections import OrderedDict

import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List

from ..utils import NestedTensor, is_main_process
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use absolute imports when possible, it is just the style choice this repo uses.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed them to absolute imports, but I'm not sure if they are correct. please let me know if they need to be fixed. I also added a README file.


from .position_encoding import build_position_encoding

import IPython
e = IPython.embed

class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.

Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""

def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]

super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)

def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias


class BackboneBase(nn.Module):

def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {'layer4': "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels

def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out


class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)


class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)

def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))

return out, pos


def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model
139 changes: 139 additions & 0 deletions examples/baselines/act/act/detr/detr_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR model and criterion classes.
"""
import torch
from torch import nn
from torch.autograd import Variable
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer

import numpy as np

import IPython
e = IPython.embed


def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps


def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1

return torch.FloatTensor(sinusoid_table).unsqueeze(0)


class DETRVAE(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries):
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
self.encoder = encoder
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, action_dim)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
self.backbones = None

# encoder extra parameters
self.latent_dim = 32 # size of latent z
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_state_proj = nn.Linear(state_dim, hidden_dim) # project state to embedding
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], state, actions

# decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for state and proprio

def forward(self, obs, actions=None):
is_training = actions is not None
state = obs['state'] if self.backbones is not None else obs
bs = state.shape[0]

if is_training:
# project CLS token, state sequence, and action sequence to embedding dim
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
state_embed = self.encoder_state_proj(state) # (bs, hidden_dim)
state_embed = torch.unsqueeze(state_embed, axis=1) # (bs, 1, hidden_dim)
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
# concat them together to form an input to the CVAE encoder
encoder_input = torch.cat([cls_embed, state_embed, action_embed], axis=1) # (bs, seq+2, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+2, bs, hidden_dim)
# no masking is applied to all parts of the CVAE encoder input
is_pad = torch.full((bs, encoder_input.shape[0]), False).to(state.device) # False: not a padding
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+2, 1, hidden_dim)
# query CVAE encoder
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
mu = latent_info[:, :self.latent_dim]
logvar = latent_info[:, self.latent_dim:]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = None
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(state.device)
latent_input = self.latent_out_proj(latent_sample)

# CVAE decoder
if self.backbones is not None:
vis_data = obs['rgb'] if "rgb" in obs else obs['rgbd']
num_cams = vis_data.shape[1]

# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id in range(num_cams):
features, pos = self.backbones[0](vis_data[:, cam_id]) # HARDCODED
features = features[0] # take the last layer feature # (batch, hidden_dim, H, W)
pos = pos[0] # (1, hidden_dim, H, W)
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)

# proprioception features (state)
proprio_input = self.input_proj_robot_state(state)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3) # (batch, hidden_dim, 4, 8)
pos = torch.cat(all_cam_pos, axis=3) # (batch, hidden_dim, 4, 8)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0] # (batch, num_queries, hidden_dim)
else:
state = self.input_proj_robot_state(state)
hs = self.transformer(None, None, self.query_embed.weight, None, latent_input, state, self.additional_pos_embed.weight)[0]

a_hat = self.action_head(hs)
return a_hat, [mu, logvar]


def build_encoder(args):
d_model = args.hidden_dim # 256
dropout = args.dropout # 0.1
nhead = args.nheads # 8
dim_feedforward = args.dim_feedforward # 2048
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
normalize_before = args.pre_norm # False
activation = "relu"

encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

return encoder
93 changes: 93 additions & 0 deletions examples/baselines/act/act/detr/position_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn

from ..utils import NestedTensor

import IPython
e = IPython.embed

class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale

def forward(self, tensor):
x = tensor
# mask = tensor_list.mask
# assert mask is not None
# not_mask = ~mask

not_mask = torch.ones_like(x[0, [0]])
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos


class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()

def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)

def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = torch.cat([
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
return pos


def build_position_encoding(args):
N_steps = args.hidden_dim // 2
if args.position_embedding in ('v2', 'sine'):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif args.position_embedding in ('v3', 'learned'):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")

return position_embedding
Loading