diff --git a/README.md b/README.md index 236d279c2..4e56d7855 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ You can chat with us and other users on with T2T announcements. Here is a one-command version that installs tensor2tensor, downloads the data, -trains an English-German translation model, and lets you use it interactively: +trains an English-German translation model, and evaluates it: ``` pip install tensor2tensor && t2t-trainer \ --generate_data \ @@ -37,7 +37,18 @@ pip install tensor2tensor && t2t-trainer \ --problems=translate_ende_wmt32k \ --model=transformer \ --hparams_set=transformer_base_single_gpu \ - --output_dir=~/t2t_train/base \ + --output_dir=~/t2t_train/base +``` + +You can decode from the model interactively: + +``` +t2t-decoder \ + --data_dir=~/t2t_data \ + --problems=translate_ende_wmt32k \ + --model=transformer \ + --hparams_set=transformer_base_single_gpu \ + --output_dir=~/t2t_train/base --decode_interactive ``` @@ -106,14 +117,12 @@ echo "Goodbye world" >> $DECODE_FILE BEAM_SIZE=4 ALPHA=0.6 -t2t-trainer \ +t2t-decoder \ --data_dir=$DATA_DIR \ --problems=$PROBLEM \ --model=$MODEL \ --hparams_set=$HPARAMS \ --output_dir=$TRAIN_DIR \ - --train_steps=0 \ - --eval_steps=0 \ --decode_beam_size=$BEAM_SIZE \ --decode_alpha=$ALPHA \ --decode_from_file=$DECODE_FILE diff --git a/docs/example_life.md b/docs/example_life.md new file mode 100644 index 000000000..2983f5077 --- /dev/null +++ b/docs/example_life.md @@ -0,0 +1,34 @@ +# T2T: Life of an Example + +[![PyPI +version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor) +[![GitHub +Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues) +[![Contributions +welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) +[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) +[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) + +This document show how a training example passes through the T2T pipeline, +and how all its parts are connected to work together. + +## The Life of an Example + +A training example passes the following stages in T2T: +* raw input (text from command line or file) +* encoded input after [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `encode` is usually a sparse tensor, e.g., a vector of `tf.int32`s +* batched input after [data input pipeline](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/data_reader.py#L242) where the inputs, after [Problem.preprocess_examples](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L188) are grouped by their length and made into batches. +* dense input after being processed by a [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `bottom`. +* dense output after [T2T.model_fn_body](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/t2t_model.py#L542) +* back to sparse output through [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `top`. +* if decoding, back through [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `decode` to display on the screen. + +We go into these phases step by step below. + +## Feature Encoders + +TODO: describe [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) which is a dict of encoders that have `encode` and `decode` functions. + +## Modalities + +TODO: describe [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) which has `bottom` and `top` but also sharded versions and one for targets. diff --git a/docs/index.md b/docs/index.md index a5eeba137..9394809b3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,11 +1,4 @@ -# T2T: Tensor2Tensor Transformers - -Check us out on - -GitHub - - -. +# Tensor2Tensor Docs Index [![PyPI version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor) @@ -16,8 +9,26 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO [![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) -See our -[README](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/README.md) -for documentation. -More documentation and tutorials coming soon... +Welcome to Tensor2Tensor! + +Tensor2Tensor, or T2T for short, is a library we use to create, +investigate and deploy deep learning models. This page hosts our +documentation, from basic tutorials to full code documentation. + +## Basics + +* [Walkthrough: Install and Run](walkthrough.md) +* [Tutorial: Train on Your Data](new_problem.md) +* [Tutorial: Create Your Own Model](new_model.md) + +## Deep Dive + +* [Life of an Example](example_life.md): how all parts of T2T are connected and work together +* [Distributed Training](distributed_training.md) + +## Code documentation + +See our +[README](https://github.com/tensorflow/tensor2tensor/blob/master/README.md) +for now, code docs coming. diff --git a/docs/new_model.md b/docs/new_model.md new file mode 100644 index 000000000..5968c8325 --- /dev/null +++ b/docs/new_model.md @@ -0,0 +1,16 @@ +# T2T: Create Your Own Model + +[![PyPI +version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor) +[![GitHub +Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues) +[![Contributions +welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) +[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) +[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) + +Here we show how to create your own model in T2T. + +## The T2TModel class + +TODO: complete. diff --git a/docs/new_problem.md b/docs/new_problem.md new file mode 100644 index 000000000..c859c6eba --- /dev/null +++ b/docs/new_problem.md @@ -0,0 +1,240 @@ +# T2T: Train on Your Own Data + +[![PyPI +version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor) +[![GitHub +Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues) +[![Contributions +welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) +[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) +[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) + +Let's add a new dataset together and train the transformer model. We'll be learning to define English words by training the transformer to "translate" between English words and their definitions on a character level. + +# About the Problem + +For each problem we want to tackle we create a new problem class and register it. Let's call our problem `Word2def`. + +Since many text2text problems share similar methods, there's already a class called `Text2TextProblem` that extends the base problem class, `Problem` (both found in `problem.py`). + +For our problem, we can go ahead and create the file `word2def.py` in the `data_generators` folder and add our new problem, `Word2def`, which extends `TranslateProblem`. Let's also register it while we're at it so we can specify the problem through flags. + +```python +@registry.register_problem() +class Word2def(problem.Text2TextProblem): + """Problem spec for English word to dictionary definition.""" + return NotImplementedError() +``` + +We need to implement the following methods from `Text2TextProblem` in our new class: +* is_character_level +* targeted_vocab_size +* generator +* input_space_id +* target_space_id +* num_shards +* vocab_name +* use_subword_tokenizer + +Let's tackle them one by one: + +**input_space_id, target_space_id, is_character_level, targeted_vocab_size, use_subword_tokenizer**: + +SpaceIDs tell Tensor2Tensor what sort of space the input and target tensors are in. These are things like, EN_CHR (English character), EN_TOK (English token), AUDIO_WAV (audio waveform), IMAGE, DNA (genetic bases). The complete list can be found at `data_generators/problem.py` in the class `SpaceID`. + +Since we're generating definitions and feeding in words at the character level, we set `is_character_level` to true, and use the same SpaceID, EN_CHR, for both input and target. Additionally, since we aren't using tokens, we don't need to give a `targeted_vocab_size` or define `use_subword_tokenizer`. + +**vocab_name**: + +`vocab_name` will be used to name your vocabulary files. We can call ours `'vocab.word2def.en'` + +**num_shards**: + +The number of shards to break data files into. + +```python +@registry.register_problem() +class Word2def(problem.Text2TextProblem): + """Problem spec for English word to dictionary definition.""" + def is_character_level(self): + return True + + @property + def vocab_name(self): + return "vocab.word2def.en" + + @property + def input_space_id(self): + return problem.SpaceID.EN_CHR + + @property + def target_space_id(self): + return problem.SpaceID.EN_CHR + + @property + def num_shards(self): + return 100 + + @property + def use_subword_tokenizer(self): + return False +``` + +**generator**: + +We're almost done. `generator` generates the training and evaluation data and stores them in files like "word2def_train.lang1" in your DATA_DIR. Thankfully several commonly used methods like `character_generator`, and `token_generator` are already written in the file `wmt.py`. We will import `character_generator` and write: +```python + def generator(self, data_dir, tmp_dir, train): + character_vocab = text_encoder.ByteTextEncoder() + datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS + tag = "train" if train else "dev" + return character_generator(datasets[0], datasets[1], character_vocab, EOS) +``` + +Now our `word2def.py` file looks like the below: + +```python +@registry.register_problem() +class Word2def(problem.Text2TextProblem): + """Problem spec for English word to dictionary definition.""" + @property + def is_character_level(self): + return True + + @property + def vocab_name(self): + return "vocab.word2def.en" + + def generator(self, data_dir, tmp_dir, train): + character_vocab = text_encoder.ByteTextEncoder() + datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS + tag = "train" if train else "dev" + return character_generator(datasets[0], datasets[1], character_vocab, EOS) + + @property + def input_space_id(self): + return problem.SpaceID.EN_CHR + + @property + def target_space_id(self): + return problem.SpaceID.EN_CHR + + @property + def num_shards(self): + return 100 + + @property + def use_subword_tokenizer(self): + return False +``` + +## Data: +Now we need to tell Tensor2Tensor where our data is located. + +I've gone ahead and split all words into a train and test set and saved them in files called `words.train.txt`, `words.test.txt`, +`definitions.train.txt`, and `definitions.test.txt` in a directory called `LOCATION_OF_DATA/`. Let's tell T2T where these files are: + +```python +# English Word2def datasets +_WORD2DEF_TRAIN_DATASETS = [ + [ + "LOCATION_OF_DATA/", ("words_train.txt", "definitions_train.txt") + ] +] +_WORD2DEF_TEST_DATASETS = [ + [ + "LOCATION_OF_DATA", ("words_test.txt", "definitions_test.txt") + ] +] +``` + +## Putting it all together + +Now our `word2def.py` file looks like: (with the correct imports) +```python +""" Problem definition for word to dictionary definition. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tarfile # do we need this import + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.data_generators.wmt import character_generator + +from tensor2tensor.utils import registry + +import tensorflow as tf + +FLAGS = tf.flags.FLAGS + +# English Word2def datasets +_WORD2DEF_TRAIN_DATASETS = [ + LOCATION_OF_DATA+'words_train.txt', + LOCATION_OF_DATA+'definitions_train.txt' +] + +_WORD2DEF_TEST_DATASETS = [ + LOCATION_OF_DATA+'words_test.txt', + LOCATION_OF_DATA+'definitions_test.txt' +] + +@registry.register_problem() +class Word2def(problem.Text2TextProblem): + """Problem spec for English word to dictionary definition.""" + @property + def is_character_level(self): + return True + + @property + def vocab_name(self): + return "vocab.word2def.en" + + def generator(self, data_dir, tmp_dir, train): + character_vocab = text_encoder.ByteTextEncoder() + datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS + tag = "train" if train else "dev" + return character_generator(datasets[0], datasets[1], character_vocab, EOS) + + @property + def input_space_id(self): + return problem.SpaceID.EN_CHR + + @property + def target_space_id(self): + return problem.SpaceID.EN_CHR + + @property + def num_shards(self): + return 100 + + @property + def use_subword_tokenizer(self): + return False + +``` + +# Hyperparameters +All hyperparamters inherit from `_default_hparams()` in `problem.py.` If you would like to customize your hyperparameters, add another method to the file `problem_hparams.py`. + +# Run the problem +Now that we've gotten our problem set up, let's train a model and generate definitions. + +We specify our problem name, the model, and hparams. +```bash +PROBLEM=word2def +MODEL=transformer +HPARAMS=transofmer_base_single_gpu +``` + +The rest of the steps are as given in the [walkthrough](walkthrough.md). + + +What if we wanted to train a model to generate words given definitions? In T2T, we can change the problem name to be `PROBLEM=word2def_rev`. + +All done. Let us know what definitions your model generated. diff --git a/docs/walkthrough.md b/docs/walkthrough.md new file mode 100644 index 000000000..57d7a03f4 --- /dev/null +++ b/docs/walkthrough.md @@ -0,0 +1,129 @@ +# T2T Install and Run Walkthrough + +[![PyPI +version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor) +[![GitHub +Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues) +[![Contributions +welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) +[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) +[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) + +Here is a one-command version that installs tensor2tensor, downloads the data, +trains an English-German translation model, and evaluates it: +``` +pip install tensor2tensor && t2t-trainer \ + --generate_data \ + --data_dir=~/t2t_data \ + --problems=translate_ende_wmt32k \ + --model=transformer \ + --hparams_set=transformer_base_single_gpu \ + --output_dir=~/t2t_train/base +``` + +You can decode from the model interactively: + +``` +t2t-decoder \ + --data_dir=~/t2t_data \ + --problems=translate_ende_wmt32k \ + --model=transformer \ + --hparams_set=transformer_base_single_gpu \ + --output_dir=~/t2t_train/base + --decode_interactive +``` + +## Walkthrough + +Here's a walkthrough training a good English-to-German translation +model using the Transformer model from [*Attention Is All You +Need*](https://arxiv.org/abs/1706.03762) on WMT data. + +``` +pip install tensor2tensor + +# See what problems, models, and hyperparameter sets are available. +# You can easily swap between them (and add new ones). +t2t-trainer --registry_help + +PROBLEM=translate_ende_wmt32k +MODEL=transformer +HPARAMS=transformer_base_single_gpu + +DATA_DIR=$HOME/t2t_data +TMP_DIR=/tmp/t2t_datagen +TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS + +mkdir -p $DATA_DIR $TMP_DIR $TRAIN_DIR + +# Generate data +t2t-datagen \ + --data_dir=$DATA_DIR \ + --tmp_dir=$TMP_DIR \ + --problem=$PROBLEM + +# Train +# * If you run out of memory, add --hparams='batch_size=1024'. +t2t-trainer \ + --data_dir=$DATA_DIR \ + --problems=$PROBLEM \ + --model=$MODEL \ + --hparams_set=$HPARAMS \ + --output_dir=$TRAIN_DIR + +# Decode + +DECODE_FILE=$DATA_DIR/decode_this.txt +echo "Hello world" >> $DECODE_FILE +echo "Goodbye world" >> $DECODE_FILE + +BEAM_SIZE=4 +ALPHA=0.6 + +t2t-trainer \ + --data_dir=$DATA_DIR \ + --problems=$PROBLEM \ + --model=$MODEL \ + --hparams_set=$HPARAMS \ + --output_dir=$TRAIN_DIR \ + --train_steps=0 \ + --eval_steps=0 \ + --decode_beam_size=$BEAM_SIZE \ + --decode_alpha=$ALPHA \ + --decode_from_file=$DECODE_FILE + +cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes +``` + +--- + +## Installation + +``` +# Assumes tensorflow or tensorflow-gpu installed +pip install tensor2tensor + +# Installs with tensorflow-gpu requirement +pip install tensor2tensor[tensorflow_gpu] + +# Installs with tensorflow (cpu) requirement +pip install tensor2tensor[tensorflow] +``` + +Binaries: + +``` +# Data generator +t2t-datagen + +# Trainer +t2t-trainer --registry_help +``` + +Library usage: + +``` +python -c "from tensor2tensor.models.transformer import Transformer" +``` + +--- diff --git a/setup.py b/setup.py index f32e8508c..b51070c77 100644 --- a/setup.py +++ b/setup.py @@ -5,17 +5,24 @@ setup( name='tensor2tensor', - version='1.2.0', + version='1.2.1', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', url='http://github.com/tensorflow/tensor2tensor', license='Apache 2.0', packages=find_packages(), - package_data={'tensor2tensor.data_generators': ['test_data/*']}, + package_data={ + 'tensor2tensor.data_generators': ['test_data/*'], + 'tensor2tensor.visualization': [ + 'attention.js', + 'TransformerVisualization.ipynb' + ], + }, scripts=[ 'tensor2tensor/bin/t2t-trainer', 'tensor2tensor/bin/t2t-datagen', + 'tensor2tensor/bin/t2t-decoder', 'tensor2tensor/bin/t2t-make-tf-configs', ], install_requires=[ diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index f7ea7e1f2..cb6253524 100644 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -42,7 +42,6 @@ from tensor2tensor.data_generators import algorithmic_math from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import from tensor2tensor.data_generators import audio from tensor2tensor.data_generators import generator_utils -from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import snli from tensor2tensor.data_generators import wmt from tensor2tensor.data_generators import wsj_parsing @@ -92,19 +91,6 @@ _SUPPORTED_PROBLEM_GENERATORS = { FLAGS.data_dir, FLAGS.tmp_dir, True, 2**14, 2**9), lambda: wsj_parsing.parsing_token_generator( FLAGS.data_dir, FLAGS.tmp_dir, False, 2**14, 2**9)), - "translate_ende_wmt_bpe32k": ( - lambda: wmt.ende_bpe_token_generator( - FLAGS.data_dir, FLAGS.tmp_dir, True), - lambda: wmt.ende_bpe_token_generator( - FLAGS.data_dir, FLAGS.tmp_dir, False)), - "languagemodel_1b32k": ( - lambda: lm1b.generator(FLAGS.tmp_dir, True), - lambda: lm1b.generator(FLAGS.tmp_dir, False) - ), - "languagemodel_1b_characters": ( - lambda: lm1b.generator(FLAGS.tmp_dir, True, characters=True), - lambda: lm1b.generator(FLAGS.tmp_dir, False, characters=True) - ), "inference_snli32k": ( lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15), lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15), diff --git a/tensor2tensor/bin/t2t-decoder b/tensor2tensor/bin/t2t-decoder new file mode 100644 index 000000000..5c3eeb293 --- /dev/null +++ b/tensor2tensor/bin/t2t-decoder @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Decode from trained T2T models. + +This binary performs inference using the Estimator API. + +Example usage to decode from dataset: + + t2t-decoder \ + --data_dir ~/data \ + --problems=algorithmic_identity_binary40 \ + --model=transformer + --hparams_set=transformer_base + +Set FLAGS.decode_interactive or FLAGS.decode_from_file for alternative decode +sources. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +# Dependency imports + +from tensor2tensor.utils import decoding +from tensor2tensor.utils import trainer_utils +from tensor2tensor.utils import usr_dir + +import tensorflow as tf + +flags = tf.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string("t2t_usr_dir", "", + "Path to a Python module that will be imported. The " + "__init__.py file should include the necessary imports. " + "The imported files should contain registrations, " + "e.g. @registry.register_model calls, that will then be " + "available to the t2t-decoder.") + + +def main(_): + tf.logging.set_verbosity(tf.logging.INFO) + usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) + trainer_utils.log_registry() + trainer_utils.validate_flags() + data_dir = os.path.expanduser(FLAGS.data_dir) + output_dir = os.path.expanduser(FLAGS.output_dir) + + hparams = trainer_utils.create_hparams( + FLAGS.hparams_set, FLAGS.problems, data_dir, passed_hparams=FLAGS.hparams) + estimator, _ = trainer_utils.create_experiment_components( + hparams=hparams, + output_dir=output_dir, + data_dir=data_dir, + model_name=FLAGS.model) + + if FLAGS.decode_interactive: + decoding.decode_interactively(estimator) + elif FLAGS.decode_from_file: + decoding.decode_from_file(estimator, FLAGS.decode_from_file) + else: + decoding.decode_from_dataset( + estimator, + FLAGS.problems.split("-"), + return_beams=FLAGS.decode_return_beams, + beam_size=FLAGS.decode_beam_size, + max_predictions=FLAGS.decode_num_samples, + decode_to_file=FLAGS.decode_to_file, + save_images=FLAGS.decode_save_images, + identity_output=FLAGS.identity_output) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index ec3a9d0af..f9afa895b 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -23,9 +23,11 @@ from tensor2tensor.data_generators import algorithmic_math from tensor2tensor.data_generators import audio from tensor2tensor.data_generators import cipher +from tensor2tensor.data_generators import cnn_dailymail from tensor2tensor.data_generators import desc2code from tensor2tensor.data_generators import ice_parsing from tensor2tensor.data_generators import image +from tensor2tensor.data_generators import imdb from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import ptb from tensor2tensor.data_generators import snli diff --git a/tensor2tensor/data_generators/cnn_dailymail.py b/tensor2tensor/data_generators/cnn_dailymail.py new file mode 100644 index 000000000..db4deae4e --- /dev/null +++ b/tensor2tensor/data_generators/cnn_dailymail.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generators for the CNN and Daily Mail datasets.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tarfile + +# Dependency imports + +import six +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.utils import registry + +import tensorflow as tf + + +# Links to data from http://cs.nyu.edu/~kcho/DMQA/ +_CNN_STORIES_DRIVE_URL = "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ" + +_DAILYMAIL_STORIES_DRIVE_URL = "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs" + + +# End-of-sentence marker. +EOS = text_encoder.EOS_ID + + +def _maybe_download_corpora(tmp_dir): + """Download corpora if necessary and unzip them. + + Args: + tmp_dir: directory containing dataset. + + Returns: + filepath of the downloaded corpus file. + """ + cnn_filename = "cnn_stories.tgz" + dailymail_filename = "dailymail_stories.tgz" + cnn_finalpath = os.path.join(tmp_dir, "cnn/stories/") + dailymail_finalpath = os.path.join(tmp_dir, "dailymail/stories/") + if not tf.gfile.Exists(cnn_finalpath): + cnn_file = generator_utils.maybe_download_from_drive( + tmp_dir, cnn_filename, _CNN_STORIES_DRIVE_URL) + with tarfile.open(cnn_file, "r:gz") as cnn_tar: + cnn_tar.extractall(tmp_dir) + if not tf.gfile.Exists(dailymail_finalpath): + dailymail_file = generator_utils.maybe_download_from_drive( + tmp_dir, dailymail_filename, _CNN_STORIES_DRIVE_URL) + with tarfile.open(dailymail_file, "r:gz") as dailymail_tar: + dailymail_tar.extractall(tmp_dir) + return [cnn_finalpath, dailymail_finalpath] + + +def story_generator(tmp_dir): + paths = _maybe_download_corpora(tmp_dir) + for path in paths: + for story_file in tf.gfile.Glob(path + "*"): + story = u"" + for line in tf.gfile.Open(story_file): + line = unicode(line, "utf-8") if six.PY2 else line.decode("utf-8") + story += line + yield story + + +def _story_summary_split(story): + end_pos = story.find("\n\n") # Upto first empty line. + assert end_pos != -1 + return story[:end_pos], story[end_pos:].strip() + + +@registry.register_problem +class SummarizeCnnDailymail32k(problem.Text2TextProblem): + """Summarize CNN and Daily Mail articles to their first paragraph.""" + + @property + def is_character_level(self): + return False + + @property + def has_inputs(self): + return True + + @property + def input_space_id(self): + return problem.SpaceID.EN_TOK + + @property + def target_space_id(self): + return problem.SpaceID.EN_TOK + + @property + def num_shards(self): + return 100 + + @property + def vocab_name(self): + return "vocab.cnndailymail" + + @property + def use_subword_tokenizer(self): + return True + + @property + def targeted_vocab_size(self): + return 2**15 # 32768 + + @property + def use_train_shards_for_dev(self): + return True + + def generator(self, data_dir, tmp_dir, _): + encoder = generator_utils.get_or_generate_vocab_inner( + data_dir, self.vocab_file, self.targeted_vocab_size, + lambda: story_generator(tmp_dir)) + for story in story_generator(tmp_dir): + summary, rest = _story_summary_split(story) + encoded_summary = encoder.encode(summary) + [EOS] + encoded_story = encoder.encode(rest) + [EOS] + yield {"inputs": encoded_story, "targets": encoded_summary} diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index 71f4f0920..fbe91d70e 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -272,7 +272,8 @@ def hparams(self, defaults, model_hparams): small_modality = "%s:small_image_modality" % registry.Modalities.IMAGE modality = small_modality if self.is_small else registry.Modalities.IMAGE p.input_modality = {"inputs": (modality, None)} - p.target_modality = (registry.Modalities.CLASS_LABEL, self.num_classes) + p.target_modality = ("%s:2d" % registry.Modalities.CLASS_LABEL, + self.num_classes) p.batch_size_multiplier = 4 if self.is_small else 256 p.max_expected_batch_size_per_shard = 8 if self.is_small else 2 p.loss_multiplier = 3.0 if self.is_small else 1.0 diff --git a/tensor2tensor/data_generators/imdb.py b/tensor2tensor/data_generators/imdb.py new file mode 100644 index 000000000..281a03bee --- /dev/null +++ b/tensor2tensor/data_generators/imdb.py @@ -0,0 +1,124 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""IMDB Sentiment Classification Problem.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tarfile + +# Dependency imports + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.utils import registry + +import tensorflow as tf + +# End-of-sentence marker. +EOS = text_encoder.EOS_ID + + +@registry.register_problem +class SentimentIMDB(problem.Problem): + """IMDB sentiment classification.""" + URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" + + @property + def num_shards(self): + return 10 + + @property + def vocab_file(self): + return "sentiment_imdb.vocab" + + @property + def targeted_vocab_size(self): + return 2**13 # 8k vocab suffices for this small dataset. + + def doc_generator(self, imdb_dir, dataset, include_label=False): + dirs = [(os.path.join(imdb_dir, dataset, "pos"), True), (os.path.join( + imdb_dir, dataset, "neg"), False)] + + for d, label in dirs: + for filename in os.listdir(d): + with tf.gfile.Open(os.path.join(d, filename)) as imdb_f: + doc = imdb_f.read().strip() + if include_label: + yield doc, label + else: + yield doc + + def generator(self, data_dir, tmp_dir, train): + """Generate examples.""" + # Download and extract + compressed_filename = os.path.basename(self.URL) + download_path = generator_utils.maybe_download(tmp_dir, compressed_filename, + self.URL) + imdb_dir = os.path.join(tmp_dir, "aclImdb") + if not tf.gfile.Exists(imdb_dir): + with tarfile.open(download_path, "r:gz") as tar: + tar.extractall(tmp_dir) + + # Generate vocab + encoder = generator_utils.get_or_generate_vocab_inner( + data_dir, self.vocab_file, self.targeted_vocab_size, + lambda: self.doc_generator(imdb_dir, "train")) + + # Generate examples + dataset = "train" if train else "test" + for doc, label in self.doc_generator(imdb_dir, dataset, include_label=True): + yield { + "inputs": encoder.encode(doc) + [EOS], + "targets": [int(label)], + } + + def generate_data(self, data_dir, tmp_dir, task_id=-1): + train_paths = self.training_filepaths( + data_dir, self.num_shards, shuffled=False) + dev_paths = self.dev_filepaths(data_dir, 1, shuffled=False) + generator_utils.generate_dataset_and_shuffle( + self.generator(data_dir, tmp_dir, True), train_paths, + self.generator(data_dir, tmp_dir, False), dev_paths) + + def hparams(self, defaults, model_hparams): + p = defaults + source_vocab_size = self._encoders["inputs"].vocab_size + p.input_modality = { + "inputs": (registry.Modalities.SYMBOL, source_vocab_size) + } + p.target_modality = (registry.Modalities.CLASS_LABEL, 2) + p.input_space_id = problem.SpaceID.EN_TOK + p.target_space_id = problem.SpaceID.GENERIC + + def feature_encoders(self, data_dir): + vocab_filename = os.path.join(data_dir, self.vocab_file) + encoder = text_encoder.SubwordTextEncoder(vocab_filename) + return { + "inputs": encoder, + "targets": text_encoder.TextEncoder(), + } + + def example_reading_spec(self): + data_fields = { + "inputs": tf.VarLenFeature(tf.int64), + "targets": tf.FixedLenFeature([1], tf.int64), + } + data_items_to_decoders = None + return (data_fields, data_items_to_decoders) diff --git a/tensor2tensor/data_generators/lm1b.py b/tensor2tensor/data_generators/lm1b.py index a3771e124..d45e4fe1e 100644 --- a/tensor2tensor/data_generators/lm1b.py +++ b/tensor2tensor/data_generators/lm1b.py @@ -29,8 +29,10 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder from tensor2tensor.data_generators import tokenizer +from tensor2tensor.utils import registry import tensorflow as tf @@ -53,7 +55,7 @@ def _original_vocab(tmp_dir): """ vocab_url = ("http://download.tensorflow.org/models/LM_LSTM_CNN/" "vocab-2016-09-10.txt") - vocab_filename = os.path.basename(vocab_url) + vocab_filename = os.path.basename(vocab_url + ".en") vocab_filepath = os.path.join(tmp_dir, vocab_filename) if not os.path.exists(vocab_filepath): generator_utils.maybe_download(tmp_dir, vocab_filename, vocab_url) @@ -140,29 +142,80 @@ def _get_or_build_subword_text_encoder(tmp_dir): return ret -def generator(tmp_dir, train, characters=False): - """Generator for lm1b sentences. - - Args: - tmp_dir: a string. - train: a boolean. - characters: a boolean - - Yields: - A dictionary {"inputs": [0], "targets": []} - """ - _maybe_download_corpus(tmp_dir) - original_vocab = _original_vocab(tmp_dir) - files = (_train_data_filenames(tmp_dir) if train - else [_dev_data_filename(tmp_dir)]) - if characters: - encoder = text_encoder.ByteTextEncoder() - else: - encoder = _get_or_build_subword_text_encoder(tmp_dir) - for filepath in files: - tf.logging.info("filepath = %s", filepath) - for line in tf.gfile.Open(filepath): - tokens = encoder.encode( - _replace_oov(original_vocab, text_encoder.native_to_unicode(line))) - tokens.append(EOS) - yield {"inputs": [0], "targets": tokens} +@registry.register_problem +class LanguagemodelLm1b32k(problem.Text2TextProblem): + """A language model on the 1B words corpus.""" + + @property + def is_character_level(self): + return False + + @property + def has_inputs(self): + return True + + @property + def input_space_id(self): + # Ratio of dev tokens (including eos) to dev words (including eos) + # 176884 / 159658 = 1.107893; multiply ppx by this to compare results. + return problem.SpaceID.EN_TOK + + @property + def target_space_id(self): + return problem.SpaceID.EN_TOK + + @property + def num_shards(self): + return 100 + + @property + def vocab_name(self): + return "vocab.lm1b.en" + + @property + def use_subword_tokenizer(self): + return True + + @property + def targeted_vocab_size(self): + return 2**15 # 32768 + + @property + def use_train_shards_for_dev(self): + return True + + def generator(self, tmp_dir, train, characters=False): + """Generator for lm1b sentences. + + Args: + tmp_dir: a string. + train: a boolean. + characters: a boolean + + Yields: + A dictionary {"inputs": [0], "targets": []} + """ + _maybe_download_corpus(tmp_dir) + original_vocab = _original_vocab(tmp_dir) + files = (_train_data_filenames(tmp_dir) if train + else [_dev_data_filename(tmp_dir)]) + if characters: + encoder = text_encoder.ByteTextEncoder() + else: + encoder = _get_or_build_subword_text_encoder(tmp_dir) + for filepath in files: + tf.logging.info("filepath = %s", filepath) + for line in tf.gfile.Open(filepath): + tokens = encoder.encode( + _replace_oov(original_vocab, text_encoder.native_to_unicode(line))) + tokens.append(EOS) + yield {"inputs": [0], "targets": tokens} + + +@registry.register_problem +class LanguagemodelLm1bCharacters(LanguagemodelLm1b32k): + """A language model on the 1B words corpus, character level.""" + + @property + def is_character_level(self): + return True diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py index 63b835f38..e002329bc 100644 --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -147,8 +147,8 @@ def default_problem_hparams(): # Modalities used to map from input features to a space compatible with # chosen model architecture. One modality spec (which is a 2-tuple, # (modality_full_name, vocab_size)) per feature key. modality_full_name is - # a string type:name, e.g. class_label:class_label_2d. Leaving off the - # name uses the default modality for that type (e.g. class_label == + # a string type:name, e.g. class_label:2d. Leaving off the name uses the + # default modality for that type (e.g. class_label == # class_label:default). input_modality={}, @@ -267,103 +267,6 @@ def audio_timit_tokens(model_hparams, wrong_vocab_size): return p -def audio_wsj_characters(unused_model_hparams): - """English audio transcription benchmark.""" - p = default_problem_hparams() - p.input_modality = { - "inputs": (registry.Modalities.AUDIO, None), - } - p.target_modality = (registry.Modalities.SYMBOL, 256) - p.vocabulary = { - "inputs": text_encoder.TextEncoder(), - "targets": text_encoder.ByteTextEncoder(), - } - p.batch_size_multiplier = 512 - p.loss_multiplier = 2.0 - p.input_space_id = 13 - p.target_space_id = 2 - return p - - -def audio_wsj_tokens(model_hparams, wrong_vocab_size): - """English audio transcription benchmark. - - Args: - model_hparams: a tf.contrib.training.HParams - wrong_vocab_size: a number used in the filename indicating the approximate - vocabulary size. This is not to be confused with the actual vocabulary - size. - Returns: - a tf.contrib.training.HParams - """ - p = default_problem_hparams() - # This vocab file must be present within the data directory. - vocab_filename = os.path.join(model_hparams.data_dir, - "vocab.endefr.%d" % wrong_vocab_size) - subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) - p.input_modality = { - "inputs": (registry.Modalities.AUDIO, None), - } - p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size) - p.vocabulary = { - "inputs": text_encoder.TextEncoder(), - "targets": subtokenizer, - } - p.batch_size_multiplier = 512 - p.loss_multiplier = 2.0 - p.input_space_id = 12 - p.target_space_id = 3 - return p - - -def lm1b_32k(model_hparams): - """Billion-word language-modeling benchmark, 32k subword vocabulary.""" - p = default_problem_hparams() - # ratio of dev tokens (including eos) to dev words (including eos) - # 176884 / 159658 = 1.107893 - p.perplexity_exponent = 1.107893 - p.input_modality = {} - encoder = text_encoder.SubwordTextEncoder( - os.path.join(model_hparams.data_dir, "lm1b_32k.subword_text_encoder")) - p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size) - p.vocabulary = {"targets": encoder} - p.target_space_id = 3 - return p - - -def lm1b_characters(unused_model_hparams): - """Billion-word language-modeling benchmark, 32k subword vocabulary.""" - p = default_problem_hparams() - # ratio of dev tokens (including eos) to dev words (including eos) - # 826189 / 159658 = 5.174742 - p.perplexity_exponent = 5.174742 - p.input_modality = {} - encoder = text_encoder.ByteTextEncoder() - p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size) - p.vocabulary = {"targets": encoder} - p.target_space_id = 2 - return p - - -def wmt_ende_bpe32k(model_hparams): - """English to German translation benchmark.""" - p = default_problem_hparams() - vocab_size = 40960 - modality_spec = (registry.Modalities.SYMBOL, vocab_size) - p.input_modality = {"inputs": modality_spec} - p.target_modality = modality_spec - # This vocab file must be present within the data directory. - vocab_filename = os.path.join(model_hparams.data_dir, "vocab.bpe.32000") - p.vocabulary = { - "inputs": text_encoder.TokenTextEncoder(vocab_filename=vocab_filename), - "targets": text_encoder.TokenTextEncoder(vocab_filename=vocab_filename), - } - p.loss_multiplier = 1.4 - p.input_space_id = 4 - p.target_space_id = 9 - return p - - def wmt_parsing_characters(model_hparams): """English to parse tree translation benchmark.""" del model_hparams # Unused. @@ -472,25 +375,11 @@ def img2img_imagenet(unused_model_hparams): lambda p: audio_timit_tokens(p, 2**13), "audio_timit_tokens_8k_test": lambda p: audio_timit_tokens(p, 2**13), - "audio_wsj_characters_tune": - audio_wsj_characters, - "audio_wsj_characters_test": - audio_wsj_characters, - "audio_wsj_tokens_8k_tune": - lambda p: audio_wsj_tokens(p, 2**13), - "audio_wsj_tokens_8k_test": - lambda p: audio_wsj_tokens(p, 2**13), - "languagemodel_1b_characters": - lm1b_characters, - "languagemodel_1b32k": - lm1b_32k, "parsing_english_ptb8k": lambda p: wmt_parsing_tokens(p, 2**13), "parsing_english_ptb16k": lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda p, "wsj", 2**14, 2**9), - "translate_ende_wmt_bpe32k": - wmt_ende_bpe32k, "img2img_imagenet": img2img_imagenet, } diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index c8a3bd1f9..ac9260cfa 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -161,6 +161,7 @@ def __init__(self, vocab_filename, reverse=False, vocab_list=None, + replace_oov=None, num_reserved_ids=NUM_RESERVED_TOKENS): """Initialize from a file or list, one token per line. @@ -176,10 +177,13 @@ def __init__(self, and decoding. vocab_list: If not None, a list of elements of the vocabulary. If this is not None, then vocab_filename should be None. + replace_oov: If not None, every out-of-vocabulary token seen when + encoding will be replaced by this string (which must be in vocab). num_reserved_ids: Number of IDs to save for reserved tokens like . """ super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) self._reverse = reverse + self._replace_oov = replace_oov if vocab_filename: self._init_vocab_from_file(vocab_filename) else: @@ -188,7 +192,11 @@ def __init__(self, def encode(self, sentence): """Converts a space-separated string of tokens to a list of ids.""" - ret = [self._token_to_id[tok] for tok in sentence.strip().split()] + tokens = sentence.strip().split() + if self._replace_oov is not None: + tokens = [t if t in self._token_to_id else self._replace_oov + for t in tokens] + ret = [self._token_to_id[tok] for tok in tokens] return ret[::-1] if self._reverse else ret def decode(self, ids): @@ -639,19 +647,32 @@ def _init_alphabet_from_tokens(self, tokens): self._alphabet = {c for token in tokens for c in token} self._alphabet |= _ESCAPE_CHARS - def _load_from_file(self, filename): - """Load from a file. + def _load_from_file_object(self, f): + """Load from a file object. Args: - filename: filename to load vocabulary from + f: File object to load vocabulary from """ subtoken_strings = [] - with tf.gfile.Open(filename) as f: - for line in f: - subtoken_strings.append(native_to_unicode(line.strip()[1:-1])) + for line in f: + s = line.strip() + # Some vocab files wrap words in single quotes, but others don't + if ((s.startswith("'") and s.endswith("'")) or + (s.startswith("\"") and s.endswith("\""))): + s = s[1:-1] + subtoken_strings.append(native_to_unicode(s)) self._init_subtokens_from_list(subtoken_strings) self._init_alphabet_from_tokens(subtoken_strings) + def _load_from_file(self, filename): + """Load from a file. + + Args: + filename: Filename to load vocabulary from + """ + with tf.gfile.Open(filename) as f: + self._load_from_file_object(f) + def store_to_file(self, filename): with tf.gfile.Open(filename, "w") as f: for subtoken_string in self._all_subtoken_strings: diff --git a/tensor2tensor/data_generators/text_encoder_test.py b/tensor2tensor/data_generators/text_encoder_test.py index eadfcfb5e..c13078808 100644 --- a/tensor2tensor/data_generators/text_encoder_test.py +++ b/tensor2tensor/data_generators/text_encoder_test.py @@ -21,6 +21,7 @@ from __future__ import unicode_literals import collections +import io import os import shutil @@ -31,22 +32,30 @@ import tensorflow as tf +class NativeToUnicodeTest(tf.test.TestCase): + + def test_native_to_unicode(self): + s = r"foo bar" + self.assertIsInstance(text_encoder.native_to_unicode(s), unicode) + self.assertEqual(text_encoder.native_to_unicode(s), u"foo bar") + + class EscapeUnescapeTokenTest(tf.test.TestCase): def test_escape_token(self): escaped = text_encoder._escape_token( - 'Foo! Bar.\nunder_score back\\slash', - set('abcdefghijklmnopqrstuvwxyz .\n') | text_encoder._ESCAPE_CHARS) + "Foo! Bar.\nunder_score back\\slash", + set("abcdefghijklmnopqrstuvwxyz .\n") | text_encoder._ESCAPE_CHARS) self.assertEqual( - '\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_', escaped) + "\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_", escaped) def test_unescape_token(self): unescaped = text_encoder._unescape_token( - '\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_') + "\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_") self.assertEqual( - 'Foo! Bar.\nunder_score back\\slash', unescaped) + "Foo! Bar.\nunder_score back\\slash", unescaped) class TokenTextEncoderTest(tf.test.TestCase): @@ -54,7 +63,7 @@ class TokenTextEncoderTest(tf.test.TestCase): @classmethod def setUpClass(cls): """Make sure the test dir exists and is empty.""" - cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), 'encoder_test') + cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), "encoder_test") shutil.rmtree(cls.test_temp_dir, ignore_errors=True) os.mkdir(cls.test_temp_dir) @@ -65,8 +74,8 @@ def test_save_and_reload(self): that this test size be "large". """ - corpus = 'A B C D E F G H I J K L M N O P Q R S T U V W X Y Z' - vocab_filename = os.path.join(self.test_temp_dir, 'abc.vocab') + corpus = "A B C D E F G H I J K L M N O P Q R S T U V W X Y Z" + vocab_filename = os.path.join(self.test_temp_dir, "abc.vocab") # Make text encoder from a list and store vocab to fake filesystem. encoder = text_encoder.TokenTextEncoder(None, vocab_list=corpus.split()) @@ -80,7 +89,7 @@ def test_save_and_reload(self): def test_reserved_tokens_in_corpus(self): """Test that we handle reserved tokens appearing in the corpus.""" - corpus = 'A B {} D E F {} G {}'.format(text_encoder.EOS, + corpus = "A B {} D E F {} G {}".format(text_encoder.EOS, text_encoder.EOS, text_encoder.PAD) @@ -97,14 +106,14 @@ class SubwordTextEncoderTest(tf.test.TestCase): def test_encode_decode(self): corpus = ( - 'This is a corpus of text that provides a bunch of tokens from which ' - 'to build a vocabulary. It will be used when strings are encoded ' - 'with a TextEncoder subclass. The encoder was coded by a coder.') - token_counts = collections.Counter(corpus.split(' ')) - alphabet = set(corpus) ^ {' '} + "This is a corpus of text that provides a bunch of tokens from which " + "to build a vocabulary. It will be used when strings are encoded " + "with a TextEncoder subclass. The encoder was coded by a coder.") + token_counts = collections.Counter(corpus.split(" ")) + alphabet = set(corpus) ^ {" "} - original = 'This is a coded sentence encoded by the SubwordTextEncoder.' - token_counts.update(original.split(' ')) + original = "This is a coded sentence encoded by the SubwordTextEncoder." + token_counts.update(original.split(" ")) encoder = text_encoder.SubwordTextEncoder.build_to_target_size( 100, token_counts, 2, 10) @@ -118,31 +127,31 @@ def test_encode_decode(self): # they should appear in the vocabulary even though they are substrings # of other included strings. subtoken_strings = {encoder._all_subtoken_strings[i] for i in encoded} - self.assertIn('encoded_', subtoken_strings) - self.assertIn('coded_', subtoken_strings) - self.assertIn('TextEncoder', encoder._all_subtoken_strings) - self.assertIn('coder', encoder._all_subtoken_strings) + self.assertIn("encoded_", subtoken_strings) + self.assertIn("coded_", subtoken_strings) + self.assertIn("TextEncoder", encoder._all_subtoken_strings) + self.assertIn("coder", encoder._all_subtoken_strings) - # Every character in the corpus should be in the encoder's alphabet and + # Every character in the corpus should be in the encoders alphabet and # its subtoken vocabulary. self.assertTrue(alphabet.issubset(encoder._alphabet)) for a in alphabet: self.assertIn(a, encoder._all_subtoken_strings) def test_unicode(self): - corpus = 'Cat emoticons. \U0001F638 \U0001F639 \U0001F63A \U0001F63B' - token_counts = collections.Counter(corpus.split(' ')) + corpus = "Cat emoticons. \U0001F638 \U0001F639 \U0001F63A \U0001F63B" + token_counts = collections.Counter(corpus.split(" ")) encoder = text_encoder.SubwordTextEncoder.build_to_target_size( 100, token_counts, 2, 10) - self.assertIn('\U0001F638', encoder._alphabet) - self.assertIn('\U0001F63B', encoder._all_subtoken_strings) + self.assertIn("\U0001F638", encoder._alphabet) + self.assertIn("\U0001F63B", encoder._all_subtoken_strings) def test_small_vocab(self): - corpus = 'The quick brown fox jumps over the lazy dog' - token_counts = collections.Counter(corpus.split(' ')) - alphabet = set(corpus) ^ {' '} + corpus = "The quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) + alphabet = set(corpus) ^ {" "} encoder = text_encoder.SubwordTextEncoder.build_to_target_size( 10, token_counts, 2, 10) @@ -155,12 +164,12 @@ def test_small_vocab(self): self.assertIn(a, encoder._all_subtoken_strings) def test_encodable_when_not_in_alphabet(self): - corpus = 'the quick brown fox jumps over the lazy dog' - token_counts = collections.Counter(corpus.split(' ')) + corpus = "the quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) encoder = text_encoder.SubwordTextEncoder.build_to_target_size( 100, token_counts, 2, 10) - original = 'This has UPPER CASE letters that are out of alphabet' + original = "This has UPPER CASE letters that are out of alphabet" # Early versions could have an infinite loop when breaking into subtokens # if there was any out-of-alphabet characters in the encoded string. @@ -168,24 +177,42 @@ def test_encodable_when_not_in_alphabet(self): decoded = encoder.decode(encoded) self.assertEqual(original, decoded) - encoded_str = ''.join(encoder._all_subtoken_strings[i] for i in encoded) - self.assertIn('\\84;', encoded_str) + encoded_str = "".join(encoder._all_subtoken_strings[i] for i in encoded) + self.assertIn("\\84;", encoded_str) - @mock.patch.object(text_encoder, '_ESCAPE_CHARS', new=set('\\_;13579')) + @mock.patch.object(text_encoder, "_ESCAPE_CHARS", new=set("\\_;13579")) def test_raises_exception_when_not_encodable(self): - corpus = 'the quick brown fox jumps over the lazy dog' - token_counts = collections.Counter(corpus.split(' ')) + corpus = "the quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) # Deliberately exclude some required encoding chars from the alphabet # and token list, making some strings unencodable. encoder = text_encoder.SubwordTextEncoder.build_to_target_size( 100, token_counts, 2, 10) - original = 'This has UPPER CASE letters that are out of alphabet' + original = "This has UPPER CASE letters that are out of alphabet" # Previously there was a bug which produced an infinite loop in this case. with self.assertRaises(AssertionError): encoder.encode(original) - -if __name__ == '__main__': + def test_load_from_file(self): + # Test a vocab file with words not wrapped with single quotes + encoder = text_encoder.SubwordTextEncoder() + correct_vocab = ["the", "and", "of"] + vocab = io.StringIO("the\n" + "and\n" + "of\n") + encoder._load_from_file_object(vocab) + self.assertEqual(encoder._all_subtoken_strings, correct_vocab) + + # Test a vocab file with words wrapped in single quotes + encoder = text_encoder.SubwordTextEncoder() + vocab = io.StringIO("\"the\"\n" + "\"and\"\n" + "\"of\"\n") + encoder._load_from_file_object(vocab) + self.assertEqual(encoder._all_subtoken_strings, correct_vocab) + + +if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/data_generators/wiki.py b/tensor2tensor/data_generators/wiki.py index d1c80f2e1..9610cb1d8 100644 --- a/tensor2tensor/data_generators/wiki.py +++ b/tensor2tensor/data_generators/wiki.py @@ -31,6 +31,7 @@ from tensor2tensor.data_generators import text_encoder from tensor2tensor.utils import registry +import tensorflow as tf # End-of-sentence marker. EOS = text_encoder.EOS_ID @@ -49,7 +50,7 @@ def _maybe_download_corpus(tmp_dir): "enwiki-20170620-pages-articles-multistream.xml.bz2") corpus_filename = os.path.basename(corpus_url) corpus_filepath = os.path.join(tmp_dir, corpus_filename) - if not os.path.exists(corpus_filepath): + if not tf.gfile.Exists(corpus_filepath): generator_utils.maybe_download(tmp_dir, corpus_filename, corpus_url) return corpus_filepath diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py index 93fc27ac5..8d6cdae6f 100644 --- a/tensor2tensor/data_generators/wmt.py +++ b/tensor2tensor/data_generators/wmt.py @@ -305,17 +305,44 @@ def _get_wmt_ende_bpe_dataset(directory, filename): return train_path -def ende_bpe_token_generator(data_dir, tmp_dir, train): - """Instance of token generator for the WMT en->de task, training set.""" - dataset_path = ("train.tok.clean.bpe.32000" - if train else "newstest2013.tok.bpe.32000") - train_path = _get_wmt_ende_bpe_dataset(tmp_dir, dataset_path) - token_tmp_path = os.path.join(tmp_dir, "vocab.bpe.32000") - token_path = os.path.join(data_dir, "vocab.bpe.32000") - tf.gfile.Copy(token_tmp_path, token_path, overwrite=True) - token_vocab = text_encoder.TokenTextEncoder(vocab_filename=token_path) - return token_generator(train_path + ".en", train_path + ".de", token_vocab, - EOS) +@registry.register_problem +class TranslateEndeWmtBpe32k(TranslateProblem): + """Problem spec for WMT En-De translation, BPE version.""" + + @property + def targeted_vocab_size(self): + return 32000 + + @property + def vocab_name(self): + return "vocab.bpe" + + def feature_encoders(self, data_dir): + vocab_filename = os.path.join(data_dir, self.vocab_file) + encoder = text_encoder.TokenTextEncoder(vocab_filename, replace_oov="UNK") + return {"inputs": encoder, "targets": encoder} + + def generator(self, data_dir, tmp_dir, train): + """Instance of token generator for the WMT en->de task, training set.""" + dataset_path = ("train.tok.clean.bpe.32000" + if train else "newstest2013.tok.bpe.32000") + train_path = _get_wmt_ende_bpe_dataset(tmp_dir, dataset_path) + token_tmp_path = os.path.join(tmp_dir, self.vocab_file) + token_path = os.path.join(data_dir, self.vocab_file) + tf.gfile.Copy(token_tmp_path, token_path, overwrite=True) + with tf.gfile.GFile(token_path, mode="a") as f: + f.write("UNK\n") # Add UNK to the vocab. + token_vocab = text_encoder.TokenTextEncoder(token_path, replace_oov="UNK") + return token_generator(train_path + ".en", train_path + ".de", + token_vocab, EOS) + + @property + def input_space_id(self): + return problem.SpaceID.EN_BPE_TOK + + @property + def target_space_id(self): + return problem.SpaceID.DE_BPE_TOK def _preprocess_sgm(line, is_sgm): diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index d69e68f80..253e9bee5 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -30,6 +30,8 @@ import tensorflow as tf +from tensorflow.python.framework import function + def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): """Adds a bunch of sinusoids of different frequencies to a Tensor. @@ -1008,6 +1010,8 @@ def self_attention_expert( depth = x.get_shape().as_list()[-1] length = tf.shape(batch_coordinate)[0] + tf.summary.scalar("batch_size", length, family="experts_stats_batch_size") + attention_kq_size = attention_kq_size or depth attention_v_size = attention_v_size or depth @@ -1059,6 +1063,7 @@ def local_expert_attention( loss_coef, attention_num_experts, train=True, + pad_remover=None, **kwargs ): """Attention using a mixture of experts. @@ -1073,6 +1078,7 @@ def local_expert_attention( loss_coef: a scalar. A multiplier for the expert loss attention_num_experts: The number of experts to use train: a boolean for the current mode + pad_remover (PadRemover): A util object containing the padding position **kwargs: Arguments to forward to self_attention_expert Returns: @@ -1093,4 +1099,153 @@ def local_expert_attention( loss_coef=loss_coef, pass_x=True, pass_gates=False, - additional_dispatch_params=additional_dispatch_params) + additional_dispatch_params=additional_dispatch_params, + pad_remover=pad_remover + ) + + +def scaled_dot_product_attention_simple(q, k, v, bias, name=None): + """scaled dot-product attention. One head. One spatial dimension. + + Args: + q: a Tensor with shape [batch, length_q, depth_k] + k: a Tensor with shape [batch, length_kv, depth_k] + v: a Tensor with shape [batch, length_kv, depth_v] + bias: optional Tensor broadcastable to [batch, length_q, length_kv] + name: an optional string + + Returns: + A Tensor. + """ + with tf.variable_scope( + name, default_name="scaled_dot_product_attention_simple"): + scalar = tf.rsqrt(tf.to_float(tf.shape(q)[2])) + logits = tf.matmul(q * scalar, k, transpose_b=True) + if bias is not None: + logits += bias + weights = tf.nn.softmax(logits, name="attention_weights") + return tf.matmul(weights, v) + + +_function_cache = {} + + +def multihead_self_attention_memory_efficient(x, + bias, + num_heads, + head_size=None, + epsilon=1e-6, + forget=True, + test_vars=None, + name=None): + """Multihead scaled-dot-product self-attention. + + Includes layer norm. + + Returns multihead-self-attention(layer_norm(x)) + + Computes one attention head at a time to avoid exhausting memory. + + If forget=True, then forget all forwards activations and recompute on + the backwards pass. + + Args: + x: a Tensor with shape [batch, length, input_size] + bias: an attention bias tensor broadcastable to [batch, 1, length, length] + num_heads: an integer + head_size: an optional integer - defaults to input_size/num_heads + epsilon: a float, for layer norm + forget: a boolean - forget forwards activations and recompute on backprop + test_vars: optional tuple of variables for testing purposes + name: an optional string + + Returns: + A Tensor. + """ + io_size = x.get_shape().as_list()[-1] + if head_size is None: + assert io_size % num_heads == 0 + head_size = io_size / num_heads + + def forward_internal(x, wqkv, wo, attention_bias, norm_scale, norm_bias): + """Forward function.""" + n = common_layers.layer_norm_compute_python( + x, epsilon, norm_scale, norm_bias) + wqkv_split = tf.unstack(wqkv, num=num_heads) + wo_split = tf.unstack(wo, num=num_heads) + y = 0 + for h in xrange(num_heads): + with tf.control_dependencies([y] if h > 0 else []): + combined = tf.nn.conv1d(n, wqkv_split[h], 1, "SAME") + q, k, v = tf.split(combined, 3, axis=2) + o = scaled_dot_product_attention_simple(q, k, v, attention_bias) + y += tf.nn.conv1d(o, wo_split[h], 1, "SAME") + return y + + key = ("multihead_self_attention_memory_efficient %s %s" % + (num_heads, epsilon)) + if not forget: + forward_fn = forward_internal + elif key in _function_cache: + forward_fn = _function_cache[key] + else: + @function.Defun(compiled=True) + def grad_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias, dy): + with tf.control_dependencies([dy]): + n = common_layers.layer_norm_compute_python( + x, epsilon, norm_scale, norm_bias) + wqkv_split = tf.unstack(wqkv, num=num_heads) + wo_split = tf.unstack(wo, num=num_heads) + deps = [] + dwqkvs = [] + dwos = [] + dn = 0 + for h in xrange(num_heads): + with tf.control_dependencies(deps): + combined = tf.nn.conv1d(n, wqkv_split[h], 1, "SAME") + q, k, v = tf.split(combined, 3, axis=2) + o = scaled_dot_product_attention_simple(q, k, v, attention_bias) + partial_y = tf.nn.conv1d(o, wo_split[h], 1, "SAME") + pdn, dwqkvh, dwoh = tf.gradients( + ys=[partial_y], + xs=[n, wqkv_split[h], wo_split[h]], + grad_ys=[dy]) + dn += pdn + dwqkvs.append(dwqkvh) + dwos.append(dwoh) + deps = [dn, dwqkvh, dwoh] + dwqkv = tf.stack(dwqkvs) + dwo = tf.stack(dwos) + with tf.control_dependencies(deps): + dx, dnorm_scale, dnorm_bias = tf.gradients( + ys=[n], xs=[x, norm_scale, norm_bias], grad_ys=[dn]) + return (dx, dwqkv, dwo, tf.zeros_like(attention_bias), + dnorm_scale, dnorm_bias) + + @function.Defun(grad_func=grad_fn, compiled=True, + separate_compiled_gradients=True) + def forward_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias): + return forward_internal( + x, wqkv, wo, attention_bias, norm_scale, norm_bias) + _function_cache[key] = forward_fn + + if bias is not None: + bias = tf.squeeze(bias, 1) + with tf.variable_scope(name, default_name="multihead_attention", values=[x]): + # TODO(noam): it would be nice to save memory by casting x to float16 + # here, but this causes problems with the gradients. Figure out if there + # is a way to leave the gradients as float32. + if test_vars is not None: + wqkv, wo, norm_scale, norm_bias = list(test_vars) + else: + wqkv = tf.get_variable( + "wqkv", [num_heads, 1, io_size, 3 * head_size], + initializer=tf.random_normal_initializer(stddev=io_size**-0.5)) + wo = tf.get_variable( + "wo", [num_heads, 1, head_size, io_size], + initializer=tf.random_normal_initializer( + stddev=(head_size * num_heads)**-0.5)) + norm_scale, norm_bias = common_layers.layer_norm_vars(io_size) + y = forward_fn(x, wqkv, wo, bias, norm_scale, norm_bias) + y.set_shape(x.get_shape()) + return y diff --git a/tensor2tensor/layers/common_attention_test.py b/tensor2tensor/layers/common_attention_test.py index e49999fbb..6664bcc2d 100644 --- a/tensor2tensor/layers/common_attention_test.py +++ b/tensor2tensor/layers/common_attention_test.py @@ -23,6 +23,7 @@ import numpy as np from tensor2tensor.layers import common_attention +from tensor2tensor.layers import common_layers import tensorflow as tf @@ -117,6 +118,49 @@ def testLocalUnmaskedAttention2DMatchingBlockLength(self): res = session.run(a) self.assertEqual(res.shape, (5, 4, 25, 25, 16)) + def testMultiheadSelfAttentionMemoryEfficient(self): + num_heads = 4 + io_size = 16 + batch = 2 + length = 7 + head_size = 5 + x = np.random.rand(batch, length, io_size) + dy = np.random.rand(batch, length, io_size) + with self.test_session() as session: + x = tf.to_float(x) + dy = tf.to_float(dy) + bias = common_attention.attention_bias_lower_triangle(length) + wqkv = tf.get_variable( + "wqkv", [num_heads, 1, io_size, 3 * head_size], + initializer=tf.random_normal_initializer(stddev=io_size**-0.5)) + wo = tf.get_variable( + "wo", [num_heads, 1, head_size, io_size], + initializer=tf.random_normal_initializer( + stddev=(head_size * num_heads)**-0.5)) + norm_scale, norm_bias = common_layers.layer_norm_vars(io_size) + y = common_attention.multihead_self_attention_memory_efficient( + x, bias, num_heads, head_size=head_size, forget=False, + test_vars=(wqkv, wo, norm_scale, norm_bias)) + y_forget = common_attention.multihead_self_attention_memory_efficient( + x, bias, num_heads, head_size=head_size, forget=True, + test_vars=(wqkv, wo, norm_scale, norm_bias)) + dx, dwqkv, dwo, dnorm_scale, dnorm_bias = tf.gradients( + ys=[y], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy]) + dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f = tf.gradients( + ys=[y_forget], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy]) + session.run(tf.global_variables_initializer()) + (y, y_forget, + dx, dwqkv, dwo, dnorm_scale, dnorm_bias, + dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f) = session.run( + [y, y_forget, + dx, dwqkv, dwo, dnorm_scale, dnorm_bias, + dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f]) + self.assertAllClose(y, y_forget) + self.assertAllClose(dwo, dwo_f) + self.assertAllClose(dwqkv, dwqkv_f) + self.assertAllClose(dnorm_scale, dnorm_scale_f) + self.assertAllClose(dnorm_bias, dnorm_bias_f) + self.assertAllClose(dx, dx_f) if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/layers/common_hparams.py b/tensor2tensor/layers/common_hparams.py index d4751bb0d..2e33c9e94 100644 --- a/tensor2tensor/layers/common_hparams.py +++ b/tensor2tensor/layers/common_hparams.py @@ -33,14 +33,6 @@ def basic_params1(): """A set of basic hyperparameters.""" return tf.contrib.training.HParams( batch_size=4096, # in tokens per batch per gpu - # This flag controls the number of length buckets in the data reader. - # Too many buckets slows down data reading - this needs fixing. - # Too few buckets mean lots of wasted padding. - # If this value is 1, we have buckets with maximum lengths: - # [8, 12, 16, 24, 32, 48 ... (max_length or batch_size)] - # If this value is 2, we have buckets with maximum lengths: - # [8, 10, 12, 14, 16, 20, 24 ... (max_length or batch_size)] - batching_mantissa_bits=1, num_hidden_layers=4, kernel_height=3, kernel_width=1, @@ -98,9 +90,22 @@ def basic_params1(): # epsilon parameter to normalization function norm_epsilon=1e-6, symbol_modality_num_shards=16, - # setting the max length in a minibatch. 0 means default behavior, - # max_length = hparams.batch_size * length_multiplier + # During training, we drop sequences whose inputs or targets are longer + # than max_length. + # If max_length==0, we use hparams.batch_size instead. max_length=0, + # Maximum length in the smallest length bucket. Setting this + # flag too high will result in wasteful padding of short + # sequences. Due to some (hopefully) temporary hacks in the + # data reading and batching code, setting this flag too low + # results in a very long batch-shuffling queue. + # TODO(noam): change this once the Datasets API changes. + min_length_bucket=8, + # This flag controls the number of length buckets in the data + # reader. The buckets have maximum lengths from + # min_bucket_length to (max_length or batch_size), increasing + # (approximately) by factors of length_bucket_step. + length_bucket_step=1.1, # If set to True, drop sequences longer than max_length during eval. # This affects the validity of the evaluation metrics. eval_drop_long_sequences=int(False), diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index ad899bfbf..4b09e70cb 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -425,6 +425,15 @@ def conv_fn(inputs, filters, kernel_size, **kwargs): return conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs) +def layer_norm_vars(filters): + """Create Variables for layer norm.""" + scale = tf.get_variable( + "layer_norm_scale", [filters], initializer=tf.ones_initializer()) + bias = tf.get_variable( + "layer_norm_bias", [filters], initializer=tf.zeros_initializer()) + return scale, bias + + def layer_norm_compute_python(x, epsilon, scale, bias): """Layer norm raw computation.""" mean = tf.reduce_mean(x, axis=[-1], keep_dims=True) @@ -1773,7 +1782,7 @@ def smoothing_cross_entropy_factored_grad(op, dy): b = op.inputs[1] labels = op.inputs[2] confidence = op.inputs[3] - num_splits = 32 + num_splits = 16 vocab_size = tf.shape(b)[0] labels = approximate_split(labels, num_splits) a = approximate_split(a, num_splits) @@ -1817,7 +1826,7 @@ def smoothing_cross_entropy_factored(a, b, labels, confidence): Returns: A Tensor with shape [batch] """ - num_splits = 32 + num_splits = 16 vocab_size = tf.shape(b)[0] labels = approximate_split(labels, num_splits) a = approximate_split(a, num_splits) @@ -1957,3 +1966,113 @@ def identity(*args): id_out = identity(*(inputs + train_vars + outputs)) return id_out + + +_function_cache = {} + + +def conv_hidden_relu_memory_efficient(x, + filter_size, + epsilon=1e-6, + forget=True, + test_vars=None, + name=None): + """LayerNorm, Conv, ReLU, Conv. + + All convolutions have kernel size 1. + + returns conv(relu(conv(layer_norm(x)))) + + Args: + x: input Tensor with shape [batch, length, io_size] + filter_size: an integer - size of the hidden layer. + epsilon: a float (for layer norm) + forget: a boolean - forget forwards activations and recompute on backprop + test_vars: optional tuple of variables for testing purposes + name: an optional string + + Returns: + a Tensor with shape [batch, length, io_size] + """ + io_size = x.get_shape().as_list()[-1] + + def forward_internal(x, f1, f2, scale, bias): + """Forward function.""" + # split batch-wise to avoid exhausting memory in cast the batch is large + # and the hidden layer is large. + num_splits = 4 + x_flat = tf.reshape(x, [-1, 1, tf.shape(x)[2]]) + xs = approximate_split(x_flat, num_splits) + ys = [] + for i in xrange(num_splits): + with tf.control_dependencies(ys[-1:]): + n = layer_norm_compute_python(xs[i], epsilon, scale, bias) + y = tf.nn.conv1d(n, f1, 1, "SAME") + y = tf.nn.relu(y) + y = tf.nn.conv1d(y, f2, 1, "SAME") + ys.append(y) + y = tf.concat(ys, 0) + y = tf.reshape(y, tf.shape(x)) + return y + key = ("conv_hidden_relu_memory_efficient %s" % epsilon) + if not forget: + forward_fn = forward_internal + elif key in _function_cache: + forward_fn = _function_cache[key] + else: + @function.Defun(compiled=True) + def grad_fn(x, f1, f2, scale, bias, dy): + with tf.control_dependencies([dy]): + num_splits = 4 + x_shape = tf.shape(x) + flat_shape = [-1, 1, x_shape[2]] + x = tf.reshape(x, flat_shape) + dy = tf.reshape(dy, flat_shape) + xs = approximate_split(x, num_splits) + dys = approximate_split(dy, num_splits) + dxs = [] + df1 = 0 + df2 = 0 + dscale = 0 + dbias = 0 + deps = [] + for i in xrange(num_splits): + with tf.control_dependencies(deps): + n = layer_norm_compute_python(xs[i], epsilon, scale, bias) + y = tf.nn.conv1d(n, f1, 1, "SAME") + y = tf.nn.relu(y) + y = tf.nn.conv1d(y, f2, 1, "SAME") + dxi, pdf1, pdf2, pdscale, pdbias = tf.gradients( + ys=[y], xs=[xs[i], f1, f2, scale, bias], grad_ys=[dys[i]]) + df1 += pdf1 + df2 += pdf2 + dscale += pdscale + dbias += pdbias + dxs.append(dxi) + deps = [dxi, df1, df2, dscale, dbias] + with tf.control_dependencies(deps): + dx = tf.concat(dxs, 0) + dx = tf.reshape(dx, x_shape) + return dx, df1, df2, dscale, dbias + + @function.Defun(grad_func=grad_fn, compiled=True, + separate_compiled_gradients=True) + def forward_fn(x, f1, f2, scale, bias): + return forward_internal(x, f1, f2, scale, bias) + + with tf.variable_scope(name, default_name="ffn2", values=[x]): + # TODO(noam): it would be nice to save memory by casting x to float16 + # here, but this causes problems with the gradients. Figure out if there + # is a way to leave the gradients as float32. + if test_vars is not None: + f1, f2, scale, bias = list(test_vars) + else: + f1 = tf.get_variable("f1", [1, io_size, filter_size]) + f2 = tf.get_variable("f2", [1, filter_size, io_size]) + scale, bias = layer_norm_vars(io_size) + if forget: + y = forward_fn(x, f1, f2, scale, bias) + else: + y = forward_internal(x, f1, f2, scale, bias) + y.set_shape(x.get_shape()) + return y diff --git a/tensor2tensor/layers/common_layers_test.py b/tensor2tensor/layers/common_layers_test.py index 61023938f..d11f8ce2c 100644 --- a/tensor2tensor/layers/common_layers_test.py +++ b/tensor2tensor/layers/common_layers_test.py @@ -474,6 +474,43 @@ def testFactoredTensorImplicitConversion(self): out = session.run(d) self.assertEqual(out.shape, (3, 4, 6)) + def testConvHiddenReluMemoryEfficient(self): + batch = 3 + length = 23 + io_size = 16 + filter_size = 7 + x = np.random.rand(batch, length, io_size) + dy = np.random.rand(batch, length, io_size) + with self.test_session() as session: + x = tf.to_float(x) + dy = tf.to_float(dy) + f1 = tf.get_variable("f1", [1, io_size, filter_size]) + f2 = tf.get_variable("f2", [1, filter_size, io_size]) + norm_scale, norm_bias = common_layers.layer_norm_vars(io_size) + y = common_layers.conv_hidden_relu_memory_efficient( + x, filter_size, forget=False, + test_vars=(f1, f2, norm_scale, norm_bias)) + y_forget = common_layers.conv_hidden_relu_memory_efficient( + x, filter_size, forget=True, + test_vars=(f1, f2, norm_scale, norm_bias)) + dx, df1, df2, dnorm_scale, dnorm_bias = tf.gradients( + ys=[y], xs=[x, f1, f2, norm_scale, norm_bias], grad_ys=[dy]) + dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f = tf.gradients( + ys=[y_forget], xs=[x, f1, f2, norm_scale, norm_bias], grad_ys=[dy]) + session.run(tf.global_variables_initializer()) + (y, y_forget, + dx, df1, df2, dnorm_scale, dnorm_bias, + dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f) = session.run( + [y, y_forget, + dx, df1, df2, dnorm_scale, dnorm_bias, + dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f]) + self.assertAllClose(y, y_forget) + self.assertAllClose(df2, df2_f) + self.assertAllClose(df1, df1_f) + self.assertAllClose(dnorm_scale, dnorm_scale_f) + self.assertAllClose(dnorm_bias, dnorm_bias_f) + self.assertAllClose(dx, dx_f) + class FnWithCustomGradTest(tf.test.TestCase): diff --git a/tensor2tensor/layers/modalities.py b/tensor2tensor/layers/modalities.py index 57652dbec..e03e6835e 100644 --- a/tensor2tensor/layers/modalities.py +++ b/tensor2tensor/layers/modalities.py @@ -361,9 +361,9 @@ def xnet_resblock(x, filters, res_relu, name): "compress_block_final") -@registry.register_class_label_modality("default") +@registry.register_class_label_modality("2d") class ClassLabelModality(modality.Modality): - """Used for label data.""" + """Used for label data; if is2d=True, uses Xception flow to logits.""" def __init__(self, model_hparams, vocab_size, is2d=True): super(ClassLabelModality, self).__init__(model_hparams, vocab_size) @@ -397,9 +397,11 @@ def targets_bottom(self, x): def top(self, body_output, _): """Transform inputs from model space to target space. - Perform the Xception "Exit flow", consisting of a single residual block and - two separable convolutional upscalings followed by global spatial average - pooling. + If instantiated with is2d=True, perform the Xception "Exit flow", consisting + of a single residual block and two separable convolutional upscalings + followed by global spatial average pooling. + + Otherwise, a single linear layer to logits. Args: body_output: A Tensor with shape [batch, ?, ?, body_output_size]. @@ -417,11 +419,12 @@ def top(self, body_output, _): spatial_dim = tf.to_int32(spatial_dim_float) x_depth = int(x.get_shape()[3]) x = tf.reshape(x, [-1, spatial_dim, spatial_dim, x_depth]) - x = common_layers.conv_block_downsample(x, self._kernel, self._strides, - self._padding) - x = tf.nn.relu(x) - x = tf.reduce_mean(x, axis=[1, 2], keep_dims=True) - res = common_layers.conv(x, self._vocab_size, (1, 1)) + x = common_layers.conv_block_downsample(x, self._kernel, self._strides, + self._padding) + x = tf.nn.relu(x) + x = tf.reduce_mean(x, axis=[1, 2], keep_dims=True) + + res = tf.layers.dense(x, self._vocab_size) return tf.expand_dims(res, 3) def loss(self, top_out, targets, weights_fn=common_layers.weights_all): @@ -431,7 +434,7 @@ def loss(self, top_out, targets, weights_fn=common_layers.weights_all): top_out, targets, weights_fn=weights_fn) -@registry.register_class_label_modality("class_label_2d") +@registry.register_class_label_modality("default") class ClassLabel1DModality(ClassLabelModality): """Used for label data.""" diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py index 5bb63c303..3b72ea9c2 100644 --- a/tensor2tensor/models/attention_lm_moe.py +++ b/tensor2tensor/models/attention_lm_moe.py @@ -40,16 +40,18 @@ import tensorflow as tf -class AttentionMoeType(object): - NONE = "none" - LOCAL = "local" - GLOBAL = "global" +class AttentionType(object): + MULTIHEAD = "multihead" + LOCAL_EXPERTS = "local_experts" + GLOBAL_MOE = "global_experts" + MEMORY_EFFICIENT = "memory_efficient" @staticmethod def get_choices(): return [ - AttentionMoeType.NONE, - AttentionMoeType.LOCAL, + AttentionType.MULTIHEAD, + AttentionType.LOCAL_EXPERTS, + AttentionType.MEMORY_EFFICIENT, ] @@ -70,7 +72,7 @@ def preprocess(x): def postprocess(x, y): return dp(common_layers.layer_postprocess, x, y, hparams) - (decoder_input, decoder_self_attention_bias) = dp( + (decoder_input, decoder_self_attention_bias, pad_remover) = dp( attention_lm_moe_prepare_decoder, targets, hparams) x = dp(tf.nn.dropout, decoder_input, @@ -87,15 +89,15 @@ def _diet_expert(x): else: expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) + for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope( - "attention_{}".format(hparams.attention_moe_type)): - x = preprocess(x) - if hparams.attention_moe_type == AttentionMoeType.NONE: + "attention_{}".format(hparams.attention_type)): + if hparams.attention_type == AttentionType.MULTIHEAD: y = dp( common_attention.multihead_attention, - x, + preprocess(x), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, @@ -104,14 +106,23 @@ def _diet_expert(x): hparams.num_heads, hparams.attention_dropout, name="decoder_self_attention") - elif hparams.attention_moe_type == AttentionMoeType.LOCAL: + elif hparams.attention_type == AttentionType.MEMORY_EFFICIENT: + assert hparams.layer_preprocess_sequence == "n" + y = dp( + common_attention.multihead_self_attention_memory_efficient, + x, + decoder_self_attention_bias, + hparams.num_heads, + name="decoder_self_attention") + elif hparams.attention_type == AttentionType.LOCAL_EXPERTS: y, loss = dp( common_attention.local_expert_attention, - x, + preprocess(x), k=2, - loss_coef=1e-2, + loss_coef=hparams.attention_load_balance, attention_num_experts=hparams.attention_num_experts, train=hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, + pad_remover=pad_remover, mask_right=True, attention_kq_size=hparams.attention_kq_size, attention_v_size=hparams.attention_v_size) @@ -119,7 +130,7 @@ def _diet_expert(x): extra_loss += tf.add_n(loss) / dp.n else: raise ValueError("Only {} supported for now.".format( - AttentionMoeType.get_choices())) + AttentionType.get_choices())) x = postprocess(x, y) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers.split(","): @@ -134,6 +145,12 @@ def _diet_expert(x): k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss + elif hparams.memory_efficient_ffn: + assert hparams.layer_preprocess_sequence == "n" + y = dp( + common_layers.conv_hidden_relu_memory_efficient, + x, + hparams.filter_size) else: y = dp( common_layers.conv_hidden_relu, @@ -158,17 +175,22 @@ def attention_lm_moe_prepare_decoder(targets, hparams): decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a Tensor, containing large negative values to implement masked attention and possibly baises for diagonal alignments + pad_remover (expert_utils.PadRemover): an util object to remove padding """ + targets_pad_mask = common_attention.embedding_to_padding(targets) + with tf.name_scope("pad_remover"): + pad_remover = expert_utils.PadRemover(targets_pad_mask) + if hparams.prepend_mode == "prepend_inputs_full_attention": - decoder_self_attention_bias = (common_attention.attention_bias_prepended( - common_attention.embedding_to_padding(targets))) + decoder_self_attention_bias = ( + common_attention.attention_bias_prepended(targets_pad_mask)) else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) decoder_input = common_layers.shift_left_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) - return (decoder_input, decoder_self_attention_bias) + return (decoder_input, decoder_self_attention_bias, pad_remover) @registry.register_hparams @@ -215,12 +237,15 @@ def attention_lm_moe_base(): hparams.add_hparam("pos", "timing") # timing, none hparams.add_hparam("moe_layers", "2") # comma separated list of layer numbers # moe params. local attention moe. - hparams.add_hparam("attention_moe_type", AttentionMoeType.NONE) + hparams.add_hparam("attention_type", AttentionType.MULTIHEAD) hparams.add_hparam("attention_num_experts", 16) # Key, query and value dimensions for the attention - hparams.add_hparam("attention_kq_size", 64) - hparams.add_hparam("attention_v_size", 64) + hparams.add_hparam("attention_kq_size", 128) + hparams.add_hparam("attention_v_size", 256) + # Loss coef for load balancing + hparams.add_hparam("attention_load_balance", 2e-2) hparams.add_hparam("diet_experts", int(False)) + hparams.add_hparam("memory_efficient_ffn", int(False)) return hparams @@ -228,9 +253,12 @@ def attention_lm_moe_base(): def attention_lm_moe_base_ae(): """Base model with attention expert.""" hparams = attention_lm_moe_base() - hparams.attention_moe_type = AttentionMoeType.LOCAL + hparams.attention_type = AttentionType.LOCAL_EXPERTS hparams.max_length = hparams.batch_size hparams.eval_drop_long_sequences = int(True) + hparams.batching_mantissa_bits = 2 # More buckets + hparams.learning_rate = 0.05 + hparams.learning_rate_warmup_steps = 10000 return hparams @@ -279,7 +307,7 @@ def attention_lm_attention_moe_tiny(): hparams.moe_layers = "" hparams.attention_num_experts = 128 hparams.filter_size = 8192 - hparams.attention_moe_type = AttentionMoeType.LOCAL + hparams.attention_type = AttentionType.LOCAL_EXPERTS return hparams @@ -335,6 +363,21 @@ def attention_lm_moe_large_diet(): return hparams +@registry.register_hparams +def attention_lm_moe_memory_efficient(): + """Memory-efficient version.""" + hparams = attention_lm_moe_large() + hparams.diet_experts = int(True) + hparams.layer_preprocess_sequence = "n" + hparams.layer_postprocess_sequence = "da" + hparams.layer_prepostprocess_dropout = 0.0 + hparams.memory_efficient_ffn = True + hparams.attention_type = AttentionType.MEMORY_EFFICIENT + hparams.num_heads = 8 + hparams.factored_logits = int(True) + return hparams + + @registry.register_hparams def attention_lm_moe_32b_diet(): """Unnecessarily large model with 32B params - because we can.""" diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 47db28c30..105d9eb32 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -378,7 +378,6 @@ def transformer_big(): hparams.hidden_size = 1024 hparams.filter_size = 4096 hparams.num_heads = 16 - hparams.batching_mantissa_bits = 2 hparams.layer_prepostprocess_dropout = 0.3 return hparams @@ -390,7 +389,6 @@ def transformer_big_single_gpu(): hparams.layer_prepostprocess_dropout = 0.1 hparams.learning_rate_warmup_steps = 16000 hparams.optimizer_adam_beta2 = 0.998 - hparams.batching_mantissa_bits = 3 return hparams @@ -400,7 +398,6 @@ def transformer_base_single_gpu(): hparams = transformer_base() hparams.batch_size = 2048 hparams.learning_rate_warmup_steps = 16000 - hparams.batching_mantissa_bits = 2 return hparams @@ -593,7 +590,6 @@ def transformer_big_dr1(): hparams.filter_size = 4096 hparams.num_heads = 16 hparams.layer_prepostprocess_dropout = 0.1 - hparams.batching_mantissa_bits = 2 return hparams diff --git a/tensor2tensor/models/transformer_vae.py b/tensor2tensor/models/transformer_vae.py index fa6b3f397..1c566e996 100644 --- a/tensor2tensor/models/transformer_vae.py +++ b/tensor2tensor/models/transformer_vae.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""VAE Transformer.""" +"""AE Transformer.""" from __future__ import absolute_import from __future__ import division @@ -32,10 +32,9 @@ import tensorflow as tf -def residual_conv(x, repeat, hparams, name, reuse=None): +def residual_conv(x, repeat, k, hparams, name, reuse=None): """A stack of convolution blocks with residual connections.""" with tf.variable_scope(name, reuse=reuse): - k = (3, 1) dilations_and_kernels = [((1, 1), k) for _ in xrange(3)] for i in xrange(repeat): with tf.variable_scope("repeat_%d" % i): @@ -72,15 +71,19 @@ def interleave(x, y, axis=1): return tf.concat([x, y], axis=axis+1) -def decompress_step(source, c, hparams, first_relu, name): +def decompress_step(source, c, hparams, first_relu, is_2d, name): """Decompression function.""" with tf.variable_scope(name): shape = tf.shape(source) if c is not None: source = attend(source, c, hparams, "decompress_attend") + multiplier = 4 if is_2d else 2 + kernel = (1, 1) if is_2d else (1, 1) thicker = common_layers.conv_block( - source, hparams.hidden_size * 2, [((1, 1), (1, 1))], + source, hparams.hidden_size * multiplier, [((1, 1), kernel)], first_relu=first_relu, name="decompress_conv") + if is_2d: + return tf.depth_to_space(thicker, 2) return tf.reshape(thicker, [shape[0], shape[1] * 2, 1, hparams.hidden_size]) @@ -90,7 +93,7 @@ def gumbel_sample(shape): return -tf.log(-tf.log(uniform_samples)) -def dvae(x, hparams, name): +def dae(x, hparams, name): with tf.variable_scope(name): m = tf.layers.dense(x, hparams.v_size, name="mask") logsm = tf.nn.log_softmax(m) @@ -128,7 +131,7 @@ def nearest(x, means, hparams): _, nearest_idx = tf.nn.top_k(- dist, k=1) nearest_hot = tf.one_hot(tf.squeeze(nearest_idx, axis=1), hparams.v_size) nearest_hot = tf.reshape(nearest_hot, [tf.shape(x)[0], tf.shape(x)[1], - 1, hparams.v_size]) + tf.shape(x)[2], hparams.v_size]) return tf.stop_gradient(nearest_hot) @@ -137,21 +140,23 @@ def kmeans(x, means, hparams, name): x_means_hot = nearest(x, means, hparams) x_means = tf.gather(means, tf.argmax(x_means_hot, axis=-1)) kl = tf.reduce_sum(tf.square(x - x_means), axis=-1) - return x_means_hot, x_means_hot, tf.reduce_mean(kl) * 100.0 + return x_means_hot, tf.reduce_mean(kl) * 10.0 -def compress(x, c, hparams, name): +def compress(x, c, is_2d, hparams, name): """Compress.""" with tf.variable_scope(name): # Run compression by strided convs. cur = x + k1 = (3, 3) if is_2d else (3, 1) + k2 = (2, 2) if is_2d else (2, 1) for i in xrange(hparams.num_compress_steps): if c is not None: cur = attend(cur, c, hparams, "compress_attend_%d" % i) - cur = residual_conv(cur, 1, hparams, "compress_rc_%d" % i) + cur = residual_conv(cur, 1, k1, hparams, "compress_rc_%d" % i) cur = common_layers.conv_block( - cur, hparams.hidden_size, [((1, 1), (2, 1))], - strides=(2, 1), name="compress_%d" % i) + cur, hparams.hidden_size, [((1, 1), k2)], + strides=k2, name="compress_%d" % i) return cur @@ -188,7 +193,7 @@ def decode(cond_vec, cond_add, gold, c, ed, hparams): decoder_input = tf.squeeze(decoder_input, axis=2) decoder_input = common_attention.add_timing_signal_1d(decoder_input) bias = common_attention.attention_bias_lower_triangle(tf.shape(gold)[1]) - if c is not None: + if c is not None and len(c.get_shape()) > 3: c = tf.squeeze(c, axis=2) return transformer.transformer_decoder(decoder_input, c, bias, ed, hparams) @@ -205,62 +210,62 @@ def expand_batch(x, mul): return tf.reshape(cx, res_shape) -def vae_compress(x, c, ed, hparams, compress_name, decompress_name, reuse=None): - """Compress, then VAE.""" - with tf.variable_scope(compress_name, reuse=reuse): - cur = compress(x, None, hparams, "compress") +def ae_compress(x, is_2d, hparams, name, reuse=None): + """Compress, then AE.""" + with tf.variable_scope(name, reuse=reuse): + cur = compress(x, None, is_2d, hparams, "compress") # Convolve and ReLu to get state. cur = common_layers.conv_block( cur, hparams.hidden_size, [((1, 1), (1, 1))], name="mid_conv") cur = tf.nn.l2_normalize(cur, dim=3) + cur_n = hparams.kmeans_lr_factor * cur + cur_n += (1.0 - hparams.kmeans_lr_factor) * tf.stop_gradient(cur) means = tf.get_variable("z_to_dense", [hparams.v_size, hparams.hidden_size]) - # z, kl_loss, mu, log_sigma = vae(cur, hparams, name="vae") - # z_true, z_sample, kl_loss = dvae(cur, hparams, name="dvae") - z_true, z_sample, kl_loss = kmeans(cur, means, hparams, name="kmeans") - - # Compress context. - with tf.variable_scope(compress_name, reuse=reuse): - compress_c = compress(c, None, hparams, "compress_context") - dec_c = decode(None, compress_c, cur, None, None, hparams) - c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context") - reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits( - labels=z_true, logits=c_z) + hot, loss = kmeans(cur_n, means, hparams, name="kmeans") + # We need a linear layer to undo the l2-normalization. + cur = tf.layers.dense(cur, hparams.hidden_size, name="unnormalize") + return cur, hot, loss - # If not training, use the predicted z instead of the autoregressive one. - if hparams.mode == tf.contrib.learn.ModeKeys.INFER: - z = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size) - with tf.variable_scope(decompress_name, reuse=reuse): - # Decompress. - z_sample_flat = tf.reshape(z_sample, [-1, hparams.v_size]) - z = tf.matmul(z_sample_flat, means) - z = tf.reshape(z, [tf.shape(z_sample)[0], tf.shape(z_sample)[1], - 1, hparams.hidden_size]) +def ae_embed(hot, hparams, name, reuse=None): + with tf.variable_scope(name, reuse=reuse): + means = tf.get_variable("z_to_dense", [hparams.v_size, hparams.hidden_size]) + hot_flat = tf.reshape(hot, [-1, hparams.v_size]) + emb = tf.matmul(hot_flat, means) + emb = tf.reshape(emb, [tf.shape(hot)[0], tf.shape(hot)[1], + tf.shape(hot)[2], hparams.hidden_size]) + return tf.layers.dense(emb, hparams.hidden_size, + name="unnormalize", reuse=reuse) + +def ae_decompress(z, ae, x, is_2d, hparams, name, reuse=None): + """Decompress from z, leaking from ae.""" + with tf.variable_scope(name + "_decompress", reuse=reuse): # Leak at the beginning to help train. - z = mix(z, cur, hparams.startup_steps) + z = mix(z, ae, hparams.startup_steps) + prob_z = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.8 + prob_z = prob_z if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN else 1.0 + z = tf.cond(tf.less(tf.random_uniform([]), prob_z), + lambda: z, lambda: ae) # Dropout for better autoencoding. - z = tf.nn.dropout(z, keep_prob=0.9) + z = tf.nn.dropout(z, keep_prob=1.0 - hparams.z_dropout) # Decompress. d = z for i in xrange(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 - d = residual_conv(d, 1, hparams, "decompress_rc_%d" % j) - d = decompress_step(d, c, hparams, i > 0, "decompress_step_%d" % j) + d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) + d = decompress_step(d, None, hparams, i > 0, is_2d, "decompress_%d" % j) k = 2**hparams.num_compress_steps z_batch = tf.reshape(z, [-1, 1, 1, hparams.hidden_size]) x_batch = tf.reshape(x, [-1, k, 1, hparams.hidden_size]) d_batch = tf.reshape(d, [-1, k, 1, hparams.hidden_size]) - # dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams) - c = expand_batch(c, tf.shape(x_batch)[0] / tf.shape(x)[0]) - ed = expand_batch(ed, tf.shape(x_batch)[0] / tf.shape(x)[0]) - dec_batch = decode(z_batch, d_batch, x_batch, c, ed, hparams) + dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams) z = tf.reshape(dec_batch, [-1, tf.shape(x)[1], 1, hparams.hidden_size]) - return z, kl_loss, reconstruct_loss + return z def ffn(x, hparams, name): @@ -270,35 +275,42 @@ def ffn(x, hparams, name): return common_layers.layer_postprocess(x, y, hparams) -def vae_transformer_internal(inputs, targets, target_space, hparams): - """VAE Transformer, main step used for training.""" - with tf.variable_scope("vae_transformer"): - # Prepare inputs, targets, and k. - inputs = common_layers.flatten4d3d(inputs) - input_len = tf.shape(inputs)[1] # Double input size to cover targets. - inputs = tf.pad(inputs, [[0, 0], [0, input_len], [0, 0]]) - inputs.set_shape([None, None, hparams.hidden_size]) - targets = common_layers.flatten4d3d(targets) +def ae_transformer_internal(inputs, targets, target_space, hparams): + """AE Transformer, main step used for training.""" + with tf.variable_scope("ae_transformer"): + # Prepare inputs, targets, k. k = 2**hparams.num_compress_steps - inputs, targets = common_layers.pad_to_same_length( - inputs, targets, final_length_divisible_by=k) - inputs, ed_bias = encode(inputs, target_space, hparams, "input_enc") - - # Compress and vae. - z, kl, r = vae_compress(tf.expand_dims(targets, axis=2), - tf.expand_dims(inputs, axis=2), - ed_bias, hparams, "vae_compress", "vae_decompress") + _, targets = common_layers.pad_to_same_length( + targets, targets, final_length_divisible_by=k) + inputs = common_layers.flatten4d3d(inputs) + inputs, ed = encode(inputs, target_space, hparams, "input_enc") + + # Compress and ae. + ae, hot, kl = ae_compress(targets, False, hparams, "ae") + emb = ae_embed(hot, hparams, "ae", reuse=True) + + # Compress context and run autoregressive decoder on emb-hot. + dec_c = decode(None, None, emb, inputs, ed, hparams) + c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context") + reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits( + labels=hot, logits=c_z) + # If not training, use the predicted z instead of the autoregressive one. + if hparams.mode == tf.contrib.learn.ModeKeys.INFER: + hot = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size) + + # Decompress, pass for ae loss. + z = ae_decompress(emb, ae, targets, False, hparams, "ae") kl *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.5)) - r *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 2.0)) - losses = {"kl": kl, "reconstruction": r} + reconstruct_loss *= common_layers.inverse_exp_decay(hparams.startup_steps) + losses = {"kl": kl, "reconstruction": reconstruct_loss} return z, losses @registry.register_model -class TransformerVAE(t2t_model.T2TModel): +class TransformerAE(t2t_model.T2TModel): def model_fn_body(self, features): - return vae_transformer_internal( + return ae_transformer_internal( features["inputs"], features["targets"], features["target_space_id"], self._hparams) @@ -341,7 +353,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, @registry.register_hparams -def transformer_vae_small(): +def transformer_ae_small(): """Set of hyperparameters.""" hparams = transformer.transformer_small() hparams.batch_size = 2048 @@ -351,13 +363,15 @@ def transformer_vae_small(): hparams.add_hparam("num_compress_steps", 4) hparams.add_hparam("kl_warmup_steps", 60000) hparams.add_hparam("startup_steps", 30000) + hparams.add_hparam("kmeans_lr_factor", 0.002) + hparams.add_hparam("z_dropout", 0.1) return hparams @registry.register_hparams -def transformer_vae_base(): +def transformer_ae_base(): """Set of hyperparameters.""" - hparams = transformer_vae_small() + hparams = transformer_ae_small() hparams.hidden_size = 512 hparams.filter_size = 2048 hparams.attention_dropout = 0.0 diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index dbbd8e936..d55911f19 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -18,8 +18,6 @@ from __future__ import division from __future__ import print_function -import fractions -import math import os import random @@ -271,10 +269,11 @@ def input_pipeline(problem, data_file_pattern, capacity, mode, hparams, dataset = bucket_by_sequence_length(dataset, _example_length, batching_scheme["boundaries"], - batching_scheme["batch_sizes"]) - max_batch_size = max(batching_scheme["batch_sizes"]) + batching_scheme["batch_sizes"], + batching_scheme["window_size"]) # We reshuffle the batches to prevent many long-sequence batches at once. - dataset = dataset.shuffle(max_batch_size * 3) + if batching_scheme["shuffle_queue_size"] is not None: + dataset = dataset.shuffle(batching_scheme["shuffle_queue_size"]) batched_examples = dataset.make_one_shot_iterator().get_next() return batched_examples @@ -308,38 +307,8 @@ def _example_too_big(example, max_length): return tf.less_equal(_example_length(example), max_length) -def _lcm(l): - """Least common multiple of integers in a list.""" - if not l: - raise ValueError("LCD of an empty list.") - if len(l) == 1: - return l[0] - x = l[0] - y = _lcm(l[1:]) - return x * y // fractions.gcd(x, y) - - -def _closest_small_primes(x): - """Closest number to x which has only 2, 3, 5 as prime factors, 3,5 once.""" - assert x > 0 - def is_small_primes(x, covered3, covered5): - if x % 2 == 0: - return is_small_primes(x // 2, covered3, covered5) - if x % 3 == 0 and not covered3: - return is_small_primes(x // 3, True, covered5) - if x % 5 == 0 and not covered5: - return is_small_primes(x // 5, covered3, True) - return x == 1 - for i in xrange(x): - if is_small_primes(x - i, False, False): - return x - i - # We search for higher numbers too, but only 8 of them to not increase much. - if i < 9 and is_small_primes(x + i, False, False): - return x + i - - def bucket_by_sequence_length(dataset, example_length_fn, bucket_boundaries, - bucket_batch_sizes): + bucket_batch_sizes, window_size): """Bucket entries in dataset by length. Args: @@ -348,18 +317,11 @@ def bucket_by_sequence_length(dataset, example_length_fn, bucket_boundaries, the example, which will determine the bucket it goes into. bucket_boundaries: list, boundaries of the buckets. bucket_batch_sizes: list, batch size per bucket. + window_size: an integer divisible by all elements of bucket_batch_sizes Returns: Dataset of padded and batched examples. """ - # Since the Datasets API only allows a single constant for window_size, - # and it needs divide all bucket_batch_sizes, we first make sure they only - # have a few primes in them so that their LCM doesn't explode quickly. - # TODO(lukaszkaiser): remove this adjustment when Dataset API improves. - bucket_batch_sizes1 = [_closest_small_primes(b) for b in bucket_batch_sizes] - tf.logging.info("Corrected bucket_batch_sizes from %s to %s." - % (str(bucket_batch_sizes), str(bucket_batch_sizes1))) - bucket_batch_sizes = bucket_batch_sizes1 with tf.name_scope("bucket_by_seq_length"): def example_to_bucket_id(example): @@ -386,25 +348,27 @@ def batching_fn(bucket_id, grouped_dataset): for name, shape in grouped_dataset.output_shapes.items()]) return grouped_dataset.padded_batch(batch_size, padded_shapes) - window_size = _lcm(bucket_batch_sizes) dataset = dataset.group_by_window(example_to_bucket_id, batching_fn, window_size) return dataset -def _bucket_boundaries(max_length, min_length=8, mantissa_bits=2): +def _bucket_boundaries(max_length, min_length=8, length_bucket_step=1.1): """A default set of length-bucket boundaries.""" + assert min_length <= max_length + assert length_bucket_step > 1.0 x = min_length boundaries = [] while x < max_length: boundaries.append(x) - x += 2**max(0, int(math.log(x, 2)) - mantissa_bits) + x = max(x + 1, int(x * length_bucket_step)) return boundaries -def _batching_scheme(batch_size=16 * 256, - max_length=None, - batching_mantissa_bits=1, +def _batching_scheme(batch_size, + max_length, + min_length_bucket, + length_bucket_step, drop_long_sequences=False, shard_multiplier=1, length_multiplier=1): @@ -416,7 +380,8 @@ def _batching_scheme(batch_size=16 * 256, batch_size: int, total number of tokens in a batch. max_length: int, sequences longer than this will be skipped. Defaults to batch_size. - batching_mantissa_bits: int, ??. + min_length_bucket: int + length_bucket_step: float greater than 1.0 drop_long_sequences: bool, if True, then sequences longer than `max_length` are dropped. This prevents generating batches with more than the usual number of tokens, which can cause out-of-memory @@ -434,19 +399,47 @@ def _batching_scheme(batch_size=16 * 256, """ max_length = max_length or batch_size boundaries = _bucket_boundaries( - max_length, mantissa_bits=batching_mantissa_bits) + max_length, min_length_bucket, length_bucket_step) boundaries = [boundary * length_multiplier for boundary in boundaries] max_length *= length_multiplier - batch_sizes = [ - max(1, batch_size // length) * shard_multiplier - for length in boundaries + [max_length] + max(1, batch_size // length) for length in boundaries + [max_length] ] - return { + max_batch_size = max(batch_sizes) + # Since the Datasets API only allows a single constant for window_size, + # and it needs divide all bucket_batch_sizes, we pick a highly-compoisite + # window size and then round down all batch sizes to divisors of that window + # size, so that a window can always be divided evenly into batches. + # TODO(noam): remove this when Dataset API improves. + highly_composite_numbers = [ + 1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, 1680, + 2520, 5040, 7560, 10080, 15120, 20160, 25200, 27720, 45360, 50400, 55440, + 83160, 110880, 166320, 221760, 277200, 332640, 498960, 554400, 665280, + 720720, 1081080, 1441440, 2162160, 2882880, 3603600, 4324320, 6486480, + 7207200, 8648640, 10810800, 14414400, 17297280, 21621600, 32432400, + 36756720, 43243200, 61261200, 73513440, 110270160] + window_size = max([ + i for i in highly_composite_numbers if i <= 3 * max_batch_size]) + divisors = [i for i in xrange(1, window_size + 1) if window_size % i == 0] + batch_sizes = [max([d for d in divisors if d <= bs]) for bs in batch_sizes] + window_size *= shard_multiplier + batch_sizes = [bs * shard_multiplier for bs in batch_sizes] + # The Datasets API splits one window into multiple batches, which + # produces runs of many consecutive batches of the same size. This + # is bad for training. To solve this, we will shuffle the batches + # using a queue which must be several times as large as the maximum + # number of batches per window. + max_batches_per_window = window_size // min(batch_sizes) + shuffle_queue_size = max_batches_per_window * 3 + ret = { "boundaries": boundaries, "batch_sizes": batch_sizes, - "max_length": (max_length if drop_long_sequences else 10**9) + "max_length": (max_length if drop_long_sequences else 10**9), + "shuffle_queue_size": shuffle_queue_size, + "window_size": window_size, } + tf.logging.info("batching_scheme = %s" % ret) + return ret def hparams_to_batching_scheme(hparams, @@ -455,9 +448,10 @@ def hparams_to_batching_scheme(hparams, length_multiplier=1): """Wrapper around _batching_scheme with hparams.""" return _batching_scheme( - max_length=hparams.max_length, batch_size=hparams.batch_size, - batching_mantissa_bits=hparams.batching_mantissa_bits, + max_length=hparams.max_length, + min_length_bucket=hparams.min_length_bucket, + length_bucket_step=hparams.length_bucket_step, drop_long_sequences=drop_long_sequences, shard_multiplier=shard_multiplier, length_multiplier=length_multiplier) @@ -477,7 +471,9 @@ def constant_batching_scheme(constant_batch_size_in_sequences): return { "boundaries": boundaries, "batch_sizes": batch_sizes, - "max_length": 10**9 + "max_length": 10**9, + "shuffle_queue_size": None, + "window_size": constant_batch_size_in_sequences, } diff --git a/tensor2tensor/utils/data_reader_test.py b/tensor2tensor/utils/data_reader_test.py index 318fb1cab..991669a99 100644 --- a/tensor2tensor/utils/data_reader_test.py +++ b/tensor2tensor/utils/data_reader_test.py @@ -169,36 +169,62 @@ def testLengthFilter(self): def testBatchingSchemeMaxLength(self): scheme = data_reader._batching_scheme( - batch_size=20, max_length=None, drop_long_sequences=False) + batch_size=20, max_length=None, + min_length_bucket=8, length_bucket_step=1.1, + drop_long_sequences=False) self.assertGreater(scheme["max_length"], 10000) scheme = data_reader._batching_scheme( - batch_size=20, max_length=None, drop_long_sequences=True) + batch_size=20, max_length=None, + min_length_bucket=8, length_bucket_step=1.1, + drop_long_sequences=True) self.assertEqual(scheme["max_length"], 20) scheme = data_reader._batching_scheme( - batch_size=20, max_length=15, drop_long_sequences=True) + batch_size=20, max_length=15, + min_length_bucket=8, length_bucket_step=1.1, + drop_long_sequences=True) self.assertEqual(scheme["max_length"], 15) scheme = data_reader._batching_scheme( - batch_size=20, max_length=15, drop_long_sequences=False) + batch_size=20, max_length=15, + min_length_bucket=8, length_bucket_step=1.1, + drop_long_sequences=False) self.assertGreater(scheme["max_length"], 10000) def testBatchingSchemeBuckets(self): - scheme = data_reader._batching_scheme(batch_size=128) + scheme = data_reader._batching_scheme( + batch_size=128, + max_length=0, + min_length_bucket=8, + length_bucket_step=1.1) boundaries, batch_sizes = scheme["boundaries"], scheme["batch_sizes"] self.assertEqual(len(boundaries), len(batch_sizes) - 1) - expected_boundaries = [8, 12, 16, 24, 32, 48, 64, 96] + expected_boundaries = [ + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 24, 26, 28, + 30, 33, 36, 39, 42, 46, 50, 55, 60, 66, 72, 79, 86, 94, 103, 113, 124] self.assertEqual(expected_boundaries, boundaries) - expected_batch_sizes = [16, 10, 8, 5, 4, 2, 2, 1, 1] + expected_batch_sizes = [ + 16, 12, 12, 8, 8, 8, 8, 8, 8, 6, 6, 6, 6, 4, 4, 4, 4, 4, 3, 3, 3, + 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1] self.assertEqual(expected_batch_sizes, batch_sizes) - scheme = data_reader._batching_scheme(batch_size=128, shard_multiplier=2) + scheme = data_reader._batching_scheme( + batch_size=128, + max_length=0, + min_length_bucket=8, + length_bucket_step=1.1, + shard_multiplier=2) boundaries, batch_sizes = scheme["boundaries"], scheme["batch_sizes"] self.assertAllEqual([bs * 2 for bs in expected_batch_sizes], batch_sizes) self.assertEqual(expected_boundaries, boundaries) - scheme = data_reader._batching_scheme(batch_size=128, length_multiplier=2) + scheme = data_reader._batching_scheme( + batch_size=128, + max_length=0, + min_length_bucket=8, + length_bucket_step=1.1, + length_multiplier=2) boundaries, batch_sizes = scheme["boundaries"], scheme["batch_sizes"] self.assertAllEqual([b * 2 for b in expected_boundaries], boundaries) self.assertEqual([max(1, bs // 2) @@ -211,14 +237,16 @@ def example_len(ex): boundaries = [10, 20, 30] batch_sizes = [10, 8, 4, 2] + window_size = 40 dataset = data_reader.read_examples( self.problem, self.filepatterns[0], 32, mode=tf.contrib.learn.ModeKeys.EVAL) - dataset = data_reader.bucket_by_sequence_length(dataset, example_len, - boundaries, batch_sizes) + dataset = data_reader.bucket_by_sequence_length( + dataset, example_len, + boundaries, batch_sizes, window_size) batch = dataset.make_one_shot_iterator().get_next() input_vals = [] diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index 2e430a204..3f00c25a9 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -37,85 +37,114 @@ FLAGS = tf.flags.FLAGS -def decode_from_dataset(estimator): +def _decode_from_dataset_log_results(inputs, + targets, + outputs, + problem_name, + prediction_idx, + inputs_vocab, + targets_vocab, + save_images=False, + model_dir=None, + identity_output=False): + """Log inference results.""" + if "image" in problem_name and save_images: + save_path = os.path.join(model_dir, "%s_prediction_%d.jpg" % + (problem_name, prediction_idx)) + show_and_save_image(inputs / 255., save_path) + elif inputs_vocab: + decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs.flatten())) + tf.logging.info("Inference results INPUT: %s" % decoded_inputs) + + if identity_output: + decoded_outputs = "".join(map(str, outputs.flatten())) + decoded_targets = "".join(map(str, targets.flatten())) + else: + decoded_outputs = "".join( + map(str, targets_vocab.decode(_save_until_eos(outputs.flatten())))) + decoded_targets = "".join( + map(str, targets_vocab.decode(_save_until_eos(targets.flatten())))) + + tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs) + tf.logging.info("Inference results TARGET: %s" % decoded_targets) + return decoded_outputs, decoded_targets + + +def decode_from_dataset(estimator, + problem_names, + return_beams=False, + beam_size=1, + max_predictions=-1, + decode_to_file=None, + save_images=False, + identity_output=False): + tf.logging.info("Performing local inference from dataset for %s.", + str(problem_names)) hparams = estimator.hparams - for i, problem in enumerate(FLAGS.problems.split("-")): - inputs_vocab = hparams.problems[i].vocabulary.get("inputs", None) - targets_vocab = hparams.problems[i].vocabulary["targets"] - tf.logging.info("Performing local inference.") + + for problem_idx, problem_name in enumerate(problem_names): + # Build the inference input function infer_problems_data = data_reader.get_data_filepatterns( - FLAGS.problems, hparams.data_dir, tf.contrib.learn.ModeKeys.INFER) + problem_name, hparams.data_dir, tf.contrib.learn.ModeKeys.INFER) infer_input_fn = input_fn_builder.build_input_fn( mode=tf.contrib.learn.ModeKeys.INFER, hparams=hparams, data_file_patterns=infer_problems_data, num_datashards=devices.data_parallelism().n, - fixed_problem=i) - - def log_fn(inputs, - targets, - outputs, - problem, - j, - inputs_vocab=inputs_vocab, - targets_vocab=targets_vocab): - """Log inference results.""" - if "image" in problem and FLAGS.decode_save_images: - save_path = os.path.join(estimator.model_dir, - "%s_prediction_%d.jpg" % (problem, j)) - show_and_save_image(inputs / 255., save_path) - elif inputs_vocab: - decoded_inputs = inputs_vocab.decode( - _save_until_eos(inputs.flatten())) - tf.logging.info("Inference results INPUT: %s" % decoded_inputs) + fixed_problem=problem_idx) - if FLAGS.identity_output: - decoded_outputs = " ".join(map(str, outputs.flatten())) - decoded_targets = " ".join(map(str, targets.flatten())) - else: - decoded_outputs = " ".join(map( - str, targets_vocab.decode(_save_until_eos(outputs.flatten())))) - decoded_targets = " ".join(map( - str, targets_vocab.decode(_save_until_eos(targets.flatten())))) - - tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs) - tf.logging.info("Inference results TARGET: %s" % decoded_targets) - return decoded_outputs, decoded_targets - - result_iter = estimator.predict(input_fn=infer_input_fn, as_iterable=True) - count = 0 - agg_outputs = [] - agg_targets = [] - for result in result_iter: - # predictions from the test input. We use it to log inputs and decodes. - inputs = result["inputs"] - targets = result["targets"] - outputs = result["outputs"] - if FLAGS.decode_return_beams: - output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0) - for k, beam in enumerate(output_beams): - tf.logging.info("BEAM %d:" % k) - o, t = log_fn(inputs, targets, beam, problem, count) - agg_outputs.append(o) - agg_targets.append(t) - else: - o, t = log_fn(inputs, targets, outputs, problem, count) - agg_outputs.append(o) - agg_targets.append(t) + # Get the predictions as an iterable + predictions = estimator.predict(input_fn=infer_input_fn, as_iterable=True) + + # Prepare output file writers if decode_to_file passed + if decode_to_file: + output_filepath = decode_to_file + ".outputs." + problem_name + target_filepath = decode_to_file + ".targets." + problem_name - count += 1 - if FLAGS.decode_num_samples != -1 and count >= FLAGS.decode_num_samples: - break - if FLAGS.decode_to_file: - output_filepath = FLAGS.decode_to_file + ".outputs." + problem output_file = tf.gfile.Open(output_filepath, "w") - target_filepath = FLAGS.decode_to_file + ".targets." + problem target_file = tf.gfile.Open(target_filepath, "w") - for o, t in zip(agg_outputs, agg_targets): - output_file.write(str(o)+"\n") - target_file.write(str(t)+"\n") - tf.logging.info("Completed inference on %d samples." % count) + + problem_hparams = hparams.problems[problem_idx] + inputs_vocab = problem_hparams.vocabulary.get("inputs", None) + targets_vocab = problem_hparams.vocabulary["targets"] + for num_predictions, prediction in enumerate(predictions): + inputs = prediction["inputs"] + targets = prediction["targets"] + outputs = prediction["outputs"] + + # Log predictions + decoded_outputs = [] + if return_beams: + output_beams = np.split(outputs, beam_size, axis=0) + for i, beam in enumerate(output_beams): + tf.logging.info("BEAM %d:" % i) + decoded = _decode_from_dataset_log_results( + inputs, targets, beam, problem_name, num_predictions, + inputs_vocab, targets_vocab, save_images, estimator.model_dir, + identity_output) + decoded_outputs.append(decoded) + else: + decoded = _decode_from_dataset_log_results( + inputs, targets, outputs, problem_name, num_predictions, + inputs_vocab, targets_vocab, save_images, estimator.model_dir, + identity_output) + decoded_outputs.append(decoded) + + # Write out predictions if decode_to_file passed + if decode_to_file: + for decoded_output, decoded_target in decoded_outputs: + output_file.write(str(decoded_output) + "\n") + target_file.write(str(decoded_target) + "\n") + + if max_predictions >= 0 and num_predictions >= max_predictions: + break + + if decode_to_file: + output_file.close() + target_file.close() + + tf.logging.info("Completed inference on %d samples." % num_predictions) # pylint: disable=undefined-loop-variable def decode_from_file(estimator, filename): diff --git a/tensor2tensor/utils/expert_utils.py b/tensor2tensor/utils/expert_utils.py index 6f26f20fa..fb1d1fac0 100644 --- a/tensor2tensor/utils/expert_utils.py +++ b/tensor2tensor/utils/expert_utils.py @@ -436,6 +436,87 @@ def noisy_top_k_gating(x, return gates, load +class PadRemover(object): + """Helper to remove padding from a tensor before sending to the experts. + + The padding is computed for one reference tensor containing the padding mask + and then can be applied to any other tensor of shape [dim_origin,...]. + + Ex: + input = [ + [tok1, tok2], + [tok3, tok4], + [0, 0], + [0, 0], + [tok5, tok6], + [0, 0], + ] + output = [ + [tok1, tok2], + [tok3, tok4], + [tok5, tok6], + ] + """ + + def __init__(self, pad_mask): + """Compute and store the location of the padding. + + Args: + pad_mask (tf.Tensor): Reference padding tensor of shape + [batch_size,length] or [dim_origin] (dim_origin=batch_size*length) + containing non-zeros positive values to indicate padding location. + """ + self.nonpad_ids = None + self.dim_origin = None + + with tf.name_scope("pad_reduce/get_ids"): + pad_mask = tf.reshape(pad_mask, [-1]) # Flatten the batch + # nonpad_ids contains coordinates of zeros rows (as pad_mask is + # float32, checking zero equality is done with |x| < epsilon, with + # epsilon=1e-9 as standard, here pad_mask only contains positive values + # so tf.abs would be redundant) + self.nonpad_ids = tf.to_int32(tf.where(pad_mask < 1e-9)) + self.dim_origin = tf.shape(pad_mask)[:1] + + def remove(self, x): + """Remove padding from the given tensor. + + Args: + x (tf.Tensor): of shape [dim_origin,...] + + Returns: + a tensor of shape [dim_compressed,...] with dim_compressed <= dim_origin + """ + with tf.name_scope("pad_reduce/remove"): + x_shape = x.get_shape().as_list() + x = tf.gather_nd( + x, + indices=self.nonpad_ids, + ) + # This is a hack but for some reason, gather_nd return a tensor of + # undefined shape, so the shape is set up manually + x.set_shape([None] + x_shape[1:]) + return x + + def restore(self, x): + """Add padding back to the given tensor. + + Args: + x (tf.Tensor): of shape [dim_compressed,...] + + Returns: + a tensor of shape [dim_origin,...] with dim_compressed >= dim_origin. The + dim is restored from the original reference tensor + """ + with tf.name_scope("pad_reduce/restore"): + x = tf.scatter_nd( + indices=self.nonpad_ids, + updates=x, + shape=tf.concat([self.dim_origin, tf.shape(x)[1:]], axis=0), + ) + return x + + class SparseDispatcher(object): """Helper for implementing a mixture of experts. @@ -766,6 +847,7 @@ def local_moe(x, pass_x=True, pass_gates=False, additional_dispatch_params=None, + pad_remover=None, name=None): """Call a local mixture of experts. @@ -782,6 +864,8 @@ def local_moe(x, additional_dispatch_params: The extra tensors that need to be sent to each expert. Examples include batch batch coordinates (see common_attention.local_expert_attention) + pad_remover (PadRemover): If given, the padding is removed/restored before + sending to the experts name: a string Returns: @@ -791,8 +875,18 @@ def local_moe(x, training loss of the model. The backpropagation of this loss encourages all experts to be approximately equally used across a batch. """ + with tf.variable_scope(name, default_name="local_moe"): x_flat = flatten_all_but_last(x) + + # Remove the padding tokens + if pad_remover: + x_flat = pad_remover.remove(x_flat) + tf.summary.scalar( # Should match the targets_nonpadding_tokens + "nonpadding_tokens", + tf.shape(x_flat)[0], + family="experts_stats") + # The gates indicate which batch elements go to which tensors. # load is a measure of approximately how many examples go to each expert gates, load = noisy_top_k_gating( @@ -805,17 +899,27 @@ def local_moe(x, noise_epsilon=1e-2) # This magic object helps us shuffle data between datashards and experts. dispatcher = SparseDispatcher(num_experts, gates) + + # Set up expert_fn arguments expert_kwargs = {} if pass_x: expert_kwargs["x"] = dispatcher.dispatch(x_flat) if pass_gates: expert_kwargs["gates"] = dispatcher.expert_to_gates() for k, v in six.iteritems(additional_dispatch_params or {}): - expert_kwargs[k] = dispatcher.dispatch(flatten_all_but_last(v)) + v = flatten_all_but_last(v) + if pad_remover: + v = pad_remover.remove(v) + expert_kwargs[k] = dispatcher.dispatch(v) + ep = Parallelism([DEFAULT_DEV_STRING] * num_experts) expert_outputs = ep(expert_fn, **expert_kwargs) + y_flat = dispatcher.combine(expert_outputs) + if pad_remover: + y_flat = pad_remover.restore(y_flat) y = reshape_like(y_flat, x) + importance = tf.reduce_sum(gates, 0) loss = loss_coef * (cv_squared(importance) + cv_squared(load)) return y, loss diff --git a/tensor2tensor/utils/expert_utils_test.py b/tensor2tensor/utils/expert_utils_test.py new file mode 100644 index 000000000..93af9c78c --- /dev/null +++ b/tensor2tensor/utils/expert_utils_test.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tensor2tensor.utils.expert_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports +from tensor2tensor.layers import common_attention +from tensor2tensor.utils import expert_utils +import tensorflow as tf + + +class ExpertUtilsTest(tf.test.TestCase): + + def _verify_value(self, sess, tensor, expected): + output = sess.run(tensor) + self.assertAllClose(output, expected, 1e-9) + + def testPadRemover(self): + """Check that the padding remover is working correctly.""" + x_1 = tf.constant([ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [0, 0, 0], # pad + [0, 0, 0], # pad + [0, 0, 0], # pad + [10, 11, 12], + [13, 14, 15], + [0, 0, 0], # pad + ], dtype=tf.float32) + # Get padding mask + x_pad_mask = common_attention.embedding_to_padding(x_1) + x_2 = tf.constant([ + [1], + [2], + [3], + [4], # pad + [5], # pad + [6], # pad + [7], + [8], + [9], # pad + ], dtype=tf.float32) + x_3 = tf.constant([ + 1, + 2, + 3, + 4, # pad + 5, # pad + 6, # pad + 7, + 8, + 9, # pad + ], dtype=tf.float32) + + pad_remover = expert_utils.PadRemover(x_pad_mask) + + y_1 = pad_remover.remove(x_1) + y_2 = pad_remover.remove(x_2) + y_3 = pad_remover.remove(x_3) + + z_1 = pad_remover.restore(y_1 * 2) + z_2 = pad_remover.restore(y_2 * 2) + z_3 = pad_remover.restore(y_3 * 2) + + with self.test_session() as sess: + # Padding should have been removed + self._verify_value(sess, y_1, [ + [1., 2., 3.], + [4., 5., 6.], + [7., 8., 9.], + [10., 11., 12.], + [13., 14., 15.], + ]) + self._verify_value(sess, y_2, [ + [1.], + [2.], + [3.], + [7.], + [8.], + ]) + self._verify_value(sess, y_3, [ + 1., + 2., + 3., + 7., + 8., + ]) + + # Padding should have been restored + self._verify_value(sess, z_1, [ + [2., 4., 6.], + [8., 10., 12.], + [14., 16, 18.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [20., 22., 24.], + [26., 28., 30.], + [0., 0., 0.], + ]) + self._verify_value(sess, z_2, [ + [2.], + [4.], + [6.], + [0.], # pad + [0.], # pad + [0.], # pad + [14.], + [16.], + [0.], # pad + ]) + self._verify_value(sess, z_3, [ + 2., + 4., + 6., + 0., # pad + 0., # pad + 0., # pad + 14., + 16., + 0., # pad + ]) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensor2tensor/utils/input_fn_builder.py b/tensor2tensor/utils/input_fn_builder.py index c31ba0f31..bef95d58f 100644 --- a/tensor2tensor/utils/input_fn_builder.py +++ b/tensor2tensor/utils/input_fn_builder.py @@ -183,6 +183,12 @@ def input_fn(): if mode == tf.contrib.learn.ModeKeys.INFER: rand_feature_map["infer_targets"] = rand_target rand_target = None + # This is because of a bug in the tf.contrib.learn Estimator that + # short-circuits prediction if it doesn't see a QueueRunner. + # DummyQueueRunner implements the minimal expected interface but does + # nothing. + # TODO(rsepassi): Remove once we move to core Estimator. + tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, DummyQueueRunner()) return rand_feature_map, rand_target return input_fn @@ -195,3 +201,14 @@ def cond_on_index(fn, index_tensor, cur_idx, max_idx): return tf.cond( tf.equal(index_tensor, cur_idx), lambda: fn(cur_idx), lambda: cond_on_index(fn, index_tensor, cur_idx + 1, max_idx)) + + +class DummyQueueRunner(object): + """Can stand-in for a QueueRunner but does nothing.""" + + def __init__(self): + pass + + def create_threads(self, sess, coord=None, daemon=False, start=False): + del sess, coord, daemon, start + return [] diff --git a/tensor2tensor/utils/metrics.py b/tensor2tensor/utils/metrics.py index e5cb88ddf..baff66669 100644 --- a/tensor2tensor/utils/metrics.py +++ b/tensor2tensor/utils/metrics.py @@ -42,6 +42,7 @@ class Metrics(object): R2 = "r_squared" ROUGE_2_F = "rouge_2_fscore" ROUGE_L_F = "rouge_L_fscore" + EDIT_DISTANCE = "edit_distance" def padded_rmse(predictions, labels, weights_fn=common_layers.weights_all): @@ -122,6 +123,50 @@ def padded_sequence_accuracy(predictions, return correct_seq, tf.constant(1.0) +def sequence_edit_distance(predictions, + labels, + weights_fn=common_layers.weights_nonzero): + """Average edit distance, ignoring padding 0s. + + The score returned is the edit distance divided by the total length of + reference truth and the weight returned is the total length of the truth. + + Args: + predictions: Tensor of shape [`batch_size`, `length`, 1, `num_classes`] and + type tf.float32 representing the logits, 0-padded. + labels: Tensor of shape [`batch_size`, `length`, 1, 1] and type tf.int32 + representing the labels of same length as logits and 0-padded. + weights_fn: ignored. The weights returned are the total length of the ground + truth labels, excluding 0-paddings. + + Returns: + (edit distance / reference length, reference length) + + Raises: + ValueError: if weights_fn is not common_layers.weights_nonzero. + """ + if weights_fn is not common_layers.weights_nonzero: + raise ValueError("Only weights_nonzero can be used for this metric.") + + with tf.variable_scope("edit_distance", values=[predictions, labels]): + # Transform logits into sequence classes by taking max at every step. + predictions = tf.to_int32( + tf.squeeze(tf.argmax(predictions, axis=-1), axis=(2, 3))) + nonzero_idx = tf.where(tf.not_equal(predictions, 0)) + sparse_outputs = tf.SparseTensor(nonzero_idx, + tf.gather_nd(predictions, nonzero_idx), + tf.shape(predictions, out_type=tf.int64)) + labels = tf.squeeze(labels, axis=(2, 3)) + nonzero_idx = tf.where(tf.not_equal(labels, 0)) + label_sparse_outputs = tf.SparseTensor(nonzero_idx, + tf.gather_nd(labels, nonzero_idx), + tf.shape(labels, out_type=tf.int64)) + distance = tf.reduce_sum( + tf.edit_distance(sparse_outputs, label_sparse_outputs, normalize=False)) + reference_length = tf.to_float(tf.shape(nonzero_idx)[0]) + return distance / reference_length, reference_length + + def padded_neg_log_perplexity(predictions, labels, weights_fn=common_layers.weights_nonzero): @@ -234,4 +279,5 @@ def problem_metric_fn(predictions, labels, weights): Metrics.R2: padded_variance_explained, Metrics.ROUGE_2_F: rouge.rouge_2_fscore, Metrics.ROUGE_L_F: rouge.rouge_l_fscore, + Metrics.EDIT_DISTANCE: sequence_edit_distance, } diff --git a/tensor2tensor/utils/metrics_test.py b/tensor2tensor/utils/metrics_test.py index 0d78e632c..528fd4755 100644 --- a/tensor2tensor/utils/metrics_test.py +++ b/tensor2tensor/utils/metrics_test.py @@ -72,6 +72,29 @@ def testSequenceAccuracyMetric(self): actual = session.run(a) self.assertEqual(actual, expected) + def testSequenceEditDistanceMetric(self): + predictions = np.array([[3, 4, 5, 1, 0, 0], + [2, 1, 3, 4, 0, 0], + [2, 1, 3, 4, 0, 0]]) + # Targets are just a bit different: + # - first sequence has a different prediction + # - second sequence has a different prediction and one extra step + # - third sequence is identical + targets = np.array([[5, 4, 5, 1, 0, 0], + [2, 5, 3, 4, 1, 0], + [2, 1, 3, 4, 0, 0]]) + # Reshape to match expected input format by metric fns. + predictions = np.reshape(predictions, [3, 6, 1, 1]) + targets = np.reshape(targets, [3, 6, 1, 1]) + with self.test_session() as session: + scores, weight = metrics.sequence_edit_distance( + tf.one_hot(predictions, depth=6, dtype=tf.float32), + tf.constant(targets, dtype=tf.int32)) + session.run(tf.global_variables_initializer()) + actual_scores, actual_weight = session.run([scores, weight]) + self.assertAlmostEqual(actual_scores, 3.0 / 13) + self.assertEqual(actual_weight, 13) + def testNegativeLogPerplexity(self): predictions = np.random.randint(4, size=(12, 12, 12, 1)) targets = np.random.randint(4, size=(12, 12, 12, 1)) diff --git a/tensor2tensor/utils/model_builder.py b/tensor2tensor/utils/model_builder.py index 24c17ca9e..34af6c827 100644 --- a/tensor2tensor/utils/model_builder.py +++ b/tensor2tensor/utils/model_builder.py @@ -104,6 +104,14 @@ def learning_rate_decay(): elif hparams.learning_rate_decay_scheme == "cosine": cycle_steps = hparams.learning_rate_cosine_cycle_steps return 0.5 * (1 + tf.cos(np.pi * (step % cycle_steps) / cycle_steps)) + elif hparams.learning_rate_decay_scheme == "cyclelinear10x": + # Cycle the rate linearly by 10x every warmup_steps, up and down. + cycle_steps = hparams.learning_rate_warmup_steps + cycle_position = step % (2 * cycle_steps) + cycle_position = tf.to_float( # Normalize to the interval [-1, 1]. + cycle_position - cycle_steps) / float(cycle_steps) + cycle_position = 1.0 - tf.abs(cycle_position) # 0 to 1 and back to 0. + return (cycle_position + 0.1) * 3.0 # 10x difference each cycle (0.3-3). inv_base = tf.exp(tf.log(0.01) / warmup_steps) inv_decay = inv_base**(warmup_steps - step) @@ -156,12 +164,6 @@ def model_fn(features, targets, mode): features = _interactive_input_tensor_to_features_dict(features, my_hp) elif FLAGS.decode_from_file: features = _decode_input_tensor_to_features_dict(features, my_hp) - # A dictionary containing: - # - problem_choice: A Tensor containing an integer indicating which problem - # was selected for this run. - # - predictions: A Tensor containing the model's output predictions. - run_info = dict() - run_info["problem_choice"] = features["problem_choice"] if targets is not None: features["targets"] = targets @@ -185,17 +187,20 @@ def model_fn(features, targets, mode): tf.summary.scalar("%s_nonpadding_fraction" % k, tf.reduce_mean(nonpadding)) - # The new data reader occasionally emits very small batches, which - # cause the examples in those batches to be grossly overweighted. - # We decrease the loss proportionally to the ratio of the size of this - # batch to the size of the largest training batch ever. - # TODO(noam): to be more sophisticated, we could keep separate - # maxima based on problem choice. - max_nonpadding_var = tf.get_variable( - "max_nonpadding", shape=[], - initializer=tf.ones_initializer(), trainable=False) - max_nonpadding = tf.maximum(max_nonpadding_var, targets_nonpadding_tokens) if is_training: + # The new data reader occasionally emits very small batches, which + # cause the examples in those batches to be grossly overweighted. + # We decrease the loss proportionally to the ratio of the size of this + # batch to the size of the largest training batch ever. + # TODO(noam): to be more sophisticated, we could keep separate + # maxima based on problem choice. + max_nonpadding_var = tf.get_variable( + "max_nonpadding", + shape=[], + initializer=tf.ones_initializer(), + trainable=False) + max_nonpadding = tf.maximum(max_nonpadding_var, + targets_nonpadding_tokens) with tf.control_dependencies( [tf.assign(max_nonpadding_var, max_nonpadding)]): small_batch_multiplier = targets_nonpadding_tokens / max_nonpadding @@ -203,6 +208,7 @@ def model_fn(features, targets, mode): # Get multi-problem logits and loss based on features["problem_choice"]. loss_variable_names = [] + def nth_model(n): """Build the model for the n-th problem, plus some added variables.""" model_class = registry.model(model)( @@ -249,8 +255,8 @@ def nth_model(n): # Total loss was already constructed on input. loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n) except ValueError: - loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n, - initializer=100.0, trainable=False) + loss_moving_avg = tf.get_variable( + "problem_%d/total_loss" % n, initializer=100.0, trainable=False) ops.append( loss_moving_avg.assign(loss_moving_avg * 0.9 + total_loss * 0.1)) with tf.variable_scope("train_stats"): # Count steps for this problem. @@ -287,11 +293,13 @@ def nth_model(n): sharded_logits, total_loss = result_list[1:], result_list[0] if mode == tf.contrib.learn.ModeKeys.EVAL: - logits = tf.concat(sharded_logits, 0) # For evaluation, return the logits layer as our predictions. - run_info["predictions"] = logits - train_op = None - return run_info, total_loss, None + logits = tf.concat(sharded_logits, 0) + ret = { + "predictions": logits, + "problem_choice": features["problem_choice"], + } + return ret, total_loss, None assert mode == tf.contrib.learn.ModeKeys.TRAIN @@ -373,7 +381,7 @@ def nth_model(n): del summaries[i] tf.logging.info("Global model_fn finished.") - return run_info, total_loss, train_op + return {"problem_choice": features["problem_choice"]}, total_loss, train_op return model_fn diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index fa9d9233e..a747b9a09 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -27,7 +27,6 @@ from tensor2tensor.data_generators import problem_hparams from tensor2tensor.models import models # pylint: disable=unused-import from tensor2tensor.utils import data_reader -from tensor2tensor.utils import decoding from tensor2tensor.utils import devices from tensor2tensor.utils import input_fn_builder from tensor2tensor.utils import metrics @@ -101,16 +100,13 @@ flags.DEFINE_string("ps_job", "/job:ps", "name of ps job") flags.DEFINE_integer("ps_replicas", 0, "How many ps replicas.") -# Decode flags -# Set one of {decode_from_dataset, decode_interactive, decode_from_file} to -# decode. -flags.DEFINE_bool("decode_from_dataset", False, "Decode from dataset on disk.") -flags.DEFINE_bool("decode_use_last_position_only", False, - "In inference, use last position only for speedup.") +# Decoding flags +flags.DEFINE_string("decode_from_file", None, "Path to decode file") flags.DEFINE_bool("decode_interactive", False, "Interactive local inference mode.") +flags.DEFINE_bool("decode_use_last_position_only", False, + "In inference, use last position only for speedup.") flags.DEFINE_bool("decode_save_images", False, "Save inference input images.") -flags.DEFINE_string("decode_from_file", None, "Path to decode file") flags.DEFINE_string("decode_to_file", None, "Path to inference output file") flags.DEFINE_integer("decode_shards", 1, "How many shards to decode.") flags.DEFINE_integer("decode_problem_id", 0, "Which problem to decode.") @@ -128,7 +124,7 @@ "Maximum number of ids in input. Or <= 0 for no max.") flags.DEFINE_bool("identity_output", False, "To print the output as identity") flags.DEFINE_integer("decode_num_samples", -1, - "Number of samples to decode. Currently used in" + "Number of samples to decode. Currently used in " "decode_from_dataset. Use -1 for all.") @@ -149,8 +145,8 @@ def experiment_fn(output_dir): def create_experiment(output_dir, data_dir, model_name, train_steps, eval_steps): """Create Experiment.""" - hparams = create_hparams(FLAGS.hparams_set, FLAGS.problems, data_dir, - passed_hparams=FLAGS.hparams) + hparams = create_hparams( + FLAGS.hparams_set, FLAGS.problems, data_dir, passed_hparams=FLAGS.hparams) estimator, input_fns = create_experiment_components( hparams=hparams, output_dir=output_dir, @@ -303,7 +299,6 @@ def run(data_dir, model, output_dir, train_steps, eval_steps, schedule): if exp.train_steps > 0 or exp.eval_steps > 0: tf.logging.info("Performing local training and evaluation.") exp.train_and_evaluate() - decode(exp.estimator) else: # Perform distributed training/evaluation. learn_runner.run( @@ -350,12 +345,3 @@ def session_config(): def get_data_filepatterns(data_dir, mode): return data_reader.get_data_filepatterns(FLAGS.problems, data_dir, mode) - - -def decode(estimator): - if FLAGS.decode_interactive: - decoding.decode_interactively(estimator) - elif FLAGS.decode_from_file is not None and FLAGS.decode_from_file is not "": - decoding.decode_from_file(estimator, FLAGS.decode_from_file) - elif FLAGS.decode_from_dataset: - decoding.decode_from_dataset(estimator) diff --git a/tensor2tensor/visualization/TransformerVisualization.ipynb b/tensor2tensor/visualization/TransformerVisualization.ipynb index ef1c7b45d..e3fb8f958 100644 --- a/tensor2tensor/visualization/TransformerVisualization.ipynb +++ b/tensor2tensor/visualization/TransformerVisualization.ipynb @@ -86,7 +86,7 @@ "import os\n", "# PUT THE MODEL YOU WANT TO LOAD HERE!\n", "\n", - "PROBLEM = 'wmt_ende_tokens_32k'\n", + "PROBLEM = 'translate_ende_wmt32k'\n", "MODEL = 'transformer'\n", "HPARAMS = 'transformer_base_single_gpu'\n", "\n", @@ -118,7 +118,7 @@ } ], "source": [ - "hparams = utils.create_hparams(HPARAMS, DATA_DIR)\n", + "hparams = utils.create_hparams(HPARAMS, PROBLEM, DATA_DIR)\n", "\n", "# SET EXTRA HYPER PARAMS HERE!\n", "# e.g.\n", @@ -381,8 +381,7 @@ }, "outputs": [], "source": [ - "der = decode(beam_decode[0])\n", - "output_ids = encode(der)\n", + "output_ids = beam_decode\n", "\n", "# Get attentions\n", "np_enc_atts, np_dec_atts, np_encdec_atts = sess.run([enc_atts, dec_atts, encdec_atts], {\n", diff --git a/tensor2tensor/visualization/__init__.py b/tensor2tensor/visualization/__init__.py new file mode 100644 index 000000000..b62605264 --- /dev/null +++ b/tensor2tensor/visualization/__init__.py @@ -0,0 +1,16 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/tensor2tensor/visualization/attention.py b/tensor2tensor/visualization/attention.py index 2c1f61c9c..bc4238081 100644 --- a/tensor2tensor/visualization/attention.py +++ b/tensor2tensor/visualization/attention.py @@ -21,8 +21,7 @@ import json import os -from IPython.display import HTML -from IPython.display import Javascript +import IPython.display as display import numpy as np @@ -53,9 +52,9 @@ def show(inp_text, out_text, enc_atts, dec_atts, encdec_atts): def _show_attention(att_json): - display(HTML(vis_html)) # pylint: disable=undefined-variable - display(Javascript('window.attention = %s' % att_json)) # pylint: disable=undefined-variable - display(Javascript(vis_js)) # pylint: disable=undefined-variable + display.display(display.HTML(vis_html)) + display.display(display.Javascript('window.attention = %s' % att_json)) + display.display(display.Javascript(vis_js)) def _get_attention(inp_text, out_text, enc_atts, dec_atts, encdec_atts): @@ -88,8 +87,8 @@ def _get_attention(inp_text, out_text, enc_atts, dec_atts, encdec_atts): """ def get_full_attention(layer): """Get the full input+output - input+output attentions.""" - enc_att = enc_atts[layer][0], - dec_att = dec_atts[layer][0], + enc_att = enc_atts[layer][0] + dec_att = dec_atts[layer][0] encdec_att = encdec_atts[layer][0] enc_att = np.transpose(enc_att, [0, 2, 1]) dec_att = np.transpose(dec_att, [0, 2, 1])