Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to run inference of the model with a single image and no proprioception data? #5

Open
alik-git opened this issue Oct 18, 2024 · 5 comments

Comments

@alik-git
Copy link

Hi there!

Thank you for your great research and open-source contributions.

I just have a few questions about running your model.

What I am trying to do

I am trying to run RDT on the SimplerEnv to see how it compares to other baselines, but I am struggling to understand how to run inference of the pre-trained checkpoint given only a single image and text prompt as the input, which is the format for SimplerEnv. Is this something that is possible with RDT? I understand that RDT requires more inputs than just an image and text prompt (it requires proprioception data, control frequency, noisy action chunk and diffusion timestep as well), but I am wondering if there is a way to somehow pass blank (or default) values for those inputs in situations where they are not available.

What I have already tried

First I tried to run the agilex_inference.py as documented in the deployment section of your README here. But it seems that script expects to process inputs via ROS, which is not my situation.

What would help if possible

If you have anything similar to code snippet in the OpenVLA readme in the their getting started section, something like this below. Note that I got most of this code from your sample.py script. But I am struggling to get it to run as I'm not sure how to pass only a single image and text prompt to the model.

# install dependencies as shown in the README here https://github.com/alik-git/RoboticsDiffusionTransformer?tab=readme-ov-file#installation

from models.rdt_runner import RDTRunner
# other imports

# Load vision encoder
vision_encoder = SiglipVisionTower(
    vision_tower="/path/to/siglip-so400m-patch14-384",  
    args=None
)
vision_encoder.vision_tower.to(device, dtype=weight_dtype)
vision_encoder.vision_tower.eval()

# Get image embeddings # I'm not sure how to do this
# Suppose I have the image saved as so
image_path = "path/to/image.png"
# Then how can properly do
image = # load image at image_path
image_embeds = vision_encoder(image)

# Load language embeddings # this part I belive can be done via the `encode_lang.py` script, which I have managed to do
lang_embeddings = torch.load("path/to/outs/handover_pan.pt" , map_location=device)

# Load pretrained model
rdt = RDTRunner.from_pretrained(args.pretrained_model_name_or_path)
rdt.to(device)
rdt.eval()

actions = rdt.sample(
    lang_tokens=lang_embeddings.to(device, dtype=weight_dtype),
    img_tokens=image_embeds,
    state_tokens=states, # how can I get this?
    action_mask= # how can I get this?
    ctrl_freqs=torch.tensor([25.0], device=device),  # would this default work?
    sample_steps=config["diffusion"]["sampling_steps"], # what would be good default values for this?
    guidance_scale=config["diffusion"]["guidance_scale"], # what would be good default values for this?
)

# And what is the format of the action that RDT outputs?

I hope it is clear what I'm trying to do, but please feel free to ask me to clarify any of the above if needed. Any suggestions on how to achieve this would be very appreciated. Thank you!

@LBG21
Copy link

LBG21 commented Oct 18, 2024

Thank you for your question.

Our RDT model is pretrained and fine-tuned using proprioception data along with single or multiple images. If you intend to perform inference using only a single image without proprioception data, we recommend fine-tuning the RDT model with data in this format. For further details, please refer to the fine-tuning section of our documentation. However, it is important to note that we have not yet conducted experiments fine-tuning the RDT model without proprioception data, so we cannot guarantee the quality of the resulting performance.

@ethan-iai
Copy link
Contributor

Thank you for your engagement.

We are currently working to improve the clarity of the inference as you requested. We will inform you as soon as the updates are complete.

Feel free to let me know if you’d like any further adjustments!

@alik-git
Copy link
Author

Thank you for the prompt response! That would really help our use case.

And yes, I should have been more clear, I do want to finetune, but also before I do that I wanted to just do a sanity check to make sure I could at least run inference on the environment. I do not expect the model to perform well without finetuning, since as you said, it is quite a different input space without the proprioception data.

Again, please dont hesitate to ask if you'd like me to clarify anything more on my end. Thank you!

@ethan-iai
Copy link
Contributor

Thank you for the prompt response! That would really help our use case.

And yes, I should have been more clear, I do want to finetune, but also before I do that I wanted to just do a sanity check to make sure I could at least run inference on the environment. I do not expect the model to perform well without finetuning, since as you said, it is quite a different input space without the proprioception data.

Again, please dont hesitate to ask if you'd like me to clarify anything more on my end. Thank you!

Apologies for the late response. We have prepared a beta version of the minimal implementation for inference as requested. Please note that due to urgent circumstances, it has not been tested yet. We hope this version clarifies the inference process. Thank you for your understanding.

# install dependencies as shown in the README here https://github.com/alik-git/RoboticsDiffusionTransformer?tab=readme-ov-file#installation
import yaml
import torch
import numpy as np
from PIL import Image
from torchvision import transforms

from configs.state_vec import STATE_VEC_IDX_MAPPING
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
from models.rdt_runner import RDTRunner
# other imports

config_path = "configs/base.yaml"   # default config
pretrained_model_name_or_path = "path/to/rdt-model"
device = torch.device('cuda:0')
dtype = torch.bfloat16 # recommanded
cfg_scale = 2.0

# suppose you control in 7DOF joint position  
STATE_INDICES = [
    STATE_VEC_IDX_MAPPING['arm_joint_0_pos'],
    STATE_VEC_IDX_MAPPING['arm_joint_1_pos'],
    STATE_VEC_IDX_MAPPING['arm_joint_2_pos'],
    STATE_VEC_IDX_MAPPING['arm_joint_3_pos'],
    STATE_VEC_IDX_MAPPING['arm_joint_4_pos'],
    STATE_VEC_IDX_MAPPING['arm_joint_5_pos'],
    STATE_VEC_IDX_MAPPING['arm_joint_6_pos'],
    STATE_VEC_IDX_MAPPING['gripper_open']
]

