diff --git a/.travis.yml b/.travis.yml index ecfcb699a..02e7e0768 100644 --- a/.travis.yml +++ b/.travis.yml @@ -41,6 +41,7 @@ script: --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/bin/t2t_trainer_test.py --ignore=tensor2tensor/data_generators/algorithmic_math_test.py + --ignore=tensor2tensor/models/research/r_transformer_test.py # Requires new feature in tf.foldl (rm with TF 1.9) - pytest tensor2tensor/utils/registry_test.py - pytest tensor2tensor/utils/trainer_lib_test.py - pytest tensor2tensor/visualization/visualization_test.py diff --git a/docs/distributed_training.md b/docs/distributed_training.md index b9e070721..48ef14a34 100644 --- a/docs/distributed_training.md +++ b/docs/distributed_training.md @@ -51,8 +51,10 @@ distributed training: Parameter servers only need `--master=grpc://$ADDRESS` and `--schedule=run_std_server`. ->> Note about `output_dir`: All the workers (masters and parameter servers) should use the same `output_dir`. If training ->> on separate nodes, output_dir can be a shared filesystem like NFS or an object store like GCS. +>> Note about `--output_dir`: All the nodes should use the same `--output_dir`. +>> When using multiple machines, `output_dir` should point to a shared +>> filesystem like NFS or an object store like Google Cloud Storage +>> (`gs://...`). ## Utility to produce `TF_CONFIG` and flags diff --git a/setup.py b/setup.py index cc22c8a0f..6cb7841c4 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.6.0', + version='1.6.1', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', @@ -14,6 +14,7 @@ packages=find_packages(), package_data={ 'tensor2tensor.data_generators': ['test_data/*'], + 'tensor2tensor.data_generators.wikisum': ['test_data/*'], 'tensor2tensor.visualization': [ 'attention.js', 'TransformerVisualization.ipynb' ], @@ -37,7 +38,8 @@ 'gevent', 'google-api-python-client', 'gunicorn', - 'gym<=0.9.5', # gym in version 0.9.6 has some temporary issues. + 'gym', + 'h5py', 'numpy', 'requests', 'scipy', @@ -47,7 +49,7 @@ extras_require={ 'tensorflow': ['tensorflow>=1.5.0'], 'tensorflow_gpu': ['tensorflow-gpu>=1.5.0'], - 'tests': ['pytest', 'h5py', 'mock'], + 'tests': ['pytest', 'mock'], }, classifiers=[ 'Development Status :: 4 - Beta', diff --git a/tensor2tensor/bin/make_tf_configs.py b/tensor2tensor/bin/make_tf_configs.py index 85ba874f3..cf3d10257 100644 --- a/tensor2tensor/bin/make_tf_configs.py +++ b/tensor2tensor/bin/make_tf_configs.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Output command line arguments and json-encoded TF_CONFIGs. Usage: diff --git a/tensor2tensor/bin/t2t_avg_all.py b/tensor2tensor/bin/t2t_avg_all.py index 694ab26ed..7d34f4a33 100644 --- a/tensor2tensor/bin/t2t_avg_all.py +++ b/tensor2tensor/bin/t2t_avg_all.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Script to continuously average last N checkpoints in a given directory.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/bin/t2t_bleu.py b/tensor2tensor/bin/t2t_bleu.py index 74117454d..9cf789e7c 100644 --- a/tensor2tensor/bin/t2t_bleu.py +++ b/tensor2tensor/bin/t2t_bleu.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Evaluate BLEU score for all checkpoints/translations in a given directory. This script can be used in two ways. diff --git a/tensor2tensor/bin/t2t_datagen.py b/tensor2tensor/bin/t2t_datagen.py index d1dcab834..45826c4bd 100644 --- a/tensor2tensor/bin/t2t_datagen.py +++ b/tensor2tensor/bin/t2t_datagen.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Produces the training and dev data for --problem into --data_dir. Produces sharded and shuffled TFRecord files of tensorflow.Example protocol diff --git a/tensor2tensor/bin/t2t_decoder.py b/tensor2tensor/bin/t2t_decoder.py index 08d8c7ee5..cb3b19032 100644 --- a/tensor2tensor/bin/t2t_decoder.py +++ b/tensor2tensor/bin/t2t_decoder.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - r"""Decode from trained T2T models. This binary performs inference using the Estimator API. @@ -82,9 +81,13 @@ def create_decode_hparams(): def decode(estimator, hparams, decode_hp): if FLAGS.decode_interactive: + if estimator.config.use_tpu: + raise ValueError("TPU can only decode from dataset.") decoding.decode_interactively(estimator, hparams, decode_hp, checkpoint_path=FLAGS.checkpoint_path) elif FLAGS.decode_from_file: + if estimator.config.use_tpu: + raise ValueError("TPU can only decode from dataset.") decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams, decode_hp, FLAGS.decode_to_file, checkpoint_path=FLAGS.checkpoint_path) @@ -160,7 +163,6 @@ def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) - FLAGS.use_tpu = False # decoding not supported on TPU if FLAGS.score_file: filename = os.path.expanduser(FLAGS.score_file) @@ -183,7 +185,7 @@ def main(_): hp, t2t_trainer.create_run_config(hp), decode_hparams=decode_hp, - use_tpu=False) + use_tpu=FLAGS.use_tpu) decode(estimator, hp, decode_hp) diff --git a/tensor2tensor/bin/t2t_distill.py b/tensor2tensor/bin/t2t_distill.py index 75c14ca55..2ed3d0cb6 100644 --- a/tensor2tensor/bin/t2t_distill.py +++ b/tensor2tensor/bin/t2t_distill.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - r"""Perform distillation for a teacher to student. This script is intended to be used with --model=distillation. See the model for diff --git a/tensor2tensor/bin/t2t_trainer.py b/tensor2tensor/bin/t2t_trainer.py index 87443ad47..6376aa2ad 100644 --- a/tensor2tensor/bin/t2t_trainer.py +++ b/tensor2tensor/bin/t2t_trainer.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Train and evaluate.""" from __future__ import absolute_import from __future__ import division @@ -87,6 +86,8 @@ "Name of Cloud TPU instance to use or create.") flags.DEFINE_bool("cloud_delete_on_done", False, "Whether to delete the VM and TPU instance when done.") +flags.DEFINE_bool("cloud_skip_confirmation", False, + "Whether to skip launch confirmations.") # Google Cloud ML Engine flags.DEFINE_bool("cloud_mlengine", False, @@ -319,7 +320,8 @@ def maybe_cloud_tpu(): with cloud_tpu.cloud_tpu( FLAGS.cloud_vm_name, FLAGS.cloud_tpu_name, - delete_on_done=FLAGS.cloud_delete_on_done) as tpu_master: + delete_on_done=FLAGS.cloud_delete_on_done, + skip_confirmation=FLAGS.cloud_skip_confirmation) as tpu_master: FLAGS.master = tpu_master yield diff --git a/tensor2tensor/bin/t2t_trainer_test.py b/tensor2tensor/bin/t2t_trainer_test.py index 52d58111a..31b1e884b 100644 --- a/tensor2tensor/bin/t2t_trainer_test.py +++ b/tensor2tensor/bin/t2t_trainer_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for t2t_trainer.""" from __future__ import absolute_import diff --git a/tensor2tensor/bin/t2t_translate_all.py b/tensor2tensor/bin/t2t_translate_all.py index 7041fb8c1..c6f8354e5 100644 --- a/tensor2tensor/bin/t2t_translate_all.py +++ b/tensor2tensor/bin/t2t_translate_all.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Translate a file with all checkpoints in a given directory. t2t-decoder will be executed with these parameters: diff --git a/tensor2tensor/data_generators/algorithmic.py b/tensor2tensor/data_generators/algorithmic.py index 25dbb8add..d98a133c5 100644 --- a/tensor2tensor/data_generators/algorithmic.py +++ b/tensor2tensor/data_generators/algorithmic.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Algorithmic data generators.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/data_generators/algorithmic_math.py b/tensor2tensor/data_generators/algorithmic_math.py index ed96bbfad..3edc0db19 100644 --- a/tensor2tensor/data_generators/algorithmic_math.py +++ b/tensor2tensor/data_generators/algorithmic_math.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Algorithmic data generators for symbolic math tasks. See go/symbolic-math-dataset diff --git a/tensor2tensor/data_generators/algorithmic_math_test.py b/tensor2tensor/data_generators/algorithmic_math_test.py index ce5310b8e..953415947 100644 --- a/tensor2tensor/data_generators/algorithmic_math_test.py +++ b/tensor2tensor/data_generators/algorithmic_math_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for tensor2tensor.data_generators.algorithmic_math.""" # TODO(rsepassi): This test is flaky. Disable, remove, or update. diff --git a/tensor2tensor/data_generators/algorithmic_test.py b/tensor2tensor/data_generators/algorithmic_test.py index 2644a3b33..ffa2f4b38 100644 --- a/tensor2tensor/data_generators/algorithmic_test.py +++ b/tensor2tensor/data_generators/algorithmic_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Algorithmic generators test.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index ce4407c28..7e3b3e008 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Imports for problem modules.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import importlib +import six modules = [ @@ -55,12 +55,39 @@ "tensor2tensor.data_generators.translate_enzh", "tensor2tensor.data_generators.twentybn", "tensor2tensor.data_generators.wiki", + "tensor2tensor.data_generators.wikisum.wikisum", "tensor2tensor.data_generators.wsj_parsing", ] -for module in modules: + +def _py_err_msg(module): + if six.PY2: + msg = "No module named %s" % module.split(".")[-1] + else: + msg = "No module named '%s'" % module + return msg + + +def _handle_errors(errors): + """Log out and possibly reraise errors during import.""" + if not errors: + return + log_all = True # pylint: disable=unused-variable + err_msg = "Skipped importing {num_missing} data_generators modules." + print(err_msg.format(num_missing=len(errors))) + for module, err in errors: + err_str = str(err) + if err_str != _py_err_msg(module): + raise err + if log_all: + print("Did not import module: %s; Cause: %s" % (module, err_str)) + + +_errors = [] +for _module in modules: try: - importlib.import_module(module) + importlib.import_module(_module) except ImportError as error: - print("Did not import module: %s; Cause: %s" % (module, str(error))) + _errors.append((_module, error)) +_handle_errors(_errors) diff --git a/tensor2tensor/data_generators/all_problems_test.py b/tensor2tensor/data_generators/all_problems_test.py index 2cad7ee44..cc899fa6f 100644 --- a/tensor2tensor/data_generators/all_problems_test.py +++ b/tensor2tensor/data_generators/all_problems_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for Tensor2Tensor's all_problems.py.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/audio.py b/tensor2tensor/data_generators/audio.py index 1f92ee939..3335cd70f 100644 --- a/tensor2tensor/data_generators/audio.py +++ b/tensor2tensor/data_generators/audio.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """TIMIT data generator.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/data_generators/audio_test.py b/tensor2tensor/data_generators/audio_test.py index 81a55686d..37e188e3f 100644 --- a/tensor2tensor/data_generators/audio_test.py +++ b/tensor2tensor/data_generators/audio_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for tensor2tensor.data_generators.audio.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/celeba.py b/tensor2tensor/data_generators/celeba.py index a4e76fbb4..1ca5fc04c 100644 --- a/tensor2tensor/data_generators/celeba.py +++ b/tensor2tensor/data_generators/celeba.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """CelebA.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/cifar.py b/tensor2tensor/data_generators/cifar.py index 2c332e6e7..bdde9c89f 100644 --- a/tensor2tensor/data_generators/cifar.py +++ b/tensor2tensor/data_generators/cifar.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """CIFAR.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/cipher.py b/tensor2tensor/data_generators/cipher.py index d7147547b..d6a244d59 100644 --- a/tensor2tensor/data_generators/cipher.py +++ b/tensor2tensor/data_generators/cipher.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Cipher data generators.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/data_generators/cnn_dailymail.py b/tensor2tensor/data_generators/cnn_dailymail.py index a09d22c3c..6cc0a48ec 100644 --- a/tensor2tensor/data_generators/cnn_dailymail.py +++ b/tensor2tensor/data_generators/cnn_dailymail.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for the CNN and Daily Mail datasets.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/desc2code.py b/tensor2tensor/data_generators/desc2code.py index 145279a84..99c15882e 100644 --- a/tensor2tensor/data_generators/desc2code.py +++ b/tensor2tensor/data_generators/desc2code.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for the Description2Code OpenAI data-set.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/desc2code_test.py b/tensor2tensor/data_generators/desc2code_test.py index 61f059789..cccfef801 100644 --- a/tensor2tensor/data_generators/desc2code_test.py +++ b/tensor2tensor/data_generators/desc2code_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for desc2code.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/dna_encoder.py b/tensor2tensor/data_generators/dna_encoder.py index a4b2c244b..ce1d09955 100644 --- a/tensor2tensor/data_generators/dna_encoder.py +++ b/tensor2tensor/data_generators/dna_encoder.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Encoders for DNA data. * DNAEncoder: ACTG strings to ints and back diff --git a/tensor2tensor/data_generators/dna_encoder_test.py b/tensor2tensor/data_generators/dna_encoder_test.py index a32e3f2bf..453faf6a6 100644 --- a/tensor2tensor/data_generators/dna_encoder_test.py +++ b/tensor2tensor/data_generators/dna_encoder_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for tensor2tensor.data_generators.dna_encoder.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/fsns.py b/tensor2tensor/data_generators/fsns.py index ed7baaec6..46fb8c021 100644 --- a/tensor2tensor/data_generators/fsns.py +++ b/tensor2tensor/data_generators/fsns.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """FSNS.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/gene_expression.py b/tensor2tensor/data_generators/gene_expression.py index cdd62491f..2b640dc11 100644 --- a/tensor2tensor/data_generators/gene_expression.py +++ b/tensor2tensor/data_generators/gene_expression.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Gene expression problems. Inputs are bases ACTG (with indices assigned in that order). diff --git a/tensor2tensor/data_generators/gene_expression_test.py b/tensor2tensor/data_generators/gene_expression_test.py index 8a7ccd55a..b70b0885d 100644 --- a/tensor2tensor/data_generators/gene_expression_test.py +++ b/tensor2tensor/data_generators/gene_expression_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for Genetics problems.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index 6bd069388..e6ab96c04 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utilities for data generators.""" from __future__ import absolute_import @@ -20,6 +19,7 @@ from __future__ import print_function import gzip +import multiprocessing as mp import os import random import stat @@ -330,6 +330,7 @@ def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, reserved_tokens=reserved_tokens) if vocab_filepath: + tf.gfile.MakeDirs(data_dir) vocab.store_to_file(vocab_filepath) return vocab @@ -468,17 +469,24 @@ def generate_dataset_and_shuffle(train_gen, shuffle_dataset(train_paths + dev_paths) +def _shuffle_single(fname): + records = read_records(fname) + random.shuffle(records) + out_fname = fname.replace(UNSHUFFLED_SUFFIX, "") + write_records(records, out_fname) + tf.gfile.Remove(fname) + + def shuffle_dataset(filenames): if outputs_exist(filenames): tf.logging.info("Skipping shuffle because output files exist") return tf.logging.info("Shuffling data...") - for fname in filenames: - records = read_records(fname) - random.shuffle(records) - out_fname = fname.replace(UNSHUFFLED_SUFFIX, "") - write_records(records, out_fname) - tf.gfile.Remove(fname) + if len(filenames) > 1: + pool = mp.Pool(min(len(filenames), 20)) + pool.map(_shuffle_single, filenames) + else: + _shuffle_single(filenames[0]) class SequencePacker(object): diff --git a/tensor2tensor/data_generators/generator_utils_test.py b/tensor2tensor/data_generators/generator_utils_test.py index 0460ef2e7..6276e0d3a 100644 --- a/tensor2tensor/data_generators/generator_utils_test.py +++ b/tensor2tensor/data_generators/generator_utils_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Generator utilities test.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/gym.py b/tensor2tensor/data_generators/gym.py index 6a82f1d4c..df2b8dd1c 100644 --- a/tensor2tensor/data_generators/gym.py +++ b/tensor2tensor/data_generators/gym.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for Gym environments.""" from __future__ import absolute_import @@ -99,7 +98,7 @@ def env(self): @property def num_actions(self): - raise NotImplementedError() + return self.env.action_space.n @property def num_rewards(self): @@ -154,10 +153,6 @@ def frame_height(self): def frame_width(self): return 160 - @property - def num_actions(self): - return 4 - @property def min_reward(self): return -1 diff --git a/tensor2tensor/data_generators/ice_parsing.py b/tensor2tensor/data_generators/ice_parsing.py index 22f5d1282..e5c39934f 100644 --- a/tensor2tensor/data_generators/ice_parsing.py +++ b/tensor2tensor/data_generators/ice_parsing.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """This module implements the ice_parsing_* problems.""" # These parse plain text into flattened parse trees and POS tags. diff --git a/tensor2tensor/data_generators/image_utils.py b/tensor2tensor/data_generators/image_utils.py index 06061a5ff..59e7a1277 100644 --- a/tensor2tensor/data_generators/image_utils.py +++ b/tensor2tensor/data_generators/image_utils.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Base classes and utilities for image datasets.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/image_utils_test.py b/tensor2tensor/data_generators/image_utils_test.py index 875863338..6c0ce9367 100644 --- a/tensor2tensor/data_generators/image_utils_test.py +++ b/tensor2tensor/data_generators/image_utils_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """image_utils test.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/imagenet.py b/tensor2tensor/data_generators/imagenet.py index e20a18fed..109d37c5d 100644 --- a/tensor2tensor/data_generators/imagenet.py +++ b/tensor2tensor/data_generators/imagenet.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ImageNet.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/imdb.py b/tensor2tensor/data_generators/imdb.py index d0f1e5cac..865def20c 100644 --- a/tensor2tensor/data_generators/imdb.py +++ b/tensor2tensor/data_generators/imdb.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """IMDB Sentiment Classification Problem.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/inspect_tfrecord.py b/tensor2tensor/data_generators/inspect_tfrecord.py index dc6aae26a..0113757e6 100644 --- a/tensor2tensor/data_generators/inspect_tfrecord.py +++ b/tensor2tensor/data_generators/inspect_tfrecord.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - r"""Inspect a TFRecord file of tensorflow.Example and show tokenizations. python data_generators/inspect_tfrecord.py \ diff --git a/tensor2tensor/data_generators/librispeech.py b/tensor2tensor/data_generators/librispeech.py index 978672ad2..28cb7b756 100644 --- a/tensor2tensor/data_generators/librispeech.py +++ b/tensor2tensor/data_generators/librispeech.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Librispeech dataset.""" import os diff --git a/tensor2tensor/data_generators/lm1b.py b/tensor2tensor/data_generators/lm1b.py index b0b2e719a..a81ff02bd 100644 --- a/tensor2tensor/data_generators/lm1b.py +++ b/tensor2tensor/data_generators/lm1b.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for LM1B data-set.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/mnist.py b/tensor2tensor/data_generators/mnist.py index ef40f62e6..c3d122499 100644 --- a/tensor2tensor/data_generators/mnist.py +++ b/tensor2tensor/data_generators/mnist.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """MNIST.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/mscoco.py b/tensor2tensor/data_generators/mscoco.py index c5472bd87..76b745ade 100644 --- a/tensor2tensor/data_generators/mscoco.py +++ b/tensor2tensor/data_generators/mscoco.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """MS COCO.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/multinli.py b/tensor2tensor/data_generators/multinli.py index e70252005..4a6649a4f 100644 --- a/tensor2tensor/data_generators/multinli.py +++ b/tensor2tensor/data_generators/multinli.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for MultiNLI (https://www.nyu.edu/projects/bowman/multinli/). """ diff --git a/tensor2tensor/data_generators/ocr.py b/tensor2tensor/data_generators/ocr.py index 074686459..fcc5e07e5 100644 --- a/tensor2tensor/data_generators/ocr.py +++ b/tensor2tensor/data_generators/ocr.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """OCR.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 80d44ee61..75448bf70 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Base class for problem/dataset definitions.""" from __future__ import absolute_import from __future__ import division @@ -491,7 +490,8 @@ def dataset(self, dataset_split=None, shard=None, partition_id=0, - num_partitions=1): + num_partitions=1, + max_records=-1): """Build a Dataset for this problem. Args: @@ -512,6 +512,7 @@ def dataset(self, shard: int, if provided, will only read data from the specified shard. partition_id: integer - which partition of the dataset to read from num_partitions: how many partitions in the dataset + max_records: int, number of records to truncate to. Returns: Dataset containing dict. @@ -569,7 +570,7 @@ def _load_records_and_preprocess(filename): _load_records_and_preprocess, sloppy=is_training, cycle_length=8)) dataset = dataset.map( self.maybe_reverse_and_copy, num_parallel_calls=num_threads) - + dataset = dataset.take(max_records) if output_buffer_size: dataset = dataset.prefetch(output_buffer_size) @@ -829,21 +830,28 @@ def _pad_batch(features): dataset = dataset.map(_pad_batch, num_parallel_calls=num_threads) dataset = dataset.map(define_shapes, num_parallel_calls=num_threads) + + def prepare_for_output(example): + if not config or not config.use_tpu: + _summarize_features(example, + (config and config.data_parallelism.n) or 1) + if mode == tf.estimator.ModeKeys.PREDICT: + example["infer_targets"] = example.pop("targets") + return example + else: + return example, example["targets"] + + dataset = dataset.map(prepare_for_output, num_parallel_calls=num_threads) dataset = dataset.prefetch(2) - features = dataset.make_one_shot_iterator().get_next() - if not config or not config.use_tpu: - _summarize_features(features, (config and config.data_parallelism.n) or 1) if mode == tf.estimator.ModeKeys.PREDICT: - features["infer_targets"] = features["targets"] - features["targets"] = None # This is because of a bug in the Estimator that short-circuits prediction # if it doesn't see a QueueRunner. DummyQueueRunner implements the # minimal expected interface but does nothing. tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_reader.DummyQueueRunner()) - return features, features["targets"] + return dataset def serving_input_fn(self, hparams): """Input fn for serving export, starting from serialized example.""" diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py index 262a0dc51..7e84257c7 100644 --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Hyperparameters defining different problems. """ diff --git a/tensor2tensor/data_generators/ptb.py b/tensor2tensor/data_generators/ptb.py index 4ac3911b9..de3db0e28 100644 --- a/tensor2tensor/data_generators/ptb.py +++ b/tensor2tensor/data_generators/ptb.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for PTB data-sets.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/snli.py b/tensor2tensor/data_generators/snli.py index 8533c72c2..7e2dd067c 100644 --- a/tensor2tensor/data_generators/snli.py +++ b/tensor2tensor/data_generators/snli.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for the SNLI data-set.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/speech_recognition.py b/tensor2tensor/data_generators/speech_recognition.py index 2777cd9cf..7f72d5afd 100644 --- a/tensor2tensor/data_generators/speech_recognition.py +++ b/tensor2tensor/data_generators/speech_recognition.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Common classes for automatic speech recognition (ASR) datasets. The audio import uses sox to generate normalized waveforms, please install diff --git a/tensor2tensor/data_generators/squad.py b/tensor2tensor/data_generators/squad.py index 7de1e4efc..178a3a4f4 100644 --- a/tensor2tensor/data_generators/squad.py +++ b/tensor2tensor/data_generators/squad.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for SquaAD (https://rajpurkar.github.io/SQuAD-explorer/). """ @@ -115,11 +114,34 @@ def preprocess_example(self, example, unused_mode, unused_model_hparams): [example['inputs'], sep, example['context']], 0) return example - def generate_data(self, data_dir, tmp_dir, task_id=-1): - tf.logging.warn('Use Squad to generate data for SquadConcat.') - def hparams(self, defaults, unused_model_hparams): (super(SquadConcat, self) .hparams(defaults, unused_model_hparams)) p = defaults del p.input_modality['context'] + + +@registry.register_problem +class SquadConcatPositioned(SquadConcat): + """SquadConcat with targets in format of answer position + answer length.""" + + def generate_targets(self, targets, context): + targets = targets[:-1] # skip last terminal symbol. + targets_new = [] + i = 0 + while i < len(context) - len(targets): + if context[i: i + len(targets)] == targets: + # emit answer's position and length. + targets_new.append(i) + targets_new.append(len(targets)) + i += 1 + return targets_new + + def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split): + samples = (super(SquadConcatPositioned, self) + .generate_encoded_samples(data_dir, tmp_dir, dataset_split)) + for sample in samples: + sample['targets'] = self.generate_targets(sample['targets'], + sample['context']) + if not sample['targets']: + yield sample diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index cb2a43978..699ad001a 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Encoders for text data. * TextEncoder: base class diff --git a/tensor2tensor/data_generators/text_encoder_build_subword.py b/tensor2tensor/data_generators/text_encoder_build_subword.py index 1bd904f62..4c5f01f6c 100644 --- a/tensor2tensor/data_generators/text_encoder_build_subword.py +++ b/tensor2tensor/data_generators/text_encoder_build_subword.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - r"""Program to build a SubwordTextEncoder. The flags --min_count and --corpus_max_lines will affect the size of the diff --git a/tensor2tensor/data_generators/text_encoder_test.py b/tensor2tensor/data_generators/text_encoder_test.py index b3248a7c4..1606e790c 100644 --- a/tensor2tensor/data_generators/text_encoder_test.py +++ b/tensor2tensor/data_generators/text_encoder_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for tensor2tensor.data_generators.text_encoder.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/text_problems.py b/tensor2tensor/data_generators/text_problems.py index de7fbb4e6..79e185190 100644 --- a/tensor2tensor/data_generators/text_problems.py +++ b/tensor2tensor/data_generators/text_problems.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Base classes for text-based Problems. * Text2TextProblem: input=text, target=text. diff --git a/tensor2tensor/data_generators/text_problems_test.py b/tensor2tensor/data_generators/text_problems_test.py index af39e35ef..bddc58a6e 100644 --- a/tensor2tensor/data_generators/text_problems_test.py +++ b/tensor2tensor/data_generators/text_problems_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Text problems test.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/tokenizer.py b/tensor2tensor/data_generators/tokenizer.py index b6c0e3236..92a42382c 100644 --- a/tensor2tensor/data_generators/tokenizer.py +++ b/tensor2tensor/data_generators/tokenizer.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """A simple invertible tokenizer. Converts from a unicode string to a list of tokens diff --git a/tensor2tensor/data_generators/tokenizer_test.py b/tensor2tensor/data_generators/tokenizer_test.py index e977d1126..7c5ababd1 100644 --- a/tensor2tensor/data_generators/tokenizer_test.py +++ b/tensor2tensor/data_generators/tokenizer_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # coding=utf-8 """Tests for tensor2tensor.data_generators.tokenizer.""" diff --git a/tensor2tensor/data_generators/translate.py b/tensor2tensor/data_generators/translate.py index 8d5cf808f..200cc71c0 100644 --- a/tensor2tensor/data_generators/translate.py +++ b/tensor2tensor/data_generators/translate.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for translation data-sets.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/translate_encs.py b/tensor2tensor/data_generators/translate_encs.py index 47f2b9adc..fed471ea9 100644 --- a/tensor2tensor/data_generators/translate_encs.py +++ b/tensor2tensor/data_generators/translate_encs.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for translation data-sets.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/translate_ende.py b/tensor2tensor/data_generators/translate_ende.py index 2a1e52c2f..9991ad3f5 100644 --- a/tensor2tensor/data_generators/translate_ende.py +++ b/tensor2tensor/data_generators/translate_ende.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for translation data-sets.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/translate_enet.py b/tensor2tensor/data_generators/translate_enet.py index 66056cc3a..a81b117ca 100644 --- a/tensor2tensor/data_generators/translate_enet.py +++ b/tensor2tensor/data_generators/translate_enet.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for En-Et translation.""" from __future__ import absolute_import @@ -38,7 +37,7 @@ ("training/europarl-v8.et-en.en", "training/europarl-v8.et-en.et") ], [ - "https://s3.amazonaws.com/web-language-models/paracrawl/release1/paracrawl-release1.en-et.zipporah0-dedup-clean.tgz", # pylint: disable=line-too-long + "https://s3.amazonaws.com/web-language-models/paracrawl/release1/paracrawl-release1.en-et.zipporah0-dedup-clean.tgz", # pylint: disable=line-too-long ("paracrawl-release1.en-et.zipporah0-dedup-clean.en", "paracrawl-release1.en-et.zipporah0-dedup-clean.et") ], diff --git a/tensor2tensor/data_generators/translate_enfr.py b/tensor2tensor/data_generators/translate_enfr.py index 53b46b78a..56f1f23fc 100644 --- a/tensor2tensor/data_generators/translate_enfr.py +++ b/tensor2tensor/data_generators/translate_enfr.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for translation data-sets.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/translate_enmk.py b/tensor2tensor/data_generators/translate_enmk.py index d7f3e08a3..8a0568f05 100644 --- a/tensor2tensor/data_generators/translate_enmk.py +++ b/tensor2tensor/data_generators/translate_enmk.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for translation data-sets.""" from __future__ import absolute_import @@ -33,7 +32,7 @@ # For English-Macedonian the SETimes corpus # from http://nlp.ffzg.hr/resources/corpora/setimes/ is used. _ENMK_TRAIN_DATASETS = [[ - "http://nlp.ffzg.hr/data/corpora/setimes/setimes.en-mk.txt.tgz", # pylint: disable=line-too-long + "http://nlp.ffzg.hr/data/corpora/setimes/setimes.en-mk.txt.tgz", ("setimes.en-mk.en.txt", "setimes.en-mk.mk.txt") ]] diff --git a/tensor2tensor/data_generators/translate_envi.py b/tensor2tensor/data_generators/translate_envi.py index 2003316eb..f102cdff8 100644 --- a/tensor2tensor/data_generators/translate_envi.py +++ b/tensor2tensor/data_generators/translate_envi.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for En-Vi translation.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/translate_enzh.py b/tensor2tensor/data_generators/translate_enzh.py index 9e2f56e04..a1d9ecd53 100644 --- a/tensor2tensor/data_generators/translate_enzh.py +++ b/tensor2tensor/data_generators/translate_enzh.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for translation data-sets.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/translate_test.py b/tensor2tensor/data_generators/translate_test.py index 201898352..1cb3c9f36 100644 --- a/tensor2tensor/data_generators/translate_test.py +++ b/tensor2tensor/data_generators/translate_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Translate generators test.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/twentybn.py b/tensor2tensor/data_generators/twentybn.py index 279f159d9..3505b17a2 100644 --- a/tensor2tensor/data_generators/twentybn.py +++ b/tensor2tensor/data_generators/twentybn.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generator for twenty bn video data-set.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/video_utils.py b/tensor2tensor/data_generators/video_utils.py index 869fad721..db965795d 100644 --- a/tensor2tensor/data_generators/video_utils.py +++ b/tensor2tensor/data_generators/video_utils.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Base classes and utilities for video datasets.""" from __future__ import absolute_import @@ -267,7 +266,7 @@ def generate_encoded_samples_debug(self, data_dir, tmp_dir, dataset_split): data_dir, tmp_dir, dataset_split): if self.debug_dump_frames_path: path = os.path.join(self.debug_dump_frames_path, - "frame_%d.png" % counter) + "frame_%05d.png" % counter) with tf.gfile.Open(path, "wb") as f: f.write(sample["image/encoded"][0]) counter += 1 diff --git a/tensor2tensor/data_generators/wiki.py b/tensor2tensor/data_generators/wiki.py index 5222b5a62..772898745 100644 --- a/tensor2tensor/data_generators/wiki.py +++ b/tensor2tensor/data_generators/wiki.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generator for Wikipedia title to article dataset.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/wikisum/README.md b/tensor2tensor/data_generators/wikisum/README.md new file mode 100644 index 000000000..4713a8d37 --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/README.md @@ -0,0 +1,204 @@ +# Generating Wikipedia by Summarizing Long Sequences + +This directory contains the code and scripts to generate the dataset from the +paper [Generating Wikipedia by Summarizing Long +Sequences](https://arxiv.org/abs/1801.10198). The task is to generate a +Wikipedia article based on the contents of the cited references in that article +and the top 10 Google search results for the article's title. + +There are 2 sources for the reference URLs used: + +1. [CommonCrawl](http://commoncrawl.org/), an open-source crawl of the web. The + advantage of using CommonCrawl is that the dataset is perfectly reproducible. + However, there is limited coverage of the reference URLs. +1. Live web fetches. Coverage is considerably increased, but the content is + subject to change. + +This document provides instructions for producing both datasets. + +## Support files + +Some files that are used in dataset generation have already been generated and +uploaded to Google Cloud Storage as `gs://tensor2tensor-data/wikisum`. + +**URLs:** The dataset contains ~90M URLs total (~2.3M Wikipedia articles, each +with ~40 reference URLs). The URLs in the dataset are available in sharded JSON +files here: `gs://tensor2tensor-data/wikisum/wiki_urls/`. + +**Wikipedia Articles:** We have processed the Wikipedia articles slightly to +extract the title, section breaks, and section headings. The processed Wikipedia +content is available in sharded `TFRecord` files containing serialized +`tensorflow.Example` protocol buffers here: +`gs://tensor2tensor-data/wikisum/wiki_content/`. The sharding is determined by a +hash of the Wikpedia article's title. The `Example`s contain features `[url, +title, section_titles, section_texts]`. + +**CommonCrawl References Index:** To enable efficiently extracting the reference +URLs from CommonCrawl, we provide a JSON file per CommonCrawl file which maps a +reference URL contained in that CommonCrawl file to a list of shard ids: +`gs://tensor2tensor-data/wikisum/commoncrawl_metadata/`. These shards are the +ones that contain one or more Wikipedia articles that cite this reference. The +scripts in this directory will use this information to efficiently join the +reference with their Wikipedia articles. + +*Note*: You can use [`gsutil`](https://cloud.google.com/storage/docs/gsutil) to +view the support files. + +## Data generation + +Data generation will first extract reference content (from either CommonCrawl or +the web), then generate a vocabulary, join the references with their Wikipedia +articles, run TF-IDF to rank reference paragraphs for a given article, and then +encode the references and the Wikipedia article with the vocabulary and write +the encoded training or evaluation example out to disk. + +The output of data generation is a set of `TFRecord` files containing serialized +`tensorflow.Example` protocol buffers, with feature keys `"inputs"` and +`"targets"`. The inputs are the reference tokens, and the targets are the +Wikipedia article tokens. + +In both cases, you must use multiple machines to extract references and produce +the final data to disk because of the size of the data. See `parallel_launch.py` +which is a script that will launch N machines in parallel on GCP. You can use it +as a guide if you'd like to launch on other infrastructure. + +There are 3 jobs to run: + +1. Extract references: `get_references_commoncrawl.py` for `WikisumCommoncrawl` + and `get_references_web.py` for `WikisumWeb`. +1. Build vocabulary (single-machine): `generate_vocab.py` +1. Produce Examples: `produce_examples.py` + +With 1,000 machines with a good internet connection, data generation takes well +under 24 hours. + +## Setup if using `parallel_launch.py` to launch on Google Cloud Platform + +First, [install the `gcloud` CLI](https://cloud.google.com/sdk/downloads). + +``` +# Initialize the CLI +gcloud init + +# Login +gcloud auth login + +# Update the CLI +gcloud components update + +# Set the default project and zone +gcloud config set core/project myproject +gcloud config set compute/zone us-central1-c +``` + +You'll also need to request the requisite +[quotas](https://console.cloud.google.com/iam-admin/quotas) in the zone you'll +be launching the machines in (whatever default zone you set above): + +* In-use IP addresses: 1,000 +* Internal IP addresses: 1,000 +* Persistent Disk Standard (GB): 10,000 +* CPUs: 4,000 + +**Running the commands below will launch instances on Google Cloud Platform and +you will incur charges.** If any of the commands go bad, immediately delete any +stranded instances. `delete_instances.sh` helps you delete instances in bulk +from the command-line, or you can delete many instances at once from the +[GCP Console](https://console.cloud.google.com/). + +### Cost estimates + +These are rough (and **not** guaranteed) estimates of cost if you were to launch +on GCP. + +Pricing is taken from +[here](https://cloud.google.com/compute/pricing#custommachinetypepricing). + +* `WikisumCommoncrawl` + * `get_references_commoncrawl`: $50 (1k machines, 1 CPU, 2G memory, 1 hour) + * `produce_examples`: $350 (1k machines, 1 CPU, 2G memory, 8 hours) +* `WikisumWeb` + * `get_references_web`: $750 (1k machines, 4 CPU, 4G memory, 5 hours) + * `produce_examples`: $350 (1k machines, 1 CPU, 2G memory, 8 hours) + +## Commands to generate `WikisumCommoncrawl` + +``` +pip install tensor2tensor -U --user + +# Set to your own GCS bucket +BUCKET=gs://my-gcs-bucket/wikisum_commoncrawl + +# Extract references from CommonCrawl +python -m tensor2tensor.data_generators.wikisum.parallel_launch.py \ + --num_instances=1000 \ + --cpu=1 --mem=2 \ + --name=wikisum-refs-cc \ + --log_dir=$BUCKET/refs_logs \ + --setup_command="pip install tensor2tensor tensorflow -U -q --user" \ + --command_prefix="python -m tensor2tensor.data_generators.wikisum.get_references_commoncrawl --num_tasks=1000 --out_dir=$BUCKET/wiki_references --task_id" + +# Generate vocabulary file +python -m tensor2tensor.data_generators.wikisum.generate_vocab \ + --out_dir=$BUCKET/data \ + --refs_dir=$BUCKET/wiki_references \ + --for_commoncrawl + +# Produce examples +python -m tensor2tensor.data_generators.wikisum.parallel_launch.py \ + --num_instances=1000 \ + --cpu=1 --mem=2 \ + --name=wikisum-cc-produce \ + --log_dir=$BUCKET/produce_logs \ + --setup_command="pip install tensor2tensor tensorflow -U -q --user" \ + --command_prefix="python -m tensor2tensor.data_generators.wikisum.produce_examples.py --out_dir=$BUCKET/data --refs_dir=$BUCKET/wiki_references --num_tasks=1000 --for_commoncrawl --task_id" +``` + +## Commands to generate `WikisumWeb` + +``` +pip install tensor2tensor -U --user + +# Set to your own GCS bucket +BUCKET=gs://my-gcs-bucket/wikisum_web + +# Fetch references from web +python -m tensor2tensor.data_generators.wikisum.parallel_launch.py \ + --num_instances=1000 \ + --cpu=4 --mem=4 \ + --name=wikisum-refs-web \ + --log_dir=$BUCKET/refs_logs \ + --setup_command="pip3 install tensorflow tensor2tensor aiohttp cchardet aiodns bs4 -U -q --user" \ + --command_prefix="python3 wikisum/get_references_web.py --out_dir=$BUCKET/wiki_references --shard_id" + +# Generate vocabulary file +python -m tensor2tensor.data_generators.wikisum.generate_vocab \ + --out_dir=$BUCKET/data \ + --refs_dir=$BUCKET/wiki_references + +# Produce examples +python -m tensor2tensor.data_generators.wikisum.parallel_launch.py \ + --num_instances=1000 \ + --cpu=1 --mem=2 \ + --name=wikisum-web-produce \ + --log_dir=$BUCKET/produce_logs \ + --setup_command="pip install tensor2tensor tensorflow -U -q --user" \ + --command_prefix="python -m tensor2tensor.data_generators.wikisum.produce_examples.py --out_dir=$BUCKET/data --refs_dir=$BUCKET/wiki_references --num_tasks=1000 --task_id" +``` + +## Training + +**TODO(rsepassi)**: Put actual results achieved on `wikisum_web` and/or +`wikisum_commoncrawl` and with what `hparams_set`. + +``` +PROBLEM=wikisum_web # or wikisum_commoncrawl +t2t-trainer \ + --problem=$PROBLEM \ + --model=transformer \ + --hparams_set=transformer_base \ + --train_steps=250000 \ + --eval_steps=100 \ + --data_dir=$DATA_DIR \ + --output_dir=$TRAIN_DIR +``` diff --git a/tensor2tensor/data_generators/wikisum/__init__.py b/tensor2tensor/data_generators/wikisum/__init__.py new file mode 100644 index 000000000..dba7ece95 --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tensor2tensor/data_generators/wikisum/delete_instances.sh b/tensor2tensor/data_generators/wikisum/delete_instances.sh new file mode 100755 index 000000000..c35e48d8d --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/delete_instances.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# Delete Google Compute Engine instances with naming structure $NAME-$INDEX +# (e.g. machines created with parallel_launch.py). +# Example usage: +# delete_instances.sh fetch-ref-urls 1000 + +NAME=$1 +MAX=$2 +MIN=${3:-0} + +LOG_F=/tmp/delete-$NAME-logs.txt + +echo "Deleting $MAX instances starting with $NAME-$MIN" + +for i in $(seq $MIN $MAX) +do + gcloud compute instances delete --quiet $NAME-$i > $LOG_F 2>&1 & + if [[ $(( i % 100 )) == 0 ]] + then + # Give it some room to breathe every 100 + sleep 30 + fi +done + +echo "Delete commands launched. Logs redirected to $LOG_F" diff --git a/tensor2tensor/data_generators/wikisum/generate_vocab.py b/tensor2tensor/data_generators/wikisum/generate_vocab.py new file mode 100644 index 000000000..b8e64702f --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/generate_vocab.py @@ -0,0 +1,46 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generate vocab from references and wikis.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensor2tensor.data_generators.wikisum import wikisum + +import tensorflow as tf + +flags = tf.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string("out_dir", None, "Directory to write vocab to.") +flags.DEFINE_string("wikis_dir", + "gs://tensor2tensor-data/wikisum/wiki_content/", + "Directory with wiki_content.tfrecords shards.") +flags.DEFINE_string("refs_dir", None, + "Directory with process_X folders with reference shards.") +flags.DEFINE_bool("for_commoncrawl", False, + "Whether to use WikisumCommoncrawl or WikisumWeb.") + + +def main(_): + if FLAGS.for_commoncrawl: + problem = wikisum.WikisumCommoncrawl() + else: + problem = wikisum.WikisumWeb() + problem.generate_vocab(FLAGS.out_dir, FLAGS.wikis_dir, FLAGS.refs_dir) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensor2tensor/data_generators/wikisum/get_references_commoncrawl.py b/tensor2tensor/data_generators/wikisum/get_references_commoncrawl.py new file mode 100644 index 000000000..5532f9a00 --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/get_references_commoncrawl.py @@ -0,0 +1,70 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Extract references from CommonCrawl files.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from tensor2tensor.data_generators.wikisum import utils +from tensor2tensor.data_generators.wikisum import wikisum + +import tensorflow as tf + +flags = tf.flags +FLAGS = flags.FLAGS + +flags.DEFINE_integer("num_tasks", 1000, "Number of parallel tasks.") +flags.DEFINE_integer("task_id", 0, "Task id in a parallel run.") +flags.DEFINE_string("metadata_dir", + "gs://tensor2tensor-data/wikisum/commoncrawl_metadata/", + "Path to metadata files specifying what references are in " + "which CommonCrawl files.") +flags.DEFINE_string("out_dir", None, "Directory to write references to.") +flags.DEFINE_string("commoncrawl_wet_dir", None, + "Path to CommonCrawl wet.gz files locally. If not " + "provided, will download.") + + +def main(_): + assert FLAGS.out_dir + assert FLAGS.metadata_dir + out_dir = os.path.join(FLAGS.out_dir, "process_%d" % FLAGS.task_id) + tf.gfile.MakeDirs(out_dir) + + with utils.timing("get_refs_commoncrawl"): + # Get all WET files + if FLAGS.commoncrawl_wet_dir: + wet_files = tf.gfile.Glob( + os.path.join(FLAGS.commoncrawl_wet_dir, "*.wet.gz")) + else: + tmp_dir = tempfile.gettempdir() + wet_files = list( + utils.wet_download_urls(utils.WET_PATHS_BY_DATE["0917"], tmp_dir)) + + # Shard and select this task's work + wet_files.sort() + wet_files = utils.shard(wet_files, FLAGS.num_tasks)[FLAGS.task_id] + tf.logging.info("Sharded out WET files. Processing %d files", + len(wet_files)) + + wikisum.extract_references_from_wets(wet_files, FLAGS.metadata_dir, out_dir) + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + tf.app.run() diff --git a/tensor2tensor/data_generators/wikisum/get_references_web.py b/tensor2tensor/data_generators/wikisum/get_references_web.py new file mode 100644 index 000000000..05ddda100 --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/get_references_web.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=line-too-long +r"""Fetch reference URLs from all groups for a single shard id. + +Because of an SSL memory leak in Python 3.5, fetching too many URLs in the same +Python process will OOM. This script wraps get_references_web_single_group.py +and calls it through subprocess for each group in the shard, where each group is +~5k URLs. + +Launch with parallel_launch.py + +Each job should finish in ~5 hours with the settings below. + +GCS_BUCKET=gs://my-bucket +python parallel_launch.py \ + --num_instances=1000 \ + --cpu=4 \ + --mem=4 \ + --name=get-refs-web \ + --code_dir=./ \ + --log_dir=$GCS_BUCKET/logs \ + --setup_command="pip3 install aiohttp cchardet aiodns bs4 -q --user" \ + --command_prefix="python3 wikisum/get_references_web.py --out_dir=$GCS_BUCKET/wiki_references --shard_id" +""" +# pylint: enable=line-too-long +import math +import os +import subprocess as sp + +from tensor2tensor.data_generators.wikisum import get_references_web_single_group as fetch +from tensor2tensor.data_generators.wikisum import utils + +import tensorflow as tf + + +flags = tf.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string("command", + "python3 wikisum/get_references_web_single_group.py", + "Command to run get_references_web_single_group, without " + "flags.") + + +def main(_): + shard_urls = fetch.get_urls_for_shard(FLAGS.urls_dir, FLAGS.shard_id) + num_groups = int(math.ceil(len(shard_urls) / fetch.URLS_PER_CLIENT)) + tf.logging.info("Launching get_references_web_single_group sequentially for " + "%d groups in shard %d. Total URLs: %d", + num_groups, FLAGS.shard_id, len(shard_urls)) + command_prefix = FLAGS.command.split() + [ + "--urls_dir=%s" % FLAGS.urls_dir, + "--shard_id=%d" % FLAGS.shard_id, + "--debug_num_urls=%d" % FLAGS.debug_num_urls, + ] + with utils.timing("all_groups_fetch"): + for i in range(num_groups): + command = list(command_prefix) + out_dir = os.path.join(FLAGS.out_dir, "process_%d" % i) + command.append("--out_dir=%s" % out_dir) + command.append("--group_id=%d" % i) + try: + # Even on 1 CPU, each group should finish within an hour. + sp.check_call(command, timeout=60*60) + except sp.TimeoutExpired: + tf.logging.error("Group %d timed out", i) + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + tf.app.run() diff --git a/tensor2tensor/data_generators/wikisum/get_references_web_single_group.py b/tensor2tensor/data_generators/wikisum/get_references_web_single_group.py new file mode 100644 index 000000000..ca1717c73 --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/get_references_web_single_group.py @@ -0,0 +1,361 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fetch reference URLs for a single group_id within a single shard_id. + +See get_references_web.py to fetch URLs for all groups in within a single +shard_id. + +Requires Python 3.5 +pip3 install aiohttp cchardet aiodns bs4 tensorflow +""" + +import datetime +import json +import math +import multiprocessing +import os +import random + +import asyncio +import aiohttp +import bs4 +import tensorflow as tf + +from tensor2tensor.data_generators.wikisum import utils + + +flags = tf.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string("urls_dir", "gs://tensor2tensor-data/wikisum/wiki_urls/", + "Directory with wiki_urls.json files.") +flags.DEFINE_string("out_dir", None, "Directory to write reference files.") +flags.DEFINE_integer("max_parallel_requests", 50, + "Number of web requests to make in parallel.") + +# Identify which URLs to fetch +flags.DEFINE_integer("shard_id", 0, "ID of URL shard to process.") +flags.DEFINE_integer("group_id", 0, "ID of group within the shard to process.") + +flags.DEFINE_bool("log_samples", False, + "Whether to write out samples of the text extraction.") +flags.DEFINE_integer("log_every", 1000, + "How often to log and write out samples.") +flags.DEFINE_integer("debug_num_urls", 0, + "If >0, limits number of URLs fetched per input shard. " + "For debugging purposes only.") + + +WIKI_URLS_FILE = "wiki_urls.json-%05d-of-01000" +REF_SHARD_FILE = "references.tfrecords.gz-%05d-of-01000" + +# Note that this program leaks memory, likely due to a bug in Python's SSL +# implementation that leaks sockets. This constant is used here and in +# get_references_web.py to limit the number of requests made by a single +# Python process. The more requests made, the more memory required due to the +# leak. +# TODO(rsepassi): Document memory impact of changing this. +URLS_PER_CLIENT = 5000 + + +def concat_tfrecord_files(fnames, out_fname, rm_after=True): + with tf.gfile.Open(out_fname, "wb") as out_f: + for fname in fnames: + with tf.gfile.Open(fname, "rb") as in_f: + while True: + read = in_f.read(1000) + if not read: + break + out_f.write(read) + if rm_after: + tf.gfile.Remove(fname) + + +def shard(items, num_shards): + """Split items into num_shards groups.""" + sharded = [] + num_per_shard = len(items) // num_shards + start = 0 + for _ in range(num_shards): + sharded.append(items[start:start + num_per_shard]) + start += num_per_shard + + remainder = len(items) % num_shards + start = len(items) - remainder + for i in range(remainder): + sharded[i].append(items[start + i]) + + assert sum([len(fs) for fs in sharded]) == len(items) + return sharded + + +def soup_strings(soup): + paragraph_tags = set(["caption", "details", "h1", "h2", "h3", "h4", "h5", + "h6", "li", "p", "td", "div", "span"]) + + skip_children = None + for descendant in soup.descendants: + # If we've treated a tag as a contiguous paragraph, don't re-emit the + # children (see below). + if skip_children is not None: + try: + in_skip = descendant in skip_children + except RecursionError: + # Possible for this check to hit a nasty infinite recursion because of + # BeautifulSoup __eq__ checks. + in_skip = True + if in_skip: + continue + else: + skip_children = None + + # Treat some tags as contigous paragraphs, regardless of other tags nested + # inside (like or ). + if isinstance(descendant, bs4.Tag): + if descendant.name in paragraph_tags: + if descendant.find_all(paragraph_tags): + # If there are nested paragraph tags, don't treat it as a single + # contiguous tag. + continue + skip_children = list(descendant.descendants) + text = " ".join(descendant.get_text(" ", strip=True).split()) + if text: + yield text + continue + + if (isinstance(descendant, bs4.Comment) or + not isinstance(descendant, bs4.NavigableString)): + continue + + text = " ".join(descendant.strip().split()) + if text: + yield text + + +def mp_get_text(url, html): + return url, get_text_from_html(html) + + +def get_text_from_html(html): + try: + soup = bs4.BeautifulSoup(html, 'html.parser') + except: + # Some docs don't parse + return "" + # Remove script and style tags + for s in soup(["script", "style"]): + s.decompose() + return "\n".join([s for s in soup_strings(soup)]) + + +def encode(s): + return bytes(s, "utf-8") + + +def make_example_from_ref(url, ref): + try: + url = encode(url) + ref = encode(ref) + except UnicodeEncodeError: + return None + + features = { + "url": + tf.train.Feature(bytes_list=tf.train.BytesList(value=[url])), + "content": + tf.train.Feature( + bytes_list=tf.train.BytesList(value=[ref])), + } + return tf.train.Example(features=tf.train.Features(feature=features)) + + +def tfrecord_fname(out_dir, shard_id, idx=None): + fname = os.path.join(out_dir, REF_SHARD_FILE % shard_id) + if idx is not None: + fname += ".%d" % idx + return fname + + +def make_tfrecord_writer(fname): + opts = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP) + return tf.python_io.TFRecordWriter(fname, opts) + + +def write_ref_content(url, ref, f): + if not ref: + return False + ex = make_example_from_ref(url, ref) + if ex is None: + return False + f.write(ex.SerializeToString()) + return True + + +async def fetch_url(url, session, side_data): + text = None + try: + async with session.get(url, timeout=30, verify_ssl=False) as response: + if response.status == 200: + text = await response.text() + else: + tf.logging.error("Status %d, url: %s", response.status, url) + except: + # Request can fail for many reasons. + pass + + return text, side_data + + +async def throttled_fetch_url(url, sem, session, side_data): + async with sem: + return await fetch_url(url, session, side_data) + + +async def fetch_urls(urls, + out_fname, + logging_fnames=None): + tasks = [] + connector = aiohttp.TCPConnector(limit_per_host=1) + async with aiohttp.ClientSession( + connector=connector, cookie_jar=aiohttp.DummyCookieJar()) as session: + # Async fetch the urls + sem = asyncio.Semaphore(FLAGS.max_parallel_requests) + for url in urls: + side_data = {"url": url} + task = asyncio.ensure_future( + throttled_fetch_url(url, sem, session, side_data)) + tasks.append(task) + tf.logging.info("Async requested %d urls", len(urls)) + + # Setup output files + file_handles = [] + out_f = make_tfrecord_writer(out_fname) + file_handles.append(out_f) + + logging_fnames = logging_fnames or {} + + samples_f = None + if "samples" in logging_fnames: + samples_f = tf.gfile.Open(logging_fnames["samples"], "w") + file_handles.append(samples_f) + + refs_written = [0] # Made a list so can be mutated + + def text_extraction_callback(callback_arg): + url, text = callback_arg + written = write_ref_content(url, text, out_f) + if not written: + return + if not refs_written[0] % FLAGS.log_every: + timestamp = datetime.datetime.now().strftime("%H:%M") + tf.logging.info("%s: Wrote ref %d in group", timestamp, refs_written[0]) + if samples_f is not None: + samples_f.write(url) + samples_f.write("\n") + samples_f.write(text) + samples_f.write("\n\n---\n\n") + refs_written[0] += 1 + + try: + # Process each URL as it comes in. + # Using a multiprocessing Pool because the text extraction is expensive + # and so we distribute across cores. + pool = multiprocessing.Pool() + results = [] + for task in asyncio.as_completed(tasks): + html, side_data = await task + url = side_data["url"] + if not html: + continue + res = pool.apply_async(mp_get_text, (url, html), {}, + text_extraction_callback) + results.append(res) + for res in results: + try: + res.get(timeout=10) + except multiprocessing.TimeoutError: + pass + finally: + for f in file_handles: + f.close() + + return refs_written[0] + + +def get_urls_per_shard(urls_files): + total_urls = 0 + per_shard = {} + for urls_file in urls_files: + ref_urls = set() + shard_id = int(os.path.basename(urls_file)[15:20]) + with tf.gfile.Open(urls_file) as f: + wiki_urls = json.loads(f.read()) + for _, wiki_info in wiki_urls.items(): + ref_urls |= set(wiki_info["refs"]) + + per_shard[shard_id] = list(ref_urls) + total_urls += len(ref_urls) + return per_shard, total_urls + + +def get_urls_for_shard(urls_dir, shard_id): + urls_file = os.path.join(urls_dir, WIKI_URLS_FILE % shard_id) + urls_per_shard, _ = get_urls_per_shard([urls_file]) + assert len(urls_per_shard) == 1 + return urls_per_shard[shard_id] + + +def get_urls_for_shard_group(urls_dir, shard_id, group_id): + shard_urls = get_urls_for_shard(urls_dir, shard_id) + + # Deterministic sort and shuffle to prepare for sharding + shard_urls.sort() + random.seed(123) + random.shuffle(shard_urls) + groups = shard(shard_urls, int(math.ceil(len(shard_urls) / URLS_PER_CLIENT))) + group_urls = groups[group_id] + if FLAGS.debug_num_urls: + group_urls = group_urls[:FLAGS.debug_num_urls] + return group_urls + + +def main(_): + urls = get_urls_for_shard_group( + FLAGS.urls_dir, FLAGS.shard_id, FLAGS.group_id) + tf.logging.info("Fetching %d URLs for shard %d, group %d", + len(urls), FLAGS.shard_id, FLAGS.group_id) + + tf.gfile.MakeDirs(FLAGS.out_dir) + out_fname = tfrecord_fname(FLAGS.out_dir, FLAGS.shard_id) + + with utils.timing("group_fetch"): + logging_fnames = {} + if FLAGS.log_samples: + logging_fnames["samples"] = os.path.join( + FLAGS.out_dir, "samples.%d.txt" % FLAGS.shard_id) + loop = asyncio.get_event_loop() + num_written = loop.run_until_complete(asyncio.ensure_future( + fetch_urls(urls, + out_fname, + logging_fnames))) + + tf.logging.info("Total URLs: %d", len(urls)) + tf.logging.info("Num written: %d", num_written) + tf.logging.info("Coverage: %.1f", (num_written / len(urls)) * 100) + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + tf.app.run() diff --git a/tensor2tensor/data_generators/wikisum/parallel_launch.py b/tensor2tensor/data_generators/wikisum/parallel_launch.py new file mode 100644 index 000000000..9cec88d07 --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/parallel_launch.py @@ -0,0 +1,276 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=line-too-long +r"""Launch a script in parallel on GCP. + +For each instance (`--num_instances`), the script will copy the code in +`--code_dir` to the instance, run `--setup_command` and then run +`--command_prefix` joined with the task's id or a line in +`--per_instance_suffix_file`. + +Note that the machines will attempt to down themselves on completion or failure. +If they do not, you can delete them manually or use delete_instances.sh to +delete many at once. + +Example usage: + +``` +BUCKET=gs://my-bucket +python parallel_launch.py \ + --num_instances=1000 \ + --cpu=4 --mem=4 \ + --name=wikisum-refs-web \ + --code_dir=./ \ + --log_dir=$BUCKET/refs_logs \ + --setup_command="pip3 install aiohttp cchardet aiodns bs4 -q --user" \ + --command_prefix="python3 wikisum/get_references_web.py --out_dir=$BUCKET/wiki_references --shard_id" +``` +""" +# pylint: enable=line-too-long + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import multiprocessing as mp +import os +import subprocess as sp +import time + +from tensor2tensor.utils import cloud_tpu as cloud +import tensorflow as tf + +flags = tf.flags +FLAGS = flags.FLAGS + + +flags.DEFINE_integer("num_instances", None, "Number of instances to launch.") +flags.DEFINE_string("name", None, "Instance name prefix.") +flags.DEFINE_string("log_dir", None, "GCS bucket to copy logs out to.") +flags.DEFINE_string("code_dir", None, "Directory to copy.") +flags.DEFINE_string("setup_command", None, "Setup command to run.") +flags.DEFINE_string("command_prefix", None, "Command to run, prefix.") +flags.DEFINE_string("per_instance_suffix_file", None, + "Command to run, suffix per instance. If None, suffix will " + "be instance id.") +flags.DEFINE_integer("cpu", 1, "Number of CPUs per instance.") +flags.DEFINE_integer("mem", 4, "Memory in GB per instance.") +flags.DEFINE_integer("num_threads", 48, + "Number of threads to use to spin up jobs.") +flags.DEFINE_bool("debug_keep_up", False, + "If True, will keep the machine up. num_instances must be 1.") +flags.DEFINE_string("instance_ids", None, + "Comma-separated list of integer instance ids to launch. " + "Useful if some failed on a previous run and you only want " + "to rerun specific tasks.") + + +DELETE = "gcloud compute instances delete {name}" +DELETE_SELF = ("gcloud compute instances delete $(hostname) --quiet " + "--zone={zone}") +CREATE_INSTANCE = ("gcloud compute instances create {instance_name} " + "--custom-cpu {cpu} --custom-memory {mem} " + "--custom-extensions " + "--image-project=ml-images --image-family=tf-1-7 " + "--scopes=cloud-platform") +COPY_CODE = "gcloud compute scp --recurse {local_dir} {instance_name}:~/" +SSH = "gcloud compute ssh {instance_name} --command" +SCREEN = "screen -dmS test bash -c \"{command}\"" +SSH_CHECK = "nc -w 1 -z {ip} 22" +DEFAULT_ZONE = "gcloud config get-value compute/zone" +LOGS = "> ~/logs-{task_id}.txt 2>&1; gsutil cp ~/logs-{task_id}.txt {bucket}" + + +def remote_run(cmd, instance_name, detach=False, retries=1): + if detach: + cmd = SCREEN.format(command=cmd) + args = SSH.format(instance_name=instance_name).split() + args.append(cmd) + for i in range(retries + 1): + try: + if i > 0: + tf.logging.info("Retry %d for %s", i, args) + return sp.check_call(args) + except sp.CalledProcessError as e: + if i == retries: + raise e + + +def default_zone(): + return cloud.shell_output(DEFAULT_ZONE).strip() + + +def wait_for_ssh(ip): + """Wait for SSH to be available at given IP address.""" + i = 0 + while True: + try: + cloud.shell_run(SSH_CHECK, ip=ip) + break + except sp.CalledProcessError: + if i > 12: # ~2m + return False + time.sleep(10) + i += 1 + return True + + +def create_instance(instance_name, cpu=1, mem=4): + tf.logging.info("Creating instance %s", instance_name) + out = cloud.shell_output(CREATE_INSTANCE, instance_name=instance_name, + cpu=cpu, mem=mem) + return out.split("\n")[1:-1][0].split()[8] + + +def list_vm_names_and_ips(): + list_out = cloud.shell_output(cloud.Gcloud.LIST_VM) + lines = [l.split() for l in list_out.split("\n")[1:-1]] + names_and_ips = [(l[0].strip(), l[-2].strip()) for l in lines] + return names_and_ips + + +def shell_run_with_retry(cmd, retries=1, **kwargs): + for i in range(retries + 1): + try: + if i > 0: + tf.logging.info("Retry %d for %s", i, cmd) + cloud.shell_run(cmd, **kwargs) + return + except sp.CalledProcessError as e: + if i == retries: + raise e + + +def delete_instance(instance_name): + cloud.shell_run(DELETE, name=instance_name) + + +def launch_instance(instance_name, + command, + existing_ip=None, + cpu=1, + mem=4, + code_dir=None, + setup_command=None): + """Launch a GCE instance.""" + # Create instance + ip = existing_ip or create_instance(instance_name, cpu=cpu, mem=mem) + tf.logging.info("Waiting for SSH %s", instance_name) + ready = wait_for_ssh(ip) + if not ready: + raise ValueError("Instance %s never ready for SSH" % instance_name) + + # Copy code + if code_dir: + shell_run_with_retry(COPY_CODE, retries=2, + local_dir=code_dir, instance_name=instance_name) + + # Run setup + if setup_command: + tf.logging.info("Running setup on %s", instance_name) + remote_run(setup_command, instance_name) + + # Run command + tf.logging.info("Running command on %s", instance_name) + remote_run(command, instance_name, detach=True) + + +def main(_): + assert FLAGS.num_instances + assert FLAGS.name + zone = default_zone() + assert zone + + code_dir = None + if FLAGS.code_dir: + code_dir = os.path.abspath(os.path.expanduser(FLAGS.code_dir)) + + # Suffixes per instance + if FLAGS.per_instance_suffix_file: + with tf.gfile.Open(FLAGS.per_instance_suffix_file) as f: + suffixes = [l.strip() for l in f.readlines()] + else: + suffixes = list(range(FLAGS.num_instances)) + assert len(suffixes) == FLAGS.num_instances + + vm_info = list_vm_names_and_ips() + vm_names = zip(*vm_info)[0] if vm_info else [] + + pool = mp.Pool(FLAGS.num_threads) + async_results = [] + + log_dir = None + if FLAGS.log_dir: + log_dir = os.path.join(FLAGS.log_dir, FLAGS.name) + tf.gfile.MakeDirs(log_dir) + assert log_dir.startswith("gs://") + if not log_dir.endswith("/"): + log_dir += "/" + + instance_ids = list(range(FLAGS.num_instances)) + if FLAGS.instance_ids: + instance_ids = [int(i) for i in FLAGS.instance_ids.split(",")] + tf.logging.info("Launching %d instances", len(instance_ids)) + + for i in instance_ids: + instance_name = "%s-%d" % (FLAGS.name, i) + existing_ip = (vm_info[vm_names.index(instance_name)][1] + if instance_name in vm_names else None) + logging = LOGS.format(task_id=i, bucket=log_dir) if log_dir else "" + delete = DELETE_SELF.format(zone=zone) + if FLAGS.debug_keep_up: + assert FLAGS.num_instances == 1 + delete = "" + command = "{prefix} {suffix} {logging}; {delete}".format( + prefix=FLAGS.command_prefix, + suffix=suffixes[i], + delete=delete, + logging=logging) + args = (instance_name, command, existing_ip, + FLAGS.cpu, FLAGS.mem, code_dir, + FLAGS.setup_command) + res = pool.apply_async(launch_instance, args) + async_results.append(res) + + failed = [] + for i, res in enumerate(async_results): + try: + res.get() + except: # pylint: disable=bare-except + failed.append(i) + tf.logging.error("Failed to launch task %d", i) + + results = [] + if failed: + tf.logging.error("Failed to launch %d jobs. Task ids: %s. " + "Attempting delete in case they are still up.", + len(failed), str(failed)) + for i in failed: + instance_name = "%s-%d" % (FLAGS.name, i) + res = pool.apply_async(delete_instance, (instance_name,)) + results.append(res) + + for res in results: + try: + res.get() + except: # pylint: disable=bare-except + pass + + tf.logging.info("Launching complete.") + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + tf.app.run() diff --git a/tensor2tensor/data_generators/wikisum/produce_examples.py b/tensor2tensor/data_generators/wikisum/produce_examples.py new file mode 100644 index 000000000..7ce3a3508 --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/produce_examples.py @@ -0,0 +1,71 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Produce examples given a vocab, wikis, references, and dataset URLs.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensor2tensor.data_generators.wikisum import utils +from tensor2tensor.data_generators.wikisum import wikisum + +import tensorflow as tf + +flags = tf.flags +FLAGS = flags.FLAGS + +flags.DEFINE_integer("num_tasks", 1000, "Number of parallel tasks.") +flags.DEFINE_integer("task_id", 0, "Task id in a parallel run.") +flags.DEFINE_string("out_dir", None, "Directory to write to.") +flags.DEFINE_string("wikis_dir", + "gs://tensor2tensor-data/wikisum/wiki_content/", + "Directory with wiki_content.tfrecords.") +flags.DEFINE_string("refs_dir", None, "Directory with process_X dirs") +flags.DEFINE_string("urls_dir", "gs://tensor2tensor-data/wikisum/wiki_urls/", + "Directory with wiki_urls.json") +flags.DEFINE_string("vocab_dir", None, "Directory with vocab file") +flags.DEFINE_bool("for_commoncrawl", False, + "Whether to use WikisumCommoncrawl or WikisumWeb.") + + +def main(_): + if FLAGS.for_commoncrawl: + problem = wikisum.WikisumCommoncrawl() + else: + problem = wikisum.WikisumWeb() + + out_filepaths = problem.out_filepaths(FLAGS.out_dir) + out_filepaths = utils.shard(out_filepaths, FLAGS.num_tasks)[FLAGS.task_id] + + if not FLAGS.vocab_dir: + FLAGS.vocab_dir = FLAGS.out_dir + + shard_ids = utils.shard(list(range(utils.NUM_SHARDS)), + FLAGS.num_tasks)[FLAGS.task_id] + + with utils.timing("produce_examples"): + wikisum.produce_examples( + shard_ids=shard_ids, + wikis_dir=FLAGS.wikis_dir, + refs_dir=FLAGS.refs_dir, + urls_dir=FLAGS.urls_dir, + vocab_path=os.path.join(FLAGS.vocab_dir, problem.vocab_filename), + out_filepaths=out_filepaths) + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + tf.app.run() diff --git a/tensor2tensor/data_generators/wikisum/test_data/para_bad1.txt b/tensor2tensor/data_generators/wikisum/test_data/para_bad1.txt new file mode 100644 index 000000000..b15107bd9 --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/test_data/para_bad1.txt @@ -0,0 +1,11 @@ +kolkata ward no 97 37 +you are here : india » west bengal » kolkata » kolkata +this paragraph too short +a | b | c | d | e | f | g | h | i | j | k | l | m | n | o | p | q | r | s | t | u | v | w | x | y | z +123 123 123 123 985 9880 1230 0980 . 12398 . +- 5 . 7 % - 5 . 2 % - 15 . 1 % 4 . 7 % - 13 . 3 % +http : / / www . bbc . co . uk / sport / football / 24351521 +no . - 26 beadon street . +{ { / playpopup } } { { ^ playpopup } } { { # playinvideopage } } { { / playinvideopage } } { { ^ playinvideopage } } { { / playinvideopage } } { { / playpopup } }

