-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 952c5af
Showing
17 changed files
with
2,314 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
|
||
# Proxy Anchor Loss for Deep Metric Learning | ||
|
||
Official PyTorch implementation of CVPR 2020 paper [**Proxy Anchor Loss for Deep Metric Learning**](https://arxiv.org/abs/2003.13911). | ||
|
||
A standard embedding network trained with **Proxy-Anchor Loss** achieves state-of-the-art performance and most quickly converges . | ||
|
||
This repository provides source code of experiments on four datasets (CUB-200-2011, Cars-196, Stanford Online Products and In-shop) and pretrained models. | ||
|
||
#### Accuracy in Recall@1 versus training time on the Cars-196 | ||
|
||
<p align="center"><img src="../misc/Recall_Trainingtime.jpg" alt="graph" width="60%"></p> | ||
|
||
|
||
|
||
## Requirements | ||
|
||
- Python3 | ||
- PyTorch (> 1.0) | ||
- NumPy | ||
- tqdm | ||
- wandb | ||
- [Pytorch-Metric-Learning](https://github.com/KevinMusgrave/pytorch-metric-learning) | ||
|
||
|
||
|
||
## Datasets | ||
|
||
1. Download four public benchmarks for deep metric learning | ||
- [CUB-200-2011](http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz) | ||
- Cars-196 ([Img](http://imagenet.stanford.edu/internal/car196/car_ims.tgz), [Annotation](http://imagenet.stanford.edu/internal/car196/cars_annos.mat)) | ||
- [Stanford Online Products](ftp://cs.stanford.edu/cs/cvgl/Stanford_Online_Products.zip) | ||
- In-shop Clothes Retrieval ([Link](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html)) | ||
|
||
2. Extract the tgz or zip file into `./data/` (Exceptionally, for Cars-196, put the files in a `./data/cars196`) | ||
|
||
|
||
|
||
## Training Embedding Network | ||
|
||
Note that a sufficiently large batch size and good parameters resulted in better overall performance than the performance described in the paper. You can download the trained model through the hyperlink in the table. | ||
|
||
### CUB-200-2011 | ||
|
||
- Train a embedding network of Inception-BN (d=512) using **Proxy-Anchor loss** | ||
|
||
```bash | ||
python train.py --gpu-id 0 --loss Proxy_Anchor--model bn_inception --embedding-size 512 --batch-size 180 --lr 1e-4 --dataset cub --warm 1 --bn-freeze 1 --lr-decay-step 10 | ||
``` | ||
|
||
- Train a embedding network of ResNet-50 (d=512) using **Proxy-Anchor loss** | ||
|
||
```bash | ||
python train.py --gpu-id 0 --loss Proxy_Anchor --model resnet50 --embedding-size 512 --batch-size 120 --lr 1e-4 --dataset cub --warm 5 --bn-freeze 1 --lr-decay-step 5 | ||
``` | ||
|
||
| Method | Backbone | R@1 | R@2 | R@4 | R@8 | | ||
|:-:|:-:|:-:|:-:|:-:|:-:| | ||
| [Proxy-Anchor<sup>512</sup>](https://drive.google.com/file/d/1twaY6S2QIR8eanjDB6PoVPlCTsn-6ZJW/view?usp=sharing) | Inception-BN | 69.1 | 78.9 | 86.1 | 91.2 | | ||
| [Proxy-Anchor<sup>512</sup>](https://drive.google.com/file/d/1s-cRSEL2PhPFL9S7bavkrD_c59bJXL_u/view?usp=sharing) | ResNet-50 | 69.9 | 79.6 | 86.6 | 91.4 | | ||
|
||
### Cars-196 | ||
|
||
- Train a embedding network of Inception-BN (d=512) using **Proxy-Anchor loss** | ||
|
||
```bash | ||
python train.py --gpu-id 0 --loss Proxy_Anchor --model bn_inception --embedding-size 512 --batch-size 180 --lr 1e-4 --dataset cars --warm 1 --bn-freeze 1 --lr-decay-step 20 | ||
``` | ||
|
||
- Train a embedding network of ResNet-50 (d=512) using **Proxy-Anchor loss** | ||
|
||
```bash | ||
python train.py --gpu-id 0 --loss Proxy_Anchor --model resnet50 --embedding-size 512 --batch-size 120 --lr 1e-4 --dataset cars --warm 5 --bn-freeze 1 --lr-decay-step 10 | ||
``` | ||
|
||
| Method | Backbone | R@1 | R@2 | R@4 | R@8 | | ||
|:-:|:-:|:-:|:-:|:-:|:-:| | ||
| [Proxy-Anchor<sup>512</sup>](https://drive.google.com/file/d/1wwN4ojmOCEAOaSYQHArzJbNdJQNvo4E1/view?usp=sharing) | Inception-BN | 86.4 | 91.9 | 95.0 | 97.0 | | ||
| [Proxy-Anchor<sup>512</sup>](https://drive.google.com/file/d/1_4P90jZcDr0xolRduNpgJ9tX9HZ1Ih7n/view?usp=sharing) | ResNet-50 | 87.7 | 92.7 | 95.5 | 97.3 | | ||
|
||
### Stanford Online Products | ||
|
||
- Train a embedding network of Inception-BN (d=512) using **Proxy-Anchor loss** | ||
|
||
```bash | ||
python train.py --gpu-id 0 --loss Proxy_Anchor --model bn_inception --embedding-size 512 --batch-size 180 --lr 6e-4 --dataset SOP --warm 1 --bn-freeze 0 --l2-norm 1 --lr-decay-step 20 --lr-decay-gamma 0.25 | ||
``` | ||
|
||
| Method | Backbone | R@1 | R@10 | R@100 | R@1000 | | ||
|:-:|:-:|:-:|:-:|:-:|:-:| | ||
|[Proxy-Anchor<sup>512</sup>](https://drive.google.com/file/d/1hBdWhLP2J83JlOMRgZ4LLZY45L-9Gj2X/view?usp=sharing) | Inception-BN | 79.2 | 90.7 | 96.2 | 98.6 | | ||
|
||
### In-Shop Clothes Retrieval | ||
|
||
- Train a embedding network of Inception-BN (d=512) using **Proxy-Anchor loss** | ||
|
||
```bash | ||
python train.py --gpu-id 0 --loss Proxy_Anchor --model bn_inception --embedding-size 512 --batch-size 180 --lr 6e-4 --dataset Inshop --warm 1 --bn-freeze 0 --l2-norm 1 --lr-decay-step 20 --lr-decay-gamma 0.25 | ||
``` | ||
|
||
| Method | Backbone | R@1 | R@10 | R@20 | R@30 | R@40 | | ||
|:-:|:-:|:-:|:-:|:-:|:-:|:-:| | ||
| [Proxy-Anchor<sup>512</sup>](https://drive.google.com/file/d/1VE7psay7dblDyod8di72Sv7Z2xGtUGra/view?usp=sharing) | Inception-BN | 91.9 | 98.1 | 98.7 | 99.0 | 99.1 | | ||
|
||
|
||
|
||
## Evaluating Image Retrieval | ||
|
||
Follow the steps below to evaluate the provided pretrained model or your trained model. Trained best model will be saved in the `./logs/folder_name`. | ||
|
||
```bash | ||
# The parameters should be changed according to the model to be evaluated. | ||
python evaluate.py --gpu-id 0 --batch-size 120 --model bn_inception --embedding-size 512 --dataset cub --resume /set/your/model/path/best_model.pth | ||
``` | ||
|
||
|
||
|
||
## Acknowledgements | ||
|
||
Our code is modified and adapted on these great repositories: | ||
|
||
- [No Fuss Distance Metric Learning using Proxies](https://github.com/dichotomies/proxy-nca) | ||
- [PyTorch Metric learning](https://github.com/KevinMusgrave/pytorch-metric-learning) | ||
|
||
|
||
|
||
## Other Implementations | ||
|
||
- [Pytorch, Tensorflow and Mxnet implementations](https://github.com/geonm/proxy-anchor-loss) (Thank you for Geonmo Gu :D) | ||
|
||
|
||
|
||
## Citation | ||
|
||
If you use this method or this code in your research, please cite as: | ||
|
||
@inproceedings{kim2020proxy, | ||
title={Proxy Anchor Loss for Deep Metric Learning}, | ||
author={Kim, Sungyeon and Kim, Dongwon and Cho, Minsu and Kwak, Suha}, | ||
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, | ||
year={2020} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from .base import * | ||
|
||
import numpy as np, os, sys, pandas as pd, csv, copy | ||
import torch | ||
import torchvision | ||
import PIL.Image | ||
|
||
|
||
class Inshop_Dataset(torch.utils.data.Dataset): | ||
def __init__(self, root, mode, transform = None): | ||
self.root = root + '/Inshop_Clothes' | ||
self.mode = mode | ||
self.transform = transform | ||
self.train_ys, self.train_im_paths = [], [] | ||
self.query_ys, self.query_im_paths = [], [] | ||
self.gallery_ys, self.gallery_im_paths = [], [] | ||
|
||
data_info = np.array(pd.read_table(self.root +'/Eval/list_eval_partition.txt', header=1, delim_whitespace=True))[:,:] | ||
#Separate into training dataset and query/gallery dataset for testing. | ||
train, query, gallery = data_info[data_info[:,2]=='train'][:,:2], data_info[data_info[:,2]=='query'][:,:2], data_info[data_info[:,2]=='gallery'][:,:2] | ||
|
||
#Generate conversions | ||
lab_conv = {x:i for i,x in enumerate(np.unique(np.array([int(x.split('_')[-1]) for x in train[:,1]])))} | ||
train[:,1] = np.array([lab_conv[int(x.split('_')[-1])] for x in train[:,1]]) | ||
|
||
lab_conv = {x:i for i,x in enumerate(np.unique(np.array([int(x.split('_')[-1]) for x in np.concatenate([query[:,1], gallery[:,1]])])))} | ||
query[:,1] = np.array([lab_conv[int(x.split('_')[-1])] for x in query[:,1]]) | ||
gallery[:,1] = np.array([lab_conv[int(x.split('_')[-1])] for x in gallery[:,1]]) | ||
|
||
#Generate Image-Dicts for training, query and gallery of shape {class_idx:[list of paths to images belong to this class] ...} | ||
for img_path, key in train: | ||
self.train_im_paths.append(os.path.join(self.root, 'Img', img_path)) | ||
self.train_ys += [int(key)] | ||
|
||
for img_path, key in query: | ||
self.query_im_paths.append(os.path.join(self.root, 'Img', img_path)) | ||
self.query_ys += [int(key)] | ||
|
||
for img_path, key in gallery: | ||
self.gallery_im_paths.append(os.path.join(self.root, 'Img', img_path)) | ||
self.gallery_ys += [int(key)] | ||
|
||
if self.mode == 'train': | ||
self.im_paths = self.train_im_paths | ||
self.ys = self.train_ys | ||
elif self.mode == 'query': | ||
self.im_paths = self.query_im_paths | ||
self.ys = self.query_ys | ||
elif self.mode == 'gallery': | ||
self.im_paths = self.gallery_im_paths | ||
self.ys = self.gallery_ys | ||
|
||
def nb_classes(self): | ||
return len(set(self.ys)) | ||
|
||
def __len__(self): | ||
return len(self.ys) | ||
|
||
def __getitem__(self, index): | ||
|
||
def img_load(index): | ||
im = PIL.Image.open(self.im_paths[index]) | ||
# convert gray to rgb | ||
if len(list(im.split())) == 1 : im = im.convert('RGB') | ||
if self.transform is not None: | ||
im = self.transform(im) | ||
return im | ||
|
||
im = img_load(index) | ||
target = self.ys[index] | ||
|
||
return im, target |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from .base import * | ||
|
||
class SOP(BaseDataset): | ||
def __init__(self, root, mode, transform = None): | ||
self.root = root + '/Stanford_Online_Products' | ||
self.mode = mode | ||
self.transform = transform | ||
if self.mode == 'train': | ||
self.classes = range(0,11318) | ||
elif self.mode == 'eval': | ||
self.classes = range(11318,22634) | ||
|
||
BaseDataset.__init__(self, self.root, self.mode, self.transform) | ||
metadata = open(os.path.join(self.root, 'Ebay_train.txt' if self.classes == range(0, 11318) else 'Ebay_test.txt')) | ||
for i, (image_id, class_id, _, path) in enumerate(map(str.split, metadata)): | ||
if i > 0: | ||
if int(class_id)-1 in self.classes: | ||
self.ys += [int(class_id)-1] | ||
self.I += [int(image_id)-1] | ||
self.im_paths.append(os.path.join(self.root, path)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from .cars import Cars | ||
from .cub import CUBirds | ||
from .SOP import SOP | ||
from .import utils | ||
from .base import BaseDataset | ||
|
||
|
||
_type = { | ||
'cars': Cars, | ||
'cub': CUBirds, | ||
'SOP': SOP | ||
} | ||
|
||
def load(name, root, mode, transform = None): | ||
return _type[name](root = root, mode = mode, transform = transform) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
|
||
from __future__ import print_function | ||
from __future__ import division | ||
|
||
import os | ||
import torch | ||
import torchvision | ||
import numpy as np | ||
import PIL.Image | ||
|
||
class BaseDataset(torch.utils.data.Dataset): | ||
def __init__(self, root, mode, transform = None): | ||
self.root = root | ||
self.mode = mode | ||
self.transform = transform | ||
self.ys, self.im_paths, self.I = [], [], [] | ||
|
||
def nb_classes(self): | ||
assert set(self.ys) == set(self.classes) | ||
return len(self.classes) | ||
|
||
def __len__(self): | ||
return len(self.ys) | ||
|
||
def __getitem__(self, index): | ||
def img_load(index): | ||
im = PIL.Image.open(self.im_paths[index]) | ||
# convert gray to rgb | ||
if len(list(im.split())) == 1 : im = im.convert('RGB') | ||
if self.transform is not None: | ||
im = self.transform(im) | ||
return im | ||
|
||
im = img_load(index) | ||
target = self.ys[index] | ||
|
||
return im, target | ||
|
||
def get_label(self, index): | ||
return self.ys[index] | ||
|
||
def set_subset(self, I): | ||
self.ys = [self.ys[i] for i in I] | ||
self.I = [self.I[i] for i in I] | ||
self.im_paths = [self.im_paths[i] for i in I] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from .base import * | ||
import scipy.io | ||
|
||
class Cars(BaseDataset): | ||
def __init__(self, root, mode, transform = None): | ||
self.root = root + '/cars196' | ||
self.mode = mode | ||
self.transform = transform | ||
if self.mode == 'train': | ||
self.classes = range(0,98) | ||
elif self.mode == 'eval': | ||
self.classes = range(98,196) | ||
|
||
BaseDataset.__init__(self, self.root, self.mode, self.transform) | ||
annos_fn = 'cars_annos.mat' | ||
cars = scipy.io.loadmat(os.path.join(self.root, annos_fn)) | ||
ys = [int(a[5][0] - 1) for a in cars['annotations'][0]] | ||
im_paths = [a[0][0] for a in cars['annotations'][0]] | ||
index = 0 | ||
for im_path, y in zip(im_paths, ys): | ||
if y in self.classes: # choose only specified classes | ||
self.im_paths.append(os.path.join(self.root, im_path)) | ||
self.ys.append(y) | ||
self.I += [index] | ||
index += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from .base import * | ||
|
||
class CUBirds(BaseDataset): | ||
def __init__(self, root, mode, transform = None): | ||
self.root = root + '/CUB_200_2011' | ||
self.mode = mode | ||
self.transform = transform | ||
if self.mode == 'train': | ||
self.classes = range(0,100) | ||
elif self.mode == 'eval': | ||
self.classes = range(100,200) | ||
|
||
BaseDataset.__init__(self, self.root, self.mode, self.transform) | ||
index = 0 | ||
for i in torchvision.datasets.ImageFolder(root = | ||
os.path.join(self.root, 'images')).imgs: | ||
# i[1]: label, i[0]: root | ||
This comment has been minimized.
Sorry, something went wrong. |
||
y = i[1] | ||
# fn needed for removing non-images starting with `._` | ||
fn = os.path.split(i[0])[1] | ||
if y in self.classes and fn[:2] != '._': | ||
self.ys += [y] | ||
self.I += [index] | ||
self.im_paths.append(os.path.join(self.root, i[0])) | ||
This comment has been minimized.
Sorry, something went wrong.
jdhao
Contributor
|
||
index += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
from torch.utils.data.sampler import Sampler | ||
from tqdm import * | ||
|
||
class BalancedSampler(Sampler): | ||
def __init__(self, data_source, batch_size, images_per_class=3): | ||
self.data_source = data_source | ||
self.ys = data_source.ys | ||
self.num_groups = batch_size // images_per_class | ||
self.batch_size = batch_size | ||
self.num_instances = images_per_class | ||
self.num_samples = len(self.ys) | ||
self.num_classes = len(set(self.ys)) | ||
|
||
def __len__(self): | ||
return self.num_samples | ||
|
||
def __iter__(self): | ||
num_batches = len(self.data_source) // self.batch_size | ||
ret = [] | ||
while num_batches > 0: | ||
sampled_classes = np.random.choice(self.num_classes, self.num_groups, replace=False) | ||
for i in range(len(sampled_classes)): | ||
ith_class_idxs = np.nonzero(np.array(self.ys) == sampled_classes[i])[0] | ||
class_sel = np.random.choice(ith_class_idxs, size=self.num_instances, replace=True) | ||
ret.extend(np.random.permutation(class_sel)) | ||
num_batches -= 1 | ||
return iter(ret) | ||
|
Oops, something went wrong.
here, i[0] is already the full path to a single image. You do not need to use
os.path.join(root, i[0])
to generate the full path to image again. Although it do no harm, there is no benefit.