Skip to content

Commit

Permalink
Merge pull request #246 from gpengzhi/seq2seq_attn
Browse files Browse the repository at this point in the history
Resolve issue mentioned in #242
  • Loading branch information
gpengzhi authored Nov 19, 2019
2 parents 676c66a + b9b4c71 commit 8553974
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 147 deletions.
8 changes: 5 additions & 3 deletions texar/tf/data/tokenizers/xlnet_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ class XLNetTokenizerTest(tf.test.TestCase):

def setUp(self):
self.tmp_dir = tempfile.TemporaryDirectory()
# Use the test sentencepiece model downloaded from huggingface
# transformers
self.SAMPLE_VOCAB = maybe_download(
'https://github.com/gpengzhi/pytorch-transformers/blob/master/'
'pytorch_transformers/tests/fixtures/test_sentencepiece.model'
'?raw=true', self.tmp_dir.name)
'https://github.com/huggingface/transformers/blob/master/'
'transformers/tests/fixtures/test_sentencepiece.model?raw=true',
self.tmp_dir.name)

self.tokenizer = XLNetTokenizer.load(
self.SAMPLE_VOCAB[0], configs={'keep_accents': True})
Expand Down
295 changes: 151 additions & 144 deletions texar/tf/modules/decoders/dynamic_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ def dynamic_decode(decoder,
Args:
decoder: A `Decoder` instance.
output_time_major: Python boolean. Default: `False` (batch major). If
`True`, outputs are returned as time major tensors (this mode is faster).
Otherwise, outputs are returned as batch major tensors (this adds extra
time to the computation).
`True`, outputs are returned as time major tensors (this mode is
faster). Otherwise, outputs are returned as batch major tensors
(this adds extra time to the computation).
impute_finished: Python boolean. If `True`, then states for batch
entries which are marked as finished get copied through and the
corresponding outputs get zeroed out. This causes some slowdown at
Expand All @@ -186,153 +186,160 @@ def dynamic_decode(decoder,
type(decoder))

with tf.variable_scope(scope, "decoder") as varscope:
# Properly cache variable values inside the while_loop
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)

if maximum_iterations is not None:
maximum_iterations = tf.convert_to_tensor(
maximum_iterations, dtype=tf.int32, name="maximum_iterations")
if maximum_iterations.get_shape().ndims != 0:
raise ValueError("maximum_iterations must be a scalar")

initial_finished, initial_inputs, initial_state = decoder.initialize()

zero_outputs = _create_zero_outputs(decoder.output_size,
decoder.output_dtype,
decoder.batch_size)

if maximum_iterations is not None:
initial_finished = tf.logical_or(
initial_finished, 0 >= maximum_iterations)
initial_sequence_lengths = tf.zeros_like(
initial_finished, dtype=tf.int32)
initial_time = tf.constant(0, dtype=tf.int32)

def _shape(batch_size, from_shape):
if (not isinstance(from_shape, tensor_shape.TensorShape) or
from_shape.ndims == 0):
return None
else:
batch_size = tf.get_static_value(
tf.convert_to_tensor(
batch_size, name="batch_size"))
return tensor_shape.TensorShape([batch_size]).\
concatenate(from_shape)

dynamic_size = True

def _create_ta(s, d):
return tf.TensorArray(
dtype=d,
size=0 if dynamic_size else maximum_iterations,
dynamic_size=dynamic_size,
element_shape=_shape(decoder.batch_size, s))

initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
decoder.output_dtype)

def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
finished, unused_sequence_lengths):
cond = tf.logical_not(tf.reduce_all(finished))
cond_time = (maximum_iterations is None or
unused_time < maximum_iterations)
ret = tf.logical_and(cond, tf.convert_to_tensor(cond_time))
return ret

def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
r"""Internal while_loop body.
Args:
time: scalar int32 tensor.
outputs_ta: structure of TensorArray.
state: (structure of) state tensors and TensorArrays.
inputs: (structure of) input tensors.
finished: bool tensor (keeping track of what's finished).
sequence_lengths: int32 tensor (keeping track of time of finish).
Returns:
`(time + 1, outputs_ta, next_state, next_inputs, next_finished,
next_sequence_lengths)`.
"""
(next_outputs, state) = decoder.step(time, inputs, state)

# Check if the maximum iteration is met. If it is met, do not compute
# the next inputs.
reach_max = tf.equal(time + 1, maximum_iterations)
(decoder_finished, next_inputs, decoder_state) = tf.cond(
reach_max,
lambda: (tf.cast(tf.ones_like(finished), tf.bool),
inputs, state),
lambda: decoder.next_inputs(time, next_outputs, state)
)
if decoder.tracks_own_finished:
next_finished = decoder_finished
else:
next_finished = tf.logical_or(decoder_finished, finished)
next_sequence_lengths = tf.where(
tf.logical_not(finished),
tf.fill(tf.shape(sequence_lengths), time + 1),
sequence_lengths)

nest.assert_same_structure(state, decoder_state)
nest.assert_same_structure(outputs_ta, next_outputs)
nest.assert_same_structure(inputs, next_inputs)

# Zero out output values past finish
if impute_finished:
emit = nest.map_structure(
lambda out, zero: tf.where(finished, zero, out),
next_outputs,
zero_outputs)
else:
emit = next_outputs

# Copy through states past finish
def _maybe_copy_state(new, cur):
# TensorArrays and scalar states get passed through.
if isinstance(cur, tf.TensorArray):
pass_through = True
initial_finished, initial_inputs, initial_state = decoder.initialize()

