-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMSEFeedback.lua
76 lines (68 loc) · 2.18 KB
/
MSEFeedback.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
------------------------------------------------------------------------
--[[ MSEFeedback ]]--
-- Feedback
------------------------------------------------------------------------
local MSEFeedback, parent = torch.class("dp.MSEFeedback", "dp.Feedback")
function MSEFeedback:__init(config)
config = config or {}
assert(torch.type(config) == 'table' and not config[1],
"Constructor requires key-value arguments")
local args, name, target_dim, output_module, target_shape = xlua.unpack(
{config},
'MSEFeedback',
'',
{arg='name', type='string', default='MSE',
help='name identifying Feedback in reports'},
{arg='target_dim', type='number', default=-1,
help='index of target vector to measure MSE against'},
{arg='output_module', type='nn.Module',
help='module applied to output before measuring mean squared error'},
{arg='target_shape', type=string,
help='shape of batch targets'}
)
config.name = name
self._output_module = output_module or nn.Identity()
self._target_dim=target_dim;
parent.__init(self, config)
self._target_shape=target_shape or 'cbw';
end
function MSEFeedback:setup(config)
parent.setup(self, config)
self._mediator:subscribe("doneEpoch", self, "doneEpoch")
end
function MSEFeedback:doneEpoch(report)
if self.mse and self._verbose then
print(self._id:toString().." mse = "..self.mse)
end
end
function MSEFeedback:_add(batch, output, report)
if self._output_module then
output = self._output_module:updateOutput(output)
end
local tgt = batch:targets():forward(self._target_shape)
if self._target_dim >0 then
tgt=tgt[self._target_dim]
end
if torch.type(tgt)~=torch.type(output) then
tgt=tgt:typeAs(output);
end
local diff=output:clone();
diff:csub(tgt); -- (input-target)
local err=diff:norm(2,2):pow(2)/diff:size(2);
self.sse=self.sse+err:sum();
self.count=self.count+err:size(1);
self.mse=self.sse/self.count;
end
function MSEFeedback:_reset()
self.sse=0
self.count=0
self.mse=nil
end
function MSEFeedback:report()
return {
[self:name()] = {
mse = self.mse;
},
n_sample = self.count
}
end