Skip to content

Commit

Permalink
Merge pull request #26 from HiLab-git/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
taigw authored Aug 21, 2022
2 parents ee51fa9 + 42feb23 commit f2dad93
Show file tree
Hide file tree
Showing 55 changed files with 1,843 additions and 679 deletions.
32 changes: 18 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,56 +1,60 @@
# PyMIC: A Pytorch-Based Toolkit for Medical Image Computing

PyMIC is a pytorch-based toolkit for medical image computing with deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with higher dimension, multiple modalities and low contrast. The toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configure files.
PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations.

Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. It was originally developed for COVID-19 pneumonia lesion segmentation from CT images. If you use this toolkit, please cite the following paper:


* G. Wang, X. Liu, C. Li, Z. Xu, J. Ruan, H. Zhu, T. Meng, K. Li, N. Huang, S. Zhang.
[A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions from CT Images.][tmi2020] IEEE Transactions on Medical Imaging. 39(8):2653-2663, 2020. DOI: [10.1109/TMI.2020.3000314][tmi2020]

[tmi2020]:https://ieeexplore.ieee.org/document/9109297


# Advantages
PyMIC provides some basic modules for medical image computing that can be share by different applications. We currently provide the following functions:
# Features
PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions:
* Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning.
* User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC.
* Easy-to-use I/O interface to read and write different 2D and 3D images.
* Various data pre-processing/transformation methods before sending a tensor into a network.
* Implementation of typical neural networks for medical image segmentation.
* Re-useable training and testing pipeline that can be transferred to different tasks.
* Various data pre-processing methods before sending a tensor into a network.
* Implementation of loss functions, especially for image segmentation.
* Implementation of evaluation metrics to get quantitative evaluation of your methods (for segmentation).
* Evaluation metrics for quantitative evaluation of your methods.

# Usage
## Requirement
* [Pytorch][torch_link] version >=1.0.1
* [TensorboardX][tbx_link] to visualize training performance
* Some common python packages such as Numpy, Pandas, SimpleITK
* See `requirements.txt` for details.

[torch_link]:https://pytorch.org/
[tbx_link]:https://github.com/lanpa/tensorboardX

## Installation
Run the following command to install the current released version of PyMIC:
Run the following command to install the latest released version of PyMIC:

```bash
pip install PYMIC
```
To install a specific version of PYMIC such as 0.2.4, run:
To install a specific version of PYMIC such as 0.3.0, run:

```bash
pip install PYMIC==0.2.4
pip install PYMIC==0.3.0
```
Alternatively, you can download the source code for the latest version. Run the following command to compile and install:

```bash
python setup.py install
```

## Examples
[PyMIC_examples][examples] provides some examples of starting to use PyMIC. For beginners, you only need to simply change the configuration files to select different datasets, networks and training methods for running the code. For advanced users, you can develop your own modules based on this package. You can find both types of examples
## How to start
* [PyMIC_examples][exp_link] shows some examples of starting to use PyMIC.
* [PyMIC_doc][docs_link] provides documentation of this project.

[examples]: https://github.com/HiLab-git/PyMIC_examples
[docs_link]:https://pymic.readthedocs.io/en/latest/
[exp_link]:https://github.com/HiLab-git/PyMIC_examples

# Projects based on PyMIC
## Projects based on PyMIC
Using PyMIC, it becomes easy to develop deep learning models for different projects, such as the following:

1, [MyoPS][myops] Winner of the MICCAI 2020 myocardial pathology segmentation (MyoPS) Challenge.
Expand Down
4 changes: 2 additions & 2 deletions pymic/loss/loss_dict_seg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division
import torch.nn as nn
from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCrossEntropyLoss
from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCELoss
from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss
from pymic.loss.seg.slsr import SLSRLoss
from pymic.loss.seg.exp_log import ExpLogLoss
from pymic.loss.seg.mse import MSELoss, MAELoss

