-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun_qat.py
393 lines (330 loc) · 18.4 KB
/
run_qat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
import torch
from torch.utils.data import DataLoader
import argparse
import torch.nn.functional as F
import torch.nn as nn
import os
import pathlib
from datasets.dcase23 import get_training_set, get_test_set
from helpers.init import worker_init_fn
from models.cp_mobile_clean import get_model
from models.mel import AugmentMelSTFT
from helpers.lr_schedule import exp_warmup_linear_down
from helpers.utils import mixstyle, QuantizationCallback, QuantParamFreezeCallback
from helpers import nessi
class PLModule(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config # results from argparse and contains all configurations for our experiment
# model to preprocess waveforms into log mel spectrograms
self.mel = AugmentMelSTFT(n_mels=config.n_mels,
sr=config.resample_rate,
win_length=config.window_size,
hopsize=config.hop_size,
n_fft=config.n_fft,
freqm=config.freqm,
timem=config.timem,
fmin=config.fmin,
fmax=config.fmax,
fmin_aug_range=config.fmin_aug_range,
fmax_aug_range=config.fmax_aug_range
)
# CP-Mobile - our model to be trained on the log mel spectrograms
self.model = get_model(n_classes=config.n_classes,
in_channels=config.in_channels,
base_channels=config.base_channels,
channels_multiplier=config.channels_multiplier,
expansion_rate=config.expansion_rate
)
# int8 model will be initialized later
self.model_int8 = None
self.kl_div_loss = nn.KLDivLoss(log_target=True, reduction="none") # KL Divergence loss for soft targets
self.device_ids = ['a', 'b', 'c', 's1', 's2', 's3', 's4', 's5', 's6']
self.label_ids = ['airport', 'bus', 'metro', 'metro_station', 'park', 'public_square', 'shopping_mall',
'street_pedestrian', 'street_traffic', 'tram']
# categorization of devices into 'real', 'seen' and 'unseen'
self.device_groups = {'a': "real", 'b': "real", 'c': "real",
's1': "seen", 's2': "seen", 's3': "seen",
's4': "unseen", 's5': "unseen", 's6': "unseen"}
def mel_forward(self, x):
"""
@param x: a batch of raw signals (waveform)
return: a batch of log mel spectrograms
"""
old_shape = x.size()
x = x.reshape(-1, old_shape[2]) # for calculating log mel spectrograms we remove the channel dimension
x = self.mel(x)
x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2]) # batch x channels x mels x time-frames
return x
def forward(self, x):
"""
:param x: batch of spectrograms
:return: final model predictions
"""
x = self.model(x)
return x
def quantized_forward(self, x):
"""
:param x: batch of spectrograms
:return: final model predictions
"""
# quantized forward needs to be done on cpu
orig_device = x.device
x = x.cpu()
self.model_int8.cpu()
y = self.model_int8(x)
return y.to(orig_device)
def configure_optimizers(self):
"""
This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined.
The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself).
:return: dict containing optimizer and learning rate scheduler
"""
optimizer = torch.optim.Adam(self.parameters(), lr=self.config.lr, weight_decay=self.config.weight_decay)
schedule_lambda = \
exp_warmup_linear_down(self.config.warm_up_len, self.config.ramp_down_len, self.config.ramp_down_start,
self.config.last_lr_value)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, schedule_lambda)
return {
'optimizer': optimizer,
'lr_scheduler': lr_scheduler
}
def training_step(self, train_batch, batch_idx):
"""
:param train_batch: contains one batch from train dataloader
:param batch_idx
:return: a dict containing at least loss that is used to update model parameters, can also contain
other items that can be processed in 'training_epoch_end' to log other metrics than loss
"""
x, file, labels, devices, cities, teacher_logits = train_batch
x = self.mel_forward(x) # we convert the raw audio signals into log mel spectrograms
if self.config.mixstyle_p > 0:
# frequency mixstyle
x = mixstyle(x, self.config.mixstyle_p, self.config.mixstyle_alpha)
y_hat = self.model(x)
samples_loss = F.cross_entropy(y_hat, labels, reduction="none")
label_loss = samples_loss.mean()
# Temperature adjusted probabilities of teacher and student
with torch.cuda.amp.autocast():
y_hat_soft = F.log_softmax(y_hat / self.config.temperature, dim=-1)
kd_loss = self.kl_div_loss(y_hat_soft, teacher_logits).mean()
kd_loss = kd_loss * (self.config.temperature ** 2)
loss = self.config.kd_lambda * label_loss + (1 - self.config.kd_lambda) * kd_loss
results = {"loss": loss, "label_loss": label_loss * self.config.kd_lambda,
"kd_loss": kd_loss * (1 - self.config.kd_lambda)}
return results
def training_epoch_end(self, outputs):
"""
:param outputs: contains the items you log in 'training_step'
:return: a dict containing the metrics you want to log to Weights and Biases
"""
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
avg_label_loss = torch.stack([x['label_loss'] for x in outputs]).mean()
avg_kd_loss = torch.stack([x['kd_loss'] for x in outputs]).mean()
self.log_dict({'loss': avg_loss, 'label_loss': avg_label_loss, 'kd_loss': avg_kd_loss})
def validation_step(self, val_batch, batch_idx):
x, files, labels, devices, cities = val_batch
x = self.mel_forward(x)
# fp32 accuracy + loss
y_hat = self.forward(x)
samples_loss = F.cross_entropy(y_hat, labels, reduction="none")
fp32_loss = samples_loss.mean()
_, preds = torch.max(y_hat, dim=1)
n_correct_pred_per_sample = (preds == labels)
fp32_n_correct_pred = n_correct_pred_per_sample.sum()
# quantized metrics
y_hat = self.quantized_forward(x)
samples_loss = F.cross_entropy(y_hat, labels, reduction="none")
loss = samples_loss.mean()
# for computing accuracy
_, preds = torch.max(y_hat, dim=1)
n_correct_pred_per_sample = (preds == labels)
n_correct_pred = n_correct_pred_per_sample.sum()
dev_names = [d.rsplit("-", 1)[1][:-4] for d in files]
results = {'val_loss': loss, "n_correct_pred": n_correct_pred, "n_pred": len(labels),
"fp32_val_loss": fp32_loss, "fp32_n_correct_pred": fp32_n_correct_pred}
# log metric per device and scene
for d in self.device_ids:
results["devloss." + d] = torch.as_tensor(0., device=self.device)
results["devcnt." + d] = torch.as_tensor(0., device=self.device)
results["devn_correct." + d] = torch.as_tensor(0., device=self.device)
for i, d in enumerate(dev_names):
results["devloss." + d] = results["devloss." + d] + samples_loss[i]
results["devn_correct." + d] = results["devn_correct." + d] + n_correct_pred_per_sample[i]
results["devcnt." + d] = results["devcnt." + d] + 1
for l in self.label_ids:
results["lblloss." + l] = torch.as_tensor(0., device=self.device)
results["lblcnt." + l] = torch.as_tensor(0., device=self.device)
results["lbln_correct." + l] = torch.as_tensor(0., device=self.device)
for i, l in enumerate(labels):
results["lblloss." + self.label_ids[l]] = results["lblloss." + self.label_ids[l]] + samples_loss[i]
results["lbln_correct." + self.label_ids[l]] = \
results["lbln_correct." + self.label_ids[l]] + n_correct_pred_per_sample[i]
results["lblcnt." + self.label_ids[l]] = results["lblcnt." + self.label_ids[l]] + 1
return results
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
val_acc = sum([x['n_correct_pred'] for x in outputs]) * 1.0 / sum(x['n_pred'] for x in outputs)
fp32_avg_loss = torch.stack([x['fp32_val_loss'] for x in outputs]).mean()
fp32_val_acc = sum([x['fp32_n_correct_pred'] for x in outputs]) * 1.0 / sum(x['n_pred'] for x in outputs)
logs = {'val_acc': val_acc, 'val_loss': avg_loss, 'fp32_val.loss': fp32_avg_loss, 'fp32_val_acc': fp32_val_acc}
# log metric per device and scene
for d in self.device_ids:
dev_loss = torch.stack([x["devloss." + d] for x in outputs]).sum()
dev_cnt = torch.stack([x["devcnt." + d] for x in outputs]).sum()
dev_corrct = torch.stack([x["devn_correct." + d] for x in outputs]).sum()
logs["vloss." + d] = dev_loss / dev_cnt
logs["vacc." + d] = dev_corrct / dev_cnt
logs["vcnt." + d] = dev_cnt
# device groups
logs["acc." + self.device_groups[d]] = logs.get("acc." + self.device_groups[d], 0.) + dev_corrct
logs["count." + self.device_groups[d]] = logs.get("count." + self.device_groups[d], 0.) + dev_cnt
logs["lloss." + self.device_groups[d]] = logs.get("lloss." + self.device_groups[d], 0.) + dev_loss
for d in set(self.device_groups.values()):
logs["acc." + d] = logs["acc." + d] / logs["count." + d]
logs["lloss." + d] = logs["lloss." + d] / logs["count." + d]
for l in self.label_ids:
lbl_loss = torch.stack([x["lblloss." + l] for x in outputs]).sum()
lbl_cnt = torch.stack([x["lblcnt." + l] for x in outputs]).sum()
lbl_corrct = torch.stack([x["lbln_correct." + l] for x in outputs]).sum()
logs["vloss." + l] = lbl_loss / lbl_cnt
logs["vacc." + l] = lbl_corrct / lbl_cnt
logs["vcnt." + l] = lbl_cnt
logs["macro_avg_acc"] = torch.mean(torch.stack([logs["vacc." + l] for l in self.label_ids]))
self.log_dict(logs)
def fuse_model(module):
# fuse layers
module.model.eval() # only works in eval mode
module.model.cpu()
module.model.fuse_model()
# put original net back on cuda
module.model.cuda()
def prepare_quantized(module):
module.model.train() # only works in train mode
module.model.cpu()
# give information of what kind of observers to attach
module.model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
# prepare model for QAT, insert observers and fake_quants
module.model = torch.ao.quantization.prepare_qat(module.model)
# attach the quantized model to module
module.model_int8 = torch.ao.quantization.convert(module.model)
# put original net back on cuda and in train mode
module.model.cuda()
module.model.train()
def load_pretrained_from_id(module, project_name, wandb_id):
ckpt_path = os.path.join(project_name, wandb_id, "checkpoints")
assert os.path.exists(ckpt_path), f"No checkpoint path '{ckpt_path}' found."
ckpt_files = [file for file in pathlib.Path(os.path.expanduser(ckpt_path)).rglob('*.ckpt')]
assert len(ckpt_files) > 0, f"No checkpoint files found in path {ckpt_path}."
latest_ckpt = sorted(ckpt_files)[-1]
state_dict = torch.load(latest_ckpt)['state_dict']
# remove "model" prefix
state_dict = {k[len("model."):]: state_dict[k] for k in state_dict.keys()}
module.model.load_state_dict(state_dict)
def train(config):
# logging is done using wandb
wandb_logger = WandbLogger(
project=config.project_name,
notes="CPJKU pipeline for DCASE23 Task 1.",
tags=["DCASE23"],
config=config, # this logs all hyperparameters for us
name=config.experiment_name
)
# train dataloader
train_dl = DataLoader(dataset=get_training_set(config.cache_path, config.resample_rate, config.roll,
config.dir_prob, config.temperature),
worker_init_fn=worker_init_fn,
num_workers=config.num_workers,
batch_size=config.batch_size,
shuffle=True)
# test loader
test_dl = DataLoader(dataset=get_test_set(config.cache_path, config.resample_rate),
worker_init_fn=worker_init_fn,
num_workers=config.num_workers,
batch_size=config.batch_size)
# create pytorch lightening module
pl_module = PLModule(config)
# load model to fine-tune via QAT
if config.wandb_id:
load_pretrained_from_id(pl_module, config.project_name, config.wandb_id)
# fuse layers and prepare for QAT
fuse_model(pl_module)
prepare_quantized(pl_module)
# get model complexity from nessi and log results to wandb
sample = next(iter(train_dl))[0][0].unsqueeze(0)
shape = pl_module.mel_forward(sample).size()
macs, params = nessi.get_model_size(pl_module.model, input_size=shape)
wandb_logger.experiment.config['MACs'] = macs
wandb_logger.experiment.config['Parameters'] = params
# create monitor to keep track of learning rate - we want to check the behaviour of our learning rate schedule
lr_monitor = LearningRateMonitor(logging_interval='epoch')
# create the pytorch lightening trainer by specifying the number of epochs to train, the logger,
# on which kind of device(s) to train and possible callbacks
trainer = pl.Trainer(max_epochs=config.n_epochs,
logger=wandb_logger,
accelerator='auto',
devices=1,
callbacks=[lr_monitor,
QuantizationCallback(),
QuantParamFreezeCallback(config.freeze_params_epochs)])
# start training and validation for the specified number of epochs
trainer.fit(pl_module, train_dl, test_dl)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Example of parser. ')
# general
parser.add_argument('--project_name', type=str, default="DCASE23_Task1")
parser.add_argument('--wandb_id', type=str, default=None) # for loading a pre-trained model
parser.add_argument('--experiment_name', type=str, default="CPJKU_QAT")
parser.add_argument('--num_workers', type=int, default=12) # number of workers for dataloaders
# dataset
# location to store resampled waveform
parser.add_argument('--cache_path', type=str, default=os.path.join("datasets", "cpath"))
# model
parser.add_argument('--n_classes', type=int, default=10) # classification model with 'n_classes' output neurons
parser.add_argument('--in_channels', type=int, default=1)
# adapt the complexity of the neural network (3 main dimensions to scale CP-Mobile)
parser.add_argument('--base_channels', type=int, default=32)
parser.add_argument('--channels_multiplier', type=int, default=2.3)
parser.add_argument('--expansion_rate', type=int, default=3)
# training
parser.add_argument('--n_epochs', type=int, default=20)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--mixstyle_p', type=float, default=0.4) # frequency mixstyle
parser.add_argument('--mixstyle_alpha', type=float, default=0.3)
parser.add_argument('--weight_decay', type=float, default=0.0001)
parser.add_argument('--roll', type=int, default=4000) # roll waveform over time
parser.add_argument('--dir_prob', type=float, default=0.6) # prob. to apply device impulse response augmentation
## knowledge distillation
parser.add_argument('--temperature', type=float, default=2.0)
parser.add_argument('--kd_lambda', type=float, default=0.02)
# learning rate + schedule
# phases:
# 1. exponentially increasing warmup phase (for 'warm_up_len' epochs)
# 2. constant lr phase using value specified in 'lr' (for 'ramp_down_start' - 'warm_up_len' epochs)
# 3. linearly decreasing to value 'las_lr_value' * 'lr' (for 'ramp_down_len' epochs)
# 4. finetuning phase using a learning rate of 'last_lr_value' * 'lr' (for the rest of epochs up to 'n_epochs')
parser.add_argument('--lr', type=float, default=5e-5)
parser.add_argument('--warm_up_len', type=int, default=0)
parser.add_argument('--ramp_down_start', type=int, default=1)
parser.add_argument('--ramp_down_len', type=int, default=16)
parser.add_argument('--last_lr_value', type=float, default=0.1) # relative to 'lr'
# preprocessing
parser.add_argument('--resample_rate', type=int, default=32000)
parser.add_argument('--window_size', type=int, default=3072) # in samples (corresponds to 96 ms)
parser.add_argument('--hop_size', type=int, default=500) # in samples (corresponds to ~16 ms)
parser.add_argument('--n_fft', type=int, default=4096) # length (points) of fft, e.g. 4096 point FFT
parser.add_argument('--n_mels', type=int, default=256) # number of mel bins
parser.add_argument('--freqm', type=int, default=48) # mask up to 'freqm' spectrogram bins
parser.add_argument('--timem', type=int, default=0) # mask up to 'timem' spectrogram frames
parser.add_argument('--fmin', type=int, default=0) # mel bins are created for freqs. between 'fmin' and 'fmax'
parser.add_argument('--fmax', type=int, default=None)
parser.add_argument('--fmin_aug_range', type=int, default=1) # data augmentation: vary 'fmin' and 'fmax'
parser.add_argument('--fmax_aug_range', type=int, default=1000)
# qat specific
# freeze quantizer parameters and batchnorm stats for last n epochs
parser.add_argument('--freeze_params_epochs', type=int, default=4)
args = parser.parse_args()
train(args)