-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathadv_test.py
124 lines (97 loc) · 3.87 KB
/
adv_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
"""PyTorch Inference Script
An example inference script that outputs top-k class ids for images in a folder into a csv.
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
from json import load
import os
import time
import argparse
import logging
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from advertorch.attacks import LinfPGDAttack
from tqdm import tqdm
from collections import OrderedDict
from PIL import Image
import robust_models
from timm.models import create_model
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
parser.add_argument('--model', '-m', metavar='MODEL', default='resnet50',
help='model architecture (default: dpn92)')
parser.add_argument('--ckpt_path', default='', type=str,
help='model architecture (default: dpn92)')
parser.add_argument('-b', '--batch-size', default=128, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--num-gpu', type=int, default=8,
help='Number of GPUS to use')
def normalize_fn(tensor, mean, std):
"""Differentiable version of torchvision.functional.normalize"""
# here we assume the color channel is in at dim=1
mean = mean[None, :, None, None]
std = std[None, :, None, None]
return tensor.sub(mean).div(std)
class NormalizeByChannelMeanStd(nn.Module):
def __init__(self, mean, std):
super(NormalizeByChannelMeanStd, self).__init__()
if not isinstance(mean, torch.Tensor):
mean = torch.tensor(mean)
if not isinstance(std, torch.Tensor):
std = torch.tensor(std)
self.register_buffer("mean", mean)
self.register_buffer("std", std)
def forward(self, tensor):
return normalize_fn(tensor, self.mean, self.std)
def extra_repr(self):
return 'mean={}, std={}'.format(self.mean, self.std)
def main():
args = parser.parse_args()
print(f"Creating model: {args.model}")
model = create_model(
args.model,
pretrained=False,
num_classes=1000
)
state_dict = torch.load(args.ckpt_path)
model.load_state_dict(state_dict)
normalize = NormalizeByChannelMeanStd(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
model = nn.Sequential(normalize, model)
dataset = datasets.ImageFolder('/data/ILSVRC2012/ILSVRC/Data/CLS-LOC/val', transform=transforms.Compose([
transforms.Resize(249, interpolation=Image.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor()
]))
loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=False)
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else:
model = model.cuda()
model.eval()
end = time.time()
# features_list = []
total_num = 0
correct_num = 0
adv_num = 0
adversary = LinfPGDAttack(
model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=4/255,
nb_iter=3, eps_iter=8/255/3, rand_init=True, clip_min=0.0, clip_max=1.0,
targeted=False)
for input, targets in tqdm(loader):
input = input.cuda()
labels = model(input)
targets = targets.cuda()
total_num += input.shape[0]
correct = labels.max(1)[1].eq(targets).sum().item()
correct_num += correct
adv_untargeted = adversary.perturb(input, targets)
labels_adv = model(adv_untargeted.detach())
adv_num += labels_adv.max(1)[1].eq(targets).sum().item()
print('acc: {} adv_acc:{} '.format(correct_num/total_num, adv_num/total_num))
if __name__ == '__main__':
main()