Skip to content

Commit

Permalink
Merge pull request #37 from HiLab-git/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
taigw authored Feb 26, 2023
2 parents 78c1460 + ea8d405 commit e386513
Show file tree
Hide file tree
Showing 57 changed files with 936 additions and 333 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ BibTeX entry:
author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang},
title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}},
year = {2023},
url = {http://arxiv.org/abs/2208.09350},
url = {https://doi.org/10.1016/j.cmpb.2023.107398},
journal = {Computer Methods and Programs in Biomedicine},
volume = {February},
volume = {231},
pages = {107398},
}

# Features
PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions:
* Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning.
* Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, weakly-supervised and noisy-label learning.
* User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC.
* Easy-to-use I/O interface to read and write different 2D and 3D images.
* Various data pre-processing/transformation methods before sending a tensor into a network.
Expand All @@ -47,10 +47,10 @@ Run the following command to install the latest released version of PyMIC:
```bash
pip install PYMIC
```
To install a specific version of PYMIC such as 0.3.1, run:
To install a specific version of PYMIC such as 0.4.0, run:

```bash
pip install PYMIC==0.3.1
pip install PYMIC==0.4.0
```
Alternatively, you can download the source code for the latest version. Run the following command to compile and install:

Expand Down
2 changes: 2 additions & 0 deletions pymic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from __future__ import absolute_import
__version__ = "0.4.0"
4 changes: 4 additions & 0 deletions pymic/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from __future__ import absolute_import
from pymic.io.image_read_write import *
from pymic.io.nifty_dataset import *
from pymic.io.h5_dataset import *
2 changes: 2 additions & 0 deletions pymic/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from __future__ import absolute_import
from . import *
2 changes: 2 additions & 0 deletions pymic/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from __future__ import absolute_import
from . import *
2 changes: 2 additions & 0 deletions pymic/loss/cls/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from __future__ import absolute_import
from . import *
2 changes: 2 additions & 0 deletions pymic/loss/seg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from __future__ import absolute_import
from . import *
2 changes: 2 additions & 0 deletions pymic/net/cls/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from __future__ import absolute_import
from . import *
2 changes: 2 additions & 0 deletions pymic/net/net2d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from __future__ import absolute_import
from . import *
11 changes: 6 additions & 5 deletions pymic/net/net2d/unet2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class UNet2D(nn.Module):
:param class_num: (int) The class number for segmentation task.
:param bilinear: (bool) Using bilinear for up-sampling or not.
If False, deconvolution will be used for up-sampling.
:param deep_supervise: (bool) Using deep supervision for training or not.
:param multiscale_pred: (bool) Get multiscale prediction.
"""
def __init__(self, params):
super(UNet2D, self).__init__()
Expand All @@ -197,7 +197,7 @@ def __init__(self, params):
self.dropout = self.params['dropout']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
self.deep_sup = self.params['deep_supervise']
self.mul_pred = self.params['multiscale_pred']

assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)

Expand All @@ -213,7 +213,7 @@ def __init__(self, params):
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)

self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1)
if(self.deep_sup):
if(self.mul_pred):
self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1)
self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1)
self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1)
Expand All @@ -239,7 +239,7 @@ def forward(self, x):
x_d1 = self.up3(x_d2, x1)
x_d0 = self.up4(x_d1, x0)
output = self.out_conv(x_d0)
if(self.deep_sup):
if(self.mul_pred):
output1 = self.out_conv1(x_d1)
output2 = self.out_conv2(x_d2)
output3 = self.out_conv3(x_d3)
Expand All @@ -261,7 +261,8 @@ def forward(self, x):
'feature_chns':[2, 8, 32, 48, 64],
'dropout': [0, 0, 0.3, 0.4, 0.5],
'class_num': 2,
'bilinear': True}
'bilinear': True,
'multiscale_pred': False}
Net = UNet2D(params)
Net = Net.double()

Expand Down
2 changes: 2 additions & 0 deletions pymic/net/net3d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from __future__ import absolute_import
from . import *
36 changes: 22 additions & 14 deletions pymic/net/net3d/unet3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class Decoder(nn.Module):
:param class_num: (int) The class number for segmentation task.
:param trilinear: (bool) Using bilinear for up-sampling or not.
If False, deconvolution will be used for up-sampling.
:param multiscale_pred: (bool) Get multi-scale prediction.
"""
def __init__(self, params):
super(Decoder, self).__init__()
Expand All @@ -139,16 +140,21 @@ def __init__(self, params):
self.ft_chns = self.params['feature_chns']
self.dropout = self.params['dropout']
self.n_class = self.params['class_num']
self.trilinear = self.params['trilinear']
self.trilinear = self.params.get('trilinear', True)
self.mul_pred = self.params.get('multiscale_pred', False)

assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)

