-
Notifications
You must be signed in to change notification settings - Fork 74
/
main.py
574 lines (483 loc) · 20.7 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
# System libs
import os
import random
import time
# Numerical libs
import torch
import torch.nn.functional as F
import numpy as np
import scipy.io.wavfile as wavfile
from scipy.misc import imsave
from mir_eval.separation import bss_eval_sources
# Our libs
from arguments import ArgParser
from dataset import MUSICMixDataset
from models import ModelBuilder, activate
from utils import AverageMeter, \
recover_rgb, magnitude2heatmap,\
istft_reconstruction, warpgrid, \
combine_video_audio, save_video, makedirs
from viz import plot_loss_metrics, HTMLVisualizer
# Network wrapper, defines forward pass
class NetWrapper(torch.nn.Module):
def __init__(self, nets, crit):
super(NetWrapper, self).__init__()
self.net_sound, self.net_frame, self.net_synthesizer = nets
self.crit = crit
def forward(self, batch_data, args):
mag_mix = batch_data['mag_mix']
mags = batch_data['mags']
frames = batch_data['frames']
mag_mix = mag_mix + 1e-10
N = args.num_mix
B = mag_mix.size(0)
T = mag_mix.size(3)
# 0.0 warp the spectrogram
if args.log_freq:
grid_warp = torch.from_numpy(
warpgrid(B, 256, T, warp=True)).to(args.device)
mag_mix = F.grid_sample(mag_mix, grid_warp)
for n in range(N):
mags[n] = F.grid_sample(mags[n], grid_warp)
# 0.1 calculate loss weighting coefficient: magnitude of input mixture
if args.weighted_loss:
weight = torch.log1p(mag_mix)
weight = torch.clamp(weight, 1e-3, 10)
else:
weight = torch.ones_like(mag_mix)
# 0.2 ground truth masks are computed after warpping!
gt_masks = [None for n in range(N)]
for n in range(N):
if args.binary_mask:
# for simplicity, mag_N > 0.5 * mag_mix
gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
else:
gt_masks[n] = mags[n] / mag_mix
# clamp to avoid large numbers in ratio masks
gt_masks[n].clamp_(0., 5.)
# LOG magnitude
log_mag_mix = torch.log(mag_mix).detach()
# 1. forward net_sound -> BxCxHxW
feat_sound = self.net_sound(log_mag_mix)
feat_sound = activate(feat_sound, args.sound_activation)
# 2. forward net_frame -> Bx1xC
feat_frames = [None for n in range(N)]
for n in range(N):
feat_frames[n] = self.net_frame.forward_multiframe(frames[n])
feat_frames[n] = activate(feat_frames[n], args.img_activation)
# 3. sound synthesizer
pred_masks = [None for n in range(N)]
for n in range(N):
pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound)
pred_masks[n] = activate(pred_masks[n], args.output_activation)
# 4. loss
err = self.crit(pred_masks, gt_masks, weight).reshape(1)
return err, \
{'pred_masks': pred_masks, 'gt_masks': gt_masks,
'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
# Calculate metrics
def calc_metrics(batch_data, outputs, args):
# meters
sdr_mix_meter = AverageMeter()
sdr_meter = AverageMeter()
sir_meter = AverageMeter()
sar_meter = AverageMeter()
# fetch data and predictions
mag_mix = batch_data['mag_mix']
phase_mix = batch_data['phase_mix']
audios = batch_data['audios']
pred_masks_ = outputs['pred_masks']
# unwarp log scale
N = args.num_mix
B = mag_mix.size(0)
pred_masks_linear = [None for n in range(N)]
for n in range(N):
if args.log_freq:
grid_unwarp = torch.from_numpy(
warpgrid(B, args.stft_frame//2+1, pred_masks_[0].size(3), warp=False)).to(args.device)
pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp)
else:
pred_masks_linear[n] = pred_masks_[n]
# convert into numpy
mag_mix = mag_mix.numpy()
phase_mix = phase_mix.numpy()
for n in range(N):
pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy()
# threshold if binary mask
if args.binary_mask:
pred_masks_linear[n] = (pred_masks_linear[n] > args.mask_thres).astype(np.float32)
# loop over each sample
for j in range(B):
# save mixture
mix_wav = istft_reconstruction(mag_mix[j, 0], phase_mix[j, 0], hop_length=args.stft_hop)
# save each component
preds_wav = [None for n in range(N)]
for n in range(N):
# Predicted audio recovery
pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0]
preds_wav[n] = istft_reconstruction(pred_mag, phase_mix[j, 0], hop_length=args.stft_hop)
# separation performance computes
L = preds_wav[0].shape[0]
gts_wav = [None for n in range(N)]
valid = True
for n in range(N):
gts_wav[n] = audios[n][j, 0:L].numpy()
valid *= np.sum(np.abs(gts_wav[n])) > 1e-5
valid *= np.sum(np.abs(preds_wav[n])) > 1e-5
if valid:
sdr, sir, sar, _ = bss_eval_sources(
np.asarray(gts_wav),
np.asarray(preds_wav),
False)
sdr_mix, _, _, _ = bss_eval_sources(
np.asarray(gts_wav),
np.asarray([mix_wav[0:L] for n in range(N)]),
False)
sdr_mix_meter.update(sdr_mix.mean())
sdr_meter.update(sdr.mean())
sir_meter.update(sir.mean())
sar_meter.update(sar.mean())
return [sdr_mix_meter.average(),
sdr_meter.average(),
sir_meter.average(),
sar_meter.average()]
# Visualize predictions
def output_visuals(vis_rows, batch_data, outputs, args):
# fetch data and predictions
mag_mix = batch_data['mag_mix']
phase_mix = batch_data['phase_mix']
frames = batch_data['frames']
infos = batch_data['infos']
pred_masks_ = outputs['pred_masks']
gt_masks_ = outputs['gt_masks']
mag_mix_ = outputs['mag_mix']
weight_ = outputs['weight']
# unwarp log scale
N = args.num_mix
B = mag_mix.size(0)
pred_masks_linear = [None for n in range(N)]
gt_masks_linear = [None for n in range(N)]
for n in range(N):
if args.log_freq:
grid_unwarp = torch.from_numpy(
warpgrid(B, args.stft_frame//2+1, gt_masks_[0].size(3), warp=False)).to(args.device)
pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp)
gt_masks_linear[n] = F.grid_sample(gt_masks_[n], grid_unwarp)
else:
pred_masks_linear[n] = pred_masks_[n]
gt_masks_linear[n] = gt_masks_[n]
# convert into numpy
mag_mix = mag_mix.numpy()
mag_mix_ = mag_mix_.detach().cpu().numpy()
phase_mix = phase_mix.numpy()
weight_ = weight_.detach().cpu().numpy()
for n in range(N):
pred_masks_[n] = pred_masks_[n].detach().cpu().numpy()
pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy()
gt_masks_[n] = gt_masks_[n].detach().cpu().numpy()
gt_masks_linear[n] = gt_masks_linear[n].detach().cpu().numpy()
# threshold if binary mask
if args.binary_mask:
pred_masks_[n] = (pred_masks_[n] > args.mask_thres).astype(np.float32)
pred_masks_linear[n] = (pred_masks_linear[n] > args.mask_thres).astype(np.float32)
# loop over each sample
for j in range(B):
row_elements = []
# video names
prefix = []
for n in range(N):
prefix.append('-'.join(infos[n][0][j].split('/')[-2:]).split('.')[0])
prefix = '+'.join(prefix)
makedirs(os.path.join(args.vis, prefix))
# save mixture
mix_wav = istft_reconstruction(mag_mix[j, 0], phase_mix[j, 0], hop_length=args.stft_hop)
mix_amp = magnitude2heatmap(mag_mix_[j, 0])
weight = magnitude2heatmap(weight_[j, 0], log=False, scale=100.)
filename_mixwav = os.path.join(prefix, 'mix.wav')
filename_mixmag = os.path.join(prefix, 'mix.jpg')
filename_weight = os.path.join(prefix, 'weight.jpg')
imsave(os.path.join(args.vis, filename_mixmag), mix_amp[::-1, :, :])
imsave(os.path.join(args.vis, filename_weight), weight[::-1, :])
wavfile.write(os.path.join(args.vis, filename_mixwav), args.audRate, mix_wav)
row_elements += [{'text': prefix}, {'image': filename_mixmag, 'audio': filename_mixwav}]
# save each component
preds_wav = [None for n in range(N)]
for n in range(N):
# GT and predicted audio recovery
gt_mag = mag_mix[j, 0] * gt_masks_linear[n][j, 0]
gt_wav = istft_reconstruction(gt_mag, phase_mix[j, 0], hop_length=args.stft_hop)
pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0]
preds_wav[n] = istft_reconstruction(pred_mag, phase_mix[j, 0], hop_length=args.stft_hop)
# output masks
filename_gtmask = os.path.join(prefix, 'gtmask{}.jpg'.format(n+1))
filename_predmask = os.path.join(prefix, 'predmask{}.jpg'.format(n+1))
gt_mask = (np.clip(gt_masks_[n][j, 0], 0, 1) * 255).astype(np.uint8)
pred_mask = (np.clip(pred_masks_[n][j, 0], 0, 1) * 255).astype(np.uint8)
imsave(os.path.join(args.vis, filename_gtmask), gt_mask[::-1, :])
imsave(os.path.join(args.vis, filename_predmask), pred_mask[::-1, :])
# ouput spectrogram (log of magnitude, show colormap)
filename_gtmag = os.path.join(prefix, 'gtamp{}.jpg'.format(n+1))
filename_predmag = os.path.join(prefix, 'predamp{}.jpg'.format(n+1))
gt_mag = magnitude2heatmap(gt_mag)
pred_mag = magnitude2heatmap(pred_mag)
imsave(os.path.join(args.vis, filename_gtmag), gt_mag[::-1, :, :])
imsave(os.path.join(args.vis, filename_predmag), pred_mag[::-1, :, :])
# output audio
filename_gtwav = os.path.join(prefix, 'gt{}.wav'.format(n+1))
filename_predwav = os.path.join(prefix, 'pred{}.wav'.format(n+1))
wavfile.write(os.path.join(args.vis, filename_gtwav), args.audRate, gt_wav)
wavfile.write(os.path.join(args.vis, filename_predwav), args.audRate, preds_wav[n])
# output video
frames_tensor = [recover_rgb(frames[n][j, :, t]) for t in range(args.num_frames)]
frames_tensor = np.asarray(frames_tensor)
path_video = os.path.join(args.vis, prefix, 'video{}.mp4'.format(n+1))
save_video(path_video, frames_tensor, fps=args.frameRate/args.stride_frames)
# combine gt video and audio
filename_av = os.path.join(prefix, 'av{}.mp4'.format(n+1))
combine_video_audio(
path_video,
os.path.join(args.vis, filename_gtwav),
os.path.join(args.vis, filename_av))
row_elements += [
{'video': filename_av},
{'image': filename_predmag, 'audio': filename_predwav},
{'image': filename_gtmag, 'audio': filename_gtwav},
{'image': filename_predmask},
{'image': filename_gtmask}]
row_elements += [{'image': filename_weight}]
vis_rows.append(row_elements)
def evaluate(netWrapper, loader, history, epoch, args):
print('Evaluating at {} epochs...'.format(epoch))
torch.set_grad_enabled(False)
# remove previous viz results
makedirs(args.vis, remove=True)
# switch to eval mode
netWrapper.eval()
# initialize meters
loss_meter = AverageMeter()
sdr_mix_meter = AverageMeter()
sdr_meter = AverageMeter()
sir_meter = AverageMeter()
sar_meter = AverageMeter()
# initialize HTML header
visualizer = HTMLVisualizer(os.path.join(args.vis, 'index.html'))
header = ['Filename', 'Input Mixed Audio']
for n in range(1, args.num_mix+1):
header += ['Video {:d}'.format(n),
'Predicted Audio {:d}'.format(n),
'GroundTruth Audio {}'.format(n),
'Predicted Mask {}'.format(n),
'GroundTruth Mask {}'.format(n)]
header += ['Loss weighting']
visualizer.add_header(header)
vis_rows = []
for i, batch_data in enumerate(loader):
# forward pass
err, outputs = netWrapper.forward(batch_data, args)
err = err.mean()
loss_meter.update(err.item())
print('[Eval] iter {}, loss: {:.4f}'.format(i, err.item()))
# calculate metrics
sdr_mix, sdr, sir, sar = calc_metrics(batch_data, outputs, args)
sdr_mix_meter.update(sdr_mix)
sdr_meter.update(sdr)
sir_meter.update(sir)
sar_meter.update(sar)
# output visualization
if len(vis_rows) < args.num_vis:
output_visuals(vis_rows, batch_data, outputs, args)
print('[Eval Summary] Epoch: {}, Loss: {:.4f}, '
'SDR_mixture: {:.4f}, SDR: {:.4f}, SIR: {:.4f}, SAR: {:.4f}'
.format(epoch, loss_meter.average(),
sdr_mix_meter.average(),
sdr_meter.average(),
sir_meter.average(),
sar_meter.average()))
history['val']['epoch'].append(epoch)
history['val']['err'].append(loss_meter.average())
history['val']['sdr'].append(sdr_meter.average())
history['val']['sir'].append(sir_meter.average())
history['val']['sar'].append(sar_meter.average())
print('Plotting html for visualization...')
visualizer.add_rows(vis_rows)
visualizer.write_html()
# Plot figure
if epoch > 0:
print('Plotting figures...')
plot_loss_metrics(args.ckpt, history)
# train one epoch
def train(netWrapper, loader, optimizer, history, epoch, args):
torch.set_grad_enabled(True)
batch_time = AverageMeter()
data_time = AverageMeter()
# switch to train mode
netWrapper.train()
# main loop
torch.cuda.synchronize()
tic = time.perf_counter()
for i, batch_data in enumerate(loader):
# measure data time
torch.cuda.synchronize()
data_time.update(time.perf_counter() - tic)
# forward pass
netWrapper.zero_grad()
err, _ = netWrapper.forward(batch_data, args)
err = err.mean()
# backward
err.backward()
optimizer.step()
# measure total time
torch.cuda.synchronize()
batch_time.update(time.perf_counter() - tic)
tic = time.perf_counter()
# display
if i % args.disp_iter == 0:
print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
'lr_sound: {}, lr_frame: {}, lr_synthesizer: {}, '
'loss: {:.4f}'
.format(epoch, i, args.epoch_iters,
batch_time.average(), data_time.average(),
args.lr_sound, args.lr_frame, args.lr_synthesizer,
err.item()))
fractional_epoch = epoch - 1 + 1. * i / args.epoch_iters
history['train']['epoch'].append(fractional_epoch)
history['train']['err'].append(err.item())
def checkpoint(nets, history, epoch, args):
print('Saving checkpoints at {} epochs.'.format(epoch))
(net_sound, net_frame, net_synthesizer) = nets
suffix_latest = 'latest.pth'
suffix_best = 'best.pth'
torch.save(history,
'{}/history_{}'.format(args.ckpt, suffix_latest))
torch.save(net_sound.state_dict(),
'{}/sound_{}'.format(args.ckpt, suffix_latest))
torch.save(net_frame.state_dict(),
'{}/frame_{}'.format(args.ckpt, suffix_latest))
torch.save(net_synthesizer.state_dict(),
'{}/synthesizer_{}'.format(args.ckpt, suffix_latest))
cur_err = history['val']['err'][-1]
if cur_err < args.best_err:
args.best_err = cur_err
torch.save(net_sound.state_dict(),
'{}/sound_{}'.format(args.ckpt, suffix_best))
torch.save(net_frame.state_dict(),
'{}/frame_{}'.format(args.ckpt, suffix_best))
torch.save(net_synthesizer.state_dict(),
'{}/synthesizer_{}'.format(args.ckpt, suffix_best))
def create_optimizer(nets, args):
(net_sound, net_frame, net_synthesizer) = nets
param_groups = [{'params': net_sound.parameters(), 'lr': args.lr_sound},
{'params': net_synthesizer.parameters(), 'lr': args.lr_synthesizer},
{'params': net_frame.features.parameters(), 'lr': args.lr_frame},
{'params': net_frame.fc.parameters(), 'lr': args.lr_sound}]
return torch.optim.SGD(param_groups, momentum=args.beta1, weight_decay=args.weight_decay)
def adjust_learning_rate(optimizer, args):
args.lr_sound *= 0.1
args.lr_frame *= 0.1
args.lr_synthesizer *= 0.1
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1
def main(args):
# Network Builders
builder = ModelBuilder()
net_sound = builder.build_sound(
arch=args.arch_sound,
fc_dim=args.num_channels,
weights=args.weights_sound)
net_frame = builder.build_frame(
arch=args.arch_frame,
fc_dim=args.num_channels,
pool_type=args.img_pool,
weights=args.weights_frame)
net_synthesizer = builder.build_synthesizer(
arch=args.arch_synthesizer,
fc_dim=args.num_channels,
weights=args.weights_synthesizer)
nets = (net_sound, net_frame, net_synthesizer)
crit = builder.build_criterion(arch=args.loss)
# Dataset and Loader
dataset_train = MUSICMixDataset(
args.list_train, args, split='train')
dataset_val = MUSICMixDataset(
args.list_val, args, max_sample=args.num_val, split='val')
loader_train = torch.utils.data.DataLoader(
dataset_train,
batch_size=args.batch_size,
shuffle=True,
num_workers=int(args.workers),
drop_last=True)
loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=args.batch_size,
shuffle=False,
num_workers=2,
drop_last=False)
args.epoch_iters = len(dataset_train) // args.batch_size
print('1 Epoch = {} iters'.format(args.epoch_iters))
# Wrap networks
netWrapper = NetWrapper(nets, crit)
netWrapper = torch.nn.DataParallel(netWrapper, device_ids=range(args.num_gpus))
netWrapper.to(args.device)
# Set up optimizer
optimizer = create_optimizer(nets, args)
# History of peroformance
history = {
'train': {'epoch': [], 'err': []},
'val': {'epoch': [], 'err': [], 'sdr': [], 'sir': [], 'sar': []}}
# Eval mode
evaluate(netWrapper, loader_val, history, 0, args)
if args.mode == 'eval':
print('Evaluation Done!')
return
# Training loop
for epoch in range(1, args.num_epoch + 1):
train(netWrapper, loader_train, optimizer, history, epoch, args)
# Evaluation and visualization
if epoch % args.eval_epoch == 0:
evaluate(netWrapper, loader_val, history, epoch, args)
# checkpointing
checkpoint(nets, history, epoch, args)
# drop learning rate
if epoch in args.lr_steps:
adjust_learning_rate(optimizer, args)
print('Training Done!')
if __name__ == '__main__':
# arguments
parser = ArgParser()
args = parser.parse_train_arguments()
args.batch_size = args.num_gpus * args.batch_size_per_gpu
args.device = torch.device("cuda")
# experiment name
if args.mode == 'train':
args.id += '-{}mix'.format(args.num_mix)
if args.log_freq:
args.id += '-LogFreq'
args.id += '-{}-{}-{}'.format(
args.arch_frame, args.arch_sound, args.arch_synthesizer)
args.id += '-frames{}stride{}'.format(args.num_frames, args.stride_frames)
args.id += '-{}'.format(args.img_pool)
if args.binary_mask:
assert args.loss == 'bce', 'Binary Mask should go with BCE loss'
args.id += '-binary'
else:
args.id += '-ratio'
if args.weighted_loss:
args.id += '-weightedLoss'
args.id += '-channels{}'.format(args.num_channels)
args.id += '-epoch{}'.format(args.num_epoch)
args.id += '-step' + '_'.join([str(x) for x in args.lr_steps])
print('Model ID: {}'.format(args.id))
# paths to save/load output
args.ckpt = os.path.join(args.ckpt, args.id)
args.vis = os.path.join(args.ckpt, 'visualization/')
if args.mode == 'train':
makedirs(args.ckpt, remove=True)
elif args.mode == 'eval':
args.weights_sound = os.path.join(args.ckpt, 'sound_best.pth')
args.weights_frame = os.path.join(args.ckpt, 'frame_best.pth')
args.weights_synthesizer = os.path.join(args.ckpt, 'synthesizer_best.pth')
# initialize best error with a big number
args.best_err = float("inf")
random.seed(args.seed)
torch.manual_seed(args.seed)
main(args)