PyTorch implementation of "Contrastive Neural Processes for Self-Supervised Learning" accepted as a Long Oral at ACML2021
This folder includes the code for Contrastive Neural Processes and Baselines. Code has been modified accordingly to the needs of the project. Original sources are cited here:
Folders:
- Baselines : Includes all baselines, hyperparameters and evaluation metrics used for base experiments
- ContrNP : Includes code for ContrNP method and resources
- Results : Location where weights and results are saved
- Data : Location where datasets are located. Use data_name_load.py to download and extract data. (Please Note: Some commands for data extraction are Ubuntu specific)
Baselines includes implementations for Tloss [1], CPC [2], TNC [2] and SimCLR [3].
Folders: npf, utils are for the implementation of Neural Processes [4].
- Main Implementation of contrastive convolutional cnp: Contrastive-ConvCNP-SSL.ipynb
- Implementation of Self supervised convolutional cnp: ConvCNP-SSL.ipynb
- Implementation of Self supervised cnp: CNP-SSL.ipynb
@misc{kallidromitis2021contrastive,
title={Contrastive Neural Processes for Self-Supervised Learning},
author={Konstantinos Kallidromitis and Denis Gudovskiy and Kazuki Kozuka and Ohama Iku and Luca Rigazio},
journal={arXiv preprint arXiv:2110.13623},
year={2021}
}
AFDB | IMS Bearing | Urban8K | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
Method | Accuracy | AUPRC | Sil↑ | DBI↓ | Accuracy | AUPRC | Sil↑ | DBI↓ | Accuracy | AUPRC | Sil↑ | DBI↓ |
CPC | 71.6 | 62.6 | 0.22 | 1.74 | 72.4 | 84.4 | 0.12 | 2.20 | 83.3 | 94.5 | 0.24 | 1.64 |
Tloss | 74.8 | 59.8 | 0.14 | 2.04 | 73.2 | 87.6 | 0.17 | 1.79 | 81.5 | 93.8 | 0.26 | 1.30 |
TNC | 74.5 | 56.3 | 0.24 | 1.44 | 70.3 | 86.3 | 0.31 | 0.94 | 80.7 | 93.9 | 0.36 | 0.72 |
SimCLR | 82.3 | 71.5 | 0.34 | 1.49 | 41.5 | 70.7 | 0.24 | 1.47 | 82.8 | 94.1 | 0.35 | 1.13 |
ContrNP (ours) | 94.2 | 89.1 | 0.36 | 1.35 | 73.6 | 89.3 | 0.38 | 0.91 | 84.2 | 95.4 | 0.42 | 0.89 |
Fully supervised | 98.4 | 81.6 | 0.43 | 0.83 | 86.3 | 94.8 | 0.47 | 0.77 | 99.9 | 99.9 | 0.49 | 0.80 |
Requirements [txt]
python>=3.6.9
skorch==0.8
pytorch>=1.3.1
scikit-image
wfdb