Skip to content

Commit

Permalink
Merge branch 'feature/webots_runner' of github.com:bit-bots/ddlitlab2…
Browse files Browse the repository at this point in the history
…024 into feature/webots_runner
  • Loading branch information
Flova committed Jan 30, 2025
2 parents f767c12 + ea60f24 commit 891cdad
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 156 deletions.
273 changes: 122 additions & 151 deletions ddlitlab2024/ml/inference/ros.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import time
from threading import Lock
from typing import Optional

import cv2
import numpy as np
Expand All @@ -11,26 +13,21 @@
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from ema_pytorch import EMA
from game_controller_hl_interfaces.msg import GameState
from profilehooks import profile
from rclpy.callback_groups import MutuallyExclusiveCallbackGroup
from rclpy.duration import Duration
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node
from rclpy.time import Time
from sensor_msgs.msg import Image, JointState
from trajectory_msgs.msg import JointTrajectory, JointTrajectoryPoint

from ddlitlab2024.dataset.pytorch import Normalizer
from ddlitlab2024.ml.model import End2EndDiffusionTransformer
from ddlitlab2024.ml.model.encoder.image import ImageEncoderType, SequenceEncoderType
from ddlitlab2024.ml.model.encoder.imu import IMUEncoder
from ddlitlab2024.utils.utils import JOINT_NAMES_ORDER

from ddlitlab2024.dataset.pytorch import DDLITLab2024Dataset, Normalizer, worker_init_fn
from ddlitlab2024.ml import logger
from ddlitlab2024.ml.model import End2EndDiffusionTransformer
from ddlitlab2024.ml.model.encoder.image import ImageEncoderType, SequenceEncoderType
from ddlitlab2024.ml.model.encoder.imu import IMUEncoder

from torch.utils.data import DataLoader


# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand All @@ -50,20 +47,24 @@ def __init__(self, node_name, context):
self.action_context_length = 100
self.trajectory_prediction_length = 10
train_denoising_timesteps = 1000
self.inference_denosing_timesteps = 30
self.inference_denosing_timesteps = 10
self.image_context_length = 10
self.imu_context_length = 100
self.joint_state_context_length = 100
self.num_joints = 20
checkpoint = "/homes/17vahl/ddlitlab2024/ddlitlab2024/ml/training/trajectory_transformer_model_500_epoch_xmas.pth"
checkpoint = "/home/florian/ddlitlab/ddlitlab_repo/ddlitlab2024/ml/training/trajectory_transformer_model_500_epoch_xmas.pth"

# Subscribe to all the input topics
self.joint_state_sub = self.create_subscription(JointState, "/joint_states", self.joint_state_callback, 10)
self.img_sub = self.create_subscription(Image, "/camera/image_proc", self.img_callback, 10)
self.gamestate_sub = self.create_subscription(GameState, "/gamestate", self.gamestate_callback, 10)
self.motor_command_sub = self.create_subscription(
JointCommand, "/DynamixelController/command", self.motor_command_callback, 10
)

# Publisher for the output topic
self.joint_state_pub = self.create_publisher(JointCommand, "/DynamixelController/command", 10)
# self.joint_state_pub = self.create_publisher(JointCommand, "/DynamixelController/command", 10)
self.trajectory_pub = self.create_publisher(JointTrajectory, "/traj", 10)

# Image embedding buffer
self.latest_image = None
Expand All @@ -73,52 +74,23 @@ def __init__(self, node_name, context):
self.imu_data = []

# Joint state buffer
self.latest_joint_state = None
self.latest_joint_state: Optional[JointState] = None
self.joint_state_data = []

# Joint command buffer
self.latest_motor_command: Optional[JointCommand] = None
self.joint_command_data = []

# Gamestate
self.latest_game_state = None

