Skip to content

Commit

Permalink
Merge pull request #53 from bit-bots/feature/proper_param_handling
Browse files Browse the repository at this point in the history
Proper hyperparameter handling
  • Loading branch information
Flova authored Jan 30, 2025
2 parents b422c44 + 3f047d4 commit d4fe368
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 123 deletions.
6 changes: 3 additions & 3 deletions ddlitlab2024/dataset/vizualization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import sqlite3\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"con = sqlite3.connect(\"./db.sqlite3\")"
]
},
Expand Down
111 changes: 55 additions & 56 deletions ddlitlab2024/ml/inference/plot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
from dataclasses import asdict

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F # noqa
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from ema_pytorch import EMA
from torch.utils.data import DataLoader

from ddlitlab2024.dataset.pytorch import DDLITLab2024Dataset, Normalizer, worker_init_fn
Expand All @@ -22,73 +22,67 @@
logger.info("Starting")
logger.info(f"Using device {device}")

hidden_dim = 256
num_layers = 4
num_heads = 4
action_context_length = 100
trajectory_prediction_length = 10
batch_size = 1
lr = 1e-4
train_denoising_timesteps = 1000
image_context_length = 10
action_context_length = 100
imu_context_length = 100
joint_state_context_length = 100
num_normalization_samples = 50
num_joints = 20
checkpoint = "/homes/17vahl/ddlitlab2024/ddlitlab2024/ml/training/trajectory_transformer_model.pth"
# Parse the command line arguments
parser = argparse.ArgumentParser(description="Inference Plot")
parser.add_argument("checkpoint", type=str, help="Path to the checkpoint to load")
parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps")
parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate")
args = parser.parse_args()

# Load the hyperparameters from the checkpoint
logger.info(f"Loading checkpoint '{args.checkpoint}'")
checkpoint = torch.load(args.checkpoint, weights_only=True)
params = checkpoint["hyperparams"]

logger.info("Load model")
model = End2EndDiffusionTransformer(
num_joints=20,
hidden_dim=hidden_dim,
use_action_history=True,
num_action_history_encoder_layers=2,
max_action_context_length=action_context_length,
use_imu=True,
imu_orientation_embedding_method=IMUEncoder.OrientationEmbeddingMethod.QUATERNION,
num_imu_encoder_layers=2,
max_imu_context_length=imu_context_length,
use_joint_states=True,
joint_state_encoder_layers=2,
max_joint_state_context_length=joint_state_context_length,
use_images=True,
image_sequence_encoder_type=SequenceEncoderType.TRANSFORMER,
image_encoder_type=ImageEncoderType.RESNET18,
num_image_sequence_encoder_layers=1,
max_image_context_length=image_context_length,
num_decoder_layers=4,
trajectory_prediction_length=trajectory_prediction_length,
num_joints=params["num_joints"],
hidden_dim=params["hidden_dim"],
use_action_history=params["use_action_history"],
num_action_history_encoder_layers=params["num_action_history_encoder_layers"],
max_action_context_length=params["action_context_length"],
use_imu=params["use_imu"],
imu_orientation_embedding_method=IMUEncoder.OrientationEmbeddingMethod(
params["imu_orientation_embedding_method"]
),
num_imu_encoder_layers=params["num_imu_encoder_layers"],
imu_context_length=params["imu_context_length"],
use_joint_states=params["use_joint_states"],
joint_state_encoder_layers=params["joint_state_encoder_layers"],
joint_state_context_length=params["joint_state_context_length"],
use_images=params["use_images"],
image_sequence_encoder_type=SequenceEncoderType(params["image_sequence_encoder_type"]),
image_encoder_type=ImageEncoderType(params["image_encoder_type"]),
num_image_sequence_encoder_layers=params["num_image_sequence_encoder_layers"],
image_context_length=params["image_context_length"],
num_decoder_layers=params["num_decoder_layers"],
trajectory_prediction_length=params["trajectory_prediction_length"],
).to(device)
normalizer = Normalizer(model.mean, model.std)
model = EMA(model)
model.load_state_dict(torch.load(checkpoint, weights_only=True))
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print(normalizer.mean)

num_samples = 10
inference_denosing_timesteps = 30

# Create diffusion noise scheduler
scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False)
scheduler.config["num_train_timesteps"] = train_denoising_timesteps
scheduler.set_timesteps(inference_denosing_timesteps)
scheduler.config["num_train_timesteps"] = params["train_denoising_timesteps"]
scheduler.set_timesteps(args.steps)

# Create Dataset object
dataset = DDLITLab2024Dataset(
num_joints=num_joints,
num_frames_video=image_context_length,
num_samples_joint_trajectory_future=trajectory_prediction_length,
num_samples_joint_trajectory=action_context_length,
num_samples_imu=imu_context_length,
num_samples_joint_states=joint_state_context_length,
num_joints=params["num_joints"],
num_frames_video=params["image_context_length"],
num_samples_joint_trajectory_future=params["trajectory_prediction_length"],
num_samples_joint_trajectory=params["action_context_length"],
num_samples_imu=params["imu_context_length"],
num_samples_joint_states=params["joint_state_context_length"],
)