with open(config_path, "r") as fp:
    config = yaml.safe_load(fp)

# Load vision encoder
vision_encoder = SiglipVisionTower(
    vision_tower="/path/to/siglip-so400m-patch14-384",  
    args=None
)
vision_encoder.to(device, dtype=dtype)
vision_encoder.eval()
image_processor = vision_encoder.image_processor

# Load pretrained model (in HF style)
rdt = RDTRunner.from_pretrained(pretrained_model_name_or_path)
rdt.to(device, dtype=dtype)
rdt.eval()

previous_image_path = None
# previous_image = None # if t = 0
previous_image = Image.open(previous_image_path).convert("RGB") # if t > 0

current_image_path = None
current_image = Image.open(current_image_path).convert("RGB")

# here I suppose you only have an image from exterior (e.g., 3rd person view) and you don't have any state information
# the images shoud arrange in sequence [exterior_image, right_wrist_image, left_wrist_image] * image_history_size (e.g., 2)
rgbs_lst = [
    [previous_image, None, None],
    [current_image, None, None]
] 
# if your have an right_wrist_image, then it should be
# rgbs_lst = [
#     [previous_image, previous_right_wrist_image, None],
#     [current_image, current_right_wrist_image, None]
# ]

# image pre-processing
# The background image used for padding
background_color = np.array([
    int(x*255) for x in image_processor.image_mean
], dtype=np.uint8).reshape(1, 1, 3)
background_image = np.ones((
    image_processor.size["height"], 
    image_processor.size["width"], 3), dtype=np.uint8
) * background_color

image_tensor_list = []
for step in range(config["common"]["img_history_size"]):
    rgbs = rgbs_lst[step % len(rgbs_lst)]
    for rgb in rgbs:
        if rgb is None:
            # Replace it with the background image
            image = Image.fromarray(background_image)
        else:
            image = Image.fromarray((rgb * 255).astype(np.uint8))
        
        if config["dataset"].get("auto_adjust_image_brightness", False):
            pixel_values = list(image.getdata())
            average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
            if average_brightness <= 0.15:
                image = transforms.ColorJitter(brightness=(1.75,1.75))(image)
                
        if config["dataset"].get("image_aspect_ratio", "pad") == 'pad':
            def expand2square(pil_img, background_color):
                width, height = pil_img.size
                if width == height:
                    return pil_img
                elif width > height:
                    result = Image.new(pil_img.mode, (width, width), background_color)
                    result.paste(pil_img, (0, (width - height) // 2))
                    return result
                else:
                    result = Image.new(pil_img.mode, (height, height), background_color)
                    result.paste(pil_img, ((height - width) // 2, 0))
                    return result
            image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
        image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
        image_tensor_list.append(image)

image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)

# encode images
image_embeds = vision_encoder(image_tensor).detach()
image_embeds = image_embeds.reshape(-1, vision_encoder.hidden_size).unsqueeze(0)

# Load language embeddings # this part I belive can be done via the `encode_lang.py` script, which I have managed to do
lang_embeddings = torch.load("path/to/outs/handover_pan.pt" , map_location=device)

# suppose you do not have proprio
# it's kind of tricky, I strongly suggest adding proprio as input and futher fine-tuning
B, N = 1, 1 # batch size and state history size
states = torch.zeros(
    (B, N, config["model"]["state_token_dim"]), 
    device=device, dtype=dtype
)

# if you have proprio, you can do like this
# format like this: [arm_joint_0_pos, arm_joint_1_pos, arm_joint_2_pos, arm_joint_3_pos, arm_joint_4_pos, arm_joint_5_pos, arm_joint_6_pos, gripper_open]
# proprio = torch.tensor([0, 1, 2, 3, 4, 5, 6, 0.5]).reshape((1, 1, -1))
# states[:, :, STATE_INDICES] = proprio

state_elem_mask = torch.zeros(
    (B, config["model"]["state_token_dim"]),
    device=device, dtype=torch.bool
)
state_elem_mask[:, STATE_INDICES] = True  
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
states = states[:, -1:, :] # only use the last state

actions = rdt.predict_action(
    lang_tokens=lang_embeddings.to(device, dtype=dtype),
    lang_attn_mask=torch.ones(
        lang_embeddings.shape[:2], dtype=torch.bool,
        device=device
    ),
    img_tokens=image_embeds,
    state_tokens=states, # how can I get this?
    action_mask=state_elem_mask.unsqueeze(1), # how can I get this?
    ctrl_freqs=torch.tensor([25.0], device=device),  # would this default work?
)   # (1, chunk_size, 128)

# select the meaning action via STATE_INDICES
action = actions[:, :, STATE_INDICES]   # (1, chunk_size, len(STATE_INDICES)) = (1, chunk_size, 7+ 1)

@csuastt
Copy link
Collaborator

csuastt commented Oct 19, 2024

I am wondering if there is a way to somehow pass blank (or default) values for those inputs in situations where they are not available

Apologies for the delayed response—I just returned from my weekend trip. You can use a padding value when fine-tuning and running RDT.

We're also in the process of fine-tuning and running RDT on ManiSkill2/SimplerEnv. Once that work is completed, we'll be releasing everything. Stay tuned!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants