From f4e74b002cfb82db55844279c23447c26c7d9b14 Mon Sep 17 00:00:00 2001 From: Florian Vahl <7vahl@informatik.uni-hamburg.de> Date: Thu, 30 Jan 2025 18:15:03 +0100 Subject: [PATCH] Current WIP --- ddlitlab2024/ml/inference/plot.py | 2 +- ddlitlab2024/ml/inference/ros.py | 131 ++++++++++++++++-------------- 2 files changed, 70 insertions(+), 63 deletions(-) diff --git a/ddlitlab2024/ml/inference/plot.py b/ddlitlab2024/ml/inference/plot.py index bc62bd8..10db21f 100644 --- a/ddlitlab2024/ml/inference/plot.py +++ b/ddlitlab2024/ml/inference/plot.py @@ -104,7 +104,7 @@ noisy_trajectory = torch.randn_like(joint_targets).to(device) trajectory = noisy_trajectory - if params["distilled_decoder"]: + if params.get("distilled_decoder", False): # Directly predict the trajectory based on the noise with torch.no_grad(): trajectory = model(batch, noisy_trajectory, torch.tensor([0], device=device)) diff --git a/ddlitlab2024/ml/inference/ros.py b/ddlitlab2024/ml/inference/ros.py index b52406b..8cdf1f8 100644 --- a/ddlitlab2024/ml/inference/ros.py +++ b/ddlitlab2024/ml/inference/ros.py @@ -11,7 +11,6 @@ from bitbots_tf_buffer import Buffer from cv_bridge import CvBridge from diffusers.schedulers.scheduling_ddim import DDIMScheduler -from ema_pytorch import EMA from game_controller_hl_interfaces.msg import GameState from profilehooks import profile from rclpy.callback_groups import MutuallyExclusiveCallbackGroup @@ -42,21 +41,18 @@ def __init__(self, node_name, context): [rclpy.parameter.Parameter("use_sim_time", rclpy.Parameter.Type.BOOL, True)], ) + checkpoint_path = ( + "../training/destilled_trajectory_transformer_model_first_train_20_epoch_hyp.pth" + #"../training/trajectory_transformer_model_500_epoch_xmas_hyp.pth" + ) + self.inference_denosing_timesteps = 30 + # Params self.sample_rate = DEFAULT_RESAMPLE_RATE_HZ - hidden_dim = 256 - self.action_context_length = 100 - self.trajectory_prediction_length = 10 - train_denoising_timesteps = 1000 - self.inference_denosing_timesteps = 10 - self.image_context_length = 10 - self.imu_context_length = 100 - self.joint_state_context_length = 100 - self.num_joints = 20 - checkpoint = ( - "/home/florian/ddlitlab/ddlitlab_repo/ddlitlab2024/ml/training/" - "trajectory_transformer_model_500_epoch_xmas.pth" - ) + # Load the hyperparameters from the checkpoint + self.get_logger().info(f"Loading checkpoint '{checkpoint_path}'") + checkpoint = torch.load(checkpoint_path, weights_only=True) + self.hyper_params = checkpoint["hyperparams"] # Subscribe to all the input topics self.joint_state_sub = self.create_subscription(JointState, "/joint_states", self.joint_state_callback, 10) @@ -89,12 +85,14 @@ def __init__(self, node_name, context): self.latest_game_state = None # Add default values to the buffers - self.image_embeddings = [torch.randn(3, 480, 480)] * self.image_context_length - self.imu_data = [torch.randn(4)] * self.imu_context_length - self.joint_state_data = [ - torch.randn(len(JointStates.get_ordered_joint_names())) - ] * self.joint_state_context_length - self.joint_command_data = [torch.randn(self.num_joints)] * self.action_context_length + self.image_embeddings = [torch.randn(3, 480, 480)] * self.hyper_params["image_context_length"] + self.imu_data = [torch.randn(4)] * self.hyper_params["imu_context_length"] + self.joint_state_data = [torch.randn(len(JointStates.get_ordered_joint_names()))] * self.hyper_params[ + "joint_state_context_length" + ] + self.joint_command_data = [torch.randn(self.hyper_params["num_joints"])] * self.hyper_params[ + "action_context_length" + ] self.data_lock = Lock() @@ -106,43 +104,41 @@ def __init__(self, node_name, context): # Load model self.get_logger().info("Load model") self.model = End2EndDiffusionTransformer( - num_joints=self.num_joints, - hidden_dim=hidden_dim, - use_action_history=True, - num_action_history_encoder_layers=2, - max_action_context_length=self.action_context_length, - use_imu=True, - imu_orientation_embedding_method=IMUEncoder.OrientationEmbeddingMethod.QUATERNION, - num_imu_encoder_layers=2, - max_imu_context_length=self.imu_context_length, - use_joint_states=True, - joint_state_encoder_layers=2, - max_joint_state_context_length=self.joint_state_context_length, - use_images=True, - image_sequence_encoder_type=SequenceEncoderType.TRANSFORMER, - image_encoder_type=ImageEncoderType.RESNET18, - num_image_sequence_encoder_layers=1, - max_image_context_length=self.image_context_length, - num_decoder_layers=4, - trajectory_prediction_length=self.trajectory_prediction_length, + num_joints=self.hyper_params["num_joints"], + hidden_dim=self.hyper_params["hidden_dim"], + use_action_history=self.hyper_params["use_action_history"], + num_action_history_encoder_layers=self.hyper_params["num_action_history_encoder_layers"], + max_action_context_length=self.hyper_params["action_context_length"], + use_imu=self.hyper_params["use_imu"], + imu_orientation_embedding_method=IMUEncoder.OrientationEmbeddingMethod( + self.hyper_params["imu_orientation_embedding_method"] + ), + num_imu_encoder_layers=self.hyper_params["num_imu_encoder_layers"], + imu_context_length=self.hyper_params["imu_context_length"], + use_joint_states=self.hyper_params["use_joint_states"], + joint_state_encoder_layers=self.hyper_params["joint_state_encoder_layers"], + joint_state_context_length=self.hyper_params["joint_state_context_length"], + use_images=self.hyper_params["use_images"], + image_sequence_encoder_type=SequenceEncoderType(self.hyper_params["image_sequence_encoder_type"]), + image_encoder_type=ImageEncoderType(self.hyper_params["image_encoder_type"]), + num_image_sequence_encoder_layers=self.hyper_params["num_image_sequence_encoder_layers"], + image_context_length=self.hyper_params["image_context_length"], + num_decoder_layers=self.hyper_params["num_decoder_layers"], + trajectory_prediction_length=self.hyper_params["trajectory_prediction_length"], ).to(device) - - self.og_model = self.model - self.normalizer = Normalizer(self.model.mean, self.model.std) - self.model = EMA(self.model) - self.model.load_state_dict(torch.load(checkpoint, weights_only=True)) + self.model.load_state_dict(checkpoint["model_state_dict"]) self.model.eval() print(self.normalizer.mean) # Create diffusion noise scheduler self.get_logger().info("Create diffusion noise scheduler") self.scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False) - self.scheduler.config["num_train_timesteps"] = train_denoising_timesteps + self.scheduler.config["num_train_timesteps"] = self.hyper_params["train_denoising_timesteps"] self.scheduler.set_timesteps(self.inference_denosing_timesteps) # Create control timer to run inference at a fixed rate - interval = 1 / self.sample_rate * self.trajectory_prediction_length + interval = 1 / self.sample_rate * (self.hyper_params["trajectory_prediction_length"]) # We want to run the inference in a separate thread to not block the callbacks, but we also want to make sure # that the inference is not running multiple times in parallel self.create_timer(interval, self.step, callback_group=MutuallyExclusiveCallbackGroup()) @@ -217,10 +213,10 @@ def update_buffers(self): ) # Remove the oldest data from the buffers - self.joint_state_data = self.joint_state_data[-self.joint_state_context_length :] - self.image_embeddings = self.image_embeddings[-self.image_context_length :] - self.imu_data = self.imu_data[-self.imu_context_length :] - self.joint_command_data = self.joint_command_data[-self.action_context_length :] + self.joint_state_data = self.joint_state_data[-self.hyper_params["joint_state_context_length"] :] + self.image_embeddings = self.image_embeddings[-self.hyper_params["image_context_length"] :] + self.imu_data = self.imu_data[-self.hyper_params["imu_context_length"] :] + self.joint_command_data = self.joint_command_data[-self.hyper_params["action_context_length"] :] @profile def step(self): @@ -239,28 +235,39 @@ def step(self): % (2 * np.pi), # torch.stack(list(self.joint_command_data), dim=0).unsqueeze(0).to(device), } + print("Batch: ", batch["image_data"].shape) + # Perform the denoising process - trajectory = torch.randn(1, self.trajectory_prediction_length, self.num_joints).to(device) + trajectory = torch.randn( + 1, self.hyper_params["trajectory_prediction_length"], self.hyper_params["num_joints"] + ).to(device) start_ros_time = self.get_clock().now() ## Perform the embedding of the conditioning start = time.time() - embedded_input = self.og_model.encode_input_data(batch) + embedded_input = self.model.encode_input_data(batch) print("Time for embedding: ", time.time() - start) # Denoise the trajectory start = time.time() - self.scheduler.set_timesteps(self.inference_denosing_timesteps) - for t in self.scheduler.timesteps: - with torch.no_grad(): - # Predict the noise residual - noise_pred = self.og_model.forward_with_context( - embedded_input, trajectory, torch.tensor([t], device=device) - ) - # Update the trajectory based on the predicted noise and the current step of the denoising process - trajectory = self.scheduler.step(noise_pred, t, trajectory).prev_sample + if self.hyper_params.get("distilled_decoder", False): + # Directly predict the trajectory based on the noise + with torch.no_grad(): + trajectory = self.model.forward_with_context(embedded_input, trajectory, torch.tensor([0], device=device)) + else: + # Perform the denoising process + self.scheduler.set_timesteps(self.inference_denosing_timesteps) + for t in self.scheduler.timesteps: + with torch.no_grad(): + # Predict the noise residual + noise_pred = self.model.forward_with_context( + embedded_input, trajectory, torch.tensor([t], device=device) + ) + + # Update the trajectory based on the predicted noise and the current step of the denoising process + trajectory = self.scheduler.step(noise_pred, t, trajectory).prev_sample print("Time for forward: ", time.time() - start) @@ -272,7 +279,7 @@ def step(self): trajectory_msg.header.stamp = Time.to_msg(start_ros_time) trajectory_msg.joint_names = JointStates.get_ordered_joint_names() trajectory_msg.points = [] - for i in range(self.trajectory_prediction_length): + for i in range(self.hyper_params["trajectory_prediction_length"]): point = JointTrajectoryPoint() point.positions = trajectory[0, i].cpu().numpy() - np.pi point.time_from_start = Duration(nanoseconds=int(1e9 / self.sample_rate * i)).to_msg()