SegLossDict = {'CrossEntropyLoss': CrossEntropyLoss,
'GeneralizedCrossEntropyLoss': GeneralizedCrossEntropyLoss,
'GeneralizedCELoss': GeneralizedCELoss,
'SLSRLoss': SLSRLoss,
'DiceLoss': DiceLoss,
'FocalDiceLoss': FocalDiceLoss,
Expand Down
25 changes: 14 additions & 11 deletions pymic/loss/seg/ce.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pymic.loss.seg.util import reshape_tensor_to_2D

class CrossEntropyLoss(nn.Module):
def __init__(self, params):
def __init__(self, params = None):
super(CrossEntropyLoss, self).__init__()
if(params is None):
self.softmax = True
Expand Down Expand Up @@ -59,41 +59,44 @@ def forward(self, loss_input_dict):
ce = torch.mean(ce)
return ce

class GeneralizedCrossEntropyLoss(nn.Module):
class GeneralizedCELoss(nn.Module):
"""
Generalized cross entropy loss to deal with noisy labels.
Z. Zhang et al. Generalized Cross Entropy Loss for Training Deep Neural Networks
with Noisy Labels, NeurIPS 2018.
"""
def __init__(self, params):
super(GeneralizedCrossEntropyLoss, self).__init__()
self.enable_pix_weight = params['GeneralizedCrossEntropyLoss_Enable_Pixel_Weight'.lower()]
self.enable_cls_weight = params['GeneralizedCrossEntropyLoss_Enable_Class_Weight'.lower()]
self.q = params['GeneralizedCrossEntropyLoss_q'.lower()]
"""
q: in (0, 1), becmomes MAE when q = 1
"""
super(GeneralizedCELoss, self).__init__()
self.enable_pix_weight = params.get('GeneralizedCELoss_Enable_Pixel_Weight', False)
self.enable_cls_weight = params.get('GeneralizedCELoss_Enable_Class_Weight', False)
self.q = params.get('GeneralizedCELoss_q', 0.5)
self.softmax = params.get('loss_softmax', True)

def forward(self, loss_input_dict):
predict = loss_input_dict['prediction']
soft_y = loss_input_dict['ground_truth']
pix_w = loss_input_dict['pixel_weight']
cls_w = loss_input_dict['class_weight']
softmax = loss_input_dict['softmax']
soft_y = loss_input_dict['ground_truth']

if(isinstance(predict, (list, tuple))):
predict = predict[0]
if(softmax):
if(self.softmax):
predict = nn.Softmax(dim = 1)(predict)
predict = reshape_tensor_to_2D(predict)
soft_y = reshape_tensor_to_2D(soft_y)
gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y

if(self.enable_cls_weight):
cls_w = loss_input_dict.get('class_weight', None)
if(cls_w is None):
raise ValueError("Class weight is enabled but not defined")
gce = torch.sum(gce * cls_w, dim = 1)
else:
gce = torch.sum(gce, dim = 1)

if(self.enable_pix_weight):
pix_w = loss_input_dict.get('pixel_weight', None)
if(pix_w is None):
raise ValueError("Pixel weight is enabled but not defined")
pix_w = reshape_tensor_to_2D(pix_w)
Expand Down
27 changes: 2 additions & 25 deletions pymic/loss/seg/mumford_shah.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,14 @@
import torch
import torch.nn as nn

class DiceLoss(nn.Module):
def __init__(self, params = None):
super(DiceLoss, self).__init__()
if(params is None):
self.softmax = True
else:
self.softmax = params.get('loss_softmax', True)

def forward(self, loss_input_dict):
predict = loss_input_dict['prediction']
soft_y = loss_input_dict['ground_truth']

if(isinstance(predict, (list, tuple))):
predict = predict[0]
if(self.softmax):
predict = nn.Softmax(dim = 1)(predict)
predict = reshape_tensor_to_2D(predict)
soft_y = reshape_tensor_to_2D(soft_y)
dice_score = get_classwise_dice(predict, soft_y)
dice_loss = 1.0 - dice_score.mean()
return dice_loss

class MumfordShahLoss(nn.Module):
"""
Implementation of Mumford Shah Loss in this paper:
Boah Kim and Jong Chul Ye, Mumford–Shah Loss Functional
Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional
for Image Segmentation With Deep Learning. IEEE TIP, 2019.
The oringial implementation is availabel at:
https://github.com/jongcye/CNN_MumfordShah_Loss
currently only 2D version is supported.
Currently only 2D version is supported.
"""
def __init__(self, params = None):
super(MumfordShahLoss, self).__init__()
Expand Down
9 changes: 5 additions & 4 deletions pymic/loss/seg/slsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
"""
Spatial Label Smoothing Regularization (SLSR) loss for learning from
noisy annotatins according to the following paper:
Minqing Zhang, Jiantao Gao et al., Characterizing Label Errors:
Confident Learning for Noisy-Labeled Image Segmentation, MICCAI 2020.
Minqing Zhang, Jiantao Gao et al.:
Characterizing Label Errors: Confident Learning for Noisy-Labeled Image
Segmentation, MICCAI 2020.
https://link.springer.com/chapter/10.1007/978-3-030-59710-8_70
"""
from __future__ import print_function, division

Expand All @@ -17,7 +19,7 @@ def __init__(self, params):
if(params is None):
params = {}
self.softmax = params.get('loss_softmax', True)
self.epsilon = params.get('slsrloss_softmax', 0.25)
self.epsilon = params.get('slsrloss_epsilon', 0.25)

def forward(self, loss_input_dict):
predict = loss_input_dict['prediction']
Expand All @@ -35,7 +37,6 @@ def forward(self, loss_input_dict):
soft_y = reshape_tensor_to_2D(soft_y)
if(pix_w is not None):
pix_w = reshape_tensor_to_2D(pix_w > 0).float()

# smooth labels for pixels in the unconfident mask
smooth_y = (soft_y - 0.5) * (0.5 - self.epsilon) / 0.5 + 0.5
smooth_y = pix_w * smooth_y + (1 - pix_w) * soft_y
Expand Down
69 changes: 65 additions & 4 deletions pymic/net/net2d/unet2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,67 @@ def forward(self, x1, x2):
x = torch.cat([x2, x1], dim=1)
return self.conv(x)

class Encoder(nn.Module):
def __init__(self, params):
super(Encoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.dropout = self.params['dropout']
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)

self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1])
self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2])
self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3])
if(len(self.ft_chns) == 5):
self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4])