if(len(self.ft_chns) == 5):
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear)
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear)
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear)
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear)
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear)
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear)
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear)
self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1)
if(self.mul_pred):
self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1)
self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1)
self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1)

def forward(self, x):
if(len(self.ft_chns) == 5):
Expand All @@ -163,6 +169,11 @@ def forward(self, x):
x_d1 = self.up3(x_d2, x1)
x_d0 = self.up4(x_d1, x0)
output = self.out_conv(x_d0)
if(self.mul_pred):
output1 = self.out_conv1(x_d1)
output2 = self.out_conv2(x_d2)
output3 = self.out_conv3(x_d3)
output = [output, output1, output2, output3]
return output

class UNet3D(nn.Module):
Expand All @@ -187,7 +198,7 @@ class UNet3D(nn.Module):
:param class_num: (int) The class number for segmentation task.
:param trilinear: (bool) Using trilinear for up-sampling or not.
If False, deconvolution will be used for up-sampling.
:param deep_supervise: (bool) Using deep supervision for training or not.
:param multiscale_pred: (bool) Get multi-scale prediction.
"""
def __init__(self, params):
super(UNet3D, self).__init__()
Expand All @@ -197,7 +208,7 @@ def __init__(self, params):
self.dropout = self.params['dropout']
self.n_class = self.params['class_num']
self.trilinear = self.params['trilinear']
self.deep_sup = self.params['deep_supervise']
self.mul_pred = self.params['multiscale_pred']
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)

self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
Expand All @@ -216,7 +227,7 @@ def __init__(self, params):
dropout_p = self.dropout[0], trilinear=self.trilinear)

self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1)
if(self.deep_sup):
if(self.mul_pred):
self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1)
self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1)
self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1)
Expand All @@ -235,14 +246,10 @@ def forward(self, x):
x_d1 = self.up3(x_d2, x1)
x_d0 = self.up4(x_d1, x0)
output = self.out_conv(x_d0)
if(self.deep_sup):
out_shape = list(output.shape)[2:]
if(self.mul_pred):
output1 = self.out_conv1(x_d1)
output1 = interpolate(output1, out_shape, mode = 'trilinear')
output2 = self.out_conv2(x_d2)
output2 = interpolate(output2, out_shape, mode = 'trilinear')
output3 = self.out_conv3(x_d3)
output3 = interpolate(output3, out_shape, mode = 'trilinear')
output = [output, output1, output2, output3]
return output

Expand All @@ -251,7 +258,8 @@ def forward(self, x):
'class_num': 2,
'feature_chns':[2, 8, 32, 64],
'dropout' : [0, 0, 0, 0.5],
'trilinear': True}
'trilinear': True,
'multiscale_pred': False}
Net = UNet3D(params)
Net = Net.double()

Expand Down
2 changes: 2 additions & 0 deletions pymic/net_run/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from __future__ import absolute_import
from . import *
8 changes: 4 additions & 4 deletions pymic/net_run/agent_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def worker_init_fn(worker_id):
self.test_loader = torch.utils.data.DataLoader(self.test_set,
batch_size = bn_test, shuffle=False, num_workers= bn_test)

def create_optimizer(self, params):
def create_optimizer(self, params, checkpoint = None):
"""
Create optimizer based on configuration.
Expand All @@ -288,9 +288,9 @@ def create_optimizer(self, params):
self.optimizer = get_optimizer(opt_params['optimizer'],
params, opt_params)
last_iter = -1
if(self.checkpoint is not None):
self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict'])
last_iter = self.checkpoint['iteration'] - 1
if(checkpoint is not None):
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
last_iter = checkpoint['iteration'] - 1
if(self.scheduler is None):
opt_params["last_iter"] = last_iter
self.scheduler = get_lr_scheduler(self.optimizer, opt_params)
Expand Down
37 changes: 23 additions & 14 deletions pymic/net_run/agent_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,6 @@ def training(self):
loss = self.get_loss_value(data, outputs, labels)
loss.backward()
self.optimizer.step()
if(self.scheduler is not None and \
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
self.scheduler.step()

# statistics
sample_num += labels.size(0)
Expand All @@ -183,7 +180,7 @@ def validation(self):
inputs = self.convert_tensor_type(data['image'])
labels = self.convert_tensor_type(data['label_prob'])
inputs, labels = inputs.to(self.device), labels.to(self.device)
self.optimizer.zero_grad()
# self.optimizer.zero_grad()
# forward + backward + optimize
outputs = self.net(inputs)
loss = self.get_loss_value(data, outputs, labels)
Expand All @@ -196,20 +193,17 @@ def validation(self):
avg_loss = running_loss / sample_num
avg_score= running_score.double() / sample_num
metrics = self.config['training'].get("evaluation_metric", "accuracy")
if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
self.scheduler.step(avg_score)
valid_scalers = {'loss': avg_loss, metrics: avg_score}
return valid_scalers

def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
metrics =self.config['training'].get("evaluation_metric", "accuracy")
metrics = self.config['training'].get("evaluation_metric", "accuracy")
loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']}
acc_scalar ={'train':train_scalars[metrics],'valid':valid_scalars[metrics]}
self.summ_writer.add_scalars('loss', loss_scalar, glob_it)
self.summ_writer.add_scalars(metrics, acc_scalar, glob_it)
self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it)

