-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathattack_convergence.py
200 lines (163 loc) · 8.87 KB
/
attack_convergence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import logging
from multiprocessing.sharedctypes import Value
from pathlib import Path
import numpy as np
import timm
from timm.bits import initialize_device
from timm.data import create_dataset, create_loader_v2, resolve_data_config
from timm.models import apply_test_time_pool
from timm.utils import setup_default_logging
from torch import nn
import torch
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
from src import attacks, utils
from src.random import random_seed
from validate import parser
_logger = logging.getLogger('validate')
parser.add_argument('--runs', type=int, default=20, metavar='N', help='The number of runs')
parser.add_argument('--n-points', type=int, default=100, metavar='N', help='The number of points')
parser.add_argument('--output-file', type=str, default=None, metavar='N', help='The output file')
parser.add_argument('--steps-to-try',
type=int,
nargs="+",
default=(0, 1, 2, 5, 10, 50, 100, 200, 500),
metavar='X Y Z',
help='The number of steps to try')
parser.add_argument('--one-instance',
action='store_true',
help='Run only one instance and save the losses at each step')
def main():
setup_default_logging()
args = parser.parse_args()
dev_env = initialize_device(force_cpu=args.force_cpu, amp=args.amp)
random_seed(args.seed, dev_env.global_rank)
if args.output_file is None:
args.output_file = f"{args.model}.csv"
output_path = Path(args.output_file)
csv_writer = utils.GCSSummaryCsv(output_path.parent, filename=output_path.name)
model = timm.create_model(args.model, pretrained=not args.checkpoint, checkpoint_path=args.checkpoint)
criterion = dev_env.to_device(nn.CrossEntropyLoss(reduction='none'))
eps = args.attack_eps / 255
lr = args.attack_lr or (1.5 * eps / args.attack_steps)
attack_criterion = nn.NLLLoss(reduction="sum")
dataset = create_dataset(root=args.data,
name=args.dataset,
split=args.split)
data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
data_config['normalize'] = not (args.no_normalize or args.normalize_model)
if args.normalize_model:
mean = args.mean or data_config["mean"]
std = args.std or data_config["std"]
model = utils.normalize_model(model, mean=mean, std=std)
model = dev_env.to_device(model)
model.eval()
test_time_pool = False
if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
eval_pp_cfg = utils.MyPreprocessCfg( # type: ignore
input_size=data_config['input_size'],
interpolation=data_config['interpolation'],
crop_pct=1.0 if test_time_pool else data_config['crop_pct'],
mean=data_config['mean'],
std=data_config['std'],
normalize=data_config['normalize'],
)
if args.one_instance:
args.steps_to_try = [max(args.steps_to_try)]
if args.n_points % args.batch_size != 0:
raise ValueError(f"n_points ({args.n_points}) must be a multiple of batch_size ({args.batch_size})")
loader = create_loader_v2(dataset,
batch_size=args.batch_size,
is_training=False,
pp_cfg=eval_pp_cfg,
num_workers=args.workers,
pin_memory=args.pin_mem)
if not eval_pp_cfg.normalize:
loader.dataset.transform.transforms[-1] = transforms.ToTensor()
correctly_classified_samples = []
correctly_classified_targets = []
correctly_classified_ids = []
_logger.info("Starting creation of correctly classified DataSet and DataLoader")
for batch_idx, (sample, target) in enumerate(loader):
logits = model(sample)
predicted_classes = logits.argmax(-1)
accuracy_mask = predicted_classes.eq(target)
print(f"Batch {batch_idx} accuracy: {accuracy_mask.sum() / sample.shape[0]}")
# Get correctly classified samples, targets, and ids
batch_correctly_classified_samples = sample[accuracy_mask]
batch_correctly_classified_targets = target[accuracy_mask]
batch_correctly_classified_ids = accuracy_mask.nonzero().flatten() + batch_idx * args.batch_size
correctly_classified_samples.append(dev_env.to_cpu(batch_correctly_classified_samples))
correctly_classified_targets.append(dev_env.to_cpu(batch_correctly_classified_targets))
correctly_classified_ids.append(dev_env.to_cpu(batch_correctly_classified_ids))
if len(torch.cat(correctly_classified_samples)) >= args.n_points:
correctly_classified_samples = torch.cat(correctly_classified_samples)[:args.n_points]
correctly_classified_targets = torch.cat(correctly_classified_targets)[:args.n_points]
correctly_classified_ids = torch.cat(correctly_classified_ids)[:args.n_points]
break
if len(correctly_classified_samples) != args.n_points:
raise ValueError("Impossible to have enough correctly classified samples.")
correctly_classified_dataset = TensorDataset(correctly_classified_samples, correctly_classified_targets,
correctly_classified_ids)
experiment_batch_size = 1 if args.one_instance else args.batch_size
correctly_classified_loader = DataLoader(correctly_classified_dataset, batch_size=experiment_batch_size)
_logger.info("Created correctly classified DataSet and DataLoader")
# Backup batchnorm stats
batch_stats_backup = utils.backup_batchnorm_stats(model)
original_state_dict = model.state_dict()
for batch_idx, (sample, target, sample_id) in zip(range(args.n_points // experiment_batch_size),
correctly_classified_loader):
sample, target, sample_id = dev_env.to_device(sample, target, sample_id)
for run in range(args.runs):
for step in args.steps_to_try:
random_seed(run, dev_env.global_rank)
attack = attacks.make_attack(args.attack,
eps,
lr,
step,
args.attack_norm,
args.attack_boundaries,
attack_criterion,
dev_env=dev_env,
return_losses=True)
_logger.info(
f"Points ({batch_idx * experiment_batch_size}, {batch_idx * experiment_batch_size + experiment_batch_size}) - run {run} - steps {step}"
)
# Make sure that all samples are correctly classified (only need to check the first time)
if run == 0:
logits = model(sample)
assert dev_env.to_cpu(logits.argmax(-1).eq(target).all()).item()
if dev_env.type_xla:
model.train()
# Attack sample
adv_sample, intermediate_losses = attack(model, sample, target)
final_losses = criterion(model(adv_sample), target)
if dev_env.type_xla:
# Change model back to `eval` if on XLA, and restore batchnorm stats
model = utils.restore_batchnorm_stats(model, batch_stats_backup)
for k, v in model.state_dict().items():
assert (original_state_dict[k] == v).all()
model.eval()
final_losses_np = dev_env.to_cpu(final_losses).detach().numpy()
sample_id_numpy = dev_env.to_cpu(sample_id).detach().numpy()
if not args.one_instance:
for point_id, loss in zip(sample_id_numpy, final_losses_np):
row_to_write = {"point": point_id, "seed": run, "steps": step, "loss": loss}
csv_writer.update(row_to_write)
_logger.info(f"Point {point_id} - run {run} - steps {step} - loss: {loss:.4f}")
else:
intermediate_losses_np = np.concatenate(
[dev_env.to_cpu(intermediate_losses).detach().numpy(), final_losses_np])
for step_idx, loss in enumerate(intermediate_losses_np):
row_to_write = {
"point": sample_id_numpy.item(),
"seed": run,
"steps": step_idx,
"loss": loss
}
csv_writer.update(row_to_write)
def _mp_entry(*args):
main()
if __name__ == '__main__':
main()