-
Notifications
You must be signed in to change notification settings - Fork 70
/
train_c2f.lua
236 lines (206 loc) · 7.77 KB
/
train_c2f.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
require 'torch'
require 'optim'
require 'image'
require 'pl'
require 'paths'
--image_utils = require 'utils.image'
ok, disp = pcall(require, 'display')
if not ok then print('display not found. unable to plot') end
ADVERSARIAL = require 'adversarial_c2f'
DATASET = require 'dataset_c2f'
NN_UTILS = require 'utils.nn_utils'
MODELS = require 'models_c2f'
----------------------------------------------------------------------
-- parse command-line options
OPT = lapp[[
--save (default "logs") subdirectory to save logs
--saveFreq (default 30) save every saveFreq epochs
--network (default "") reload pretrained network
--noplot plot while training
--D_sgd_lr (default 0.02) D SGD learning rate
--G_sgd_lr (default 0.02) G SGD learning rate
--D_sgd_momentum (default 0) D SGD momentum
--G_sgd_momentum (default 0) G SGD momentum
--batchSize (default 32) batch size
--N_epoch (default 1000) Number of examples per epoch (-1 means all)
--G_L1 (default 0) L1 penalty on the weights of G
--G_L2 (default 0e-6) L2 penalty on the weights of G
--D_L1 (default 1e-7) L1 penalty on the weights of D
--D_L2 (default 0e-6) L2 penalty on the weights of D
--D_iterations (default 1) number of iterations to optimize D for
--G_iterations (default 1) number of iterations to optimize G for
--D_clamp (default 1) Clamp threshold for D's gradient (+/- N)
--G_clamp (default 5) Clamp threshold for G's gradient (+/- N)
--D_optmethod (default "adam") adam|adagrad|sgd
--G_optmethod (default "adam") adam|adagrad|sgd
--threads (default 4) number of threads
--gpu (default 0) gpu to run on (default cpu)
--noiseDim (default 100) dimensionality of noise vector
--window (default 3) window id of sample image
--coarseSize (default 16) coarse scale
--fineSize (default 32) fine scale
--grayscale grayscale mode on/off
--seed (default 1) seed for the RNG
--aws run in AWS mode
]]
if OPT.fineSize ~= 32 then
print("[Warning] Models are currently only optimized for fine size of 32.")
end
START_TIME = os.time()
if OPT.gpu < 0 or OPT.gpu > 3 then OPT.gpu = false end
print(OPT)
-- fix seed
math.randomseed(OPT.seed)
torch.manualSeed(OPT.seed)
-- threads
torch.setnumthreads(OPT.threads)
print('<torch> set nb of threads to ' .. torch.getnumthreads())
-- possible output of disciminator
CLASSES = {"0", "1"}
Y_GENERATOR = 0
Y_NOT_GENERATOR = 1
-- axis of images: 3 channels, <scale> height, <scale> width
if OPT.grayscale then
IMG_DIMENSIONS = {1, OPT.fineSize, OPT.fineSize}
COND_DIM = {1, OPT.fineSize, OPT.fineSize}
else
IMG_DIMENSIONS = {3, OPT.fineSize, OPT.fineSize}
COND_DIM = {3, OPT.fineSize, OPT.fineSize}
end
-- size in values/pixels per input image (channels*height*width)
INPUT_SZ = IMG_DIMENSIONS[1] * IMG_DIMENSIONS[2] * IMG_DIMENSIONS[3]
NOISE_DIM = {1, OPT.fineSize, OPT.fineSize}
----------------------------------------------------------------------
-- get/create dataset
----------------------------------------------------------------------
DATASET.nbChannels = IMG_DIMENSIONS[1]
DATASET.setFileExtension("jpg")
DATASET.setCoarseScale(OPT.coarseSize)
DATASET.setFineScale(OPT.fineSize)
if OPT.aws then
DATASET.setDirs({"/mnt/datasets/out_aug_64x64"})
else
DATASET.setDirs({"dataset/out_aug_64x64"})
end
----------------------------------------------------------------------
-- run on gpu if chosen
print("<trainer> starting gpu support...")
require 'nn'
require 'cutorch'
require 'cunn'
if OPT.gpu then
cutorch.setDevice(OPT.gpu + 1)
cutorch.manualSeed(OPT.seed)
print(string.format("<trainer> using gpu device %d", OPT.gpu))
end
torch.setdefaulttensortype('torch.FloatTensor')
if OPT.network ~= "" then
print(string.format("<trainer> reloading previously trained network: %s", OPT.network))
local tmp = torch.load(OPT.network)
MODEL_D = tmp.D
MODEL_G = tmp.G
OPTSTATE = tmp.optstate
EPOCH = tmp.epoch + 1
if OPT.gpu == false then
MODEL_D:float()
MODEL_G:float()
end
else
MODEL_D = MODELS.create_D(IMG_DIMENSIONS, OPT.gpu ~= false)
MODEL_G = MODELS.create_G(IMG_DIMENSIONS, OPT.gpu ~= false)
end
-- loss function: negative log-likelihood
CRITERION = nn.BCECriterion()
-- retrieve parameters and gradients
PARAMETERS_D, GRAD_PARAMETERS_D = MODEL_D:getParameters()
PARAMETERS_G, GRAD_PARAMETERS_G = MODEL_G:getParameters()
-- this matrix records the current confusion across classes
CONFUSION = optim.ConfusionMatrix(CLASSES)
print("Model D:")
print(MODEL_D)
print("Model G:")
print(MODEL_G)
-- count free parameters in D/G
local nparams = 0
local dModules = MODEL_D:listModules()
for i=1,#dModules do
if dModules[i].weight ~= nil then
nparams = nparams + dModules[i].weight:nElement()
end
end
print('\nNumber of free parameters in D: ' .. nparams)
local nparams = 0
local gModules = MODEL_G:listModules()
for i=1,#gModules do
if gModules[i].weight ~= nil then
nparams = nparams + gModules[i].weight:nElement()
end
end
print('Number of free parameters in G: ' .. nparams .. '\n')
-- Set optimizer state
if OPTSTATE == nil or OPT.rebuildOptstate == 1 then
OPTSTATE = {
adagrad = {
D = { learningRate = 1e-3 },
G = { learningRate = 1e-3 * 3 }
},
adam = {
D = {},
G = {}
},
rmsprop = {D = {}, G = {}},
sgd = {
D = {learningRate = OPT.D_sgd_lr, momentum = OPT.D_sgd_momentum},
G = {learningRate = OPT.G_sgd_lr, momentum = OPT.G_sgd_momentum}
}
}
end
if EPOCH == nil then
EPOCH = 1
end
PLOT_DATA = {}
VIS_NOISE_INPUTS = NN_UTILS.createNoiseInputs(100)
-- Get examples to plot
function getSamples(ds, N)
local N = N or 8
local noiseInputs = torch.Tensor(N, NOISE_DIM[1], NOISE_DIM[2], NOISE_DIM[3])
local condInputs = torch.Tensor(N, COND_DIM[1], COND_DIM[2], COND_DIM[3])
local gt_diff = torch.Tensor(N, IMG_DIMENSIONS[1], OPT.fineSize, OPT.fineSize)
local gt = torch.Tensor(N, IMG_DIMENSIONS[1], OPT.fineSize, OPT.fineSize)
-- Generate samples
noiseInputs:uniform(-1, 1)
for n = 1,N do
local rand = math.random(ds:size())
local example = ds[rand]
condInputs[n] = example.coarse:clone()
gt[n] = example.fine:clone()
gt_diff[n] = example.diff:clone()
end
local samples = MODEL_G:forward({noiseInputs, condInputs})
--local preds_D = MODEL_D:forward({samples, condInputs})
local to_plot = {}
for i=1,N do
local refined = torch.add(condInputs[i]:float(), samples[i]:float())
to_plot[#to_plot+1] = condInputs[i]:float()
to_plot[#to_plot+1] = gt[i]:float()
to_plot[#to_plot+1] = refined
to_plot[#to_plot+1] = gt_diff[i]:float()
to_plot[#to_plot+1] = samples[i]:float()
end
return to_plot
end
VAL_DATA = DATASET.loadImages(0, 500)
-- training loop
while true do
print('Loading new training data...')
TRAIN_DATA = DATASET.loadRandomImages(OPT.N_epoch, 500)
-- plot errors
if not OPT.noplot then
local to_plot = getSamples(VAL_DATA, 20)
disp.image(to_plot, {win=OPT.window, width=2*10*IMG_DIMENSIONS[3], title="Coarse, GT, G img, GT diff, G diff (" .. OPT.save .. " epoch " .. (EPOCH-1) .. ")"})
end
-- train/test
ADVERSARIAL.train(TRAIN_DATA)
--adversarial.test(valData, nval)
ADVERSARIAL.approxParzen(VAL_DATA, 200, OPT.batchSize)
end