zero_outputs = _create_zero_outputs(decoder.output_size,
decoder.output_dtype,
decoder.batch_size)

if maximum_iterations is not None:
initial_finished = tf.logical_or(
initial_finished, 0 >= maximum_iterations)
initial_sequence_lengths = tf.zeros_like(
initial_finished, dtype=tf.int32)
initial_time = tf.constant(0, dtype=tf.int32)

def _shape(batch_size, from_shape):
if (not isinstance(from_shape, tensor_shape.TensorShape) or
from_shape.ndims == 0):
return None
else:
batch_size = tf.get_static_value(
tf.convert_to_tensor(
batch_size, name="batch_size"))
return tensor_shape.TensorShape([batch_size]).\
concatenate(from_shape)

dynamic_size = True

def _create_ta(s, d):
return tf.TensorArray(
dtype=d,
size=0 if dynamic_size else maximum_iterations,
dynamic_size=dynamic_size,
element_shape=_shape(decoder.batch_size, s))

initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
decoder.output_dtype)

def condition(unused_time, unused_outputs_ta, unused_state,
unused_inputs, finished, unused_sequence_lengths):
cond = tf.logical_not(tf.reduce_all(finished))
cond_time = (maximum_iterations is None or
unused_time < maximum_iterations)
ret = tf.logical_and(cond, tf.convert_to_tensor(cond_time))
return ret

def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
r"""Internal while_loop body.
Args:
time: scalar int32 tensor.
outputs_ta: structure of TensorArray.
state: (structure of) state tensors and TensorArrays.
inputs: (structure of) input tensors.
finished: bool tensor (keeping track of what's finished).
sequence_lengths: int32 tensor (keeping track of time of
finish).
Returns:
`(time + 1, outputs_ta, next_state, next_inputs, next_finished,
next_sequence_lengths)`.
"""
(next_outputs, state) = decoder.step(time, inputs, state)

# Check if the maximum iteration is met. If it is met, do not
# compute the next inputs.
reach_max = tf.equal(time + 1, maximum_iterations)
(decoder_finished, next_inputs, decoder_state) = tf.cond(
reach_max,
lambda: (tf.cast(tf.ones_like(finished), tf.bool),
inputs, state),
lambda: decoder.next_inputs(time, next_outputs, state)
)
if decoder.tracks_own_finished:
next_finished = decoder_finished
else:
next_finished = tf.logical_or(decoder_finished, finished)
next_sequence_lengths = tf.where(
tf.logical_not(finished),
tf.fill(tf.shape(sequence_lengths), time + 1),
sequence_lengths)

nest.assert_same_structure(state, decoder_state)
nest.assert_same_structure(outputs_ta, next_outputs)
nest.assert_same_structure(inputs, next_inputs)

# Zero out output values past finish
if impute_finished:
emit = nest.map_structure(
lambda out, zero: tf.where(finished, zero, out),
next_outputs,
zero_outputs)
else:
emit = next_outputs

# Copy through states past finish
def _maybe_copy_state(new, cur):
# TensorArrays and scalar states get passed through.
if isinstance(cur, tf.TensorArray):
pass_through = True
else:
new.set_shape(cur.shape)
pass_through = (new.shape.ndims == 0)
return new if pass_through else tf.where(finished, cur, new)

if impute_finished:
next_state = nest.map_structure(
_maybe_copy_state, decoder_state, state)
else:
new.set_shape(cur.shape)
pass_through = (new.shape.ndims == 0)
return new if pass_through else tf.where(finished, cur, new)

if impute_finished:
next_state = nest.map_structure(
_maybe_copy_state, decoder_state, state)
else:
next_state = decoder_state

outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
outputs_ta, emit)
return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
next_sequence_lengths)

res = tf.while_loop(
condition,
body,
loop_vars=(
initial_time,
initial_outputs_ta,
initial_state,
initial_inputs,
initial_finished,
initial_sequence_lengths,
),
parallel_iterations=parallel_iterations,
maximum_iterations=maximum_iterations,
swap_memory=swap_memory)

final_outputs_ta = res[1]
final_state = res[2]
final_sequence_lengths = res[5]

final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)

try:
final_outputs, final_state = decoder.finalize(
final_outputs, final_state, final_sequence_lengths)
except NotImplementedError:
pass

if not output_time_major:
final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)
next_state = decoder_state

outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
outputs_ta, emit)
return (time + 1, outputs_ta, next_state, next_inputs,
next_finished, next_sequence_lengths)

res = tf.while_loop(
condition,
body,
loop_vars=(
initial_time,
initial_outputs_ta,
initial_state,
initial_inputs,
initial_finished,
initial_sequence_lengths,
),
parallel_iterations=parallel_iterations,
maximum_iterations=maximum_iterations,
swap_memory=swap_memory)

final_outputs_ta = res[1]
final_state = res[2]
final_sequence_lengths = res[5]

final_outputs = nest.map_structure(lambda ta: ta.stack(),
final_outputs_ta)

try:
final_outputs, final_state = decoder.finalize(
final_outputs, final_state, final_sequence_lengths)
except NotImplementedError:
pass

if not output_time_major:
final_outputs = nest.map_structure(_transpose_batch_time,
final_outputs)

return final_outputs, final_state, final_sequence_lengths

0 comments on commit 8553974

Please sign in to comment.