{ { # playpopup } } { { / playpopup } } { { ^ playpopup } } { { # playinvideopage } } { { / playinvideopage } } { { ^ playinvideopage } } { { / playinvideopage } } { { / playpopup } } { { genre } } +denham , samuel coulter , sally 133 oct 28 1819 +browse by diff --git a/tensor2tensor/data_generators/wikisum/test_data/para_good1.txt b/tensor2tensor/data_generators/wikisum/test_data/para_good1.txt new file mode 100644 index 000000000..99f78ef45 --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/test_data/para_good1.txt @@ -0,0 +1,15 @@ +this is a very good paragraph . it even has two sentences . +the castle that was soon to figure so largely in lee’s life lay fourteen miles +to the southwest of where he sat perched atop his tank . topped with storybook +crenelations and accompanied by a rich history , schloss itter , as it’s called +in german , was first mentioned in land records as early as 1240 . since then , +itter has passed through a number of hands . after germany’s march 1938 +annexation of austria , the castle’s robust construction and relatively remote +location attracted the attention of the notoriously secretive nazis . within +months of absorbing austria into the greater reich , the german government +requisitioned castle itter for unspecified “official use”—which included housing +for several months in 1942 an organization called the “german association for +combating the dangers of tobacco . ” on february 7 , 1943 , it fell into new +hands yet again , for on that day , the structure and all its outbuildings were +requisitioned by the wehrmacht on behalf of the ss . +the url for the site is http : / / www . bbc . co . uk / sport / football / 24351521 . diff --git a/tensor2tensor/data_generators/wikisum/utils.py b/tensor2tensor/data_generators/wikisum/utils.py new file mode 100644 index 000000000..c2c6f8c88 --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/utils.py @@ -0,0 +1,268 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wikisum data generation utilities.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import contextlib +import datetime +import gzip +import os +import re +import urllib + +import tensorflow as tf + +# pylint: disable=g-import-not-at-top +# To maintain compatibility with Python 2 and 3 +try: + import cStringIO as StringIO +except ImportError: + import io as StringIO +# pylint: enable=g-import-not-at-top + + +# Each entry is a URL to the wet.paths.gz file for that CommonCrawl dump. +WET_PATHS_BY_DATE = { + '0917': ('https://commoncrawl.s3.amazonaws.com/crawl-data/CC-MAIN-2017-39/' + 'wet.paths.gz'), +} + +S3_HTTP_PREFIX = 'https://commoncrawl.s3.amazonaws.com/' +NUM_SHARDS = 1000 +METADTA_SUFFIX = '.metadata.json' + + + +def readahead(path): + return path + + +class WETHeader(collections.namedtuple('WETHeader', ['url', 'length'])): + URI_HEADER = 'WARC-Target-URI: ' + LENGTH_HEADER = 'Content-Length: ' + + @classmethod + def read(cls, f): + """Read header from file. Headers end with length and then 1 blank line.""" + url = None + + line = f.readline() + if not line: + # EOF + return None + while not line.startswith(cls.LENGTH_HEADER): + if line.startswith(cls.URI_HEADER): + url = line[len(cls.URI_HEADER):].strip() + line = f.readline() + + # Consume empty separator + f.readline() + + # Read content + length = int(line.split(':')[1]) + + return cls(url, length) + + +class WETRecord(collections.namedtuple('WETRecord', ['url', 'content'])): + + @classmethod + def read(cls, f): + """Read WETRecord from file. Records end with 2 blank lines.""" + header = WETHeader.read(f) + if header is None: + # EOF + return None + content = f.read(header.length) + + # Consume empty separators + f.readline() + f.readline() + + return cls(header.url, content) + + +def wet_records_from_file_obj(f, take_ownership=False): + """Iterate through records in WET file object.""" + while True: + record = WETRecord.read(f) + + if record is None: + break + + if not record.url: + continue + + yield record + + if take_ownership: + f.close() + + +def wet_records(wet_filepath): + """Generate WETRecords from filepath.""" + if wet_filepath.endswith('.gz'): + fopen = gzip.open + else: + fopen = tf.gfile.FastGFile + + with fopen(wet_filepath) as f: + for record in wet_records_from_file_obj(f): + yield record + + +def download(url, download_dir): + outname = os.path.join(download_dir, os.path.basename(url)) + if tf.gfile.Exists(outname): + print('Found %s, skipping download' % outname) + return outname + inprogress = outname + '.incomplete' + print('Downloading %s' % url) + inprogress, _ = urllib.urlretrieve(url, inprogress) + tf.gfile.Rename(inprogress, outname) + return outname + + +def wet_download_urls(wet_paths_url, tmp_dir, rm_after=True): + paths_gz = download(wet_paths_url, tmp_dir) + with gzip.open(paths_gz) as f: + path = f.readline() + while path: + download_path = S3_HTTP_PREFIX + path[:-1] + yield download_path + path = f.readline() + if rm_after: + tf.gfile.Remove(paths_gz) + + +def wet_records_from_url(download_url, tmp_dir, rm_after=True): + wet_gz = download(download_url, tmp_dir) + try: + for wet_record in wet_records(wet_gz): + yield wet_record + finally: + if rm_after: + tf.gfile.Remove(wet_gz) + + +class DummyPool(object): + + def __init__(self, processes=None): + pass + + def apply_async(self, fn, args=None): + args = args or tuple() + return DummyResult(fn(*args)) + + def map(self, fn, arg_list): + return [fn(a) for a in arg_list] + + +class DummyResult(object): + + def __init__(self, result): + self.result = result + + def get(self): + return self.result + + +def shard(items, num_shards): + """Split items into num_shards groups.""" + sharded = [] + num_per_shard = len(items) // num_shards + start = 0 + for _ in range(num_shards): + sharded.append(items[start:start + num_per_shard]) + start += num_per_shard + + remainder = len(items) % num_shards + start = len(items) - remainder + for i in range(remainder): + sharded[i].append(items[start + i]) + + assert sum([len(fs) for fs in sharded]) == len(items) + return sharded + + +def gzip_memfile(fname): + with tf.gfile.Open(readahead(fname)) as f: + memfile = StringIO.StringIO(f.read()) + return gzip.GzipFile(fileobj=memfile) + + +_SOME_ALPHA_RE = re.compile(r'[A-Za-z]+') +_ONLY_ALPHA_RE = re.compile(r'^[A-Za-z]*$') + + +def filter_paragraph(p): + """Simple filter to remove obviously bad paragraphs (bad text extraction). + + Note this needs to run very quickly as it is applied to every paragraph + in the corpus, so nothing fancy! This whole method should be linear + expected time in len(p). + + Args: + p: string, paragraph + + Returns: + True if we should remove the paragraph. + """ + # Expect a minimum number of words. + tokens = p.split() + if len(tokens) < 6: + return True + + # Require some letters. + if not re.search(_SOME_ALPHA_RE, p): + return True + + # Keep this one at the end, probably the most complicated logic. + # We try to detect sentences, which should have a minimum of 3 tokens + # with only alphabetic characters. + last = 0 + found_sentence = False + num_alpha = 0 + for i, x in enumerate(tokens): + if x == '.': + if i - last > 3 and num_alpha >= 3: + found_sentence = True + break + last = i + num_alpha = 0 + if re.match(_ONLY_ALPHA_RE, x): + num_alpha += 1 + if not found_sentence: + return True + + return False + + +@contextlib.contextmanager +def timing(name=''): + """Log start, end, and duration.""" + start = datetime.datetime.now() + timestamp = start.strftime('%H:%M') + tf.logging.info('Starting job [%s] at %s', name, timestamp) + yield + end = datetime.datetime.now() + timestamp = end.strftime('%H:%M') + tf.logging.info('Finished job [%s] at %s', name, timestamp) + duration = end - start + duration_mins = duration.total_seconds() / 60 + tf.logging.info('Total time [%s] (m): %d', name, int(duration_mins)) diff --git a/tensor2tensor/data_generators/wikisum/utils_test.py b/tensor2tensor/data_generators/wikisum/utils_test.py new file mode 100644 index 000000000..d57c187ae --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/utils_test.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensor2tensor.data_generators.wikisum.utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +# Dependency imports + +from tensor2tensor.data_generators.wikisum import utils + +import tensorflow as tf + +pkg_dir, _ = os.path.split(__file__) +_TESTDATA = os.path.join(pkg_dir, "test_data") + + +def _get_testdata(filename): + with tf.gfile.Open(os.path.join(_TESTDATA, filename)) as f: + return f.read() + + +class UtilsTest(tf.test.TestCase): + + def test_filter_paragraph(self): + for bad in tf.gfile.Glob(os.path.join(_TESTDATA, "para_bad*.txt")): + for p in _get_testdata(bad).split("\n"): + self.assertTrue(utils.filter_paragraph(p), + msg="Didn't filter %s" % p) + for good in tf.gfile.Glob(os.path.join(_TESTDATA, "para_good*.txt")): + for p in _get_testdata(good).split("\n"): + p = _get_testdata(good) + self.assertFalse(utils.filter_paragraph(p), msg="Filtered %s" % p) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensor2tensor/data_generators/wikisum/wikisum.py b/tensor2tensor/data_generators/wikisum/wikisum.py new file mode 100644 index 000000000..2dac504d6 --- /dev/null +++ b/tensor2tensor/data_generators/wikisum/wikisum.py @@ -0,0 +1,578 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wikipedia Summarization Problems.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import json +import math +import os +import re +import string +import tempfile + +import six +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.data_generators import tokenizer +from tensor2tensor.data_generators.wikisum import utils as cc_utils +from tensor2tensor.utils import metrics +from tensor2tensor.utils import registry +import tensorflow as tf + +PROCESS_FOLDER_PREFIX = "process" +REF_SHARD_FILE_PREFIX = "references.tfrecords.gz" +REF_SHARD_FILE = REF_SHARD_FILE_PREFIX + "-%05d-of-01000" + +# Support files +BASE_SUPPORT_DIR = "gs://tensor2tensor-data/wikisum" +WIKI_CONTENT_DIR = os.path.join(BASE_SUPPORT_DIR, "wiki_content") +WIKI_URLS_DIR = os.path.join(BASE_SUPPORT_DIR, "wiki_urls") +WET_METADATA_DIR = os.path.join(BASE_SUPPORT_DIR, "commoncrawl_metadata") +WIKI_CONTENT_FILE = "wiki_content.tfrecords-%05d-of-01000" +WIKI_URLS_FILE = "wiki_urls.json-%05d-of-01000" + +EOT = "" # end-of-title string +_MIN_REFS = 1 +_MIN_LEADSECTION_TOKENS = 1 + + +class WikisumBase(problem.Problem): + """Base class for Wikisum problems.""" + + def example_reading_spec(self): + data_fields = { + "inputs": tf.VarLenFeature(tf.int64), + "targets": tf.VarLenFeature(tf.int64), + "section_boundaries": tf.VarLenFeature(tf.int64), + } + data_items_to_decoders = None + return (data_fields, data_items_to_decoders) + + @property + def target_vocab_size(self): + return 2**15 + + @property + def vocab_filename(self): + return "vocab.%s.%d" % (self.dataset_filename(), self.target_vocab_size) + + def feature_encoders(self, data_dir): + vocab_filename = os.path.join(data_dir, self.vocab_filename) + encoder = text_encoder.SubwordTextEncoder(vocab_filename) + # Shared encoder for inputs and targets + return {"inputs": encoder, "targets": encoder} + + def hparams(self, defaults, unused_model_hparams): + p = defaults + p.stop_at_eos = True + + source_vocab_size = self._encoders["inputs"].vocab_size + target_vocab_size = self._encoders["targets"].vocab_size + p.input_modality = { + "inputs": (registry.Modalities.SYMBOL, source_vocab_size) + } + p.target_modality = (registry.Modalities.SYMBOL, target_vocab_size) + + def eval_metrics(self): + return super(WikisumBase, self).eval_metrics() + [ + metrics.Metrics.ROUGE_2_F, metrics.Metrics.ROUGE_L_F + ] + + def generate_lines_for_vocab(self, wikis_dir, refs_dir, max_chars=10**7): + total_chars = 0 + ref_files_by_shard = _references_files_by_shard(refs_dir) + for shard_id in range(cc_utils.NUM_SHARDS): + # Wikipedia articles + for wiki in _wiki_articles(shard_id, wikis_dir): + yield _normalize_text(wiki.title) + EOT + for section in wiki.sections: + yield _format_title(_normalize_text(section.title)) + yield _normalize_text(section.text) + total_chars += len(section.title) + total_chars += len(section.text) + + # References + for i, content in enumerate( + six.itervalues(_references_content(ref_files_by_shard[shard_id]))): + for line in content.split("\n"): + if line: + yield _normalize_text(line) + total_chars += len(line) + + # Make sure we use at least 1k references + if i >= 1000 and total_chars >= max_chars: + break + + if total_chars >= max_chars: + tf.logging.info("Seen enough chars: %d; finished.", max_chars) + break + tf.logging.info("Built vocabulary using %d chars", total_chars) + + def generate_vocab(self, data_dir, wikis_dir, refs_dir): + # Produce a SubwordTextEncoder from a subset of the data + return generator_utils.get_or_generate_vocab_inner( + data_dir, self.vocab_filename, self.target_vocab_size, + self.generate_lines_for_vocab(wikis_dir, refs_dir)) + + def generate_data(self, data_dir, tmp_dir, task_id=-1): + tf.logging.warn("See wikisum/README.md for instructions to generate data.") + + def out_filepaths(self, data_dir): + train_shards = 800 + dev_shards = 100 + test_shards = 100 + train_filepaths = self.training_filepaths( + data_dir, train_shards, shuffled=True) + dev_filepaths = self.dev_filepaths(data_dir, dev_shards, shuffled=True) + test_filepaths = self.test_filepaths(data_dir, test_shards, shuffled=True) + out_filepaths = train_filepaths + dev_filepaths + test_filepaths + out_filepaths.sort() + assert len(out_filepaths) == cc_utils.NUM_SHARDS + return out_filepaths + + +@registry.register_problem +class WikisumCommoncrawl(WikisumBase): + """Wikipedia references->article summarization task based on CommonCrawl.""" + pass + + +@registry.register_problem +class WikisumWeb(WikisumBase): + """Wikipedia references->article summarization task based on web data.""" + pass + + +@registry.register_problem +class WikisumCommoncrawlLeadSection(WikisumCommoncrawl): + """Wikipedia references->lead section summarization task.""" + + def preprocess_example(self, example, mode, hparams): + example["targets"] = _truncate_to_lead_section(example) + return super(WikisumCommoncrawlLeadSection, self).preprocess_example( + example, mode, hparams) + + def dataset_filename(self): + return WikisumCommoncrawl.name + + def generate_data(self, data_dir, tmp_dir, task_id=-1): + tf.logging.warn("Problem %s reuses data from problem %s", self.name, + WikisumCommoncrawl.name) + + +@registry.register_problem +class WikisumWebLeadSection(WikisumWeb): + """Wikipedia references->lead section summarization task.""" + + def preprocess_example(self, example, mode, hparams): + example["targets"] = _truncate_to_lead_section(example) + return super(WikisumWebLeadSection, self).preprocess_example( + example, mode, hparams) + + def dataset_filename(self): + return WikisumWeb.name + + def generate_data(self, data_dir, tmp_dir, task_id=-1): + tf.logging.warn("Problem %s reuses data from problem %s", self.name, + WikisumWeb.name) + + +def make_ref_shard_files(out_dir): + tf.gfile.MakeDirs(out_dir) + opts = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP) + files = [ + tf.python_io.TFRecordWriter( + os.path.join(out_dir, REF_SHARD_FILE % i), opts) + for i in range(cc_utils.NUM_SHARDS) + ] + return files + + +def _truncate_to_lead_section(example): + wiki = example["targets"] + lead_boundary = example["section_boundaries"][0] + # Concat a new EOS to the lead since the original one gets truncated. + lead = tf.concat((wiki[:lead_boundary], [text_encoder.EOS_ID]), 0) + return lead + + +def _make_example_from_record(record): + features = { + "url": + tf.train.Feature(bytes_list=tf.train.BytesList(value=[record.url])), + "content": + tf.train.Feature( + bytes_list=tf.train.BytesList(value=[record.content])), + } + return tf.train.Example(features=tf.train.Features(feature=features)) + + +def _shard_id_for_file(sharded_filename): + suffix = "00000-of-00000" + parts = sharded_filename[-len(suffix):].split("-") + assert len(parts) == 3 + return int(parts[0]) + + +def _references_files_by_shard(refs_dir): + process_dirs = _process_folders(refs_dir) + shards = collections.defaultdict(list) + for d in process_dirs: + ref_files = tf.gfile.Glob(os.path.join(d, REF_SHARD_FILE_PREFIX) + "*") + for f in ref_files: + shards[_shard_id_for_file(f)].append(f) + return shards + + +def _references_content(ref_files): + """Returns dict.""" + with tf.Graph().as_default(): + dataset = tf.data.Dataset.from_tensor_slices(ref_files) + + def _load_records(filename): + return tf.data.TFRecordDataset( + filename, + compression_type=tf.constant("GZIP"), + buffer_size=16 * 1000 * 1000) + + dataset = dataset.flat_map(_load_records) + + def _parse_example(ex_ser): + features = { + "url": tf.VarLenFeature(tf.string), + "content": tf.VarLenFeature(tf.string), + } + ex = tf.parse_single_example(ex_ser, features) + for k in ex.keys(): + ex[k] = ex[k].values[0] + return ex + + dataset = dataset.map(_parse_example, num_parallel_calls=32) + dataset = dataset.prefetch(100) + record_it = dataset.make_one_shot_iterator().get_next() + + data = {} + + with tf.Session() as sess: + i = 0 + while True: + try: + ex = sess.run(record_it) + except tf.errors.OutOfRangeError: + break + + data[ex["url"]] = ex["content"] + i += 1 + + return data + + +def _wiki_urls_for_shard(shard_id, urls_dir=None): + """Urls for chunk: dict ref_urls>.""" + urls_dir = urls_dir or WIKI_URLS_DIR + urls_filepath = os.path.join(urls_dir, WIKI_URLS_FILE % shard_id) + with tf.gfile.GFile(urls_filepath) as f: + return json.loads(f.read()) + + +class WikipediaSection( + collections.namedtuple("WikipediaSection", ["title", "text"])): + pass + + +class WikipediaArticle( + collections.namedtuple("WikipediaArticle", ["url", "title", "sections"])): + pass + + +def _wiki_articles(shard_id, wikis_dir=None): + """Generates WikipediaArticles from GCS that are part of shard shard_id.""" + if not wikis_dir: + wikis_dir = WIKI_CONTENT_DIR + with tf.Graph().as_default(): + dataset = tf.data.TFRecordDataset( + cc_utils.readahead( + os.path.join(wikis_dir, WIKI_CONTENT_FILE % shard_id)), + buffer_size=16 * 1000 * 1000) + + def _parse_example(ex_ser): + """Parse serialized Example containing Wikipedia article content.""" + features = { + "url": tf.VarLenFeature(tf.string), + "title": tf.VarLenFeature(tf.string), + "section_titles": tf.VarLenFeature(tf.string), + "section_texts": tf.VarLenFeature(tf.string), + } + ex = tf.parse_single_example(ex_ser, features) + for k in ex.keys(): + ex[k] = ex[k].values + ex["url"] = ex["url"][0] + ex["title"] = ex["title"][0] + return ex + + dataset = dataset.map(_parse_example, num_parallel_calls=32) + dataset = dataset.prefetch(100) + record_it = dataset.make_one_shot_iterator().get_next() + + with tf.Session() as sess: + while True: + try: + ex = sess.run(record_it) + except tf.errors.OutOfRangeError: + break + + sections = [ + WikipediaSection(title=title, text=text) + for title, text in zip(ex["section_titles"], ex["section_texts"]) + ] + yield WikipediaArticle( + url=ex["url"], title=ex["title"], sections=sections) + + +def _token_counts(text, token_set=None): + counts = collections.defaultdict(int) + for token in tokenizer.encode(text_encoder.native_to_unicode(text)): + if token_set and token not in token_set: + continue + counts[token] += 1 + return counts + + +def _normalize_text(text): + text = text.lower() + # Space around punctuation + text = re.sub("[%s]" % re.escape(string.punctuation), r" \g<0> ", text) + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def _tokens_to_score(tokens): + return {t for t in tokens if re.search("[a-z0-9]", t)} + + +def _rank_reference_paragraphs(wiki_title, references_content): + """Rank and return reference paragraphs by tf-idf score on title tokens.""" + title_tokens = _tokens_to_score(set( + tokenizer.encode(text_encoder.native_to_unicode(wiki_title)))) + ref_paragraph_info = [] + doc_counts = collections.defaultdict(int) + for ref in references_content: + for paragraph in ref.split("\n"): + paragraph = _normalize_text(paragraph) + if cc_utils.filter_paragraph(paragraph): + # Skip paragraph + continue + counts = _token_counts(paragraph, title_tokens) + for token in title_tokens: + if counts[token]: + doc_counts[token] += 1 + info = {"content": paragraph, "counts": counts} + ref_paragraph_info.append(info) + + for info in ref_paragraph_info: + score = 0. + for token in title_tokens: + term_frequency = info["counts"][token] + inv_doc_frequency = ( + float(len(ref_paragraph_info)) / max(doc_counts[token], 1)) + score += term_frequency * math.log(inv_doc_frequency) + info["score"] = score + + ref_paragraph_info.sort(key=lambda el: el["score"], reverse=True) + return [info["content"] for info in ref_paragraph_info] + + +def produce_examples(shard_ids, wikis_dir, refs_dir, urls_dir, vocab_path, + out_filepaths): + """Produce examples from shard_ids to out_filepaths.""" + # * Join the Wikipedia articles with their references + # * Run Tf-idf to sort reference paragraphs + # * Encode the Wikipedia and reference text with the vocabulary + # * Write out TFRecords of tensorflow.Example + tf.logging.info("Processing %d input shards into %d output files.", + len(shard_ids), len(out_filepaths)) + + vocab = text_encoder.SubwordTextEncoder(vocab_path) + eot_ids = vocab.encode(EOT) + + def example_generator(): + """Generate Example dicts.""" + stats = dict(total_original_wikis=0, total_original_refs=0, + total_found_refs=0, ref_lengths=[], wiki_original_refs=[], + wiki_found_refs=[], wikis_skipped_no_refs=0, + wikis_skipped_short_lead=0, num_wikis_written=0) + ref_files_by_shard = _references_files_by_shard(refs_dir) + for shard_id in shard_ids: + tf.logging.info("Processing shard %d", shard_id) + wiki_urls = _wiki_urls_for_shard(shard_id, urls_dir) + tf.logging.info("Loaded wiki URLs for shard") + refs_content = _references_content(ref_files_by_shard[shard_id]) + tf.logging.info("Loaded reference content for shard") + for i, wiki in enumerate(_wiki_articles(shard_id, wikis_dir)): + if not i % 1000: + tf.logging.info("Processing wiki index %d for shard %d", i, shard_id) + stats["total_original_wikis"] += 1 + + # Get reference content + wiki_ref_content = [] + ref_urls = wiki_urls[wiki.url]["refs"] + stats["total_original_refs"] += len(ref_urls) + stats_wiki_original_refs = len(ref_urls) + stats_wiki_found_refs = 0 + for ref_url in ref_urls: + ref_content = refs_content.get(ref_url) + if not ref_content: + continue + stats["total_found_refs"] += 1 + stats["ref_lengths"].append(len(ref_content)) + stats_wiki_found_refs += 1 + wiki_ref_content.append(ref_content) + + stats["wiki_original_refs"].append(stats_wiki_original_refs) + stats["wiki_found_refs"].append(stats_wiki_found_refs) + if not wiki_ref_content or len(wiki_ref_content) < _MIN_REFS: + # No/few refs were found + stats["wikis_skipped_no_refs"] += 1 + continue + + # Rank reference paragraphs with TFIDF + wiki_title = _normalize_text(wiki.title) + ranked_paragraphs = _rank_reference_paragraphs(wiki_title, + wiki_ref_content) + + # Construct inputs from Wiki title and references + inputs = [] + inputs.extend(vocab.encode(wiki_title)) + inputs.extend(eot_ids) + for paragraph in ranked_paragraphs: + if len(inputs) >= 1e6: + break + paragraph += " " + inputs.extend(vocab.encode(paragraph)) + + # Construct targets from article sections + targets, section_boundaries = _encode_wiki_sections( + wiki.sections, vocab) + + # Skip if lead section is too short + if (not section_boundaries or + section_boundaries[0] < _MIN_LEADSECTION_TOKENS): + stats["wikis_skipped_short_lead"] += 1 + continue + + inputs.append(text_encoder.EOS_ID) + targets.append(text_encoder.EOS_ID) + + stats["num_wikis_written"] += 1 + yield { + "inputs": inputs, + "targets": targets, + "section_boundaries": section_boundaries, + } + + tf.logging.info("Total: %d, Skipped: %d", + stats["num_wikis_written"], + stats["total_original_wikis"] - stats["num_wikis_written"]) + tf.logging.info("Total refs: %d, Skipped refs: %d", + stats["total_found_refs"], + stats["total_original_refs"] - stats["total_found_refs"]) + stats_fname = os.path.join(os.path.split(out_filepaths[0])[0], + "stats.%d.json" % shard_ids[0]) + with tf.gfile.Open(stats_fname, "w") as f: + f.write(json.dumps(stats)) + + generator_utils.generate_files(example_generator(), out_filepaths) + + +def _format_title(title): + return " == %s == " % title + + +def _encode_wiki_sections(sections, vocab): + """Encodes sections with vocab. Returns ids and section boundaries.""" + ids = [] + section_boundaries = [] + for i, section in enumerate(sections): + if i > 0: + # Skip including article title + ids.extend(vocab.encode(_format_title(_normalize_text(section.title)))) + ids.extend(vocab.encode(_normalize_text(section.text))) + section_boundaries.append(len(ids)) + + return ids, section_boundaries + + +def _process_folders(tmp_dir): + return tf.gfile.Glob(os.path.join(tmp_dir, PROCESS_FOLDER_PREFIX) + "*") + + +def extract_references_from_wets(wet_files, metadata_dir, out_dir, + tmp_dir=None): + """Extract references from WET files into sharded output files.""" + # Setup output files + shard_files = make_ref_shard_files(out_dir) + + num_refs = 0 + for i, wet_file in enumerate(wet_files): + num_refs_in_wet = 0 + tf.logging.info("Processing file %d", i) + + # Read metadata file + metadata_fname = os.path.join( + metadata_dir, os.path.basename(wet_file)) + cc_utils.METADTA_SUFFIX + with tf.gfile.Open(cc_utils.readahead(metadata_fname)) as f: + wet_metadata = json.loads(f.read()) + + if not wet_metadata: + # No references in this WET file + continue + + if wet_file.startswith("http"): + # download + if not tmp_dir: + tmp_dir = tempfile.gettempdir() + record_gen = cc_utils.wet_records_from_url(wet_file, tmp_dir) + else: + # local + record_gen = cc_utils.wet_records_from_file_obj( + cc_utils.gzip_memfile(wet_file), take_ownership=True) + + for wet_record in record_gen: + shard_ids = wet_metadata.get(wet_record.url) + if not shard_ids: + # URL not in dataset + continue + + # Serialize and write out + ex = _make_example_from_record(wet_record) + ex_str = ex.SerializeToString() + for shard_id in shard_ids: + shard_files[shard_id].write(ex_str) + num_refs += 1 + num_refs_in_wet += 1 + + tf.logging.info("Wrote out %d references for this WET", num_refs_in_wet) + + tf.logging.info("Wrote out %d references total", num_refs) + + # Cleanup + for shard_file in shard_files: + shard_file.close() diff --git a/tensor2tensor/data_generators/wsj_parsing.py b/tensor2tensor/data_generators/wsj_parsing.py index 867277de9..15281b4d0 100644 --- a/tensor2tensor/data_generators/wsj_parsing.py +++ b/tensor2tensor/data_generators/wsj_parsing.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data generators for parsing data-sets.""" import os diff --git a/tensor2tensor/insights/graph.py b/tensor2tensor/insights/graph.py index 17e18ea3c..896f4d8ed 100644 --- a/tensor2tensor/insights/graph.py +++ b/tensor2tensor/insights/graph.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Graph representation for building decoding graph visualizations.""" diff --git a/tensor2tensor/insights/query_processor.py b/tensor2tensor/insights/query_processor.py index 905a20c95..00a2ca297 100644 --- a/tensor2tensor/insights/query_processor.py +++ b/tensor2tensor/insights/query_processor.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """A base class for all query processing classes.""" diff --git a/tensor2tensor/insights/server.py b/tensor2tensor/insights/server.py index e61c465ff..54421db58 100644 --- a/tensor2tensor/insights/server.py +++ b/tensor2tensor/insights/server.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """A GUnicorn + Flask Debug Frontend for Transformer models.""" import json diff --git a/tensor2tensor/insights/transformer_model.py b/tensor2tensor/insights/transformer_model.py index da8cf5fe3..3a4e5101d 100644 --- a/tensor2tensor/insights/transformer_model.py +++ b/tensor2tensor/insights/transformer_model.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """A QueryProcessor using the Transformer framework.""" from collections import deque diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index f8676bf32..1545f477f 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utilities for attention.""" from __future__ import absolute_import from __future__ import division @@ -1464,6 +1463,140 @@ def dot_product_attention_relative(q, return _relative_attention_inner(weights, v, relations_values, False) +def _relative_position_to_absolute_position_masked(x): + """Helper to dot_product_self_attention_relative_v2. + + Rearrange an attention logits or weights Tensor. + + The dimensions of the input represent: + [batch, heads, query_position, memory_position - query_position + length - 1] + + The dimensions of the output represent: + [batch, heads, query_position, memory_position] + + Only works with masked_attention. Undefined behavior for regions of the + input where memory_position > query_position. + + Args: + x: a Tensor with shape [batch, heads, length, length] + + Returns: + a Tensor with shape [batch, heads, length, length] + """ + batch, heads, length, _ = common_layers.shape_list(x) + x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [1, 0]]) + x = tf.reshape(x, [batch, heads, 1 + length, length]) + x = tf.slice(x, [0, 0, 1, 0], [-1, -1, -1, -1]) + return x + + +def _absolute_position_to_relative_position_masked(x): + """Helper to dot_product_self_attention_relative_v2. + + Rearrange an attention logits or weights Tensor. + + The dimensions of the input represent: + [batch, heads, query_position, memory_position] + + The dimensions of the output represent: + [batch, heads, query_position, memory_position - query_position + length - 1] + + Only works with masked_attention. Undefined behavior for regions of the + input where memory_position > query_position. + + Args: + x: a Tensor with shape [batch, heads, length, length] + + Returns: + a Tensor with shape [batch, heads, length, length] + """ + batch, heads, length, _ = common_layers.shape_list(x) + x = tf.pad(x, [[0, 0], [0, 0], [1, 0], [0, 0]]) + x = tf.reshape(x, [batch, heads, length, length + 1]) + x = tf.slice(x, [0, 0, 0, 1], [batch, heads, length, length]) + return x + + +def dot_product_self_attention_relative_v2(q, + k, + v, + bias, + max_length=None, + dropout_rate=0.0, + image_shapes=None, + name=None, + make_image_summary=True, + dropout_broadcast_dims=None): + """Calculate relative position-aware dot-product self-attention. + + Only works for masked self-attention (no looking forward). + TODO(noam): extend to unmasked self-attention + + The attention calculation is augmented with learned representations for the + relative position between each element in q and each element in k and v. + + Args: + q: a Tensor with shape [batch, heads, length, depth]. + k: a Tensor with shape [batch, heads, length, depth]. + v: a Tensor with shape [batch, heads, length, depth]. + bias: bias Tensor. + max_length: an integer - changing this invalidates checkpoints + dropout_rate: a floating point number. + image_shapes: optional tuple of integer scalars. + name: an optional string. + make_image_summary: Whether to make an attention image summary. + dropout_broadcast_dims: an optional list of integers less than 4 + specifying in which dimensions to broadcast the dropout decisions. + saves memory. + + Returns: + A Tensor. + """ + with tf.variable_scope( + name, default_name="dot_product_self_attention_relative_v2", + values=[q, k, v]): + + # This calculation only works for self attention. + # q, k and v must therefore have the same shape. + q.get_shape().assert_is_compatible_with(k.get_shape()) + q.get_shape().assert_is_compatible_with(v.get_shape()) + + # Use separate embeddings suitable for keys and values. + length = common_layers.shape_list(q)[2] + assert max_length is not None + + # [batch, num_heads, query_length, memory_length] + logits = tf.matmul(q, k, transpose_b=True) + + # now add relative logits + # [batch, num_heads, query_length, max_length] + rel_logits = common_layers.dense(q, max_length, name="rel0") + # [batch, num_heads, query_length, max_length] + rel_logits = tf.slice( + rel_logits, [0, 0, 0, max_length - length], [-1, -1, -1, -1]) + rel_logits = _relative_position_to_absolute_position_masked(rel_logits) + logits += rel_logits + + if bias is not None: + logits += bias + weights = tf.nn.softmax(logits, name="attention_weights") + # dropping out the attention links for each of the heads + weights = common_layers.dropout_with_broadcast_dims( + weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) + if expert_utils.should_generate_summaries() and make_image_summary: + attention_image_summary(weights, image_shapes) + ret = tf.matmul(weights, v) + # [batch, num_heads, query_length, memory_length] + relative_weights = _absolute_position_to_relative_position_masked(weights) + # [batch, num_heads, query_length, memory_length] + relative_weights = tf.pad( + relative_weights, [[0, 0], [0, 0], [0, 0], [max_length - length, 0]]) + relative_weights.set_shape([None, None, None, max_length]) + depth_v = common_layers.shape_list(v)[3] + ret += common_layers.dense(relative_weights, depth_v, name="rel1") + return ret + + def masked_within_block_local_attention_1d(q, k, v, block_length=64, name=None): """Attention to the source and a neighborhood to the left within a block. @@ -2445,6 +2578,7 @@ def multihead_attention(query_antecedent, save_weights_to=None, make_image_summary=True, dropout_broadcast_dims=None, + max_length=None, **kwargs): """Multihead scaled-dot-product attention with input/output transformations. @@ -2492,6 +2626,7 @@ def multihead_attention(query_antecedent, dropout_broadcast_dims: an optional list of integers less than 4 specifying in which dimensions to broadcast the dropout decisions. saves memory. + max_length: an integer - needed by relative attention **kwargs (dict): Parameters for the attention function Caching: @@ -2562,6 +2697,11 @@ def multihead_attention(query_antecedent, x = dot_product_attention_relative(q, k, v, bias, max_relative_position, dropout_rate, image_shapes, make_image_summary=make_image_summary) + elif attention_type == "dot_product_relative_v2": + x = dot_product_self_attention_relative_v2( + q, k, v, bias, max_length, dropout_rate, image_shapes, + make_image_summary=make_image_summary, + dropout_broadcast_dims=dropout_broadcast_dims) elif attention_type == "local_within_block_mask_right": x = masked_within_block_local_attention_1d(q, k, v, block_length=block_length) diff --git a/tensor2tensor/layers/common_attention_test.py b/tensor2tensor/layers/common_attention_test.py index 61ff5a6d5..e3c24f5b4 100644 --- a/tensor2tensor/layers/common_attention_test.py +++ b/tensor2tensor/layers/common_attention_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for common attention.""" from __future__ import absolute_import diff --git a/tensor2tensor/layers/common_hparams.py b/tensor2tensor/layers/common_hparams.py index 230579888..509569b0f 100644 --- a/tensor2tensor/layers/common_hparams.py +++ b/tensor2tensor/layers/common_hparams.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Hyperparameters and ranges common to multiple models.""" from __future__ import absolute_import diff --git a/tensor2tensor/layers/common_image_attention.py b/tensor2tensor/layers/common_image_attention.py index f60fa3711..30959398f 100644 --- a/tensor2tensor/layers/common_image_attention.py +++ b/tensor2tensor/layers/common_image_attention.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utils for attention mechanism for images.""" # Dependency imports @@ -543,7 +542,8 @@ def prepare_decoder(targets, hparams): # during training, images are [batch, IMG_LEN, IMG_LEN, 3]. # At inference, they are [batch, curr_infer_length, 1, 1] - if hparams.mode == tf.contrib.learn.ModeKeys.INFER: + if (hparams.mode == tf.contrib.learn.ModeKeys.INFER and + hparams.block_raster_scan): curr_infer_length = targets_shape[1] if hparams.block_raster_scan: assert hparams.img_len*channels % hparams.query_shape[1] == 0 diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index ca5f3efc8..1a7c2500b 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -12,20 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Layers common to multiple models.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from collections import defaultdict + import contextlib import functools +from functools import partial import math import random # Dependency imports + import numpy as np from six.moves import range # pylint: disable=redefined-builtin from tensor2tensor.utils import expert_utils as eu @@ -1384,18 +1386,49 @@ def maybe_zero_out_padding(inputs, kernel_size, nonpadding_mask): return inputs -def dense_relu_dense(inputs, filter_size, output_size, dropout=0.0, - dropout_broadcast_dims=None): +def dense_relu_dense(inputs, + filter_size, + output_size, + output_activation=None, + dropout=0.0, + dropout_broadcast_dims=None, + name=None): """Hidden layer with RELU activation followed by linear projection.""" + layer_name = "%s_{}" % name if name else "{}" h = dense( - inputs, filter_size, use_bias=True, activation=tf.nn.relu, name="conv1") + inputs, + filter_size, + use_bias=True, + activation=tf.nn.relu, + name=layer_name.format("conv1")) + if dropout != 0.0: h = dropout_with_broadcast_dims( h, 1.0 - dropout, broadcast_dims=dropout_broadcast_dims) - o = dense(h, output_size, use_bias=True, name="conv2") + o = dense( + h, + output_size, + activation=output_activation, + use_bias=True, + name=layer_name.format("conv2")) return o +def dense_dropconnect(inputs, + output_size, + dropconnect_dropout=0.0, + name="dense_dropconnect", + **kwargs): + """Dense layer with dropconnect.""" + + if dropconnect_dropout != 0.0: + tf.logging.info("Applying dropconnect as the kernel regularization.") + kwargs["kernel_regularizer"] = partial( + tf.nn.dropout, keep_prob=1.0 - dropconnect_dropout) + + return dense(inputs, output_size, use_bias=True, name=name, **kwargs) + + def conv_relu_conv(inputs, filter_size, output_size, @@ -2765,10 +2798,18 @@ def dense(x, units, **kwargs): def mix(x1, x2, steps, is_training, - min_prob=0.0, max_prob=1.0, mode="lin", simple=False): + min_prob=0.0, max_prob=1.0, + mode="lin", simple=False, broadcast_last=False): """Mix starting with x2, mixing mixing, going towards x1.""" if not is_training: - return x1 + if max_prob >= 1.0: + return x1 + alpha_shape = shape_list(x1) + if broadcast_last: + alpha_shape = alpha_shape[:-1] + [1] + alpha = tf.random_uniform(alpha_shape) + alpha = tf.to_float(tf.less(alpha, max_prob)) + return alpha * x1 + (1.0 - alpha) * x2 def get_res(): """Create the result. Separate function to speed it up later (see below).""" @@ -2779,7 +2820,10 @@ def get_res(): alpha_p = alpha_p * (max_prob - min_prob) + min_prob if simple: return alpha_p * x1 + (1.0 - alpha_p) * x2 - alpha = tf.random_uniform(shape_list(x1)) + alpha_shape = shape_list(x1) + if broadcast_last: + alpha_shape = alpha_shape[:-1] + [1] + alpha = tf.random_uniform(alpha_shape) alpha = tf.to_float(tf.less(alpha, alpha_p)) return alpha * x1 + (1.0 - alpha) * x2 @@ -2807,3 +2851,32 @@ def belu(x): y1 = tf.nn.elu(x1) y2 = -tf.nn.elu(-x2) return tf.reshape(tf.concat([y1, y2], axis=-1), x_shape) + + +def argmax_with_score(logits, axis=None): + """Argmax along with the value.""" + axis = axis or len(logits.get_shape()) - 1 + predictions = tf.argmax(logits, axis=axis) + + logits_shape = shape_list(logits) + prefix_shape, vocab_size = logits_shape[:-1], logits_shape[-1] + prefix_size = 1 + for d in prefix_shape: + prefix_size *= d + + # Flatten to extract scores + flat_logits = tf.reshape(logits, [prefix_size, vocab_size]) + flat_predictions = tf.reshape(predictions, [prefix_size]) + flat_indices = tf.stack( + [tf.range(tf.to_int64(prefix_size)), + tf.to_int64(flat_predictions)], axis=1) + flat_scores = tf.gather_nd(flat_logits, flat_indices) + + # Unflatten + scores = tf.reshape(flat_scores, prefix_shape) + + return predictions, scores + + +def log_prob_from_logits(logits, reduce_axis=-1): + return logits - tf.reduce_logsumexp(logits, axis=reduce_axis, keep_dims=True) diff --git a/tensor2tensor/layers/common_layers_test.py b/tensor2tensor/layers/common_layers_test.py index 2cc36d42d..c715f2fa1 100644 --- a/tensor2tensor/layers/common_layers_test.py +++ b/tensor2tensor/layers/common_layers_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for common layers.""" from __future__ import absolute_import diff --git a/tensor2tensor/layers/discretization.py b/tensor2tensor/layers/discretization.py index f0fc57391..f29024be3 100644 --- a/tensor2tensor/layers/discretization.py +++ b/tensor2tensor/layers/discretization.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Discretization bottlenecks used to train discrete latent variables.""" from __future__ import absolute_import @@ -683,9 +682,9 @@ def isemhash_bottleneck(x, bottleneck_size, bottleneck_noise, noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0 d *= noise - d = common_layers.mix(d, 2.0 * y - 1.0, discretize_warmup_steps, - mode == tf.estimator.ModeKeys.TRAIN, - max_prob=isemhash_mix_prob) + d = common_layers.mix(d, 2.0 * y - 1.0, discretize_warmup_steps, + mode == tf.estimator.ModeKeys.TRAIN, + max_prob=isemhash_mix_prob) return d diff --git a/tensor2tensor/layers/discretization_test.py b/tensor2tensor/layers/discretization_test.py index 74eb3d6fb..25576e09d 100644 --- a/tensor2tensor/layers/discretization_test.py +++ b/tensor2tensor/layers/discretization_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for tensor2tensor.layers.discretization.""" from __future__ import absolute_import diff --git a/tensor2tensor/layers/modalities.py b/tensor2tensor/layers/modalities.py index 5dba2c6f9..6f3dd8f49 100644 --- a/tensor2tensor/layers/modalities.py +++ b/tensor2tensor/layers/modalities.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Modalities define the bottom and top of the model (not the body).""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/layers/modalities_test.py b/tensor2tensor/layers/modalities_test.py index 1c305ed17..949e0b817 100644 --- a/tensor2tensor/layers/modalities_test.py +++ b/tensor2tensor/layers/modalities_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for Modalities.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/layers/rev_block.py b/tensor2tensor/layers/rev_block.py index a6e462f7b..964b294d0 100644 --- a/tensor2tensor/layers/rev_block.py +++ b/tensor2tensor/layers/rev_block.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Reversible Residual Block. From diff --git a/tensor2tensor/layers/rev_block_test.py b/tensor2tensor/layers/rev_block_test.py index 62b167e72..6c3a10be7 100644 --- a/tensor2tensor/layers/rev_block_test.py +++ b/tensor2tensor/layers/rev_block_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for RevBlock.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/__init__.py b/tensor2tensor/models/__init__.py index 76c51f581..883193316 100644 --- a/tensor2tensor/models/__init__.py +++ b/tensor2tensor/models/__init__.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Models defined in T2T. Imports here force registration.""" from __future__ import absolute_import from __future__ import division @@ -48,6 +47,7 @@ from tensor2tensor.models.research import gene_expression from tensor2tensor.models.research import lm_experiments from tensor2tensor.models.research import multimodel +from tensor2tensor.models.research import r_transformer from tensor2tensor.models.research import rl from tensor2tensor.models.research import super_lm from tensor2tensor.models.research import transformer_moe diff --git a/tensor2tensor/models/basic.py b/tensor2tensor/models/basic.py index d6fdc6101..1541d8cfb 100644 --- a/tensor2tensor/models/basic.py +++ b/tensor2tensor/models/basic.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Basic models for testing simple tasks.""" from __future__ import absolute_import @@ -33,7 +32,7 @@ class BasicFcRelu(t2t_model.T2TModel): def body(self, features): - hparams = self._hparams + hparams = self.hparams x = features["inputs"] shape = common_layers.shape_list(x) x = tf.reshape(x, [-1, shape[1] * shape[2] * shape[3]]) @@ -54,7 +53,7 @@ def __init__(self, *args, **kwargs): def bottleneck(self, x): with tf.variable_scope("bottleneck"): - hparams = self._hparams + hparams = self.hparams x = tf.layers.dense(x, hparams.bottleneck_size, name="bottleneck") if hparams.mode == tf.estimator.ModeKeys.TRAIN: noise = 2.0 * tf.random_uniform(common_layers.shape_list(x)) - 1.0 @@ -69,12 +68,27 @@ def unbottleneck(self, x, res_size): def bottleneck_loss(self, b): return 0.0 + def make_even_size(self, x): + shape = [dim if dim is not None else -1 for dim in x.get_shape().as_list()] + if shape[1] % 2 == 0 and shape[2] % 2 == 0: + return x + if shape[1] % 2 == 0 and self.is1d: + return x + x, _ = common_layers.pad_to_same_length( + x, x, final_length_divisible_by=2, axis=1) + if self.is1d: + return x + x, _ = common_layers.pad_to_same_length( + x, x, final_length_divisible_by=2, axis=2) + return x + def encoder(self, x): with tf.variable_scope("encoder"): - hparams = self._hparams + hparams = self.hparams kernel, strides = self._get_kernel_and_strides() # Down-convolutions. for i in range(hparams.num_hidden_layers): + x = self.make_even_size(x) x = tf.layers.conv2d( x, hparams.hidden_size * 2**(i + 1), kernel, strides=strides, padding="SAME", activation=common_layers.belu, name="conv_%d" % i) @@ -83,7 +97,7 @@ def encoder(self, x): def decoder(self, x): with tf.variable_scope("decoder"): - hparams = self._hparams + hparams = self.hparams kernel, strides = self._get_kernel_and_strides() # Up-convolutions. for i in range(hparams.num_hidden_layers): @@ -95,19 +109,13 @@ def decoder(self, x): return x def body(self, features): - hparams = self._hparams + hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d - x, _ = common_layers.pad_to_same_length( - x, x, final_length_divisible_by=2**hparams.num_hidden_layers, axis=1) - if not is1d: - x, _ = common_layers.pad_to_same_length( - x, x, final_length_divisible_by=2**hparams.num_hidden_layers, - axis=2) # Run encoder. x = self.encoder(x) # Bottleneck (mix during early training, not too important but stable). @@ -123,13 +131,13 @@ def body(self, features): x = b else: b = self.sample() - res_size = self._hparams.hidden_size * 2**self._hparams.num_hidden_layers + res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: - return x + return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] res = common_layers.mix(res, features["targets"], @@ -137,7 +145,7 @@ def body(self, features): return res, {"bottleneck_loss": b_loss} def sample(self): - hp = self._hparams + hp = self.hparams div_x = 2**hp.num_hidden_layers div_y = 1 if self.is1d else 2**hp.num_hidden_layers size = [hp.batch_size, hp.sample_height // div_x, hp.sample_width // div_y, @@ -159,11 +167,11 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, # Sample and decode. # TODO(lukaszkaiser): is this a universal enough way to get channels? try: - num_channels = self._hparams.problem.num_channels + num_channels = self.hparams.problem.num_channels except AttributeError: num_channels = 1 features["targets"] = tf.zeros( - [self._hparams.batch_size, 1, 1, num_channels], + [self.hparams.batch_size, 1, 1, num_channels], dtype=tf.int32) logits, _ = self(features) # pylint: disable=not-callable samples = tf.argmax(logits, axis=-1) @@ -176,7 +184,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, return samples def _get_kernel_and_strides(self): - hparams = self._hparams + hparams = self.hparams kernel = (hparams.kernel_height, hparams.kernel_width) kernel = (hparams.kernel_height, 1) if self.is1d else kernel strides = (2, 1) if self.is1d else (2, 2) diff --git a/tensor2tensor/models/basic_test.py b/tensor2tensor/models/basic_test.py index 5a07a5502..1b581b718 100644 --- a/tensor2tensor/models/basic_test.py +++ b/tensor2tensor/models/basic_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Basic nets tests.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/bytenet.py b/tensor2tensor/models/bytenet.py index 74f46c27c..6f65adb03 100644 --- a/tensor2tensor/models/bytenet.py +++ b/tensor2tensor/models/bytenet.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ByteNet.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/bytenet_test.py b/tensor2tensor/models/bytenet_test.py index d077a5b69..fc920487c 100644 --- a/tensor2tensor/models/bytenet_test.py +++ b/tensor2tensor/models/bytenet_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ByteNet tests.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/distillation.py b/tensor2tensor/models/distillation.py index 3f468f54c..3630cc334 100644 --- a/tensor2tensor/models/distillation.py +++ b/tensor2tensor/models/distillation.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Traditional Student-Teacher Distillation.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/image_transformer.py b/tensor2tensor/models/image_transformer.py index f0130b195..834a80d34 100644 --- a/tensor2tensor/models/image_transformer.py +++ b/tensor2tensor/models/image_transformer.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """image generation with transformer (attention). encoder: [Self-Attention, Feed-forward] x n diff --git a/tensor2tensor/models/image_transformer_2d.py b/tensor2tensor/models/image_transformer_2d.py index cdcd0c654..07aa1231c 100644 --- a/tensor2tensor/models/image_transformer_2d.py +++ b/tensor2tensor/models/image_transformer_2d.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """image generation with transformer (attention). encoder: [Self-Attention, Feed-forward] x n diff --git a/tensor2tensor/models/image_transformer_2d_test.py b/tensor2tensor/models/image_transformer_2d_test.py index 4098792a4..42cb0ce58 100644 --- a/tensor2tensor/models/image_transformer_2d_test.py +++ b/tensor2tensor/models/image_transformer_2d_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for Transformer.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/image_transformer_test.py b/tensor2tensor/models/image_transformer_test.py index a997a6bc5..9c9110d8e 100644 --- a/tensor2tensor/models/image_transformer_test.py +++ b/tensor2tensor/models/image_transformer_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for Transformer.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/lstm.py b/tensor2tensor/models/lstm.py index d05c1f599..e2b23c067 100644 --- a/tensor2tensor/models/lstm.py +++ b/tensor2tensor/models/lstm.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """RNN LSTM models.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/lstm_test.py b/tensor2tensor/models/lstm_test.py index c392f23fd..e22760311 100644 --- a/tensor2tensor/models/lstm_test.py +++ b/tensor2tensor/models/lstm_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """LSTMSeq2Seq models tests.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/neural_gpu.py b/tensor2tensor/models/neural_gpu.py index 7d6433b92..e278d4606 100644 --- a/tensor2tensor/models/neural_gpu.py +++ b/tensor2tensor/models/neural_gpu.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """The Neural GPU model and its variants.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/neural_gpu_test.py b/tensor2tensor/models/neural_gpu_test.py index bbf9f1d4d..08fb2b18c 100644 --- a/tensor2tensor/models/neural_gpu_test.py +++ b/tensor2tensor/models/neural_gpu_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for Neural GPU.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/research/adafactor_experiments.py b/tensor2tensor/models/research/adafactor_experiments.py index d7d3d4e2c..e9753c6c8 100644 --- a/tensor2tensor/models/research/adafactor_experiments.py +++ b/tensor2tensor/models/research/adafactor_experiments.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Experiments with Adafactor. """ diff --git a/tensor2tensor/models/research/aligned.py b/tensor2tensor/models/research/aligned.py index 5f19d4db1..ed048a68c 100644 --- a/tensor2tensor/models/research/aligned.py +++ b/tensor2tensor/models/research/aligned.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Single stack of transformations with no masking. Produces output aligned with inputs. diff --git a/tensor2tensor/models/research/attention_lm.py b/tensor2tensor/models/research/attention_lm.py index bf7315f07..1d2d2acb0 100644 --- a/tensor2tensor/models/research/attention_lm.py +++ b/tensor2tensor/models/research/attention_lm.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Self-attention based language model. Like transformer.py, but no encoder diff --git a/tensor2tensor/models/research/attention_lm_moe.py b/tensor2tensor/models/research/attention_lm_moe.py index 14b633495..6fd549cbe 100644 --- a/tensor2tensor/models/research/attention_lm_moe.py +++ b/tensor2tensor/models/research/attention_lm_moe.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Self-attention based language model. Like transformer.py, but no encoder diff --git a/tensor2tensor/models/research/autoencoders.py b/tensor2tensor/models/research/autoencoders.py index d9c852742..4a05af024 100644 --- a/tensor2tensor/models/research/autoencoders.py +++ b/tensor2tensor/models/research/autoencoders.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Autoencoders.""" from __future__ import absolute_import @@ -34,17 +33,29 @@ class AutoencoderAutoregressive(basic.BasicAutoencoder): """Autoencoder with an autoregressive part.""" def body(self, features): - hparams = self._hparams - shape = common_layers.shape_list(features["targets"]) + hparams = self.hparams + is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Run the basic autoencoder part first. basic_result, losses = super(AutoencoderAutoregressive, self).body(features) + shape = common_layers.shape_list(basic_result) + basic1d = tf.reshape(basic_result, [shape[0], -1, shape[3]]) + # During autoregressive inference, don't resample. + if hparams.mode == tf.estimator.ModeKeys.PREDICT: + if hasattr(hparams, "sampled_basic1d_tensor"): + basic1d = hparams.sampled_basic1d_tensor + else: + hparams.sampled_basic1d_tensor = basic1d # Prepare inputs for autoregressive modes. - targets_keep_prob = 1.0 - hparams.autoregressive_dropout - targets_dropout = common_layers.dropout_with_broadcast_dims( - features["targets"], targets_keep_prob, broadcast_dims=[-1]) + if common_layers.shape_list(features["targets"])[1] == 1: + # This happens on the first step of predicitions. + assert hparams.mode == tf.estimator.ModeKeys.PREDICT + features["targets"] = tf.zeros_like(basic_result) + targets_dropout = common_layers.mix( + features["targets"], tf.zeros_like(basic_result), + hparams.bottleneck_warmup_steps, is_training, + max_prob=1.0 - hparams.autoregressive_dropout, broadcast_last=True) targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]]) targets_shifted = common_layers.shift_right_3d(targets1d) - basic1d = tf.reshape(basic_result, [shape[0], -1, shape[3]]) concat1d = tf.concat([basic1d, targets_shifted], axis=-1) # The forget_base hparam sets purely-autoregressive mode, no autoencoder. if hparams.autoregressive_forget_base: @@ -74,6 +85,57 @@ def body(self, features): raise ValueError("Unsupported autoregressive mode: %s" % hparams.autoregressive_mode) + def infer(self, features=None, *args, **kwargs): + """Produce predictions from the model by sampling.""" + # Inputs and features preparation needed to handle edge cases. + if not features: + features = {} + inputs_old = None + if "inputs" in features and len(features["inputs"].shape) < 4: + inputs_old = features["inputs"] + features["inputs"] = tf.expand_dims(features["inputs"], 2) + + # Sample first. + try: + num_channels = self.hparams.problem.num_channels + except AttributeError: + num_channels = 1 + features["targets"] = tf.zeros( + [self.hparams.batch_size, 1, 1, num_channels], + dtype=tf.int32) + logits, _ = self(features) # pylint: disable=not-callable + samples = common_layers.sample_with_temperature( + logits, 0.0) + shape = common_layers.shape_list(samples) + + # Sample again if requested for the autoregressive part. + extra_samples = 0 + self.hparams.autoregressive_dropout = 0.2 + for i in range(extra_samples): + if i == extra_samples - 2: + self.hparams.autoregressive_dropout -= 0.1 + self.hparams.sampling_temp /= 2 + if i == extra_samples - 1: + self.hparams.autoregressive_dropout -= 0.1 + self.hparams.sampling_temp = 0.0 + features["targets"] = samples + old_samples1d = tf.reshape(samples, [shape[0], -1, shape[3]]) + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + logits, _ = self(features) # pylint: disable=not-callable + samples = common_layers.sample_with_temperature( + logits, self.hparams.sampling_temp) + samples1d = tf.reshape(samples, [shape[0], -1, shape[3]]) + samples1d = tf.concat([old_samples1d[:, :i, :], samples1d[:, i:, :]], + axis=1) + samples = tf.reshape(samples1d, shape) + + # Restore inputs to not confuse Estimator in edge cases. + if inputs_old is not None: + features["inputs"] = inputs_old + + # Return samples. + return samples + @registry.register_model class AutoencoderResidual(AutoencoderAutoregressive): @@ -81,7 +143,7 @@ class AutoencoderResidual(AutoencoderAutoregressive): def encoder(self, x): with tf.variable_scope("encoder"): - hparams = self._hparams + hparams = self.hparams kernel, strides = self._get_kernel_and_strides() residual_kernel = (hparams.residual_kernel_height, hparams.residual_kernel_width) @@ -93,6 +155,7 @@ def encoder(self, x): # Down-convolutions. for i in range(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % i): + x = self.make_even_size(x) x = tf.nn.dropout(x, 1.0 - hparams.dropout) filters = hparams.hidden_size * 2**(i + 1) filters = min(filters, hparams.max_hidden_size) @@ -115,7 +178,7 @@ def encoder(self, x): def decoder(self, x): with tf.variable_scope("decoder"): - hparams = self._hparams + hparams = self.hparams kernel, strides = self._get_kernel_and_strides() residual_kernel = (hparams.residual_kernel_height, hparams.residual_kernel_width) @@ -156,7 +219,7 @@ class AutoencoderBasicDiscrete(AutoencoderAutoregressive): """Discrete autoencoder.""" def bottleneck(self, x): - hparams = self._hparams + hparams = self.hparams x = tf.tanh(tf.layers.dense(x, hparams.bottleneck_size, name="bottleneck")) d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x) if hparams.mode == tf.estimator.ModeKeys.TRAIN: @@ -168,7 +231,7 @@ def bottleneck(self, x): return x def sample(self): - hp = self._hparams + hp = self.hparams div_x = 2**hp.num_hidden_layers div_y = 1 if self.is1d else 2**hp.num_hidden_layers size = [hp.batch_size, hp.sample_height // div_x, hp.sample_width // div_y, @@ -183,24 +246,25 @@ class AutoencoderResidualDiscrete(AutoencoderResidual): def bottleneck(self, x, bottleneck_size=None): if bottleneck_size is not None: - old_bottleneck_size = self._hparams.bottleneck_size - self._hparams.bottleneck_size = bottleneck_size - res = discretization.parametrized_bottleneck(x, self._hparams) + old_bottleneck_size = self.hparams.bottleneck_size + self.hparams.bottleneck_size = bottleneck_size + res = discretization.parametrized_bottleneck(x, self.hparams) if bottleneck_size is not None: - self._hparams.bottleneck_size = old_bottleneck_size + self.hparams.bottleneck_size = old_bottleneck_size return res def unbottleneck(self, x, res_size): - return discretization.parametrized_unbottleneck(x, res_size, self._hparams) + return discretization.parametrized_unbottleneck(x, res_size, self.hparams) def bottleneck_loss(self, b): part = tf.random_uniform(common_layers.shape_list(b)) selection = tf.to_float(tf.less(part, tf.random_uniform([]))) - part_avg = tf.abs(tf.reduce_sum(b * selection)) / tf.reduce_sum(selection) + selection_size = tf.reduce_sum(selection) + part_avg = tf.abs(tf.reduce_sum(b * selection)) / (selection_size + 1) return part_avg def sample(self): - hp = self._hparams + hp = self.hparams div_x = 2**hp.num_hidden_layers div_y = 1 if self.is1d else 2**hp.num_hidden_layers size = [hp.batch_size, hp.sample_height // div_x, hp.sample_width // div_y, @@ -209,7 +273,8 @@ def sample(self): res = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 # If you want to set some first bits to a fixed value, do this: # fixed = tf.zeros_like(rand) - 1.0 - # res = tf.concat([fixed[:, :, :, :2], res[:, :, :, 2:]], axis=-1) + # nbits = 3 + # res = tf.concat([fixed[:, :, :, :nbits], res[:, :, :, nbits:]], axis=-1) return res @@ -218,27 +283,23 @@ class AutoencoderOrderedDiscrete(AutoencoderResidualDiscrete): """Ordered discrete autoencoder.""" def bottleneck(self, x): - hparams = self._hparams + hparams = self.hparams + noise = hparams.bottleneck_noise + hparams.bottleneck_noise = 0.0 # We'll add noise below. x = discretization.parametrized_bottleneck(x, hparams) + hparams.bottleneck_noise = noise if hparams.mode == tf.estimator.ModeKeys.TRAIN: - # In the ordered case, we'll have no noise on top bits, let's make a mask. - # Start with randomly uniformly choosing numbers [0, number_of_bits) where - # the number of bits in our case is bottleneck size. We pick separately - # for every position and batch just to keep it varied. - no_noise_mask = tf.random_uniform(common_layers.shape_list(x)[:-1]) - no_noise_mask *= hparams.bottleneck_size - # Now let's make a 1-hot vector that is 1 on the index i from which on - # we want to be noisy and 0 everywhere else. - no_noise_mask = tf.one_hot(tf.to_int32(no_noise_mask), - hparams.bottleneck_size) - # Use tf.cumsum to make the mask (0 before index i, 1 after index i). - no_noise_mask = tf.cumsum(no_noise_mask, axis=-1) + # We want a number p such that p^bottleneck_size = 1 - noise. + # So log(p) * bottleneck_size = log(noise) + log_p = tf.log(1 - float(noise) / 2) / float(hparams.bottleneck_size) + # Probabilities of flipping are p, p^2, p^3, ..., p^bottleneck_size. + noise_mask = 1.0 - tf.exp(tf.cumsum(tf.zeros_like(x) + log_p, axis=-1)) # Having the no-noise mask, we can make noise just uniformly at random. - ordered_noise = tf.random_uniform(tf.shape(x)) * no_noise_mask + ordered_noise = tf.random_uniform(tf.shape(x)) # We want our noise to be 1s at the start and random {-1, 1} bits later. - ordered_noise = 2.0 * tf.to_float(tf.less(ordered_noise, 0.5)) - 1.0 + ordered_noise = tf.to_float(tf.less(noise_mask, ordered_noise)) # Now we flip the bits of x on the noisy positions (ordered and normal). - x *= ordered_noise + x *= 2.0 * ordered_noise - 1 return x @@ -289,7 +350,7 @@ def full_stack(self, b, x_size, bottleneck_size, losses, is_training, i): return tf.reshape(b1, b_shape) def body(self, features): - hparams = self._hparams + hparams = self.hparams num_stacks = hparams.num_hidden_layers hparams.num_hidden_layers = 1 is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN @@ -322,7 +383,7 @@ def body(self, features): x = b else: b = self.sample() - res_size = self._hparams.hidden_size * 2**self._hparams.num_hidden_layers + res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. diff --git a/tensor2tensor/models/research/autoencoders_test.py b/tensor2tensor/models/research/autoencoders_test.py index 9cdcd139a..23c8108e9 100644 --- a/tensor2tensor/models/research/autoencoders_test.py +++ b/tensor2tensor/models/research/autoencoders_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Autoencoders tests.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/research/basic_conv_gen.py b/tensor2tensor/models/research/basic_conv_gen.py index f35509237..9d4a810bb 100644 --- a/tensor2tensor/models/research/basic_conv_gen.py +++ b/tensor2tensor/models/research/basic_conv_gen.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Basic models for testing simple tasks.""" from __future__ import absolute_import @@ -91,8 +90,7 @@ def body(self, features): reward_pred = tf.reduce_mean(x, axis=[1, 2], keep_dims=True) return {"targets": x, "target_reward": reward_pred} - def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, - alpha=0.0): + def infer(self, features=None, *args, **kwargs): """Produce predictions from the model by running it.""" # Inputs and features preparation needed to handle edge cases. if not features: @@ -144,11 +142,11 @@ def basic_conv(): hparams.learning_rate_constant = 0.0002 hparams.learning_rate_warmup_steps = 500 hparams.learning_rate_schedule = "constant * linear_warmup" - hparams.label_smoothing = 0.05 + hparams.label_smoothing = 0.0 hparams.initializer = "uniform_unit_scaling" hparams.initializer_gain = 1.0 hparams.weight_decay = 0.0 - hparams.dropout = 0.1 + hparams.dropout = 0.2 hparams.add_hparam("num_compress_steps", 5) return hparams diff --git a/tensor2tensor/models/research/cycle_gan.py b/tensor2tensor/models/research/cycle_gan.py index 55cbe350d..bba12768c 100644 --- a/tensor2tensor/models/research/cycle_gan.py +++ b/tensor2tensor/models/research/cycle_gan.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Cycle GAN.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/research/gene_expression.py b/tensor2tensor/models/research/gene_expression.py index abe0a4834..97134184c 100644 --- a/tensor2tensor/models/research/gene_expression.py +++ b/tensor2tensor/models/research/gene_expression.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Models for gene expression from DNA.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/models/research/gene_expression_test.py b/tensor2tensor/models/research/gene_expression_test.py index 70403935c..06779c978 100644 --- a/tensor2tensor/models/research/gene_expression_test.py +++ b/tensor2tensor/models/research/gene_expression_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for Gene Expression models.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/models/research/lm_experiments.py b/tensor2tensor/models/research/lm_experiments.py index d33206d59..10f2b943a 100644 --- a/tensor2tensor/models/research/lm_experiments.py +++ b/tensor2tensor/models/research/lm_experiments.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Experiments with Language Models. Train languagemodel_lm1b32k_packed and measure log-ppl/token (dev). @@ -98,3 +97,21 @@ def lmx_h4k_f16k(): hparams.batch_size = 1024 hparams.weight_dtype = "bfloat16" return hparams + + +@registry.register_hparams +def lmx_relative(): + """Language model using relative attention.""" + hparams = lmx_base() + hparams.self_attention_type = "dot_product_relative_v2" + hparams.activation_dtype = "float32" + hparams.weight_dtype = "float32" + return hparams + + +@registry.register_hparams +def lmx_relative_nopos(): + """Language model using relative attention and no positional encoding.""" + hparams = lmx_relative() + hparams.pos = "none" + return hparams diff --git a/tensor2tensor/models/research/multimodel.py b/tensor2tensor/models/research/multimodel.py index 4b3d93445..ccb62bae2 100644 --- a/tensor2tensor/models/research/multimodel.py +++ b/tensor2tensor/models/research/multimodel.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """MultiModel.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/models/research/multimodel_test.py b/tensor2tensor/models/research/multimodel_test.py index c480d23e1..64a510dab 100644 --- a/tensor2tensor/models/research/multimodel_test.py +++ b/tensor2tensor/models/research/multimodel_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for Xnet.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/research/r_transformer.py b/tensor2tensor/models/research/r_transformer.py new file mode 100644 index 000000000..9af75e8fc --- /dev/null +++ b/tensor2tensor/models/research/r_transformer.py @@ -0,0 +1,657 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers with depthwise recurrency (go/r-transformer). + + +A high-level explanation on the idea and the architecture: + +The vanilla Transformer model has no recurrence and struggles with some tasks +that a fully recurrent model can easily solve. Instead of incorporating +recurrence in time (which has a dependency on sequence length T), +we apply recurrence in depth (which we can set to some fixed length D << T), +and apply self-attention instead of sequential processing to enable the model +to incorporate long-range dependencies. + +Structure of the code is explained in r_transformer_util.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from tensor2tensor.layers import common_attention +from tensor2tensor.layers import common_layers +from tensor2tensor.models import transformer +from tensor2tensor.models.research import r_transformer_util +from tensor2tensor.utils import registry + +import tensorflow as tf + + +@registry.register_model +class RTransformer(transformer.Transformer): + """R-Transformer: Depth-wise recurrent transoformer model.""" + + def encode(self, inputs, target_space, hparams, features=None): + """Encode r-transformer inputs. + + It is similar to "transformer.encode", but it uses + "r_transformer_util.r_transformer_encoder" instead of + "transformer.transformer_encoder". + + Args: + inputs: Transformer inputs [batch_size, input_length, input_height, + hidden_dim] which will be flattened along the two spatial dimensions. + target_space: scalar, target space ID. + hparams: hyperparmeters for model. + features: optionally pass the entire features dictionary as well. + This is needed now for "packed" datasets. + + Returns: + Tuple of: + encoder_output: Encoder representation. + [batch_size, input_length, hidden_dim] + encoder_decoder_attention_bias: Bias and mask weights for + encoder-decoder attention. [batch_size, input_length] + encoder_extra_output: which is extra encoder output used in some + variants of the model (e.g. in ACT, to pass the ponder-time to body) + """ + + inputs = common_layers.flatten4d3d(inputs) + + encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( + transformer.transformer_prepare_encoder( + inputs, target_space, hparams, features=features)) + + encoder_input = tf.nn.dropout(encoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + + (encoder_output, + encoder_extra_output) = r_transformer_util.r_transformer_encoder( + encoder_input, + self_attention_bias, + hparams, + nonpadding=transformer.features_to_nonpadding(features, "inputs"), + save_weights_to=self.attention_weights) + + return encoder_output, encoder_decoder_attention_bias, encoder_extra_output + + def decode(self, + decoder_input, + encoder_output, + encoder_decoder_attention_bias, + decoder_self_attention_bias, + hparams, + nonpadding=None): + """Decode R-Transformer outputs from encoder representation. + + It is similar to "transformer.decode", but it uses + "r_transformer_util.r_transformer_decoder" instead of + "transformer.transformer_decoder". + + Args: + decoder_input: inputs to bottom of the model. [batch_size, decoder_length, + hidden_dim] + encoder_output: Encoder representation. [batch_size, input_length, + hidden_dim] + encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder + attention. [batch_size, input_length] + decoder_self_attention_bias: Bias and mask weights for decoder + self-attention. [batch_size, decoder_length] + hparams: hyperparmeters for model. + nonpadding: optional Tensor with shape [batch_size, decoder_length] + + Returns: + Tuple of: + Final decoder representation. [batch_size, decoder_length, + hidden_dim] + encoder_extra_output: which is extra encoder output used in some + variants of the model (e.g. in ACT, to pass the ponder-time to body) + + """ + + decoder_input = tf.nn.dropout(decoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + + # No caching in r-transformers! + decoder_output, dec_extra_output = r_transformer_util.r_transformer_decoder( + decoder_input, + encoder_output, + decoder_self_attention_bias, + encoder_decoder_attention_bias, + hparams, + nonpadding=nonpadding, + save_weights_to=self.attention_weights) + + # Expand since t2t expects 4d tensors. + return tf.expand_dims(decoder_output, axis=2), dec_extra_output + + def body(self, features): + """R-Transformer main model_fn. + + + Args: + features: Map of features to the model. Should contain the following: + "inputs": Transformer inputs [batch_size, input_length, hidden_dim] + "targets": Target decoder outputs. + [batch_size, decoder_length, hidden_dim] + "target_space_id" + + Returns: + Final decoder representation. [batch_size, decoder_length, hidden_dim] + """ + hparams = self._hparams + + if self.has_input: + inputs = features["inputs"] + target_space = features["target_space_id"] + (encoder_output, encoder_decoder_attention_bias, + enc_extra_output) = self.encode( + inputs, target_space, hparams, features=features) + else: + (encoder_output, encoder_decoder_attention_bias, + enc_extra_output) = (None, None, (None, None)) + + targets = features["targets"] + targets = common_layers.flatten4d3d(targets) + + (decoder_input, + decoder_self_attention_bias) = transformer.transformer_prepare_decoder( + targets, hparams, features=features) + + decoder_output, dec_extra_output = self.decode( + decoder_input, + encoder_output, + encoder_decoder_attention_bias, + decoder_self_attention_bias, + hparams, + nonpadding=transformer.features_to_nonpadding(features, "targets")) + + expected_attentions = features.get("expected_attentions") + if expected_attentions is not None: + attention_loss = common_attention.encoder_decoder_attention_loss( + expected_attentions, self.attention_weights, + hparams.expected_attention_loss_type, + hparams.expected_attention_loss_multiplier) + return decoder_output, {"attention_loss": attention_loss} + + if hparams.recurrence_type == "act" and hparams.act_loss_weight != 0: + if self.has_input: + enc_ponder_times, enc_remainders = enc_extra_output + enc_act_loss = ( + hparams.act_loss_weight * + tf.reduce_mean(enc_ponder_times + enc_remainders)) + else: + enc_act_loss = 0.0 + + (dec_ponder_times, dec_remainders) = dec_extra_output + dec_act_loss = ( + hparams.act_loss_weight * + tf.reduce_mean(dec_ponder_times + dec_remainders)) + act_loss = enc_act_loss + dec_act_loss + tf.summary.scalar("act_loss", act_loss) + return decoder_output, {"act_loss": act_loss} + + return decoder_output + + +@registry.register_model +class RTransformerEncoder(transformer.Transformer): + """R-Transformer Encoder: Depth-wise recurrent transoformer encoder-only.""" + + def encode(self, inputs, target_space, hparams, features=None): + """Encode transformer inputs. + + Args: + inputs: Transformer inputs [batch_size, input_length, input_height, + hidden_dim] which will be flattened along the two spatial dimensions. + target_space: scalar, target space ID. + hparams: hyperparmeters for model. + features: optionally pass the entire features dictionary as well. + This is needed now for "packed" datasets. + + Returns: + Tuple of: + encoder_output: Encoder representation. + [batch_size, input_length, hidden_dim] + encoder_extra_output: which is extra encoder output used in some + variants of the model (e.g. in ACT, to pass the ponder-time to body) + """ + inputs = common_layers.flatten4d3d(inputs) + + (encoder_input, self_attention_bias, _) = ( + transformer.transformer_prepare_encoder(inputs, target_space, hparams)) + + encoder_input = tf.nn.dropout(encoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + + (encoder_output, + encoder_extra_output) = r_transformer_util.r_transformer_encoder( + encoder_input, + self_attention_bias, + hparams, + nonpadding=transformer.features_to_nonpadding(features, "inputs"), + save_weights_to=self.attention_weights) + + return encoder_output, encoder_extra_output + + def body(self, features): + """R-Transformer main model_fn. + + Args: + features: Map of features to the model. Should contain the following: + "inputs": Transformer inputs [batch_size, input_length, hidden_dim] + "targets": Target decoder outputs. + [batch_size, decoder_length, hidden_dim] + "target_space_id" + + Returns: + Final decoder representation. [batch_size, decoder_length, hidden_dim] + """ + hparams = self._hparams + + assert self.has_input, ("r_transformer_encoder is applicable on problems" + "with inputs") + + inputs = features["inputs"] + target_space = features["target_space_id"] + encoder_output, enc_extra_output = self.encode( + inputs, target_space, hparams, features=features) + + encoder_output = tf.expand_dims(encoder_output, 2) + + if hparams.recurrence_type == "act" and hparams.act_loss_weight != 0: + ponder_times, remainders = enc_extra_output + act_loss = hparams.act_loss_weight * tf.reduce_mean(ponder_times + + remainders) + tf.summary.scalar("act_loss", act_loss) + + return encoder_output, {"act_loss": act_loss} + return encoder_output + + +def update_hparams_for_r_transformer(hparams): + """Adds deault hparams for all of the variants of the R-transformer. + + Args: + hparams: default hparams (usually one of the standard hparams from + transformer model (like "transformer_base") + + Returns: + hparams with default values for R-Transformers hyper-parameters + + """ + # Type of recurrency: + # None(no-recurrency) basic, highway, skip, dwa, act, rnn, gru, lstm. + hparams.add_hparam("recurrence_type", "basic") + + # Number of steps (which is equivalent to num layer in transformer). + hparams.add_hparam("num_rec_steps", hparams.num_hidden_layers) + + # Default ffn layer is separable convolution. + hparams.add_hparam("transformer_ffn_type", "sep") + + # Transform bias (in models with highway or skip connection). + hparams.add_hparam("transform_bias_init", -1.0) + hparams.add_hparam("couple_carry_transform_gates", True) + + # Depth-wise attention (grid-transformer!) hparams: + # Adds depth embedding, if true. + hparams.add_hparam("depth_embedding", True) + # Learns attention weights for elements (instead of positions), if true. + hparams.add_hparam("dwa_elements", True) + + # Type of ffn_layer used for gate in skip, highway, etc. + # "dense" or "dense_dropconnect". + # With dense_relu_dense, the bias/kernel initializations will not be applied. + hparams.add_hparam("gate_ffn_layer", "dense") + + # Config for all rnn style recurrencies (rnn, lstm, gru): + # Input of the gate functions: i:input/s:state/t:transformed state. + # or any combination: e.g. is, ts, ist, etc. + hparams.add_hparam("gates_inputs", "i") + + # LSTEM forget bias. + hparams.add_hparam("lstm_forget_bias", 1.0) + + # How to combine state and input in each step: + # "mh_attention_ffn_add" or "add_mh_attention_ffn" or "dense_mh_attention" + # or "mh_attention_dense". + # Interpretation for e.g. "mh_attention_ffn_add": + # Apply transformer attention then transformer ffn, then add. + hparams.add_hparam("inputs_states_combination", "mh_attention_ffn_add") + + # Config for gru_style recurrency: + # What to transform in gru: state/output/candidate/combination of them. + hparams.add_hparam("gru_transformation", ["state_transformation"]) + + # Config for lstm_style Recurrency: + # What to transform in lstm: state/modulated_input/memory. + hparams.add_hparam("lstm_transformation", ["state_transformation"]) + # Uses the mememory at the last step as the final touput, if true. + hparams.add_hparam("use_memory_as_final_state", False) + + # Type of act: basic/accumulated/global (instead of position-wise!)/random. + hparams.add_hparam("act_type", "basic") + # Max number of steps (forces halting at this step). + hparams.add_hparam("act_max_steps", 2 * hparams.num_hidden_layers) + hparams.add_hparam("act_halting_bias_init", 1.0) + hparams.add_hparam("act_epsilon", 0.01) + hparams.add_hparam("act_loss_weight", 0.01) + + return hparams + + +@registry.register_hparams +def r_transformer_big(): + hparams = transformer.transformer_big() + hparams = update_hparams_for_r_transformer(hparams) + return hparams + + +@registry.register_hparams +def r_transformer_base(): + hparams = transformer.transformer_base() + hparams = update_hparams_for_r_transformer(hparams) + return hparams + + +@registry.register_hparams +def r_transformer_tiny(): + hparams = transformer.transformer_tiny() + hparams = update_hparams_for_r_transformer(hparams) + hparams.num_rec_steps = 8 + return hparams + + +@registry.register_hparams +def transformer_teeny(): + hparams = transformer.transformer_base() + hparams.num_rec_steps = 2 + hparams.hidden_size = 128 + hparams.filter_size = 128 + hparams.num_heads = 2 + return hparams + + +@registry.register_hparams +def r_transformer_teeny(): + hparams = transformer_teeny() + hparams = update_hparams_for_r_transformer(hparams) + hparams.num_rec_steps = 10 + return hparams + + +@registry.register_hparams +def r_transformer_base_dropconnect(): + hparams = r_transformer_base() + hparams.gate_ffn_layer = "dense_dropconnect" + hparams.add_hparam("dropconnect_dropout", 0.5) + return hparams + + +@registry.register_hparams +def r_transformer_act_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + return hparams + + +@registry.register_hparams +def r_transformer_act_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + return hparams + + +@registry.register_hparams +def r_transformer_act_big(): + hparams = r_transformer_big() + hparams.recurrence_type = "act" + return hparams + + +@registry.register_hparams +def r_transformer_act_random_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + hparams.act_type = "random" + return hparams + + +@registry.register_hparams +def r_transformer_act_accumulated_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + hparams.act_type = "accumulated" + return hparams + + +@registry.register_hparams +def r_transformer_act_global_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + hparams.act_type = "global" + return hparams + + +@registry.register_hparams +def r_transformer_act_accumulated_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.act_type = "accumulated" + return hparams + + +@registry.register_hparams +def r_transformer_act_global_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.act_type = "global" + return hparams + + +@registry.register_hparams +def r_transformer_act_random_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.act_type = "random" + return hparams + + +@registry.register_hparams +def r_transformer_act_base_sb(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + hparams.batch_size = 2048 + return hparams + + +@registry.register_hparams +def r_transformer_act_large(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + hparams.hidden_size = 1024 + hparams.batch_size = 2048 + hparams.filter_size = 2048 + return hparams + + +@registry.register_hparams +def r_transformer_act_tall(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + hparams.num_hidden_layers = 16 + hparams.batch_size = 1024 + hparams.act_max_steps = 24 + return hparams + + +@registry.register_hparams +def r_transformer_act_tall_actlossw0(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + hparams.num_hidden_layers = 16 + hparams.batch_size = 1024 + hparams.act_max_steps = 24 + return hparams + + +@registry.register_hparams +def r_transformer_act_tall_actlossw001(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + hparams.num_hidden_layers = 16 + hparams.batch_size = 1024 + hparams.act_max_steps = 24 + return hparams + + +@registry.register_hparams +def r_transformer_act_base_d03(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + hparams.layer_prepostprocess_dropout = 0.3 + hparams.attention_dropout = 0.3 + hparams.relu_dropout = 0.3 + return hparams + + +@registry.register_hparams +def r_transformer_act_big_d03(): + hparams = r_transformer_big() + hparams.recurrence_type = "act" + hparams.layer_prepostprocess_dropout = 0.3 + hparams.attention_dropout = 0.3 + hparams.relu_dropout = 0.3 + return hparams + + +@registry.register_hparams +def r_transformer_act_tiny_d02(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.layer_prepostprocess_dropout = 0.2 + hparams.attention_dropout = 0.2 + hparams.relu_dropout = 0.2 + return hparams + + +@registry.register_hparams +def r_transformer_act_tiny_d02_sb(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.layer_prepostprocess_dropout = 0.2 + hparams.attention_dropout = 0.2 + hparams.relu_dropout = 0.2 + hparams.batch_size = 2048 + return hparams + + +@registry.register_hparams +def r_transformer_act_tiny_sb(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.batch_size = 2048 + return hparams + + +@registry.register_hparams +def r_transformer_act_tiny_d05(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.layer_prepostprocess_dropout = 0.5 + hparams.attention_dropout = 0.5 + hparams.relu_dropout = 0.5 + return hparams + + +@registry.register_hparams +def r_transformer_base_sb(): + hparams = r_transformer_base() + hparams.batch_size = 2048 + return hparams + + +@registry.register_hparams +def r_transformer_skip_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "skip" + return hparams + + +@registry.register_hparams +def r_transformer_skip_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "skip" + return hparams + + +@registry.register_hparams +def r_transformer_highway_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "highway" + return hparams + + +@registry.register_hparams +def r_transformer_highway_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "highway" + return hparams + + +@registry.register_hparams +def r_transformer_dwa_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "dwa" + return hparams + + +@registry.register_hparams +def r_transformer_dwa_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "dwa" + return hparams + + +@registry.register_hparams +def r_transformer_dwa_tiny_test(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "dwa" + return hparams + + +@registry.register_hparams +def r_transformer_rnn_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "rnn" + return hparams + + +@registry.register_hparams +def r_transformer_gru_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "gru" + return hparams + + +@registry.register_hparams +def r_transformer_lstm_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "lstm" + return hparams diff --git a/tensor2tensor/models/research/r_transformer_test.py b/tensor2tensor/models/research/r_transformer_test.py new file mode 100644 index 000000000..c9fd6521d --- /dev/null +++ b/tensor2tensor/models/research/r_transformer_test.py @@ -0,0 +1,71 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Transformer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +import numpy as np + +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.models.research import r_transformer + +import tensorflow as tf + +BATCH_SIZE = 3 +INPUT_LENGTH = 5 +TARGET_LENGTH = 7 +VOCAB_SIZE = 10 + + +class RTransformerTest(tf.test.TestCase): + + def getModel(self, hparams, mode=tf.estimator.ModeKeys.TRAIN, has_input=True): + hparams.hidden_size = 8 + hparams.filter_size = 32 + hparams.num_heads = 1 + hparams.layer_prepostprocess_dropout = 0.0 + + p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, VOCAB_SIZE) + if not has_input: + p_hparams.input_modality = {} + hparams.problems = [p_hparams] + + inputs = -1 + np.random.random_integers( + VOCAB_SIZE, size=(BATCH_SIZE, INPUT_LENGTH, 1, 1)) + targets = -1 + np.random.random_integers( + VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1)) + features = { + "inputs": tf.constant(inputs, dtype=tf.int32, name="inputs"), + "targets": tf.constant(targets, dtype=tf.int32, name="targets"), + "target_space_id": tf.constant(1, dtype=tf.int32) + } + + return r_transformer.RTransformer(hparams, mode, p_hparams), features + + def testTransformer(self): + model, features = self.getModel(r_transformer.r_transformer_base()) + logits, _ = model(features) + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensor2tensor/models/research/r_transformer_util.py b/tensor2tensor/models/research/r_transformer_util.py new file mode 100644 index 000000000..5ac242852 --- /dev/null +++ b/tensor2tensor/models/research/r_transformer_util.py @@ -0,0 +1,1760 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for R-Transformer. + + +R-Transformer learns a function (for instance the transformer multi-head +attention plus a feed-forward unit) and uses this function over n-steps to +process the input. +In other words, we can describe this as having a vanilla transformer, in which +the weights in the layers are shared and we have a module(the recurrency module) +next to this transformer that controls how steps communicate with each other in +depth. + +For instance, the recurrency module, can be a simple identity function +which passes the output of a step as the input to next step (applying one layer +of transformer n times on the input in a row --> lead to a better +generalization!). Or as another example, the recurrent module can be an LSTM, +(filliped vertically) next to the transformer which controls how state of the +model changes in depth, Or even a grit transformer (a transformer which learns +the attention over steps of an R-Transformer) + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import functools + +# Dependency imports + +from tensor2tensor.layers import common_attention +from tensor2tensor.layers import common_layers +from tensor2tensor.models import transformer +from tensor2tensor.utils import expert_utils + +import tensorflow as tf + + +def r_transformer_encoder(encoder_input, + encoder_self_attention_bias, + hparams, + name="encoder", + nonpadding=None, + save_weights_to=None, + make_image_summary=True): + """R_transformer_encoder function. + + Prepares all the arguments and the inputs and passes it to a + r_transformer_layer to encode the encoder_input. + + Args: + encoder_input: a Tensor + encoder_self_attention_bias: bias Tensor for self-attention + (see common_attention.attention_bias()) + hparams: hyperparameters for model + name: a string + nonpadding: optional Tensor with shape [batch_size, encoder_length] + indicating what positions are not padding. This must either be + passed in, which we do for "packed" datasets, or inferred from + encoder_self_attention_bias. The knowledge about padding is used + for pad_remover(efficiency) and to mask out padding in convoltutional + layers. + save_weights_to: an optional dictionary to capture attention weights + for vizualization; the weights tensor will be appended there under + a string key created from the variable scope (including name). + make_image_summary: Whether to make an attention image summary. + + Returns: + y: a Tensors as the output of the encoder + extra_output: which can be used to pass extra information to the body + """ + + x = encoder_input + attention_dropout_broadcast_dims = ( + common_layers.comma_separated_string_to_integer_list( + getattr(hparams, "attention_dropout_broadcast_dims", ""))) + with tf.variable_scope(name): + if nonpadding is not None: + padding = 1.0 - nonpadding + else: + padding = common_attention.attention_bias_to_padding( + encoder_self_attention_bias) + nonpadding = 1.0 - padding + pad_remover = None + if hparams.use_pad_remover and not common_layers.is_on_tpu(): + pad_remover = expert_utils.PadRemover(padding) + + ffn_unit = functools.partial( + transformer_encoder_ffn_unit, + hparams=hparams, + pad_remover=pad_remover, + nonpadding_mask=nonpadding) + + attention_unit = functools.partial( + transformer_encoder_attention_unit, + hparams=hparams, + encoder_self_attention_bias=encoder_self_attention_bias, + attention_dropout_broadcast_dims=attention_dropout_broadcast_dims, + save_weights_to=save_weights_to, + make_image_summary=make_image_summary) + + x, extra_output = r_transformer_layer( + x, hparams, ffn_unit, attention_unit, pad_remover=pad_remover) + + if hparams.get("use_memory_as_last_state", False): + x = extra_output # which is memory + return common_layers.layer_preprocess(x, hparams), extra_output + + +def r_transformer_decoder(decoder_input, + encoder_output, + decoder_self_attention_bias, + encoder_decoder_attention_bias, + hparams, + name="decoder", + nonpadding=None, + save_weights_to=None, + make_image_summary=True): + """R_transformer decoder function. + + Prepares all the arguments and the inputs and passes it to a + core_r_transformer_layer to decoder. + + Args: + decoder_input: a Tensor + encoder_output: a Tensor + decoder_self_attention_bias: bias Tensor for self-attention + (see common_attention.attention_bias()) + encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention + (see common_attention.attention_bias()) + hparams: hyperparameters for model + name: a string + nonpadding: optional Tensor with shape [batch_size, encoder_length] + indicating what positions are not padding. This is used + to mask out padding in convoltutional layers. We generally only + need this mask for "packed" datasets, because for ordinary datasets, + no padding is ever followed by nonpadding. + save_weights_to: an optional dictionary to capture attention weights + for vizualization; the weights tensor will be appended there under + a string key created from the variable scope (including name). + make_image_summary: Whether to make an attention image summary. + + Returns: + y: the output Tensors + extra_output: which can be used to pass extra information to the body + """ + x = decoder_input + attention_dropout_broadcast_dims = ( + common_layers.comma_separated_string_to_integer_list( + getattr(hparams, "attention_dropout_broadcast_dims", ""))) + with tf.variable_scope(name): + ffn_unit = functools.partial( + transformer_decoder_ffn_unit, + hparams=hparams, + nonpadding_mask=nonpadding) + + attention_unit = functools.partial( + transformer_decoder_attention_unit, + hparams=hparams, + encoder_output=encoder_output, + decoder_self_attention_bias=decoder_self_attention_bias, + encoder_decoder_attention_bias=encoder_decoder_attention_bias, + attention_dropout_broadcast_dims=attention_dropout_broadcast_dims, + save_weights_to=save_weights_to, + make_image_summary=make_image_summary) + + x, extra_output = r_transformer_layer(x, hparams, ffn_unit, attention_unit) + + return common_layers.layer_preprocess(x, hparams), extra_output + + +def r_transformer_layer(x, hparams, ffn_unit, attention_unit, pad_remover=None): + """Core function applying the r-transforemr layer. + + Args: + x: input + hparams: model hyper-parameters + ffn_unit: feed-forward unit + attention_unit: multi-head attention unit + pad_remover: to mask out padding in convolutional layers (efficiency). + + Returns: + the output tensor, extra output (can be memory, ponder time, etc.) + + Raises: + ValueError: Unknown recurrence type + """ + with tf.variable_scope("r_transformer_%s" % hparams.recurrence_type): + + if hparams.recurrence_type == "act": + return r_transformer_act(x, hparams, ffn_unit, attention_unit) + + else: # for all the other recurrency types with fixed number of steps + rt_function, initializer = get_rt_layer(x, hparams, ffn_unit, + attention_unit, pad_remover) + + output, _, extra_output = tf.foldl( + rt_function, tf.range(hparams.num_rec_steps), initializer=initializer) + + # This can be the if we use r_transformer_lstm layer. + if hparams.get("use_memory_as_final_state", False): + output = extra_output + + return output, extra_output + + +def get_rt_layer(x, hparams, ffn_unit, attention_unit, pad_remover=None): + """provides the function that is used in r-transforemr steps. + + Args: + x: input + hparams: model hyper-parameters + ffn_unit: feed-forward unit + attention_unit: multi-head attention unit + pad_remover: to mask out padding in convolutional layers (efficiency). + + Returns: + rt_function and the rt_initializer + + Raises: + ValueError: Unknown recurrence type + """ + + if hparams.recurrence_type == "basic": + rt_initializer = (x, x, x) # (state, input, memory) + rt_function = functools.partial( + r_transformer_basic, ffn_unit=ffn_unit, attention_unit=attention_unit) + + elif hparams.recurrence_type == "highway": + rt_initializer = (x, x, x) # (state, input, memory) + rt_function = functools.partial( + r_transformer_highway, + hparams=hparams, + ffn_unit=ffn_unit, + attention_unit=attention_unit, + pad_remover=pad_remover) + + elif hparams.recurrence_type == "skip": + rt_initializer = (x, x, x) # (state, input, memory) + rt_function = functools.partial( + r_transformer_skip, + hparams=hparams, + ffn_unit=ffn_unit, + attention_unit=attention_unit, + pad_remover=pad_remover) + + elif hparams.recurrence_type == "dwa": + # memory contains the original input + all the states + memory_size = hparams.num_rec_steps + 1 + + # prepare initializer: + memory_empty = tf.zeros([memory_size] + common_layers.shape_list(x)) + + # filling the first slot with the original input + memory = fill_memory_slot(memory_empty, x, 0) + + rt_initializer = (x, x, memory) # (state, input, memory) + rt_function = functools.partial( + r_transformer_depthwise_attention, + hparams=hparams, + ffn_unit=ffn_unit, + attention_unit=attention_unit) + + elif hparams.recurrence_type == "rnn": + rt_initializer = (x, x, x) # (state, input, memory) + rt_function = functools.partial( + r_transformer_rnn, + hparams=hparams, + ffn_unit=ffn_unit, + attention_unit=attention_unit, + pad_remover=pad_remover) + + elif hparams.recurrence_type == "gru": + rt_initializer = (x, x, x) # (state, input, memory) + rt_function = functools.partial( + r_transformer_gru, + hparams=hparams, + attention_unit=attention_unit, + pad_remover=pad_remover) + + elif hparams.recurrence_type == "lstm": + memory = tf.zeros(common_layers.shape_list(x)) + rt_initializer = (x, x, memory) # (state, input, memory) + rt_function = functools.partial( + r_transformer_lstm, + hparams=hparams, + attention_unit=attention_unit, + pad_remover=pad_remover) + + else: + raise ValueError("Unknown recurrence type: %s" % hparams.recurrence_type) + + return rt_function, rt_initializer + + +def transformer_encoder_ffn_unit(x, + hparams, + pad_remover=None, + nonpadding_mask=None): + """Applies a feed-forward function which is parametrised for encoding. + + Args: + x: input + hparams: model hyper-parameters + pad_remover: to mask out padding in convolutional layers (efficiency). + nonpadding_mask: optional Tensor with shape [batch_size, encoder_length] + indicating what positions are not padding. This is used + to mask out padding in convoltutional layers. We generally only + need this mask for "packed" datasets, because for ordinary datasets, + no padding is ever followed by nonpadding. + + Returns: + the output tensor + """ + + with tf.variable_scope("ffn"): + if hparams.transformer_ffn_type == "fc": + y = transformer.transformer_ffn_layer( + common_layers.layer_preprocess(x, hparams), + hparams, + pad_remover, + conv_padding="SAME", + nonpadding_mask=nonpadding_mask) + + if hparams.transformer_ffn_type == "sep": + y = common_layers.conv_hidden_relu( + common_layers.layer_preprocess(x, hparams), + padding="SAME", + kernel_size=(3, 1), + second_kernel_size=(31, 1), + hidden_size=hparams.filter_size, + output_size=hparams.hidden_size, + dropout=hparams.relu_dropout) + + x = common_layers.layer_postprocess(x, y, hparams) + + return x + + +def transformer_encoder_attention_unit(x, + hparams, + encoder_self_attention_bias, + attention_dropout_broadcast_dims, + save_weights_to=None, + make_image_summary=True): + """Applies multihead attention function which is parametrised for encoding. + + Args: + x: input + hparams: model hyper-parameters + encoder_self_attention_bias: a bias tensor for use in encoder self-attention + attention_dropout_broadcast_dims: Fpr noise broadcasting in the dropout + layers to save memory during training + save_weights_to: an optional dictionary to capture attention weights for + visualization; the weights tensor will be appended there under a string + key created from the variable scope (including name). + make_image_summary: Whether to make an attention image summary. + + Returns: + the output tensor + + """ + + with tf.variable_scope("self_attention"): + y = common_attention.multihead_attention( + common_layers.layer_preprocess(x, hparams), + None, + encoder_self_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + attention_type=hparams.self_attention_type, + save_weights_to=save_weights_to, + max_relative_position=hparams.max_relative_position, + make_image_summary=make_image_summary, + dropout_broadcast_dims=attention_dropout_broadcast_dims) + x = common_layers.layer_postprocess(x, y, hparams) + return x + + +def transformer_decoder_ffn_unit(x, hparams, nonpadding_mask=None): + """Applies a feed-forward function which is parametrised for decoding. + + Args: + x: input + hparams: model hyper-parameters + nonpadding_mask: optional Tensor with shape [batch_size, encoder_length] + indicating what positions are not padding. This is used + to mask out padding in convoltutional layers. We generally only + need this mask for "packed" datasets, because for ordinary datasets, + no padding is ever followed by nonpadding. + + Returns: + the output tensor + + """ + + with tf.variable_scope("ffn"): + if hparams.transformer_ffn_type == "fc": + y = transformer.transformer_ffn_layer( + common_layers.layer_preprocess(x, hparams), + hparams, + conv_padding="LEFT", + nonpadding_mask=nonpadding_mask) + + if hparams.transformer_ffn_type == "sep": + y = common_layers.conv_hidden_relu( + common_layers.layer_preprocess(x, hparams), + padding="LEFT", + kernel_size=(3, 1), + second_kernel_size=(31, 1), + hidden_size=hparams.filter_size, + output_size=hparams.hidden_size, + dropout=hparams.relu_dropout) + + x = common_layers.layer_postprocess(x, y, hparams) + + return x + + +def transformer_decoder_attention_unit(x, + hparams, + encoder_output, + decoder_self_attention_bias, + encoder_decoder_attention_bias, + attention_dropout_broadcast_dims, + save_weights_to=None, + make_image_summary=True): + """Applies multihead attention function which is parametrised for decoding. + + Args: + x: input (decoder input) + hparams: model hyper-parameters + encoder_output: Encoder representation. [batch_size, input_length, + hidden_dim] + decoder_self_attention_bias: Bias and mask weights for decoder + self-attention. [batch_size, decoder_length] + encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder + attention. [batch_size, input_length] + attention_dropout_broadcast_dims: Fpr noise broadcasting in the dropout + layers to save memory during training + save_weights_to: an optional dictionary to capture attention weights for + visualization; the weights tensor will be appended there under a string + key created from the variable scope (including name). + make_image_summary: Whether to make an attention image summary. + + Returns: + The output tensor + """ + + with tf.variable_scope("self_attention"): + y = common_attention.multihead_attention( + common_layers.layer_preprocess(x, hparams), + None, + decoder_self_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + attention_type=hparams.self_attention_type, + save_weights_to=save_weights_to, + max_relative_position=hparams.max_relative_position, + cache=None, + make_image_summary=make_image_summary, + dropout_broadcast_dims=attention_dropout_broadcast_dims) + x = common_layers.layer_postprocess(x, y, hparams) + if encoder_output is not None: + with tf.variable_scope("encdec_attention"): + y = common_attention.multihead_attention( + common_layers.layer_preprocess(x, hparams), + encoder_output, + encoder_decoder_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + save_weights_to=save_weights_to, + make_image_summary=make_image_summary, + dropout_broadcast_dims=attention_dropout_broadcast_dims) + x = common_layers.layer_postprocess(x, y, hparams) + return x + + +def r_transformer_basic(layer_inputs, unused_step, ffn_unit, attention_unit): + """Basic r_transformer. + + This is in fact vanilla transformer in which weights are shared between + layers. For some tasks, this simple idea brings a generalization that is not + achievable by playing with the size of the model or drop_out parameters in + the vanilla transformer. + + Args: + layer_inputs: + - state: state + unused_step: indicating number of steps take so far + ffn_unit: feed-forward unit + attention_unit: multi-head attention unit + + Returns: + layer_output: + new_state: new state + """ + state, inputs, memory = layer_inputs + + new_state = ffn_unit(attention_unit(state)) + + return new_state, inputs, memory + + +def r_transformer_highway(layer_inputs, + unused_step, + hparams, + ffn_unit, + attention_unit, + pad_remover=None): + """R_transformer with highway connection. + + + It transforms the state using attention and ffn and wrap this transformation + with a highway connection. (the new state is a combination of the state and + the transformed-state based on cary/transform gates.) + + Interesting observation: + Controlling the cary/transform gate with the original inputs works usually + better (i.e. hparams.gates_inputs="i") + + Args: + layer_inputs: + - state: state + - inputs: the original embedded inputs (= inputs to the first step) + unused_step: indicating number of steps take so far + hparams: model hyper-parameters. + ffn_unit: feed-forward unit + attention_unit: multi-head attention unit + pad_remover: to mask out padding in convolutional layers (efficiency). + + Returns: + layer_output: + new_state: new state + inputs: the original embedded inputs (= inputs to the first step) + + """ + + state, inputs, memory = layer_inputs + + transformed_state = ffn_unit(attention_unit(state)) + state.get_shape().assert_is_compatible_with(state.get_shape()) + + gate_inputs = [] + if "s" in hparams.gates_inputs: + gate_inputs.append(state) + + if "t" in hparams.gates_inputs: + gate_inputs.append(transformed_state) + + if "i" in hparams.gates_inputs: + gate_inputs.append(inputs) + + gate_ffn_layer = hparams.gate_ffn_layer + + transform_gate = _ffn_layer_multi_inputs( + gate_inputs, + hparams, + ffn_layer_type=gate_ffn_layer, + name="transform", + bias_initializer=tf.constant_initializer(hparams.transform_bias_init), + activation=tf.sigmoid, + pad_remover=pad_remover, + preprocess=True, + postprocess=True) + + if hparams.couple_carry_transform_gates: + carry_gate = tf.subtract(1.0, transform_gate, name="carry") + + else: + carry_gate = _ffn_layer_multi_inputs( + gate_inputs, + hparams, + ffn_layer_type=gate_ffn_layer, + name="carry", + bias_initializer=tf.constant_initializer(-hparams.transform_bias_init), + activation=tf.sigmoid, + pad_remover=pad_remover, + preprocess=True, + postprocess=True) + + new_state = state * carry_gate + transformed_state * transform_gate + + tf.contrib.summary.scalar("highway_transform_gate_layer", + tf.reduce_mean(transform_gate)) + + tf.contrib.summary.scalar("highway_carry_gate_layer", + tf.reduce_mean(carry_gate)) + + return new_state, inputs, memory + + +def r_transformer_skip(layer_inputs, + unused_step, + hparams, + ffn_unit, + attention_unit, + pad_remover=None): + """R_transformer with highway connection. + + + It transforms the state using attention and ffn and wrap this transformation + with a skip-all connection. (the new state is a combination of the state and + the inputs (original inputs) based on cary/transform gates.) + + Observation: + Controlling the cary/transform gate with the original inputs works usually + better (i.e. hparams.gates_inputs="i") + + Args: + layer_inputs: + - state: state + - inputs: the original embedded inputs (= inputs to the first step) + unused_step: indicating number of steps take so far + hparams: model hyper-parameters. + ffn_unit: feed-forward unit + attention_unit: multi-head attention unit + pad_remover: to mask out padding in convolutional layers (efficiency). + + + Returns: + layer_output: + new_state: new state + inputs: the original embedded inputs (= inputs to the first step) + """ + + state, inputs, memory = layer_inputs + + transformed_state = ffn_unit(attention_unit(state)) + + inputs.get_shape().assert_is_compatible_with(state.get_shape()) + + gate_inputs = [] + if "s" in hparams.gates_inputs: + gate_inputs.append(state) + + if "t" in hparams.gates_inputs: + gate_inputs.append(transformed_state) + + if "i" in hparams.gates_inputs: + gate_inputs.append(inputs) + + gate_ffn_layer = hparams.gate_ffn_layer + + transform_gate = _ffn_layer_multi_inputs( + gate_inputs, + hparams, + ffn_layer_type=gate_ffn_layer, + name="transform", + bias_initializer=tf.constant_initializer(hparams.transform_bias_init), + activation=tf.sigmoid, + pad_remover=pad_remover, + preprocess=True, + postprocess=True) + + if hparams.couple_carry_transform_gates: + carry_gate = tf.subtract(1.0, transform_gate, name="carry") + + else: + carry_gate = _ffn_layer_multi_inputs( + gate_inputs, + hparams, + ffn_layer_type=gate_ffn_layer, + name="carry", + bias_initializer=tf.constant_initializer(-hparams.transform_bias_init), + activation=tf.sigmoid, + pad_remover=pad_remover, + preprocess=True, + postprocess=True) + + tf.contrib.summary.scalar("skip_transform_gate_layer", + tf.reduce_mean(transform_gate)) + + tf.contrib.summary.scalar("skip_carry_gate_layer", tf.reduce_mean(carry_gate)) + + new_state = inputs * carry_gate + transformed_state * transform_gate + return new_state, inputs, memory + + +def r_transformer_depthwise_attention(layer_inputs, step, hparams, ffn_unit, + attention_unit): + """R_transformer with depth-wise attention. + + It uses an attention mechanism-flipped vertically- + over all the states from previous steps to generate the new_state. + + Args: + layer_inputs: + - state: state + - memory: contains states from all the previous steps. + step: indicating number of steps take so far + hparams: model hyper-parameters. + ffn_unit: feed-forward unit + attention_unit: multi-head attention unit + + + Returns: + layer_output: + new_state: new state + memory: contains states from all the previous steps. + + """ + _, inputs, memory = layer_inputs + all_states = memory + + # add depth signal + if hparams.depth_embedding: + all_states = add_depth_embedding(all_states) + + # get the states up to the current step (non-zero part of the memory) + states_so_far = all_states[:step, :, :, :] + + states_so_far_weights = tf.nn.softmax( + common_layers.dense( + states_so_far, (hparams.hidden_size if hparams.dwa_elements else 1), + activation=None, + use_bias=True), + axis=-1) + + # prepare the state tensor that will be transformed + state_to_be_transformed = tf.reduce_sum( + (states_so_far * states_so_far_weights), axis=0) + + new_state = ffn_unit(attention_unit(state_to_be_transformed)) + + # add the new state to the memory + memory = fill_memory_slot(memory, new_state, step + 1) + + return new_state, inputs, memory + + +def r_transformer_rnn(layer_inputs, + unused_step, + hparams, + ffn_unit, + attention_unit, + pad_remover=None): + """The RT cell which models recurencey similar to basic RNN cell. + + It's an R-transformer with an RNN applied over the stats on depth. + + Args: + layer_inputs: + - state: state + - inputs: the original embedded inputs (= inputs to the first step) + unused_step: indicating number of steps take so far + hparams: model hyper-parameters. + ffn_unit: feed-forward unit + attention_unit: multi-head attention unit + pad_remover: to mask out padding in convolutional layers (efficiency). + + Returns: + layer_output: + new_state: new state + inputs: the original embedded inputs (= inputs to the first step) + + Raises: + ValueError: Unknown inputs_states_combination type + + """ + + state, inputs, memory = layer_inputs + + # TODO(dehghani) keep only the meaningful cases: + if hparams.inputs_states_combination == "mh_attention_ffn_add": + state.get_shape().assert_is_compatible_with(inputs.get_shape()) + state = ffn_unit(attention_unit(state)) + new_state = state + inputs + + elif hparams.inputs_states_combination == "add_mh_attention_ffn": + state.get_shape().assert_is_compatible_with(inputs.get_shape()) + state += inputs + new_state = ffn_unit(attention_unit(state)) + + elif hparams.inputs_states_combination == "dense_mh_attention": + state = _ffn_layer_multi_inputs( + [state, inputs], + hparams=hparams, + ffn_layer_type="dense_relu_dense", + name="rnn", + activation=tf.tanh, + pad_remover=pad_remover) + + new_state = attention_unit(state) + + elif hparams.inputs_states_combination == "mh_attention_dense": + state = attention_unit(state) + new_state = _ffn_layer_multi_inputs( + [state, inputs], + hparams=hparams, + ffn_layer_type="dense_relu_dense", + name="rnn", + activation=tf.tanh, + pad_remover=pad_remover) + + else: + raise ValueError("Unknown inputs_states_combination type: %s" % + hparams.inputs_states_combination) + + return new_state, inputs, memory + + +def r_transformer_gru(layer_inputs, + unused_step, + hparams, + attention_unit, + pad_remover=None): + """The RT cell which models recurencey similar to GRU cell. + + It's an R-transformer with a gru applied over the stats on depth. + Based on GRU paper: http://arxiv.org/abs/1406.1078 + + Args: + layer_inputs: + - state: state + - inputs: the original embedded inputs (= inputs to the first step) + unused_step: indicating number of steps take so far + hparams: model hyper-parameters. + attention_unit: multi-head attention unit + pad_remover: to mask out padding in convolutional layers (efficiency). + + + Returns: + layer_output: + new_state: new state + inputs: the original embedded inputs (= inputs to the first step) + """ + + state, inputs, memory = layer_inputs + + # TODO(dehghani): do we need preprocess here? + state = common_layers.layer_preprocess(state, hparams) + inputs = common_layers.layer_preprocess(inputs, hparams) + + update_gate = _ffn_layer_multi_inputs( + [inputs, state], + hparams, + name="update", + bias_initializer=tf.constant_initializer(1.0), + activation=tf.sigmoid, + pad_remover=pad_remover) + + reset_gate = _ffn_layer_multi_inputs( + [inputs, state], + hparams, + name="reset", + bias_initializer=tf.constant_initializer(1.0), + activation=tf.sigmoid, + pad_remover=pad_remover) + + reset_state = reset_gate * state + + candidate = _ffn_layer_multi_inputs( + [inputs, reset_state], + hparams, + name="candidate", + bias_initializer=tf.zeros_initializer(), + activation=tf.tanh, + pad_remover=pad_remover) + + if "candidate_transformation" in hparams.gru_transformation: + candidate = attention_unit(candidate) + + if "state_transformation" in hparams.gru_transformation: + state = attention_unit(state) + + state = update_gate * state + (1 - update_gate) * candidate + + if "state_transformation" in hparams.gru_transformation: + state = attention_unit(state) + # normalization on the output + new_state = common_layers.layer_preprocess(state, hparams) + + return new_state, inputs, memory + + +def r_transformer_lstm(layer_inputs, + unused_step, + hparams, + attention_unit, + pad_remover=None): + """The RT cell which models recurencey similar to GRU cell. + + It's an R-transformer with a gru applied over the stats on depth. + based on LSTM paper: https://arxiv.org/pdf/1409.2329.pdf + + Args: + layer_inputs: + - state: state + - inputs: the original embedded inputs (= inputs to the first step) + - memory: memory used in lstm. + unused_step: indicating number of steps take so far + hparams: model hyper-parameters. + attention_unit: multi-head attention unit + pad_remover: to mask out padding in convolutional layers (efficiency). + + Returns: + layer_output: + new_state: new state + inputs: the original embedded inputs (= inputs to the first step) + memory: contains states from all the previous steps. + """ + state, inputs, memory = layer_inputs + + state = common_layers.layer_preprocess(state, hparams) + inputs = common_layers.layer_preprocess(inputs, hparams) + + input_gate = _ffn_layer_multi_inputs( + [inputs, state], + hparams, + name="input_g", + bias_initializer=tf.zeros_initializer(), + activation=tf.sigmoid, + pad_remover=pad_remover) + + forget_gate = _ffn_layer_multi_inputs( + [inputs, state], + hparams, + name="forget_g", + bias_initializer=tf.zeros_initializer(), + activation=None, + pad_remover=pad_remover) + + output_gate = _ffn_layer_multi_inputs( + [inputs, state], + hparams, + name="output_g", + bias_initializer=tf.zeros_initializer(), + activation=tf.sigmoid, + pad_remover=pad_remover) + + input_modulation = _ffn_layer_multi_inputs( + [inputs, state], + hparams, + name="input_modulation", + bias_initializer=tf.zeros_initializer(), + activation=tf.tanh, + pad_remover=pad_remover) + + forget_bias_tensor = tf.constant(hparams.lstm_forget_bias) + forget_gate = tf.sigmoid(forget_gate + forget_bias_tensor) + + if "modulated_input_transformation" in hparams.lstm_transformation: + input_modulation = attention_unit(input_modulation) + + memory = memory * forget_gate + input_gate * input_modulation + + if "memory_transformation" in hparams.lstm_transformation: + memory = attention_unit(memory) + + new_state = tf.tanh(memory) * output_gate + + if "state_transformation" in hparams.lstm_transformation: + new_state = attention_unit(new_state) + + return new_state, inputs, memory + + +def r_transformer_act(x, hparams, ffn_unit, attention_unit): + """ACT based models. + + Implementations of all act models are based on craffel@'s cl/160711592. + + Args: + x: input + hparams: model hyper-parameters + ffn_unit: feed-forward unit + attention_unit: multi-head attention unit + + Returns: + the output tensor, (ponder_times, remainders) + + Raises: + ValueError: Unknown act type + + """ + + if hparams.act_type == "basic": + return r_transformer_act_basic(x, hparams, ffn_unit, attention_unit) + + elif hparams.act_type == "accumulated": + return r_transformer_act_accumulated(x, hparams, ffn_unit, attention_unit) + + elif hparams.act_type == "global": + return r_transformer_act_global(x, hparams, ffn_unit, attention_unit) + + elif hparams.act_type == "random": + return r_transformer_act_random(x, hparams, ffn_unit, attention_unit) + + else: + raise ValueError("Unknown act type: %s" % hparams.act_type) + + +def r_transformer_act_basic(x, hparams, ffn_unit, attention_unit): + """Basic r_transformer with ACT based on remainder-distribution ACT. + + Args: + x: input + hparams: model hyper-parameters + ffn_unit: feed-forward unit + attention_unit: multi-head attention unit + + Returns: + the output tensor, (ponder_times, remainders) + + """ + + state = x + act_max_steps = hparams.act_max_steps + threshold = 1.0 - hparams.act_epsilon + + batch_size = tf.shape(state)[0] + length = tf.shape(state)[1] + + # Halting probabilities (p_t^n in the paper) + halting_probability = tf.zeros( + ( + batch_size, + length, + ), name="halting_probability") + # Remainders (R(t) in the paper) + remainders = tf.zeros( + ( + batch_size, + length, + ), name="remainder") + # Number of updates performed (N(t) in the paper) + n_updates = tf.zeros( + ( + batch_size, + length, + ), name="n_updates") + + # Previous cell states (s_t in the paper) + previous_state = tf.zeros_like(state, name="previous_state") + step = tf.constant(0, dtype=tf.int32) + + def rt_function(state, step, halting_probability, remainders, n_updates, + previous_state): + """implements act (position-wise halting). + + Args: + state: 3-D Tensor: [batch_size, length, channel] + step: indicating number of steps take so far + halting_probability: halting probability + remainders: act remainders + n_updates: act n_updates + previous_state: previous state + + Returns: + transformed_state: transformed state + step: step+1 + halting_probability: halting probability + remainders: act remainders + n_updates: act n_updates + new_state: new state + """ + state_shape = state.get_shape() + + with tf.variable_scope("sigmoid_activation_for_pondering"): + p = common_layers.dense( + state, + 1, + activation=tf.nn.sigmoid, + use_bias=True, + bias_initializer=tf.constant_initializer( + hparams.act_halting_bias_init)) + p = tf.squeeze(p) + + # Mask for inputs which have not halted yet + still_running = tf.cast(tf.less(halting_probability, 1.0), tf.float32) + + # Mask of inputs which halted at this step + new_halted = tf.cast( + tf.greater(halting_probability + p * still_running, threshold), + tf.float32) * still_running + + # Mask of inputs which haven't halted, and didn't halt this step + still_running = tf.cast( + tf.less_equal(halting_probability + p * still_running, threshold), + tf.float32) * still_running + + # Add the halting probability for this step to the halting + # probabilities for those input which haven't halted yet + halting_probability += p * still_running + + # Compute remainders for the inputs which halted at this step + remainders += new_halted * (1 - halting_probability) + + # Add the remainders to those inputs which halted at this step + halting_probability += new_halted * remainders + + # Increment n_updates for all inputs which are still running + n_updates += still_running + new_halted + + # Compute the weight to be applied to the new state and output + # 0 when the input has already halted + # p when the input hasn't halted yet + # the remainders when it halted this step + update_weights = tf.expand_dims(p * still_running + new_halted * remainders, + -1) + + # apply transformation on the state + transformed_state = ffn_unit(attention_unit(state)) + + # update running part in the weighted state and keep the rest + new_state = ((transformed_state * update_weights) + + (previous_state * 1 - update_weights)) + + # remind TensorFlow of everything's shape + transformed_state.set_shape(state_shape) + for x in [halting_probability, remainders, n_updates]: + x.set_shape([ + state_shape[0], + state_shape[1], + ]) + new_state.set_shape(state_shape) + step += 1 + return (transformed_state, step, halting_probability, remainders, n_updates, + new_state) + + # While loop stops when this predicate is FALSE. + # Ie all (probability < 1-eps AND counter < N) are false. + def should_continue(u0, u1, halting_probability, u2, n_updates, u3): + del u0, u1, u2, u3 + return tf.reduce_any( + tf.logical_and( + tf.less(halting_probability, threshold), + tf.less(n_updates, act_max_steps))) + + # Do while loop iterations until predicate above is false. + (_, _, _, remainder, n_updates, new_state) = tf.while_loop( + should_continue, rt_function, + (state, step, halting_probability, remainders, n_updates, previous_state)) + + ponder_times = n_updates + remainders = remainder + + tf.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) + + return new_state, (ponder_times, remainders) + + +def r_transformer_act_accumulated(x, hparams, ffn_unit, attention_unit): + """The RTAct cell where the final state is accumulation of all states. + + (similar to the main ACT paper: --> check the issue of differentiability) + + Args: + x: input + hparams: model hyper-parameters + ffn_unit: feed-forward unit + attention_unit: multi-head attention unit + + Returns: + the output tensor, (ponder_times, remainders) + + """ + state = x + act_max_steps = hparams.act_max_steps + threshold = 1.0 - hparams.act_epsilon + + batch_size = tf.shape(state)[0] + length = tf.shape(state)[1] + + # Halting probabilities (p_t^n in the paper) + halting_probability = tf.zeros( + ( + batch_size, + length, + ), name="halting_probability") + # Remainders (R(t) in the paper) + remainders = tf.zeros( + ( + batch_size, + length, + ), name="remainder") + # Number of updates performed (N(t) in the paper) + n_updates = tf.zeros( + ( + batch_size, + length, + ), name="n_updates") + + # Accumulated cell states (s_t in the paper) + accumulated_state = tf.zeros_like(state, name="previous_state") + step = tf.constant(0, dtype=tf.int32) + + def rt_function(state, step, halting_probability, remainders, n_updates, + accumulated_state): + """Position-wise act. + + Args: + state: 3-D Tensor: [batch_size, length, channel] + step: indicating number of steps take so far + halting_probability: halting probability + remainders: act remainders + n_updates: act n_updates + accumulated_state: accumulated state + + Returns: + transformed_state: transformed state + step: step+1 + halting_probability: halting probability + remainders: act remainders + n_updates: act n_updates + accumulated_state: accumulated state + """ + state_shape = state.get_shape() + + with tf.variable_scope("sigmoid_activation_for_pondering"): + p = common_layers.dense( + state, + 1, + activation=tf.nn.sigmoid, + use_bias=True, + bias_initializer=tf.constant_initializer( + hparams.act_halting_bias_init)) + p = tf.squeeze(p) + + # Mask for inputs which have not halted yet + still_running = tf.cast(tf.less(halting_probability, 1.0), tf.float32) + + # Mask of inputs which halted at this step + new_halted = tf.cast( + tf.greater(halting_probability + p * still_running, threshold), + tf.float32) * still_running + + # Mask of inputs which haven't halted, and didn't halt this step + still_running = tf.cast( + tf.less_equal(halting_probability + p * still_running, threshold), + tf.float32) * still_running + + # Add the halting probability for this step to the halting + # probabilities for those input which haven't halted yet + halting_probability += p * still_running + + # Compute remainders for the inputs which halted at this step + remainders += new_halted * (1 - halting_probability) + + # Add the remainders to those inputs which halted at this step + halting_probability += new_halted * remainders + + # Increment n_updates for all inputs which are still running + n_updates += still_running + new_halted + + # Compute the weight to be applied to the new state and output + # 0 when the input has already halted + # p when the input hasn't halted yet + # the remainders when it halted this step + update_weights = tf.expand_dims(p * still_running + new_halted * remainders, + -1) + + # apply transformation on the state + transformed_state = ffn_unit(attention_unit(state)) + + # Add in the weighted state + accumulated_state = (transformed_state * update_weights) + accumulated_state + + # Remind TensorFlow of everything's shape + state.set_shape(state_shape) + for x in [halting_probability, remainders, n_updates]: + x.set_shape([ + state_shape[0], + state_shape[1], + ]) + accumulated_state.set_shape(state_shape) + step += 1 + return (transformed_state, step, halting_probability, remainders, n_updates, + accumulated_state) + + # While loop stops when this predicate is FALSE. + # Ie all (probability < 1-eps AND counter < N) are false. + def should_continue(u0, u1, halting_probability, u2, n_updates, u3): + del u0, u1, u2, u3 + return tf.reduce_any( + tf.logical_and( + tf.less(halting_probability, threshold), + tf.less(n_updates, act_max_steps))) + + # Do while loop iterations until predicate above is false. + (_, _, _, remainder, n_updates, accumulated_state) = tf.while_loop( + should_continue, rt_function, (state, step, halting_probability, + remainders, n_updates, accumulated_state)) + + ponder_times = n_updates + remainders = remainder + + tf.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) + + return accumulated_state, (ponder_times, remainders) + + +def r_transformer_act_global(x, hparams, ffn_unit, attention_unit): + """The RTAct with global halting probability (not position-wise). + + Args: + x: input + hparams: model hyper-parameters + ffn_unit: feed-forward unit + attention_unit: multi-head attention unit + + Returns: + the output tensor, (ponder_times, remainders) + + """ + state = x + act_max_steps = hparams.act_max_steps + threshold = 1.0 - hparams.act_epsilon + act_max_steps = hparams.act_max_steps + batch_size = tf.shape(state)[0] + state_shape = state.get_shape() + + # Halting probabilities (p_t^n in the paper) + halting_probability = tf.zeros((batch_size,), name="halting_probability") + # Remainders (R(t) in the paper) + remainders = tf.zeros((batch_size,), name="remainder") + # Number of updates performed (N(t) in the paper) + n_updates = tf.zeros((batch_size,), name="n_updates") + # Previous cell states (s_t in the paper) + previous_state = tf.zeros_like(state, name="previous_state") + step = tf.constant(0, dtype=tf.int32) + + def rt_function(state, step, halting_probability, remainders, n_updates, + previous_state): + """implements act (global halting). + + Args: + state: 3-D Tensor: [batch_size, length, channel] + step: indicating number of steps take so far + halting_probability: halting probability + remainders: act remainders + n_updates: act n_updates + previous_state: previous state + + Returns: + transformed_state: transformed state + step: step+1 + halting_probability: halting probability + remainders: act remainders + n_updates: act n_updates + new_state: new state + + """ + + with tf.variable_scope("sigmoid_activation_for_pondering"): + p = common_layers.dense( + state, + 1, + activation=tf.nn.sigmoid, + use_bias=True, + bias_initializer=tf.constant_initializer( + hparams.act_halting_bias_init)) + # average over all positions (as a global halting prob) + p = tf.reduce_mean(p, axis=1) + p = tf.squeeze(p) + + # Mask for inputs which have not halted yet + still_running = tf.cast(tf.less(halting_probability, 1.0), tf.float32) + + # Mask of inputs which halted at this step + new_halted = tf.cast( + tf.greater(halting_probability + p * still_running, threshold), + tf.float32) * still_running + + # Mask of inputs which haven't halted, and didn't halt this step + still_running = tf.cast( + tf.less_equal(halting_probability + p * still_running, threshold), + tf.float32) * still_running + + # Add the halting probability for this step to the halting + # probabilities for those input which haven't halted yet + halting_probability += p * still_running + + # Compute remainders for the inputs which halted at this step + remainders += new_halted * (1 - halting_probability) + + # Add the remainders to those inputs which halted at this step + halting_probability += new_halted * remainders + + # Increment n_updates for all inputs which are still running + n_updates += still_running + new_halted + + # Compute the weight to be applied to the new state and output + # 0 when the input has already halted + # p when the input hasn't halted yet + # the remainders when it halted this step + update_weights = tf.expand_dims( + tf.expand_dims(p * still_running + new_halted * remainders, -1), -1) + + # apply transformation on the state + transformed_state = ffn_unit(attention_unit(state)) + + # Add in the weighted state + new_state = ((transformed_state * update_weights) + + (previous_state * 1 - update_weights)) + + # Remind TensorFlow of everything's shape + state.set_shape(state_shape) + for x in [halting_probability, remainders, n_updates]: + x.set_shape([ + state_shape[0], + ]) + new_state.set_shape(state_shape) + + step += 1 + return [ + transformed_state, step, halting_probability, remainders, n_updates, + new_state + ] + + # While loop stops when this predicate is FALSE. + # Ie all (probability < 1-eps AND counter < N) are false. + def should_continue(u0, u1, halting_probability, u2, n_updates, u3): + del u0, u1, u2, u3 + return tf.reduce_any( + tf.logical_and( + tf.less(halting_probability, threshold), + tf.less(n_updates, act_max_steps))) + + # Do while loop iterations until predicate above is false. + (_, _, _, remainder, n_updates, new_state) = tf.while_loop( + should_continue, rt_function, + (state, step, halting_probability, remainders, n_updates, previous_state)) + + ponder_times = n_updates + remainders = remainder + + tf.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) + + return new_state, (ponder_times, remainders) + + +def r_transformer_act_random(x, hparams, ffn_unit, attention_unit): + """r_transformer with ACT with random halting probability. + + Args: + x: input + hparams: model hyper-parameters + ffn_unit: feed-forward unit + attention_unit: multi-head attention unit + + Returns: + the output tensor, (ponder_times, remainders) + + """ + + state = x + act_max_steps = hparams.act_max_steps + threshold = 1.0 - hparams.act_epsilon + + batch_size = tf.shape(state)[0] + length = tf.shape(state)[1] + + # Halting probabilities (p_t^n in the paper) + halting_probability = tf.zeros( + ( + batch_size, + length, + ), name="halting_probability") + # Remainders (R(t) in the paper) + remainders = tf.zeros( + ( + batch_size, + length, + ), name="remainder") + # Number of updates performed (N(t) in the paper) + n_updates = tf.zeros( + ( + batch_size, + length, + ), name="n_updates") + + # Previous cell states (s_t in the paper) + previous_state = tf.zeros_like(state, name="previous_state") + step = tf.constant(0, dtype=tf.int32) + + def rt_function(state, step, halting_probability, remainders, n_updates, + previous_state): + """Implements act (position-wise halting). + + Args: + state: 3-D Tensor: [batch_size, length, channel] + step: indicating number of steps take so far + halting_probability: halting probability + remainders: act remainders + n_updates: act n_updates + previous_state: previous state + + Returns: + transformed_state: transformed state + step: step+1 + halting_probability: halting probability + remainders: act remainders + n_updates: act n_updates + new_state: new state + + """ + state_shape = state.get_shape() + + # random as halting probability + p = tf.random_uniform(shape=common_layers.shape_list(halting_probability)) + + # Mask for inputs which have not halted yet + still_running = tf.cast(tf.less(halting_probability, 1.0), tf.float32) + + # Mask of inputs which halted at this step + new_halted = tf.cast( + tf.greater(halting_probability + p * still_running, threshold), + tf.float32) * still_running + + # Mask of inputs which haven't halted, and didn't halt this step + still_running = tf.cast( + tf.less_equal(halting_probability + p * still_running, threshold), + tf.float32) * still_running + + # Add the halting probability for this step to the halting + # probabilities for those input which haven't halted yet + halting_probability += p * still_running + + # Compute remainders for the inputs which halted at this step + remainders += new_halted * (1 - halting_probability) + + # Add the remainders to those inputs which halted at this step + halting_probability += new_halted * remainders + + # Increment n_updates for all inputs which are still running + n_updates += still_running + new_halted + + # Compute the weight to be applied to the new state and output + # 0 when the input has already halted + # p when the input hasn't halted yet + # the remainders when it halted this step + update_weights = tf.expand_dims(p * still_running + new_halted * remainders, + -1) + + # apply transformation on the state + transformed_state = ffn_unit(attention_unit(state)) + + # update running part in the weighted state and keep the rest + new_state = ((transformed_state * update_weights) + + (previous_state * 1 - update_weights)) + + # remind TensorFlow of everything's shape + transformed_state.set_shape(state_shape) + for x in [halting_probability, remainders, n_updates]: + x.set_shape([ + state_shape[0], + state_shape[1], + ]) + new_state.set_shape(state_shape) + step += 1 + return [ + transformed_state, step, halting_probability, remainders, n_updates, + new_state + ] + + # While loop stops when this predicate is FALSE. + # Ie all (probability < 1-eps AND counter < N) are false. + def should_continue(u0, u1, halting_probability, u2, n_updates, u3): + del u0, u1, u2, u3 + return tf.reduce_any( + tf.logical_and( + tf.less(halting_probability, threshold), + tf.less(n_updates, act_max_steps))) + + # Do while loop iterations until predicate above is false. + (_, _, _, remainder, n_updates, new_state) = tf.while_loop( + should_continue, rt_function, + (state, step, halting_probability, remainders, n_updates, previous_state)) + + ponder_times = n_updates + remainders = remainder + + tf.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) + + return new_state, (ponder_times, remainders) + + +def _ffn_layer_multi_inputs(inputs_list, + hparams, + ffn_layer_type="dense", + name="ffn", + kernel_initializer=None, + bias_initializer=None, + activation=None, + pad_remover=None, + preprocess=True, + postprocess=True): + """Implements a Feed-forward layer with multiple inputs, pad-removing, etc. + + Args: + inputs_list: list of input tensors + hparams: hyper-parameters + ffn_layer_type: dense / dense_dropconnect/ dense_relu_dense + name: name + kernel_initializer: kernel initializer + bias_initializer: bias initializer + activation: activation function + pad_remover: pad remover + preprocess: if preprocess the input + postprocess: if postprocess the output + + Returns: + a tensor + Raises: + ValueError: Unknown ffn_layer type. + + """ + + # need at least one inputs + num_inputs = len(inputs_list) + assert num_inputs > 0 + + if preprocess and num_inputs == 1: + inputs_list[0] = common_layers.layer_preprocess(inputs_list[0], hparams) + + if postprocess: + original_inputs = inputs_list[0] + + # the output size is the hidden size of the main inputs + main_input = inputs_list[0] + original_shape = common_layers.shape_list(main_input) + assert hparams.hidden_size == common_layers.shape_list(main_input)[-1] + + # all the inputs are in the same shape with main inputs + for inputs in inputs_list: + main_input.get_shape().assert_is_compatible_with(inputs.get_shape()) + + def remove_pads(x): + original_shape = common_layers.shape_list(x) + # Collapse `x` across examples, and remove padding positions. + x = tf.reshape(x, tf.concat([[-1], original_shape[2:]], axis=0)) + x = tf.expand_dims(pad_remover.remove(x), axis=0) + return x + + if pad_remover: + for i, inputs in enumerate(inputs_list): + inputs_list[i] = remove_pads(inputs) + + ffn_inputs = ( + inputs_list[0] + if len(inputs_list) == 1 else tf.concat(inputs_list, axis=-1)) + + if ffn_layer_type == "dense": + output = common_layers.dense( + ffn_inputs, + hparams.hidden_size, + name=name, + activation=activation, + use_bias=True, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer) + + elif ffn_layer_type == "dense_dropconnect": + output = common_layers.dense_dropconnect( + ffn_inputs, + hparams.hidden_size, + name=name, + dropconnect_dropout=hparams.dropconnect_dropout, + output_activation=activation) + postprocess = False # no dropout on the output unit + + elif ffn_layer_type == "dense_relu_dense": + output = common_layers.dense_relu_dense( + ffn_inputs, + hparams.filter_size, + hparams.hidden_size, + name=name, + dropout=hparams.relu_dropout, + output_activation=activation, + ) + + else: + raise ValueError("Unknown ffn_layer type: %s" % ffn_layer_type) + + if pad_remover: + # Restore `output` to the original shape of `x`, including padding. + output = tf.reshape( + pad_remover.restore(tf.squeeze(output, axis=0)), original_shape) + + if postprocess: + if num_inputs == 1: + output = common_layers.layer_postprocess(original_inputs, output, hparams) + else: # only dropout (no residual)x + hp = copy.copy(hparams) + hp.layer_postprocess_sequence = hp.layer_postprocess_sequence.replace( + "a", "") + output = common_layers.layer_postprocess(original_inputs, output, hp) + + return output + + +def fill_memory_slot(memory, value, index): + """Fills the memory slot at a particular index with the given value. + + Args: + memory: a 4-d tensor [memory_size, batch, length, channel] containing + the state of all steps + value: a 3-d tensor [batch, length, channel] as the sate + index: integer in [0, memory_size) + + Returns: + filled memory + + """ + mask = tf.to_float( + tf.one_hot(index, + tf.shape(memory)[0])[:, None, None, None]) + fill_memory = (1 - mask) * memory + mask * value[None, ...] + return fill_memory + + +def add_depth_embedding(x): + """Add n-dimensional embedding as the depth embedding (timing signal). + + Adds embeddings to represent the position of the step in the recurrent + tower. + + Args: + x: a tensor with shape [max_step, batch, length, depth] + + Returns: + a Tensor the same shape as x. + """ + x_shape = common_layers.shape_list(x) + depth = x_shape[-1] + num_steps = x_shape[0] + shape = [num_steps, 1, 1, x_shape[-1]] + depth_embedding = ( + tf.get_variable( + "depth_embedding", + shape, + initializer=tf.random_normal_initializer(0, depth**-0.5)) * (depth** + 0.5)) + x += depth_embedding + return x diff --git a/tensor2tensor/models/research/rl.py b/tensor2tensor/models/research/rl.py index 2c5181d95..6d7eb15b5 100644 --- a/tensor2tensor/models/research/rl.py +++ b/tensor2tensor/models/research/rl.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Reinforcement learning models and parameters.""" import collections diff --git a/tensor2tensor/models/research/super_lm.py b/tensor2tensor/models/research/super_lm.py index 40bfb7f64..5ffe5e256 100644 --- a/tensor2tensor/models/research/super_lm.py +++ b/tensor2tensor/models/research/super_lm.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Supercomputer-based language model. Uses model-parallelism. diff --git a/tensor2tensor/models/research/transformer_moe.py b/tensor2tensor/models/research/transformer_moe.py index 7baeeb691..c0d3ed2e9 100644 --- a/tensor2tensor/models/research/transformer_moe.py +++ b/tensor2tensor/models/research/transformer_moe.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """transformer (attention seq-seq model) with mixtures of experts. """ diff --git a/tensor2tensor/models/research/transformer_revnet.py b/tensor2tensor/models/research/transformer_revnet.py index c15ba314f..0bafb546b 100644 --- a/tensor2tensor/models/research/transformer_revnet.py +++ b/tensor2tensor/models/research/transformer_revnet.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Reversible Residual Transformer.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/research/transformer_revnet_test.py b/tensor2tensor/models/research/transformer_revnet_test.py index 89e075c12..4393943ce 100644 --- a/tensor2tensor/models/research/transformer_revnet_test.py +++ b/tensor2tensor/models/research/transformer_revnet_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for TransformerRevnet.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/research/transformer_sketch.py b/tensor2tensor/models/research/transformer_sketch.py index c32a28493..d26f8e96b 100644 --- a/tensor2tensor/models/research/transformer_sketch.py +++ b/tensor2tensor/models/research/transformer_sketch.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Transformer Sketch for im2sketch problems. """ diff --git a/tensor2tensor/models/research/transformer_symshard.py b/tensor2tensor/models/research/transformer_symshard.py index e3c541a07..0acc4251e 100644 --- a/tensor2tensor/models/research/transformer_symshard.py +++ b/tensor2tensor/models/research/transformer_symshard.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Test of the SymShard programming model. Symmetric model parallellism. diff --git a/tensor2tensor/models/research/transformer_vae.py b/tensor2tensor/models/research/transformer_vae.py index 6d7b35b3e..9825d8290 100644 --- a/tensor2tensor/models/research/transformer_vae.py +++ b/tensor2tensor/models/research/transformer_vae.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """AE Transformer.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/research/transformer_vae_test.py b/tensor2tensor/models/research/transformer_vae_test.py index ae08f6dc3..09efd32ec 100644 --- a/tensor2tensor/models/research/transformer_vae_test.py +++ b/tensor2tensor/models/research/transformer_vae_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for tensor2tensor.models.research.transformer_vae.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/models/resnet.py b/tensor2tensor/models/resnet.py index e953ba6a0..ed753225b 100644 --- a/tensor2tensor/models/resnet.py +++ b/tensor2tensor/models/resnet.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Resnets.""" # Copied from cloud_tpu/models/resnet/resnet_model.py and modified @@ -23,6 +22,7 @@ # Dependency imports from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model @@ -394,6 +394,7 @@ def resnet_v2(inputs, @registry.register_model class Resnet(t2t_model.T2TModel): + """Residual Network.""" def body(self, features): hp = self.hparams @@ -425,6 +426,26 @@ def body(self, features): return out + def infer(self, + features=None, + decode_length=50, + beam_size=1, + top_beams=1, + alpha=0.0, + use_tpu=False): + """Predict.""" + del decode_length, beam_size, top_beams, alpha, use_tpu + assert features is not None + logits, _ = self(features) # pylint: disable=not-callable + assert len(logits.get_shape()) == 5 + logits = tf.squeeze(logits, [1, 2, 3]) + log_probs = common_layers.log_prob_from_logits(logits) + predictions, scores = common_layers.argmax_with_score(log_probs) + return { + "outputs": predictions, + "scores": scores, + } + def resnet_base(): """Set of hyperparameters.""" diff --git a/tensor2tensor/models/resnet_test.py b/tensor2tensor/models/resnet_test.py index c5c7312da..24c11a325 100644 --- a/tensor2tensor/models/resnet_test.py +++ b/tensor2tensor/models/resnet_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Resnet tests.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/revnet.py b/tensor2tensor/models/revnet.py index 63ae19717..7ddab9a2b 100644 --- a/tensor2tensor/models/revnet.py +++ b/tensor2tensor/models/revnet.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - """Creates a RevNet with the bottleneck residual function. Implements the following equations described in the RevNet paper: diff --git a/tensor2tensor/models/revnet_test.py b/tensor2tensor/models/revnet_test.py index 68fec94a2..6f4f50f3c 100644 --- a/tensor2tensor/models/revnet_test.py +++ b/tensor2tensor/models/revnet_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for Revnet.""" from tensor2tensor.models import revnet diff --git a/tensor2tensor/models/shake_shake.py b/tensor2tensor/models/shake_shake.py index c5a6caf04..f05659a89 100644 --- a/tensor2tensor/models/shake_shake.py +++ b/tensor2tensor/models/shake_shake.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Shake-shake model for CIFAR.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/slicenet.py b/tensor2tensor/models/slicenet.py index be6c51e4d..e2c2271aa 100644 --- a/tensor2tensor/models/slicenet.py +++ b/tensor2tensor/models/slicenet.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """SliceNet.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/models/slicenet_test.py b/tensor2tensor/models/slicenet_test.py index 299944b6b..85a6d161b 100644 --- a/tensor2tensor/models/slicenet_test.py +++ b/tensor2tensor/models/slicenet_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for SliceNet.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 4fb89db61..5e228fd51 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Transformer model from "Attention Is All You Need". The Transformer model consists of an encoder and a decoder. Both are stacks @@ -160,6 +159,7 @@ def body(self, features): encoder_output, encoder_decoder_attention_bias = (None, None) targets = features["targets"] + targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = transformer_prepare_decoder( @@ -181,7 +181,7 @@ def body(self, features): hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} - return decoder_output + return tf.reshape(decoder_output, targets_shape) def _greedy_infer(self, features, decode_length): """Fast version of greedy decoding. @@ -496,7 +496,7 @@ def fast_decode(encoder_output, def inner_loop(i, finished, next_id, decoded_ids, cache, log_prob): """One step of greedy decoding.""" logits, cache = symbols_to_logits_fn(next_id, i, cache) - log_probs = beam_search.log_prob_from_logits(logits) + log_probs = common_layers.log_prob_from_logits(logits) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = common_layers.sample_with_temperature(logits, temperature) @@ -566,7 +566,7 @@ def infer(self, logits = tf.squeeze(logits, [2, 3]) # Compute the log probabilities - log_probs = beam_search.log_prob_from_logits(logits) + log_probs = common_layers.log_prob_from_logits(logits) # Slice out the log_probs of the targets targets = features["targets"] @@ -577,11 +577,10 @@ def infer(self, flat_targets = tf.reshape(targets, [batch_size * timesteps]) flat_log_probs = tf.reshape(log_probs, [batch_size * timesteps, vocab_size]) flat_indices = tf.stack( - [tf.range(tf.to_int64(batch_size) * tf.to_int64(timesteps)), + [tf.range(tf.to_int64(common_layers.shape_list(flat_targets)[0])), tf.to_int64(flat_targets)], axis=1) - log_probs = tf.reshape( - tf.gather_nd(flat_log_probs, flat_indices), - [batch_size, timesteps]) + flat_log_probs = tf.gather_nd(flat_log_probs, flat_indices) + log_probs = tf.reshape(flat_log_probs, [batch_size, timesteps]) # Sum over time to get the log_prob of the sequence scores = tf.reduce_sum(log_probs, axis=1) @@ -790,7 +789,8 @@ def transformer_encoder(encoder_input, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, make_image_summary=make_image_summary, - dropout_broadcast_dims=attention_dropout_broadcast_dims) + dropout_broadcast_dims=attention_dropout_broadcast_dims, + max_length=hparams.get("max_length")) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( @@ -864,7 +864,8 @@ def transformer_decoder(decoder_input, max_relative_position=hparams.max_relative_position, cache=layer_cache, make_image_summary=make_image_summary, - dropout_broadcast_dims=attention_dropout_broadcast_dims) + dropout_broadcast_dims=attention_dropout_broadcast_dims, + max_length=hparams.get("max_length")) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): @@ -880,7 +881,8 @@ def transformer_decoder(decoder_input, hparams.attention_dropout, save_weights_to=save_weights_to, make_image_summary=make_image_summary, - dropout_broadcast_dims=attention_dropout_broadcast_dims) + dropout_broadcast_dims=attention_dropout_broadcast_dims, + max_length=hparams.get("max_length")) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index 7eaf0e285..fdcf86731 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for Transformer.""" from __future__ import absolute_import @@ -153,13 +152,11 @@ def testSlowVsFastNoInput(self): def testBeamDecodeWithRelativeAttention(self): decode_length = 2 model, features = get_model(transformer.transformer_relative_tiny()) - model(features) model.set_mode(tf.estimator.ModeKeys.PREDICT) - with tf.variable_scope(tf.get_variable_scope(), reuse=True): - beam_result = model._beam_decode( - features, decode_length, beam_size=4, top_beams=1, - alpha=1.0)["outputs"] + beam_result = model._beam_decode( + features, decode_length, beam_size=4, top_beams=1, + alpha=1.0)["outputs"] with self.test_session(): tf.global_variables_initializer().run() diff --git a/tensor2tensor/models/vanilla_gan.py b/tensor2tensor/models/vanilla_gan.py index e78d56679..0aa32b136 100644 --- a/tensor2tensor/models/vanilla_gan.py +++ b/tensor2tensor/models/vanilla_gan.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Simple Generative Adversarial Model with two linear layers. Example of how to create a GAN in T2T. diff --git a/tensor2tensor/models/xception.py b/tensor2tensor/models/xception.py index bec758687..0730966de 100644 --- a/tensor2tensor/models/xception.py +++ b/tensor2tensor/models/xception.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Xception.""" from __future__ import absolute_import diff --git a/tensor2tensor/models/xception_test.py b/tensor2tensor/models/xception_test.py index 90d2ba9fb..553a218b1 100644 --- a/tensor2tensor/models/xception_test.py +++ b/tensor2tensor/models/xception_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Xception tests.""" from __future__ import absolute_import diff --git a/tensor2tensor/problems.py b/tensor2tensor/problems.py index 3c25abb40..ae0a42da1 100644 --- a/tensor2tensor/problems.py +++ b/tensor2tensor/problems.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Access T2T Problems. See problems_test.py for basic usage. diff --git a/tensor2tensor/problems_test.py b/tensor2tensor/problems_test.py index 76742aafa..7c67f3de9 100644 --- a/tensor2tensor/problems_test.py +++ b/tensor2tensor/problems_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """tensor2tensor.problems test.""" from __future__ import absolute_import diff --git a/tensor2tensor/rl/collect.py b/tensor2tensor/rl/collect.py index a43b3a551..e4a19422b 100644 --- a/tensor2tensor/rl/collect.py +++ b/tensor2tensor/rl/collect.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Collect trajectories from interactions of agent with environment.""" import tensorflow as tf diff --git a/tensor2tensor/rl/envs/batch_env.py b/tensor2tensor/rl/envs/batch_env.py index 14f36a121..f3a72844c 100644 --- a/tensor2tensor/rl/envs/batch_env.py +++ b/tensor2tensor/rl/envs/batch_env.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Combine multiple environments to step them in batch.""" # The code was based on Danijar Hafner's code from tf.agents: diff --git a/tensor2tensor/rl/envs/in_graph_batch_env.py b/tensor2tensor/rl/envs/in_graph_batch_env.py index e671d8f1b..ee0ae94d1 100644 --- a/tensor2tensor/rl/envs/in_graph_batch_env.py +++ b/tensor2tensor/rl/envs/in_graph_batch_env.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Batch of environments inside the TensorFlow graph.""" # The code was based on Danijar Hafner's code from tf.agents: diff --git a/tensor2tensor/rl/envs/py_func_batch_env.py b/tensor2tensor/rl/envs/py_func_batch_env.py index 518c7bf29..011da95bf 100644 --- a/tensor2tensor/rl/envs/py_func_batch_env.py +++ b/tensor2tensor/rl/envs/py_func_batch_env.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Batch of environments inside the TensorFlow graph.""" # The code was based on Danijar Hafner's code from tf.agents: diff --git a/tensor2tensor/rl/envs/simulated_batch_env.py b/tensor2tensor/rl/envs/simulated_batch_env.py index 5d4a7e066..20fe868f5 100644 --- a/tensor2tensor/rl/envs/simulated_batch_env.py +++ b/tensor2tensor/rl/envs/simulated_batch_env.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Batch of environments inside the TensorFlow graph.""" # The code was based on Danijar Hafner's code from tf.agents: diff --git a/tensor2tensor/rl/envs/tf_atari_wrappers.py b/tensor2tensor/rl/envs/tf_atari_wrappers.py index cfe25caa2..83b9a9ae7 100644 --- a/tensor2tensor/rl/envs/tf_atari_wrappers.py +++ b/tensor2tensor/rl/envs/tf_atari_wrappers.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Batch of environments inside the TensorFlow graph.""" from __future__ import absolute_import diff --git a/tensor2tensor/rl/envs/utils.py b/tensor2tensor/rl/envs/utils.py index c5bf484f0..cbd4cbb60 100644 --- a/tensor2tensor/rl/envs/utils.py +++ b/tensor2tensor/rl/envs/utils.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utilities for using batched environments.""" # The code was based on Danijar Hafner's code from tf.agents: diff --git a/tensor2tensor/rl/model_rl_experiment.py b/tensor2tensor/rl/model_rl_experiment.py index 139cf6f46..87c07b26e 100644 --- a/tensor2tensor/rl/model_rl_experiment.py +++ b/tensor2tensor/rl/model_rl_experiment.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Training of model-based RL agents.""" import datetime @@ -69,8 +68,7 @@ def train(hparams, output_dir): FLAGS.model = hparams.generative_model FLAGS.hparams_set = hparams.generative_model_params FLAGS.train_steps = hparams.model_train_steps - FLAGS.train_steps = 1 - FLAGS.eval_steps = 1 + FLAGS.eval_steps = 10 t2t_trainer.main([]) # Dump frames from env model. @@ -109,13 +107,13 @@ def train(hparams, output_dir): def main(_): hparams = tf.contrib.training.HParams( - epochs=100, - true_env_generator_num_steps=100, + epochs=10, + true_env_generator_num_steps=5000, generative_model="basic_conv_gen", generative_model_params="basic_conv", - model_train_steps=5000, + model_train_steps=15000, simulated_env_generator_num_steps=300, - ppo_epochs_num=2, + ppo_epochs_num=200, ppo_epoch_length=300, ) train(hparams, FLAGS.output_dir) diff --git a/tensor2tensor/rl/ppo.py b/tensor2tensor/rl/ppo.py index ca0481e9e..c5ff6e8a6 100644 --- a/tensor2tensor/rl/ppo.py +++ b/tensor2tensor/rl/ppo.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """PPO algorithm implementation. Based on: https://arxiv.org/abs/1707.06347 diff --git a/tensor2tensor/rl/rl_trainer_lib.py b/tensor2tensor/rl/rl_trainer_lib.py index e710a2eeb..d244bcf2b 100644 --- a/tensor2tensor/rl/rl_trainer_lib.py +++ b/tensor2tensor/rl/rl_trainer_lib.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Library for training of RL agent with PPO algorithm.""" from __future__ import absolute_import diff --git a/tensor2tensor/rl/rl_trainer_lib_test.py b/tensor2tensor/rl/rl_trainer_lib_test.py index 461e7a0da..75e7ac978 100644 --- a/tensor2tensor/rl/rl_trainer_lib_test.py +++ b/tensor2tensor/rl/rl_trainer_lib_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests of basic flow of collecting trajectories and training PPO.""" # Dependency imports diff --git a/tensor2tensor/rl/t2t_rl_trainer.py b/tensor2tensor/rl/t2t_rl_trainer.py index 188433789..2f884d462 100644 --- a/tensor2tensor/rl/t2t_rl_trainer.py +++ b/tensor2tensor/rl/t2t_rl_trainer.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Training of RL agent with PPO algorithm.""" # Dependency imports diff --git a/tensor2tensor/serving/export.py b/tensor2tensor/serving/export.py index 5b5dccf5b..1f65b67cb 100644 --- a/tensor2tensor/serving/export.py +++ b/tensor2tensor/serving/export.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Export a trained model for serving.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/serving/query.py b/tensor2tensor/serving/query.py index 1af0e9f2d..7d89d3c7e 100644 --- a/tensor2tensor/serving/query.py +++ b/tensor2tensor/serving/query.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Query an exported model. Py2 only. Install tensorflow-serving-api.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/serving/serving_utils.py b/tensor2tensor/serving/serving_utils.py index 5bb2fe724..e755d9551 100644 --- a/tensor2tensor/serving/serving_utils.py +++ b/tensor2tensor/serving/serving_utils.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utilities for serving tensor2tensor.""" from __future__ import absolute_import diff --git a/tensor2tensor/test_data/example_usr_dir/__init__.py b/tensor2tensor/test_data/example_usr_dir/__init__.py index 61a511e17..c1d23fdfa 100644 --- a/tensor2tensor/test_data/example_usr_dir/__init__.py +++ b/tensor2tensor/test_data/example_usr_dir/__init__.py @@ -12,6 +12,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Example T2T user directory.""" from . import my_submodule diff --git a/tensor2tensor/test_data/example_usr_dir/my_submodule.py b/tensor2tensor/test_data/example_usr_dir/my_submodule.py index e3ffd962c..a6b31469b 100644 --- a/tensor2tensor/test_data/example_usr_dir/my_submodule.py +++ b/tensor2tensor/test_data/example_usr_dir/my_submodule.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Example registrations for T2T.""" import re diff --git a/tensor2tensor/utils/adafactor.py b/tensor2tensor/utils/adafactor.py index 69a240733..2d328d4ba 100644 --- a/tensor2tensor/utils/adafactor.py +++ b/tensor2tensor/utils/adafactor.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Optimization.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/utils/avg_checkpoints.py b/tensor2tensor/utils/avg_checkpoints.py index 6d73de2ac..15175405d 100644 --- a/tensor2tensor/utils/avg_checkpoints.py +++ b/tensor2tensor/utils/avg_checkpoints.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Script to average values of variables in a list of checkpoint files.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/utils/beam_search.py b/tensor2tensor/utils/beam_search.py index e42fd9621..c01df0276 100644 --- a/tensor2tensor/utils/beam_search.py +++ b/tensor2tensor/utils/beam_search.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Implementation of beam search with penalties.""" from __future__ import absolute_import @@ -89,10 +88,6 @@ def get_state_shape_invariants(tensor): return tf.TensorShape(shape) -def log_prob_from_logits(logits, reduce_axis=-1): - return logits - tf.reduce_logsumexp(logits, axis=reduce_axis, keep_dims=True) - - def compute_batch_indices(batch_size, beam_size): """Computes the i'th coordinate that contains the batch index for gathers. @@ -360,7 +355,7 @@ def grow_topk(i, alive_seq, alive_log_probs, states): logits = tf.reshape(flat_logits, [batch_size, beam_size, -1]) # Convert logits to normalized log probs - candidate_log_probs = log_prob_from_logits(logits) + candidate_log_probs = common_layers.log_prob_from_logits(logits) # Multiply the probabilities by the current probabilities of the beam. # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1) diff --git a/tensor2tensor/utils/beam_search_test.py b/tensor2tensor/utils/beam_search_test.py index 0d4f35a0a..2c1262724 100644 --- a/tensor2tensor/utils/beam_search_test.py +++ b/tensor2tensor/utils/beam_search_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for tensor2tensor.beam_search.""" from __future__ import absolute_import diff --git a/tensor2tensor/utils/bleu_hook.py b/tensor2tensor/utils/bleu_hook.py index e18b72c0b..a0fa7fe3e 100644 --- a/tensor2tensor/utils/bleu_hook.py +++ b/tensor2tensor/utils/bleu_hook.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """BLEU metric util used during eval for MT.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/utils/bleu_hook_test.py b/tensor2tensor/utils/bleu_hook_test.py index 01de21d26..522573b08 100644 --- a/tensor2tensor/utils/bleu_hook_test.py +++ b/tensor2tensor/utils/bleu_hook_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # coding=utf-8 """Tests for tensor2tensor.utils.bleu_hook.""" diff --git a/tensor2tensor/utils/checkpoint_compatibility_test.py b/tensor2tensor/utils/checkpoint_compatibility_test.py index c95fc8886..73a2958e3 100644 --- a/tensor2tensor/utils/checkpoint_compatibility_test.py +++ b/tensor2tensor/utils/checkpoint_compatibility_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Test for checkpoint compatibility.""" # The checkpoint in test_data/transformer_test_ckpt is generated with the OSS # release. diff --git a/tensor2tensor/utils/cloud_mlengine.py b/tensor2tensor/utils/cloud_mlengine.py index 36e3bfcb5..248d96666 100755 --- a/tensor2tensor/utils/cloud_mlengine.py +++ b/tensor2tensor/utils/cloud_mlengine.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Launch on GCP's ML Engine.""" import datetime diff --git a/tensor2tensor/utils/cloud_tpu.py b/tensor2tensor/utils/cloud_tpu.py index d1ea417be..ef78458a9 100644 --- a/tensor2tensor/utils/cloud_tpu.py +++ b/tensor2tensor/utils/cloud_tpu.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Launch on TPU on GCP.""" from __future__ import absolute_import from __future__ import division @@ -49,7 +48,8 @@ def __init__(self): self._tmp_dir = os.path.expanduser("~/.t2t/cloud_state") tf.gfile.MakeDirs(self._tmp_dir) - def cleanup(self, current_vm_name=None, current_tpu_name=None): + def cleanup(self, current_vm_name=None, current_tpu_name=None, + skip_confirmation=False): process_pids = os.listdir(self._tmp_dir) for pid in process_pids: try: @@ -74,13 +74,19 @@ def cleanup(self, current_vm_name=None, current_tpu_name=None): if (info["vm_name"] != current_vm_name and info["vm_name"] in zip(*list_vm_names_and_ips())[0]): print("Old VM %s found. Delete?" % info["vm_name"]) - if confirm(): + if skip_confirmation: del_vm = True + else: + if confirm(): + del_vm = True if (info["tpu_name"] != current_tpu_name and info["tpu_name"] in zip(*list_tpu_names_and_ips())[0]): print("Old TPU %s found. Delete?" % info["tpu_name"]) - if confirm(): + if skip_confirmation: del_tpu = True + else: + if confirm(): + del_tpu = True results = [] pool = mp.Pool(2) @@ -111,25 +117,29 @@ def add_current(self, tunnel_pid, vm_name, tpu_name, delete_on_done): @contextlib.contextmanager -def cloud_tpu(vm_name, tpu_name, delete_on_done=False): +def cloud_tpu(vm_name, tpu_name, delete_on_done=False, skip_confirmation=False): """Gets or creates a VM and TPU instance, and forwards ports. Args: vm_name: str, name of VM. tpu_name: str, name of TPU instance. delete_on_done: bool, whether to delete the instances when done. + skip_confirmation: bool, whether to skip launch confirmations. Yields: master: str, grpc master pointing to the TPU instance. """ state = CloudState() # Read state from previous processes and possibly cleanup - state.cleanup(current_vm_name=vm_name, current_tpu_name=tpu_name) + state.cleanup(current_vm_name=vm_name, current_tpu_name=tpu_name, + skip_confirmation=skip_confirmation) done_str = "" if delete_on_done else "NOT " print("Will %sdelete VM and TPU instance on done." % done_str) - assert confirm() - _, tpu_ip = create_vm_tpu_pair(vm_name, tpu_name) + if not skip_confirmation: + assert confirm() + _, tpu_ip = create_vm_tpu_pair(vm_name, tpu_name, + skip_confirmation=skip_confirmation) with tpu_tunnel(vm_name, tpu_ip) as (local_ports, tunnel_pid): master = "grpc://localhost:%d" % local_ports["tpu"] @@ -151,8 +161,8 @@ def cloud_tpu(vm_name, tpu_name, delete_on_done=False): class Gcloud(object): """gcloud command strings.""" # Note these can be modified by set_versions - VM_VERSION = "tf-1-5" - TPU_VERSION = "1.5" + VM_VERSION = "tf-1-7" + TPU_VERSION = "1.7" @classmethod def set_versions(cls, vm, tpu): @@ -175,16 +185,16 @@ def create_vm(cls): @classmethod def create_tpu(cls): create_tpu_str = """ - gcloud alpha compute tpus create \ + gcloud beta compute tpus create \ {name} \ --range={tpu_ip}/29 \ --version=%s """ % cls.TPU_VERSION return create_tpu_str - DELETE_TPU = "gcloud alpha compute tpus delete {name} --quiet" + DELETE_TPU = "gcloud beta compute tpus delete {name} --quiet" - LIST_TPU = "gcloud alpha compute tpus list" + LIST_TPU = "gcloud beta compute tpus list" LIST_VM = "gcloud compute instances list" SSH_LOCAL_PORT_FORWARD = "-L {local_port}:{host}:{remote_port}" @@ -310,7 +320,8 @@ def tpu_tunnel(vm_name, tpu_ip): yield local_ports, tunnel_process.pid -def create_vm_tpu_pair(vm_name, tpu_name, reuse_if_exists=True): +def create_vm_tpu_pair(vm_name, tpu_name, reuse_if_exists=True, + skip_confirmation=False): """Create a VM and paired TPU instance. Args: @@ -318,6 +329,7 @@ def create_vm_tpu_pair(vm_name, tpu_name, reuse_if_exists=True): tpu_name: str, name for TPU instance. reuse_if_exists: bool, if True, this will act as a get or create. If False and vm_name or tpu_name already exists, will error. + skip_confirmation: bool, whether to skip launch confirmations. Returns: tuple: (vm_ip, tpu_ip) @@ -328,8 +340,8 @@ def create_vm_tpu_pair(vm_name, tpu_name, reuse_if_exists=True): vm_info = list_vm_names_and_ips() tpu_info = list_tpu_names_and_ips() - vm_names = zip(*vm_info)[0] - tpu_names = zip(*tpu_info)[0] + vm_names = zip(*vm_info)[0] if vm_info else [] + tpu_names = zip(*tpu_info)[0] if tpu_info else [] make_vm = False vm_ip = None @@ -341,7 +353,8 @@ def create_vm_tpu_pair(vm_name, tpu_name, reuse_if_exists=True): vm_ip = vm_info[vm_names.index(vm_name)][1] else: print("Creating VM %s" % vm_name) - assert confirm() + if not skip_confirmation: + assert confirm() make_vm = True make_tpu = False @@ -354,7 +367,8 @@ def create_vm_tpu_pair(vm_name, tpu_name, reuse_if_exists=True): tpu_ip = tpu_info[tpu_names.index(tpu_name)][1] else: print("Creating TPU instance %s" % tpu_name) - assert confirm() + if not skip_confirmation: + assert confirm() make_tpu = True # Create VM and TPU in parallel diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index 9ad3a712f..3f660b1ec 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data reader module.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/utils/data_reader_test.py b/tensor2tensor/utils/data_reader_test.py index ec5897092..176785321 100644 --- a/tensor2tensor/utils/data_reader_test.py +++ b/tensor2tensor/utils/data_reader_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Data reader test.""" from __future__ import absolute_import diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index 7daa12b21..2856c76ad 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Decoding utilities.""" from __future__ import absolute_import from __future__ import division @@ -120,6 +119,7 @@ def decode_from_dataset(estimator, dataset_kwargs = { "shard": shard, "dataset_split": dataset_split, + "max_records": decode_hp.num_samples } # Build the inference input function diff --git a/tensor2tensor/utils/devices.py b/tensor2tensor/utils/devices.py index a405b498e..9371a4885 100644 --- a/tensor2tensor/utils/devices.py +++ b/tensor2tensor/utils/devices.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Device placement and data parallelism.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/utils/diet.py b/tensor2tensor/utils/diet.py index df78c77a1..2efaaa0b8 100644 --- a/tensor2tensor/utils/diet.py +++ b/tensor2tensor/utils/diet.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Diet variables are much more memory-efficient than regular variables. Using diet variables, we can reduce memory overhead per parameter from diff --git a/tensor2tensor/utils/diet_test.py b/tensor2tensor/utils/diet_test.py index 4e9202440..93fcaefa5 100644 --- a/tensor2tensor/utils/diet_test.py +++ b/tensor2tensor/utils/diet_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for common layers.""" from __future__ import absolute_import diff --git a/tensor2tensor/utils/expert_utils.py b/tensor2tensor/utils/expert_utils.py index a83616500..3cfa3592d 100644 --- a/tensor2tensor/utils/expert_utils.py +++ b/tensor2tensor/utils/expert_utils.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utilities for creating Sparsely-Gated Mixture-of-Experts Layers. See "Outrageously Large Neural Networks" diff --git a/tensor2tensor/utils/expert_utils_test.py b/tensor2tensor/utils/expert_utils_test.py index eccad731b..b90edec8c 100644 --- a/tensor2tensor/utils/expert_utils_test.py +++ b/tensor2tensor/utils/expert_utils_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for tensor2tensor.utils.expert_utils.""" from __future__ import absolute_import diff --git a/tensor2tensor/utils/flags.py b/tensor2tensor/utils/flags.py index 20827e69a..7cd9b3f3f 100644 --- a/tensor2tensor/utils/flags.py +++ b/tensor2tensor/utils/flags.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Common command-line flags.""" from __future__ import absolute_import diff --git a/tensor2tensor/utils/get_rouge.py b/tensor2tensor/utils/get_rouge.py index 50f0994d9..a56319154 100644 --- a/tensor2tensor/utils/get_rouge.py +++ b/tensor2tensor/utils/get_rouge.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Computing rouge scores using pyrouge.""" from __future__ import absolute_import diff --git a/tensor2tensor/utils/learning_rate.py b/tensor2tensor/utils/learning_rate.py index f7fe33100..ec08cca2c 100644 --- a/tensor2tensor/utils/learning_rate.py +++ b/tensor2tensor/utils/learning_rate.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Optimization.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/utils/metrics.py b/tensor2tensor/utils/metrics.py index 8375121d5..054912693 100644 --- a/tensor2tensor/utils/metrics.py +++ b/tensor2tensor/utils/metrics.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utils for metrics used in eval.""" from __future__ import absolute_import from __future__ import division @@ -353,16 +352,14 @@ def sigmoid_cross_entropy_one_hot(logits, labels, weights_fn=None): return cross_entropy, tf.constant(1.0) -def roc_auc(logits, - labels, - weights_fn=None): +def roc_auc(logits, labels, weights_fn=None): """Calculate ROC AUC. Requires binary classes. Args: - logits: Tensor of size [batch-size, o=1, p=1, num-classes] - labels: Tensor of size [batch-size, o=1, p=1, num-classes] + logits: Tensor of size [batch_size, 1, 1, num_classes] + labels: Tensor of size [batch_size, 1, 1, num_classes] weights_fn: Function that takes in labels and weighs examples (unused) Returns: ROC AUC (scalar), weights @@ -370,7 +367,7 @@ def roc_auc(logits, del weights_fn with tf.variable_scope("roc_auc", values=[logits, labels]): predictions = tf.argmax(logits, axis=-1) - _, auc = tf.metrics.auc(labels, predictions, curve='ROC') + _, auc = tf.metrics.auc(labels, predictions, curve="ROC") return auc, tf.constant(1.0) @@ -541,8 +538,8 @@ def metric_means(): Metrics.SIGMOID_RECALL_ONE_HOT: sigmoid_recall_one_hot, Metrics.SIGMOID_PRECISION_ONE_HOT: sigmoid_precision_one_hot, Metrics.SIGMOID_CROSS_ENTROPY_ONE_HOT: sigmoid_cross_entropy_one_hot, - Metrics.ROC_AUC: roc_auc, Metrics.SET_PRECISION: set_precision, Metrics.SET_RECALL: set_recall, + Metrics.ROC_AUC: roc_auc, Metrics.IMAGE_SUMMARY: image_summary, } diff --git a/tensor2tensor/utils/metrics_hook.py b/tensor2tensor/utils/metrics_hook.py index 78681df50..b87564e0c 100644 --- a/tensor2tensor/utils/metrics_hook.py +++ b/tensor2tensor/utils/metrics_hook.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Summary-based SessionRunHooks.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/utils/metrics_hook_test.py b/tensor2tensor/utils/metrics_hook_test.py index 1350679a7..3eac23208 100644 --- a/tensor2tensor/utils/metrics_hook_test.py +++ b/tensor2tensor/utils/metrics_hook_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for metrics_hook.""" from __future__ import absolute_import diff --git a/tensor2tensor/utils/metrics_test.py b/tensor2tensor/utils/metrics_test.py index 96d4684de..a47379a65 100644 --- a/tensor2tensor/utils/metrics_test.py +++ b/tensor2tensor/utils/metrics_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for tensor2tensor.utils.metrics.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/utils/modality.py b/tensor2tensor/utils/modality.py index 611d4ea04..c26154783 100644 --- a/tensor2tensor/utils/modality.py +++ b/tensor2tensor/utils/modality.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Modality base class - defines the bottom and top of the model.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/utils/optimize.py b/tensor2tensor/utils/optimize.py index 86d8a3b7c..9395e2fa2 100644 --- a/tensor2tensor/utils/optimize.py +++ b/tensor2tensor/utils/optimize.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Optimization.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/utils/quantization.py b/tensor2tensor/utils/quantization.py index 339bcf9f5..21215d91e 100644 --- a/tensor2tensor/utils/quantization.py +++ b/tensor2tensor/utils/quantization.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utilities related to using bfloat16 activations and/or parameters.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/utils/registry.py b/tensor2tensor/utils/registry.py index ef0a6cfc0..5fcc07e43 100644 --- a/tensor2tensor/utils/registry.py +++ b/tensor2tensor/utils/registry.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Registry for models, hyperparameter settings, problem types, and datasets. Define a new model by subclassing T2TModel and register it: diff --git a/tensor2tensor/utils/registry_test.py b/tensor2tensor/utils/registry_test.py index 9c7d9d8f2..b0c85027f 100644 --- a/tensor2tensor/utils/registry_test.py +++ b/tensor2tensor/utils/registry_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for tensor2tensor.registry.""" from __future__ import absolute_import diff --git a/tensor2tensor/utils/rouge.py b/tensor2tensor/utils/rouge.py index cb3c9af4b..7a956b5dd 100644 --- a/tensor2tensor/utils/rouge.py +++ b/tensor2tensor/utils/rouge.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # coding=utf-8 """ROUGE metric implementation. diff --git a/tensor2tensor/utils/rouge_test.py b/tensor2tensor/utils/rouge_test.py index a73fd309b..2760c4b42 100644 --- a/tensor2tensor/utils/rouge_test.py +++ b/tensor2tensor/utils/rouge_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for Rouge metric.""" from __future__ import absolute_import diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 50d036f33..f314c33bd 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """T2TModel Base Class.""" from __future__ import absolute_import from __future__ import division @@ -289,10 +288,11 @@ def bottom(self, features): transformed_features[k] = v.targets_bottom(features[k]) else: with tf.variable_scope(target_modality.name): - log_info("Transforming 'targets' with %s.targets_bottom", - target_modality.name) - transformed_features["targets"] = target_modality.targets_bottom( - features["targets"]) + if "targets" in features: + log_info("Transforming 'targets' with %s.targets_bottom", + target_modality.name) + transformed_features["targets"] = target_modality.targets_bottom( + features["targets"]) for key in features: if key not in transformed_features: @@ -333,7 +333,7 @@ def _top_single(self, body_output, target_modality, features): self.hparams.mode == tf.estimator.ModeKeys.PREDICT and not self.hparams.force_full_predict) if not last_only: - logits = target_modality.top(body_output, features["targets"]) + logits = target_modality.top(body_output, features.get("targets")) else: # Take body outputs for the last position only, and targets too. last_position_body_output = tf.expand_dims( @@ -503,7 +503,8 @@ def infer(self, decode_length=50, beam_size=1, top_beams=1, - alpha=0.0): + alpha=0.0, + use_tpu=False): """A inference method. Quadratic time in decode_length. @@ -515,6 +516,7 @@ def infer(self, top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. + use_tpu: bool, whether to build the inference graph for TPU. Returns: A dict of decoding results { @@ -955,8 +957,7 @@ def estimator_model_fn(cls, # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: - assert not use_tpu - return model.estimator_spec_predict(features) + return model.estimator_spec_predict(features, use_tpu=use_tpu) # TRAIN and EVAL modes if hparams.eval_run_autoregressive and mode == tf.estimator.ModeKeys.EVAL: @@ -993,7 +994,7 @@ def estimator_model_fn(cls, tf.summary.scalar(loss_name, loss_val) # Accumulate losses - loss = sum(losses_dict.values()) + loss = sum(losses_dict[key] for key in sorted(losses_dict.keys())) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: @@ -1070,7 +1071,7 @@ def estimator_spec_eval(self, features, logits, labels, loss, losses_dict): eval_metric_ops=eval_metrics, loss=loss) - def estimator_spec_predict(self, features): + def estimator_spec_predict(self, features, use_tpu=False): """Construct EstimatorSpec for PREDICT mode.""" decode_hparams = self._decode_hparams infer_out = self.infer( @@ -1079,7 +1080,8 @@ def estimator_spec_predict(self, features): top_beams=(decode_hparams.beam_size if decode_hparams.return_beams else 1), alpha=decode_hparams.alpha, - decode_length=decode_hparams.extra_length) + decode_length=decode_hparams.extra_length, + use_tpu=use_tpu) if isinstance(infer_out, dict): outputs = infer_out["outputs"] scores = infer_out["scores"] @@ -1111,14 +1113,19 @@ def estimator_spec_predict(self, features): _remove_summaries() - return tf.estimator.EstimatorSpec( - tf.estimator.ModeKeys.PREDICT, - predictions=predictions, - export_outputs={ - tf.saved_model.signature_constants. - DEFAULT_SERVING_SIGNATURE_DEF_KEY: - tf.estimator.export.PredictOutput(export_out) - }) + if use_tpu: + return tf.contrib.tpu.TPUEstimatorSpec( + tf.estimator.ModeKeys.PREDICT, + predictions=predictions) + else: + return tf.estimator.EstimatorSpec( + tf.estimator.ModeKeys.PREDICT, + predictions=predictions, + export_outputs={ + tf.saved_model.signature_constants. + DEFAULT_SERVING_SIGNATURE_DEF_KEY: + tf.estimator.export.PredictOutput(export_out) + }) def _normalize_body_output(self, body_out): if isinstance(body_out, tuple): diff --git a/tensor2tensor/utils/trainer_lib.py b/tensor2tensor/utils/trainer_lib.py index 9cfd1264a..ddc2da83e 100644 --- a/tensor2tensor/utils/trainer_lib.py +++ b/tensor2tensor/utils/trainer_lib.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Library for training. See t2t_trainer.py.""" from __future__ import absolute_import @@ -188,12 +187,16 @@ def create_estimator(model_name, problem = hparams.problem batch_size = (problem.tpu_batch_size_per_shard(hparams) * run_config.tpu_config.num_shards) + predict_batch_size = batch_size + if decode_hparams and decode_hparams.batch_size: + predict_batch_size = decode_hparams.batch_size return tf.contrib.tpu.TPUEstimator( model_fn=model_fn, model_dir=run_config.model_dir, config=run_config, train_batch_size=batch_size, - eval_batch_size=batch_size if "eval" in schedule else None) + eval_batch_size=batch_size if "eval" in schedule else None, + predict_batch_size=predict_batch_size) else: return tf.estimator.Estimator( model_fn=model_fn, model_dir=run_config.model_dir, config=run_config) diff --git a/tensor2tensor/utils/trainer_lib_test.py b/tensor2tensor/utils/trainer_lib_test.py index 6ae599721..271394fe1 100644 --- a/tensor2tensor/utils/trainer_lib_test.py +++ b/tensor2tensor/utils/trainer_lib_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for trainer_lib.""" from __future__ import absolute_import diff --git a/tensor2tensor/utils/usr_dir.py b/tensor2tensor/utils/usr_dir.py index 4aab19ac9..367d98c1d 100644 --- a/tensor2tensor/utils/usr_dir.py +++ b/tensor2tensor/utils/usr_dir.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utility to load code from an external user-supplied directory.""" from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/utils/yellowfin.py b/tensor2tensor/utils/yellowfin.py index 8653d2baf..6e0252882 100644 --- a/tensor2tensor/utils/yellowfin.py +++ b/tensor2tensor/utils/yellowfin.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """YellowFin for TensorFlow. Thanks Jian Zhang: zjian [@] stanford [.] edu.""" from __future__ import absolute_import diff --git a/tensor2tensor/utils/yellowfin_test.py b/tensor2tensor/utils/yellowfin_test.py index 4d3a2584a..914068e41 100644 --- a/tensor2tensor/utils/yellowfin_test.py +++ b/tensor2tensor/utils/yellowfin_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """YellowFin Test Module for TensorFlow.""" from __future__ import absolute_import diff --git a/tensor2tensor/visualization/__init__.py b/tensor2tensor/visualization/__init__.py index dba7ece95..4bd418a74 100644 --- a/tensor2tensor/visualization/__init__.py +++ b/tensor2tensor/visualization/__init__.py @@ -12,3 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + diff --git a/tensor2tensor/visualization/attention.py b/tensor2tensor/visualization/attention.py index 56ece8154..abf9a0e57 100644 --- a/tensor2tensor/visualization/attention.py +++ b/tensor2tensor/visualization/attention.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Module for postprocessing and displaying transformer attentions. This module is designed to be called from an ipython notebook. diff --git a/tensor2tensor/visualization/visualization.py b/tensor2tensor/visualization/visualization.py index 119e7dbb3..c32b44295 100644 --- a/tensor2tensor/visualization/visualization.py +++ b/tensor2tensor/visualization/visualization.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Shared code for visualizing transformer attentions.""" from __future__ import absolute_import diff --git a/tensor2tensor/visualization/visualization_test.py b/tensor2tensor/visualization/visualization_test.py index c430c815e..c40204ed2 100644 --- a/tensor2tensor/visualization/visualization_test.py +++ b/tensor2tensor/visualization/visualization_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for visualization library. IF ANY OF THESE TESTS BREAK PLEASE UPDATE THE CODE IN THE VIZ NOTEBOOK