-
Notifications
You must be signed in to change notification settings - Fork 2
/
Mul.lua
38 lines (29 loc) · 860 Bytes
/
Mul.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
local Mul, parent = torch.class('nn.Mul', 'nn.Module')
function Mul:__init()
parent.__init(self)
self.weight = torch.Tensor(1)
self.gradWeight = torch.Tensor(1)
self:reset()
end
function Mul:reset(stdv)
if stdv then
stdv = stdv * math.sqrt(3)
else
stdv = 1./math.sqrt(self.weight:size(1))
end
self.weight[1] = torch.uniform(-stdv, stdv);
end
function Mul:updateOutput(input)
self.output:resizeAs(input):copy(input);
self.output:mul(self.weight[1]);
return self.output
end
function Mul:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(input):zero()
self.gradInput:add(self.weight[1], gradOutput)
return self.gradInput
end
function Mul:accGradParameters(input, gradOutput, scale)
scale = scale or 1
self.gradWeight[1] = self.gradWeight[1] + scale*input:dot(gradOutput);
end