# Add default values to the buffers
#self.image_embeddings = [torch.zeros(3, 480, 480)] * self.image_context_length
#self.imu_data = [torch.tensor([0.0, 0.0, 0.0, 1.0])] * self.imu_context_length
#self.joint_state_data = [torch.zeros(len(JOINT_NAMES_ORDER))] * self.joint_state_context_length
#self.joint_command_data = [torch.zeros(self.num_joints)] * self.action_context_length
self.image_embeddings = [torch.randn(3, 480, 480)] * self.image_context_length
self.imu_data = [torch.randn(4)] * self.imu_context_length
self.joint_state_data = [torch.randn(len(JOINT_NAMES_ORDER))] * self.joint_state_context_length
self.joint_command_data = [torch.randn(self.num_joints)] * self.action_context_length

# Create Dataset object
dataset = DDLITLab2024Dataset(
num_joints=self.num_joints,
num_frames_video=self.image_context_length,
num_samples_joint_trajectory_future=self.trajectory_prediction_length,
num_samples_joint_trajectory=self.action_context_length,
num_samples_imu=self.imu_context_length,
num_samples_joint_states=self.joint_state_context_length,
)

# Create DataLoader object
num_workers = 5
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
collate_fn=DDLITLab2024Dataset.collate_fn,
persistent_workers=num_workers > 1,
num_workers=num_workers,
worker_init_fn=worker_init_fn,
)

dataloader = iter(dataloader)

batch = next(dataloader)

# Add the data from the dataset to the buffers
self.image_embeddings = [x for x in batch.image_data.squeeze(0)]
self.imu_data = [x for x in batch.rotation.squeeze(0)]
self.joint_state_data = [x for x in batch.joint_state.squeeze(0)]
self.joint_command_data = [x for x in batch.joint_command_history.squeeze(0)]
self.data_lock = Lock()

# TF buffer to estimate imu similarly to the way we fixed the dataset
self.tf_buffer = Buffer(self, Duration(seconds=10))
Expand Down Expand Up @@ -164,8 +136,12 @@ def __init__(self, node_name, context):
self.scheduler.set_timesteps(self.inference_denosing_timesteps)

# Create control timer to run inference at a fixed rate
interval = 1 / self.sample_rate # * self.trajectory_prediction_length
self.create_timer(interval, self.step)
interval = 1 / self.sample_rate * self.trajectory_prediction_length
# We want to run the inference in a separate thread to not block the callbacks, but we also want to make sure
# that the inference is not running multiple times in parallel
self.create_timer(interval, self.step, callback_group=MutuallyExclusiveCallbackGroup())
interval = 1 / self.sample_rate
self.create_timer(interval, self.update_buffers)

def joint_state_callback(self, msg: JointState):
self.latest_joint_state = msg
Expand All @@ -176,88 +152,99 @@ def img_callback(self, msg: Image):
def gamestate_callback(self, msg: GameState):
self.latest_game_state = msg

def step(self):
self.get_logger().info("Step")

# First we want to fill the buffers
if self.latest_joint_state is not None:
# Joint names are not in the correct order, so we need to reorder them
joint_state = torch.zeros(len(JOINT_NAMES_ORDER))
for i, joint_name in enumerate(JOINT_NAMES_ORDER):
idx = self.latest_joint_state.name.index(joint_name)
joint_state[i] = self.latest_joint_state.position[idx]
self.get_logger().info("Storing joint state")
self.joint_state_data.append(joint_state)

self.get_logger().info("Calculating image embeddings")
if self.latest_image is not None:
# Here we don't just want to put the image in the buffer, but calculate the embedding first
# But for now the model dos not support the direct use of embeddings so we
# calculate them every timestep for the whole sequence.
# This is not efficient and should be changed in the future TODO

# Deserialize the image
img = self.cv_bridge.imgmsg_to_cv2(self.latest_image, desired_encoding="rgb8")

# Resize the image
img = cv2.resize(img, (480, 480))

# Make chw from hwc
img = np.moveaxis(img, -1, 0)

# Convert the image to a tensor
img = torch.tensor(img, dtype=torch.float32)

self.image_embeddings.append(img)

self.get_logger().info("Calculating IMU data")
# Due to a bug in the recordings of the bit-bots we can not use the imu data directly,
# but instead need to derive it from the tf tree
imu_transform = self.tf_buffer.lookup_transform("base_footprint", "base_link", Time())

