Skip to content

Commit

Permalink
add more logging
Browse files Browse the repository at this point in the history
  • Loading branch information
wolny committed Jan 3, 2025
1 parent 105aaea commit 0b6a531
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pytorch3dunet/unet3d/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from torch import nn as nn
from torch.nn import MSELoss, SmoothL1Loss, L1Loss

from pytorch3dunet.unet3d.utils import get_logger

logger = get_logger('Loss')

def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None):
"""
Expand Down Expand Up @@ -278,13 +281,15 @@ def get_loss_criterion(config):
assert 'loss' in config, 'Could not find loss function configuration'
loss_config = config['loss']
name = loss_config.pop('name')
logger.info(f"Creating loss function: {name}")

ignore_index = loss_config.pop('ignore_index', None)
skip_last_target = loss_config.pop('skip_last_target', False)
weight = loss_config.pop('weight', None)

if weight is not None:
weight = torch.tensor(weight)
logger.info(f"Using class weights: {weight}")

pos_weight = loss_config.pop('pos_weight', None)
if pos_weight is not None:
Expand Down

0 comments on commit 0b6a531

Please sign in to comment.