This repository contains the code to reproduce the results from the paper "Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging", which was accepted to the 15th International Workshop on Machine Learning in Medical Imaging (MLMI 2024).
We present a comprehensive performance comparison of dependency measures to prevent shortcut learning in medical imaging.
Set up a python environment with a python version 3.10
. Then, download the repository, activate the environment and install all other dependencies with
cd dependence-measures-medical-imaging
pip install --editable .
This installs the code in src
as an editable package and all the dependencies in requirements.txt
.
- configs: Configuration files for all experiments.
- scripts: Slurm scripts for model training and hyperparameter sweeps.
- src: Main source code to run the experiments.
- train.py: Main training script to train k-fold cross validation (and optional hyperparameter sweeps).
First, you need to download the two data sets Morpho-MNIST and CheXpert. For Morpho-MNIST we have a download script:
python src/data/download_data/load_morpho_mnist.py -d path-to-dataset-directory -v True
For CheXpert you need to register, hence we provide additional information on how to register and download the dataset: load_chexpert.txt.
To run k-fold cross-validation for one method you need to hand over a config file to the train script. For example, for MINE with the Morpho-MNIST dataset the comand-line interface is
python src/train.py -tc configs/morpho-mnist/mine.yaml
Note: The dataset_path
needs to be adjusted in the config file.
To run the code on a slurm cluster, we provide a bash script:
sbatch scripts/train.sh configs/morpho-mnist/mine.yaml
Initialize the sweep with
python src/utils/sweep_init.py -sc configs/example_sweep.yaml
This will print out the sweep_id
that you can hand over to the script to start multiple runs (10 in this case) on a slurm cluster
sh scripts/sweep.sh 10 configs/morpho-mnist/mine.yaml sweep_id
To evaluate the trained models for the confusion matrix of kNN classifier accuracy for one model run
python src/eval/knn_classifier.py -c model_config -ckpts model_checkpoints
To generate the embedding plots of the paper run
python src/eval/embeddings.py -cfgs list_of_model_configs -ckpts list_of_model_checkpoints
If you find our code or paper useful, please consider citing this work:
@InProceedings{mueller2025benchmarking,
title = {Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging},
author = {M\"uller, Sarah and Fay, Louisa and Koch, Lisa M. and Gatidis, Sergios and K\"ustner, Thomas and Berens, Philipp},
booktitle = {Machine Learning in Medical Imaging},
year = {2025},
publisher = {Springer Nature Switzerland},
pages = {53--62},
isbn = {978-3-031-73290-4},
}