logging.info("{0:} it {1:}".format(str(datetime.now())[:-7], glob_it))
logging.info('train loss {0:.4f}, avg {1:} {2:.4f}'.format(
train_scalars['loss'], metrics, train_scalars[metrics]))
logging.info('valid loss {0:.4f}, avg {1:} {2:.4f}'.format(
Expand Down Expand Up @@ -251,7 +245,10 @@ def train_valid(self):
checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start)
self.checkpoint = torch.load(checkpoint_file, map_location = self.device)
assert(self.checkpoint['iteration'] == iter_start)
self.net.load_state_dict(self.checkpoint['model_state_dict'])
if(len(device_ids) > 1):
self.net.module.load_state_dict(self.checkpoint['model_state_dict'])
else:
self.net.load_state_dict(self.checkpoint['model_state_dict'])
self.max_val_score = self.checkpoint.get('valid_pred', 0)
self.max_val_it = self.checkpoint['iteration']
self.best_model_wts = self.checkpoint['model_state_dict']
Expand All @@ -266,15 +263,28 @@ def train_valid(self):
self.glob_it = iter_start
for it in range(iter_start, iter_max, iter_valid):
lr_value = self.optimizer.param_groups[0]['lr']
t0 = time.time()
train_scalars = self.training()
t1 = time.time()
valid_scalars = self.validation()
t2 = time.time()
if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
self.scheduler.step(valid_scalars[metrics])
else:
self.scheduler.step()

self.glob_it = it + iter_valid
logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it))
logging.info('learning rate {0:}'.format(lr_value))
logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1))
self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it)

if(valid_scalars[metrics] > self.max_val_score):
self.max_val_score = valid_scalars[metrics]
self.max_val_it = self.glob_it
self.best_model_wts = copy.deepcopy(self.net.state_dict())
if(len(device_ids) > 1):
self.best_model_wts = copy.deepcopy(self.net.module.state_dict())
else:
self.best_model_wts = copy.deepcopy(self.net.state_dict())

stop_now = True if(early_stop_it is not None and \
self.glob_it - self.max_val_it > early_stop_it) else False
Expand Down Expand Up @@ -306,7 +316,6 @@ def train_valid(self):
self.max_val_it, metrics, self.max_val_score))
self.summ_writer.close()


def infer(self):
device_ids = self.config['testing']['gpus']
device = torch.device("cuda:{0:}".format(device_ids[0]))
Expand All @@ -318,8 +327,8 @@ def infer(self):

if(self.config['testing'].get('evaluation_mode', True)):
self.net.eval()
output_csv = self.config['testing']['output_csv']

output_csv = self.config['testing']['output_dir'] + '/' + self.config['testing']['output_csv']
class_num = self.config['network']['class_num']
save_probability = self.config['testing'].get('save_probability', False)

Expand Down
Loading

0 comments on commit e386513

Please sign in to comment.