Skip to content

Commit

Permalink
Merge pull request #31 from HiLab-git/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
taigw authored Dec 6, 2022
2 parents a114a79 + 1e79947 commit 1ecadee
Show file tree
Hide file tree
Showing 15 changed files with 235 additions and 257 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ Run the following command to install the latest released version of PyMIC:
```bash
pip install PYMIC
```
To install a specific version of PYMIC such as 0.3.0, run:
To install a specific version of PYMIC such as 0.3.1, run:

```bash
pip install PYMIC==0.3.0
pip install PYMIC==0.3.1
```
Alternatively, you can download the source code for the latest version. Run the following command to compile and install:

Expand Down
10 changes: 2 additions & 8 deletions pymic/loss/cls/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ def forward(self, loss_input_dict):
labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1
softmax = nn.Softmax(dim = 1)
predict = softmax(predict)
num_class = list(predict.size())[1]
data_type = 'float' if(predict.dtype is torch.float32) else 'double'
soft_y = get_soft_label(labels, num_class, data_type)
loss = self.l1_loss(predict, soft_y)
loss = self.l1_loss(predict, labels)
return loss

class MSELoss(AbstractClassificationLoss):
Expand All @@ -84,10 +81,7 @@ def forward(self, loss_input_dict):
labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1
softmax = nn.Softmax(dim = 1)
predict = softmax(predict)
num_class = list(predict.size())[1]
data_type = 'float' if(predict.dtype is torch.float32) else 'double'
soft_y = get_soft_label(labels, num_class, data_type)
loss = self.mse_loss(predict, soft_y)
loss = self.mse_loss(predict, labels)
return loss

class NLLLoss(AbstractClassificationLoss):
Expand Down
1 change: 0 additions & 1 deletion pymic/loss/seg/ce.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class CrossEntropyLoss(AbstractSegLoss):
"""
def __init__(self, params = None):
super(CrossEntropyLoss, self).__init__(params)


def forward(self, loss_input_dict):
predict = loss_input_dict['prediction']
Expand Down
5 changes: 3 additions & 2 deletions pymic/loss/seg/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import torch.nn as nn
import numpy as np
from pymic.loss.seg.util import reshape_tensor_to_2D
from pymic.loss.seg.abstract import AbstractSegLoss

class EntropyLoss(nn.Module):
class EntropyLoss(AbstractSegLoss):
"""
Entropy Minimization for segmentation tasks.
The parameters should be written in the `params` dictionary, and it has the
Expand Down Expand Up @@ -43,7 +44,7 @@ def forward(self, loss_input_dict):
avg_ent = torch.mean(entropy)
return avg_ent

class TotalVariationLoss(nn.Module):
class TotalVariationLoss(AbstractSegLoss):
"""
Total Variation Loss for segmentation tasks.
The parameters should be written in the `params` dictionary, and it has the
Expand Down
8 changes: 4 additions & 4 deletions pymic/net/cls/torch_pretrained_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, params):
def get_parameters_to_update(self):
if(self.update_mode == "all"):
return self.net.parameters()
elif(self.update_layers == "last"):
elif(self.update_mode == "last"):
params = self.net.fc.parameters()
if(self.in_chns !=3):
# combining the two iterables into a single one
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_parameters_to_update(self):
params = self.net.classifier[-1].parameters()
if(self.in_chns !=3):
params = itertools.chain()
for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0].parameters()]:
for pram in [self.net.classifier[-1].parameters(), self.net.features[0].parameters()]:
params = itertools.chain(params, pram)
return params
else:
Expand All @@ -138,7 +138,7 @@ class MobileNetV2(BuiltInNet):
as well as the first layer when `input_chns` is not 3.
"""
def __init__(self, params):
super(MobileNetV2, self).__init__()
super(MobileNetV2, self).__init__(params)
self.net = models.mobilenet_v2(pretrained = self.pretrain)

# replace the last layer
Expand All @@ -157,7 +157,7 @@ def get_parameters_to_update(self):
params = self.net.classifier[-1].parameters()
if(self.in_chns !=3):
params = itertools.chain()
for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0][0].parameters()]:
for pram in [self.net.classifier[-1].parameters(), self.net.features[0][0].parameters()]:
params = itertools.chain(params, pram)
return params
else:
Expand Down
132 changes: 0 additions & 132 deletions pymic/net/net2d/unet2d_urpc.py

This file was deleted.

3 changes: 0 additions & 3 deletions pymic/net/net_dict_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
* UNet2D :mod:`pymic.net.net2d.unet2d.UNet2D`
* UNet2D_DualBranch :mod:`pymic.net.net2d.unet2d_dual_branch.UNet2D_DualBranch`
* UNet2D_URPC :mod:`pymic.net.net2d.unet2d_urpc.UNet2D_URPC`
* UNet2D_CCT :mod:`pymic.net.net2d.unet2d_cct.UNet2D_CCT`
* UNet2D_ScSE :mod:`pymic.net.net2d.unet2d_scse.UNet2D_ScSE`
* AttentionUNet2D :mod:`pymic.net.net2d.unet2d_attention.AttentionUNet2D`
Expand All @@ -17,7 +16,6 @@
from __future__ import print_function, division
from pymic.net.net2d.unet2d import UNet2D
from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch
from pymic.net.net2d.unet2d_urpc import UNet2D_URPC
from pymic.net.net2d.unet2d_cct import UNet2D_CCT
from pymic.net.net2d.cople_net import COPLENet
from pymic.net.net2d.unet2d_attention import AttentionUNet2D
Expand All @@ -30,7 +28,6 @@
SegNetDict = {
'UNet2D': UNet2D,
'UNet2D_DualBranch': UNet2D_DualBranch,
'UNet2D_URPC': UNet2D_URPC,
'UNet2D_CCT': UNet2D_CCT,
'COPLENet': COPLENet,
'AttentionUNet2D': AttentionUNet2D,
Expand Down
Loading

0 comments on commit 1ecadee

Please sign in to comment.