-
Notifications
You must be signed in to change notification settings - Fork 7
/
sltunet.py
346 lines (287 loc) · 14.3 KB
/
sltunet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
# coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import tensorflow as tf
import func
from utils import util, dtype
def encoder(source, mask, params, in_text=False, to_gloss=False):
# - in_text: if true, source is word ids and we need an embedding layer to extract source input
# if false, source is sign video features
# - to_gloss: if true, translation into glosses
# if false, translation into text
# we append an indicator vector to guide the model where to generate
hidden_size = params.hidden_size
initializer = tf.random_normal_initializer(0.0, hidden_size ** -0.5)
if not in_text:
# project sign video features to the embedding space
features = func.linear(source, hidden_size, scope="premapper")
else:
mask = dtype.tf_to_float(tf.cast(source, tf.bool))
embed_name = "embedding" if params.shared_source_target_embedding \
else "src_embedding"
src_emb = tf.get_variable(embed_name,
[params.src_vocab.size(), params.embed_size],
initializer=initializer)
src_bias = tf.get_variable("bias", [params.embed_size])
inputs = tf.gather(src_emb, source) * (hidden_size ** 0.5)
features = tf.nn.bias_add(inputs, src_bias)
# handle text or gloss generation
gloss_indicator = tf.get_variable("gloss", [1, params.embed_size])
trans_indicator = tf.get_variable("trans", [1, params.embed_size])
indicator = gloss_indicator if to_gloss else trans_indicator
# adding indicator in front of the inputs
mask = tf.pad(mask, [[0, 0], [1, 0]], constant_values=1)
ishp = util.shape_list(features)
features = tf.concat([util.expand_tile_dims(indicator, ishp[0], axis=0), features], 1)
inputs = func.add_timing_signal(features)
inputs = func.layer_norm(inputs)
inputs = util.valid_apply_dropout(inputs, params.dropout)
with tf.variable_scope("encoder"):
x = inputs
for layer in range(params.num_encoder_layer):
if params.deep_transformer_init:
layer_initializer = tf.variance_scaling_initializer(
params.initializer_gain * (layer + 1) ** -0.5,
mode="fan_avg",
distribution="uniform")
else:
layer_initializer = None
# modality-specific layers:
# - when layer <= sep_layer, we apply different encoder layers to sign videos and texts
with tf.variable_scope(
"layer_{}".format(layer) if layer > params.sep_layer else "layer_{}_{}".format(layer, 'mt' if in_text else 'st'),
initializer=layer_initializer):
with tf.variable_scope("self_attention"):
y = func.dot_attention(
x,
None,
func.attention_bias(mask, "masking"),
hidden_size,
num_heads=params.num_heads,
dropout=params.attention_dropout,
)
y = y['output']
x = func.residual_fn(x, y, dropout=params.residual_dropout)
x = func.layer_norm(x)
with tf.variable_scope("feed_forward"):
y = func.ffn_layer(
x,
params.filter_size,
hidden_size,
dropout=params.relu_dropout,
)
x = func.residual_fn(x, y, dropout=params.residual_dropout)
x = func.layer_norm(x)
source_encodes = x
x_shp = util.shape_list(x)
return {
"encodes": source_encodes,
"decoder_initializer": {
"layer_{}".format(l): {
"k": dtype.tf_to_float(tf.zeros([x_shp[0], 0, hidden_size])),
"v": dtype.tf_to_float(tf.zeros([x_shp[0], 0, hidden_size])),
}
for l in range(params.num_decoder_layer)
},
"mask": mask
}
def decoder(target, state, params, labels=None, is_img=None):
mask = dtype.tf_to_float(tf.cast(target, tf.bool))
hidden_size = params.hidden_size
initializer = tf.random_normal_initializer(0.0, hidden_size ** -0.5)
is_training = ('decoder' not in state)
embed_name = "embedding" if params.shared_source_target_embedding \
else "tgt_embedding"
tgt_emb = tf.get_variable(embed_name,
[params.tgt_vocab.size(), params.embed_size],
initializer=initializer)
tgt_bias = tf.get_variable("bias", [params.embed_size])
inputs = tf.gather(tgt_emb, target) * (hidden_size ** 0.5)
inputs = tf.nn.bias_add(inputs, tgt_bias)
# shift
if is_training:
inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]])
inputs = inputs[:, :-1, :]
inputs = func.add_timing_signal(inputs)
else:
inputs = tf.cond(tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())),
lambda: tf.zeros_like(inputs),
lambda: inputs)
mask = tf.ones_like(mask)
inputs = func.add_timing_signal(inputs, time=dtype.tf_to_float(state['time']))
inputs = util.valid_apply_dropout(inputs, params.dropout)
with tf.variable_scope("decoder"):
x = inputs
for layer in range(params.num_decoder_layer):
if params.deep_transformer_init:
layer_initializer = tf.variance_scaling_initializer(
params.initializer_gain * (layer + 1) ** -0.5,
mode="fan_avg",
distribution="uniform")
else:
layer_initializer = None
with tf.variable_scope("layer_{}".format(layer), initializer=layer_initializer):
with tf.variable_scope("self_attention"):
y = func.dot_attention(
x,
None,
func.attention_bias(tf.shape(mask)[1], "causal"),
hidden_size,
num_heads=params.num_heads,
dropout=params.attention_dropout,
cache=None if is_training else
state['decoder']['state']['layer_{}'.format(layer)],
)
if not is_training:
# k, v
state['decoder']['state']['layer_{}'.format(layer)].update(y['cache'])
y = y['output']
x = func.residual_fn(x, y, dropout=params.residual_dropout)
x = func.layer_norm(x)
with tf.variable_scope("cross_attention"):
y = func.dot_attention(
x,
state['encodes'],
func.attention_bias(state['mask'], "masking"),
hidden_size,
num_heads=params.num_heads,
dropout=params.attention_dropout,
cache=None if is_training else
state['decoder']['state']['layer_{}'.format(layer)],
)
if not is_training:
# mk, mv
state['decoder']['state']['layer_{}'.format(layer)].update(y['cache'])
y = y['output']
x = func.residual_fn(x, y, dropout=params.residual_dropout)
x = func.layer_norm(x)
with tf.variable_scope("feed_forward"):
y = func.ffn_layer(
x,
params.filter_size,
hidden_size,
dropout=params.relu_dropout,
)
x = func.residual_fn(x, y, dropout=params.residual_dropout)
x = func.layer_norm(x)
feature = x
embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \
else "softmax_embedding"
embed_name = "embedding" if params.shared_source_target_embedding \
else embed_name
softmax_emb = tf.get_variable(embed_name,
[params.tgt_vocab.size(), params.embed_size],
initializer=initializer)
feature = tf.reshape(feature, [-1, params.embed_size])
logits = tf.matmul(feature, softmax_emb, False, True)
logits = tf.cast(logits, tf.float32)
soft_label, normalizer = util.label_smooth(
target,
util.shape_list(logits)[-1],
factor=params.label_smooth)
centropy = tf.nn.softmax_cross_entropy_with_logits_v2(
logits=logits,
labels=soft_label
)
centropy -= normalizer
centropy = tf.reshape(centropy, tf.shape(target))
mask = tf.cast(mask, tf.float32)
per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum(mask, -1)
# for sign-related tasks, we need is_img to distinguish which examples are sign examples
if is_img is None:
loss = tf.reduce_mean(per_sample_loss)
else:
loss = tf.reduce_sum(per_sample_loss * is_img) / (tf.reduce_sum(is_img) + 1e-8)
# computing CTC regularization term
# note we only retrain sign2text's CTC regularizer
if is_training and params.ctc_enable and labels is not None:
assert labels is not None
# batch x seq x dim
encoding = state['encodes']
enc_logits = func.linear(encoding, params.src_vocab.size() + 1, scope="ctc_mapper")
# seq dimension transpose
enc_logits = tf.transpose(enc_logits, (1, 0, 2))
enc_logits = tf.to_float(enc_logits)
with tf.name_scope('loss'):
ctc_loss = tf.nn.ctc_loss(
labels, enc_logits, tf.cast(tf.reduce_sum(state['mask'], -1), tf.int32),
ignore_longer_outputs_than_inputs=True, preprocess_collapse_repeated=params.ctc_repeated)
ctc_loss /= tf.reduce_sum(mask, -1)
if is_img is None:
ctc_loss = tf.reduce_mean(ctc_loss)
else:
ctc_loss = tf.reduce_sum(ctc_loss * is_img) / (tf.reduce_sum(is_img) + 1e-8)
loss = params.ctc_alpha * ctc_loss + loss
return loss, logits, state, per_sample_loss
def train_fn(features, params, initializer=None):
with tf.variable_scope(params.scope_name or "model",
initializer=initializer,
reuse=tf.AUTO_REUSE,
dtype=tf.as_dtype(dtype.floatx()),
custom_getter=dtype.float32_variable_storage_getter):
# features contains five items
# - image: [batch, sign_video_len, feature_dim] (float) extracted sign video features based on SMKD
# - mask : [batch, sign_video_len] (float) mask for sign video features
# - source: [batch, src_seq_len] (int, ids) gloss or MT source inputs
# - target: [batch, tgt_seq_len] (int, ids) gloss translation or MT target
# - is_img: [batch] (float, like mask, 0.0 or 1.0) indicate whether the example comes from SLT
# note SLT example contains sign videos; but MT doesn't
# for SLT examples, the training data is a triple (sign video, gloss, text)
# for MT examples, the training data is also a triple (dummy video, source, target)
# sign translation: sign2text
state = encoder(features['image'], features['mask'], params, in_text=False, to_gloss=False)
loss_trans, *others = decoder(
features['target'], state, params,
labels=features['label'] if params.ctc_enable else None, is_img=features["is_img"])
# sign recognition: sign2gloss
state = encoder(features['image'], features['mask'], params, in_text=False, to_gloss=True)
# note we only add one CTC objective in sing2text, here we directly set labels `None`
loss_gloss, *others = decoder(
features['source'], state, params, labels=None, is_img=features["is_img"])
# gloss2text translation & machine translation: both are text-to-text tasks
state = encoder(features['source'], None, params, in_text=True, to_gloss=False)
loss_g2t, *others = decoder(features['target'], state, params, labels=None, is_img=None)
# note included in the final objective
# # text2gloss translation
# state = encoder(features['target'], None, params, in_text=True, to_gloss=True)
# loss_t2g, *others = decoder(
# features['source'], state, params, labels=None, is_img=features["is_img"])
# sum-up all loss terms
loss = loss_trans + loss_gloss + loss_g2t
return {
"loss": loss
}
def infer_fn(params):
params = copy.copy(params)
params = util.closing_dropout(params)
def encoding_fn(image, mask):
with tf.variable_scope(params.scope_name or "model",
reuse=tf.AUTO_REUSE,
dtype=tf.as_dtype(dtype.floatx()),
custom_getter=dtype.float32_variable_storage_getter):
eval_task = params.eval_task
if eval_task == 'sign2text':
state = encoder(image, mask, params, in_text=False, to_gloss=False)
elif eval_task == 'sign2gloss':
state = encoder(image, mask, params, in_text=False, to_gloss=True)
elif eval_task == 'gloss2text':
state = encoder(image, mask, params, in_text=True, to_gloss=False)
else:
raise NotImplementedError(f"Not supporting {eval_task}")
state["decoder"] = {
"state": state["decoder_initializer"]
}
return state
def decoding_fn(target, state, time):
with tf.variable_scope(params.scope_name or "model",
reuse=tf.AUTO_REUSE,
dtype=tf.as_dtype(dtype.floatx()),
custom_getter=dtype.float32_variable_storage_getter):
state['time'] = time
step_loss, step_logits, step_state, _ = decoder(
target, state, params)
del state['time']
return step_logits, step_state
return encoding_fn, decoding_fn