Skip to content

Latest commit

 

History

History
139 lines (101 loc) · 4.56 KB

README.md

File metadata and controls

139 lines (101 loc) · 4.56 KB

Deep Frank-Wolfe For Neural Network Optimization

This repository contains the implementation of the paper Deep Frank-Wolfe For Neural Network Optimization in pytorch. If you use this work for your research, please cite the paper:

@Article{berrada2019deep,
  author       = {Berrada, Leonard and Zisserman, Andrew and Kumar, M Pawan},
  title        = {Deep Frank-Wolfe For Neural Network Optimization},
  journal      = {International Conference on Learning Representations},
  year         = {2019},
}

Note: you might be interested in our follow-up algorithm ALI-G, which has explicit convergence guarantees and outperforms DFW in our experiments.

Requirements

This code should work for pytorch >= 1.0 in python3. Detailed requirements are available in requirements.txt.

Installation

  • Clone this repository: git clone --recursive https://github.com/oval-group/dfw (note that the option recursive is necessary to have clone the submodules, these are needed to reproduce the experiments but not for the DFW implementation itself).
  • Go to directory and install the requirements: cd dfw && pip install -r requirements.txt
  • Install the DFW package python setup.py install

Example of Usage

  • Simple usage example:
from dfw import DFW
from dfw.losses import MultiClassHingeLoss


# boilerplate code:
# `model` is a nn.Module
# `x` is an input sample, `y` is a label

# create loss function
svm = MultiClassHingeLoss()

# create DFW optimizer with learning rate of 0.1
optimizer = DFW(model.parameters(), eta=0.1)

# DFW can be used with standard pytorch syntax
optimizer.zero_grad()
loss = svm(model(x), y)
loss.backward()
# NB: DFW needs to have access to the current loss value,
# (this syntax is compatible with standard pytorch optimizers too)
optimizer.step(lambda: float(loss))
  • Technical requirement: the DFW uses a custom step-size at each step. For this update to make sense, the loss function must be piecewise linear convex. For instance, one can use a multi-class SVM loss or an l1 regression.

  • Smoothing: sometimes the multi-class SVM loss does not fare well with a large number of classes. This issue can be alleviated by using dual smoothing, which is easy to plug in the code:

from dfw.losses import set_smoothing_enabled
...
with set_smoothing_enabled(True):
    loss = svm(model(x), y)

Reproducing the Results

alt text

  • To reproduce the CIFAR experiments: VISION_DATA=[path/to/your/cifar/data] python reproduce/cifar.py
  • To reproduce the SNLI experiments: follow the preparation instructions and run python reproduce/snli.py

DFW largely outperforms all baselines that do not use a manual schedule for the learning rate. The tables below show the performance on the CIFAR data sets when using data augmentation (AMSGrad, a variant of Adam, is the strongest baseline in our experiments), and on the SNLI data set.

CIFAR-10:

Wide Residual Networks Densely Connected Networks
Optimizer Test Accuracy (%)
AMSGrad 90.1
DFW 94.7
SGD (with schedule) 95.4
Optimizer Test Accuracy (%)
AMSGrad 91.8
DFW 94.9
SGD (with schedule) 95.3

CIFAR-100:

Wide Residual Networks Densely Connected Networks
Optimizer Test Accuracy (%)
AMSGrad 67.8
DFW 74.7
SGD (with schedule) 77.8
Optimizer Test Accuracy (%)
AMSGrad 69.6
DFW 73.2
SGD (with schedule) 76.3

SNLI:

Optimizer Test Accuracy (%)
Adagrad 84.6
Adam 85.0
AMSGrad 85.1
BPGrad 84.2
DFW 85.2
SGD (with schedule) 85.2

Acknowledgments

We use the following third-part implementations: