-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcreateBatches.m
50 lines (39 loc) · 1.52 KB
/
createBatches.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
function batches = createBatches(data, labels, batchSize)
%% Create minibatches of size batchSize.
%% If only two arguments are provided, we assume there are no labels.
%%
%% batches is a structure array where each element contains a field data (with
%% the datapoints of that minibatch) and a field labels (with the labels of that
%% minibatch). If the total number of samples is not a multiple of the batch
%% size, the last minibatch is made bigger.
if nargin == 2, batchSize = labels; end
[nSamples nValues] = size(data);
nLabels = cols(labels);
nBatches = ceil(nSamples/batchSize);
if nargin == 3
batches = repmat(struct('data', zeros(batchSize, nValues), 'labels', zeros(batchSize, nLabels)), 1, nBatches);
if nSamples
index = 0;
for i = 1:(nBatches-1)
batches(i).data = gpuArray(data( index + (1:batchSize), :));
batches(i).labels = gpuArray(labels( index + (1:batchSize), :));
index = index + batchSize;
end
% Remaining samples in the last batch.
batches(end).data = gpuArray(data( (index+1) :end, :));
batches(end).labels = gpuArray(labels( (index+1) :end, :));
end
elseif nargin == 2
batches = repmat(struct('data', zeros(batchSize, nValues)), 1, nBatches);
if nSamples
index = 0;
for i = 1:(nBatches-1)
batches(i).data = gpuArray(data( index + (1:batchSize), :));
index = index + batchSize;
end
% Remaining samples in the last batch.
batches(end).data = gpuArray(data( (index+1) :end, :));
end
else
error('Wrong number of arguments');
end