diff --git a/README.md b/README.md new file mode 100644 index 00000000..0245c859 --- /dev/null +++ b/README.md @@ -0,0 +1,143 @@ + +# Proxy Anchor Loss for Deep Metric Learning + +Official PyTorch implementation of CVPR 2020 paper [**Proxy Anchor Loss for Deep Metric Learning**](https://arxiv.org/abs/2003.13911). + +A standard embedding network trained with **Proxy-Anchor Loss** achieves state-of-the-art performance and most quickly converges . + +This repository provides source code of experiments on four datasets (CUB-200-2011, Cars-196, Stanford Online Products and In-shop) and pretrained models. + +#### Accuracy in Recall@1 versus training time on the Cars-196 + +
+ + + +## Requirements + +- Python3 +- PyTorch (> 1.0) +- NumPy +- tqdm +- wandb +- [Pytorch-Metric-Learning](https://github.com/KevinMusgrave/pytorch-metric-learning) + + + +## Datasets + +1. Download four public benchmarks for deep metric learning + - [CUB-200-2011](http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz) + - Cars-196 ([Img](http://imagenet.stanford.edu/internal/car196/car_ims.tgz), [Annotation](http://imagenet.stanford.edu/internal/car196/cars_annos.mat)) + - [Stanford Online Products](ftp://cs.stanford.edu/cs/cvgl/Stanford_Online_Products.zip) + - In-shop Clothes Retrieval ([Link](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html)) + +2. Extract the tgz or zip file into `./data/` (Exceptionally, for Cars-196, put the files in a `./data/cars196`) + + + +## Training Embedding Network + +Note that a sufficiently large batch size and good parameters resulted in better overall performance than the performance described in the paper. You can download the trained model through the hyperlink in the table. + +### CUB-200-2011 + +- Train a embedding network of Inception-BN (d=512) using **Proxy-Anchor loss** + +```bash +python train.py --gpu-id 0 --loss Proxy_Anchor--model bn_inception --embedding-size 512 --batch-size 180 --lr 1e-4 --dataset cub --warm 1 --bn-freeze 1 --lr-decay-step 10 +``` + +- Train a embedding network of ResNet-50 (d=512) using **Proxy-Anchor loss** + +```bash +python train.py --gpu-id 0 --loss Proxy_Anchor --model resnet50 --embedding-size 512 --batch-size 120 --lr 1e-4 --dataset cub --warm 5 --bn-freeze 1 --lr-decay-step 5 +``` + +| Method | Backbone | R@1 | R@2 | R@4 | R@8 | +|:-:|:-:|:-:|:-:|:-:|:-:| +| [Proxy-Anchor512](https://drive.google.com/file/d/1twaY6S2QIR8eanjDB6PoVPlCTsn-6ZJW/view?usp=sharing) | Inception-BN | 69.1 | 78.9 | 86.1 | 91.2 | +| [Proxy-Anchor512](https://drive.google.com/file/d/1s-cRSEL2PhPFL9S7bavkrD_c59bJXL_u/view?usp=sharing) | ResNet-50 | 69.9 | 79.6 | 86.6 | 91.4 | + +### Cars-196 + +- Train a embedding network of Inception-BN (d=512) using **Proxy-Anchor loss** + +```bash +python train.py --gpu-id 0 --loss Proxy_Anchor --model bn_inception --embedding-size 512 --batch-size 180 --lr 1e-4 --dataset cars --warm 1 --bn-freeze 1 --lr-decay-step 20 +``` + +- Train a embedding network of ResNet-50 (d=512) using **Proxy-Anchor loss** + +```bash +python train.py --gpu-id 0 --loss Proxy_Anchor --model resnet50 --embedding-size 512 --batch-size 120 --lr 1e-4 --dataset cars --warm 5 --bn-freeze 1 --lr-decay-step 10 +``` + +| Method | Backbone | R@1 | R@2 | R@4 | R@8 | +|:-:|:-:|:-:|:-:|:-:|:-:| +| [Proxy-Anchor512](https://drive.google.com/file/d/1wwN4ojmOCEAOaSYQHArzJbNdJQNvo4E1/view?usp=sharing) | Inception-BN | 86.4 | 91.9 | 95.0 | 97.0 | +| [Proxy-Anchor512](https://drive.google.com/file/d/1_4P90jZcDr0xolRduNpgJ9tX9HZ1Ih7n/view?usp=sharing) | ResNet-50 | 87.7 | 92.7 | 95.5 | 97.3 | + +### Stanford Online Products + +- Train a embedding network of Inception-BN (d=512) using **Proxy-Anchor loss** + +```bash +python train.py --gpu-id 0 --loss Proxy_Anchor --model bn_inception --embedding-size 512 --batch-size 180 --lr 6e-4 --dataset SOP --warm 1 --bn-freeze 0 --l2-norm 1 --lr-decay-step 20 --lr-decay-gamma 0.25 +``` + +| Method | Backbone | R@1 | R@10 | R@100 | R@1000 | +|:-:|:-:|:-:|:-:|:-:|:-:| +|[Proxy-Anchor512](https://drive.google.com/file/d/1hBdWhLP2J83JlOMRgZ4LLZY45L-9Gj2X/view?usp=sharing) | Inception-BN | 79.2 | 90.7 | 96.2 | 98.6 | + +### In-Shop Clothes Retrieval + +- Train a embedding network of Inception-BN (d=512) using **Proxy-Anchor loss** + +```bash +python train.py --gpu-id 0 --loss Proxy_Anchor --model bn_inception --embedding-size 512 --batch-size 180 --lr 6e-4 --dataset Inshop --warm 1 --bn-freeze 0 --l2-norm 1 --lr-decay-step 20 --lr-decay-gamma 0.25 +``` + +| Method | Backbone | R@1 | R@10 | R@20 | R@30 | R@40 | +|:-:|:-:|:-:|:-:|:-:|:-:|:-:| +| [Proxy-Anchor512](https://drive.google.com/file/d/1VE7psay7dblDyod8di72Sv7Z2xGtUGra/view?usp=sharing) | Inception-BN | 91.9 | 98.1 | 98.7 | 99.0 | 99.1 | + + + +## Evaluating Image Retrieval + +Follow the steps below to evaluate the provided pretrained model or your trained model. Trained best model will be saved in the `./logs/folder_name`. + +```bash +# The parameters should be changed according to the model to be evaluated. +python evaluate.py --gpu-id 0 --batch-size 120 --model bn_inception --embedding-size 512 --dataset cub --resume /set/your/model/path/best_model.pth +``` + + + +## Acknowledgements + +Our code is modified and adapted on these great repositories: + +- [No Fuss Distance Metric Learning using Proxies](https://github.com/dichotomies/proxy-nca) +- [PyTorch Metric learning](https://github.com/KevinMusgrave/pytorch-metric-learning) + + + +## Other Implementations + +- [Pytorch, Tensorflow and Mxnet implementations](https://github.com/geonm/proxy-anchor-loss) (Thank you for Geonmo Gu :D) + + + +## Citation + +If you use this method or this code in your research, please cite as: + + @inproceedings{kim2020proxy, + title={Proxy Anchor Loss for Deep Metric Learning}, + author={Kim, Sungyeon and Kim, Dongwon and Cho, Minsu and Kwak, Suha}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + year={2020} + } + diff --git a/code/dataset/Inshop.py b/code/dataset/Inshop.py new file mode 100644 index 00000000..bfa6c902 --- /dev/null +++ b/code/dataset/Inshop.py @@ -0,0 +1,72 @@ +from .base import * + +import numpy as np, os, sys, pandas as pd, csv, copy +import torch +import torchvision +import PIL.Image + + +class Inshop_Dataset(torch.utils.data.Dataset): + def __init__(self, root, mode, transform = None): + self.root = root + '/Inshop_Clothes' + self.mode = mode + self.transform = transform + self.train_ys, self.train_im_paths = [], [] + self.query_ys, self.query_im_paths = [], [] + self.gallery_ys, self.gallery_im_paths = [], [] + + data_info = np.array(pd.read_table(self.root +'/Eval/list_eval_partition.txt', header=1, delim_whitespace=True))[:,:] + #Separate into training dataset and query/gallery dataset for testing. + train, query, gallery = data_info[data_info[:,2]=='train'][:,:2], data_info[data_info[:,2]=='query'][:,:2], data_info[data_info[:,2]=='gallery'][:,:2] + + #Generate conversions + lab_conv = {x:i for i,x in enumerate(np.unique(np.array([int(x.split('_')[-1]) for x in train[:,1]])))} + train[:,1] = np.array([lab_conv[int(x.split('_')[-1])] for x in train[:,1]]) + + lab_conv = {x:i for i,x in enumerate(np.unique(np.array([int(x.split('_')[-1]) for x in np.concatenate([query[:,1], gallery[:,1]])])))} + query[:,1] = np.array([lab_conv[int(x.split('_')[-1])] for x in query[:,1]]) + gallery[:,1] = np.array([lab_conv[int(x.split('_')[-1])] for x in gallery[:,1]]) + + #Generate Image-Dicts for training, query and gallery of shape {class_idx:[list of paths to images belong to this class] ...} + for img_path, key in train: + self.train_im_paths.append(os.path.join(self.root, 'Img', img_path)) + self.train_ys += [int(key)] + + for img_path, key in query: + self.query_im_paths.append(os.path.join(self.root, 'Img', img_path)) + self.query_ys += [int(key)] + + for img_path, key in gallery: + self.gallery_im_paths.append(os.path.join(self.root, 'Img', img_path)) + self.gallery_ys += [int(key)] + + if self.mode == 'train': + self.im_paths = self.train_im_paths + self.ys = self.train_ys + elif self.mode == 'query': + self.im_paths = self.query_im_paths + self.ys = self.query_ys + elif self.mode == 'gallery': + self.im_paths = self.gallery_im_paths + self.ys = self.gallery_ys + + def nb_classes(self): + return len(set(self.ys)) + + def __len__(self): + return len(self.ys) + + def __getitem__(self, index): + + def img_load(index): + im = PIL.Image.open(self.im_paths[index]) + # convert gray to rgb + if len(list(im.split())) == 1 : im = im.convert('RGB') + if self.transform is not None: + im = self.transform(im) + return im + + im = img_load(index) + target = self.ys[index] + + return im, target diff --git a/code/dataset/SOP.py b/code/dataset/SOP.py new file mode 100644 index 00000000..be81919b --- /dev/null +++ b/code/dataset/SOP.py @@ -0,0 +1,20 @@ +from .base import * + +class SOP(BaseDataset): + def __init__(self, root, mode, transform = None): + self.root = root + '/Stanford_Online_Products' + self.mode = mode + self.transform = transform + if self.mode == 'train': + self.classes = range(0,11318) + elif self.mode == 'eval': + self.classes = range(11318,22634) + + BaseDataset.__init__(self, self.root, self.mode, self.transform) + metadata = open(os.path.join(self.root, 'Ebay_train.txt' if self.classes == range(0, 11318) else 'Ebay_test.txt')) + for i, (image_id, class_id, _, path) in enumerate(map(str.split, metadata)): + if i > 0: + if int(class_id)-1 in self.classes: + self.ys += [int(class_id)-1] + self.I += [int(image_id)-1] + self.im_paths.append(os.path.join(self.root, path)) \ No newline at end of file diff --git a/code/dataset/__init__.py b/code/dataset/__init__.py new file mode 100644 index 00000000..ee74a180 --- /dev/null +++ b/code/dataset/__init__.py @@ -0,0 +1,16 @@ +from .cars import Cars +from .cub import CUBirds +from .SOP import SOP +from .import utils +from .base import BaseDataset + + +_type = { + 'cars': Cars, + 'cub': CUBirds, + 'SOP': SOP +} + +def load(name, root, mode, transform = None): + return _type[name](root = root, mode = mode, transform = transform) + diff --git a/code/dataset/base.py b/code/dataset/base.py new file mode 100644 index 00000000..aaf450e2 --- /dev/null +++ b/code/dataset/base.py @@ -0,0 +1,45 @@ + +from __future__ import print_function +from __future__ import division + +import os +import torch +import torchvision +import numpy as np +import PIL.Image + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, root, mode, transform = None): + self.root = root + self.mode = mode + self.transform = transform + self.ys, self.im_paths, self.I = [], [], [] + + def nb_classes(self): + assert set(self.ys) == set(self.classes) + return len(self.classes) + + def __len__(self): + return len(self.ys) + + def __getitem__(self, index): + def img_load(index): + im = PIL.Image.open(self.im_paths[index]) + # convert gray to rgb + if len(list(im.split())) == 1 : im = im.convert('RGB') + if self.transform is not None: + im = self.transform(im) + return im + + im = img_load(index) + target = self.ys[index] + + return im, target + + def get_label(self, index): + return self.ys[index] + + def set_subset(self, I): + self.ys = [self.ys[i] for i in I] + self.I = [self.I[i] for i in I] + self.im_paths = [self.im_paths[i] for i in I] \ No newline at end of file diff --git a/code/dataset/cars.py b/code/dataset/cars.py new file mode 100644 index 00000000..98e8f112 --- /dev/null +++ b/code/dataset/cars.py @@ -0,0 +1,25 @@ +from .base import * +import scipy.io + +class Cars(BaseDataset): + def __init__(self, root, mode, transform = None): + self.root = root + '/cars196' + self.mode = mode + self.transform = transform + if self.mode == 'train': + self.classes = range(0,98) + elif self.mode == 'eval': + self.classes = range(98,196) + + BaseDataset.__init__(self, self.root, self.mode, self.transform) + annos_fn = 'cars_annos.mat' + cars = scipy.io.loadmat(os.path.join(self.root, annos_fn)) + ys = [int(a[5][0] - 1) for a in cars['annotations'][0]] + im_paths = [a[0][0] for a in cars['annotations'][0]] + index = 0 + for im_path, y in zip(im_paths, ys): + if y in self.classes: # choose only specified classes + self.im_paths.append(os.path.join(self.root, im_path)) + self.ys.append(y) + self.I += [index] + index += 1 \ No newline at end of file diff --git a/code/dataset/cub.py b/code/dataset/cub.py new file mode 100644 index 00000000..c6460bf0 --- /dev/null +++ b/code/dataset/cub.py @@ -0,0 +1,25 @@ +from .base import * + +class CUBirds(BaseDataset): + def __init__(self, root, mode, transform = None): + self.root = root + '/CUB_200_2011' + self.mode = mode + self.transform = transform + if self.mode == 'train': + self.classes = range(0,100) + elif self.mode == 'eval': + self.classes = range(100,200) + + BaseDataset.__init__(self, self.root, self.mode, self.transform) + index = 0 + for i in torchvision.datasets.ImageFolder(root = + os.path.join(self.root, 'images')).imgs: + # i[1]: label, i[0]: root + y = i[1] + # fn needed for removing non-images starting with `._` + fn = os.path.split(i[0])[1] + if y in self.classes and fn[:2] != '._': + self.ys += [y] + self.I += [index] + self.im_paths.append(os.path.join(self.root, i[0])) + index += 1 \ No newline at end of file diff --git a/code/dataset/sampler.py b/code/dataset/sampler.py new file mode 100644 index 00000000..6b291b55 --- /dev/null +++ b/code/dataset/sampler.py @@ -0,0 +1,31 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data.sampler import Sampler +from tqdm import * + +class BalancedSampler(Sampler): + def __init__(self, data_source, batch_size, images_per_class=3): + self.data_source = data_source + self.ys = data_source.ys + self.num_groups = batch_size // images_per_class + self.batch_size = batch_size + self.num_instances = images_per_class + self.num_samples = len(self.ys) + self.num_classes = len(set(self.ys)) + + def __len__(self): + return self.num_samples + + def __iter__(self): + num_batches = len(self.data_source) // self.batch_size + ret = [] + while num_batches > 0: + sampled_classes = np.random.choice(self.num_classes, self.num_groups, replace=False) + for i in range(len(sampled_classes)): + ith_class_idxs = np.nonzero(np.array(self.ys) == sampled_classes[i])[0] + class_sel = np.random.choice(ith_class_idxs, size=self.num_instances, replace=True) + ret.extend(np.random.permutation(class_sel)) + num_batches -= 1 + return iter(ret) + \ No newline at end of file diff --git a/code/dataset/utils.py b/code/dataset/utils.py new file mode 100644 index 00000000..0d410f52 --- /dev/null +++ b/code/dataset/utils.py @@ -0,0 +1,100 @@ +from __future__ import print_function +from __future__ import division + +import torchvision +from torchvision import transforms +import PIL.Image +import torch +import random + +def std_per_channel(images): + images = torch.stack(images, dim = 0) + return images.view(3, -1).std(dim = 1) + + +def mean_per_channel(images): + images = torch.stack(images, dim = 0) + return images.view(3, -1).mean(dim = 1) + + +class Identity(): # used for skipping transforms + def __call__(self, im): + return im + +class print_shape(): + def __call__(self, im): + print(im.size) + return im + +class RGBToBGR(): + def __call__(self, im): + assert im.mode == 'RGB' + r, g, b = [im.getchannel(i) for i in range(3)] + # RGB mode also for BGR, `3x8-bit pixels, true color`, see PIL doc + im = PIL.Image.merge('RGB', [b, g, r]) + return im + +class pad_shorter(): + def __call__(self, im): + h,w = im.size[-2:] + s = max(h, w) + new_im = PIL.Image.new("RGB", (s, s)) + new_im.paste(im, ((s-h)//2, (s-w)//2)) + return new_im + + +class ScaleIntensities(): + def __init__(self, in_range, out_range): + """ Scales intensities. For example [-1, 1] -> [0, 255].""" + self.in_range = in_range + self.out_range = out_range + + def __oldcall__(self, tensor): + tensor.mul_(255) + return tensor + + def __call__(self, tensor): + tensor = ( + tensor - self.in_range[0] + ) / ( + self.in_range[1] - self.in_range[0] + ) * ( + self.out_range[1] - self.out_range[0] + ) + self.out_range[0] + return tensor + + +def make_transform(is_train = True, is_inception = False): + # Resolution Resize List : 256, 292, 361, 512 + # Resolution Crop List: 224, 256, 324, 448 + + resnet_sz_resize = 256 + resnet_sz_crop = 224 + resnet_mean = [0.485, 0.456, 0.406] + resnet_std = [0.229, 0.224, 0.225] + resnet_transform = transforms.Compose([ + transforms.RandomResizedCrop(resnet_sz_crop) if is_train else Identity(), + transforms.RandomHorizontalFlip() if is_train else Identity(), + transforms.Resize(resnet_sz_resize) if not is_train else Identity(), + transforms.CenterCrop(resnet_sz_crop) if not is_train else Identity(), + transforms.ToTensor(), + transforms.Normalize(mean=resnet_mean, std=resnet_std) + ]) + + inception_sz_resize = 256 + inception_sz_crop = 224 + inception_mean = [104, 117, 128] + inception_std = [1, 1, 1] + inception_transform = transforms.Compose( + [ + RGBToBGR(), + transforms.RandomResizedCrop(inception_sz_crop) if is_train else Identity(), + transforms.RandomHorizontalFlip() if is_train else Identity(), + transforms.Resize(inception_sz_resize) if not is_train else Identity(), + transforms.CenterCrop(inception_sz_crop) if not is_train else Identity(), + transforms.ToTensor(), + ScaleIntensities([0, 1], [0, 255]), + transforms.Normalize(mean=inception_mean, std=inception_std) + ]) + + return inception_transform if is_inception else resnet_transform \ No newline at end of file diff --git a/code/evaluate.py b/code/evaluate.py new file mode 100644 index 00000000..8a67c9fc --- /dev/null +++ b/code/evaluate.py @@ -0,0 +1,156 @@ +import torch, math, time, argparse, json, os, sys +import random, dataset, utils, losses, net +import numpy as np +import matplotlib.pyplot as plt + +from dataset.Inshop import Inshop_Dataset +from net.resnet import * +from net.googlenet import * +from net.bn_inception import * +from dataset import sampler +from torch.utils.data.sampler import BatchSampler +from torch.utils.data.dataloader import default_collate + +from tqdm import * +import wandb + +seed = 1 +random.seed(seed) +np.random.seed(seed) +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) # set random seed for all gpus + +parser = argparse.ArgumentParser(description= + 'Official implementation of `Proxy Anchor Loss for Deep Metric Learning`' + + 'Our code is modified from `https://github.com/dichotomies/proxy-nca`' +) +parser.add_argument('--dataset', + default='cub', + help = 'Training dataset, e.g. cub, cars, SOP, Inshop' +) +parser.add_argument('--embedding-size', default = 512, type = int, + dest = 'sz_embedding', + help = 'Size of embedding that is appended to backbone model.' +) +parser.add_argument('--batch-size', default = 150, type = int, + dest = 'sz_batch', + help = 'Number of samples per batch.' +) +parser.add_argument('--gpu-id', default = 0, type = int, + help = 'ID of GPU that is used for training.' +) +parser.add_argument('--workers', default = 4, type = int, + dest = 'nb_workers', + help = 'Number of workers for dataloader.' +) +parser.add_argument('--model', default = 'bn_inception', + help = 'Model for training' +) +parser.add_argument('--l2-norm', default = 1, type = int, + help = 'L2 normlization' +) +parser.add_argument('--resume', default = '', + help = 'Path of resuming model' +) +parser.add_argument('--remark', default = '', + help = 'Any reamrk' +) + +args = parser.parse_args() + +if args.gpu_id != -1: + torch.cuda.set_device(args.gpu_id) + +# Data Root Directory +os.chdir('../data/') +data_root = os.getcwd() + +# Dataset Loader and Sampler +if args.dataset != 'Inshop': + ev_dataset = dataset.load( + name = args.dataset, + root = data_root, + mode = 'eval', + transform = dataset.utils.make_transform( + is_train = False, + is_inception = (args.model == 'bn_inception') + )) + + dl_ev = torch.utils.data.DataLoader( + ev_dataset, + batch_size = args.sz_batch, + shuffle = False, + num_workers = args.nb_workers, + pin_memory = True + ) + +else: + query_dataset = Inshop_Dataset( + root = data_root, + mode = 'query', + transform = dataset.utils.make_transform( + is_train = False, + is_inception = (args.model == 'bn_inception') + )) + + dl_query = torch.utils.data.DataLoader( + query_dataset, + batch_size = args.sz_batch, + shuffle = False, + num_workers = args.nb_workers, + pin_memory = True + ) + + gallery_dataset = Inshop_Dataset( + root = data_root, + mode = 'gallery', + transform = dataset.utils.make_transform( + is_train = False, + is_inception = (args.model == 'bn_inception') + )) + + dl_gallery = torch.utils.data.DataLoader( + gallery_dataset, + batch_size = args.sz_batch, + shuffle = False, + num_workers = args.nb_workers, + pin_memory = True + ) + +# Backbone Model +if args.model.find('googlenet')+1: + model = googlenet(embedding_size=args.sz_embedding, pretrained=True, is_norm=args.l2_norm, bn_freeze = 1) +elif args.model.find('bn_inception')+1: + model = bn_inception(embedding_size=args.sz_embedding, pretrained=True, is_norm=args.l2_norm, bn_freeze = 1) +elif args.model.find('resnet18')+1: + model = Resnet18(embedding_size=args.sz_embedding, pretrained=True, is_norm=args.l2_norm, bn_freeze = 1) +elif args.model.find('resnet50')+1: + model = Resnet50(embedding_size=args.sz_embedding, pretrained=True, is_norm=args.l2_norm, bn_freeze = 1) +elif args.model.find('resnet101')+1: + model = Resnet101(embedding_size=args.sz_embedding, pretrained=True, is_norm=args.l2_norm, bn_freeze = 1) +model = model.cuda() + +if args.gpu_id == -1: + model = nn.DataParallel(model) + +if os.path.isfile(args.resume): + print('=> loading checkpoint {}'.format(args.resume)) + checkpoint = torch.load(args.resume) + model.load_state_dict(checkpoint['model_state_dict']) +else: + print('=> No checkpoint found at {}'.format(args.resume)) + sys.exit(0) + +with torch.no_grad(): + print("**Evaluating...**") + if args.dataset == 'Inshop': + NMI = 0 + Recalls = utils.evaluate_cos_Inshop(model, dl_query, dl_gallery) + + elif args.dataset != 'SOP': + Recalls = utils.evaluate_cos(model, dl_ev) + + else: + Recalls = utils.evaluate_cos_SOP(model, dl_ev) + + \ No newline at end of file diff --git a/code/losses.py b/code/losses.py new file mode 100644 index 00000000..a6296525 --- /dev/null +++ b/code/losses.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import random +from pytorch_metric_learning import miners, losses + +def binarize(T, nb_classes): + T = T.cpu().numpy() + import sklearn.preprocessing + T = sklearn.preprocessing.label_binarize( + T, classes = range(0, nb_classes) + ) + T = torch.FloatTensor(T).cuda() + return T + +def l2_norm(input): + input_size = input.size() + buffer = torch.pow(input, 2) + normp = torch.sum(buffer, 1).add_(1e-12) + norm = torch.sqrt(normp) + _output = torch.div(input, norm.view(-1, 1).expand_as(input)) + output = _output.view(input_size) + return output + +class Proxy_Anchor(torch.nn.Module): + def __init__(self, nb_classes, sz_embed, mrg = 0.1, alpha = 32): + torch.nn.Module.__init__(self) + # Proxy Anchor Initialization + self.proxies = torch.nn.Parameter(torch.randn(nb_classes, sz_embed).cuda()) + nn.init.kaiming_normal_(self.proxies, mode='fan_out') + + self.nb_classes = nb_classes + self.sz_embed = sz_embed + self.mrg = mrg + self.alpha = alpha + + def forward(self, X, T): + P = self.proxies + + cos = F.linear(l2_norm(X), l2_norm(P)) # Calcluate cosine similarity + P_one_hot = binarize(T = T, nb_classes = self.nb_classes) + N_one_hot = 1 - P_one_hot + + pos_exp = torch.exp(-self.alpha * (cos - self.mrg)) + neg_exp = torch.exp(self.alpha * (cos + self.mrg)) + + with_pos_proxies = torch.nonzero(P_one_hot.sum(dim = 0) != 0).squeeze(dim = 1) # The set of positive proxies of data in the batch + num_valid_proxies = len(with_pos_proxies) # The number of positive proxies + + P_sim_sum = torch.where(P_one_hot == 1, pos_exp, torch.zeros_like(pos_exp)).sum(dim=0) + N_sim_sum = torch.where(N_one_hot == 1, neg_exp, torch.zeros_like(neg_exp)).sum(dim=0) + + pos_term = torch.log(1 + P_sim_sum).sum() / num_valid_proxies + neg_term = torch.log(1 + N_sim_sum).sum() / self.nb_classes + loss = pos_term + neg_term + + return loss + +# We use PyTorch Metric Learning library for the following codes. +# Please refer to "https://github.com/KevinMusgrave/pytorch-metric-learning" for details. +class Proxy_NCA(torch.nn.Module): + def __init__(self, nb_classes, sz_embed, scale=32): + super(Proxy_NCA, self).__init__() + self.nb_classes = nb_classes + self.sz_embed = sz_embed + self.scale = scale + self.loss_func = losses.ProxyNCALoss(num_classes = self.nb_classes, embedding_size = self.sz_embed, softmax_scale = self.scale).cuda() + + def forward(self, embeddings, labels): + loss = self.loss_func(embeddings, labels) + return loss + +class MultiSimilarityLoss(torch.nn.Module): + def __init__(self, ): + super(MultiSimilarityLoss, self).__init__() + self.thresh = 0.5 + self.epsilon = 0.1 + self.scale_pos = 2 + self.scale_neg = 50 + + self.miner = miners.MultiSimilarityMiner(epsilon=self.epsilon) + self.loss_func = losses.MultiSimilarityLoss(self.scale_pos, self.scale_neg, self.thresh) + + def forward(self, embeddings, labels): + hard_pairs = self.miner(embeddings, labels) + loss = self.loss_func(embeddings, labels, hard_pairs) + return loss + +class ContrastiveLoss(nn.Module): + def __init__(self, margin=0.5, **kwargs): + super(ContrastiveLoss, self).__init__() + self.margin = margin + self.loss_func = losses.ContrastiveLoss(neg_margin=self.margin) + + def forward(self, embeddings, labels): + loss = self.loss_func(embeddings, labels) + return loss + +class TripletLoss(nn.Module): + def __init__(self, margin=0.1, **kwargs): + super(TripletLoss, self).__init__() + self.margin = margin + self.miner = miners.TripletMarginMiner(margin, type_of_triplets = 'semihard') + self.loss_func = losses.TripletMarginLoss(margin = self.margin) + + def forward(self, embeddings, labels): + hard_pairs = self.miner(embeddings, labels) + loss = self.loss_func(embeddings, labels, hard_pairs) + return loss + +class NPairLoss(nn.Module): + def __init__(self, l2_reg=0): + super(NPairLoss, self).__init__() + self.l2_reg = l2_reg + self.loss_func = losses.NPairsLoss(l2_reg_weight=self.l2_reg, normalize_embeddings = False) + + def forward(self, embeddings, labels): + loss = self.loss_func(embeddings, labels) + return loss \ No newline at end of file diff --git a/code/net/bn_inception.py b/code/net/bn_inception.py new file mode 100644 index 00000000..bbbc8306 --- /dev/null +++ b/code/net/bn_inception.py @@ -0,0 +1,530 @@ +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +import random + +__all__ = ['BNInception', 'bn_inception'] + +""" +Inception v2 was ported from Caffee to pytorch 0.2, see +https://github.com/Cadene/pretrained-models.pytorch. I've ported it to +PyTorch 0.4 for the Proxy-NCA implementation, see +https://github.com/dichotomies/proxy-nca. +""" + +class bn_inception(nn.Module): + def __init__(self, embedding_size, pretrained = True, is_norm=True, bn_freeze = True): + super(bn_inception, self).__init__() + self.model = BNInception(embedding_size, pretrained, is_norm) + if pretrained: +# weight = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/bn_inception-239d2248.pth') + weight = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/bn_inception-52deb4733.pth') + weight = {k: v.squeeze(0) if v.size(0) == 1 else v for k, v in weight.items()} + self.model.load_state_dict(weight) + + self.model.gap = nn.AdaptiveAvgPool2d(1) + self.model.gmp = nn.AdaptiveMaxPool2d(1) + + self.model.embedding = nn.Linear(self.model.num_ftrs, self.model.embedding_size) + init.kaiming_normal_(self.model.embedding.weight, mode='fan_out') + init.constant_(self.model.embedding.bias, 0) + + if bn_freeze: + for m in self.model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + m.weight.requires_grad_(False) + m.bias.requires_grad_(False) + + + def forward(self, input): + return self.model.forward(input) + +class BNInception(nn.Module): + + def __init__(self, embedding_size, pretrained = True, is_norm=True): + super(BNInception, self).__init__() + + inplace = True + self.embedding_size = embedding_size + self.num_ftrs = 1024 + + self.is_norm = is_norm + + self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)) + self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.conv1_relu_7x7 = nn.ReLU (inplace) + self.pool1_3x3_s2 = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) + self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) + self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.conv2_relu_3x3_reduce = nn.ReLU (inplace) + self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.conv2_3x3_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.conv2_relu_3x3 = nn.ReLU (inplace) + self.pool2_3x3_s2 = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) + self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3a_1x1_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_1x1 = nn.ReLU (inplace) + self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3a_3x3_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_3x3 = nn.ReLU (inplace) + self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_3a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_pool_proj = nn.ReLU (inplace) + self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3b_1x1_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_1x1 = nn.ReLU (inplace) + self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3b_3x3_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_3x3 = nn.ReLU (inplace) + self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_3b_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_pool_proj = nn.ReLU (inplace) + self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_3c_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.inception_3c_3x3_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_3c_relu_3x3 = nn.ReLU (inplace) + self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3c_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3c_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3c_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_3c_pool = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) + self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4a_1x1_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_1x1 = nn.ReLU (inplace) + self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4a_3x3_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_3x3 = nn.ReLU (inplace) + self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_pool_proj = nn.ReLU (inplace) + self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4b_1x1_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_1x1 = nn.ReLU (inplace) + self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4b_3x3_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_3x3 = nn.ReLU (inplace) + self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4b_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_pool_proj = nn.ReLU (inplace) + self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4c_1x1_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_1x1 = nn.ReLU (inplace) + self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4c_3x3_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_3x3 = nn.ReLU (inplace) + self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4c_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_pool_proj = nn.ReLU (inplace) + self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4d_1x1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_1x1 = nn.ReLU (inplace) + self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4d_3x3_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_3x3 = nn.ReLU (inplace) + self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4d_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_pool_proj = nn.ReLU (inplace) + self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4e_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.inception_4e_3x3_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_4e_relu_3x3 = nn.ReLU (inplace) + self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_4e_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True) + self.inception_4e_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True) + self.inception_4e_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4e_pool = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) + self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5a_1x1_bn = nn.BatchNorm2d(352, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_1x1 = nn.ReLU (inplace) + self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5a_3x3_bn = nn.BatchNorm2d(320, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_3x3 = nn.ReLU (inplace) + self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_5a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_pool_proj = nn.ReLU (inplace) + self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5b_1x1_bn = nn.BatchNorm2d(352, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_1x1 = nn.ReLU (inplace) + self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5b_3x3_bn = nn.BatchNorm2d(320, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_3x3 = nn.ReLU (inplace) + self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_5b_pool = nn.MaxPool2d ((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True) + self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_pool_proj = nn.ReLU (inplace) + self.global_pool = nn.AvgPool2d(7, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + self.last_linear = nn.Linear(1024, 1000) + + def features(self, input): + conv1_7x7_s2_out = self.conv1_7x7_s2(input) + conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out) + conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out) + pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_7x7_s2_bn_out) + conv2_3x3_reduce_out = self.conv2_3x3_reduce(pool1_3x3_s2_out) + conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out) + conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out) + conv2_3x3_out = self.conv2_3x3(conv2_3x3_reduce_bn_out) + conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out) + conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out) + pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_3x3_bn_out) + inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out) + inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out) + inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out) + inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out) + inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out) + inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out) + inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_3x3_reduce_bn_out) + inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out) + inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out) + inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out) + inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn(inception_3a_double_3x3_reduce_out) + inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce(inception_3a_double_3x3_reduce_bn_out) + inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_double_3x3_reduce_bn_out) + inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out) + inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out) + inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_double_3x3_1_bn_out) + inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out) + inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out) + inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out) + inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out) + inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out) + inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out) + inception_3a_output_out = torch.cat([inception_3a_1x1_bn_out,inception_3a_3x3_bn_out,inception_3a_double_3x3_2_bn_out,inception_3a_pool_proj_bn_out], 1) + inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out) + inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out) + inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out) + inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out) + inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out) + inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out) + inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_3x3_reduce_bn_out) + inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out) + inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out) + inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out) + inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn(inception_3b_double_3x3_reduce_out) + inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce(inception_3b_double_3x3_reduce_bn_out) + inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_double_3x3_reduce_bn_out) + inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out) + inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out) + inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_double_3x3_1_bn_out) + inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out) + inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out) + inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out) + inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out) + inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out) + inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out) + inception_3b_output_out = torch.cat([inception_3b_1x1_bn_out,inception_3b_3x3_bn_out,inception_3b_double_3x3_2_bn_out,inception_3b_pool_proj_bn_out], 1) + inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out) + inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out) + inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out) + inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_3x3_reduce_bn_out) + inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out) + inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out) + inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out) + inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn(inception_3c_double_3x3_reduce_out) + inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce(inception_3c_double_3x3_reduce_bn_out) + inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_double_3x3_reduce_bn_out) + inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out) + inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out) + inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_double_3x3_1_bn_out) + inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out) + inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out) + inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out) + inception_3c_output_out = torch.cat([inception_3c_3x3_bn_out,inception_3c_double_3x3_2_bn_out,inception_3c_pool_out], 1) + inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out) + inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out) + inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out) + inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out) + inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out) + inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out) + inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_3x3_reduce_bn_out) + inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out) + inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out) + inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out) + inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn(inception_4a_double_3x3_reduce_out) + inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce(inception_4a_double_3x3_reduce_bn_out) + inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_double_3x3_reduce_bn_out) + inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out) + inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out) + inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_double_3x3_1_bn_out) + inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out) + inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out) + inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out) + inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out) + inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out) + inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out) + inception_4a_output_out = torch.cat([inception_4a_1x1_bn_out,inception_4a_3x3_bn_out,inception_4a_double_3x3_2_bn_out,inception_4a_pool_proj_bn_out], 1) + inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out) + inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out) + inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out) + inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out) + inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out) + inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out) + inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_3x3_reduce_bn_out) + inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out) + inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out) + inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out) + inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn(inception_4b_double_3x3_reduce_out) + inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce(inception_4b_double_3x3_reduce_bn_out) + inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_double_3x3_reduce_bn_out) + inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out) + inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out) + inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_double_3x3_1_bn_out) + inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out) + inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out) + inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out) + inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out) + inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out) + inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out) + inception_4b_output_out = torch.cat([inception_4b_1x1_bn_out,inception_4b_3x3_bn_out,inception_4b_double_3x3_2_bn_out,inception_4b_pool_proj_bn_out], 1) + inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out) + inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out) + inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out) + inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out) + inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out) + inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out) + inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_3x3_reduce_bn_out) + inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out) + inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out) + inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out) + inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn(inception_4c_double_3x3_reduce_out) + inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce(inception_4c_double_3x3_reduce_bn_out) + inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_double_3x3_reduce_bn_out) + inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out) + inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out) + inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_double_3x3_1_bn_out) + inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out) + inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out) + inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out) + inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out) + inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out) + inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out) + inception_4c_output_out = torch.cat([inception_4c_1x1_bn_out,inception_4c_3x3_bn_out,inception_4c_double_3x3_2_bn_out,inception_4c_pool_proj_bn_out], 1) + inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out) + inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out) + inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out) + inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out) + inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out) + inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out) + inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_3x3_reduce_bn_out) + inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out) + inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out) + inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out) + inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn(inception_4d_double_3x3_reduce_out) + inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce(inception_4d_double_3x3_reduce_bn_out) + inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_double_3x3_reduce_bn_out) + inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out) + inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out) + inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_double_3x3_1_bn_out) + inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out) + inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out) + inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out) + inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out) + inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out) + inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out) + inception_4d_output_out = torch.cat([inception_4d_1x1_bn_out,inception_4d_3x3_bn_out,inception_4d_double_3x3_2_bn_out,inception_4d_pool_proj_bn_out], 1) + inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out) + inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out) + inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out) + inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_3x3_reduce_bn_out) + inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out) + inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out) + inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out) + inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn(inception_4e_double_3x3_reduce_out) + inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce(inception_4e_double_3x3_reduce_bn_out) + inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_double_3x3_reduce_bn_out) + inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out) + inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out) + inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_double_3x3_1_bn_out) + inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out) + inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out) + inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out) + inception_4e_output_out = torch.cat([inception_4e_3x3_bn_out,inception_4e_double_3x3_2_bn_out,inception_4e_pool_out], 1) + inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out) + inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out) + inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out) + inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out) + inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out) + inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out) + inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_3x3_reduce_bn_out) + inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out) + inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out) + inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out) + inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn(inception_5a_double_3x3_reduce_out) + inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce(inception_5a_double_3x3_reduce_bn_out) + inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_double_3x3_reduce_bn_out) + inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out) + inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out) + inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_double_3x3_1_bn_out) + inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out) + inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out) + inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out) + inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out) + inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out) + inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out) + inception_5a_output_out = torch.cat([inception_5a_1x1_bn_out,inception_5a_3x3_bn_out,inception_5a_double_3x3_2_bn_out,inception_5a_pool_proj_bn_out], 1) + inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out) + inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out) + inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out) + inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out) + inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out) + inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out) + inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_3x3_reduce_bn_out) + inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out) + inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out) + inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out) + inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn(inception_5b_double_3x3_reduce_out) + inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce(inception_5b_double_3x3_reduce_bn_out) + inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_double_3x3_reduce_bn_out) + inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out) + inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out) + inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_double_3x3_1_bn_out) + inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out) + inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out) + inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out) + inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out) + inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out) + inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out) + inception_5b_output_out = torch.cat([inception_5b_1x1_bn_out,inception_5b_3x3_bn_out,inception_5b_double_3x3_2_bn_out,inception_5b_pool_proj_bn_out], 1) + return inception_5b_output_out + + def l2_norm(self,input): + input_size = input.size() + buffer = torch.pow(input, 2) + normp = torch.sum(buffer, 1).add_(1e-12) + norm = torch.sqrt(normp) + _output = torch.div(input, norm.view(-1, 1).expand_as(input)) + output = _output.view(input_size) + return output + + def forward(self, input): + x = self.features(input) + avg_x = self.gap(x) + max_x = self.gmp(x) + + x = avg_x + max_x + x = x.view(x.size(0), -1) + x = self.embedding(x) + + if self.is_norm: + x = self.l2_norm(x) + return x \ No newline at end of file diff --git a/code/net/googlenet.py b/code/net/googlenet.py new file mode 100644 index 00000000..4d9de087 --- /dev/null +++ b/code/net/googlenet.py @@ -0,0 +1,256 @@ +import torch +import torch.nn as nn +import math +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +import torch.nn.init as init +import torch.utils.model_zoo as model_zoo + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils import model_zoo + +__all__ = ['GoogLeNet', 'googlenet'] + +model_urls = { + # GoogLeNet ported from TensorFlow + 'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth', +} + +class GoogLeNet(nn.Module): + + def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True): + super(GoogLeNet, self).__init__() + self.aux_logits = aux_logits + self.transform_input = transform_input + + self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.conv2 = BasicConv2d(64, 64, kernel_size=1) + self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) + self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + + self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) + self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) + self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + + self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) + self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) + self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) + self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) + self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) + self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) + self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) + if aux_logits: + self.aux1 = InceptionAux(512, num_classes) + self.aux2 = InceptionAux(528, num_classes) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.dropout = nn.Dropout(0.4) + self.fc = nn.Linear(1024, num_classes) + + if init_weights: + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0.2) + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + if self.transform_input: + x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 + x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 + x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + + x = self.conv1(x) + x = self.maxpool1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.maxpool2(x) + + x = self.inception3a(x) + x = self.inception3b(x) + x = self.maxpool3(x) + x = self.inception4a(x) + if self.training and self.aux_logits: + aux1 = self.aux1(x) + + x = self.inception4b(x) + x = self.inception4c(x) + x = self.inception4d(x) + if self.training and self.aux_logits: + aux2 = self.aux2(x) + + x = self.inception4e(x) + x = self.maxpool4(x) + x = self.inception5a(x) + x = self.inception5b(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.dropout(x) + x = self.fc(x) + if self.training and self.aux_logits: + return aux1, aux2, x + return x + + +class Inception(nn.Module): + + def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): + super(Inception, self).__init__() + + self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) + + self.branch2 = nn.Sequential( + BasicConv2d(in_channels, ch3x3red, kernel_size=1), + BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) + ) + + self.branch3 = nn.Sequential( + BasicConv2d(in_channels, ch5x5red, kernel_size=1), + BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1) + ) + + self.branch4 = nn.Sequential( + nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), + BasicConv2d(in_channels, pool_proj, kernel_size=1) + ) + + def forward(self, x): + branch1 = self.branch1(x) + branch2 = self.branch2(x) + branch3 = self.branch3(x) + branch4 = self.branch4(x) + + outputs = [branch1, branch2, branch3, branch4] + return torch.cat(outputs, 1) + + +class InceptionAux(nn.Module): + + def __init__(self, in_channels, num_classes): + super(InceptionAux, self).__init__() + self.conv = BasicConv2d(in_channels, 128, kernel_size=1) + + self.fc1 = nn.Linear(2048, 1024) + self.fc2 = nn.Linear(1024, num_classes) + + def forward(self, x): + x = F.adaptive_avg_pool2d(x, (4, 4)) + x = self.conv(x) + x = x.view(x.size(0), -1) + x = F.relu(self.fc1(x), inplace=True) + x = F.dropout(x, 0.7, training=self.training) + x = self.fc2(x) + + return x + + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) + + +class googlenet(nn.Module): + def __init__(self,embedding_size, pretrained=True, is_norm=True, bn_freeze = True): + super(googlenet, self).__init__() + + self.model = GoogLeNet() + if pretrained: + self.model.load_state_dict(model_zoo.load_url(model_urls['googlenet']),strict=False) + + self.transform_input=False + self.is_norm = is_norm + self.embedding_size = embedding_size + self.num_ftrs = self.model.fc.in_features + self.model.embedding = nn.Linear(self.num_ftrs, self.embedding_size) + self.model.gap = nn.AdaptiveAvgPool2d(1) + self.model.gmp = nn.AdaptiveMaxPool2d(1) + + self._initialize_weights() + + if bn_freeze: + for m in self.model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + m.weight.requires_grad_(False) + m.bias.requires_grad_(False) + + + def l2_norm(self,input): + input_size = input.size() + buffer = torch.pow(input, 2) + normp = torch.sum(buffer, 1).add_(1e-5) + norm = torch.sqrt(normp) + _output = torch.div(input, norm.view(-1, 1).expand_as(input)) + output = _output.view(input_size) + + return output + + def forward(self, x): + if self.transform_input: + x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 + x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 + x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + + x = self.model.conv1(x) + x = self.model.maxpool1(x) + x = self.model.conv2(x) + x = self.model.conv3(x) + x = self.model.maxpool2(x) + + x = self.model.inception3a(x) + x = self.model.inception3b(x) + x = self.model.maxpool3(x) + x = self.model.inception4a(x) + + x = self.model.inception4b(x) + x = self.model.inception4c(x) + x = self.model.inception4d(x) + + x = self.model.inception4e(x) + x = self.model.maxpool4(x) + x = self.model.inception5a(x) + x = self.model.inception5b(x) + + avg_x = self.model.gap(x) + max_x = self.model.gmp(x) + + x = max_x + avg_x + x = x.view(x.size(0), -1) + x = self.model.embedding(x) + + if self.is_norm: + x = self.l2_norm(x) + + self.features = x + + return self.features + + def _initialize_weights(self): + init.kaiming_normal_(self.model.embedding.weight, mode='fan_out') + init.constant_(self.model.embedding.bias, 0) \ No newline at end of file diff --git a/code/net/resnet.py b/code/net/resnet.py new file mode 100644 index 00000000..41a32de3 --- /dev/null +++ b/code/net/resnet.py @@ -0,0 +1,254 @@ +import torch +import torch.nn as nn +import math +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +import torch.nn.init as init +from torchvision.models import resnet18 +from torchvision.models import resnet34 +from torchvision.models import resnet50 +from torchvision.models import resnet101 +import torch.utils.model_zoo as model_zoo + +class Resnet18(nn.Module): + def __init__(self,embedding_size, pretrained=True, is_norm=True, bn_freeze = True): + super(Resnet18, self).__init__() + + self.model = resnet18(pretrained) + self.is_norm = is_norm + self.embedding_size = embedding_size + self.num_ftrs = self.model.fc.in_features + self.model.gap = nn.AdaptiveAvgPool2d(1) + self.model.gmp = nn.AdaptiveMaxPool2d(1) + + self.model.embedding = nn.Linear(self.num_ftrs, self.embedding_size) + self._initialize_weights() + + if bn_freeze: + for m in self.model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + m.weight.requires_grad_(False) + m.bias.requires_grad_(False) + + def l2_norm(self,input): + input_size = input.size() + buffer = torch.pow(input, 2) + + normp = torch.sum(buffer, 1).add_(1e-12) + norm = torch.sqrt(normp) + + _output = torch.div(input, norm.view(-1, 1).expand_as(input)) + + output = _output.view(input_size) + + return output + + def forward(self, x): + x = self.model.conv1(x) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + x = self.model.layer1(x) + x = self.model.layer2(x) + x = self.model.layer3(x) + x = self.model.layer4(x) + + avg_x = self.model.gap(x) + max_x = self.model.gmp(x) + + x = max_x + avg_x + + x = x.view(x.size(0), -1) + x = self.model.embedding(x) + + if self.is_norm: + x = self.l2_norm(x) + + return x + + def _initialize_weights(self): + init.kaiming_normal_(self.model.embedding.weight, mode='fan_out') + init.constant_(self.model.embedding.bias, 0) + +class Resnet34(nn.Module): + def __init__(self,embedding_size, pretrained=True, is_norm=True, bn_freeze = True): + super(Resnet34, self).__init__() + + self.model = resnet34(pretrained) + self.is_norm = is_norm + self.embedding_size = embedding_size + self.num_ftrs = self.model.fc.in_features + self.model.gap = nn.AdaptiveAvgPool2d(1) + self.model.gmp = nn.AdaptiveMaxPool2d(1) + + self.model.embedding = nn.Linear(self.num_ftrs, self.embedding_size) + self._initialize_weights() + + if bn_freeze: + for m in self.model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + m.weight.requires_grad_(False) + m.bias.requires_grad_(False) + + def l2_norm(self,input): + input_size = input.size() + buffer = torch.pow(input, 2) + + normp = torch.sum(buffer, 1).add_(1e-12) + norm = torch.sqrt(normp) + + _output = torch.div(input, norm.view(-1, 1).expand_as(input)) + + output = _output.view(input_size) + + return output + + def forward(self, x): + x = self.model.conv1(x) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + x = self.model.layer1(x) + x = self.model.layer2(x) + x = self.model.layer3(x) + x = self.model.layer4(x) + + avg_x = self.model.gap(x) + max_x = self.model.gmp(x) + + x = avg_x + max_x + + x = x.view(x.size(0), -1) + x = self.model.embedding(x) + + if self.is_norm: + x = self.l2_norm(x) + + return x + + def _initialize_weights(self): + init.kaiming_normal_(self.model.embedding.weight, mode='fan_out') + init.constant_(self.model.embedding.bias, 0) + +class Resnet50(nn.Module): + def __init__(self,embedding_size, pretrained=True, is_norm=True, bn_freeze = True): + super(Resnet50, self).__init__() + + self.model = resnet50(pretrained) + self.is_norm = is_norm + self.embedding_size = embedding_size + self.num_ftrs = self.model.fc.in_features + self.model.gap = nn.AdaptiveAvgPool2d(1) + self.model.gmp = nn.AdaptiveMaxPool2d(1) + + self.model.embedding = nn.Linear(self.num_ftrs, self.embedding_size) + self._initialize_weights() + + if bn_freeze: + for m in self.model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + m.weight.requires_grad_(False) + m.bias.requires_grad_(False) + + def l2_norm(self,input): + input_size = input.size() + buffer = torch.pow(input, 2) + + normp = torch.sum(buffer, 1).add_(1e-12) + norm = torch.sqrt(normp) + + _output = torch.div(input, norm.view(-1, 1).expand_as(input)) + + output = _output.view(input_size) + + return output + + def forward(self, x): + x = self.model.conv1(x) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + x = self.model.layer1(x) + x = self.model.layer2(x) + x = self.model.layer3(x) + x = self.model.layer4(x) + + avg_x = self.model.gap(x) + max_x = self.model.gmp(x) + + x = max_x + avg_x + x = x.view(x.size(0), -1) + x = self.model.embedding(x) + + if self.is_norm: + x = self.l2_norm(x) + + return x + + def _initialize_weights(self): + init.kaiming_normal_(self.model.embedding.weight, mode='fan_out') + init.constant_(self.model.embedding.bias, 0) + +class Resnet101(nn.Module): + def __init__(self,embedding_size, pretrained=True, is_norm=True, bn_freeze = True): + super(Resnet101, self).__init__() + + self.model = resnet101(pretrained) + self.is_norm = is_norm + self.embedding_size = embedding_size + self.num_ftrs = self.model.fc.in_features + self.model.gap = nn.AdaptiveAvgPool2d(1) + self.model.gmp = nn.AdaptiveMaxPool2d(1) + + self.model.embedding = nn.Linear(self.num_ftrs, self.embedding_size) + self._initialize_weights() + + if bn_freeze: + for m in self.model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + m.weight.requires_grad_(False) + m.bias.requires_grad_(False) + + def l2_norm(self,input): + input_size = input.size() + buffer = torch.pow(input, 2) + + normp = torch.sum(buffer, 1).add_(1e-12) + norm = torch.sqrt(normp) + + _output = torch.div(input, norm.view(-1, 1).expand_as(input)) + + output = _output.view(input_size) + + return output + + def forward(self, x): + x = self.model.conv1(x) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + x = self.model.layer1(x) + x = self.model.layer2(x) + x = self.model.layer3(x) + x = self.model.layer4(x) + + avg_x = self.model.gap(x) + max_x = self.model.gmp(x) + + x = max_x + avg_x + x = x.view(x.size(0), -1) + x = self.model.embedding(x) + + if self.is_norm: + x = self.l2_norm(x) + + return x + + def _initialize_weights(self): + init.kaiming_normal_(self.model.embedding.weight, mode='fan_out') + init.constant_(self.model.embedding.bias, 0) \ No newline at end of file diff --git a/code/train.py b/code/train.py new file mode 100644 index 00000000..cafa428b --- /dev/null +++ b/code/train.py @@ -0,0 +1,355 @@ +import torch, math, time, argparse, os +import random, dataset, utils, losses, net +import numpy as np +import matplotlib.pyplot as plt + +from dataset.Inshop import Inshop_Dataset +from net.resnet import * +from net.googlenet import * +from net.bn_inception import * +from dataset import sampler +from torch.utils.data.sampler import BatchSampler +from torch.utils.data.dataloader import default_collate + +from tqdm import * +import wandb + +seed = 1 +random.seed(seed) +np.random.seed(seed) +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) # set random seed for all gpus + +parser = argparse.ArgumentParser(description= + 'Official implementation of `Proxy Anchor Loss for Deep Metric Learning`' + + 'Our code is modified from `https://github.com/dichotomies/proxy-nca`' +) +# export directory, training and val datasets, test datasets +parser.add_argument('--LOG_DIR', + default='../logs', + help = 'Path to log folder' +) +parser.add_argument('--dataset', + default='cub', + help = 'Training dataset, e.g. cub, cars, SOP, Inshop' +) +parser.add_argument('--embedding-size', default = 512, type = int, + dest = 'sz_embedding', + help = 'Size of embedding that is appended to backbone model.' +) +parser.add_argument('--batch-size', default = 150, type = int, + dest = 'sz_batch', + help = 'Number of samples per batch.' +) +parser.add_argument('--epochs', default = 60, type = int, + dest = 'nb_epochs', + help = 'Number of training epochs.' +) +parser.add_argument('--gpu-id', default = 0, type = int, + help = 'ID of GPU that is used for training.' +) +parser.add_argument('--workers', default = 4, type = int, + dest = 'nb_workers', + help = 'Number of workers for dataloader.' +) +parser.add_argument('--model', default = 'bn_inception', + help = 'Model for training' +) +parser.add_argument('--loss', default = 'Proxy_Anchor', + help = 'Criterion for training' +) +parser.add_argument('--optimizer', default = 'adamw', + help = 'Optimizer setting' +) +parser.add_argument('--lr', default = 1e-4, type =float, + help = 'Learning rate setting' +) +parser.add_argument('--weight-decay', default = 1e-4, type =float, + help = 'Weight decay setting' +) +parser.add_argument('--lr-decay-step', default = 10, type =int, + help = 'Learning decay step setting' +) +parser.add_argument('--lr-decay-gamma', default = 0.5, type =float, + help = 'Learning decay gamma setting' +) +parser.add_argument('--alpha', default = 32, type = float, + help = 'Scaling Parameter setting' +) +parser.add_argument('--mrg', default = 0.1, type = float, + help = 'Margin parameter setting' +) +parser.add_argument('--IPC', type = int, + help = 'Balanced sampling, images per class' +) +parser.add_argument('--warm', default = 1, type = int, + help = 'Warmup training epochs' +) +parser.add_argument('--bn-freeze', default = 1, type = int, + help = 'Batch normalization parameter freeze' +) +parser.add_argument('--l2-norm', default = 1, type = int, + help = 'L2 normlization' +) +parser.add_argument('--remark', default = '', + help = 'Any reamrk' +) + +args = parser.parse_args() + +if args.gpu_id != -1: + torch.cuda.set_device(args.gpu_id) + +# Directory for Log +LOG_DIR = args.LOG_DIR + '/logs_{}/{}_{}_embedding{}_alpha{}_mrg{}_{}_lr{}_batch{}{}'.format(args.dataset, args.model, args.loss, args.sz_embedding, args.alpha, + args.mrg, args.optimizer, args.lr, args.sz_batch, args.remark) +# Wandb Initialization +wandb.init(project=args.dataset + '_ProxyAnchor', notes=LOG_DIR) +wandb.config.update(args) + +os.chdir('../data/') +data_root = os.getcwd() +# Dataset Loader and Sampler +if args.dataset != 'Inshop': + trn_dataset = dataset.load( + name = args.dataset, + root = data_root, + mode = 'train', + transform = dataset.utils.make_transform( + is_train = True, + is_inception = (args.model == 'bn_inception') + )) +else: + trn_dataset = Inshop_Dataset( + root = data_root, + mode = 'train', + transform = dataset.utils.make_transform( + is_train = True, + is_inception = (args.model == 'bn_inception') + )) + +if args.IPC: + balanced_sampler = sampler.BalancedSampler(trn_dataset, batch_size=args.sz_batch, images_per_class = args.IPC) + batch_sampler = BatchSampler(balanced_sampler, batch_size = args.sz_batch, drop_last = True) + dl_tr = torch.utils.data.DataLoader( + trn_dataset, + num_workers = args.nb_workers, + pin_memory = True, + batch_sampler = batch_sampler + ) + print('Balanced Sampling') + +else: + dl_tr = torch.utils.data.DataLoader( + trn_dataset, + batch_size = args.sz_batch, + shuffle = True, + num_workers = args.nb_workers, + drop_last = True, + pin_memory = True + ) + print('Random Sampling') + +if args.dataset != 'Inshop': + ev_dataset = dataset.load( + name = args.dataset, + root = data_root, + mode = 'eval', + transform = dataset.utils.make_transform( + is_train = False, + is_inception = (args.model == 'bn_inception') + )) + + dl_ev = torch.utils.data.DataLoader( + ev_dataset, + batch_size = args.sz_batch, + shuffle = False, + num_workers = args.nb_workers, + pin_memory = True + ) + +else: + query_dataset = Inshop_Dataset( + root = data_root, + mode = 'query', + transform = dataset.utils.make_transform( + is_train = False, + is_inception = (args.model == 'bn_inception') + )) + + dl_query = torch.utils.data.DataLoader( + query_dataset, + batch_size = args.sz_batch, + shuffle = False, + num_workers = args.nb_workers, + pin_memory = True + ) + + gallery_dataset = Inshop_Dataset( + root = data_root, + mode = 'gallery', + transform = dataset.utils.make_transform( + is_train = False, + is_inception = (args.model == 'bn_inception') + )) + + dl_gallery = torch.utils.data.DataLoader( + gallery_dataset, + batch_size = args.sz_batch, + shuffle = False, + num_workers = args.nb_workers, + pin_memory = True + ) + +nb_classes = trn_dataset.nb_classes() + +# Backbone Model +if args.model.find('googlenet')+1: + model = googlenet(embedding_size=args.sz_embedding, pretrained=True, is_norm=args.l2_norm, bn_freeze = args.bn_freeze) +elif args.model.find('bn_inception')+1: + model = bn_inception(embedding_size=args.sz_embedding, pretrained=True, is_norm=args.l2_norm, bn_freeze = args.bn_freeze) +elif args.model.find('resnet18')+1: + model = Resnet18(embedding_size=args.sz_embedding, pretrained=True, is_norm=args.l2_norm, bn_freeze = args.bn_freeze) +elif args.model.find('resnet50')+1: + model = Resnet50(embedding_size=args.sz_embedding, pretrained=True, is_norm=args.l2_norm, bn_freeze = args.bn_freeze) +elif args.model.find('resnet101')+1: + model = Resnet101(embedding_size=args.sz_embedding, pretrained=True, is_norm=args.l2_norm, bn_freeze = args.bn_freeze) +model = model.cuda() + +if args.gpu_id == -1: + model = nn.DataParallel(model) + +# DML Losses +if args.loss == 'Proxy_Anchor': + criterion = losses.Proxy_Anchor(nb_classes = nb_classes, sz_embed = args.sz_embedding, mrg = args.mrg, alpha = args.alpha).cuda() +elif args.loss == 'Proxy_NCA': + criterion = losses.Proxy_NCA().cuda() +elif args.loss == 'MS': + criterion = losses.MultiSimilarityLoss().cuda() +elif args.loss == 'Contrastive': + criterion = losses.ContrastiveLoss().cuda() +elif args.loss == 'Triplet': + criterion = losses.TripletLoss().cuda() +elif args.loss == 'NPair': + criterion = losses.NPairLoss().cuda() + +# Train Parameters +param_groups = [ + {'params': list(set(model.parameters()).difference(set(model.model.embedding.parameters()))) if args.gpu_id != -1 else + list(set(model.module.parameters()).difference(set(model.module.model.embedding.parameters())))}, + {'params': model.model.embedding.parameters() if args.gpu_id != -1 else model.module.model.embedding.parameters(), 'lr':float(args.lr) * 1}, +] +if args.loss == 'Proxy_Anchor': + param_groups.append({'params': criterion.proxies, 'lr':float(args.lr) * 100}) + +# Optimizer Setting +if args.optimizer == 'sgd': + opt = torch.optim.SGD(param_groups, lr=float(args.lr), weight_decay = args.weight_decay, momentum = 0.9, nesterov=True) +elif args.optimizer == 'adam': + opt = torch.optim.Adam(param_groups, lr=float(args.lr), weight_decay = args.weight_decay) +elif args.optimizer == 'rmsprop': + opt = torch.optim.RMSprop(param_groups, lr=float(args.lr), alpha=0.9, weight_decay = args.weight_decay, momentum = 0.9) +elif args.optimizer == 'adamw': + opt = torch.optim.AdamW(param_groups, lr=float(args.lr), weight_decay = args.weight_decay) + +scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=args.lr_decay_step, gamma = args.lr_decay_gamma) + +print("Training parameters: {}".format(vars(args))) +print("Training for {} epochs.".format(args.nb_epochs)) +losses_list = [] +best_recall=[0] +best_epoch = 0 + +for epoch in range(0, args.nb_epochs): + model.train() + bn_freeze = args.bn_freeze + if bn_freeze: + modules = model.model.modules() if args.gpu_id != -1 else model.module.model.modules() + for m in modules: + if isinstance(m, nn.BatchNorm2d): + m.eval() + + losses_per_epoch = [] + + # Warmup: Train only new params, helps stabilize learning. + if args.warm > 0: + if args.gpu_id != -1: + unfreeze_model_param = list(model.model.embedding.parameters()) + list(criterion.parameters()) + else: + unfreeze_model_param = list(model.module.model.embedding.parameters()) + list(criterion.parameters()) + + if epoch == 0: + for param in list(set(model.parameters()).difference(set(unfreeze_model_param))): + param.requires_grad = False + if epoch == args.warm: + for param in list(set(model.parameters()).difference(set(unfreeze_model_param))): + param.requires_grad = True + + pbar = tqdm(enumerate(dl_tr)) + + for batch_idx, (x, y) in pbar: + m = model(x.squeeze().cuda()) + loss = criterion(m, y.squeeze().cuda()) + + opt.zero_grad() + loss.backward() + + torch.nn.utils.clip_grad_value_(model.parameters(), 10) + if args.loss == 'Proxy_Anchor': + torch.nn.utils.clip_grad_value_(criterion.parameters(), 10) + + losses_per_epoch.append(loss.data.cpu().numpy()) + opt.step() + + pbar.set_description( + 'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format( + epoch, batch_idx + 1, len(dl_tr), + 100. * batch_idx / len(dl_tr), + loss.item())) + + losses_list.append(np.mean(losses_per_epoch)) + wandb.log({'loss': losses_list[-1]}, step=epoch) + scheduler.step() + + if(epoch >= 0): + with torch.no_grad(): + print("**Evaluating...**") + if args.dataset == 'Inshop': + NMI = 0 + Recalls = utils.evaluate_cos_Inshop(model, dl_query, dl_gallery) + elif args.dataset != 'SOP': + Recalls = utils.evaluate_cos(model, dl_ev) + else: + Recalls = utils.evaluate_cos_SOP(model, dl_ev) + + # Logging Evaluation Score + if args.dataset == 'Inshop': + for i, K in enumerate([1,10,20,30,40,50]): + wandb.log({"R@{}".format(K): Recalls[i]}, step=epoch) + elif args.dataset != 'SOP': + for i in range(6): + wandb.log({"R@{}".format(2**i): Recalls[i]}, step=epoch) + else: + for i in range(4): + wandb.log({"R@{}".format(10**i): Recalls[i]}, step=epoch) + + # Best model save + if best_recall[0] < Recalls[0]: + best_recall = Recalls + best_epoch = epoch + if not os.path.exists('{}'.format(LOG_DIR)): + os.makedirs('{}'.format(LOG_DIR)) + torch.save({'model_state_dict':model.state_dict()}, '{}/{}_{}_best.pth'.format(LOG_DIR, args.dataset, args.model)) + with open('{}/{}_{}_best_results.txt'.format(LOG_DIR, args.dataset, args.model), 'w') as f: + f.write('Best Epoch: {}\n'.format(best_epoch)) + if args.dataset == 'Inshop': + for i, K in enumerate([1,10,20,30,40,50]): + f.write("Best Recall@{}: {:.4f}\n".format(K, best_recall[i] * 100)) + elif args.dataset != 'SOP': + for i in range(6): + f.write("Best Recall@{}: {:.4f}\n".format(2**i, best_recall[i] * 100)) + else: + for i in range(4): + f.write("Best Recall@{}: {:.4f}\n".format(10**i, best_recall[i] * 100)) + + \ No newline at end of file diff --git a/code/utils.py b/code/utils.py new file mode 100644 index 00000000..0e275f39 --- /dev/null +++ b/code/utils.py @@ -0,0 +1,166 @@ +import numpy as np +import torch +import logging +import losses +import json +from tqdm import tqdm +import torch.nn.functional as F +import math + +def l2_norm(input): + input_size = input.size() + buffer = torch.pow(input, 2) + normp = torch.sum(buffer, 1).add_(1e-12) + norm = torch.sqrt(normp) + _output = torch.div(input, norm.view(-1, 1).expand_as(input)) + output = _output.view(input_size) + + return output + +def calc_recall_at_k(T, Y, k): + """ + T : [nb_samples] (target labels) + Y : [nb_samples x k] (k predicted labels/neighbours) + """ + + s = 0 + for t,y in zip(T,Y): + if t in torch.Tensor(y).long()[:k]: + s += 1 + return s / (1. * len(T)) + + +def predict_batchwise(model, dataloader): + device = "cuda" + model_is_training = model.training + model.eval() + + ds = dataloader.dataset + A = [[] for i in range(len(ds[0]))] + with torch.no_grad(): + # extract batches (A becomes list of samples) + for batch in tqdm(dataloader): + for i, J in enumerate(batch): + # i = 0: sz_batch * images + # i = 1: sz_batch * labels + # i = 2: sz_batch * indices + if i == 0: + # move images to device of model (approximate device) + J = model(J.cuda()) + + for j in J: + A[i].append(j) + model.train() + model.train(model_is_training) # revert to previous training state + + return [torch.stack(A[i]) for i in range(len(A))] + +def proxy_init_calc(model, dataloader): + nb_classes = dataloader.dataset.nb_classes() + X, T, *_ = predict_batchwise(model, dataloader) + + proxy_mean = torch.stack([X[T==class_idx].mean(0) for class_idx in range(nb_classes)]) + + return proxy_mean + +def evaluate_cos(model, dataloader): + nb_classes = dataloader.dataset.nb_classes() + + # calculate embeddings with model and get targets + X, T = predict_batchwise(model, dataloader) + X = l2_norm(X) + + # get predictions by assigning nearest 8 neighbors with cosine + K = 32 + Y = [] + xs = [] + + cos_sim = F.linear(X, X) + Y = T[cos_sim.topk(1 + K)[1][:,1:]] + Y = Y.float().cpu() + + recall = [] + for k in [1, 2, 4, 8, 16, 32]: + r_at_k = calc_recall_at_k(T, Y, k) + recall.append(r_at_k) + print("R@{} : {:.3f}".format(k, 100 * r_at_k)) + + return recall + +def evaluate_cos_Inshop(model, query_dataloader, gallery_dataloader): + nb_classes = query_dataloader.dataset.nb_classes() + + # calculate embeddings with model and get targets + query_X, query_T = predict_batchwise(model, query_dataloader) + gallery_X, gallery_T = predict_batchwise(model, gallery_dataloader) + + query_X = l2_norm(query_X) + gallery_X = l2_norm(gallery_X) + + # get predictions by assigning nearest 8 neighbors with cosine + K = 50 + Y = [] + xs = [] + + cos_sim = F.linear(query_X, gallery_X) + + def recall_k(cos_sim, query_T, gallery_T, k): + m = len(cos_sim) + match_counter = 0 + + for i in range(m): + pos_sim = cos_sim[i][gallery_T == query_T[i]] + neg_sim = cos_sim[i][gallery_T != query_T[i]] + + thresh = torch.max(pos_sim).item() + + if torch.sum(neg_sim > thresh) < k: + match_counter += 1 + + return match_counter / m + + # calculate recall @ 1, 2, 4, 8 + recall = [] + for k in [1, 10, 20, 30, 40, 50]: + r_at_k = recall_k(cos_sim, query_T, gallery_T, k) + recall.append(r_at_k) + print("R@{} : {:.3f}".format(k, 100 * r_at_k)) + + return recall + +def evaluate_cos_SOP(model, dataloader): + nb_classes = dataloader.dataset.nb_classes() + + # calculate embeddings with model and get targets + X, T = predict_batchwise(model, dataloader) + X = l2_norm(X) + + # get predictions by assigning nearest 8 neighbors with cosine + K = 1000 + Y = [] + xs = [] + for x in X: + if len(xs)<10000: + xs.append(x) + else: + xs.append(x) + xs = torch.stack(xs,dim=0) + cos_sim = F.linear(xs,X) + y = T[cos_sim.topk(1 + K)[1][:,1:]] + Y.append(y.float().cpu()) + xs = [] + + # Last Loop + xs = torch.stack(xs,dim=0) + cos_sim = F.linear(xs,X) + y = T[cos_sim.topk(1 + K)[1][:,1:]] + Y.append(y.float().cpu()) + Y = torch.cat(Y, dim=0) + + # calculate recall @ 1, 2, 4, 8 + recall = [] + for k in [1, 10, 100, 1000]: + r_at_k = calc_recall_at_k(T, Y, k) + recall.append(r_at_k) + print("R@{} : {:.3f}".format(k, 100 * r_at_k)) + return recall diff --git a/misc/Recall_Trainingtime.jpg b/misc/Recall_Trainingtime.jpg new file mode 100644 index 00000000..afd7cbf8 Binary files /dev/null and b/misc/Recall_Trainingtime.jpg differ