diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ae319c70a..c66b4029c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,15 @@ # How to Contribute +# Issues + +* Please tag your issue with `bug`, `feature request`, or `question` to help us + effectively respond. +* Please include the versions of TensorFlow and Tensor2Tensor you are running + (run `pip list | grep tensor`) +* Please provide the command line you ran as well as the log output. + +# Pull Requests + We'd love to accept your patches and contributions to this project. There are just a few small guidelines you need to follow. diff --git a/setup.py b/setup.py index 5eebe27f3..bedb393fd 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.2.8', + version='1.2.9', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index eba408074..2ac0f0db2 100644 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -43,7 +43,6 @@ from tensor2tensor.data_generators import all_problems # pylint: disable=unused from tensor2tensor.data_generators import audio from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import snli -from tensor2tensor.data_generators import translate from tensor2tensor.data_generators import wsj_parsing from tensor2tensor.utils import registry from tensor2tensor.utils import usr_dir @@ -82,10 +81,10 @@ _SUPPORTED_PROBLEM_GENERATORS = { lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000), lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)), "parsing_english_ptb8k": ( - lambda: translate.parsing_token_generator( - FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13), - lambda: translate.parsing_token_generator( - FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13)), + lambda: wsj_parsing.parsing_token_generator( + FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13, 2**9), + lambda: wsj_parsing.parsing_token_generator( + FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13, 2**9)), "parsing_english_ptb16k": ( lambda: wsj_parsing.parsing_token_generator( FLAGS.data_dir, FLAGS.tmp_dir, True, 2**14, 2**9), diff --git a/tensor2tensor/models/__init__.py b/tensor2tensor/models/__init__.py index dd1c11390..c067711be 100644 --- a/tensor2tensor/models/__init__.py +++ b/tensor2tensor/models/__init__.py @@ -37,8 +37,6 @@ from tensor2tensor.models import shake_shake from tensor2tensor.models import slicenet from tensor2tensor.models import transformer -from tensor2tensor.models import transformer_adv -from tensor2tensor.models import transformer_alternative from tensor2tensor.models import transformer_moe from tensor2tensor.models import transformer_revnet from tensor2tensor.models import transformer_sketch diff --git a/tensor2tensor/models/shake_shake.py b/tensor2tensor/models/shake_shake.py index a4dd2385a..bad951a32 100644 --- a/tensor2tensor/models/shake_shake.py +++ b/tensor2tensor/models/shake_shake.py @@ -132,6 +132,8 @@ def model_fn_body(self, features): @registry.register_hparams def shakeshake_cifar10(): """Parameters for CIFAR-10.""" + tf.logging.warning("shakeshake_cifar10 hparams have not been verified to " + "achieve good performance.") hparams = common_hparams.basic_params1() # This leads to effective batch size 128 when number of GPUs is 1 hparams.batch_size = 4096 * 8 diff --git a/tensor2tensor/models/transformer_adv.py b/tensor2tensor/models/transformer_adv.py deleted file mode 100644 index 737aa822e..000000000 --- a/tensor2tensor/models/transformer_adv.py +++ /dev/null @@ -1,233 +0,0 @@ -# coding=utf-8 -# Copyright 2017 The Tensor2Tensor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Adversarial Transformer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Dependency imports - -from tensor2tensor.layers import common_layers -from tensor2tensor.models import transformer -from tensor2tensor.models import transformer_vae -from tensor2tensor.utils import registry -from tensor2tensor.utils import t2t_model - -import tensorflow as tf - - -def encode(x, x_space, hparams, name): - """Transformer preparations and encoder.""" - with tf.variable_scope(name): - (encoder_input, encoder_self_attention_bias, - ed) = transformer.transformer_prepare_encoder(x, x_space, hparams) - encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) - return transformer.transformer_encoder( - encoder_input, encoder_self_attention_bias, hparams), ed - - -def decode(encoder_output, encoder_decoder_attention_bias, targets, - hparams, name, reuse=False): - """Transformer decoder.""" - with tf.variable_scope(name, reuse=reuse): - targets = common_layers.flatten4d3d(targets) - - decoder_input, decoder_self_bias = transformer.transformer_prepare_decoder( - targets, hparams) - - decoder_input = tf.nn.dropout(decoder_input, - 1.0 - hparams.layer_prepostprocess_dropout) - - decoder_output = transformer.transformer_decoder( - decoder_input, - encoder_output, - decoder_self_bias, - encoder_decoder_attention_bias, - hparams) - - # Expand since t2t expects 4d tensors. - return tf.expand_dims(decoder_output, axis=2) - - -def reverse_gradient(x, delta=1.0): - return tf.stop_gradient((1.0 + delta) * x) - delta * x - - -def adversary(embedded, inputs, hparams, name, reuse=False): - with tf.variable_scope(name, reuse=reuse): - h0, i0 = common_layers.pad_to_same_length( - embedded, inputs, final_length_divisible_by=16) - h0 = tf.concat([h0, tf.expand_dims(i0, axis=2)], axis=-1) - h0 = tf.layers.dense(h0, hparams.hidden_size, name="io") - h1 = transformer_vae.compress(h0, None, False, hparams, "compress1") - h2 = transformer_vae.compress(h1, None, False, hparams, "compress2") - res_dense = tf.reduce_mean(h2, axis=[1, 2]) - res_single = tf.squeeze(tf.layers.dense(res_dense, 1), axis=-1) - return tf.nn.sigmoid(res_single) - - -def softmax_embed(x, embedding, batch_size, hparams): - """Softmax x and embed.""" - x = tf.reshape(tf.nn.softmax(x), [-1, 34*1024]) - x = tf.matmul(x, embedding) - return tf.reshape(x, [batch_size, -1, 1, hparams.hidden_size]) - - -def adv_transformer_internal(inputs, targets, target_space, hparams): - """Adversarial Transformer, main step used for training.""" - with tf.variable_scope("adv_transformer"): - batch_size = tf.shape(targets)[0] - targets = tf.reshape(targets, [batch_size, -1, 1]) - intermediate = tf.constant(34*1024 - 1) - intermediate += tf.zeros_like(targets) - targets = tf.concat([targets, intermediate], axis=2) - targets = tf.reshape(targets, [batch_size, -1, 1]) - embedding = tf.get_variable("embedding", [34*1024, hparams.hidden_size]) - targets_emb = tf.gather(embedding, targets) - - # Noisy embedded targets. - targets_noisy = tf.one_hot(targets, 34*1024) - noise_val = hparams.noise_val - targets_noisy += tf.random_uniform(tf.shape(targets_noisy), - minval=-noise_val, maxval=noise_val) - targets_emb_noisy = softmax_embed( - targets_noisy, embedding, batch_size, hparams) - - # Encoder. - if inputs is not None: - inputs_emb = common_layers.flatten4d3d(inputs) - inputs, ed = encode(inputs_emb, target_space, hparams, "input_enc") - else: - ed = None - - # Masking. - masking = common_layers.inverse_lin_decay(200000) - masking *= common_layers.inverse_exp_decay(50000) # Not much at start. - masking -= tf.random_uniform([]) * 0.4 - masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) - mask = tf.less(masking, tf.random_uniform(tf.shape(targets))) - mask = tf.expand_dims(tf.to_float(mask), 3) - noise = tf.random_uniform(tf.shape(targets_emb)) - targets_emb = mask * targets_emb + (1.0 - mask) * noise - - # Decoder. - res_dec = decode(inputs, ed, targets_emb, hparams, "decoder") - res = tf.layers.dense(res_dec, 34*1024, name="res_sm") - res_emb = softmax_embed(res, embedding, batch_size, hparams) - - # Extra steps. - extra_step_prob = masking * 0.6 + 0.3 - if hparams.mode != tf.estimator.ModeKeys.TRAIN: - extra_step_prob = 1.0 - for _ in xrange(hparams.extra_steps): - def another_step(emb): - res_dec = decode(inputs, ed, emb, hparams, "decoder", reuse=True) - res = tf.layers.dense(res_dec, 34*1024, name="res_sm", reuse=True) - return softmax_embed(res, embedding, batch_size, hparams), res - res_emb, res = tf.cond(tf.less(tf.random_uniform([]), extra_step_prob), - lambda e=res_emb: another_step(e), - lambda: (res_emb, res)) - - # Adversary. - delta = masking * hparams.delta_max - true_logit = adversary(tf.stop_gradient(targets_emb_noisy), - tf.stop_gradient(inputs + inputs_emb), - hparams, "adversary") - gen_logit = adversary(reverse_gradient(res_emb, delta), - tf.stop_gradient(inputs + inputs_emb), - hparams, "adversary", reuse=True) - losses = {"adv": gen_logit - true_logit} - res = tf.stop_gradient(masking * res) + (1.0 - masking) * res - return res, losses - - -@registry.register_model -class TransformerAdv(t2t_model.T2TModel): - """Adversarial Transformer.""" - - def model_fn_body(self, features): - inputs = features.get("inputs", None) - return adv_transformer_internal( - inputs, features["targets_raw"], - features["target_space_id"], self._hparams) - - def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, - alpha=0.0): - """Produce predictions from the model.""" - if not features: - features = {} - inputs_old = None - if "inputs" in features and len(features["inputs"].shape) < 4: - inputs_old = features["inputs"] - features["inputs"] = tf.expand_dims(features["inputs"], 2) - - # Create an initial targets tensor. - if "partial_targets" in features: - initial_output = tf.convert_to_tensor(features["partial_targets"]) - else: - batch_size = tf.shape(features["inputs"])[0] - length = tf.shape(features["inputs"])[1] - initial_output = tf.zeros((batch_size, 2 * length, 1, 1), dtype=tf.int64) - - features["targets"] = initial_output - sharded_logits, _ = self.model_fn(features, False) - sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) - samples = tf.concat(sharded_samples, 0) - - # More steps. - how_many_more_steps = 5 - for _ in xrange(how_many_more_steps): - with tf.variable_scope(tf.get_variable_scope(), reuse=True): - features["targets"] = samples - sharded_logits, _ = self.model_fn(features, False) - sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) - samples = tf.concat(sharded_samples, 0) - - if inputs_old is not None: # Restore to not confuse Estimator. - features["inputs"] = inputs_old - return samples - - -@registry.register_hparams -def transformer_adv_small(): - """Set of hyperparameters.""" - hparams = transformer.transformer_small() - hparams.batch_size = 2048 - hparams.learning_rate_warmup_steps = 4000 - hparams.num_hidden_layers = 3 - hparams.hidden_size = 384 - hparams.filter_size = 2048 - hparams.label_smoothing = 0.0 - hparams.weight_decay = 0.1 - hparams.symbol_modality_skip_top = True - hparams.target_modality = "symbol:ctc" - hparams.add_hparam("num_compress_steps", 2) - hparams.add_hparam("extra_steps", 0) - hparams.add_hparam("noise_val", 0.3) - hparams.add_hparam("delta_max", 2.0) - return hparams - - -@registry.register_hparams -def transformer_adv_base(): - """Set of hyperparameters.""" - hparams = transformer_adv_small() - hparams.batch_size = 1024 - hparams.hidden_size = 512 - hparams.filter_size = 4096 - hparams.num_hidden_layers = 6 - return hparams diff --git a/tensor2tensor/models/transformer_alternative.py b/tensor2tensor/models/transformer_alternative.py deleted file mode 100644 index 2604748be..000000000 --- a/tensor2tensor/models/transformer_alternative.py +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# Copyright 2017 The Tensor2Tensor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Alternative transformer network. - -Using different layer types to demonstrate alternatives to self attention. - -Code is mostly copied from original Transformer source. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Dependency imports - -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensor2tensor.layers import common_attention -from tensor2tensor.layers import common_layers -from tensor2tensor.models import transformer -from tensor2tensor.utils import registry -from tensor2tensor.utils import t2t_model - -import tensorflow as tf - - -@registry.register_model -class TransformerAlt(t2t_model.T2TModel): - - def model_fn_body(self, features): - hparams = self._hparams - targets = features["targets"] - inputs = features.get("inputs") - target_space = features.get("target_space_id") - - inputs = common_layers.flatten4d3d(inputs) - targets = common_layers.flatten4d3d(targets) - - (encoder_input, - encoder_attention_bias, _) = (transformer.transformer_prepare_encoder( - inputs, target_space, hparams)) - (decoder_input, _) = (transformer.transformer_prepare_decoder( - targets, hparams)) - - encoder_mask = bias_to_mask(encoder_attention_bias) - - def residual_fn(x, y): - return common_layers.layer_norm(x + tf.nn.dropout( - y, 1.0 - hparams.residual_dropout)) - - encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) - decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.residual_dropout) - - encoder_output = alt_transformer_encoder(encoder_input, residual_fn, - encoder_mask, hparams) - - decoder_output = alt_transformer_decoder(decoder_input, encoder_output, - residual_fn, - encoder_attention_bias, hparams) - - decoder_output = tf.expand_dims(decoder_output, 2) - - return decoder_output - - -def composite_layer(inputs, mask, hparams, for_output=False): - """Composite layer.""" - x = inputs - - # Applies ravanbakhsh on top of each other. - if hparams.composite_layer_type == "ravanbakhsh": - for layer in xrange(hparams.layers_per_layer): - with tf.variable_scope(".%d" % layer): - x = common_layers.ravanbakhsh_set_layer( - hparams.hidden_size, - x, - mask=mask, - sequential=for_output, - dropout=hparams.relu_dropout) - - # Transforms elements to get a context, and then uses this in a final layer. - elif hparams.composite_layer_type == "reembedding": - # Transform elements n times and then pool. - for layer in xrange(hparams.layers_per_layer): - with tf.variable_scope("sub_layer_%d" % layer): - x = common_layers.linear_set_layer( - hparams.hidden_size, x, dropout=hparams.relu_dropout) - if for_output: - context = common_layers.running_global_pool_1d(x) - else: - context = common_layers.global_pool_1d(x, mask=mask) - # Final layer. - x = common_layers.linear_set_layer( - hparams.hidden_size, x, context=context, dropout=hparams.relu_dropout) - return x - - -def alt_transformer_encoder(encoder_input, - residual_fn, - mask, - hparams, - name="encoder"): - """Alternative encoder.""" - x = encoder_input - with tf.variable_scope(name): - x = encoder_input - for layer in xrange(hparams.num_hidden_layers): - with tf.variable_scope("layer_%d" % layer): - x = residual_fn(x, composite_layer(x, mask, hparams)) - return x - - -def alt_transformer_decoder(decoder_input, - encoder_output, - residual_fn, - encoder_decoder_attention_bias, - hparams, - name="decoder"): - """Alternative decoder.""" - with tf.variable_scope(name): - x = decoder_input - for layer in xrange(hparams.num_hidden_layers): - with tf.variable_scope("layer_%d" % layer): - x_ = common_attention.multihead_attention( - x, - encoder_output, - encoder_decoder_attention_bias, - hparams.attention_key_channels or hparams.hidden_size, - hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, - hparams.num_heads, - hparams.attention_dropout, - name="encdec_attention") - - x_ = residual_fn(x_, composite_layer( - x_, None, hparams, for_output=True)) - x = residual_fn(x, x_) - return x - - -def bias_to_mask(bias): - # We need masks of the form batch size x input sequences - # Biases are of the form batch_size x num_heads x input sequences x - # output sequences. Squeeze out dim one, and get the first element of - # each vector. - bias = tf.squeeze(bias, [1])[:, :, 0] - bias = -tf.clip_by_value(bias, -1.0, 1.0) - mask = 1 - bias - return mask - - -@registry.register_hparams -def transformer_alt(): - """Set of hyperparameters.""" - hparams = transformer.transformer_base() - hparams.batch_size = 2048 - hparams.num_hidden_layers = 10 - hparams.add_hparam("layers_per_layer", 4) - # Composite layer: ravanbakhsh or reembedding. - hparams.add_hparam("composite_layer_type", "ravanbakhsh") - return hparams diff --git a/tensor2tensor/models/transformer_vae.py b/tensor2tensor/models/transformer_vae.py index 81156babd..ad5143095 100644 --- a/tensor2tensor/models/transformer_vae.py +++ b/tensor2tensor/models/transformer_vae.py @@ -300,11 +300,11 @@ def compress(x, c, is_2d, hparams, name): # Run compression by strided convs. cur = x k1 = (3, 3) if is_2d else (3, 1) + cur = residual_conv(cur, hparams.num_compress_steps, k1, hparams, "rc") k2 = (2, 2) if is_2d else (2, 1) for i in xrange(hparams.num_compress_steps): if c is not None: cur = attend(cur, c, hparams, "compress_attend_%d" % i) - cur = residual_conv(cur, 1, k1, hparams, "compress_rc_%d" % i) cur = common_layers.conv_block( cur, hparams.hidden_size, [((1, 1), k2)], strides=k2, name="compress_%d" % i) @@ -493,20 +493,24 @@ def ae_latent_sample(t_c, inputs, ed, embed, iters, hparams): t_pred = decode_transformer(inputs, ed, t_c, hparams, "extra") t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits") t_bit = multinomial_sample(t_pred, 2**16, hparams.sampling_temp) - for i in xrange(iters): + + def next_bit(t_bit, i): t_bit_prev = t_bit with tf.variable_scope(tf.get_variable_scope(), reuse=True): t_c = embed(t_bit) t_pred = decode_transformer(inputs, ed, t_c, hparams, "extra") t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits") t_bit = multinomial_sample(t_pred, 2**16, hparams.sampling_temp) - t_bit = tf.concat([t_bit_prev[:, :(i+1), :], - t_bit[:, (i+1):, :]], axis=1) + return tf.concat([t_bit_prev[:, :(i+1), :], + t_bit[:, (i+1):, :]], axis=1) + + for i in xrange(iters): + t_bit = next_bit(t_bit, i) return t_bit def ae_transformer_internal(inputs, targets, target_space, hparams, - beam_size, cache=None): + beam_size, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" hparams.z_size = hparams.hidden_size with tf.variable_scope("ae_transformer"): @@ -525,12 +529,10 @@ def ae_transformer_internal(inputs, targets, target_space, hparams, # Autoencoding. losses = {"vc": tf.constant(0.0), "sm": tf.constant(0.0)} - latent_len = hparams.latent_length if hparams.do_ae: - targets_pad, _ = common_layers.pad_to_same_length( - targets, targets, final_length_divisible_by=latent_len * 2**k) - targets_c = compress(targets_pad, None, False, hparams, "compress") - targets_c = targets_c[:, :latent_len, :, :] + targets, _ = common_layers.pad_to_same_length( + targets, targets, final_length_divisible_by=2**k) + targets_c = compress(targets, None, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. t_c, t_bit, vc_loss, _ = bottleneck(targets_c, hparams, 2*2048, "vc") @@ -546,25 +548,45 @@ def ae_transformer_internal(inputs, targets, target_space, hparams, t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits") losses["sm"] = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=t_bit, logits=t_pred) - losses["sm"] = tf.reduce_mean(losses["sm"]) * 0.2 * tf.to_float(cond) + losses["sm"] = tf.reduce_mean(losses["sm"]) * 0.5 * tf.to_float(cond) else: + latent_len = tf.shape(targets_c)[1] _, _, _, embed = bottleneck(targets_c, hparams, 2*2048, "vc") - t_c = tf.zeros_like(targets_c) + t_c = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: - cache = ae_latent_sample(t_c, inputs, ed, embed, 3, hparams) + cache = ae_latent_sample(t_c, inputs, ed, embed, 8, hparams) cache = cache[0, :, :] cache = tf.reshape(cache, [1, latent_len, 1]) cache = tf.tile(cache, [beam_size, 1, 1]) t_c = embed(cache) # Postprocess. - pos = tf.get_variable("pos", [1, latent_len + 1, 1, hparams.hidden_size]) + d = t_c + pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) + pos = pos[:, :tf.shape(t_c)[1] + 1, :, :] t_c = tf.pad(t_c, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos + + # Masking. + if hparams.do_mask: + masking = common_layers.inverse_lin_decay(100000) + masking *= common_layers.inverse_exp_decay(25000) # Not much at start. + masking -= tf.random_uniform([]) * 0.3 + masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) + if hparams.mode == tf.estimator.ModeKeys.PREDICT: + masking = predict_mask + mask = tf.less(masking, tf.random_uniform(tf.shape(targets)[:-1])) + mask = tf.expand_dims(tf.to_float(mask), 3) + for i in xrange(hparams.num_compress_steps): + j = hparams.num_compress_steps - i - 1 + d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) + d = decompress_step(d, None, hparams, + i > 0, False, "decompress_%d" % j) + noise = d # tf.random_uniform(tf.shape(targets)) + targets = mask * targets + (1.0 - mask) * noise targets = tf.concat([tf.reverse(t_c, [1]), targets], axis=1) - else: - targets = tf.pad(targets, [[0, 0], [latent_len + 1, 0], [0, 0], [0, 0]]) res = decode_transformer(inputs, ed, targets, hparams, "decoder") - res = res[:, latent_len + 1:, :, :] + if hparams.do_ae: + res = res[:, tf.shape(t_c)[1]:, :, :] return res, losses, cache @@ -572,6 +594,10 @@ def ae_transformer_internal(inputs, targets, target_space, hparams, class TransformerAE(t2t_model.T2TModel): """Autoencoder-augmented Transformer.""" + def __init__(self, *args, **kwargs): + super(TransformerAE, self).__init__(*args, **kwargs) + self.predict_mask = 1.0 + @property def has_input(self): return self._problem_hparams.input_modality @@ -585,7 +611,8 @@ def model_fn_body(self, features): with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): res, loss, _ = ae_transformer_internal( inputs, features["targets"], features["target_space_id"], - self._hparams, beam_size, features.get("cache_raw", None)) + self._hparams, beam_size, features.get("cache_raw", None), + predict_mask=self.predict_mask) return res, loss def prepare_features_for_infer(self, features): @@ -603,6 +630,38 @@ def prepare_features_for_infer(self, features): self._hparams, beam_size) features["cache_raw"] = cache + def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, + alpha=0.0): + """Produce predictions from the model.""" + if not self._hparams.do_mask: + return super(TransformerAE, self).infer( + features, decode_length, beam_size, top_beams, alpha) + if not features: + features = {} + inputs_old = None + if "inputs" in features and len(features["inputs"].shape) < 4: + inputs_old = features["inputs"] + features["inputs"] = tf.expand_dims(features["inputs"], 2) + + # Create an initial targets tensor. + if "partial_targets" in features: + initial_output = tf.convert_to_tensor(features["partial_targets"]) + else: + batch_size = tf.shape(features["inputs"])[0] + length = tf.shape(features["inputs"])[1] + target_length = tf.to_int32(1.3 * tf.to_float(length)) + initial_output = tf.zeros((batch_size, target_length, 1, 1), + dtype=tf.int64) + + features["targets"] = initial_output + sharded_logits, _ = self.model_fn(features, False, force_full_predict=True) + sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) + samples = tf.concat(sharded_samples, 0) + + if inputs_old is not None: # Restore to not confuse Estimator. + features["inputs"] = inputs_old + return samples + @registry.register_hparams def transformer_ae_small(): @@ -615,12 +674,12 @@ def transformer_ae_small(): hparams.filter_size = 2048 hparams.label_smoothing = 0.0 hparams.add_hparam("c_size", 16) - hparams.add_hparam("latent_length", 4) hparams.add_hparam("noise_dev", 1.0) hparams.add_hparam("d_mix", 0.5) # Bottleneck kinds supported: dense, semhash, gumbel-softmax. hparams.add_hparam("bottleneck_kind", "semhash") hparams.add_hparam("do_ae", True) + hparams.add_hparam("do_mask", True) hparams.add_hparam("drop_inputs", False) hparams.add_hparam("z_size", 128) hparams.add_hparam("v_size", 1024*64) diff --git a/tensor2tensor/utils/expert_utils.py b/tensor2tensor/utils/expert_utils.py index 9764b2b99..7d4912bc6 100644 --- a/tensor2tensor/utils/expert_utils.py +++ b/tensor2tensor/utils/expert_utils.py @@ -200,8 +200,8 @@ def daisy_chain_getter(getter, name, *args, **kwargs): else: var = getter(name, *args, **kwargs) v = tf.identity(var._ref()) # pylint: disable=protected-access + _add_variable_proxy_methods(var, v) # update the cache - _add_variable_proxy_methods(var, v) cache[name] = v cache[device_var_key] = v return v @@ -210,10 +210,12 @@ def daisy_chain_getter(getter, name, *args, **kwargs): # so we make a custom getter that uses identity to cache the variable. # pylint: disable=cell-var-from-loop def caching_getter(getter, name, *args, **kwargs): - v = getter(name, *args, **kwargs) + """Cache variables on device.""" key = (self._caching_devices[i], name) if key in cache: return cache[key] + + v = getter(name, *args, **kwargs) with tf.device(self._caching_devices[i]): ret = tf.identity(v._ref()) # pylint: disable=protected-access _add_variable_proxy_methods(v, ret) diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index ac11d54aa..f5ec04679 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -505,13 +505,15 @@ def _shard_features(self, features): # pylint: disable=missing-docstring 0)) return sharded_features - def model_fn(self, features, skip=False): + def model_fn(self, features, skip=False, force_full_predict=False): """Computes the entire model and produces sharded logits and losses. Args: features: A dictionary of feature name to tensor. - skip: a boolean, if we're just dummy-calling and actually skip this model + skip: a Boolean, if we're just dummy-calling and actually skip this model (but we need to create variables to not confuse distributed training). + force_full_predict: a Boolean, if set, then last-position-only + optimizations are not used even when allowed and in PREDICT mode. Returns: sharded_logits: a list of `Tensor`s, one per datashard. @@ -579,7 +581,8 @@ def model_fn(self, features, skip=False): with tf.variable_scope(target_modality.name, reuse=target_reuse): last_only = (target_modality.top_is_pointwise and - self._hparams.mode == tf.estimator.ModeKeys.PREDICT) + self._hparams.mode == tf.estimator.ModeKeys.PREDICT and + not force_full_predict) if not last_only: sharded_logits = target_modality.top_sharded( body_outputs, sharded_features["targets"], dp)