# Create DataLoader object
num_workers = 5
dataloader = DataLoader(
dataset,
batch_size=batch_size,
batch_size=1,
shuffle=True,
collate_fn=DDLITLab2024Dataset.collate_fn,
persistent_workers=num_workers > 1,
Expand All @@ -98,7 +92,7 @@

dataloader = iter(dataloader)

for _ in range(num_samples):
for _ in range(args.num_samples):
batch = next(dataloader)
# Move the data to the device
batch = {k: v.to(device) for k, v in asdict(batch).items()}
Expand All @@ -111,7 +105,7 @@
trajectory = noisy_trajectory

# Perform the denoising process
scheduler.set_timesteps(inference_denosing_timesteps)
scheduler.set_timesteps(args.steps)
for t in scheduler.timesteps:
with torch.no_grad():
# Predict the noise residual
Expand All @@ -121,30 +115,35 @@
trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample

# Undo the normalization
print(normalizer.mean)
trajectory = normalizer.denormalize(trajectory)
noisy_trajectory = normalizer.denormalize(noisy_trajectory)

# Plot the trajectory context, the noisy trajectory, the denoised trajectory
# and the target trajectory for each joint
plt.figure(figsize=(10, 10))
for j in range(num_joints):
for j in range(params["num_joints"]):
plt.subplot(5, 4, j + 1)

joint_command_context = batch["joint_command_history"][0, :, j].cpu().numpy()
plt.plot(np.arange(len(joint_command_context)), joint_command_context, label="Context")
plt.plot(
np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length),
np.arange(
len(joint_command_context), len(joint_command_context) + params["trajectory_prediction_length"]
),
noisy_trajectory[0, :, j].cpu().numpy(),
label="Noisy Trajectory",
)
plt.plot(
np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length),
np.arange(
len(joint_command_context), len(joint_command_context) + params["trajectory_prediction_length"]
),
joint_targets[0, :, j].cpu().numpy(),
label="Target Trajectory",
)
plt.plot(
np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length),
np.arange(
len(joint_command_context), len(joint_command_context) + params["trajectory_prediction_length"]
),
trajectory[0, :, j].cpu().numpy(),
label="Denoised Trajectory",
)
Expand Down
12 changes: 6 additions & 6 deletions ddlitlab2024/ml/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ def __init__(
use_imu: bool,
imu_orientation_embedding_method: IMUEncoder.OrientationEmbeddingMethod,
num_imu_encoder_layers: int,
max_imu_context_length: int,
imu_context_length: int,
use_joint_states: bool,
joint_state_encoder_layers: int,
max_joint_state_context_length: int,
joint_state_context_length: int,
use_images: bool,
image_encoder_type: ImageEncoderType,
image_sequence_encoder_type: SequenceEncoderType,
num_image_sequence_encoder_layers: int,
max_image_context_length: int,
image_context_length: int,
num_decoder_layers: int = 4,
trajectory_prediction_length: int = 8,
):
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(
hidden_dim=hidden_dim,
num_layers=num_imu_encoder_layers,
num_heads=4,
max_seq_len=max_imu_context_length,
max_seq_len=imu_context_length,
)
if use_imu
else None
Expand All @@ -71,7 +71,7 @@ def __init__(
hidden_dim=hidden_dim,
num_layers=joint_state_encoder_layers,
num_heads=4,
max_seq_len=max_joint_state_context_length,
max_seq_len=joint_state_context_length,
)
if use_joint_states
else None
Expand All @@ -84,7 +84,7 @@ def __init__(
image_encoder_type=image_encoder_type,
hidden_dim=hidden_dim,
num_layers=num_image_sequence_encoder_layers,
max_seq_len=max_image_context_length,
max_seq_len=image_context_length,
)
if use_images
else None
Expand Down
25 changes: 25 additions & 0 deletions ddlitlab2024/ml/training/config/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
hidden_dim: 256
action_context_length: 100
trajectory_prediction_length: 10
epochs: 500
batch_size: 16
lr: 1.e-4
train_denoising_timesteps: 1000
image_context_length: 10
imu_context_length: 100
num_imu_encoder_layers: 2
joint_state_context_length: 100
num_normalization_samples: 1000
num_joints: 20
use_action_history: True
num_action_history_encoder_layers: 2
use_imu: True
imu_orientation_embedding_method: "quaternion"
num_imu_encoder_layers: 2
use_joint_states: True
joint_state_encoder_layers: 2
use_images: True
image_sequence_encoder_type: "transformer"
image_encoder_type: "resnet18"
num_image_sequence_encoder_layers: 1
num_decoder_layers: 4
Loading

0 comments on commit d4fe368

Please sign in to comment.