This is a dockerized implementation of Spatial Transform Netrworks with CoordConv layers that allows to train some models and save and interact with the results through Jupyter Notebook.
- Train different architectures from a scratch
- Test them on MNIST and fashion-MNIST
- Save, load and compare results
- Everything is dockerized in order to allow a fast deployment
This repository has been containerized through docker compose (version 1.28.2)
Runs the dockerized jupyter-lab server.
docker-compose up jupyter
It copies the source code and the results and models to the docker container in order to analyze them through jupyter. It mounts a volume on folder jupyter and maps it to work/jupyter in the container so you can persist your modifications in your host if you decide to make some changes in the analysis notebook. It may require you to grant permissions so the docker container can write in the original folder.
The analytics notebook route in the container is work/jupyter/result_analysis.ipynb There you'll see:
- Accuracy metrics of the different models averaged through many runs with their standard deviations
- Visualizations of the images after the Spatial Transform layer for the models that have them
- Confusion Matrixes
- Examples of the misclassified items and insights on them
The training loop is also dockerized. You just need to run:
docker-compose build train
docker-compose run train
It mounts and maps volumes in folders results, models and datasets so you can persist your trained models and data in your host's route.
In order to change the parameters you'll have to edit the service configuration in docker-compose.yml file:
train:
build:
context: .
dockerfile: docker/Dockerfile.train
container_name: train
tty: true
environment:
- NVIDIA_VISIBLE_DEVICES=all
volumes:
- ./results/:/code/results/
- ./models/:/code/models/
- ./datasets/:/code/datasets/
command: 'python train.py --dataset fashion-mnist --model stnet --epochs 20'
runtime: nvidia
There you can see the command with the arguments. Allowed values are:
-
Dataset:
- mnist
- fashion-mnist
-
Model:
- convnet: Classical CNN
- stnet: vanilla STnet as in https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
- coordconv: CNN with coordconv layers instead of classical conv layers as in https://github.com/walsvid/CoordConv adapted to PyTorch 1.x
- stcoordconv: Combination of STN and coordconv implementations
If your system lacks of GPU capabilities or you have not enabled the nvidia-dockercompose integration properly you may have to edit the docker-compose.yml and remove both runtime: nvidia refferences in order to be able to launch the project.
This visualizations are extracted from the Jupyter notebook provided in jupyter service. Here we can see a comparison between different architectures over fashion-MNIST dataset, averaged over many runs:
In the jupyter notebook there are also other metrics and visualizations implemented. For example, you can check confusion matrixes over the test set:
And also see how the spatial transform layers transform the imput images in the models that include them:
I have added implementations for the bird classification dataset Caltech-UCSD Birds-200-2011 they use in the Spatial Transform Netrworks paper:
You'll have to download the dataset from the original source and extract it in dataset folder. For launching the training there's a different train-birds service in the docker-compose file since the tests are different, the ones here check size compatibilities with the definitions of the models and I didn't wanted them to interfere with the first part of the project which should be able to be executed without having this dataset.
I've tried to expand the implementation in https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html adding more complexity to the localization layer, but I've not gotten the expected results. Since the authors didn't release their code I have adapted this other implementation STN with ResNext50 as backbone removing the use of initialized weights, since the goal of this project is comparing the results of the architectures and the convergence speed instead of training a state-of-the-art model, and updating the input structure which was designed for video instead of images. You can test them by choosing birds as dataset and resnext or stresnext as model.
The results don't seem to match the resnext baseline:
The ST-ResNext models are too big for being uploaded to GitHub (100Mb max). Six different versions of them can be found here: https://drive.google.com/file/d/1efcWaiE-1mt1b_FDU-9XPo7Lq3cK3eSP/view?usp=sharing
My conclusion is that STNs are an interesting technology but their implementation over different state-of-the-art challenges is not as straight forward as it may seem, and the literature about them is still in an early stage. They seem to require extensive running in the architecture in order to adapt them to a complex problem, and I'd choose them as an alternative for improving an already developed model, but I wouldn't start from here if the goal is fast-prototyping a new use case. A problem that they already mentioned in the paper is that they are prone to overfitting. My experiments corroborate that since the training error was smaller than in the resnext baseline but the validation error was higher. Regarding my experiments they also increase training time up to ~140%, so you should consider that if you have time or infrastructure restrains. This seems worth in MNIST and fashion-MNIST since STN-based models go several epochs ahead of non-STN models, but it may not be justified in other contexts.