-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun_passt_training.py
296 lines (239 loc) · 13.3 KB
/
run_passt_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
import argparse
import os
from helpers.utils import mixstyle
from helpers.lr_schedule import exp_warmup_linear_down
from helpers.init import worker_init_fn
from models.passt import get_model
from models.mel import AugmentMelSTFT
from helpers import nessi
from datasets.dcase23 import get_training_set, get_test_set
class PLModule(pl.LightningModule):
def __init__(self, config):
super(PLModule, self).__init__()
self.config = config
# model to preprocess waveforms into log mel spectrograms
self.mel = AugmentMelSTFT(n_mels=config.n_mels,
sr=config.resample_rate,
win_length=config.window_size,
hopsize=config.hop_size,
n_fft=config.n_fft,
freqm=config.freqm,
timem=config.timem,
fmin=config.fmin,
fmax=config.fmax,
fmin_aug_range=config.fmin_aug_range,
fmax_aug_range=config.fmax_aug_range
)
self.model = get_model(arch="passt_s_swa_p16_128_ap476",
n_classes=config.n_classes,
input_fdim=config.input_fdim,
s_patchout_t=config.s_patchout_t,
s_patchout_f=config.s_patchout_f)
self.device_ids = ['a', 'b', 'c', 's1', 's2', 's3', 's4', 's5', 's6']
self.device_groups = {'a': "real", 'b': "real", 'c': "real",
's1': "seen", 's2': "seen", 's3': "seen",
's4': "unseen", 's5': "unseen", 's6': "unseen"}
self.calc_device_info = True
self.epoch = 0
def mel_forward(self, x):
old_shape = x.size()
x = x.reshape(-1, old_shape[2])
x = self.mel(x)
x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
return x
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, files, labels, devices, cities, teacher_logits = batch
if self.mel:
x = self.mel_forward(x)
if self.config.mixstyle_p > 0:
x = mixstyle(x, self.config.mixstyle_p, self.config.mixstyle_alpha)
y_hat, embed = self.forward(x)
samples_loss = F.cross_entropy(y_hat, labels, reduction="none")
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
_, preds = torch.max(y_hat, dim=1)
n_correct_pred = (preds == labels).sum()
results = {"loss": loss, "n_correct_pred": n_correct_pred, "n_pred": len(labels)}
if self.calc_device_info:
devices = [d.rsplit("-", 1)[1][:-4] for d in files]
for d in self.device_ids:
results["devloss." + d] = torch.as_tensor(0., device=self.device)
results["devcnt." + d] = torch.as_tensor(0., device=self.device)
for i, d in enumerate(devices):
results["devloss." + d] = results["devloss." + d] + samples_loss[i]
results["devcnt." + d] = results["devcnt." + d] + 1.
return results
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
train_acc = sum([x['n_correct_pred'] for x in outputs]) * 1.0 / sum(x['n_pred'] for x in outputs)
logs = {'train.loss': avg_loss, 'train_acc': train_acc}
if self.calc_device_info:
for d in self.device_ids:
dev_loss = torch.stack([x["devloss." + d] for x in outputs]).sum()
dev_cnt = torch.stack([x["devcnt." + d] for x in outputs]).sum()
logs["tloss." + d] = dev_loss / dev_cnt
logs["tcnt." + d] = dev_cnt
self.log_dict(logs)
print(f"Training Loss: {avg_loss}")
print(f"Training Accuracy: {train_acc}")
def validation_step(self, batch, batch_idx):
x, files, labels, devices, cities = batch
if self.mel:
x = self.mel_forward(x)
y_hat, embed = self.forward(x)
samples_loss = F.cross_entropy(y_hat, labels, reduction="none")
loss = samples_loss.mean()
self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
_, preds = torch.max(y_hat, dim=1)
n_correct_pred_per_sample = (preds == labels)
n_correct_pred = n_correct_pred_per_sample.sum()
results = {"val_loss": loss, "n_correct_pred": n_correct_pred, "n_pred": len(labels)}
if self.calc_device_info:
devices = [d.rsplit("-", 1)[1][:-4] for d in files]
for d in self.device_ids:
results["devloss." + d] = torch.as_tensor(0., device=self.device)
results["devcnt." + d] = torch.as_tensor(0., device=self.device)
results["devn_correct." + d] = torch.as_tensor(0., device=self.device)
for i, d in enumerate(devices):
results["devloss." + d] = results["devloss." + d] + samples_loss[i]
results["devn_correct." + d] = results["devn_correct." + d] + n_correct_pred_per_sample[i]
results["devcnt." + d] = results["devcnt." + d] + 1
return results
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
val_acc = sum([x['n_correct_pred'] for x in outputs]) * 1.0 / sum(x['n_pred'] for x in outputs)
logs = {'val.loss': avg_loss, 'val_acc': val_acc}
if self.calc_device_info:
for d in self.device_ids:
dev_loss = torch.stack([x["devloss." + d] for x in outputs]).sum()
dev_cnt = torch.stack([x["devcnt." + d] for x in outputs]).sum()
dev_corrct = torch.stack([x["devn_correct." + d] for x in outputs]).sum()
logs["vloss." + d] = dev_loss / dev_cnt
logs["vacc." + d] = dev_corrct / dev_cnt
logs["vcnt." + d] = dev_cnt
# device groups
logs["acc." + self.device_groups[d]] = logs.get("acc." + self.device_groups[d], 0.) + dev_corrct
logs["count." + self.device_groups[d]] = logs.get("count." + self.device_groups[d], 0.) + dev_cnt
logs["lloss." + self.device_groups[d]] = logs.get("lloss." + self.device_groups[d], 0.) + dev_loss
for d in set(self.device_groups.values()):
logs["acc." + d] = logs["acc." + d] / logs["count." + d]
logs["lloss.False" + d] = logs["lloss." + d] / logs["count." + d]
self.log_dict(logs)
if self.epoch > 0:
print()
print(f"Validation Loss: {avg_loss}")
print(f"Validation Accuracy: {val_acc}")
self.epoch += 1
def configure_optimizers(self):
"""
This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined.
The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself).
:return: dict containing optimizer and learning rate scheduler
"""
optimizer = torch.optim.Adam(self.parameters(), lr=self.config.lr, weight_decay=self.config.weight_decay)
schedule_lambda = \
exp_warmup_linear_down(self.config.warm_up_len, self.config.ramp_down_len, self.config.ramp_down_start,
self.config.last_lr_value)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, schedule_lambda)
return {
'optimizer': optimizer,
'lr_scheduler': lr_scheduler
}
def train(config):
# logging is done using wandb
wandb_logger = WandbLogger(
project=config.project_name,
notes="CPJKU pipeline for DCASE23 Task 1.",
tags=["DCASE23"],
config=config, # this logs all hyperparameters for us
name=config.experiment_name
)
# train dataloader
train_dl = DataLoader(dataset=get_training_set(config.cache_path, config.resample_rate, config.roll,
config.dir_prob),
worker_init_fn=worker_init_fn,
num_workers=config.num_workers,
batch_size=config.batch_size,
shuffle=True)
# test loader
test_dl = DataLoader(dataset=get_test_set(config.cache_path, config.resample_rate),
worker_init_fn=worker_init_fn,
num_workers=config.num_workers,
batch_size=config.batch_size)
# create pytorch lightening module
pl_module = PLModule(config)
# get model complexity from nessi and log results to wandb
# ATTENTION: this is before layer fusion, therefore the MACs and Params slightly deviate from what is
# reported in the challenge submission
sample = next(iter(train_dl))[0][0].unsqueeze(0)
shape = pl_module.mel_forward(sample).size()
macs, params = nessi.get_model_size(pl_module.model, input_size=shape)
wandb_logger.experiment.config['MACs'] = macs
wandb_logger.experiment.config['Parameters'] = params
# create monitor to keep track of learning rate - we want to check the behaviour of our learning rate schedule
lr_monitor = LearningRateMonitor(logging_interval='epoch')
# create the pytorch lightening trainer by specifying the number of epochs to train, the logger,
# on which kind of device(s) to train and possible callbacks
trainer = pl.Trainer(max_epochs=config.n_epochs,
logger=wandb_logger,
accelerator='auto',
devices=1,
callbacks=[lr_monitor])
# start training and validation for the specified number of epochs
trainer.fit(pl_module, train_dl, test_dl)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Example of parser. ')
# general
parser.add_argument('--project_name', type=str, default="DCASE23_Task1")
parser.add_argument('--experiment_name', type=str, default="CPJKU_passt_teacher_training")
parser.add_argument('--num_workers', type=int, default=12) # number of workers for dataloaders
# dataset
# location to store resampled waveform
parser.add_argument('--cache_path', type=str, default=os.path.join("datasets", "cpath"))
# model
parser.add_argument('--arch', type=str, default='passt_s_swa_p16_128_ap476') # pretrained passt model
parser.add_argument('--n_classes', type=int, default=10) # classification model with 'n_classes' output neurons
parser.add_argument('--input_fdim', type=int, default=128)
parser.add_argument('--s_patchout_t', type=int, default=0)
parser.add_argument('--s_patchout_f', type=int, default=6)
# training
parser.add_argument('--n_epochs', type=int, default=25)
parser.add_argument('--batch_size', type=int, default=80)
parser.add_argument('--mixstyle_p', type=float, default=0.4) # frequency mixstyle
parser.add_argument('--mixstyle_alpha', type=float, default=0.4)
parser.add_argument('--weight_decay', type=float, default=0.001)
parser.add_argument('--roll', type=int, default=10_000) # roll waveform over time
parser.add_argument('--dir_prob', type=float, default=0.6) # prob. to apply device impulse response augmentation
# learning rate + schedule
# phases:
# 1. exponentially increasing warmup phase (for 'warm_up_len' epochs)
# 2. constant lr phase using value specified in 'lr' (for 'ramp_down_start' - 'warm_up_len' epochs)
# 3. linearly decreasing to value 'las_lr_value' * 'lr' (for 'ramp_down_len' epochs)
# 4. finetuning phase using a learning rate of 'last_lr_value' * 'lr' (for the rest of epochs up to 'n_epochs')
parser.add_argument('--lr', type=float, default=0.00001)
parser.add_argument('--warm_up_len', type=int, default=3)
parser.add_argument('--ramp_down_start', type=int, default=3)
parser.add_argument('--ramp_down_len', type=int, default=10)
parser.add_argument('--last_lr_value', type=float, default=0.01) # relative to 'lr'
# preprocessing
parser.add_argument('--resample_rate', type=int, default=32000)
parser.add_argument('--window_size', type=int, default=800) # in samples
parser.add_argument('--hop_size', type=int, default=320) # in samples
parser.add_argument('--n_fft', type=int, default=1024) # length (points) of fft
parser.add_argument('--n_mels', type=int, default=128) # number of mel bins
parser.add_argument('--freqm', type=int, default=48) # mask up to 'freqm' spectrogram bins
parser.add_argument('--timem', type=int, default=20) # mask up to 'timem' spectrogram frames
parser.add_argument('--fmin', type=int, default=0) # mel bins are created for freqs. between 'fmin' and 'fmax'
parser.add_argument('--fmax', type=int, default=None)
parser.add_argument('--fmin_aug_range', type=int, default=1) # data augmentation: vary 'fmin' and 'fmax'
parser.add_argument('--fmax_aug_range', type=int, default=1000)
args = parser.parse_args()
train(args)