-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_functions.lua
155 lines (115 loc) · 4.18 KB
/
train_functions.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
log('Loading Train Functions ... ')
function train()
input_cfg = config.train;
input_cfg_test = config.test;
----------------------------
dataset = lines_from(input_cfg.datafile);
dataset_test = lines_from(input_cfg_test.datafile);
local batchSize = config.batchSize;
for iter=1,config.nIter do
---- load one batch
local tic= os.clock()
local TrInput, TrTarget = GetAnImageBatch(input_cfg, dataset)
local toc = os.clock() - tic;
log('loading time :' .. tostring(toc))
-------- train the network--------------
model.learningRate = model:LearningRateComp(iter);
local acc, per_class, loss = model:TrainOneBatch(TrInput,TrTarget);
if (iter % 10) == 0 then
local tic = os.clock()
collectgarbage();
local toc = os.clock() - tic;
print("garbage collection :", toc)
end
if (iter % config.nDisplay) == 0 then
log(('Iter = %d | Train Loss = %f\n'):format(iter,loss));
for i = 1,config.rho do
if i == config.rho then
log(('Train Accuracy -- Global [%d] = %f \n'):format(i, acc[i]));
else
log(('Train Accuracy -- Global [%d] = %f '):format(i, acc[i]));
end
end
for i = 1,config.rho do
if i == config.rho then
log(('Train Accuracy -- Per_Class [%d] = %f \n'):format(i, per_class[i]));
else
log(('Train Accuracy -- Per_Class [%d] = %f '):format(i, per_class[i]));
end
end
end
if (iter % config.nEval) == 0 then
local TeInput, TeTarget = GetAUniformImageBatch(input_cfg_test, dataset_test);
local acc, per_class, loss = model:EvaluateOneBatch(TeInput,TeTarget);
log(('Testing ---------> Iter = %d | Test Loss = %f\n'):format(iter,loss));
for i = 1,config.rho do
if i == config.rho then
log(('Test Accuracy -- Global [%d] = %f \n'):format(i, acc[i]));
else
log(('Test Accuracy -- Global [%d] = %f '):format(i, acc[i]));
end
end
for i = 1,config.rho do
if i == config.rho then
log(('Test Accuracy -- Per_Class [%d] = %f \n'):format(i, per_class[i]));
else
log(('Test Accuracy -- Per_Class [%d] = %f '):format(i, per_class[i]));
end
end
end
if (iter % config.saveModelIter) == 0 then
local fileName = 'Model_iter_' .. iter .. '.t7';
log('Saving NN model in ----> ' .. paths.concat(config.logDirectory, fileName) .. '\n');
model:SaveModel(paths.concat(config.logDirectory, fileName));
config.imgFilenamesLog:flush()
end
end
end
---------------------------------------------------------
function test()
input_cfg = config.test;
----------------------------
dataset = lines_from(input_cfg.datafile);
local batchSize = config.batchSize;
local all_predictions
local all_targets
for iter=1,config.nIter do
---- load one batch
local tic= os.clock()
local TeInput, TeTarget = GetAnImageBatch(input_cfg, dataset);
local toc = os.clock() - tic;
log('loading time :' .. tostring(toc))
if (iter % 10) == 0 then
local tic = os.clock()
collectgarbage();
local toc = os.clock() - tic;
print("garbage collection :", toc)
end
local acc, per_class, loss, acc_all, predicts = model:EvaluateOneBatch(TeInput,TeTarget);
if not all_predictions then
all_predictions = predicts
else
for i = 1, config.rho do
all_predictions[i] = torch.cat(all_predictions[i], predicts[i], 1)
end
end
if not all_targets then
all_targets = TeTarget
else
for i = 1, config.rho do
all_targets[i] = torch.cat(all_targets[i], TeTarget[i], 1)
end
end
end
local fname = input_cfg.outfile
results_pred = torch.Tensor(all_predictions[1]:size(1),config.rho)
results_gt = torch.Tensor(all_targets[1]:size(1),config.rho)
log("Saving all predictions at " .. fname)
for i = 1,config.rho do
results_pred[{{},{i}}] = all_predictions[i]:double()
results_gt[{{},{i}}] = all_targets[i]:double()
end
vars = {preds = results_pred, targets = results_gt}
log('Saving results in ----> ' .. fname .. '\n');
mattorch.save(fname, vars)
end