def forward(self, x):
x0 = self.in_conv(x)
x1 = self.down1(x0)
x2 = self.down2(x1)
x3 = self.down3(x2)
output = [x0, x1, x2, x3]
if(len(self.ft_chns) == 5):
x4 = self.down4(x3)
output.append(x4)
return output

class Decoder(nn.Module):
def __init__(self, params):
super(Decoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.dropout = self.params['dropout']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']

assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)

if(len(self.ft_chns) == 5):
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear)
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear)
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear)
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1)

def forward(self, x):
if(len(self.ft_chns) == 5):
assert(len(x) == 5)
x0, x1, x2, x3, x4 = x
x_d3 = self.up1(x4, x3)
else:
assert(len(x) == 4)
x0, x1, x2, x3 = x
x_d3 = x3
x_d2 = self.up2(x_d3, x2)
x_d1 = self.up3(x_d2, x1)
x_d0 = self.up4(x_d1, x0)
output = self.out_conv(x_d0)
return output

class UNet2D(nn.Module):
def __init__(self, params):
super(UNet2D, self).__init__()
Expand All @@ -91,10 +152,10 @@ def __init__(self, params):
self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3])
if(len(self.ft_chns) == 5):
self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4])
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear)
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear)
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear)
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear)
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear)
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear)
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear)
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)

self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1)
if(self.deep_sup):
Expand Down
Loading

0 comments on commit f2dad93

Please sign in to comment.