-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathloss.py
206 lines (182 loc) · 7.75 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import torch
import torch.nn as nn
class BinaryCrossEntropyLoss(nn.Module):
"""
This class implements the binary cross entropy loss.
"""
def __init__(self, bootstrap: bool = True, k_top: float = 0.8,
label_smoothing: float = 0., w_0: float = 1, w_1: float = 1) -> None:
"""
Constructor method
:param bootstrap: (bool) If true the bootstrap version is utilized
:param k_top: (float) K top percent of the samples are used
:param label_smoothing: (float) Label smoothing factor to be applied
:param w_0: (float) Weight for label 0
:param w_1: (float) Weight for label 1
"""
# Call super constructor
super(BinaryCrossEntropyLoss, self).__init__()
self.bootstrap = bootstrap
self.k_top = k_top
self.w_0 = w_0
self.w_1 = w_1
# Init label smoothing
self.label_smoothing = None if label_smoothing <= 0.0 else LabelSmoothing(label_smoothing=label_smoothing)
def __repr__(self):
"""
Get representation of the loss module
:return: (str) String including information
"""
return self.__class__.__name__
def forward(self, prediction: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
"""
Forward pass computes the binary cross entropy loss of segmentation masks
:param prediction: (torch.Tensor) Prediction probability
:param label: (torch.Tensor) Label one-hot encoded
:return: (torch.Tensor) Loss value
"""
# Perform label smoothing if utilized
if self.label_smoothing is not None:
label = self.label_smoothing(label, 1)
# Calc binary cross entropy loss
loss = -(self.w_0 * label * torch.log(prediction.clamp(min=1e-06, max=1. - 1e-06))
+ self.w_1 * (1.0 - label) * torch.log((1. - prediction.clamp(min=1e-06, max=1. - 1e-06))))
# Perform bootstrapping
if self.bootstrap:
# Flatten loss values
loss = loss.view(-1)
# Sort loss values and get k top elements
loss = loss.sort(descending=True)[0][:int(loss.shape[0] * self.k_top)]
return loss.mean()
class CrossEntropyLoss(nn.Module):
"""
This class implements the multi class cross entropy loss.
"""
def __init__(self) -> None:
"""
Constructor method
"""
# Call super constructor
super(CrossEntropyLoss, self).__init__()
def __repr__(self):
"""
Get representation of the loss module
:return: (str) String including information
"""
return self.__class__.__name__
def forward(self, prediction: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
"""
Forward pass computes the binary cross entropy loss of segmentation masks
:param prediction: (torch.Tensor) Prediction probability
:param label: (torch.Tensor) Label one-hot encoded
:return: (torch.Tensor) Loss value
"""
# Calc multi class cross entropy loss
loss = - (label * torch.log(prediction.clamp(min=1e-06, max=1. - 1e-06))).sum(dim=0)
return loss.mean()
class BinaryFocalLoss(nn.Module):
"""
This class implements the segmentation focal loss.
Paper: https://arxiv.org/abs/1708.02002
Source: https://github.com/ChristophReich1996/Cell-DETR/blob/master/lossfunction.py
"""
def __init__(self, gamma: float = 2.0, bootstrap: bool = False, k_top: float = 0.5) -> None:
"""
Constructor method
:param gamma: (float) Gamma constant (see paper)
:param bootstrap: (bool) If true the bootstrap version is utilized
:param k_top: (float) K top percent of the samples are used
"""
# Call super constructor
super(BinaryFocalLoss, self).__init__()
# Save parameters
self.gamma = gamma
self.bootstrap = bootstrap
self.k_top = k_top
def __repr__(self):
"""
Get representation of the loss module
:return: (str) String including information
"""
return "{}, gamma={}".format(self.__class__.__name__, self.gamma)
def forward(self, prediction: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
"""
Forward pass computes the binary cross entropy loss of segmentation masks
:param prediction: (torch.Tensor) Prediction probability
:param label: (torch.Tensor) Label one-hot encoded
:return: (torch.Tensor) Loss value
"""
# Calc binary cross entropy loss
binary_cross_entropy_loss = -(label * torch.log(prediction.clamp(min=1e-06, max=1. - 1e-06))
+ (1.0 - label) * torch.log((1.0 - prediction).clamp(min=1e-06, max=1. - 1e-06)))
# Calc focal loss factor based on the label and the prediction
focal_factor = prediction * label + (1.0 - prediction) * (1.0 - label)
# Calc final focal loss
loss = ((1.0 - focal_factor) ** self.gamma * binary_cross_entropy_loss)
# Perform bootstrapping
if self.bootstrap:
# Flatten loss values
loss = loss.view(-1)
# Sort loss values and get k top elements
loss = loss.sort(descending=True)[0][:int(loss.shape[0] * self.k_top)]
# Perform reduction
loss = loss.mean()
return loss
class FocalLoss(nn.Module):
"""
Implementation of the multi class focal loss.
Paper: https://arxiv.org/abs/1708.02002
Source: https://github.com/ChristophReich1996/Cell-DETR/blob/master/lossfunction.py
"""
def __init__(self, gamma: float = 2.) -> None:
"""
Constructor method
:param gamma: (float) Gamma constant (see paper)
"""
# Call super constructor
super(FocalLoss, self).__init__()
# Save parameters
self.gamma = gamma
def __repr__(self):
"""
Get representation of the loss module
:return: (str) String including information
"""
return "{}, gamma={}".format(self.__class__.__name__, self.gamma)
def forward(self, prediction: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
"""
Forward pass computes the binary cross entropy loss of segmentation masks
:param prediction: (torch.Tensor) Prediction probability
:param label: (torch.Tensor) Label one-hot encoded
:return: (torch.Tensor) Loss value
"""
# Calc binary cross entropy loss
cross_entropy_loss = - (label * torch.log(prediction.clamp(min=1e-06, max=1. - 1e-06))).sum(dim=0)
# Calc focal loss factor based on the label and the prediction
focal_factor = (prediction * label + (1.0 - prediction) * (1.0 - label))
# Calc final focal loss
loss = ((1.0 - focal_factor) ** self.gamma * cross_entropy_loss).mean()
return loss
class LabelSmoothing(nn.Module):
"""
This class implements one-hot label smoothing for Dirichlet segmentation loss
"""
def __init__(self, label_smoothing: float = 0.05) -> None:
"""
Constructor method
:param label_smoothing: (float) Lab-el smoothing factor
"""
# Call super constructor
super(LabelSmoothing, self).__init__()
# Save parameters
self.label_smoothing = label_smoothing
def forward(self, label: torch.Tensor, number_of_classes: int) -> torch.Tensor:
"""
Forward pass smooths a given label
:param label: (torch.Tensor) Label
:param number_of_classes: (torch.Tensor) Number of classes
:return: (torch.Tensor) Smoothed one-hot label
"""
smooth_positive = 1.0 - self.label_smoothing
smooth_negative = self.label_smoothing / number_of_classes
return label * smooth_positive + smooth_negative