-
Notifications
You must be signed in to change notification settings - Fork 69
/
generate.lua
104 lines (87 loc) · 3.38 KB
/
generate.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
require 'nn'
require 'rnn'
require 'nngraph'
utils = dofile('utils.lua');
-------------------------------------------------------------------------------
-- Input arguments and options
-------------------------------------------------------------------------------
cmd = torch.CmdLine()
cmd:text()
cmd:text('Test the VisDial model for retrieval')
cmd:text()
cmd:text('Options')
-- Data input settings
cmd:option('-inputImg','data/data_img.h5','h5file path with image feature')
cmd:option('-inputQues','data/visdial_data.h5','h5file file with preprocessed questions')
cmd:option('-inputJson','data/visdial_params.json','json path with info and vocab')
cmd:option('-loadPath', 'checkpoints/model.t7', 'path to saved model')
cmd:option('-resultPath', 'vis/results', 'path to save generated results')
-- sampling params
cmd:option('-beamSize', 5, 'Beam size')
cmd:option('-beamLen', 20, 'Beam length')
cmd:option('-sampleWords', 0, 'Whether to sample')
cmd:option('-temperature', 1.0, 'Sampling temperature')
cmd:option('-maxThreads', 50, 'Max threads')
cmd:option('-gpuid', 0, 'GPU id to use')
cmd:option('-backend', 'cudnn', 'nn|cudnn')
local opt = cmd:parse(arg);
print(opt)
-- seed for reproducibility
torch.manualSeed(1234);
-- set default tensor based on gpu usage
if opt.gpuid >= 0 then
require 'cutorch'
require 'cunn'
if opt.backend == 'cudnn' then require 'cudnn' end
cutorch.setDevice(opt.gpuid+1)
cutorch.manualSeed(1234)
torch.setdefaulttensortype('torch.CudaTensor');
else
torch.setdefaulttensortype('torch.FloatTensor');
end
------------------------------------------------------------------------
-- Read saved model and parameters
------------------------------------------------------------------------
local savedModel = torch.load(opt.loadPath)
-- transfer all options to model
local modelParams = savedModel.modelParams
opt.imgNorm = modelParams.imgNorm
opt.encoder = modelParams.encoder
opt.decoder = modelParams.decoder
modelParams.gpuid = opt.gpuid
-- add flags for various configurations
-- additionally check if its imitation of discriminative model
if string.match(opt.encoder, 'hist') then
opt.useHistory = true;
end
if string.match(opt.encoder, 'im') then opt.useIm = true; end
------------------------------------------------------------------------
-- Loading dataset
------------------------------------------------------------------------
local dataloader = dofile('dataloader.lua')
dataloader:initialize(opt, {'val'});
collectgarbage();
------------------------------------------------------------------------
-- Setup the model
------------------------------------------------------------------------
require 'model'
local model = Model(modelParams)
-- copy the weights from loaded model
model.wrapperW:copy(savedModel.modelW);
------------------------------------------------------------------------
-- Generating
------------------------------------------------------------------------
sampleParams = {
beamSize = opt.beamSize,
beamLen = opt.beamLen,
maxThreads = opt.maxThreads,
sampleWords = opt.sampleWords,
temperature = opt.temperature
}
local answers = model:generateAnswers(dataloader, 'val', sampleParams)
local output = {opts = opt, data = answers}
-- save the file to json
local savePath = string.format('%s/results.json', opt.resultPath);
paths.mkdir(opt.resultPath)
utils.writeJSON(savePath, output);
print('Writing the results to '.. savePath);