Skip to content

Latest commit

 

History

History
82 lines (68 loc) · 3.56 KB

File metadata and controls

82 lines (68 loc) · 3.56 KB

Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging

This repository contains the code to reproduce the results from the paper "Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging", which was accepted to the 15th International Workshop on Machine Learning in Medical Imaging (MLMI 2024).

We present a comprehensive performance comparison of dependency measures to prevent shortcut learning in medical imaging.

drawing

Installation

Set up a python environment with a python version 3.10. Then, download the repository, activate the environment and install all other dependencies with

cd dependence-measures-medical-imaging
pip install --editable . 

This installs the code in src as an editable package and all the dependencies in requirements.txt.

Organization of the repo

  • configs: Configuration files for all experiments.
  • scripts: Slurm scripts for model training and hyperparameter sweeps.
  • src: Main source code to run the experiments.
    • data: Pytorch datasets and scripts/info to download data.
    • models: Pytorch lightning module to train models to prevent shortcut learning with different methods.
    • eval: Model evaluation with kNN classifiers and embedding plots.
  • train.py: Main training script to train k-fold cross validation (and optional hyperparameter sweeps).

Usage

Download public datasets

First, you need to download the two data sets Morpho-MNIST and CheXpert. For Morpho-MNIST we have a download script:

python src/data/download_data/load_morpho_mnist.py -d path-to-dataset-directory -v True

For CheXpert you need to register, hence we provide additional information on how to register and download the dataset: load_chexpert.txt.

Training

To run k-fold cross-validation for one method you need to hand over a config file to the train script. For example, for MINE with the Morpho-MNIST dataset the comand-line interface is

python src/train.py -tc configs/morpho-mnist/mine.yaml

Note: The dataset_path needs to be adjusted in the config file.

To run the code on a slurm cluster, we provide a bash script:

sbatch scripts/train.sh configs/morpho-mnist/mine.yaml

Run hyperparameter sweeps (wandb)

Initialize the sweep with

python src/utils/sweep_init.py -sc configs/example_sweep.yaml

This will print out the sweep_id that you can hand over to the script to start multiple runs (10 in this case) on a slurm cluster

sh scripts/sweep.sh 10 configs/morpho-mnist/mine.yaml sweep_id

Evaluation

To evaluate the trained models for the confusion matrix of kNN classifier accuracy for one model run

python src/eval/knn_classifier.py -c model_config -ckpts model_checkpoints

To generate the embedding plots of the paper run

python src/eval/embeddings.py -cfgs list_of_model_configs -ckpts list_of_model_checkpoints

Cite

If you find our code or paper useful, please consider citing this work:

@InProceedings{mueller2025benchmarking,
    title = {Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging},
    author = {M\"uller, Sarah and Fay, Louisa and Koch, Lisa M. and Gatidis, Sergios and K\"ustner, Thomas and Berens, Philipp},
    booktitle = {Machine Learning in Medical Imaging},
    year = {2025},
    publisher = {Springer Nature Switzerland},
    pages = {53--62},
    isbn = {978-3-031-73290-4},
}