From 5a234add0e3a45f075947d1922ed262c1bbde289 Mon Sep 17 00:00:00 2001 From: urw7rs Date: Mon, 30 Jan 2023 14:49:37 +0900 Subject: [PATCH 1/2] fix bug where set_detect_anomaly wasn't used as a context manager, improves gpu utilization --- diffusion/gaussian_diffusion.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py index fbfb3da0..ce87e1f4 100644 --- a/diffusion/gaussian_diffusion.py +++ b/diffusion/gaussian_diffusion.py @@ -1318,22 +1318,22 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, data terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: - torch.autograd.set_detect_anomaly(True) - if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: - target_xyz = get_xyz(target) if target_xyz is None else target_xyz - model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz - # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 - l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 - relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] - gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] - gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] - fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) - pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] - pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] - pred_vel[~fc_mask] = 0 - terms["fc"] = self.masked_l2(pred_vel, - torch.zeros(pred_vel.shape, device=pred_vel.device), - mask[:, :, :, 1:]) + with torch.autograd.set_detect_anomaly(True): + if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: + target_xyz = get_xyz(target) if target_xyz is None else target_xyz + model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz + # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 + l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 + relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] + gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] + gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] + fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) + pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] + pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] + pred_vel[~fc_mask] = 0 + terms["fc"] = self.masked_l2(pred_vel, + torch.zeros(pred_vel.shape, device=pred_vel.device), + mask[:, :, :, 1:]) if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) From 3dec2ced782b1498f148827dba00ee91c55db45a Mon Sep 17 00:00:00 2001 From: urw7rs Date: Mon, 30 Jan 2023 14:50:22 +0900 Subject: [PATCH 2/2] set cudnn benchmark to True --- train/train_mdm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train/train_mdm.py b/train/train_mdm.py index adeabe7f..34614673 100644 --- a/train/train_mdm.py +++ b/train/train_mdm.py @@ -13,6 +13,10 @@ from utils.model_util import create_model_and_diffusion from train.train_platforms import ClearmlPlatform, TensorboardPlatform, NoPlatform # required for the eval operation +import torch +if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + def main(): args = train_args() fixseed(args.seed)