diff --git a/odak/learn/wave/loss.py b/odak/learn/wave/loss.py index 80c29cd1..d02f77d3 100644 --- a/odak/learn/wave/loss.py +++ b/odak/learn/wave/loss.py @@ -364,18 +364,19 @@ def __init__(self, target_image, target_depth, blur_ratio = 0.25, self.l1_loss_fn = torch.nn.L1Loss(reduction = self.reduction) self.l2_loss_fn = torch.nn.MSELoss(reduction = self.reduction) for key in self.additional_loss_weights.keys(): - if key == 'cvvdp': - self.cvvdp = CVVDP() - if key == 'fvvdp': - self.fvvdp = FVVDP() - if key == 'lpips': - self.lpips = LPIPS() - if key == 'psnr': - self.psnr = PSNR() - if key == 'ssim': - self.ssim = SSIM() - if key == 'msssim': - self.msssim = MSSSIM() + if self.additional_loss_weights[key]: + if key == 'cvvdp': + self.cvvdp = CVVDP(device = device) + if key == 'fvvdp': + self.fvvdp = FVVDP() + if key == 'lpips': + self.lpips = LPIPS() + if key == 'psnr': + self.psnr = PSNR() + if key == 'ssim': + self.ssim = SSIM() + if key == 'msssim': + self.msssim = MSSSIM() def get_targets(self): """ @@ -475,6 +476,7 @@ def __call__(self, image, target, plane_id = None): loss_components['l2'] = l2 loss_components['l2_mask'] = l2_mask loss_components['l2_cor'] = l2_cor + loss = l2 + l2_mask + l2_cor l1 = self.base_loss_weights['base_l1_loss'] * self.l1_loss_fn(image, target) l1_mask = self.base_loss_weights['loss_l1_mask'] * self.l1_loss_fn(image * mask, target * mask) @@ -482,29 +484,34 @@ def __call__(self, image, target, plane_id = None): loss_components['l1'] = l1 loss_components['l1_mask'] = l1_mask loss_components['l1_cor'] = l1_cor - + loss += l1 + l1_mask + l1_cor + for key in self.additional_loss_weights.keys(): - if key == 'cvvdp': - loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target) - loss_components['cvvdp'] = loss_cvvdp - if key == 'fvvdp': - loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target) - loss_components['fvvdp'] = loss_fvvdp - if key == 'lpips': - loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target) - loss_components['lpips'] = loss_lpips - if key == 'psnr': - loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target) - loss_components['psnr'] = loss_psnr - if key == 'ssim': - loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target) - loss_components['ssim'] = loss_ssim - if key == 'msssim': - loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target) - loss_components['msssim'] = loss_msssim - - loss = torch.sum(torch.stack(list(loss_components.values())), dim = 0) - + if self.additional_loss_weights[key]: + if key == 'cvvdp': + loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target) + loss_components['cvvdp'] = loss_cvvdp + loss += loss_cvvdp + if key == 'fvvdp': + loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target) + loss_components['fvvdp'] = loss_fvvdp + loss += loss_fvvdp + if key == 'lpips': + loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target) + loss_components['lpips'] = loss_lpips + loss += loss_lpips + if key == 'psnr': + loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target) + loss_components['psnr'] = loss_psnr + loss += loss_psnr + if key == 'ssim': + loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target) + loss_components['ssim'] = loss_ssim + loss += loss_ssim + if key == 'msssim': + loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target) + loss_components['msssim'] = loss_msssim + loss += loss_msssim if self.return_components: return loss, loss_components return loss