Skip to content

Commit

Permalink
Revisions in odak.learn.wave.perceptual_multiplane_loss().
Browse files Browse the repository at this point in the history
  • Loading branch information
kaanaksit committed Nov 18, 2024
1 parent 70c4136 commit 098859d
Showing 1 changed file with 41 additions and 34 deletions.
75 changes: 41 additions & 34 deletions odak/learn/wave/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,18 +364,19 @@ def __init__(self, target_image, target_depth, blur_ratio = 0.25,
self.l1_loss_fn = torch.nn.L1Loss(reduction = self.reduction)
self.l2_loss_fn = torch.nn.MSELoss(reduction = self.reduction)
for key in self.additional_loss_weights.keys():
if key == 'cvvdp':
self.cvvdp = CVVDP()
if key == 'fvvdp':
self.fvvdp = FVVDP()
if key == 'lpips':
self.lpips = LPIPS()
if key == 'psnr':
self.psnr = PSNR()
if key == 'ssim':
self.ssim = SSIM()
if key == 'msssim':
self.msssim = MSSSIM()
if self.additional_loss_weights[key]:
if key == 'cvvdp':
self.cvvdp = CVVDP(device = device)
if key == 'fvvdp':
self.fvvdp = FVVDP()
if key == 'lpips':
self.lpips = LPIPS()
if key == 'psnr':
self.psnr = PSNR()
if key == 'ssim':
self.ssim = SSIM()
if key == 'msssim':
self.msssim = MSSSIM()

def get_targets(self):
"""
Expand Down Expand Up @@ -475,36 +476,42 @@ def __call__(self, image, target, plane_id = None):
loss_components['l2'] = l2
loss_components['l2_mask'] = l2_mask
loss_components['l2_cor'] = l2_cor
loss = l2 + l2_mask + l2_cor

l1 = self.base_loss_weights['base_l1_loss'] * self.l1_loss_fn(image, target)
l1_mask = self.base_loss_weights['loss_l1_mask'] * self.l1_loss_fn(image * mask, target * mask)
l1_cor = self.base_loss_weights['loss_l1_cor'] * self.l1_loss_fn(image * target, target * target)
loss_components['l1'] = l1
loss_components['l1_mask'] = l1_mask
loss_components['l1_cor'] = l1_cor

loss += l1 + l1_mask + l1_cor

for key in self.additional_loss_weights.keys():
if key == 'cvvdp':
loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target)
loss_components['cvvdp'] = loss_cvvdp
if key == 'fvvdp':
loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target)
loss_components['fvvdp'] = loss_fvvdp
if key == 'lpips':
loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target)
loss_components['lpips'] = loss_lpips
if key == 'psnr':
loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target)
loss_components['psnr'] = loss_psnr
if key == 'ssim':
loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target)
loss_components['ssim'] = loss_ssim
if key == 'msssim':
loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target)
loss_components['msssim'] = loss_msssim

loss = torch.sum(torch.stack(list(loss_components.values())), dim = 0)

if self.additional_loss_weights[key]:
if key == 'cvvdp':
loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target)
loss_components['cvvdp'] = loss_cvvdp
loss += loss_cvvdp
if key == 'fvvdp':
loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target)
loss_components['fvvdp'] = loss_fvvdp
loss += loss_fvvdp
if key == 'lpips':
loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target)
loss_components['lpips'] = loss_lpips
loss += loss_lpips
if key == 'psnr':
loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target)
loss_components['psnr'] = loss_psnr
loss += loss_psnr
if key == 'ssim':
loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target)
loss_components['ssim'] = loss_ssim
loss += loss_ssim
if key == 'msssim':
loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target)
loss_components['msssim'] = loss_msssim
loss += loss_msssim
if self.return_components:
return loss, loss_components
return loss

0 comments on commit 098859d

Please sign in to comment.