diff --git a/GANDLF/losses/hybrid.py b/GANDLF/losses/hybrid.py index ddf62fa01..f4c862606 100644 --- a/GANDLF/losses/hybrid.py +++ b/GANDLF/losses/hybrid.py @@ -1,5 +1,4 @@ import torch - from .segmentation import MCD_loss, FocalLoss from .regression import CCE_Generic, CE, CE_Logits diff --git a/GANDLF/losses/hybrid_new.py b/GANDLF/losses/hybrid_new.py new file mode 100644 index 000000000..4fa7edfcc --- /dev/null +++ b/GANDLF/losses/hybrid_new.py @@ -0,0 +1,21 @@ +from .regression_new import BinaryCrossEntropyLoss, BinaryCrossEntropyWithLogitsLoss +from .segmentation_new import MulticlassDiceLoss, MulticlassFocalLoss +from .loss_interface import AbstractHybridLoss + + +class DiceCrossEntropyLoss(AbstractHybridLoss): + def _initialize_all_loss_calculators(self): + return [MulticlassDiceLoss(self.params), BinaryCrossEntropyLoss(self.params)] + + +class DiceCrossEntropyLossLogits(AbstractHybridLoss): + def _initialize_all_loss_calculators(self): + return [ + MulticlassDiceLoss(self.params), + BinaryCrossEntropyWithLogitsLoss(self.params), + ] + + +class DiceFocalLoss(AbstractHybridLoss): + def _initialize_all_loss_calculators(self): + return [MulticlassDiceLoss(self.params), MulticlassFocalLoss(self.params)] diff --git a/GANDLF/losses/loss_interface.py b/GANDLF/losses/loss_interface.py new file mode 100644 index 000000000..e8459f41d --- /dev/null +++ b/GANDLF/losses/loss_interface.py @@ -0,0 +1,153 @@ +import torch +from torch import nn +from abc import ABC, abstractmethod +from typing import List + + +class AbstractLossFunction(nn.Module, ABC): + def __init__(self, params: dict): + nn.Module.__init__(self) + self.params = params + self.num_classes = len(params["model"]["class_list"]) + self._initialize_penalty_weights() + + def _initialize_penalty_weights(self): + default_penalty_weights = torch.ones(self.num_classes) + self.penalty_weights = self.params.get( + "penalty_weights", default_penalty_weights + ) + + @abstractmethod + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the loss function. To be implemented by child classes. + """ + + +class AbstractSegmentationLoss(AbstractLossFunction): + """ + Base class for loss funcions that are used for segmentation tasks. + """ + + def __init__(self, params: dict): + super().__init__(params) + + def _compute_single_class_loss( + self, prediction: torch.Tensor, target: torch.Tensor, class_idx: int + ) -> torch.Tensor: + """Compute loss for a single class.""" + loss_value = self._single_class_loss_calculator( + prediction[:, class_idx, ...], target[:, class_idx, ...] + ) + return 1 - loss_value + + def _optional_loss_operations(self, loss: torch.Tensor) -> torch.Tensor: + """ + Perform addtional operations on the loss value. Defaults to identity operation. + If needed, child classes can override this method. Useful in cases where + for example, the loss value needs to log-transformed or clipped. + """ + return loss + + @abstractmethod + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute loss for a pair of prediction and target tensors. To be implemented by child classes. + """ + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + accumulated_loss = torch.tensor(0.0, device=prediction.device) + + for class_idx in range(self.num_classes): + current_loss = self._compute_single_class_loss( + prediction, target, class_idx + ) + accumulated_loss += ( + self._optional_loss_operations(current_loss) + * self.penalty_weights[class_idx] + ) + + accumulated_loss /= self.num_classes + + return accumulated_loss + + +class AbstractRegressionLoss(AbstractLossFunction): + """ + Base class for loss functions that are used for regression and classification tasks. + """ + + def __init__(self, params: dict): + super().__init__(params) + self.loss_calculator = self._initialize_loss_function_object() + self.reduction_method = self._initialize_reduction_method() + + def _initialize_reduction_method(self) -> str: + """ + Initialize the reduction method for the loss function. Defaults to 'mean'. + """ + loss_params = self.params["loss_function"] + reduction_method = "mean" + if isinstance(loss_params, dict): + reduction_method = loss_params.get("reduction", reduction_method) + assert reduction_method in [ + "mean", + "sum", + ], f"Invalid reduction method defined for loss function: {reduction_method}. Valid options are ['mean', 'sum']" + return reduction_method + + def _calculate_loss_for_single_class( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Calculate loss for a single class. To be implemented by child classes. + """ + return self.loss_calculator(prediction, target) + + @abstractmethod + def _initialize_loss_function_object(self) -> nn.modules.loss._Loss: + """ + Initialize the loss function object used in the forward method. Has to return + callable pytorch loss function object. + """ + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + accumulated_loss = torch.tensor(0.0, device=prediction.device) + for class_idx in range(self.num_classes): + accumulated_loss += ( + self._calculate_loss_for_single_class( + prediction[:, class_idx, ...], target[:, class_idx, ...] + ) + * self.penalty_weights[class_idx] + ) + + accumulated_loss /= self.num_classes + + return accumulated_loss + + +class AbstractHybridLoss(AbstractLossFunction): + """ + Base class for hybrid loss functions that are used for segmentation tasks. + """ + + def __init__(self, params: dict): + super().__init__(params) + self.loss_calculators = self._initialize_all_loss_calculators() + + @abstractmethod + def _initialize_all_loss_calculators(self) -> List[AbstractLossFunction]: + """ + Each hybrid loss should implement this method, creating all loss functions as a list that + will be used during the forward pass. + """ + pass + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + accumulated_loss = torch.tensor(0.0, device=prediction.device) + for loss_calculator in self._initialize_all_loss_calculators(): + accumulated_loss += loss_calculator(prediction, target) + + return accumulated_loss diff --git a/GANDLF/losses/regression.py b/GANDLF/losses/regression.py index 6d74a33a2..4949bd9d2 100644 --- a/GANDLF/losses/regression.py +++ b/GANDLF/losses/regression.py @@ -1,8 +1,8 @@ from typing import Optional import torch import torch.nn.functional as F -from torch.nn import CrossEntropyLoss from GANDLF.utils import one_hot +from torch.nn import CrossEntropyLoss def CEL( diff --git a/GANDLF/losses/regression_new.py b/GANDLF/losses/regression_new.py new file mode 100644 index 000000000..e9e0d5db0 --- /dev/null +++ b/GANDLF/losses/regression_new.py @@ -0,0 +1,64 @@ +import torch +from torch import nn +from .loss_interface import AbstractRegressionLoss + + +class CrossEntropyLoss(AbstractRegressionLoss): + """ + This class computes the cross entropy loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.CrossEntropyLoss(reduction=self.reduction_method) + + +class BinaryCrossEntropyLoss(AbstractRegressionLoss): + """ + This class computes the binary cross entropy loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.BCELoss(reduction=self.reduction_method) + + +class BinaryCrossEntropyWithLogitsLoss(AbstractRegressionLoss): + """ + This class computes the binary cross entropy loss with logits between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.BCEWithLogitsLoss(reduction=self.reduction_method) + + +class BaseLossWithScaledTarget(AbstractRegressionLoss): + """ + General interface for the loss functions requiring scaling of the target tensor. + """ + + def _initialize_scaling_factor(self): + loss_params: dict = self.params["loss_function"] + self.scaling_factor = loss_params.get("scaling_factor", 1.0) + if isinstance(loss_params, dict): + self.scaling_factor = loss_params.get("scaling_factor", self.scaling_factor) + return self.scaling_factor + + def _calculate_loss(self, prediction: torch.Tensor, target: torch.Tensor): + return self.loss_calculator(prediction, target * self.scaling_factor) + + +class L1Loss(BaseLossWithScaledTarget): + """ + This class computes the L1 loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.L1Loss(reduction=self.reduction_method) + + +class MSELoss(BaseLossWithScaledTarget): + """ + This class computes the mean squared error loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.MSELoss(reduction=self.reduction_method) diff --git a/GANDLF/losses/segmentation_new.py b/GANDLF/losses/segmentation_new.py new file mode 100644 index 000000000..4999686fe --- /dev/null +++ b/GANDLF/losses/segmentation_new.py @@ -0,0 +1,193 @@ +import sys +import torch +from .loss_interface import AbstractSegmentationLoss, AbstractLossFunction + + +class MulticlassDiceLoss(AbstractSegmentationLoss): + """ + This class computes the Dice loss between two tensors. + """ + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute Dice score for a single class. + + Args: + prediction (torch.Tensor): Network's predicted segmentation mask + target (torch.Tensor): Target segmentation mask + + Returns: + torch.Tensor: The computed dice score. + """ + predicted_flat = prediction.flatten() + label_flat = target.flatten() + intersection = (predicted_flat * label_flat).sum() + + dice_score = (2.0 * intersection + sys.float_info.min) / ( + predicted_flat.sum() + label_flat.sum() + sys.float_info.min + ) + + return dice_score + + +class MulticlassDiceLogLoss(MulticlassDiceLoss): + def _optional_loss_operations(self, loss): + return -torch.log( + loss + torch.finfo(torch.float32).eps + ) # epsilon for numerical stability + + +class MulticlassMCCLoss(AbstractSegmentationLoss): + """ + This class computes the Matthews Correlation Coefficient (MCC) loss between two tensors. + """ + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute MCC score for a single class. + + Args: + prediction (torch.Tensor): Network's predicted segmentation mask + target (torch.Tensor): Target segmentation mask + + Returns: + torch.Tensor: The computed MCC score. + """ + tp = torch.sum(torch.mul(prediction, target)) + tn = torch.sum(torch.mul((1 - prediction), (1 - target))) + fp = torch.sum(torch.mul(prediction, (1 - target))) + fn = torch.sum(torch.mul((1 - prediction), target)) + + numerator = torch.mul(tp, tn) - torch.mul(fp, fn) + # Adding epsilon to the denominator to avoid divide-by-zero errors. + denominator = ( + torch.sqrt( + torch.add(tp, 1, fp) + * torch.add(tp, 1, fn) + * torch.add(tn, 1, fp) + * torch.add(tn, 1, fn) + ) + + torch.finfo(torch.float32).eps + ) + + return torch.div(numerator.sum(), denominator.sum()) + + +class MulticlassMCLLogLoss(MulticlassMCCLoss): + def _optional_loss_operations(self, loss): + return -torch.log( + loss + torch.finfo(torch.float32).eps + ) # epsilon for numerical stability + + +class MulticlassTverskyLoss(AbstractSegmentationLoss): + """ + This class computes the Tversky loss between two tensors. + """ + + def __init__(self, params: dict): + super().__init__(params) + loss_params = params["loss_function"] + self.alpha = 0.5 + self.beta = 0.5 + if isinstance(loss_params, dict): + self.alpha = loss_params.get("alpha", self.alpha) + self.beta = loss_params.get("beta", self.beta) + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute Tversky score for a single class. + + Args: + prediction (torch.Tensor): Network's predicted segmentation mask + target (torch.Tensor): Target segmentation mask + + Returns: + torch.Tensor: The computed Tversky score. + """ + predicted_flat = prediction.contiguous().view(-1) + target_flat = target.contiguous().view(-1) + + true_positives = (predicted_flat * target_flat).sum() + false_positives = ((1 - target_flat) * predicted_flat).sum() + false_negatives = (target_flat * (1 - predicted_flat)).sum() + + numerator = true_positives + denominator = ( + true_positives + self.alpha * false_positives + self.beta * false_negatives + ) + loss = (numerator + sys.float_info.min) / (denominator + sys.float_info.min) + + return loss + + +class MulticlassFocalLoss(AbstractSegmentationLoss): + """ + This class computes the Focal loss between two tensors. + """ + + def __init__(self, params: dict): + super().__init__(params) + + self.ce_loss_helper = torch.nn.CrossEntropyLoss(reduction="none") + loss_params = params["loss_function"] + self.alpha = 1.0 + self.gamma = 2.0 + self.output_aggregation = "sum" + if isinstance(loss_params, dict): + self.alpha = loss_params.get("alpha", self.alpha) + self.gamma = loss_params.get("gamma", self.gamma) + self.output_aggregation = loss_params.get( + "size_average", + self.output_aggregation, # naming mismatch of key due to keeping API consistent with config format + ) + assert self.output_aggregation in [ + "sum", + "mean", + ], f"Invalid output aggregation method defined for Foal Loss: {self.output_aggregation}. Valid options are ['sum', 'mean']" + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute focal loss for a single class. It is based on the following formulas: + FocalLoss(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t) + CrossEntropy(pred, target) = -log(pred) if target = 1 else -log(1 - pred) + CrossEntropy(p_t) = CrossEntropy(pred, target) = -log(p_t) + p_t = p if target = 1 else 1 - p + """ + ce_loss = self.ce_loss_helper(prediction, target) + p_t = torch.exp(-ce_loss) + loss = -self.alpha * (1 - p_t) ** self.gamma * ce_loss + return loss.sum() if self.output_aggregation == "sum" else loss.mean() + + def _compute_single_class_loss( + self, prediction: torch.Tensor, target: torch.Tensor, class_idx: int + ) -> torch.Tensor: + """Compute loss for a single class.""" + loss_value = self._single_class_loss_calculator( + prediction[:, class_idx, ...], target[:, class_idx, ...] + ) + return loss_value # no need to subtract from 1 in this case, hence the override + + +class KullbackLeiblerDivergence(AbstractLossFunction): + def forward(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + """ + Calculates the Kullback-Leibler divergence between two Gaussian distributions. + + Args: + mu (torch.Tensor): The mean of the first Gaussian distribution. + logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution. + + Returns: + torch.Tensor: The computed Kullback-Leibler divergence + """ + loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1) + return loss.mean()