Skip to content

Latest commit

 

History

History
171 lines (155 loc) · 5.12 KB

README.md

File metadata and controls

171 lines (155 loc) · 5.12 KB

Contrastive Neural Processes

PyTorch implementation of "Contrastive Neural Processes for Self-Supervised Learning" accepted as a Long Oral at ACML2021

Implementation Details

This folder includes the code for Contrastive Neural Processes and Baselines. Code has been modified accordingly to the needs of the project. Original sources are cited here:

Folders:

  • Baselines : Includes all baselines, hyperparameters and evaluation metrics used for base experiments
  • ContrNP : Includes code for ContrNP method and resources
  • Results : Location where weights and results are saved
  • Data : Location where datasets are located. Use data_name_load.py to download and extract data. (Please Note: Some commands for data extraction are Ubuntu specific)

Baselines includes implementations for Tloss [1], CPC [2], TNC [2] and SimCLR [3].

Folders: npf, utils are for the implementation of Neural Processes [4].

  • Main Implementation of contrastive convolutional cnp: Contrastive-ConvCNP-SSL.ipynb
  • Implementation of Self supervised convolutional cnp: ConvCNP-SSL.ipynb
  • Implementation of Self supervised cnp: CNP-SSL.ipynb

Citing this work

[arXiv] [PMLR] [ACML2021]

@misc{kallidromitis2021contrastive,
      title={Contrastive Neural Processes for Self-Supervised Learning}, 
      author={Konstantinos Kallidromitis and Denis Gudovskiy and Kazuki Kozuka and Ohama Iku and Luca Rigazio},
      journal={arXiv preprint arXiv:2110.13623},
      year={2021}
}

Reproduced Results

AFDB IMS Bearing Urban8K
Method Accuracy AUPRC Sil↑ DBI↓ Accuracy AUPRC Sil↑ DBI↓ Accuracy AUPRC Sil↑ DBI↓
CPC 71.6 62.6 0.22 1.74 72.4 84.4 0.12 2.20 83.3 94.5 0.24 1.64
Tloss 74.8 59.8 0.14 2.04 73.2 87.6 0.17 1.79 81.5 93.8 0.26 1.30
TNC 74.5 56.3 0.24 1.44 70.3 86.3 0.31 0.94 80.7 93.9 0.36 0.72
SimCLR 82.3 71.5 0.34 1.49 41.5 70.7 0.24 1.47 82.8 94.1 0.35 1.13
ContrNP (ours) 94.2 89.1 0.36 1.35 73.6 89.3 0.38 0.91 84.2 95.4 0.42 0.89
Fully supervised 98.4 81.6 0.43 0.83 86.3 94.8 0.47 0.77 99.9 99.9 0.49 0.80

Requirements [txt]

python>=3.6.9
skorch==0.8
pytorch>=1.3.1
scikit-image
wfdb