-
Notifications
You must be signed in to change notification settings - Fork 85
/
BilinearSamplerBHWD.lua
executable file
·117 lines (91 loc) · 3.46 KB
/
BilinearSamplerBHWD.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
assert(nn.BilinearSamplerBHWD, "stnbhwd package not preloaded")
-- we overwrite the module of the same name found in the stnbhwd package
local BilinearSamplerBHWD, parent = nn.BilinearSamplerBHWD, nn.Module
--[[
BilinearSamplerBHWD() :
BilinearSamplerBHWD:updateOutput({inputImages, grids})
BilinearSamplerBHWD:updateGradInput({inputImages, grids}, gradOutput)
BilinearSamplerBHWD will perform bilinear sampling of the input images according to the
normalized coordinates provided in the grid. Output will be of same size as the grids,
with as many features as the input images.
- inputImages has to be in BHWD layout
- grids have to be in BHWD layout, with dim(D)=2
- grids contains, for each sample (first dim), the normalized coordinates of the output wrt the input sample
- first coordinate is Y coordinate, second is X
- normalized coordinates : (-1,-1) points to top left, (-1,1) points to top right
- if the normalized coordinates fall outside of the image, then output will be filled with zeros
]]
function BilinearSamplerBHWD:__init()
parent.__init(self)
self.gradInput={}
end
function BilinearSamplerBHWD:check(input, gradOutput)
local inputImages = input[1]
local grids = input[2]
assert(inputImages:nDimension()==4)
assert(grids:nDimension()==4)
assert(inputImages:size(1)==grids:size(1)) -- batch
assert(grids:size(4)==2) -- coordinates
if gradOutput then
assert(grids:size(1)==gradOutput:size(1))
assert(grids:size(2)==gradOutput:size(2))
assert(grids:size(3)==gradOutput:size(3))
end
end
local function addOuterDim(t)
local sizes = t:size()
local newsizes = torch.LongStorage(sizes:size()+1)
newsizes[1]=1
for i=1,sizes:size() do
newsizes[i+1]=sizes[i]
end
return t:view(newsizes)
end
function BilinearSamplerBHWD:updateOutput(input)
local _inputImages = input[1]
local _grids = input[2]
local inputImages, grids
if _inputImages:nDimension()==3 then
inputImages = addOuterDim(_inputImages)
grids = addOuterDim(_grids)
else
inputImages = _inputImages
grids = _grids
end
local input = {inputImages, grids}
self:check(input)
self.output:resize(inputImages:size(1), grids:size(2), grids:size(3), inputImages:size(4))
inputImages.nn.BilinearSamplerBHWD_updateOutput(self, inputImages, grids, self.output)
if _inputImages:nDimension()==3 then
self.output=self.output:select(1,1)
end
return self.output
end
function BilinearSamplerBHWD:updateGradInput(_input, _gradOutput)
local _inputImages = _input[1]
local _grids = _input[2]
local inputImages, grids, gradOutput
if _inputImages:nDimension()==3 then
inputImages = addOuterDim(_inputImages)
grids = addOuterDim(_grids)
gradOutput = addOuterDim(_gradOutput)
else
inputImages = _inputImages
grids = _grids
gradOutput = _gradOutput
end
local input = {inputImages, grids}
self:check(input, gradOutput)
for i=1,#input do
self.gradInput[i] = self.gradInput[i] or input[1].new()
self.gradInput[i]:resizeAs(input[i]):zero()
end
local gradInputImages = self.gradInput[1]
local gradGrids = self.gradInput[2]
inputImages.nn.BilinearSamplerBHWD_updateGradInput(self, inputImages, grids, gradInputImages, gradGrids, gradOutput)
if _gradOutput:nDimension()==3 then
self.gradInput[1]=self.gradInput[1]:select(1,1)
self.gradInput[2]=self.gradInput[2]:select(1,1)
end
return self.gradInput
end