-
Notifications
You must be signed in to change notification settings - Fork 40
/
simple_loader.py
76 lines (57 loc) · 1.85 KB
/
simple_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#coding: utf-8
import os
import os.path as osp
import time
import random
import numpy as np
import random
import string
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
np.random.seed(1)
random.seed(1)
class FilePathDataset(torch.utils.data.Dataset):
def __init__(self, df):
self.data = df
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
input_ids = self.data[idx]
return input_ids
class Collater(object):
"""
Args:
adaptive_batch_size (bool): if true, decrease batch size when long data comes.
"""
def __init__(self, return_wave=False):
self.text_pad_index = 0
self.return_wave = return_wave
def __call__(self, batch):
# batch[0] = wave, mel, text, f0, speakerid
batch_size = len(batch)
input_ids = []
for bid, (input_id) in enumerate(batch):
input_ids.extend(input_id['input_ids'])
return input_ids
def build_dataloader(df,
validation=False,
batch_size=4,
num_workers=1,
device='cpu',
collate_config={},
dataset_config={}):
dataset = FilePathDataset(df, **dataset_config)
collate_fn = Collater(**collate_config)
data_loader = DataLoader(dataset,
batch_size=batch_size,
shuffle=(not validation),
num_workers=num_workers,
drop_last=(not validation),
collate_fn=collate_fn,
pin_memory=(device != 'cpu'))
return data_loader