From 836e77143f731eed5ed3d66fe7ebcc8a90f0c7c4 Mon Sep 17 00:00:00 2001 From: szmazurek Date: Sat, 9 Nov 2024 21:17:22 +0100 Subject: [PATCH 1/9] Interface and example implementation of the loss class --- GANDLF/losses/loss_interface.py | 41 +++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 GANDLF/losses/loss_interface.py diff --git a/GANDLF/losses/loss_interface.py b/GANDLF/losses/loss_interface.py new file mode 100644 index 000000000..86cba7c75 --- /dev/null +++ b/GANDLF/losses/loss_interface.py @@ -0,0 +1,41 @@ +import torch +from torch import nn +from abc import ABC, abstractmethod + + +class AbstractLossFunction(ABC, nn.Module): + def __init__(self, params: dict): + super().__init__() + self.params = params + + @abstractmethod + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + pass + + +class WeightedCE(AbstractLossFunction): + def __init__(self, params: dict): + """ + Cross entropy loss using class weights if provided. + """ + super().__init__(params) + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + if len(target.shape) > 1 and target.shape[-1] == 1: + target = torch.squeeze(target, -1) + + weights = None + if self.params.get("penalty_weights") is not None: + num_classes = len(self.params["penalty_weights"]) + assert ( + prediction.shape[-1] == num_classes + ), f"Number of classes {num_classes} does not match prediction shape {prediction.shape[-1]}" + + weights = torch.tensor( + list(self.params["penalty_weights"].values()), + dtype=torch.float32, + device=target.device, + ) + + cel = nn.CrossEntropyLoss(weight=weights) + return cel(prediction, target) From 227a86f3d119676db155ae7d1b5d52e88db00e60 Mon Sep 17 00:00:00 2001 From: Szymon Mazurek Date: Sat, 16 Nov 2024 15:13:47 +0100 Subject: [PATCH 2/9] Porting losses to new interface WIP --- GANDLF/losses/loss_interface.py | 75 ++++++++++++----- GANDLF/losses/segmentation.py | 144 ++++++++++++++++++++++++++++---- 2 files changed, 181 insertions(+), 38 deletions(-) diff --git a/GANDLF/losses/loss_interface.py b/GANDLF/losses/loss_interface.py index 86cba7c75..49d5b5031 100644 --- a/GANDLF/losses/loss_interface.py +++ b/GANDLF/losses/loss_interface.py @@ -3,39 +3,68 @@ from abc import ABC, abstractmethod -class AbstractLossFunction(ABC, nn.Module): +class AbstractLossFunction(nn.Module, ABC): def __init__(self, params: dict): - super().__init__() + nn.Module.__init__(self) self.params = params @abstractmethod - def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward( + self, prediction: torch.Tensor, target: torch.Tensor, *args + ) -> torch.Tensor: pass -class WeightedCE(AbstractLossFunction): +class AbstractSegmentationMultiClassLoss(AbstractLossFunction): + """ + Base class for loss funcions that are used for multi-class segmentation tasks. + """ + def __init__(self, params: dict): + super().__init__(params) + self.num_classes = len(params["model"]["class_list"]) + self.penalty_weights = params["penalty_weights"] + + def _compute_single_class_loss( + self, prediction: torch.Tensor, target: torch.Tensor, class_idx: int + ) -> torch.Tensor: + """Compute loss for a single class.""" + loss_value = self._single_class_loss_calculator( + prediction[:, class_idx, ...], target[:, class_idx, ...] + ) + return 1 - loss_value + + def _optional_loss_operations(self, loss: torch.Tensor) -> torch.Tensor: """ - Cross entropy loss using class weights if provided. + Perform addtional operations of the loss value. Defaults to identity operation. + If needed, child classes can override this method. Useful in the cases where + for example, the loss value needs to log-transformed or clipped. """ - super().__init__(params) + return loss - def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - if len(target.shape) > 1 and target.shape[-1] == 1: - target = torch.squeeze(target, -1) - - weights = None - if self.params.get("penalty_weights") is not None: - num_classes = len(self.params["penalty_weights"]) - assert ( - prediction.shape[-1] == num_classes - ), f"Number of classes {num_classes} does not match prediction shape {prediction.shape[-1]}" - - weights = torch.tensor( - list(self.params["penalty_weights"].values()), - dtype=torch.float32, - device=target.device, + @abstractmethod + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """Compute loss for a pair of prediction and target tensors. To be implemented by child classes.""" + pass + + def forward( + self, prediction: torch.Tensor, target: torch.Tensor, *args + ) -> torch.Tensor: + accumulated_loss = torch.tensor(0.0, device=prediction.device) + + for class_idx in range(self.num_classes): + current_loss = self._compute_single_class_loss( + prediction, target, class_idx ) + current_loss = self._optional_loss_operations(current_loss) + + if self.penalty_weights is not None: + current_loss = current_loss * self.penalty_weights[class_idx] + accumulated_loss += current_loss + + if self.penalty_weights is None: + accumulated_loss /= self.num_classes - cel = nn.CrossEntropyLoss(weight=weights) - return cel(prediction, target) + return accumulated_loss diff --git a/GANDLF/losses/segmentation.py b/GANDLF/losses/segmentation.py index 32e43bc25..35feb3c25 100644 --- a/GANDLF/losses/segmentation.py +++ b/GANDLF/losses/segmentation.py @@ -1,29 +1,127 @@ import sys from typing import List, Optional import torch +from .loss_interface import AbstractSegmentationMultiClassLoss, AbstractLossFunction -# Dice scores and dice losses -def dice(predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor: +class MulticlassDiceLoss(AbstractSegmentationMultiClassLoss): + """ + This class computes the Dice loss between two tensors. """ - This function computes a dice score between two tensors. - Args: - predicted (torch.Tensor): Predicted value by the network. - target (torch.Tensor): Required target label to match the predicted with + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute Dice score for a single class. - Returns: - torch.Tensor: The computed dice score. + Args: + prediction (torch.Tensor): Network's predicted segmentation mask + target (torch.Tensor): Target segmentation mask + + Returns: + torch.Tensor: The computed dice score. + """ + predicted_flat = prediction.flatten() + label_flat = target.flatten() + intersection = (predicted_flat * label_flat).sum() + + dice_score = (2.0 * intersection + sys.float_info.min) / ( + predicted_flat.sum() + label_flat.sum() + sys.float_info.min + ) + + return dice_score + + +class MulticlassDiceLogLoss(MulticlassDiceLoss): + def _optional_loss_operations(self, loss): + return -torch.log( + loss + torch.finfo(torch.float32).eps + ) # epsilon for numerical stability + + +class MulticlassMCCLoss(AbstractSegmentationMultiClassLoss): + """ + This class computes the Matthews Correlation Coefficient (MCC) loss between two tensors. """ - predicted_flat = predicted.flatten() - label_flat = target.flatten() - intersection = (predicted_flat * label_flat).sum() - dice_score = (2.0 * intersection + sys.float_info.min) / ( - predicted_flat.sum() + label_flat.sum() + sys.float_info.min - ) + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute MCC score for a single class. + + Args: + prediction (torch.Tensor): Network's predicted segmentation mask + target (torch.Tensor): Target segmentation mask + + Returns: + torch.Tensor: The computed MCC score. + """ + tp = torch.sum(torch.mul(prediction, target)) + tn = torch.sum(torch.mul((1 - prediction), (1 - target))) + fp = torch.sum(torch.mul(prediction, (1 - target))) + fn = torch.sum(torch.mul((1 - prediction), target)) + + numerator = torch.mul(tp, tn) - torch.mul(fp, fn) + # Adding epsilon to the denominator to avoid divide-by-zero errors. + denominator = ( + torch.sqrt( + torch.add(tp, 1, fp) + * torch.add(tp, 1, fn) + * torch.add(tn, 1, fp) + * torch.add(tn, 1, fn) + ) + + torch.finfo(torch.float32).eps + ) - return dice_score + return torch.div(numerator.sum(), denominator.sum()) + + +class MulticlassMCLLogLoss(MulticlassMCCLoss): + def _optional_loss_operations(self, loss): + return -torch.log( + loss + torch.finfo(torch.float32).eps + ) # epsilon for numerical stability + + +class MulticlassTverskyLoss(AbstractSegmentationMultiClassLoss): + """ + This class computes the Tversky loss between two tensors. + """ + + def __init__(self, params: dict): + super().__init__(params) + self.alpha = params.get("alpha", 0.5) + self.beta = params.get("beta", 0.5) + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute Tversky score for a single class. + + Args: + prediction (torch.Tensor): Network's predicted segmentation mask + target (torch.Tensor): Target segmentation mask + + Returns: + torch.Tensor: The computed Tversky score. + """ + predicted_flat = prediction.contiguous().view(-1) + target_flat = target.contiguous().view(-1) + + true_positives = (predicted_flat * target_flat).sum() + false_positives = ((1 - target_flat) * predicted_flat).sum() + false_negatives = (target_flat * (1 - predicted_flat)).sum() + + numerator = true_positives + denominator = ( + true_positives + self.alpha * false_positives + self.beta * false_negatives + ) + loss = (numerator + sys.float_info.min) / (denominator + sys.float_info.min) + + return loss def mcc(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: @@ -114,6 +212,22 @@ def generic_loss_calculator( return accumulated_loss +class KullbackLeiblerDivergence(AbstractLossFunction): + def forward(self, mu: torch.Tensor, logvar: torch.Tensor, *args) -> torch.Tensor: + """ + Calculates the Kullback-Leibler divergence between two Gaussian distributions. + + Args: + mu (torch.Tensor): The mean of the first Gaussian distribution. + logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution. + + Returns: + torch.Tensor: The computed Kullback-Leibler divergence + """ + loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1) + return loss.mean() + + def MCD_loss( predicted: torch.Tensor, target: torch.Tensor, params: dict ) -> torch.Tensor: From f7e168bde3cecbac5b8efe1c34ccaa4e7feeebd4 Mon Sep 17 00:00:00 2001 From: Szymon Mazurek Date: Sat, 16 Nov 2024 22:37:50 +0100 Subject: [PATCH 3/9] Segmentation losses refactored --- GANDLF/losses/loss_interface.py | 9 +-- GANDLF/losses/segmentation.py | 105 +++++++++++++++++++++++++++----- 2 files changed, 92 insertions(+), 22 deletions(-) diff --git a/GANDLF/losses/loss_interface.py b/GANDLF/losses/loss_interface.py index 49d5b5031..53c5a9325 100644 --- a/GANDLF/losses/loss_interface.py +++ b/GANDLF/losses/loss_interface.py @@ -9,9 +9,7 @@ def __init__(self, params: dict): self.params = params @abstractmethod - def forward( - self, prediction: torch.Tensor, target: torch.Tensor, *args - ) -> torch.Tensor: + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: pass @@ -49,9 +47,7 @@ def _single_class_loss_calculator( """Compute loss for a pair of prediction and target tensors. To be implemented by child classes.""" pass - def forward( - self, prediction: torch.Tensor, target: torch.Tensor, *args - ) -> torch.Tensor: + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: accumulated_loss = torch.tensor(0.0, device=prediction.device) for class_idx in range(self.num_classes): @@ -64,6 +60,7 @@ def forward( current_loss = current_loss * self.penalty_weights[class_idx] accumulated_loss += current_loss + # TODO shouldn't we always divide by the number of classes? if self.penalty_weights is None: accumulated_loss /= self.num_classes diff --git a/GANDLF/losses/segmentation.py b/GANDLF/losses/segmentation.py index 35feb3c25..675dab74c 100644 --- a/GANDLF/losses/segmentation.py +++ b/GANDLF/losses/segmentation.py @@ -124,6 +124,95 @@ def _single_class_loss_calculator( return loss +class MulticlassFocalLoss(AbstractSegmentationMultiClassLoss): + """ + This class computes the Focal loss between two tensors. + """ + + def __init__(self, params: dict): + super().__init__(params) + + self.ce_loss_helper = torch.nn.CrossEntropyLoss(reduction="none") + loss_params = params["loss_function"] + self.alpha = 1.0 + self.gamma = 2.0 + self.output_aggregation = "sum" + if isinstance(loss_params, dict): + self.alpha = loss_params.get("alpha", self.alpha) + self.gamma = loss_params.get("gamma", self.gamma) + self.output_aggregation = loss_params.get( + "size_average", + self.output_aggregation, # naming mismatch of key due to keeping API consistent with config format + ) + assert self.output_aggregation in [ + "sum", + "mean", + ], f"Invalid output aggregation method defined for Foal Loss: {self.output_aggregation}. Valid options are ['sum', 'mean']" + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute focal loss for a single class. It is based on the following formulas: + FocalLoss(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t) + CrossEntropy(pred, target) = -log(pred) if target = 1 else -log(1 - pred) + CrossEntropy(p_t) = CrossEntropy(pred, target) = -log(p_t) + p_t = p if target = 1 else 1 - p + """ + ce_loss = self.ce_loss_helper(prediction, target) + p_t = torch.exp(-ce_loss) + loss = -self.alpha * (1 - p_t) ** self.gamma * ce_loss + return loss.sum() if self.output_aggregation == "sum" else loss.mean() + + def _compute_single_class_loss( + self, prediction: torch.Tensor, target: torch.Tensor, class_idx: int + ) -> torch.Tensor: + """Compute loss for a single class.""" + loss_value = self._single_class_loss_calculator( + prediction[:, class_idx, ...], target[:, class_idx, ...] + ) + return loss_value # no need to subtract from 1 in this case, hence the override + + +class KullbackLeiblerDivergence(AbstractLossFunction): + def forward(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + """ + Calculates the Kullback-Leibler divergence between two Gaussian distributions. + + Args: + mu (torch.Tensor): The mean of the first Gaussian distribution. + logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution. + + Returns: + torch.Tensor: The computed Kullback-Leibler divergence + """ + loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1) + return loss.mean() + + +# Dice scores and dice losses +def dice(predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + This function computes a dice score between two tensors. + + Args: + predicted (torch.Tensor): Predicted value by the network. + target (torch.Tensor): Required target label to match the predicted with + + Returns: + torch.Tensor: The computed dice score. + """ + predicted_flat = predicted.flatten() + label_flat = target.flatten() + intersection = (predicted_flat * label_flat).sum() + + dice_score = (2.0 * intersection + sys.float_info.min) / ( + predicted_flat.sum() + label_flat.sum() + sys.float_info.min + ) + + return dice_score + + def mcc(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ This function computes the Matthews Correlation Coefficient (MCC) between two tensors. Adapted from https://github.com/kakumarabhishek/MCC-Loss/blob/main/loss.py. @@ -212,22 +301,6 @@ def generic_loss_calculator( return accumulated_loss -class KullbackLeiblerDivergence(AbstractLossFunction): - def forward(self, mu: torch.Tensor, logvar: torch.Tensor, *args) -> torch.Tensor: - """ - Calculates the Kullback-Leibler divergence between two Gaussian distributions. - - Args: - mu (torch.Tensor): The mean of the first Gaussian distribution. - logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution. - - Returns: - torch.Tensor: The computed Kullback-Leibler divergence - """ - loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1) - return loss.mean() - - def MCD_loss( predicted: torch.Tensor, target: torch.Tensor, params: dict ) -> torch.Tensor: From 572ea5e94724458865ce106a1f7d204b24ae57ed Mon Sep 17 00:00:00 2001 From: Szymon Mazurek Date: Sat, 16 Nov 2024 22:39:48 +0100 Subject: [PATCH 4/9] Move losses temporairly to a new file --- GANDLF/losses/segmentation.py | 187 ----------------------------- GANDLF/losses/segmentation_new.py | 189 ++++++++++++++++++++++++++++++ 2 files changed, 189 insertions(+), 187 deletions(-) create mode 100644 GANDLF/losses/segmentation_new.py diff --git a/GANDLF/losses/segmentation.py b/GANDLF/losses/segmentation.py index 675dab74c..32e43bc25 100644 --- a/GANDLF/losses/segmentation.py +++ b/GANDLF/losses/segmentation.py @@ -1,193 +1,6 @@ import sys from typing import List, Optional import torch -from .loss_interface import AbstractSegmentationMultiClassLoss, AbstractLossFunction - - -class MulticlassDiceLoss(AbstractSegmentationMultiClassLoss): - """ - This class computes the Dice loss between two tensors. - """ - - def _single_class_loss_calculator( - self, prediction: torch.Tensor, target: torch.Tensor - ) -> torch.Tensor: - """ - Compute Dice score for a single class. - - Args: - prediction (torch.Tensor): Network's predicted segmentation mask - target (torch.Tensor): Target segmentation mask - - Returns: - torch.Tensor: The computed dice score. - """ - predicted_flat = prediction.flatten() - label_flat = target.flatten() - intersection = (predicted_flat * label_flat).sum() - - dice_score = (2.0 * intersection + sys.float_info.min) / ( - predicted_flat.sum() + label_flat.sum() + sys.float_info.min - ) - - return dice_score - - -class MulticlassDiceLogLoss(MulticlassDiceLoss): - def _optional_loss_operations(self, loss): - return -torch.log( - loss + torch.finfo(torch.float32).eps - ) # epsilon for numerical stability - - -class MulticlassMCCLoss(AbstractSegmentationMultiClassLoss): - """ - This class computes the Matthews Correlation Coefficient (MCC) loss between two tensors. - """ - - def _single_class_loss_calculator( - self, prediction: torch.Tensor, target: torch.Tensor - ) -> torch.Tensor: - """ - Compute MCC score for a single class. - - Args: - prediction (torch.Tensor): Network's predicted segmentation mask - target (torch.Tensor): Target segmentation mask - - Returns: - torch.Tensor: The computed MCC score. - """ - tp = torch.sum(torch.mul(prediction, target)) - tn = torch.sum(torch.mul((1 - prediction), (1 - target))) - fp = torch.sum(torch.mul(prediction, (1 - target))) - fn = torch.sum(torch.mul((1 - prediction), target)) - - numerator = torch.mul(tp, tn) - torch.mul(fp, fn) - # Adding epsilon to the denominator to avoid divide-by-zero errors. - denominator = ( - torch.sqrt( - torch.add(tp, 1, fp) - * torch.add(tp, 1, fn) - * torch.add(tn, 1, fp) - * torch.add(tn, 1, fn) - ) - + torch.finfo(torch.float32).eps - ) - - return torch.div(numerator.sum(), denominator.sum()) - - -class MulticlassMCLLogLoss(MulticlassMCCLoss): - def _optional_loss_operations(self, loss): - return -torch.log( - loss + torch.finfo(torch.float32).eps - ) # epsilon for numerical stability - - -class MulticlassTverskyLoss(AbstractSegmentationMultiClassLoss): - """ - This class computes the Tversky loss between two tensors. - """ - - def __init__(self, params: dict): - super().__init__(params) - self.alpha = params.get("alpha", 0.5) - self.beta = params.get("beta", 0.5) - - def _single_class_loss_calculator( - self, prediction: torch.Tensor, target: torch.Tensor - ) -> torch.Tensor: - """ - Compute Tversky score for a single class. - - Args: - prediction (torch.Tensor): Network's predicted segmentation mask - target (torch.Tensor): Target segmentation mask - - Returns: - torch.Tensor: The computed Tversky score. - """ - predicted_flat = prediction.contiguous().view(-1) - target_flat = target.contiguous().view(-1) - - true_positives = (predicted_flat * target_flat).sum() - false_positives = ((1 - target_flat) * predicted_flat).sum() - false_negatives = (target_flat * (1 - predicted_flat)).sum() - - numerator = true_positives - denominator = ( - true_positives + self.alpha * false_positives + self.beta * false_negatives - ) - loss = (numerator + sys.float_info.min) / (denominator + sys.float_info.min) - - return loss - - -class MulticlassFocalLoss(AbstractSegmentationMultiClassLoss): - """ - This class computes the Focal loss between two tensors. - """ - - def __init__(self, params: dict): - super().__init__(params) - - self.ce_loss_helper = torch.nn.CrossEntropyLoss(reduction="none") - loss_params = params["loss_function"] - self.alpha = 1.0 - self.gamma = 2.0 - self.output_aggregation = "sum" - if isinstance(loss_params, dict): - self.alpha = loss_params.get("alpha", self.alpha) - self.gamma = loss_params.get("gamma", self.gamma) - self.output_aggregation = loss_params.get( - "size_average", - self.output_aggregation, # naming mismatch of key due to keeping API consistent with config format - ) - assert self.output_aggregation in [ - "sum", - "mean", - ], f"Invalid output aggregation method defined for Foal Loss: {self.output_aggregation}. Valid options are ['sum', 'mean']" - - def _single_class_loss_calculator( - self, prediction: torch.Tensor, target: torch.Tensor - ) -> torch.Tensor: - """ - Compute focal loss for a single class. It is based on the following formulas: - FocalLoss(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t) - CrossEntropy(pred, target) = -log(pred) if target = 1 else -log(1 - pred) - CrossEntropy(p_t) = CrossEntropy(pred, target) = -log(p_t) - p_t = p if target = 1 else 1 - p - """ - ce_loss = self.ce_loss_helper(prediction, target) - p_t = torch.exp(-ce_loss) - loss = -self.alpha * (1 - p_t) ** self.gamma * ce_loss - return loss.sum() if self.output_aggregation == "sum" else loss.mean() - - def _compute_single_class_loss( - self, prediction: torch.Tensor, target: torch.Tensor, class_idx: int - ) -> torch.Tensor: - """Compute loss for a single class.""" - loss_value = self._single_class_loss_calculator( - prediction[:, class_idx, ...], target[:, class_idx, ...] - ) - return loss_value # no need to subtract from 1 in this case, hence the override - - -class KullbackLeiblerDivergence(AbstractLossFunction): - def forward(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: - """ - Calculates the Kullback-Leibler divergence between two Gaussian distributions. - - Args: - mu (torch.Tensor): The mean of the first Gaussian distribution. - logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution. - - Returns: - torch.Tensor: The computed Kullback-Leibler divergence - """ - loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1) - return loss.mean() # Dice scores and dice losses diff --git a/GANDLF/losses/segmentation_new.py b/GANDLF/losses/segmentation_new.py new file mode 100644 index 000000000..e68965848 --- /dev/null +++ b/GANDLF/losses/segmentation_new.py @@ -0,0 +1,189 @@ +import sys +import torch +from .loss_interface import AbstractSegmentationMultiClassLoss, AbstractLossFunction + + +class MulticlassDiceLoss(AbstractSegmentationMultiClassLoss): + """ + This class computes the Dice loss between two tensors. + """ + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute Dice score for a single class. + + Args: + prediction (torch.Tensor): Network's predicted segmentation mask + target (torch.Tensor): Target segmentation mask + + Returns: + torch.Tensor: The computed dice score. + """ + predicted_flat = prediction.flatten() + label_flat = target.flatten() + intersection = (predicted_flat * label_flat).sum() + + dice_score = (2.0 * intersection + sys.float_info.min) / ( + predicted_flat.sum() + label_flat.sum() + sys.float_info.min + ) + + return dice_score + + +class MulticlassDiceLogLoss(MulticlassDiceLoss): + def _optional_loss_operations(self, loss): + return -torch.log( + loss + torch.finfo(torch.float32).eps + ) # epsilon for numerical stability + + +class MulticlassMCCLoss(AbstractSegmentationMultiClassLoss): + """ + This class computes the Matthews Correlation Coefficient (MCC) loss between two tensors. + """ + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute MCC score for a single class. + + Args: + prediction (torch.Tensor): Network's predicted segmentation mask + target (torch.Tensor): Target segmentation mask + + Returns: + torch.Tensor: The computed MCC score. + """ + tp = torch.sum(torch.mul(prediction, target)) + tn = torch.sum(torch.mul((1 - prediction), (1 - target))) + fp = torch.sum(torch.mul(prediction, (1 - target))) + fn = torch.sum(torch.mul((1 - prediction), target)) + + numerator = torch.mul(tp, tn) - torch.mul(fp, fn) + # Adding epsilon to the denominator to avoid divide-by-zero errors. + denominator = ( + torch.sqrt( + torch.add(tp, 1, fp) + * torch.add(tp, 1, fn) + * torch.add(tn, 1, fp) + * torch.add(tn, 1, fn) + ) + + torch.finfo(torch.float32).eps + ) + + return torch.div(numerator.sum(), denominator.sum()) + + +class MulticlassMCLLogLoss(MulticlassMCCLoss): + def _optional_loss_operations(self, loss): + return -torch.log( + loss + torch.finfo(torch.float32).eps + ) # epsilon for numerical stability + + +class MulticlassTverskyLoss(AbstractSegmentationMultiClassLoss): + """ + This class computes the Tversky loss between two tensors. + """ + + def __init__(self, params: dict): + super().__init__(params) + self.alpha = params.get("alpha", 0.5) + self.beta = params.get("beta", 0.5) + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute Tversky score for a single class. + + Args: + prediction (torch.Tensor): Network's predicted segmentation mask + target (torch.Tensor): Target segmentation mask + + Returns: + torch.Tensor: The computed Tversky score. + """ + predicted_flat = prediction.contiguous().view(-1) + target_flat = target.contiguous().view(-1) + + true_positives = (predicted_flat * target_flat).sum() + false_positives = ((1 - target_flat) * predicted_flat).sum() + false_negatives = (target_flat * (1 - predicted_flat)).sum() + + numerator = true_positives + denominator = ( + true_positives + self.alpha * false_positives + self.beta * false_negatives + ) + loss = (numerator + sys.float_info.min) / (denominator + sys.float_info.min) + + return loss + + +class MulticlassFocalLoss(AbstractSegmentationMultiClassLoss): + """ + This class computes the Focal loss between two tensors. + """ + + def __init__(self, params: dict): + super().__init__(params) + + self.ce_loss_helper = torch.nn.CrossEntropyLoss(reduction="none") + loss_params = params["loss_function"] + self.alpha = 1.0 + self.gamma = 2.0 + self.output_aggregation = "sum" + if isinstance(loss_params, dict): + self.alpha = loss_params.get("alpha", self.alpha) + self.gamma = loss_params.get("gamma", self.gamma) + self.output_aggregation = loss_params.get( + "size_average", + self.output_aggregation, # naming mismatch of key due to keeping API consistent with config format + ) + assert self.output_aggregation in [ + "sum", + "mean", + ], f"Invalid output aggregation method defined for Foal Loss: {self.output_aggregation}. Valid options are ['sum', 'mean']" + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute focal loss for a single class. It is based on the following formulas: + FocalLoss(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t) + CrossEntropy(pred, target) = -log(pred) if target = 1 else -log(1 - pred) + CrossEntropy(p_t) = CrossEntropy(pred, target) = -log(p_t) + p_t = p if target = 1 else 1 - p + """ + ce_loss = self.ce_loss_helper(prediction, target) + p_t = torch.exp(-ce_loss) + loss = -self.alpha * (1 - p_t) ** self.gamma * ce_loss + return loss.sum() if self.output_aggregation == "sum" else loss.mean() + + def _compute_single_class_loss( + self, prediction: torch.Tensor, target: torch.Tensor, class_idx: int + ) -> torch.Tensor: + """Compute loss for a single class.""" + loss_value = self._single_class_loss_calculator( + prediction[:, class_idx, ...], target[:, class_idx, ...] + ) + return loss_value # no need to subtract from 1 in this case, hence the override + + +class KullbackLeiblerDivergence(AbstractLossFunction): + def forward(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + """ + Calculates the Kullback-Leibler divergence between two Gaussian distributions. + + Args: + mu (torch.Tensor): The mean of the first Gaussian distribution. + logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution. + + Returns: + torch.Tensor: The computed Kullback-Leibler divergence + """ + loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1) + return loss.mean() From d80ff6ace2b1b8bff34cc2fa6dfdd3a6fef43293 Mon Sep 17 00:00:00 2001 From: Szymon Mazurek Date: Mon, 18 Nov 2024 14:17:22 +0100 Subject: [PATCH 5/9] Refactored regression losses WIP --- GANDLF/losses/loss_interface.py | 75 ++++++++++++++++++++++++++++--- GANDLF/losses/regression.py | 64 +++++++++++++++++++++++++- GANDLF/losses/segmentation_new.py | 8 +++- 3 files changed, 137 insertions(+), 10 deletions(-) diff --git a/GANDLF/losses/loss_interface.py b/GANDLF/losses/loss_interface.py index 53c5a9325..4aeeedb51 100644 --- a/GANDLF/losses/loss_interface.py +++ b/GANDLF/losses/loss_interface.py @@ -7,6 +7,14 @@ class AbstractLossFunction(nn.Module, ABC): def __init__(self, params: dict): nn.Module.__init__(self) self.params = params + self.num_classes = len(params["model"]["class_list"]) + self._initialize_penalty_weights() + + def _initialize_penalty_weights(self): + default_penalty_weights = torch.ones(self.num_classes) + self.penalty_weights = self.params.get( + "penalty_weights", default_penalty_weights + ) @abstractmethod def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: @@ -20,8 +28,6 @@ class AbstractSegmentationMultiClassLoss(AbstractLossFunction): def __init__(self, params: dict): super().__init__(params) - self.num_classes = len(params["model"]["class_list"]) - self.penalty_weights = params["penalty_weights"] def _compute_single_class_loss( self, prediction: torch.Tensor, target: torch.Tensor, class_idx: int @@ -54,14 +60,69 @@ def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tenso current_loss = self._compute_single_class_loss( prediction, target, class_idx ) - current_loss = self._optional_loss_operations(current_loss) - - if self.penalty_weights is not None: - current_loss = current_loss * self.penalty_weights[class_idx] - accumulated_loss += current_loss + accumulated_loss += ( + self._optional_loss_operations(current_loss) + * self.penalty_weights[class_idx] + ) # TODO shouldn't we always divide by the number of classes? if self.penalty_weights is None: accumulated_loss /= self.num_classes return accumulated_loss + + +class AbstractRegressionLoss(AbstractLossFunction): + """ + Base class for loss functions that are used for regression and classification tasks. + """ + + def __init__(self, params: dict): + super().__init__(params) + self.loss_calculator = self._initialize_loss_function_object() + self.reduction_method = self._initialize_reduction_method() + + def _initialize_reduction_method(self) -> str: + """ + Initialize the reduction method for the loss function. Defaults to 'mean'. + """ + loss_params = self.params["loss_function"] + reduction_method = "mean" + if isinstance(loss_params, dict): + reduction_method = loss_params.get("reduction", reduction_method) + assert reduction_method in [ + "mean", + "sum", + ], f"Invalid reduction method defined for loss function: {reduction_method}. Valid options are ['mean', 'sum']" + return reduction_method + + def _calculate_loss_for_single_class( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Calculate loss for a single class. To be implemented by child classes. + """ + return self.loss_calculator(prediction, target) + + @abstractmethod + def _initialize_loss_function_object(self) -> nn.modules.loss._Loss: + """ + Initialize the loss function object used in the forward method. Has to return + callable pytorch loss function object. + """ + pass + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + accumulated_loss = torch.tensor(0.0, device=prediction.device) + for class_idx in range(self.num_classes): + accumulated_loss += ( + self._calculate_loss_for_single_class( + prediction[:, class_idx, ...], target[:, class_idx, ...] + ) + * self.penalty_weights[class_idx] + ) + + # TODO I Believe this is how it should be, also for segmentation - take average from all classes, despite weights being present or no + accumulated_loss /= self.num_classes + + return accumulated_loss diff --git a/GANDLF/losses/regression.py b/GANDLF/losses/regression.py index 6d74a33a2..593090319 100644 --- a/GANDLF/losses/regression.py +++ b/GANDLF/losses/regression.py @@ -1,8 +1,70 @@ from typing import Optional import torch +from torch import nn import torch.nn.functional as F -from torch.nn import CrossEntropyLoss from GANDLF.utils import one_hot +from GANDLF.losses.loss_interface import AbstractRegressionLoss + + +class CrossEntropyLoss(AbstractRegressionLoss): + """ + This class computes the cross entropy loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.CrossEntropyLoss(reduction=self.reduction_method) + + +class BinaryCrossEntropyLoss(AbstractRegressionLoss): + """ + This class computes the binary cross entropy loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.BCELoss(reduction=self.reduction_method) + + +class BinaryCrossEntropyWithLogitsLoss(AbstractRegressionLoss): + """ + This class computes the binary cross entropy loss with logits between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.BCEWithLogitsLoss(reduction=self.reduction_method) + + +class BaseLossWithScaledTarget(AbstractRegressionLoss): + """ + General interface for the loss functions requiring scaling of the target tensor. + """ + + def _initialize_scaling_factor(self): + loss_params: dict = self.params["loss_function"] + self.scaling_factor = loss_params.get("scaling_factor", 1.0) + if isinstance(loss_params, dict): + self.scaling_factor = loss_params.get("scaling_factor", self.scaling_factor) + return self.scaling_factor + + def _calculate_loss(self, prediction: torch.Tensor, target: torch.Tensor): + return self.loss_calculator(prediction, target * self.scaling_factor) + + +class L1Loss(BaseLossWithScaledTarget): + """ + This class computes the L1 loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.L1Loss(reduction=self.reduction_method) + + +class MSELoss(BaseLossWithScaledTarget): + """ + This class computes the mean squared error loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.MSELoss(reduction=self.reduction_method) def CEL( diff --git a/GANDLF/losses/segmentation_new.py b/GANDLF/losses/segmentation_new.py index e68965848..10133196d 100644 --- a/GANDLF/losses/segmentation_new.py +++ b/GANDLF/losses/segmentation_new.py @@ -91,8 +91,12 @@ class MulticlassTverskyLoss(AbstractSegmentationMultiClassLoss): def __init__(self, params: dict): super().__init__(params) - self.alpha = params.get("alpha", 0.5) - self.beta = params.get("beta", 0.5) + loss_params = params["loss_function"] + self.alpha = 0.5 + self.beta = 0.5 + if isinstance(loss_params, dict): + self.alpha = loss_params.get("alpha", self.alpha) + self.beta = loss_params.get("beta", self.beta) def _single_class_loss_calculator( self, prediction: torch.Tensor, target: torch.Tensor From d2578ad88ecc49ef192477334c30cb1338a39918 Mon Sep 17 00:00:00 2001 From: Szymon Mazurek Date: Mon, 18 Nov 2024 14:20:41 +0100 Subject: [PATCH 6/9] Move losses to separate file --- GANDLF/losses/regression.py | 63 +------------------------------------ 1 file changed, 1 insertion(+), 62 deletions(-) diff --git a/GANDLF/losses/regression.py b/GANDLF/losses/regression.py index 593090319..62f80bffd 100644 --- a/GANDLF/losses/regression.py +++ b/GANDLF/losses/regression.py @@ -3,68 +3,7 @@ from torch import nn import torch.nn.functional as F from GANDLF.utils import one_hot -from GANDLF.losses.loss_interface import AbstractRegressionLoss - - -class CrossEntropyLoss(AbstractRegressionLoss): - """ - This class computes the cross entropy loss between two tensors. - """ - - def _initialize_loss_function_object(self): - return nn.CrossEntropyLoss(reduction=self.reduction_method) - - -class BinaryCrossEntropyLoss(AbstractRegressionLoss): - """ - This class computes the binary cross entropy loss between two tensors. - """ - - def _initialize_loss_function_object(self): - return nn.BCELoss(reduction=self.reduction_method) - - -class BinaryCrossEntropyWithLogitsLoss(AbstractRegressionLoss): - """ - This class computes the binary cross entropy loss with logits between two tensors. - """ - - def _initialize_loss_function_object(self): - return nn.BCEWithLogitsLoss(reduction=self.reduction_method) - - -class BaseLossWithScaledTarget(AbstractRegressionLoss): - """ - General interface for the loss functions requiring scaling of the target tensor. - """ - - def _initialize_scaling_factor(self): - loss_params: dict = self.params["loss_function"] - self.scaling_factor = loss_params.get("scaling_factor", 1.0) - if isinstance(loss_params, dict): - self.scaling_factor = loss_params.get("scaling_factor", self.scaling_factor) - return self.scaling_factor - - def _calculate_loss(self, prediction: torch.Tensor, target: torch.Tensor): - return self.loss_calculator(prediction, target * self.scaling_factor) - - -class L1Loss(BaseLossWithScaledTarget): - """ - This class computes the L1 loss between two tensors. - """ - - def _initialize_loss_function_object(self): - return nn.L1Loss(reduction=self.reduction_method) - - -class MSELoss(BaseLossWithScaledTarget): - """ - This class computes the mean squared error loss between two tensors. - """ - - def _initialize_loss_function_object(self): - return nn.MSELoss(reduction=self.reduction_method) +from torch.nn import CrossEntropyLoss def CEL( From 2d661b9dd1c668c4211e380edaf89b49088baea7 Mon Sep 17 00:00:00 2001 From: Szymon Mazurek Date: Mon, 18 Nov 2024 14:35:09 +0100 Subject: [PATCH 7/9] Hybrid losses implementation --- GANDLF/losses/hybrid.py | 1 - GANDLF/losses/hybrid_new.py | 21 +++++++++++ GANDLF/losses/loss_interface.py | 22 ++++++++++++ GANDLF/losses/regression_new.py | 64 +++++++++++++++++++++++++++++++++ 4 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 GANDLF/losses/hybrid_new.py create mode 100644 GANDLF/losses/regression_new.py diff --git a/GANDLF/losses/hybrid.py b/GANDLF/losses/hybrid.py index ddf62fa01..f4c862606 100644 --- a/GANDLF/losses/hybrid.py +++ b/GANDLF/losses/hybrid.py @@ -1,5 +1,4 @@ import torch - from .segmentation import MCD_loss, FocalLoss from .regression import CCE_Generic, CE, CE_Logits diff --git a/GANDLF/losses/hybrid_new.py b/GANDLF/losses/hybrid_new.py new file mode 100644 index 000000000..4fa7edfcc --- /dev/null +++ b/GANDLF/losses/hybrid_new.py @@ -0,0 +1,21 @@ +from .regression_new import BinaryCrossEntropyLoss, BinaryCrossEntropyWithLogitsLoss +from .segmentation_new import MulticlassDiceLoss, MulticlassFocalLoss +from .loss_interface import AbstractHybridLoss + + +class DiceCrossEntropyLoss(AbstractHybridLoss): + def _initialize_all_loss_calculators(self): + return [MulticlassDiceLoss(self.params), BinaryCrossEntropyLoss(self.params)] + + +class DiceCrossEntropyLossLogits(AbstractHybridLoss): + def _initialize_all_loss_calculators(self): + return [ + MulticlassDiceLoss(self.params), + BinaryCrossEntropyWithLogitsLoss(self.params), + ] + + +class DiceFocalLoss(AbstractHybridLoss): + def _initialize_all_loss_calculators(self): + return [MulticlassDiceLoss(self.params), MulticlassFocalLoss(self.params)] diff --git a/GANDLF/losses/loss_interface.py b/GANDLF/losses/loss_interface.py index 4aeeedb51..bb983f2ed 100644 --- a/GANDLF/losses/loss_interface.py +++ b/GANDLF/losses/loss_interface.py @@ -1,6 +1,7 @@ import torch from torch import nn from abc import ABC, abstractmethod +from typing import List class AbstractLossFunction(nn.Module, ABC): @@ -126,3 +127,24 @@ def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tenso accumulated_loss /= self.num_classes return accumulated_loss + + +class AbstractHybridLoss(AbstractLossFunction): + def __init__(self, params: dict): + super().__init__(params) + self.loss_calculators = self._initialize_all_loss_calculators() + + @abstractmethod + def _initialize_all_loss_calculators(self) -> List[AbstractLossFunction]: + """ + Each hybrid loss should implement this method, creating all loss functions as a list that + will be used during the forward pass. + """ + pass + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + accumulated_loss = torch.tensor(0.0, device=prediction.device) + for loss_calculator in self._initialize_all_loss_calculators(): + accumulated_loss += loss_calculator(prediction, target) + + return accumulated_loss diff --git a/GANDLF/losses/regression_new.py b/GANDLF/losses/regression_new.py new file mode 100644 index 000000000..e9e0d5db0 --- /dev/null +++ b/GANDLF/losses/regression_new.py @@ -0,0 +1,64 @@ +import torch +from torch import nn +from .loss_interface import AbstractRegressionLoss + + +class CrossEntropyLoss(AbstractRegressionLoss): + """ + This class computes the cross entropy loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.CrossEntropyLoss(reduction=self.reduction_method) + + +class BinaryCrossEntropyLoss(AbstractRegressionLoss): + """ + This class computes the binary cross entropy loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.BCELoss(reduction=self.reduction_method) + + +class BinaryCrossEntropyWithLogitsLoss(AbstractRegressionLoss): + """ + This class computes the binary cross entropy loss with logits between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.BCEWithLogitsLoss(reduction=self.reduction_method) + + +class BaseLossWithScaledTarget(AbstractRegressionLoss): + """ + General interface for the loss functions requiring scaling of the target tensor. + """ + + def _initialize_scaling_factor(self): + loss_params: dict = self.params["loss_function"] + self.scaling_factor = loss_params.get("scaling_factor", 1.0) + if isinstance(loss_params, dict): + self.scaling_factor = loss_params.get("scaling_factor", self.scaling_factor) + return self.scaling_factor + + def _calculate_loss(self, prediction: torch.Tensor, target: torch.Tensor): + return self.loss_calculator(prediction, target * self.scaling_factor) + + +class L1Loss(BaseLossWithScaledTarget): + """ + This class computes the L1 loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.L1Loss(reduction=self.reduction_method) + + +class MSELoss(BaseLossWithScaledTarget): + """ + This class computes the mean squared error loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.MSELoss(reduction=self.reduction_method) From 9ac300c24ca6f3151437bd47405fe1b5658f48a8 Mon Sep 17 00:00:00 2001 From: Szymon Mazurek Date: Mon, 18 Nov 2024 22:01:59 +0100 Subject: [PATCH 8/9] Fix docstrings, remove todos --- GANDLF/losses/loss_interface.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/GANDLF/losses/loss_interface.py b/GANDLF/losses/loss_interface.py index bb983f2ed..69dc4360b 100644 --- a/GANDLF/losses/loss_interface.py +++ b/GANDLF/losses/loss_interface.py @@ -41,8 +41,8 @@ def _compute_single_class_loss( def _optional_loss_operations(self, loss: torch.Tensor) -> torch.Tensor: """ - Perform addtional operations of the loss value. Defaults to identity operation. - If needed, child classes can override this method. Useful in the cases where + Perform addtional operations on the loss value. Defaults to identity operation. + If needed, child classes can override this method. Useful in cases where for example, the loss value needs to log-transformed or clipped. """ return loss @@ -66,9 +66,7 @@ def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tenso * self.penalty_weights[class_idx] ) - # TODO shouldn't we always divide by the number of classes? - if self.penalty_weights is None: - accumulated_loss /= self.num_classes + accumulated_loss /= self.num_classes return accumulated_loss @@ -123,7 +121,6 @@ def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tenso * self.penalty_weights[class_idx] ) - # TODO I Believe this is how it should be, also for segmentation - take average from all classes, despite weights being present or no accumulated_loss /= self.num_classes return accumulated_loss From 96b64e461af5409991de18c04389e919c867c8a8 Mon Sep 17 00:00:00 2001 From: Szymon Mazurek Date: Tue, 19 Nov 2024 06:41:31 +0100 Subject: [PATCH 9/9] Cleaning up --- GANDLF/losses/loss_interface.py | 18 ++++++++++++------ GANDLF/losses/regression.py | 1 - GANDLF/losses/segmentation_new.py | 10 +++++----- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/GANDLF/losses/loss_interface.py b/GANDLF/losses/loss_interface.py index 69dc4360b..e8459f41d 100644 --- a/GANDLF/losses/loss_interface.py +++ b/GANDLF/losses/loss_interface.py @@ -19,12 +19,14 @@ def _initialize_penalty_weights(self): @abstractmethod def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - pass + """ + Forward pass of the loss function. To be implemented by child classes. + """ -class AbstractSegmentationMultiClassLoss(AbstractLossFunction): +class AbstractSegmentationLoss(AbstractLossFunction): """ - Base class for loss funcions that are used for multi-class segmentation tasks. + Base class for loss funcions that are used for segmentation tasks. """ def __init__(self, params: dict): @@ -51,8 +53,9 @@ def _optional_loss_operations(self, loss: torch.Tensor) -> torch.Tensor: def _single_class_loss_calculator( self, prediction: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: - """Compute loss for a pair of prediction and target tensors. To be implemented by child classes.""" - pass + """ + Compute loss for a pair of prediction and target tensors. To be implemented by child classes. + """ def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: accumulated_loss = torch.tensor(0.0, device=prediction.device) @@ -109,7 +112,6 @@ def _initialize_loss_function_object(self) -> nn.modules.loss._Loss: Initialize the loss function object used in the forward method. Has to return callable pytorch loss function object. """ - pass def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: accumulated_loss = torch.tensor(0.0, device=prediction.device) @@ -127,6 +129,10 @@ def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tenso class AbstractHybridLoss(AbstractLossFunction): + """ + Base class for hybrid loss functions that are used for segmentation tasks. + """ + def __init__(self, params: dict): super().__init__(params) self.loss_calculators = self._initialize_all_loss_calculators() diff --git a/GANDLF/losses/regression.py b/GANDLF/losses/regression.py index 62f80bffd..4949bd9d2 100644 --- a/GANDLF/losses/regression.py +++ b/GANDLF/losses/regression.py @@ -1,6 +1,5 @@ from typing import Optional import torch -from torch import nn import torch.nn.functional as F from GANDLF.utils import one_hot from torch.nn import CrossEntropyLoss diff --git a/GANDLF/losses/segmentation_new.py b/GANDLF/losses/segmentation_new.py index 10133196d..4999686fe 100644 --- a/GANDLF/losses/segmentation_new.py +++ b/GANDLF/losses/segmentation_new.py @@ -1,9 +1,9 @@ import sys import torch -from .loss_interface import AbstractSegmentationMultiClassLoss, AbstractLossFunction +from .loss_interface import AbstractSegmentationLoss, AbstractLossFunction -class MulticlassDiceLoss(AbstractSegmentationMultiClassLoss): +class MulticlassDiceLoss(AbstractSegmentationLoss): """ This class computes the Dice loss between two tensors. """ @@ -39,7 +39,7 @@ def _optional_loss_operations(self, loss): ) # epsilon for numerical stability -class MulticlassMCCLoss(AbstractSegmentationMultiClassLoss): +class MulticlassMCCLoss(AbstractSegmentationLoss): """ This class computes the Matthews Correlation Coefficient (MCC) loss between two tensors. """ @@ -84,7 +84,7 @@ def _optional_loss_operations(self, loss): ) # epsilon for numerical stability -class MulticlassTverskyLoss(AbstractSegmentationMultiClassLoss): +class MulticlassTverskyLoss(AbstractSegmentationLoss): """ This class computes the Tversky loss between two tensors. """ @@ -127,7 +127,7 @@ def _single_class_loss_calculator( return loss -class MulticlassFocalLoss(AbstractSegmentationMultiClassLoss): +class MulticlassFocalLoss(AbstractSegmentationLoss): """ This class computes the Focal loss between two tensors. """