diff --git a/ddlitlab2024/dataset/vizualization.ipynb b/ddlitlab2024/dataset/vizualization.ipynb index 8926670..97ef3b9 100644 --- a/ddlitlab2024/dataset/vizualization.ipynb +++ b/ddlitlab2024/dataset/vizualization.ipynb @@ -6,11 +6,11 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", "import sqlite3\n", "\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "\n", "con = sqlite3.connect(\"./db.sqlite3\")" ] }, diff --git a/ddlitlab2024/ml/inference/plot.py b/ddlitlab2024/ml/inference/plot.py index 99411ce..f47b938 100644 --- a/ddlitlab2024/ml/inference/plot.py +++ b/ddlitlab2024/ml/inference/plot.py @@ -1,3 +1,4 @@ +import argparse from dataclasses import asdict import matplotlib.pyplot as plt @@ -5,7 +6,6 @@ import torch import torch.nn.functional as F # noqa from diffusers.schedulers.scheduling_ddim import DDIMScheduler -from ema_pytorch import EMA from torch.utils.data import DataLoader from ddlitlab2024.dataset.pytorch import DDLITLab2024Dataset, Normalizer, worker_init_fn @@ -22,73 +22,67 @@ logger.info("Starting") logger.info(f"Using device {device}") - hidden_dim = 256 - num_layers = 4 - num_heads = 4 - action_context_length = 100 - trajectory_prediction_length = 10 - batch_size = 1 - lr = 1e-4 - train_denoising_timesteps = 1000 - image_context_length = 10 - action_context_length = 100 - imu_context_length = 100 - joint_state_context_length = 100 - num_normalization_samples = 50 - num_joints = 20 - checkpoint = "/homes/17vahl/ddlitlab2024/ddlitlab2024/ml/training/trajectory_transformer_model.pth" + # Parse the command line arguments + parser = argparse.ArgumentParser(description="Inference Plot") + parser.add_argument("checkpoint", type=str, help="Path to the checkpoint to load") + parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps") + parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate") + args = parser.parse_args() + + # Load the hyperparameters from the checkpoint + logger.info(f"Loading checkpoint '{args.checkpoint}'") + checkpoint = torch.load(args.checkpoint, weights_only=True) + params = checkpoint["hyperparams"] logger.info("Load model") model = End2EndDiffusionTransformer( - num_joints=20, - hidden_dim=hidden_dim, - use_action_history=True, - num_action_history_encoder_layers=2, - max_action_context_length=action_context_length, - use_imu=True, - imu_orientation_embedding_method=IMUEncoder.OrientationEmbeddingMethod.QUATERNION, - num_imu_encoder_layers=2, - max_imu_context_length=imu_context_length, - use_joint_states=True, - joint_state_encoder_layers=2, - max_joint_state_context_length=joint_state_context_length, - use_images=True, - image_sequence_encoder_type=SequenceEncoderType.TRANSFORMER, - image_encoder_type=ImageEncoderType.RESNET18, - num_image_sequence_encoder_layers=1, - max_image_context_length=image_context_length, - num_decoder_layers=4, - trajectory_prediction_length=trajectory_prediction_length, + num_joints=params["num_joints"], + hidden_dim=params["hidden_dim"], + use_action_history=params["use_action_history"], + num_action_history_encoder_layers=params["num_action_history_encoder_layers"], + max_action_context_length=params["action_context_length"], + use_imu=params["use_imu"], + imu_orientation_embedding_method=IMUEncoder.OrientationEmbeddingMethod( + params["imu_orientation_embedding_method"] + ), + num_imu_encoder_layers=params["num_imu_encoder_layers"], + imu_context_length=params["imu_context_length"], + use_joint_states=params["use_joint_states"], + joint_state_encoder_layers=params["joint_state_encoder_layers"], + joint_state_context_length=params["joint_state_context_length"], + use_images=params["use_images"], + image_sequence_encoder_type=SequenceEncoderType(params["image_sequence_encoder_type"]), + image_encoder_type=ImageEncoderType(params["image_encoder_type"]), + num_image_sequence_encoder_layers=params["num_image_sequence_encoder_layers"], + image_context_length=params["image_context_length"], + num_decoder_layers=params["num_decoder_layers"], + trajectory_prediction_length=params["trajectory_prediction_length"], ).to(device) normalizer = Normalizer(model.mean, model.std) - model = EMA(model) - model.load_state_dict(torch.load(checkpoint, weights_only=True)) + model.load_state_dict(checkpoint["model_state_dict"]) model.eval() print(normalizer.mean) - num_samples = 10 - inference_denosing_timesteps = 30 - # Create diffusion noise scheduler scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False) - scheduler.config["num_train_timesteps"] = train_denoising_timesteps - scheduler.set_timesteps(inference_denosing_timesteps) + scheduler.config["num_train_timesteps"] = params["train_denoising_timesteps"] + scheduler.set_timesteps(args.steps) # Create Dataset object dataset = DDLITLab2024Dataset( - num_joints=num_joints, - num_frames_video=image_context_length, - num_samples_joint_trajectory_future=trajectory_prediction_length, - num_samples_joint_trajectory=action_context_length, - num_samples_imu=imu_context_length, - num_samples_joint_states=joint_state_context_length, + num_joints=params["num_joints"], + num_frames_video=params["image_context_length"], + num_samples_joint_trajectory_future=params["trajectory_prediction_length"], + num_samples_joint_trajectory=params["action_context_length"], + num_samples_imu=params["imu_context_length"], + num_samples_joint_states=params["joint_state_context_length"], ) # Create DataLoader object num_workers = 5 dataloader = DataLoader( dataset, - batch_size=batch_size, + batch_size=1, shuffle=True, collate_fn=DDLITLab2024Dataset.collate_fn, persistent_workers=num_workers > 1, @@ -98,7 +92,7 @@ dataloader = iter(dataloader) - for _ in range(num_samples): + for _ in range(args.num_samples): batch = next(dataloader) # Move the data to the device batch = {k: v.to(device) for k, v in asdict(batch).items()} @@ -111,7 +105,7 @@ trajectory = noisy_trajectory # Perform the denoising process - scheduler.set_timesteps(inference_denosing_timesteps) + scheduler.set_timesteps(args.steps) for t in scheduler.timesteps: with torch.no_grad(): # Predict the noise residual @@ -121,30 +115,35 @@ trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample # Undo the normalization - print(normalizer.mean) trajectory = normalizer.denormalize(trajectory) noisy_trajectory = normalizer.denormalize(noisy_trajectory) # Plot the trajectory context, the noisy trajectory, the denoised trajectory # and the target trajectory for each joint plt.figure(figsize=(10, 10)) - for j in range(num_joints): + for j in range(params["num_joints"]): plt.subplot(5, 4, j + 1) joint_command_context = batch["joint_command_history"][0, :, j].cpu().numpy() plt.plot(np.arange(len(joint_command_context)), joint_command_context, label="Context") plt.plot( - np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length), + np.arange( + len(joint_command_context), len(joint_command_context) + params["trajectory_prediction_length"] + ), noisy_trajectory[0, :, j].cpu().numpy(), label="Noisy Trajectory", ) plt.plot( - np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length), + np.arange( + len(joint_command_context), len(joint_command_context) + params["trajectory_prediction_length"] + ), joint_targets[0, :, j].cpu().numpy(), label="Target Trajectory", ) plt.plot( - np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length), + np.arange( + len(joint_command_context), len(joint_command_context) + params["trajectory_prediction_length"] + ), trajectory[0, :, j].cpu().numpy(), label="Denoised Trajectory", ) diff --git a/ddlitlab2024/ml/model/model.py b/ddlitlab2024/ml/model/model.py index 3a48f96..19b85d0 100644 --- a/ddlitlab2024/ml/model/model.py +++ b/ddlitlab2024/ml/model/model.py @@ -19,15 +19,15 @@ def __init__( use_imu: bool, imu_orientation_embedding_method: IMUEncoder.OrientationEmbeddingMethod, num_imu_encoder_layers: int, - max_imu_context_length: int, + imu_context_length: int, use_joint_states: bool, joint_state_encoder_layers: int, - max_joint_state_context_length: int, + joint_state_context_length: int, use_images: bool, image_encoder_type: ImageEncoderType, image_sequence_encoder_type: SequenceEncoderType, num_image_sequence_encoder_layers: int, - max_image_context_length: int, + image_context_length: int, num_decoder_layers: int = 4, trajectory_prediction_length: int = 8, ): @@ -58,7 +58,7 @@ def __init__( hidden_dim=hidden_dim, num_layers=num_imu_encoder_layers, num_heads=4, - max_seq_len=max_imu_context_length, + max_seq_len=imu_context_length, ) if use_imu else None @@ -71,7 +71,7 @@ def __init__( hidden_dim=hidden_dim, num_layers=joint_state_encoder_layers, num_heads=4, - max_seq_len=max_joint_state_context_length, + max_seq_len=joint_state_context_length, ) if use_joint_states else None @@ -84,7 +84,7 @@ def __init__( image_encoder_type=image_encoder_type, hidden_dim=hidden_dim, num_layers=num_image_sequence_encoder_layers, - max_seq_len=max_image_context_length, + max_seq_len=image_context_length, ) if use_images else None diff --git a/ddlitlab2024/ml/training/config/default.yaml b/ddlitlab2024/ml/training/config/default.yaml new file mode 100644 index 0000000..f38f5b0 --- /dev/null +++ b/ddlitlab2024/ml/training/config/default.yaml @@ -0,0 +1,25 @@ +hidden_dim: 256 +action_context_length: 100 +trajectory_prediction_length: 10 +epochs: 500 +batch_size: 16 +lr: 1.e-4 +train_denoising_timesteps: 1000 +image_context_length: 10 +imu_context_length: 100 +num_imu_encoder_layers: 2 +joint_state_context_length: 100 +num_normalization_samples: 1000 +num_joints: 20 +use_action_history: True +num_action_history_encoder_layers: 2 +use_imu: True +imu_orientation_embedding_method: "quaternion" +num_imu_encoder_layers: 2 +use_joint_states: True +joint_state_encoder_layers: 2 +use_images: True +image_sequence_encoder_type: "transformer" +image_encoder_type: "resnet18" +num_image_sequence_encoder_layers: 1 +num_decoder_layers: 4 \ No newline at end of file diff --git a/ddlitlab2024/ml/training/train.py b/ddlitlab2024/ml/training/train.py index 7b5a840..eb98328 100644 --- a/ddlitlab2024/ml/training/train.py +++ b/ddlitlab2024/ml/training/train.py @@ -1,11 +1,12 @@ +import argparse from dataclasses import asdict from functools import partial import numpy as np import torch import torch.nn.functional as F # noqa +import yaml from diffusers.schedulers.scheduling_ddim import DDIMScheduler -from ema_pytorch import EMA from torch.utils.data import DataLoader from tqdm import tqdm @@ -25,38 +26,61 @@ logger.info("Starting training") logger.info(f"Using device {device}") # TODO wandb - # Define hyperparameters # TODO proper configuration - hidden_dim = 256 - num_layers = 4 - num_heads = 4 - action_context_length = 100 - trajectory_prediction_length = 10 - epochs = 500 - batch_size = 16 - lr = 1e-4 - train_denoising_timesteps = 1000 - image_context_length = 10 - action_context_length = 100 - imu_context_length = 100 - joint_state_context_length = 100 - num_normalization_samples = 1000 - num_joints = 20 - checkpoint: str | None = None + + # Parse the command line arguments + parser = argparse.ArgumentParser(description="Train the model") + parser.add_argument("--config", "-c", type=str, default=None, help="Path to the configuration file") + parser.add_argument("--checkpoint", "-p", type=str, default=None, help="Path to the checkpoint to load") + parser.add_argument( + "--output", "-o", type=str, default="trajectory_transformer_model.pth", help="Path to save the model" + ) + args = parser.parse_args() + + assert ( + args.config is not None or args.checkpoint is not None + ), "Either a configuration file or a checkpoint must be provided" + + # Load the hyperparameters from the checkpoint + if args.checkpoint is not None: + logger.info(f"Loading checkpoint '{args.checkpoint}'") + checkpoint = torch.load(args.checkpoint, weights_only=True) + params = checkpoint["hyperparams"] + + # Load the hyperparameters from the configuration file + if args.config is not None: + logger.info(f"Loading configuration file '{args.config}'") + with open(args.config) as file: + config_params = yaml.safe_load(file) + + if args.checkpoint is not None: + logger.warning( + "Both a configuration file and a checkpoint are provided. " + "The configuration file will be used for the hyperparameters." + ) + # Print the differences between the checkpoint and the configuration file + for key, value in config_params.items(): + if key not in params: + logger.warning(f"Key '{key}' is not present in the checkpoint") + elif value != params[key]: + logger.warning(f"Key '{key}' has a different value in the checkpoint: {params[key]} != {value}") + + # Now we are ready to use the configuration file + params = config_params # Load the dataset logger.info("Create dataset objects") dataset = DDLITLab2024Dataset( - num_joints=num_joints, - num_frames_video=image_context_length, - num_samples_joint_trajectory_future=trajectory_prediction_length, - num_samples_joint_trajectory=action_context_length, - num_samples_imu=imu_context_length, - num_samples_joint_states=joint_state_context_length, + num_joints=params["num_joints"], + num_frames_video=params["image_context_length"], + num_samples_joint_trajectory_future=params["trajectory_prediction_length"], + num_samples_joint_trajectory=params["action_context_length"], + num_samples_imu=params["imu_context_length"], + num_samples_joint_states=params["joint_state_context_length"], ) num_workers = 5 dataloader = DataLoader( dataset, - batch_size=batch_size, + batch_size=params["batch_size"], shuffle=True, collate_fn=DDLITLab2024Dataset.collate_fn, persistent_workers=num_workers > 1, @@ -67,31 +91,33 @@ # Get some samples to estimate the mean and std logger.info("Estimating normalization parameters") - random_indices = np.random.randint(0, len(dataset), (num_normalization_samples,)) + random_indices = np.random.randint(0, len(dataset), (params["num_normalization_samples"],)) normalization_samples = torch.cat([dataset[i].joint_command_history for i in tqdm(random_indices)], dim=0) normalizer = Normalizer.fit(normalization_samples.to(device)) # Initialize the Transformer model and optimizer, and move model to device - model = End2EndDiffusionTransformer( # TODO enforce all params to be consistent with the dataset - num_joints=num_joints, - hidden_dim=hidden_dim, - use_action_history=True, - num_action_history_encoder_layers=2, - max_action_context_length=action_context_length, - use_imu=True, - imu_orientation_embedding_method=IMUEncoder.OrientationEmbeddingMethod.QUATERNION, - num_imu_encoder_layers=2, - max_imu_context_length=imu_context_length, - use_joint_states=True, - joint_state_encoder_layers=2, - max_joint_state_context_length=joint_state_context_length, - use_images=True, - image_sequence_encoder_type=SequenceEncoderType.TRANSFORMER, - image_encoder_type=ImageEncoderType.RESNET18, - num_image_sequence_encoder_layers=1, - max_image_context_length=image_context_length, - num_decoder_layers=4, - trajectory_prediction_length=trajectory_prediction_length, + model = End2EndDiffusionTransformer( + num_joints=params["num_joints"], + hidden_dim=params["hidden_dim"], + use_action_history=params["use_action_history"], + num_action_history_encoder_layers=params["num_action_history_encoder_layers"], + max_action_context_length=params["action_context_length"], + use_imu=params["use_imu"], + imu_orientation_embedding_method=IMUEncoder.OrientationEmbeddingMethod( + params["imu_orientation_embedding_method"] + ), + num_imu_encoder_layers=params["num_imu_encoder_layers"], + imu_context_length=params["imu_context_length"], + use_joint_states=params["use_joint_states"], + joint_state_encoder_layers=params["joint_state_encoder_layers"], + joint_state_context_length=params["joint_state_context_length"], + use_images=params["use_images"], + image_sequence_encoder_type=SequenceEncoderType(params["image_sequence_encoder_type"]), + image_encoder_type=ImageEncoderType(params["image_encoder_type"]), + num_image_sequence_encoder_layers=params["num_image_sequence_encoder_layers"], + image_context_length=params["image_context_length"], + num_decoder_layers=params["num_decoder_layers"], + trajectory_prediction_length=params["trajectory_prediction_length"], ).to(device) # Add normalization parameters to the model @@ -100,24 +126,40 @@ logger.info(f"Normalization values:\nJoint mean: {normalizer.mean}\nJoint std: {normalizer.std}") assert all(model.std != 0), "Normalization std is zero, this makes no sense. Some joints are constant." - # Utilize an Exponential Moving Average (EMA) for the model to smooth out the training process - ema = EMA(model, beta=0.999) - # Load the model if a checkpoint is provided - if checkpoint is not None: - logger.info(f"Loading model from {checkpoint}") - ema.load_state_dict(torch.load(checkpoint, weights_only=True)) + if args.checkpoint is not None: + logger.info("Loading model from checkpoint") + model.load_state_dict(checkpoint["model_state_dict"]) # Create optimizer and learning rate scheduler - optimizer = torch.optim.AdamW(model.parameters(), lr=lr) - lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, total_steps=epochs * len(dataloader)) + optimizer = torch.optim.AdamW(model.parameters(), lr=params["lr"]) + + # Load the optimizer state if a checkpoint is provided + if args.checkpoint is not None: + if "optimizer_state_dict" in checkpoint: + logger.info("Loading optimizer state from checkpoint") + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + else: + logger.warning("No optimizer state found in the checkpoint") + + lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, max_lr=params["lr"], total_steps=params["epochs"] * len(dataloader) + ) + + # Load the learning rate scheduler state if a checkpoint is provided + if args.checkpoint is not None: + if "lr_scheduler_state_dict" in checkpoint: + logger.info("Loading learning rate scheduler state from checkpoint") + lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) + else: + logger.warning("No learning rate scheduler state found in the checkpoint") # Create diffusion noise scheduler scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False) - scheduler.config["num_train_timesteps"] = train_denoising_timesteps + scheduler.config["num_train_timesteps"] = params["train_denoising_timesteps"] # Training loop - for epoch in range(epochs): + for epoch in range(params["epochs"]): mean_loss = 0 # Iterate over the dataset @@ -157,11 +199,17 @@ loss.backward() optimizer.step() lr_scheduler.step() - ema.update() pbar.set_postfix_str( f"Epoch {epoch}, Loss: {mean_loss / (i + 1):.05f}, LR: {lr_scheduler.get_last_lr()[0]:0.7f}" ) # Save the model - torch.save(ema.state_dict(), "trajectory_transformer_model.pth") + checkpoint = { + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "lr_scheduler_state_dict": lr_scheduler.state_dict(), + "hyperparams": params, + "current_epoch": epoch, + } + torch.save(checkpoint, args.output) diff --git a/ddlitlab2024/utils/embed_parameters.py b/ddlitlab2024/utils/embed_parameters.py new file mode 100644 index 0000000..ad65ea5 --- /dev/null +++ b/ddlitlab2024/utils/embed_parameters.py @@ -0,0 +1,62 @@ +import argparse + +import torch +import yaml +from ema_pytorch import EMA + +from ddlitlab2024.ml.model import End2EndDiffusionTransformer +from ddlitlab2024.ml.model.encoder.image import ImageEncoderType, SequenceEncoderType +from ddlitlab2024.ml.model.encoder.imu import IMUEncoder + +# This script embeds the parameters into the model itself + +if __name__ == "__main__": + # Get command line arguments + parser = argparse.ArgumentParser(description="Convert a legacy checkpoint to the new format") + parser.add_argument("checkpoint", type=str, help="Path to the checkpoint to load") + parser.add_argument("config", type=str, help="Path to the configuration file") + parser.add_argument("output", type=str, help="Path to save the model") + args = parser.parse_args() + + # Load the hyperparameters from training yaml file + with open(args.config) as file: + params = yaml.safe_load(file) + + # Initialize the Transformer model and optimizer, and move model to device + model = End2EndDiffusionTransformer( + num_joints=params["num_joints"], + hidden_dim=params["hidden_dim"], + use_action_history=params["use_action_history"], + num_action_history_encoder_layers=params["num_action_history_encoder_layers"], + max_action_context_length=params["action_context_length"], + use_imu=params["use_imu"], + imu_orientation_embedding_method=IMUEncoder.OrientationEmbeddingMethod( + params["imu_orientation_embedding_method"] + ), + num_imu_encoder_layers=params["num_imu_encoder_layers"], + imu_context_length=params["imu_context_length"], + use_joint_states=params["use_joint_states"], + joint_state_encoder_layers=params["joint_state_encoder_layers"], + joint_state_context_length=params["joint_state_context_length"], + use_images=params["use_images"], + image_sequence_encoder_type=SequenceEncoderType(params["image_sequence_encoder_type"]), + image_encoder_type=ImageEncoderType(params["image_encoder_type"]), + num_image_sequence_encoder_layers=params["num_image_sequence_encoder_layers"], + image_context_length=params["image_context_length"], + num_decoder_layers=params["num_decoder_layers"], + trajectory_prediction_length=params["trajectory_prediction_length"], + ) + + ema_model = EMA(model, beta=0.999) + + # Load the model from the checkpoint + print(f"Loading model from {args.checkpoint}") + ema_model.load_state_dict(torch.load(args.checkpoint, weights_only=True)) + + # Save the checkpoint + print(f"Saving model to {args.output}") + checkpoint = { + "model_state_dict": model.state_dict(), + "hyperparams": params, + } + torch.save(checkpoint, args.output)