-
Notifications
You must be signed in to change notification settings - Fork 85
/
main-demo-ConvLSTM.lua
177 lines (147 loc) · 5.47 KB
/
main-demo-ConvLSTM.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
--[[
Demo script to train a model using the convolutional LSTM module
to predict the next frame in a sequence.
--]]
unpack = unpack or table.unpack
require 'nn'
require 'cunn'
require 'paths'
require 'torch'
require 'cutorch'
require 'image'
require 'optim'
require 'ConvLSTM'
local function main()
cutorch.setDevice(1)
paths.dofile('opts-mnist.lua')
opt.untied = true
paths.dofile('data-mnist.lua')
paths.dofile('model-demo-ConvLSTM.lua')
config = {}
opt.train = true
-----------------------------------------------------------------------------
-- Create model or load a pre-trained one
if opt.modelFile then -- resume training
model = torch.load(opt.modelFile)
if opt.train then
config = torch.load(opt.configFile)
end
end
if opt.train then
-----------------------------------------------------------------------------
-- Load data for training and verify one sample
dataset = getdataSeq_mnist(opt.dataFile) -- we sample nSeq consecutive frames
local trainSamples = dataset:size()
print ('main: Loaded ' .. trainSamples .. ' train sequences')
local seq = dataset[1][1]
print ('main: Verify sample')
print ('main: Image size')
print (seq[1]:size())
print ('main: Min '.. seq[1]:min() .. ', Max ' .. seq[1]:max())
if opt.display then
_check_ = image.display{image=seq, win=_check_, legend='Check sample sequence', nrow = seq:size(1)}
end
parameters, grads = model:getParameters()
print('Number of parameters ' .. parameters:nElement())
print('Number of grads ' .. grads:nElement())
local eta = config.eta or opt.eta
local momentum = config.momentum or opt.momentum
local iter = config.iter or 1
local epoch = config.epoch or 0
local err = 0
model:training()
model:forget()
rmspropconf = {learningRate = eta}
for t = 1,opt.maxIter do
iter = iter+1
--------------------------------------------------------------------
-- define eval closure
local feval = function()
local f = 0
model:zeroGradParameters()
inputTable = {}
targetTable = {}
sample = dataset[t]
data = sample[1]
for i = 1,data:size(1)-1 do
table.insert(inputTable, data[i]:cuda())
end
for i = 2,data:size(1) do
table.insert(targetTable, data[i]:cuda())
end
-- estimate f and gradients
output = model:updateOutput(inputTable)
f = criterion:updateOutput(output,targetTable)
-- gradients
local gradOutput = criterion:updateGradInput(output,targetTable)
model:updateGradInput(inputTable,gradOutput)
model:accGradParameters(inputTable, gradOutput)
grads:clamp(-opt.gradClip,opt.gradClip)
return f, grads
end
if math.fmod(t,trainSamples) == 0 then
epoch = epoch + 1
eta = opt.eta*math.pow(0.5,epoch/50)
rmspropconf.learningRate = eta
end
_,fs = optim.rmsprop(feval, parameters, rmspropconf)
err = err + fs[1]
model:forget()
--------------------------------------------------------------------
-- compute statistics / report error
if math.fmod(t , opt.statInterval) == 0 then
print('==> iteration = ' .. t .. ', average loss = ' .. err/(opt.nSeq) .. ' lr '..eta ) -- err/opt.statInterval)
err = 0
end
if opt.display and math.fmod(t , opt.displayInterval) == 0 then
_imInput_ = image.display{image=inputTable,win = _imInput_, legend = 'Input Sequence', nrow = #inputTable}
_imTarget_ = image.display{image=targetTable,win = _imTarget_, legend = 'Target Frames', nrow = #targetTable}
_imOutput_ = image.display{image=output,win = _imOutput_, legend = 'Output', nrow = #output}
end
if opt.save and math.fmod(t , opt.saveInterval) == 0 then
model:clearState()
torch.save(opt.dir .. '/model_' .. t .. '.bin', model)
config = {eta = eta, epsilon = epsilon, alpha = alpha, iter = iter, epoch = epoch}
torch.save(opt.dir .. '/config_' .. t .. '.bin', config)
end
end
print ('Training done')
collectgarbage()
end
-------------------------------------------------------------------------
-- Evaluation mode
print ('Start quantitative evaluation')
model:evaluate()
model:forget()
dataset = {}
dataset = getdataSeq_mnist(opt.dataFileTest)
local testSamples = dataset:size()
print ('main: Loaded ' .. testSamples .. ' test sequences')
err = 0
for t = 1,testSamples do
local f = 0
inputTable = {}
targetTable = {}
sample = dataset[t]
data = sample[1]
for i = 1,data:size(1)-1 do
table.insert(inputTable, data[i]:cuda())
end
for i = 2,data:size(1) do
table.insert(targetTable, data[i]:cuda())
end
output = model:updateOutput(inputTable)
f = criterion:updateOutput(output,targetTable)
print ('Error for sequence '.. t .. ' '.. f)
if opt.display then
_imInput_ = image.display{image=inputTable,win = _imInput_, legend = 'Input Sequence', nrow = #inputTable}
_imTarget_ = image.display{image=targetTable,win = _imTarget_, legend = 'Target Frame', nrow = #targetTable}
_imOutput_ = image.display{image=output,win = _imOutput_, legend = 'Output', nrow = #output}
end
err = err + f
end
print ('Average error '.. err/testSamples)
print ('Quantitative testing done')
collectgarbage()
end
main()