forked from szagoruyko/wide-residual-networks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
augmentation.lua
57 lines (50 loc) · 1.53 KB
/
augmentation.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
require 'image'
require 'nn'
do -- random crop
local RandomCrop, parent = torch.class('nn.RandomCrop', 'nn.Module')
function RandomCrop:__init(pad, mode)
assert(pad)
parent.__init(self)
self.pad = pad
if mode == 'reflection' then
self.module = nn.SpatialReflectionPadding(pad,pad,pad,pad)
elseif mode == 'zero' then
self.module = nn.SpatialZeroPadding(pad,pad,pad,pad)
else
error'unknown mode'
end
self.train = true
end
function RandomCrop:updateOutput(input)
assert(input:dim() == 4)
local imsize = input:size(4)
if self.train then
local padded = self.module:forward(input)
local x = torch.random(1,self.pad*2 + 1)
local y = torch.random(1,self.pad*2 + 1)
self.output = padded:narrow(4,x,imsize):narrow(3,y,imsize)
else
self.output:set(input)
end
return self.output
end
function RandomCrop:type(type)
self.module:type(type)
return parent.type(self, type)
end
end
do -- random horizontal flip
local BatchFlip,parent = torch.class('nn.BatchFlip', 'nn.Module')
function BatchFlip:updateOutput(input)
self.train = self.train == nil and true or self.train
if self.train then
local bs = input:size(1)
local flip_mask = torch.randperm(bs):le(bs/2)
for i=1,input:size(1) do
if flip_mask[i] == 1 then image.hflip(input[i], input[i]) end
end
end
self.output:resize(input:size()):copy(input)
return self.output
end
end