forked from jcjohnson/torch-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
make_mapped_checkpoint.lua
99 lines (88 loc) · 2.51 KB
/
make_mapped_checkpoint.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
require 'torch'
require 'nn'
require 'LanguageModel'
local cmd = torch.CmdLine()
cmd:option('-i', 'cv/in.t7')
cmd:option('-o', 'cv/out.t7')
cmd:option('-j', '')
opt = cmd:parse(arg)
cp = torch.load(opt.i)
local known_storages = {}
function maptensor(t)
local storage = t:storage()
local ptr = torch.pointer(storage)
local stride = t:stride()
local offset = t:storageOffset()
local size = t:size()
if not known_storages[ptr] then
local storageidx = #known_storages
local filename = opt.o .. '.' .. storageidx
print(string.format('Found storage id=%d size=%d file=%s', storageidx, storage:size(), filename))
local newstorage = torch.FloatStorage(filename, true, storage:size())
newstorage:copy(storage)
known_storages[ptr] = storageidx
end
local storageidx = known_storages[ptr]
return {storage=storageidx, stride=stride, offset=offset, size=size}
end
function rg(module)
local t = torch.type(module)
--print(t)
if (t == 'nn.TemporalAdapter') then module.net:apply(rg) end
if module.weight then
module.weight = maptensor(module.weight)
end
if module.bias then
module.bias = maptensor(module.bias)
end
end
cp.model.net:apply(rg)
cp.is_mapped = true
torch.save(opt.o, cp)
function jsonize(t)
local j = {}
j.storage = t.storage
j.offset = t.offset - 1
j.size = {}
j.stride = {}
for i = 1, t.size:size() do
table.insert(j.size, t.size[i])
table.insert(j.stride, t.stride[i])
end
return j
end
if opt.j ~= '' then
local util = require 'util.utils'
local jdata = {}
jdata.idx_to_token = cp.model.idx_to_token
jdata.token_to_idx = cp.model.token_to_idx
jdata.layers = {}
for i,m in ipairs(cp.model.net.modules) do
local mtype = torch.type(m)
print(string.format("found module %d: %s", i, mtype))
local jm = {}
if mtype == 'nn.LookupTable' then
jm.type = 'LookupTable'
jm.weight = jsonize(m.weight)
elseif mtype == 'nn.GRIDGRU' then
jm.type = 'GRIDGRU'
jm.weight = jsonize(m.weight)
jm.bias = jsonize(m.bias)
jm.zoneout_p = m.zoneout_prob
jm.zoneout_pd = m.zoneout_probd
jm.input_dim = m.input_dim
jm.hidden_dim = m.hidden_dim
elseif mtype == 'nn.Dropout' then
jm.type = 'Dropout'
jm.p = m.p
elseif mtype == 'nn.TemporalAdapter' then
jm.type = 'Linear'
jm.weight = jsonize(m.net.modules[2].weight)
jm.bias = jsonize(m.net.modules[2].bias)
else
error("unknown type")
end
table.insert(jdata.layers, jm)
end
util.write_json(opt.j, jdata)
end