-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathtrain.lua
258 lines (220 loc) · 7.11 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
require 'torch'
local lapp = require 'pl.lapp'
opt = lapp[[
-m,--model (default 'mnistconv')
-b,--batch_size (default 128) Batch size
--LR (default 0) Learning rate
--dropout (default 0.5) Dropout
--L (default 0) Num. Langevin iterations
--gamma (default 1e-4) Langevin gamma coefficient
--scoping (default 1e-3) Scoping parameter \gamma*(1+scoping)^t
--noise (default 1e-4) Langevin dynamics additive noise factor (*stepSize)
-g,--gpu (default 1) GPU id
--L2 (default 0) L2 regularization
-s,--seed (default 42)
-e,--max_epochs (default 10)
-v,--verbose Print gradient statistics
-h,--help Print this message
]]
print(opt)
if opt.help then
os.exit()
end
local utils = require 'utils.lua'
local models = require 'models'
require 'entropyoptim'
function trainer(d)
local x, y = d.data, d.labels
local w, dw = model:getParameters()
model:training()
local bs = opt.batch_size
local num_batches = x:size(1)/bs
local timer = torch.Timer()
local timer1 = torch.Timer()
local loss = 0
confusion:zero()
for b =1,num_batches do
collectgarbage()
timer1:reset()
local feval = function(_w, dry)
local dry = dry or false
if _w ~= w then w:copy(_w) end
dw:zero()
local idx = torch.Tensor(bs):random(1, d.size):type('torch.LongTensor')
local xc, yc = x:index(1, idx):cuda(), y:index(1, idx):cuda()
local yh = model:forward(xc)
local f = cost:forward(yh, yc)
local dfdy = cost:backward(yh, yc)
model:backward(xc, dfdy)
cutorch.synchronize()
if dry == false then
loss = loss + f
confusion:batchAdd(yh, yc)
confusion:updateValids()
end
return f, dw
end
optim.entropysgd(feval, w, optim_state)
if b % 100 == 0 then
print( ('+[%3d][%3d/%3d] %.5f %.3f%%'):format(epoch,
b, num_batches, loss/b, (1 - confusion.totalValid)*100))
end
end
loss = loss/num_batches
print(('Train: [%3d] %.5f %.3f%% [%.2fs]'):format(epoch, loss,
(1 - confusion.totalValid)*100, timer:time().real))
print('')
end
function set_dropout(p)
local p = p or 0
for i,m in ipairs(model.modules) do
if m.module_name == 'nn.Dropout' or torch.typename(m) == 'nn.Dropout' then
m.p = p
end
end
-- set input dropout back
if opt.model == 'cifarconv' then
if p > 0 then
local m = model.modules[1]
assert(m.module_name == 'nn.Dropout' or torch.typename(m) == 'nn.Dropout')
m.p = 0.2
end
end
end
-- this is a weird hack
-- batch-normalization parameters do not train well due to dropout, so this function sets
-- the dropout to zero, dry-feeds the dataset to let the batch-normalization params settle
-- and then sets the dropout back to its old value again
function compute_bn_params(d)
set_dropout(0)
local x, y = d.data, d.labels
local w, dw = model:getParameters()
model:training()
local bs = 1024
local num_batches = math.ceil(x:size(1)/bs)
for b =1,num_batches do
collectgarbage()
local feval = function(_w)
if _w ~= w then w:copy(_w) end
dw:zero()
local sidx,eidx = (b-1)*bs, math.min(b*bs, x:size(1))
local xc, yc = x:narrow(1, sidx + 1, eidx-sidx):cuda(), y:narrow(1, sidx + 1, eidx-sidx):cuda()
local yh = model:forward(xc)
cutorch.synchronize()
return f, dw
end
feval(w)
end
set_dropout(opt.dropout)
end
function tester(d)
compute_bn_params(d)
local x, y = d.data, d.labels
model:evaluate()
local bs = 1024
local num_batches = math.ceil(x:size(1)/bs)
local loss = 0
confusion:zero()
for b =1,num_batches do
collectgarbage()
local sidx,eidx = (b-1)*bs, math.min(b*bs, x:size(1))
local xc, yc = x:narrow(1, sidx + 1, eidx-sidx):cuda(),
y:narrow(1, sidx + 1, eidx-sidx):cuda()
local yh = model:forward(xc)
local f = cost:forward(yh, yc)
cutorch.synchronize()
confusion:batchAdd(yh, yc)
confusion:updateValids()
loss = loss + f
if b % 100 == 0 then
print( ('*[%2d][%3d/%3d] %.5f %.3f%%'):format(epoch, b, num_batches, loss/b, (1 - confusion.totalValid)*100))
end
end
loss = loss/num_batches
print( ('Test: [%2d] %.5f %.3f%%'):format(epoch, loss, (1 - confusion.totalValid)*100))
print('')
end
function learning_rate_schedule()
local lr = opt.LR
if opt.LR > 0 then
print(('[LR] %.5f'):format(lr))
return lr
end
local regimes = {}
if opt.L == 0 then
if opt.model == 'mnistfc' or opt.model == 'mnistconv' then
regimes = {
{1,30,0.1},
{30,60,0.1*0.2},
{60,150,0.1*0.2^2}}
opt.max_epochs = 100
elseif opt.model == 'cifarconv' then
regimes = {
{1,60, 0.1},
{60,120, 0.1*0.2^1},
{120,180, 0.1*0.2^2},
{180,250, 0.1*0.2^3}}
opt.max_epochs = 200
end
else
if opt.model == 'mnistfc' then
regimes = {
{1,2,1},
{3,15,0.1}}
opt.max_epochs = 5
elseif opt.model == 'mnistconv' then
regimes = {
{1,2,1},
{3,7,0.1},
{8,15,0.01}}
opt.max_epochs = 5
elseif opt.model == 'cifarconv' then
regimes = {
{1,3,1},
{4,6,0.2},
{7,12,0.04}}
opt.max_epochs = 10
end
end
for _,row in ipairs(regimes) do
if epoch >= row[1] and epoch <= row[2] then
lr = row[3]
break
end
end
print(('[LR] %.5f'):format(lr))
return lr
end
function main()
utils.set_gpu()
model, cost = models.build()
local train, val, test = utils.load_dataset()
local classes = torch.totable(torch.range(1,10))
confusion = optim.ConfusionMatrix(classes)
optim_state = optim_state or { learningRate= opt.LR,
learningRateDecay = 0,
weightDecay = opt.L2,
momentum = 0.9,
nesterov = true,
dampening = 0,
rho=opt.rho,
gamma=opt.gamma,
scoping=opt.scoping,
L=opt.L,
noise = opt.noise}
local freq = 5
if opt.L > 0 then freq = 1 end
epoch = epoch or 1
while epoch <= opt.max_epochs do
optim_state.learningRate = learning_rate_schedule()
trainer(train)
if epoch % freq == 0 then
tester(val)
end
epoch = epoch + 1
print('')
end
print('Finished')
tester(test)
end
main()