-
Notifications
You must be signed in to change notification settings - Fork 85
/
SmoothHuberPenalty.lua
executable file
·84 lines (67 loc) · 2.3 KB
/
SmoothHuberPenalty.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
require 'nn'
require 'extracunn'
--[[
This module does not modify its input, it only adds gradient in backprop to penalise for non-smoothness
First we use a non-trainable convolutional layer, with fixed 5point stencil as filters, then we compute
L1 penalty on the result.
--]]
SmoothHuberPenalty, parent = torch.class('nn.SmoothHuberPenalty', 'nn.Module')
function SmoothHuberPenalty:__init(transf, l1weight, threshold, sizeAverage)
parent.__init(self)
-- first layer is a non-trainable convolution with fixed 5point stencil as filters
--local stencil = torch.Tensor(transf,3,3):zero():cuda()
local gx = torch.Tensor(3,3):zero()
gx[2][1] = -1/2
gx[2][2] = 0
gx[2][3] = 1/2
gx = gx:cuda()
local gradx = nn.SpatialConvolution(1,1,3,3,1,1,1,1)
gradx.weight:copy(gx)
gradx.bias:fill(0)
local gy = torch.Tensor(3,3):zero()
gy[1][2] = -1/2
gy[2][2] = 0
gy[3][2] = 1/2
gy = gy:cuda()
local grady = nn.SpatialConvolution(1,1,3,3,1,1,1,1)
grady.weight:copy(gy)
grady.bias:fill(0)
local branchx = nn.Sequential()
branchx:add(gradx):add(nn.Square())
local branchy = nn.Sequential()
branchy:add(grady):add(nn.Square())
local gradconcat = nn.ConcatTable()
gradconcat:add(branchx):add(branchy)
local grad = nn.Sequential()
grad:add(gradconcat)
grad:add(nn.CAddTable())
grad:add(nn.Sqrt())
self.grad = grad
--print (self.conv.weight)
self.threshold = threshold or 0.001
self.l1weight = l1weight or 0.01
self.sizeAverage = sizeAverage or true
end
function SmoothHuberPenalty:updateOutput(input)
--self.output = input
self.output:resizeAs(input):copy(input)
--self.output:renorm(2,1,0.33)
return self.output
end
function SmoothHuberPenalty:updateGradInput(input, gradOutput)
local m = self.l1weight
if self.sizeAverage == true then
m = m/input:nElement()
end
self.gradInput:resizeAs(gradOutput)
for i=1,input:size(1) do
local dx = self.grad:updateOutput(input[{{i},{},{}}])
dx:Huber(self.threshold)
local gradL1 = dx:mul(m)
self.gradInput[{{i},{},{}}] = self.grad:updateGradInput(input[{{i},{},{}}],gradL1):clone()
end
--self.gradInput = self.conv:updateGradInput(input,gradL1)
self.gradInput:add(gradOutput)
--self.gradInput:resizeAs(gradOutput):copy(gradOutput)
return self.gradInput
end