forked from jcjohnson/torch-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
multisample.lua
102 lines (86 loc) · 2.6 KB
/
multisample.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
require 'torch'
require 'nn'
require 'LanguageModel'
torch.setdefaulttensortype('torch.FloatTensor')
local cmd = torch.CmdLine()
cmd:option('-checkpoint', 'cv/checkpoint_4000.t7')
cmd:option('-sample', 1)
cmd:option('-temperature', 1)
cmd:option('-bytes', 1)
cmd:option('-start_text', '\n')
cmd:option('-length', 1024)
cmd:option('-count', 16)
cmd:option('-print_every', 10)
cmd:option('-output_file', 'outputs/output-#.txt')
cmd:option('-verbose', 0)
cmd:option('-forcelayer', 0)
cmd:option('-forcevalue', 1)
cmd:option('-gpu', 0)
cmd:option('-line_start_with', '')
local opt = cmd:parse(arg)
local timer = torch.Timer()
local checkpoint = torch.load(opt.checkpoint)
local model = checkpoint.model
if checkpoint.is_mapped then
model:unmapTensors(opt.checkpoint)
end
model:evaluate()
if opt.gpu > 0 then
require 'cutorch'
model:cuda()
end
local input = torch.LongTensor(opt.count, 1)
local outtext = {}
local outputs
local newline_idx = model.token_to_idx['\n']
local outfiles = {}
local forcequeue = {}
for i = 1, opt.count do
outtext[i] = ""
outfiles[i] = io.open(string.gsub(opt.output_file, "#", string.format("%03d", i)), "w")
outfiles[i]:setvbuf('no')
end
if opt.bytes == 1 then model:convertTables() end
local inp = model:encode_string(opt.start_text):view(1, -1)
outputs = model:forward(inp)[{{}, {-1, -1}}]:expand(opt.count, 1, model.vocab_size)
local state = model:getState(1)
model:setBatchSize(opt.count)
for i = 1, opt.count do
model:setState(i, state)
end
print(string.format('Initialization complete in %.2fs', timer:time().real))
timer:reset()
for i = 1, opt.length do
for j = 1, opt.count do
local ni = model:sampleFromScores(outputs[{{j}}], opt.temperature, 1)
if forcequeue[j] and #forcequeue[j] > 0 then
ni = model.token_to_idx[forcequeue[j]:sub(1,1)]
forcequeue[j] = forcequeue[j]:sub(2,-1)
end
if ni == newline_idx then
if opt.verbose > 0 then print(string.format("%3d:%s", j, outtext[j])) end
outfiles[j]:write(outtext[j] .. '\n')
outtext[j] = ""
forcequeue[j] = opt.line_start_with
else
outtext[j] = outtext[j] .. model.idx_to_token[ni]
end
input[{j,1}] = ni
end
if opt.forcelayer > 0 then
for i = 1, opt.count do
local l = model.rnns[opt.forcelayer]:getState(i)
l[i] = opt.forcevalue
end
end
outputs = model:forward(input)
if i % opt.print_every == 0 then
local t, b = timer:time().real, opt.count * opt.print_every
print(string.format("%8d/%8d %.2f %.1f/s", i, opt.length, t, b/t))
timer:reset()
end
end
for i = 1, opt.count do
outfiles[i]:write(outtext[i])
outfiles[i]:close()
end