-
Notifications
You must be signed in to change notification settings - Fork 118
/
base.py
26 lines (23 loc) · 1.02 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from nvidia.dali.plugin.pytorch import DALIGenericIterator
class DALIDataloader(DALIGenericIterator):
def __init__(self, pipeline, size, batch_size, output_map=["data", "label"], auto_reset=True, onehot_label=False):
self.size = size
self.batch_size = batch_size
self.onehot_label = onehot_label
self.output_map = output_map
super().__init__(pipelines=pipeline, size=size, auto_reset=auto_reset, output_map=output_map)
def __next__(self):
if self._first_batch is not None:
batch = self._first_batch
self._first_batch = None
return batch
data = super().__next__()[0]
if self.onehot_label:
return [data[self.output_map[0]], data[self.output_map[1]].squeeze().long()]
else:
return [data[self.output_map[0]], data[self.output_map[1]]]
def __len__(self):
if self.size%self.batch_size==0:
return self.size//self.batch_size
else:
return self.size//self.batch_size+1