Skip to content

Commit

Permalink
Current WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Flova committed Jan 30, 2025
1 parent 34bcf83 commit f4e74b0
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 63 deletions.
2 changes: 1 addition & 1 deletion ddlitlab2024/ml/inference/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
noisy_trajectory = torch.randn_like(joint_targets).to(device)
trajectory = noisy_trajectory

if params["distilled_decoder"]:
if params.get("distilled_decoder", False):
# Directly predict the trajectory based on the noise
with torch.no_grad():
trajectory = model(batch, noisy_trajectory, torch.tensor([0], device=device))
Expand Down
131 changes: 69 additions & 62 deletions ddlitlab2024/ml/inference/ros.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from bitbots_tf_buffer import Buffer
from cv_bridge import CvBridge
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from ema_pytorch import EMA
from game_controller_hl_interfaces.msg import GameState
from profilehooks import profile
from rclpy.callback_groups import MutuallyExclusiveCallbackGroup
Expand Down Expand Up @@ -42,21 +41,18 @@ def __init__(self, node_name, context):
[rclpy.parameter.Parameter("use_sim_time", rclpy.Parameter.Type.BOOL, True)],
)

checkpoint_path = (
"../training/destilled_trajectory_transformer_model_first_train_20_epoch_hyp.pth"
#"../training/trajectory_transformer_model_500_epoch_xmas_hyp.pth"
)
self.inference_denosing_timesteps = 30

# Params
self.sample_rate = DEFAULT_RESAMPLE_RATE_HZ
hidden_dim = 256
self.action_context_length = 100
self.trajectory_prediction_length = 10
train_denoising_timesteps = 1000
self.inference_denosing_timesteps = 10
self.image_context_length = 10
self.imu_context_length = 100
self.joint_state_context_length = 100
self.num_joints = 20
checkpoint = (
"/home/florian/ddlitlab/ddlitlab_repo/ddlitlab2024/ml/training/"
"trajectory_transformer_model_500_epoch_xmas.pth"
)
# Load the hyperparameters from the checkpoint
self.get_logger().info(f"Loading checkpoint '{checkpoint_path}'")
checkpoint = torch.load(checkpoint_path, weights_only=True)
self.hyper_params = checkpoint["hyperparams"]

# Subscribe to all the input topics
self.joint_state_sub = self.create_subscription(JointState, "/joint_states", self.joint_state_callback, 10)
Expand Down Expand Up @@ -89,12 +85,14 @@ def __init__(self, node_name, context):
self.latest_game_state = None

# Add default values to the buffers
self.image_embeddings = [torch.randn(3, 480, 480)] * self.image_context_length
self.imu_data = [torch.randn(4)] * self.imu_context_length
self.joint_state_data = [
torch.randn(len(JointStates.get_ordered_joint_names()))
] * self.joint_state_context_length
self.joint_command_data = [torch.randn(self.num_joints)] * self.action_context_length
self.image_embeddings = [torch.randn(3, 480, 480)] * self.hyper_params["image_context_length"]
self.imu_data = [torch.randn(4)] * self.hyper_params["imu_context_length"]
self.joint_state_data = [torch.randn(len(JointStates.get_ordered_joint_names()))] * self.hyper_params[
"joint_state_context_length"
]
self.joint_command_data = [torch.randn(self.hyper_params["num_joints"])] * self.hyper_params[
"action_context_length"
]

self.data_lock = Lock()

Expand All @@ -106,43 +104,41 @@ def __init__(self, node_name, context):
# Load model
self.get_logger().info("Load model")
self.model = End2EndDiffusionTransformer(
num_joints=self.num_joints,
hidden_dim=hidden_dim,
use_action_history=True,
num_action_history_encoder_layers=2,
max_action_context_length=self.action_context_length,
use_imu=True,
imu_orientation_embedding_method=IMUEncoder.OrientationEmbeddingMethod.QUATERNION,
num_imu_encoder_layers=2,
max_imu_context_length=self.imu_context_length,
use_joint_states=True,
joint_state_encoder_layers=2,
max_joint_state_context_length=self.joint_state_context_length,
use_images=True,
image_sequence_encoder_type=SequenceEncoderType.TRANSFORMER,
image_encoder_type=ImageEncoderType.RESNET18,
num_image_sequence_encoder_layers=1,
max_image_context_length=self.image_context_length,
num_decoder_layers=4,
trajectory_prediction_length=self.trajectory_prediction_length,
num_joints=self.hyper_params["num_joints"],
hidden_dim=self.hyper_params["hidden_dim"],
use_action_history=self.hyper_params["use_action_history"],
num_action_history_encoder_layers=self.hyper_params["num_action_history_encoder_layers"],
max_action_context_length=self.hyper_params["action_context_length"],
use_imu=self.hyper_params["use_imu"],
imu_orientation_embedding_method=IMUEncoder.OrientationEmbeddingMethod(
self.hyper_params["imu_orientation_embedding_method"]
),
num_imu_encoder_layers=self.hyper_params["num_imu_encoder_layers"],
imu_context_length=self.hyper_params["imu_context_length"],
use_joint_states=self.hyper_params["use_joint_states"],
joint_state_encoder_layers=self.hyper_params["joint_state_encoder_layers"],
joint_state_context_length=self.hyper_params["joint_state_context_length"],
use_images=self.hyper_params["use_images"],
image_sequence_encoder_type=SequenceEncoderType(self.hyper_params["image_sequence_encoder_type"]),
image_encoder_type=ImageEncoderType(self.hyper_params["image_encoder_type"]),
num_image_sequence_encoder_layers=self.hyper_params["num_image_sequence_encoder_layers"],
image_context_length=self.hyper_params["image_context_length"],
num_decoder_layers=self.hyper_params["num_decoder_layers"],
trajectory_prediction_length=self.hyper_params["trajectory_prediction_length"],
).to(device)

