Skip to content

Unsupervised Data Augmentation experiments in PyTorch

Notifications You must be signed in to change notification settings

vfdev-5/UDA-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Unsupervised Data Augmentation experiments in PyTorch

Experiments with "Unsupervised Data Augmentation" method on Cifar10 dataset.

Based on "Unsupervised Data Augmentation"

Unsupervised Data Augmentation in nutshell

UDA

Requirements

All experiments are run using mlflow, please install the latest version of this library

pip install --upgrade mlflow

Experiments

Start MLFlow UI server

Please create output folder (e.g. $PWD/output) and setup mlflow server:

export OUTPUT_PATH=/path/to/output

and

mlflow server --backend-store-uri $OUTPUT_PATH/mlruns --default-artifact-root $OUTPUT_PATH/mlruns -p 5566 -h 0.0.0.0

MLflow dashboard is available in the browser at 0.0.0.0:5566

CIFAR10 dataset

Create once "CIFAR10" experiment

export MLFLOW_TRACKING_URI=$OUTPUT_PATH/mlruns
mlflow experiments create -n CIFAR10

Implementation details:

  • Models

  • Consistency loss: KL

  • Data augs: AutoAugment + Cutout

  • Cosine LR decay

  • Training Signal Annealing

  • Updated UDA version: see main_uda2.py

    • training 4k batchs are also passed into unsupervised learning part

Fast ResNet

Start a single run

export MLFLOW_TRACKING_URI=$OUTPUT_PATH/mlruns

mlflow run experiments/ --experiment-name=CIFAR10 -P dataset=CIFAR10 -P network=fastresnet -P params="data_path=../input/cifar10;num_epochs=100;learning_rate=0.08;batch_size=512;TSA_proba_min=0.5;unlabelled_batch_size=1024"

Wide ResNet

Start a single run

export MLFLOW_TRACKING_URI=$OUTPUT_PATH/mlruns

mlflow run experiments/ --experiment-name=CIFAR10 -P dataset=CIFAR10 -P network=wideresnet -P params="data_path=../input/cifar10;num_epochs=100;learning_rate=0.1;batch_size=512;TSA_proba_min=0.1;unlabelled_batch_size=1024"
Paper's configuration
export MLFLOW_TRACKING_URI=$OUTPUT_PATH/mlruns

mlflow run experiments/ --experiment-name=CIFAR10 -P dataset=CIFAR10 -P network=wideresnet -P params="data_path=../input/cifar10;num_epochs=6250;learning_rate=0.03;batch_size=64;TSA_proba_min=0.1;unlabelled_batch_size=320;num_warmup_steps=20000"

Unfortunately, I can not reproduce paper's result with 5.3 test error.

Updated version of UDA

export MLFLOW_TRACKING_URI=$OUTPUT_PATH/mlruns

mlflow run experiments/ -e main_uda2 --experiment-name=CIFAR10 -P dataset=CIFAR10 -P network=fastresnet -P params="data_path=../input/cifar10;num_epochs=100;learning_rate=0.08;batch_size=512;unlabelled_batch_size=512"

Some results

fastresnet_uda_vs_uda2

Tensorboard

All experiments are also logged to the Tensorboard. To visualize the experiments, please install tensorboard and run :

# tensorboard --logdir=$OUTPUT_PATH/mlruns/<experiment_id>
tensorboard --logdir=$OUTPUT_PATH/mlruns/1

Acknowledgements

In this repository we are using the code from

Thanks to the authors for sharing their code!

About

Unsupervised Data Augmentation experiments in PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published