-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhierarchical_loss.py
29 lines (22 loc) · 1.39 KB
/
hierarchical_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch.nn.functional as F
from ptsemseg.tree import getTreeList
def hierarchical_loss(cnn_output, target, root): # root represents your hierarchy
probabilities = F.softmax(cnn_output, dim = 1) ; loss = 0
precomputed_hierarchy_list = getTreeList(root) # see ptsemseg/tree.py
for level_loss_list in precomputed_hierarchy_list
probabilities_tosum = probabilities.clone()
summed_probabilities = probabilities_tosum
for branch in level_loss_list:
# Extract the relevant probabilities according to a branch in our hierarchy.
branch_probs = torch.FloatTensor()
for channel in branch:
branch_probs = torch.cat((branch_probs,probabilities_tosum[:,channel,:,:].unsqueeze(1)),1)
# Sum these probabilities into a single slice; this is hierarchical inference.
summed_tree_branch_slice = branch_probs.sum(1,keepdim=True)
# Insert inferred probability slice into each channel of summed_probabilities given by branch.
# This duplicates probabilities for easy passing to standard loss functions such as nll_loss.
for channel in branch:
summed_probabilities[:,channel:(channel+1),:,:] = summed_tree_branch_slice
level_loss = F.nll_loss(log(summed_probabilities), target)
loss = loss + level_loss
return(loss)