Official code for paper: Few-Shot Medical Image Segmentation via a Region-enhanced Prototypical Transformer
Automated segmentation of large volumes of medical images is often plagued by the limited availability of fully annotated data and the diversity of organ surface properties resulting from the use of different acquisition protocols for different patients. In this paper, we introduce a more promising few-shot learning-based method named Region-enhanced Prototypical Transformer (RPT) to mitigate the effects of large intra-class diversity/bias. First, a subdivision strategy is introduced to produce a collection of regional prototypes from the foreground of the support prototype. Second, a self-selection mechanism is proposed to incorporate into the Bias-alleviated Transformer (BaT) block to suppress or remove interferences present in the query prototype and regional support prototypes. By stacking BaT blocks, the proposed RPT can iteratively optimize the generated regional prototypes and finally produce rectified and more accurate global prototypes for Few-Shot Medical Image Segmentation (FSMS). Extensive experiments are conducted on three publicly available medical image datasets, and the obtained results show consistent improvements compared to state-of-the-art FSMS methods.
Please install following essential dependencies:
dcm2nii
json5==0.8.5
jupyter==1.0.0
nibabel==2.5.1
numpy==1.22.0
opencv-python==4.5.5.62
Pillow>=8.1.1
sacred==0.8.2
scikit-image==0.18.3
SimpleITK==1.2.3
torch==1.10.2
torchvision=0.11.2
tqdm==4.62.3
Pre-processing is performed according to Ouyang et al. and we follow the procedure on their github repository.
The trained models can be downloaded by:
- trained models for CHAOS under Setting 1
- trained models for CHAOS under Setting 2
- trained models for SABS under Setting 1
- trained models for SABS under Setting 2
- trained models for CMR
The pre-processed data and supervoxels can be downloaded by:
- Pre-processed CHAOS-T2 data and supervoxels
- Pre-processed SABS data and supervoxels
- Pre-processed CMR data and supervoxels
- Compile
./supervoxels/felzenszwalb_3d_cy.pyx
with cython (python ./supervoxels/setup.py build_ext --inplace
) and run./supervoxels/generate_supervoxels.py
- Download pre-trained ResNet-101 weights vanilla version or deeplabv3 version and put your checkpoints folder, then replace the absolute path in the code
./models/encoder.py
. - Run
./script/train.sh
Run ./script/test.sh
Our code is based the works: SSL-ALPNet, ADNet and QNet
@inproceedings{zhu2023few,
title={Few-Shot Medical Image Segmentation via a Region-Enhanced Prototypical Transformer},
author={Zhu, Yazhou and Wang, Shidong and Xin, Tong and Zhang, Haofeng},
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
pages={271--280},
year={2023},
organization={Springer}
}