diff --git a/odak/learn/perception/learned_perceptual_losses.py b/odak/learn/perception/learned_perceptual_losses.py index 219897ba..d9391b44 100644 --- a/odak/learn/perception/learned_perceptual_losses.py +++ b/odak/learn/perception/learned_perceptual_losses.py @@ -22,7 +22,7 @@ def __init__(self, device = torch.device('cpu')): logging.warning(e) - def forward(self, predictions, targets, dim_order = 'CHW'): + def forward(self, predictions, targets, dim_order = 'BCHW'): """ Parameters ---------- @@ -31,7 +31,7 @@ def forward(self, predictions, targets, dim_order = 'CHW'): targets h : torch.tensor The ground truth images. dim_order : str - The dimension order of the input images. Defaults to 'CHW' (channels, height, width). + The dimension order of the input images. Defaults to 'BCHW' (channels, height, width). Returns ------- @@ -42,7 +42,7 @@ def forward(self, predictions, targets, dim_order = 'CHW'): if len(predictions.shape) == 3: predictions = predictions.unsqueeze(0) targets = targets.unsqueeze(0) - l_ColorVideoVDP = self.cvvdp.loss(predictions, targets, dim_order = dim_order) + l_ColorVideoVDP = self.cvvdp.predict(predictions, targets, dim_order = dim_order)[0] return l_ColorVideoVDP except Exception as e: logging.warning('ColorVideoVDP failed to compute.') @@ -68,7 +68,7 @@ def __init__(self, device = torch.device('cpu')): logging.warning(e) - def forward(self, predictions, targets, dim_order = 'CHW'): + def forward(self, predictions, targets, dim_order = 'BCHW'): """ Parameters ---------- @@ -77,7 +77,7 @@ def forward(self, predictions, targets, dim_order = 'CHW'): targets : torch.tensor The ground truth images. dim_order : str - The dimension order of the input images. Defaults to 'CHW' (channels, height, width). + The dimension order of the input images. Defaults to 'BCHW' (channels, height, width). Returns -------