self.og_model = self.model

self.normalizer = Normalizer(self.model.mean, self.model.std)
self.model = EMA(self.model)
self.model.load_state_dict(torch.load(checkpoint, weights_only=True))
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.eval()
print(self.normalizer.mean)

# Create diffusion noise scheduler
self.get_logger().info("Create diffusion noise scheduler")
self.scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False)
self.scheduler.config["num_train_timesteps"] = train_denoising_timesteps
self.scheduler.config["num_train_timesteps"] = self.hyper_params["train_denoising_timesteps"]
self.scheduler.set_timesteps(self.inference_denosing_timesteps)

# Create control timer to run inference at a fixed rate
interval = 1 / self.sample_rate * self.trajectory_prediction_length
interval = 1 / self.sample_rate * (self.hyper_params["trajectory_prediction_length"])
# We want to run the inference in a separate thread to not block the callbacks, but we also want to make sure
# that the inference is not running multiple times in parallel
self.create_timer(interval, self.step, callback_group=MutuallyExclusiveCallbackGroup())
Expand Down Expand Up @@ -217,10 +213,10 @@ def update_buffers(self):
)

# Remove the oldest data from the buffers
self.joint_state_data = self.joint_state_data[-self.joint_state_context_length :]
self.image_embeddings = self.image_embeddings[-self.image_context_length :]
self.imu_data = self.imu_data[-self.imu_context_length :]
self.joint_command_data = self.joint_command_data[-self.action_context_length :]
self.joint_state_data = self.joint_state_data[-self.hyper_params["joint_state_context_length"] :]
self.image_embeddings = self.image_embeddings[-self.hyper_params["image_context_length"] :]
self.imu_data = self.imu_data[-self.hyper_params["imu_context_length"] :]
self.joint_command_data = self.joint_command_data[-self.hyper_params["action_context_length"] :]

@profile
def step(self):
Expand All @@ -239,28 +235,39 @@ def step(self):
% (2 * np.pi), # torch.stack(list(self.joint_command_data), dim=0).unsqueeze(0).to(device),
}

print("Batch: ", batch["image_data"].shape)

# Perform the denoising process
trajectory = torch.randn(1, self.trajectory_prediction_length, self.num_joints).to(device)
trajectory = torch.randn(
1, self.hyper_params["trajectory_prediction_length"], self.hyper_params["num_joints"]
).to(device)

start_ros_time = self.get_clock().now()

## Perform the embedding of the conditioning
start = time.time()
embedded_input = self.og_model.encode_input_data(batch)
embedded_input = self.model.encode_input_data(batch)
print("Time for embedding: ", time.time() - start)

# Denoise the trajectory
start = time.time()
self.scheduler.set_timesteps(self.inference_denosing_timesteps)
for t in self.scheduler.timesteps:
with torch.no_grad():
# Predict the noise residual
noise_pred = self.og_model.forward_with_context(
embedded_input, trajectory, torch.tensor([t], device=device)
)

# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = self.scheduler.step(noise_pred, t, trajectory).prev_sample
if self.hyper_params.get("distilled_decoder", False):
# Directly predict the trajectory based on the noise
with torch.no_grad():
trajectory = self.model.forward_with_context(embedded_input, trajectory, torch.tensor([0], device=device))
else:
# Perform the denoising process
self.scheduler.set_timesteps(self.inference_denosing_timesteps)
for t in self.scheduler.timesteps:
with torch.no_grad():
# Predict the noise residual
noise_pred = self.model.forward_with_context(
embedded_input, trajectory, torch.tensor([t], device=device)
)

# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = self.scheduler.step(noise_pred, t, trajectory).prev_sample

print("Time for forward: ", time.time() - start)

Expand All @@ -272,7 +279,7 @@ def step(self):
trajectory_msg.header.stamp = Time.to_msg(start_ros_time)
trajectory_msg.joint_names = JointStates.get_ordered_joint_names()
trajectory_msg.points = []
for i in range(self.trajectory_prediction_length):
for i in range(self.hyper_params["trajectory_prediction_length"]):
point = JointTrajectoryPoint()
point.positions = trajectory[0, i].cpu().numpy() - np.pi
point.time_from_start = Duration(nanoseconds=int(1e9 / self.sample_rate * i)).to_msg()
Expand Down

0 comments on commit f4e74b0

Please sign in to comment.