-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdataloader.py
144 lines (122 loc) · 5.15 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from loguru import logger
from boltons import fileutils
import os
import os.path
from collections import defaultdict
import numpy as np
import torch
import torch.utils.data as data
import random
import soundfile
from hparams import *
from stft.stft import STFT
from add_noise import inject_noise_sample
from tqdm import tqdm, trange
def spect_loader(path, trim_start, return_phase=False, num_samples=16000, crop=True):
y, sr = soundfile.read(path)
if crop:
y = y[trim_start: trim_start + num_samples] # trim 'trim_start' from start and crop 1 sec
y = np.hstack((y, np.zeros((num_samples - len(y)))))
stft = STFT(N_FFT, HOP_LENGTH)
y = torch.FloatTensor(y).unsqueeze(0)
spect, phase = stft.transform(y)
if return_phase:
return spect, phase
return spect
class BaseLoader(data.Dataset):
def __init__(self, root,
n_messages=1,
n_pairs=100000,
transform=None,
trim_start=0,
num_samples=16000,
test=False):
random.seed(0)
self.spect_pairs = self.make_pairs_dataset(root, n_messages, n_pairs)
self.root = root
self.transform = transform
self.loader = spect_loader
self.trim_start = int(trim_start)
self.num_samples = num_samples
self.test = test
def __getitem__(self, index):
carrier_file, msg_files = self.spect_pairs[index]
carrier_spect, carrier_phase = self.loader(carrier_file,
self.trim_start,
return_phase=True,
num_samples=self.num_samples)
msg = [self.loader(msg_file,
self.trim_start,
return_phase=True,
num_samples=self.num_samples)
for msg_file in msg_files]
msg_spects = list(map(lambda x: x[0], msg))
msg_phases = list(map(lambda x: x[1], msg))
if self.transform is not None:
carrier_spect = self.transform(carrier_spect)
carrier_phase= self.transform(carrier_phase)
msg_spects = [self.transform(msg_spect) for msg_spect in msg_spects]
if self.test:
return carrier_spect, carrier_phase, msg_spects, msg_phases
else:
return carrier_spect, carrier_phase, msg_spects
def __len__(self):
return len(self.spect_pairs)
class YohoLoader(BaseLoader):
def __init__(self, root,
n_messages=1,
n_pairs=100000,
transform=None,
trim_start=0,
num_samples=8000,
test=False):
super(YohoLoader, self).__init__(root,
n_messages,
n_pairs,
transform,
trim_start,
num_samples,
test)
def make_pairs_dataset(self, path, n_hidden_messages, n_pairs):
pairs = []
files_by_speaker = defaultdict(list)
unfiltered_wav_files = list(fileutils.iter_find_files(path, "*.wav"))
wav_files = []
for wav in unfiltered_wav_files:
# filter out short files
try:
if soundfile.read(wav)[0].shape[0] > 3*8000: wav_files.append(wav)
except:
pass
for wav in wav_files:
speaker = int(wav.split('/')[-3])
files_by_speaker[speaker].append(wav)
for i in range(n_pairs):
speaker = random.sample(files_by_speaker.keys(), 1)[0]
sampled_files = random.sample(files_by_speaker[speaker], 1+n_hidden_messages)
carrier_file, hidden_message_files = sampled_files[0], sampled_files[1:]
pairs.append((carrier_file, hidden_message_files))
return pairs
class TimitLoader(BaseLoader):
def __init__(self, root,
n_messages=1,
n_pairs=100000,
transform=None,
trim_start=0,
num_samples=16000,
test=False):
super(TimitLoader, self).__init__(root,
n_messages,
n_pairs,
transform,
trim_start,
num_samples,
test)
def make_pairs_dataset(self, path, n_hidden_messages, n_pairs):
pairs = []
wav_files = list(fileutils.iter_find_files(path, "*.wav"))
for i in range(n_pairs):
sampled_files = random.sample(wav_files, 1+n_hidden_messages)
carrier_file, hidden_message_files = sampled_files[0], sampled_files[1:]
pairs.append((carrier_file, hidden_message_files))
return pairs