-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
261 lines (207 loc) · 9.73 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import csv
import dataclasses
import glob
import os
import tempfile
from collections import OrderedDict
from typing import Any, Callable, Dict, Optional, Tuple, Union
import timm
import torch
import torch.nn.functional as F
from timm import bits
from timm.data import PreprocessCfg
from timm.data.fetcher import Fetcher
from timm.data.prefetcher_cuda import PrefetcherCuda
from timm.models import xcit
from torch import nn
import src.attacks as attacks
def get_outdir(path: str, *paths: str, inc=False) -> str:
"""Adapted to get out dir from GCS"""
outdir = os.path.join(path, *paths)
if path.startswith('gs://'):
import tensorflow as tf
os_module = tf.io.gfile
exists_fn = lambda x: os_module.exists(x)
else:
os_module = os
exists_fn = os.path.exists
if not exists_fn(outdir):
os_module.makedirs(outdir)
elif inc:
count = 1
outdir_inc = outdir + '-' + str(count)
while exists_fn(outdir_inc):
count = count + 1
outdir_inc = outdir + '-' + str(count)
assert count < 100
outdir = outdir_inc
os_module.makedirs(outdir)
return outdir
def load_model_from_gcs(checkpoint_path: str, model_name: str, **kwargs):
import tensorflow as tf
with tempfile.TemporaryDirectory() as dst:
local_checkpoint_path = os.path.join(dst, os.path.basename(checkpoint_path))
tf.io.gfile.copy(checkpoint_path, local_checkpoint_path)
model = timm.create_model(model_name, checkpoint_path=local_checkpoint_path, **kwargs)
return model
def load_state_dict_from_gcs(model: nn.Module, checkpoint_path: str):
import tensorflow as tf
with tempfile.TemporaryDirectory() as dst:
local_checkpoint_path = os.path.join(dst, os.path.basename(checkpoint_path))
tf.io.gfile.copy(checkpoint_path, local_checkpoint_path)
model.load_state_dict(torch.load(local_checkpoint_path)["model"])
return model
def upload_checkpoints_gcs(checkpoints_dir: str, output_dir: str):
import tensorflow as tf
checkpoints_paths = glob.glob(os.path.join(checkpoints_dir, '*.pth.tar'))
for checkpoint in checkpoints_paths:
gcs_checkpoint_path = os.path.join(output_dir, os.path.basename(checkpoint))
tf.io.gfile.copy(checkpoint, gcs_checkpoint_path)
def backup_batchnorm_stats(model: nn.Module) -> Dict[str, torch.Tensor]:
return {k: v for k, v in model.state_dict().items() if layer_is_batchnorm(k)}
def restore_batchnorm_stats(model: nn.Module, stats: Dict[str, torch.Tensor]) -> nn.Module:
_, unexp_keys = model.load_state_dict(stats, strict=False)
assert len(unexp_keys) == 0
return model
def layer_is_batchnorm(layer_name: str):
keys = {"bn", "batchnorm"}
return any(map(lambda key: key in layer_name, keys))
class GCSSummaryCsv(bits.monitor.SummaryCsv):
"""SummaryCSV version to work with GCS"""
def __init__(self, output_dir, filename='summary.csv'):
super().__init__(output_dir, filename)
def update(self, row_dict):
import tensorflow as tf
with tf.io.gfile.GFile(self.filename, mode='a') as cf:
dw = csv.DictWriter(cf, fieldnames=row_dict.keys())
if self.needs_header: # first iteration (epoch == 1 can't be used)
dw.writeheader()
self.needs_header = False
dw.writerow(row_dict)
class ComputeLossFn(nn.Module):
def __init__(self, loss_fn: nn.Module):
super().__init__()
self.loss_fn = loss_fn
def forward(self, model: nn.Module, x: torch.Tensor, y: torch.Tensor,
_: int) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
output = model(x)
return self.loss_fn(output, y), output, None
@dataclasses.dataclass
class AdvTrainState(bits.TrainState):
# pytype: disable=annotation-type-mismatch
compute_loss_fn: Callable[[nn.Module, torch.Tensor, torch.Tensor, int],
Tuple[torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]] = None # type: ignore
eps_schedule: attacks.EpsSchedule = None # type: ignore
# pytype: enable=annotation-type-mismatch
@classmethod
def from_bits(cls, instance: bits.TrainState, **kwargs):
return cls(
model=instance.model, # type: ignore
train_loss=instance.train_loss,
eval_loss=instance.eval_loss,
updater=instance.updater,
lr_scheduler=instance.lr_scheduler,
model_ema=instance.model_ema,
train_cfg=instance.train_cfg,
epoch=instance.epoch,
step_count=instance.step_count,
step_count_global=instance.step_count_global,
**kwargs)
@dataclasses.dataclass
class MyPreprocessCfg(PreprocessCfg):
normalize: bool = True
rand_rotation: int = 0
pad: int = 0
class ImageNormalizer(nn.Module):
"""From
https://github.com/RobustBench/robustbench/blob/master/robustbench/model_zoo/architectures/utils_architectures.py#L8"""
def __init__(self, mean: Tuple[float, float, float], std: Tuple[float, float, float]) -> None:
super(ImageNormalizer, self).__init__()
self.register_buffer('mean', torch.as_tensor(mean).view(1, 3, 1, 1))
self.register_buffer('std', torch.as_tensor(std).view(1, 3, 1, 1))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return (input - self.mean) / self.std
def normalize_model(model: nn.Module, mean: Tuple[float, float, float], std: Tuple[float, float,
float]) -> nn.Module:
"""From
https://github.com/RobustBench/robustbench/blob/master/robustbench/model_zoo/architectures/utils_architectures.py#L20"""
layers = OrderedDict([('normalize', ImageNormalizer(mean, std)), ('model', model)])
return nn.Sequential(layers)
class CombinedLoaders:
def __init__(self, loader_1: Union[Fetcher, PrefetcherCuda], loader_2: Union[Fetcher, PrefetcherCuda]):
self.loader_1 = loader_1
self.loader_2 = loader_2
assert loader_1.mixup_enabled == loader_2.mixup_enabled
self._mixup_enabled = loader_1.mixup_enabled
def __iter__(self):
return self._iterator()
def __len__(self):
return min(len(self.loader_1), len(self.loader_2))
def _iterator(self):
for (img1, label1), (img2, label2) in zip(self.loader_1, self.loader_2):
images = torch.cat([img1, img2])
labels = torch.cat([label1, label2])
indices = torch.randperm(len(images))
yield images[indices], labels[indices]
@property
def sampler(self):
return self.loader_1.sampler
@property
def sampler2(self):
return self.loader_2.sampler
@property
def mixup_enabled(self):
return self._mixup_enabled
@mixup_enabled.setter
def mixup_enabled(self):
self.loader_1.mixup_enabled = False
self.loader_2.mixup_enabled = False
assert self.loader_1.mixup_enabled == self.loader_2.mixup_enabled
self._mixup_enabled = False
def write_wandb_info(notes: str, output_dir: str, wandb_run):
import tensorflow as tf
assert output_dir is not None
# Log run notes and *true* output dir to wandb
if output_dir.startswith("gs://"):
exp_dir = output_dir.split("gs://")[-1]
bucket_url = f"https://console.cloud.google.com/storage/{exp_dir}"
notes += f"Bucket: {exp_dir}\n"
wandb_run.config.update({"output": bucket_url}, allow_val_change=True)
else:
wandb_run.config.update({"output": output_dir}, allow_val_change=True)
wandb_run.notes = notes
wandb_run_field = f"wandb_run: {wandb_run.url}\n" # type: ignore
# Log wandb run url to args file
with tf.io.gfile.GFile(os.path.join(output_dir, 'args.yaml'), 'a') as f:
f.write(wandb_run_field)
def interpolate_position_embeddings(model: nn.Module, checkpoint_model: Dict[str, Any]):
"""Interpolates the position embedding layer for different resolutions.
Adapted from DeiT's original repo: https://github.com/facebookresearch/deit.
The original license can be found here: https://github.com/facebookresearch/deit/blob/main/LICENSE"""
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = F.interpolate(pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
return new_pos_embed
def adapt_model_patches(model: xcit.XCiT, new_patch_size: int):
to_divide = model.patch_embed.patch_size / new_patch_size
assert int(to_divide) == to_divide, "The new patch size should divide the original patch size"
to_divide = int(to_divide)
assert to_divide % 2 == 0, "The ratio between the original patch size and the new patch size should be divisible by 2"
for conv_index in range(0, to_divide, 2):
model.patch_embed.proj[conv_index][0].stride = (1, 1)
model.patch_embed.patch_size = new_patch_size
return model