Skip to content

Commit

Permalink
Merge pull request #126 from yilmazdoga/master
Browse files Browse the repository at this point in the history
Changed defaults for learned perceptual losses
  • Loading branch information
kaanaksit authored Nov 28, 2024
2 parents 6a772a8 + 8bc0fe8 commit 196a8aa
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions odak/learn/perception/learned_perceptual_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, device = torch.device('cpu')):
logging.warning(e)


def forward(self, predictions, targets, dim_order = 'CHW'):
def forward(self, predictions, targets, dim_order = 'BCHW'):
"""
Parameters
----------
Expand All @@ -31,7 +31,7 @@ def forward(self, predictions, targets, dim_order = 'CHW'):
targets h : torch.tensor
The ground truth images.
dim_order : str
The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
The dimension order of the input images. Defaults to 'BCHW' (channels, height, width).
Returns
-------
Expand All @@ -42,7 +42,7 @@ def forward(self, predictions, targets, dim_order = 'CHW'):
if len(predictions.shape) == 3:
predictions = predictions.unsqueeze(0)
targets = targets.unsqueeze(0)
l_ColorVideoVDP = self.cvvdp.loss(predictions, targets, dim_order = dim_order)
l_ColorVideoVDP = self.cvvdp.predict(predictions, targets, dim_order = dim_order)[0]
return l_ColorVideoVDP
except Exception as e:
logging.warning('ColorVideoVDP failed to compute.')
Expand All @@ -68,7 +68,7 @@ def __init__(self, device = torch.device('cpu')):
logging.warning(e)


def forward(self, predictions, targets, dim_order = 'CHW'):
def forward(self, predictions, targets, dim_order = 'BCHW'):
"""
Parameters
----------
Expand All @@ -77,7 +77,7 @@ def forward(self, predictions, targets, dim_order = 'CHW'):
targets : torch.tensor
The ground truth images.
dim_order : str
The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
The dimension order of the input images. Defaults to 'BCHW' (channels, height, width).
Returns
-------
Expand Down

0 comments on commit 196a8aa

Please sign in to comment.