self.get_logger().info("Storing IMU data")

# Store imu data as np array in the form wxyz
self.imu_data.append(
torch.tensor(
[
imu_transform.transform.rotation.x,
imu_transform.transform.rotation.y,
imu_transform.transform.rotation.z,
imu_transform.transform.rotation.w,
]
def motor_command_callback(self, msg: JointCommand):
self.latest_motor_command = msg

def update_buffers(self):
with self.data_lock:
# First we want to fill the buffers
if self.latest_joint_state is not None:
# Joint names are not in the correct order, so we need to reorder them
joint_state = torch.zeros(len(JOINT_NAMES_ORDER))
for i, joint_name in enumerate(JOINT_NAMES_ORDER):
idx = self.latest_joint_state.name.index(joint_name)
joint_state[i] = self.latest_joint_state.position[idx]
self.joint_state_data.append(joint_state)

if self.latest_motor_command is not None:
# Joint names are not in the correct order, so we need to reorder them
joint_state = torch.zeros(len(JOINT_NAMES_ORDER))
for i, joint_name in enumerate(JOINT_NAMES_ORDER):
idx = self.latest_motor_command.joint_names.index(joint_name)
joint_state[i] = self.latest_motor_command.positions[idx]
self.joint_command_data.append(joint_state)

if self.latest_image is not None:
# Here we don't just want to put the image in the buffer, but calculate the embedding first
# But for now the model dos not support the direct use of embeddings so we
# calculate them every timestep for the whole sequence.
# This is not efficient and should be changed in the future TODO

# Deserialize the image
img = self.cv_bridge.imgmsg_to_cv2(self.latest_image, desired_encoding="rgb8")

# Resize the image
img = cv2.resize(img, (480, 480))

# Make chw from hwc
img = np.moveaxis(img, -1, 0)

# Convert the image to a tensor
img = torch.tensor(img, dtype=torch.float32)

self.image_embeddings.append(img)

# Due to a bug in the recordings of the bit-bots we can not use the imu data directly,
# but instead need to derive it from the tf tree
imu_transform = self.tf_buffer.lookup_transform("base_footprint", "base_link", Time())

# Store imu data as np array in the form wxyz
self.imu_data.append(
torch.tensor(
[
imu_transform.transform.rotation.x,
imu_transform.transform.rotation.y,
imu_transform.transform.rotation.z,
imu_transform.transform.rotation.w,
]
)
)
)

print(self.imu_data[-1])
print(len(self.imu_data))
print(self.joint_state_data[-1])
# Remove the oldest data from the buffers
self.joint_state_data = self.joint_state_data[-self.joint_state_context_length :]
self.image_embeddings = self.image_embeddings[-self.image_context_length :]
self.imu_data = self.imu_data[-self.imu_context_length :]
self.joint_command_data = self.joint_command_data[-self.action_context_length :]

# Remove the oldest data from the buffers
self.joint_state_data = self.joint_state_data[-self.joint_state_context_length :]
self.image_embeddings = self.image_embeddings[-self.image_context_length :]
self.imu_data = self.imu_data[-self.imu_context_length :]
self.joint_command_data = self.joint_command_data[-self.action_context_length :]
@profile
def step(self):
self.get_logger().info("Step")

# Prepare the data for inference
batch = {
"joint_state": (torch.stack(list(self.joint_state_data), dim=0).unsqueeze(0).to(device) + 3 * np.pi)
% (2 * np.pi),
"image_data": torch.stack(list(self.image_embeddings), dim=0).unsqueeze(0).to(device),
"rotation": torch.stack(list(self.imu_data), dim=0).unsqueeze(0).to(device),
"joint_command_history": (torch.stack(list(self.joint_state_data), dim=0).unsqueeze(0).to(device) + 3 * np.pi)
% (2 * np.pi) # torch.stack(list(self.joint_command_data), dim=0).unsqueeze(0).to(device),
}

with self.data_lock:
batch = {
"joint_state": (torch.stack(list(self.joint_state_data), dim=0).unsqueeze(0).to(device) + 3 * np.pi)
% (2 * np.pi),
"image_data": torch.stack(list(self.image_embeddings), dim=0).unsqueeze(0).to(device),
"rotation": torch.stack(list(self.imu_data), dim=0).unsqueeze(0).to(device),
"joint_command_history": (
torch.stack(list(self.joint_command_data), dim=0).unsqueeze(0).to(device) + 3 * np.pi
)
% (2 * np.pi), # torch.stack(list(self.joint_command_data), dim=0).unsqueeze(0).to(device),
}

# Perform the denoising process
trajectory = torch.randn(1, self.trajectory_prediction_length, self.num_joints).to(device)

self.get_logger().info("Performing denoising process")
start_ros_time = self.get_clock().now()

## Perform the embedding of the conditioning
start = time.time()
embedded_input = self.og_model.encode_input_data(batch)
# print("Time for embedding: ", time.time() - start)

# Denoise the trajectory
start = time.time()
self.scheduler.set_timesteps(self.inference_denosing_timesteps)
for t in self.scheduler.timesteps:
with torch.no_grad():
Expand All @@ -269,47 +256,31 @@ def step(self):
# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = self.scheduler.step(noise_pred, t, trajectory).prev_sample

# print("Time for forward: ", time.time() - start)

# Undo the normalization
trajectory = self.normalizer.denormalize(trajectory)

self.get_logger().info("Publishing trajectory")

# Store the trajectory in the joint command buffer (action history)
self.joint_command_data.append(trajectory[0, -1].cpu())

# Publish the trajectory
self.joint_state_pub.publish(
JointCommand(
joint_names=JOINT_NAMES_ORDER,
velocities=[-1.0] * len(JOINT_NAMES_ORDER),
accelerations=[-1.0] * len(JOINT_NAMES_ORDER),
max_currents=[-1.0] * len(JOINT_NAMES_ORDER),
positions=trajectory[0, -1].cpu().numpy() - np.pi,
)
)

# Store the trajectory in the joint command buffer (action history)
#for i in range(self.trajectory_prediction_length):
# self.joint_command_data.append(trajectory[0, i].cpu())

## Publish the trajectory one by one
#for i in range(self.trajectory_prediction_length):
# time.sleep(1 / self.sample_rate)
# self.joint_state_pub.publish(
# JointCommand(
# joint_names=JOINT_NAMES_ORDER,
# velocities=[-1.0] * len(JOINT_NAMES_ORDER),
# accelerations=[-1.0] * len(JOINT_NAMES_ORDER),
# max_currents=[-1.0] * len(JOINT_NAMES_ORDER),
# positions=trajectory[0, i].cpu().numpy() - np.pi,
# )
# )
trajectory_msg = JointTrajectory()
trajectory_msg.header.stamp = Time.to_msg(start_ros_time)
trajectory_msg.joint_names = JOINT_NAMES_ORDER
trajectory_msg.points = []
for i in range(self.trajectory_prediction_length):
point = JointTrajectoryPoint()
point.positions = trajectory[0, i].cpu().numpy() - np.pi
point.time_from_start = Duration(nanoseconds=int(1e9 / self.sample_rate * i)).to_msg()
point.velocities = [3.0] * 2 + [-1.0] * (len(JOINT_NAMES_ORDER) - 2)
point.accelerations = [-1.0] * len(JOINT_NAMES_ORDER)
point.effort = [-1.0] * len(JOINT_NAMES_ORDER)
trajectory_msg.points.append(point)
self.trajectory_pub.publish(trajectory_msg)


def main(args=None):
rclpy.init(args=args)
node = Inference("inference", None)
executor = rclpy.executors.MultiThreadedExecutor()
executor = MultiThreadedExecutor(num_threads=5)
executor.add_node(node)
executor.spin()
rclpy.shutdown()
Expand Down
Loading

0 comments on commit 891cdad

Please sign in to comment.