-
Notifications
You must be signed in to change notification settings - Fork 69
/
utils.lua
162 lines (136 loc) · 5.14 KB
/
utils.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
-- script containing supporting code/methods
local utils = {};
cjson = require 'cjson'
-- right align the question tokens in 3d volume
function utils.rightAlign(sequences, lengths)
-- clone the sequences
local rAligned = sequences:clone():fill(0);
local numDims = sequences:dim();
if numDims == 3 then
local M = sequences:size(3); -- maximum length of question
local numImgs = sequences:size(1); -- number of images
local maxCount = sequences:size(2); -- number of questions / image
for imId = 1, numImgs do
for quesId = 1, maxCount do
-- do only for non zero sequence counts
if lengths[imId][quesId] == 0 then
break;
end
-- copy based on the sequence length
rAligned[imId][quesId][{{M - lengths[imId][quesId] + 1, M}}] =
sequences[imId][quesId][{{1, lengths[imId][quesId]}}];
end
end
else if numDims == 2 then
-- handle 2 dimensional matrices as well
local M = sequences:size(2); -- maximum length of question
local numImgs = sequences:size(1); -- number of images
for imId = 1, numImgs do
-- do only for non zero sequence counts
if lengths[imId] > 0 then
-- copy based on the sequence length
rAligned[imId][{{M - lengths[imId] + 1, M}}] =
sequences[imId][{{1, lengths[imId]}}];
end
end
end
end
return rAligned;
end
-- translate a given tensor/table to sentence
function utils.idToWords(vector, ind2word)
local sentence = '';
local nextWord;
for wordId = 1, vector:size(1) do
if vector[wordId] > 0 then
nextWord = ind2word[vector[wordId]];
sentence = sentence..' '..nextWord;
end
-- stop if end of token is attained
if nextWord == '<END>' then break; end
end
return sentence;
end
-- read a json file and lua table
function utils.readJSON(fileName)
local file = io.open(fileName, 'r');
local text = file:read();
file:close();
-- convert and save information
return cjson.decode(text);
end
-- save a lua table to the json
function utils.writeJSON(fileName, luaTable)
-- serialize lua table
local text = cjson.encode(luaTable)
local file = io.open(fileName, 'w');
file:write(text);
file:close();
end
-- compute the likelihood given the gt words and predicted probabilities
function utils.computeLhood(words, predProbs)
-- compute the probabilities for each answer, based on its tokens
-- convert to 2d matrix
local predVec = predProbs:view(-1, predProbs:size(3));
local indices = words:contiguous():view(-1, 1);
local mask = indices:eq(0);
-- assign proxy values to avoid 0 index errors
indices[mask] = 1;
local logProbs = predVec:gather(2, indices);
-- neutralize other values
logProbs[mask] = 0;
logProbs = logProbs:viewAs(words);
-- sum up for each sentence
logProbs = logProbs:sum(1):squeeze();
return logProbs;
end
-- process the scores and obtain the ranks
-- input: scores for all options, ground truth positions
function utils.computeRanks(scores, gtPos)
-- sort in descending order - largest score gets highest rank
local sorted, rankedIdx = scores:sort(2, true)
-- convert from ranked_idx to ranks
local ranks = rankedIdx:clone():fill(0)
for i = 1, rankedIdx:size(1) do
for j = 1, 100 do
ranks[{i, rankedIdx[{i, j}]}] = j
end
end
if gtPos then
gtPos = gtPos:view(-1)
local gtRanks = torch.LongTensor(gtPos:size(1))
for i = 1, gtPos:size(1) do
gtRanks[i] = ranks[{i, gtPos[i]}]
end
ranks = gtRanks
end
return ranks:double()
end
-- process the ranks and print metrics
function utils.processRanks(ranks)
-- print the results
local numQues = ranks:size(1) * ranks:size(2);
local numOptions = 100;
-- convert ranks to double, vector and remove zeros
ranks = ranks:double():view(-1);
-- non of the values should be 0, there is gt in options
if torch.sum(ranks:le(0)) > 0 then
numZero = torch.sum(ranks:le(0));
print(string.format('Warning: some of ranks are zero : %d', numZero))
ranks = ranks[ranks:gt(0)];
end
if torch.sum(ranks:ge(numOptions + 1)) > 0 then
numGreater = torch.sum(ranks:ge(numOptions + 1));
print(string.format('Warning: some of ranks >100 : %d', numGreater))
ranks = ranks[ranks:le(numOptions + 1)];
end
------------------------------------------------
print(string.format('\tNo. questions: %d', numQues))
print(string.format('\tr@1: %f', torch.sum(torch.le(ranks, 1))/numQues))
print(string.format('\tr@5: %f', torch.sum(torch.le(ranks, 5))/numQues))
print(string.format('\tr@10: %f', torch.sum(torch.le(ranks, 10))/numQues))
print(string.format('\tmedianR: %f', torch.median(ranks:view(-1))[1]))
print(string.format('\tmeanR: %f', torch.mean(ranks)))
print(string.format('\tmeanRR: %f', torch.mean(ranks:cinv())))
end
return utils;