-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess.py
79 lines (67 loc) · 2.47 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# This preprocessing is based on the codes written by the TAs of the 'Deep Learning for Music Analysis and Generation' course.
import numpy as np
import torch
import torchaudio
import librosa
import os
num_mels = 80
n_fft = 1024
hop_size = 256
win_size = 1024
sampling_rate = 22050
fmin = 0
fmax = 8000
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
device = y.device
melTorch = torchaudio.transforms.MelSpectrogram(
sample_rate=sampling_rate,
n_fft=n_fft,
n_mels=num_mels,
hop_length=hop_size,
win_length=win_size,
f_min=fmin,
f_max=fmax,
pad=int((n_fft - hop_size) / 2),
center=center
).to(device)
spec = melTorch(y)
return spec
def to_mono(audio, dim=-2):
if len(audio.size()) > 1:
return torch.mean(audio, dim=dim, keepdim=True)
else:
return audio
def load_audio(audio_path, sr=None, mono=True):
if 'mp3' in audio_path:
torchaudio.set_audio_backend('sox_io')
audio, org_sr = torchaudio.load(audio_path)
audio = to_mono(audio) if mono else audio
if sr and org_sr != sr:
audio = torchaudio.transforms.Resample(org_sr, sr)(audio)
return audio
if __name__ == '__main__':
load_audio_path = 'dataset/valid'
save_npy_path = 'dataset/valid_mel'
if not os.path.exists(save_npy_path):
os.mkdir(save_npy_path)
audio_list = os.listdir(load_audio_path)
audio_list.sort()
for audio in audio_list:
y = load_audio(os.path.join(load_audio_path, audio), sr=sampling_rate)
mel_tensor = mel_spectrogram(
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax)
mel = mel_tensor.squeeze().cpu().numpy()
file_name = os.path.join(save_npy_path, audio[:-4] + '.npy')
np.save(file_name, mel)
mel = np.load(file_name) # check the .npy is readable
# plot the last melspectrogram
# ref: https://librosa.org/doc/main/generated/librosa.feature.melspectrogram.html
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
# don't forget to do dB conversion
S_dB = librosa.power_to_db(mel, ref=np.max)
img = librosa.display.specshow(S_dB, x_axis='time',
y_axis='mel', sr=sampling_rate,
fmax=fmax, ax=ax, hop_length=hop_size, n_fft=n_fft)
fig.colorbar(img, ax=ax, format='%+2.0f dB')
ax.set(title='Mel-frequency spectrogram')