diff --git a/README.md b/README.md
index 236d279c2..4e56d7855 100644
--- a/README.md
+++ b/README.md
@@ -29,7 +29,7 @@ You can chat with us and other users on
with T2T announcements.
Here is a one-command version that installs tensor2tensor, downloads the data,
-trains an English-German translation model, and lets you use it interactively:
+trains an English-German translation model, and evaluates it:
```
pip install tensor2tensor && t2t-trainer \
--generate_data \
@@ -37,7 +37,18 @@ pip install tensor2tensor && t2t-trainer \
--problems=translate_ende_wmt32k \
--model=transformer \
--hparams_set=transformer_base_single_gpu \
- --output_dir=~/t2t_train/base \
+ --output_dir=~/t2t_train/base
+```
+
+You can decode from the model interactively:
+
+```
+t2t-decoder \
+ --data_dir=~/t2t_data \
+ --problems=translate_ende_wmt32k \
+ --model=transformer \
+ --hparams_set=transformer_base_single_gpu \
+ --output_dir=~/t2t_train/base
--decode_interactive
```
@@ -106,14 +117,12 @@ echo "Goodbye world" >> $DECODE_FILE
BEAM_SIZE=4
ALPHA=0.6
-t2t-trainer \
+t2t-decoder \
--data_dir=$DATA_DIR \
--problems=$PROBLEM \
--model=$MODEL \
--hparams_set=$HPARAMS \
--output_dir=$TRAIN_DIR \
- --train_steps=0 \
- --eval_steps=0 \
--decode_beam_size=$BEAM_SIZE \
--decode_alpha=$ALPHA \
--decode_from_file=$DECODE_FILE
diff --git a/docs/example_life.md b/docs/example_life.md
new file mode 100644
index 000000000..2983f5077
--- /dev/null
+++ b/docs/example_life.md
@@ -0,0 +1,34 @@
+# T2T: Life of an Example
+
+[![PyPI
+version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
+[![GitHub
+Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
+[![Contributions
+welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
+[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
+[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
+
+This document show how a training example passes through the T2T pipeline,
+and how all its parts are connected to work together.
+
+## The Life of an Example
+
+A training example passes the following stages in T2T:
+* raw input (text from command line or file)
+* encoded input after [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `encode` is usually a sparse tensor, e.g., a vector of `tf.int32`s
+* batched input after [data input pipeline](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/data_reader.py#L242) where the inputs, after [Problem.preprocess_examples](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L188) are grouped by their length and made into batches.
+* dense input after being processed by a [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `bottom`.
+* dense output after [T2T.model_fn_body](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/t2t_model.py#L542)
+* back to sparse output through [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `top`.
+* if decoding, back through [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `decode` to display on the screen.
+
+We go into these phases step by step below.
+
+## Feature Encoders
+
+TODO: describe [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) which is a dict of encoders that have `encode` and `decode` functions.
+
+## Modalities
+
+TODO: describe [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) which has `bottom` and `top` but also sharded versions and one for targets.
diff --git a/docs/index.md b/docs/index.md
index a5eeba137..9394809b3 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -1,11 +1,4 @@
-# T2T: Tensor2Tensor Transformers
-
-Check us out on
-
-GitHub
-
-
-.
+# Tensor2Tensor Docs Index
[![PyPI
version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
@@ -16,8 +9,26 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
-See our
-[README](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/README.md)
-for documentation.
-More documentation and tutorials coming soon...
+Welcome to Tensor2Tensor!
+
+Tensor2Tensor, or T2T for short, is a library we use to create,
+investigate and deploy deep learning models. This page hosts our
+documentation, from basic tutorials to full code documentation.
+
+## Basics
+
+* [Walkthrough: Install and Run](walkthrough.md)
+* [Tutorial: Train on Your Data](new_problem.md)
+* [Tutorial: Create Your Own Model](new_model.md)
+
+## Deep Dive
+
+* [Life of an Example](example_life.md): how all parts of T2T are connected and work together
+* [Distributed Training](distributed_training.md)
+
+## Code documentation
+
+See our
+[README](https://github.com/tensorflow/tensor2tensor/blob/master/README.md)
+for now, code docs coming.
diff --git a/docs/new_model.md b/docs/new_model.md
new file mode 100644
index 000000000..5968c8325
--- /dev/null
+++ b/docs/new_model.md
@@ -0,0 +1,16 @@
+# T2T: Create Your Own Model
+
+[![PyPI
+version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
+[![GitHub
+Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
+[![Contributions
+welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
+[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
+[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
+
+Here we show how to create your own model in T2T.
+
+## The T2TModel class
+
+TODO: complete.
diff --git a/docs/new_problem.md b/docs/new_problem.md
new file mode 100644
index 000000000..c859c6eba
--- /dev/null
+++ b/docs/new_problem.md
@@ -0,0 +1,240 @@
+# T2T: Train on Your Own Data
+
+[![PyPI
+version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
+[![GitHub
+Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
+[![Contributions
+welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
+[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
+[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
+
+Let's add a new dataset together and train the transformer model. We'll be learning to define English words by training the transformer to "translate" between English words and their definitions on a character level.
+
+# About the Problem
+
+For each problem we want to tackle we create a new problem class and register it. Let's call our problem `Word2def`.
+
+Since many text2text problems share similar methods, there's already a class called `Text2TextProblem` that extends the base problem class, `Problem` (both found in `problem.py`).
+
+For our problem, we can go ahead and create the file `word2def.py` in the `data_generators` folder and add our new problem, `Word2def`, which extends `TranslateProblem`. Let's also register it while we're at it so we can specify the problem through flags.
+
+```python
+@registry.register_problem()
+class Word2def(problem.Text2TextProblem):
+ """Problem spec for English word to dictionary definition."""
+ return NotImplementedError()
+```
+
+We need to implement the following methods from `Text2TextProblem` in our new class:
+* is_character_level
+* targeted_vocab_size
+* generator
+* input_space_id
+* target_space_id
+* num_shards
+* vocab_name
+* use_subword_tokenizer
+
+Let's tackle them one by one:
+
+**input_space_id, target_space_id, is_character_level, targeted_vocab_size, use_subword_tokenizer**:
+
+SpaceIDs tell Tensor2Tensor what sort of space the input and target tensors are in. These are things like, EN_CHR (English character), EN_TOK (English token), AUDIO_WAV (audio waveform), IMAGE, DNA (genetic bases). The complete list can be found at `data_generators/problem.py` in the class `SpaceID`.
+
+Since we're generating definitions and feeding in words at the character level, we set `is_character_level` to true, and use the same SpaceID, EN_CHR, for both input and target. Additionally, since we aren't using tokens, we don't need to give a `targeted_vocab_size` or define `use_subword_tokenizer`.
+
+**vocab_name**:
+
+`vocab_name` will be used to name your vocabulary files. We can call ours `'vocab.word2def.en'`
+
+**num_shards**:
+
+The number of shards to break data files into.
+
+```python
+@registry.register_problem()
+class Word2def(problem.Text2TextProblem):
+ """Problem spec for English word to dictionary definition."""
+ def is_character_level(self):
+ return True
+
+ @property
+ def vocab_name(self):
+ return "vocab.word2def.en"
+
+ @property
+ def input_space_id(self):
+ return problem.SpaceID.EN_CHR
+
+ @property
+ def target_space_id(self):
+ return problem.SpaceID.EN_CHR
+
+ @property
+ def num_shards(self):
+ return 100
+
+ @property
+ def use_subword_tokenizer(self):
+ return False
+```
+
+**generator**:
+
+We're almost done. `generator` generates the training and evaluation data and stores them in files like "word2def_train.lang1" in your DATA_DIR. Thankfully several commonly used methods like `character_generator`, and `token_generator` are already written in the file `wmt.py`. We will import `character_generator` and write:
+```python
+ def generator(self, data_dir, tmp_dir, train):
+ character_vocab = text_encoder.ByteTextEncoder()
+ datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS
+ tag = "train" if train else "dev"
+ return character_generator(datasets[0], datasets[1], character_vocab, EOS)
+```
+
+Now our `word2def.py` file looks like the below:
+
+```python
+@registry.register_problem()
+class Word2def(problem.Text2TextProblem):
+ """Problem spec for English word to dictionary definition."""
+ @property
+ def is_character_level(self):
+ return True
+
+ @property
+ def vocab_name(self):
+ return "vocab.word2def.en"
+
+ def generator(self, data_dir, tmp_dir, train):
+ character_vocab = text_encoder.ByteTextEncoder()
+ datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS
+ tag = "train" if train else "dev"
+ return character_generator(datasets[0], datasets[1], character_vocab, EOS)
+
+ @property
+ def input_space_id(self):
+ return problem.SpaceID.EN_CHR
+
+ @property
+ def target_space_id(self):
+ return problem.SpaceID.EN_CHR
+
+ @property
+ def num_shards(self):
+ return 100
+
+ @property
+ def use_subword_tokenizer(self):
+ return False
+```
+
+## Data:
+Now we need to tell Tensor2Tensor where our data is located.
+
+I've gone ahead and split all words into a train and test set and saved them in files called `words.train.txt`, `words.test.txt`,
+`definitions.train.txt`, and `definitions.test.txt` in a directory called `LOCATION_OF_DATA/`. Let's tell T2T where these files are:
+
+```python
+# English Word2def datasets
+_WORD2DEF_TRAIN_DATASETS = [
+ [
+ "LOCATION_OF_DATA/", ("words_train.txt", "definitions_train.txt")
+ ]
+]
+_WORD2DEF_TEST_DATASETS = [
+ [
+ "LOCATION_OF_DATA", ("words_test.txt", "definitions_test.txt")
+ ]
+]
+```
+
+## Putting it all together
+
+Now our `word2def.py` file looks like: (with the correct imports)
+```python
+""" Problem definition for word to dictionary definition.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tarfile # do we need this import
+
+from tensor2tensor.data_generators import generator_utils
+from tensor2tensor.data_generators import problem
+from tensor2tensor.data_generators import text_encoder
+from tensor2tensor.data_generators.wmt import character_generator
+
+from tensor2tensor.utils import registry
+
+import tensorflow as tf
+
+FLAGS = tf.flags.FLAGS
+
+# English Word2def datasets
+_WORD2DEF_TRAIN_DATASETS = [
+ LOCATION_OF_DATA+'words_train.txt',
+ LOCATION_OF_DATA+'definitions_train.txt'
+]
+
+_WORD2DEF_TEST_DATASETS = [
+ LOCATION_OF_DATA+'words_test.txt',
+ LOCATION_OF_DATA+'definitions_test.txt'
+]
+
+@registry.register_problem()
+class Word2def(problem.Text2TextProblem):
+ """Problem spec for English word to dictionary definition."""
+ @property
+ def is_character_level(self):
+ return True
+
+ @property
+ def vocab_name(self):
+ return "vocab.word2def.en"
+
+ def generator(self, data_dir, tmp_dir, train):
+ character_vocab = text_encoder.ByteTextEncoder()
+ datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS
+ tag = "train" if train else "dev"
+ return character_generator(datasets[0], datasets[1], character_vocab, EOS)
+
+ @property
+ def input_space_id(self):
+ return problem.SpaceID.EN_CHR
+
+ @property
+ def target_space_id(self):
+ return problem.SpaceID.EN_CHR
+
+ @property
+ def num_shards(self):
+ return 100
+
+ @property
+ def use_subword_tokenizer(self):
+ return False
+
+```
+
+# Hyperparameters
+All hyperparamters inherit from `_default_hparams()` in `problem.py.` If you would like to customize your hyperparameters, add another method to the file `problem_hparams.py`.
+
+# Run the problem
+Now that we've gotten our problem set up, let's train a model and generate definitions.
+
+We specify our problem name, the model, and hparams.
+```bash
+PROBLEM=word2def
+MODEL=transformer
+HPARAMS=transofmer_base_single_gpu
+```
+
+The rest of the steps are as given in the [walkthrough](walkthrough.md).
+
+
+What if we wanted to train a model to generate words given definitions? In T2T, we can change the problem name to be `PROBLEM=word2def_rev`.
+
+All done. Let us know what definitions your model generated.
diff --git a/docs/walkthrough.md b/docs/walkthrough.md
new file mode 100644
index 000000000..57d7a03f4
--- /dev/null
+++ b/docs/walkthrough.md
@@ -0,0 +1,129 @@
+# T2T Install and Run Walkthrough
+
+[![PyPI
+version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
+[![GitHub
+Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
+[![Contributions
+welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
+[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
+[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
+
+Here is a one-command version that installs tensor2tensor, downloads the data,
+trains an English-German translation model, and evaluates it:
+```
+pip install tensor2tensor && t2t-trainer \
+ --generate_data \
+ --data_dir=~/t2t_data \
+ --problems=translate_ende_wmt32k \
+ --model=transformer \
+ --hparams_set=transformer_base_single_gpu \
+ --output_dir=~/t2t_train/base
+```
+
+You can decode from the model interactively:
+
+```
+t2t-decoder \
+ --data_dir=~/t2t_data \
+ --problems=translate_ende_wmt32k \
+ --model=transformer \
+ --hparams_set=transformer_base_single_gpu \
+ --output_dir=~/t2t_train/base
+ --decode_interactive
+```
+
+## Walkthrough
+
+Here's a walkthrough training a good English-to-German translation
+model using the Transformer model from [*Attention Is All You
+Need*](https://arxiv.org/abs/1706.03762) on WMT data.
+
+```
+pip install tensor2tensor
+
+# See what problems, models, and hyperparameter sets are available.
+# You can easily swap between them (and add new ones).
+t2t-trainer --registry_help
+
+PROBLEM=translate_ende_wmt32k
+MODEL=transformer
+HPARAMS=transformer_base_single_gpu
+
+DATA_DIR=$HOME/t2t_data
+TMP_DIR=/tmp/t2t_datagen
+TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS
+
+mkdir -p $DATA_DIR $TMP_DIR $TRAIN_DIR
+
+# Generate data
+t2t-datagen \
+ --data_dir=$DATA_DIR \
+ --tmp_dir=$TMP_DIR \
+ --problem=$PROBLEM
+
+# Train
+# * If you run out of memory, add --hparams='batch_size=1024'.
+t2t-trainer \
+ --data_dir=$DATA_DIR \
+ --problems=$PROBLEM \
+ --model=$MODEL \
+ --hparams_set=$HPARAMS \
+ --output_dir=$TRAIN_DIR
+
+# Decode
+
+DECODE_FILE=$DATA_DIR/decode_this.txt
+echo "Hello world" >> $DECODE_FILE
+echo "Goodbye world" >> $DECODE_FILE
+
+BEAM_SIZE=4
+ALPHA=0.6
+
+t2t-trainer \
+ --data_dir=$DATA_DIR \
+ --problems=$PROBLEM \
+ --model=$MODEL \
+ --hparams_set=$HPARAMS \
+ --output_dir=$TRAIN_DIR \
+ --train_steps=0 \
+ --eval_steps=0 \
+ --decode_beam_size=$BEAM_SIZE \
+ --decode_alpha=$ALPHA \
+ --decode_from_file=$DECODE_FILE
+
+cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes
+```
+
+---
+
+## Installation
+
+```
+# Assumes tensorflow or tensorflow-gpu installed
+pip install tensor2tensor
+
+# Installs with tensorflow-gpu requirement
+pip install tensor2tensor[tensorflow_gpu]
+
+# Installs with tensorflow (cpu) requirement
+pip install tensor2tensor[tensorflow]
+```
+
+Binaries:
+
+```
+# Data generator
+t2t-datagen
+
+# Trainer
+t2t-trainer --registry_help
+```
+
+Library usage:
+
+```
+python -c "from tensor2tensor.models.transformer import Transformer"
+```
+
+---
diff --git a/setup.py b/setup.py
index f32e8508c..b51070c77 100644
--- a/setup.py
+++ b/setup.py
@@ -5,17 +5,24 @@
setup(
name='tensor2tensor',
- version='1.2.0',
+ version='1.2.1',
description='Tensor2Tensor',
author='Google Inc.',
author_email='no-reply@google.com',
url='http://github.com/tensorflow/tensor2tensor',
license='Apache 2.0',
packages=find_packages(),
- package_data={'tensor2tensor.data_generators': ['test_data/*']},
+ package_data={
+ 'tensor2tensor.data_generators': ['test_data/*'],
+ 'tensor2tensor.visualization': [
+ 'attention.js',
+ 'TransformerVisualization.ipynb'
+ ],
+ },
scripts=[
'tensor2tensor/bin/t2t-trainer',
'tensor2tensor/bin/t2t-datagen',
+ 'tensor2tensor/bin/t2t-decoder',
'tensor2tensor/bin/t2t-make-tf-configs',
],
install_requires=[
diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen
index f7ea7e1f2..cb6253524 100644
--- a/tensor2tensor/bin/t2t-datagen
+++ b/tensor2tensor/bin/t2t-datagen
@@ -42,7 +42,6 @@ from tensor2tensor.data_generators import algorithmic_math
from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import
from tensor2tensor.data_generators import audio
from tensor2tensor.data_generators import generator_utils
-from tensor2tensor.data_generators import lm1b
from tensor2tensor.data_generators import snli
from tensor2tensor.data_generators import wmt
from tensor2tensor.data_generators import wsj_parsing
@@ -92,19 +91,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**14, 2**9),
lambda: wsj_parsing.parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**14, 2**9)),
- "translate_ende_wmt_bpe32k": (
- lambda: wmt.ende_bpe_token_generator(
- FLAGS.data_dir, FLAGS.tmp_dir, True),
- lambda: wmt.ende_bpe_token_generator(
- FLAGS.data_dir, FLAGS.tmp_dir, False)),
- "languagemodel_1b32k": (
- lambda: lm1b.generator(FLAGS.tmp_dir, True),
- lambda: lm1b.generator(FLAGS.tmp_dir, False)
- ),
- "languagemodel_1b_characters": (
- lambda: lm1b.generator(FLAGS.tmp_dir, True, characters=True),
- lambda: lm1b.generator(FLAGS.tmp_dir, False, characters=True)
- ),
"inference_snli32k": (
lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
diff --git a/tensor2tensor/bin/t2t-decoder b/tensor2tensor/bin/t2t-decoder
new file mode 100644
index 000000000..5c3eeb293
--- /dev/null
+++ b/tensor2tensor/bin/t2t-decoder
@@ -0,0 +1,90 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2017 The Tensor2Tensor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+r"""Decode from trained T2T models.
+
+This binary performs inference using the Estimator API.
+
+Example usage to decode from dataset:
+
+ t2t-decoder \
+ --data_dir ~/data \
+ --problems=algorithmic_identity_binary40 \
+ --model=transformer
+ --hparams_set=transformer_base
+
+Set FLAGS.decode_interactive or FLAGS.decode_from_file for alternative decode
+sources.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+# Dependency imports
+
+from tensor2tensor.utils import decoding
+from tensor2tensor.utils import trainer_utils
+from tensor2tensor.utils import usr_dir
+
+import tensorflow as tf
+
+flags = tf.flags
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("t2t_usr_dir", "",
+ "Path to a Python module that will be imported. The "
+ "__init__.py file should include the necessary imports. "
+ "The imported files should contain registrations, "
+ "e.g. @registry.register_model calls, that will then be "
+ "available to the t2t-decoder.")
+
+
+def main(_):
+ tf.logging.set_verbosity(tf.logging.INFO)
+ usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
+ trainer_utils.log_registry()
+ trainer_utils.validate_flags()
+ data_dir = os.path.expanduser(FLAGS.data_dir)
+ output_dir = os.path.expanduser(FLAGS.output_dir)
+
+ hparams = trainer_utils.create_hparams(
+ FLAGS.hparams_set, FLAGS.problems, data_dir, passed_hparams=FLAGS.hparams)
+ estimator, _ = trainer_utils.create_experiment_components(
+ hparams=hparams,
+ output_dir=output_dir,
+ data_dir=data_dir,
+ model_name=FLAGS.model)
+
+ if FLAGS.decode_interactive:
+ decoding.decode_interactively(estimator)
+ elif FLAGS.decode_from_file:
+ decoding.decode_from_file(estimator, FLAGS.decode_from_file)
+ else:
+ decoding.decode_from_dataset(
+ estimator,
+ FLAGS.problems.split("-"),
+ return_beams=FLAGS.decode_return_beams,
+ beam_size=FLAGS.decode_beam_size,
+ max_predictions=FLAGS.decode_num_samples,
+ decode_to_file=FLAGS.decode_to_file,
+ save_images=FLAGS.decode_save_images,
+ identity_output=FLAGS.identity_output)
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py
index ec3a9d0af..f9afa895b 100644
--- a/tensor2tensor/data_generators/all_problems.py
+++ b/tensor2tensor/data_generators/all_problems.py
@@ -23,9 +23,11 @@
from tensor2tensor.data_generators import algorithmic_math
from tensor2tensor.data_generators import audio
from tensor2tensor.data_generators import cipher
+from tensor2tensor.data_generators import cnn_dailymail
from tensor2tensor.data_generators import desc2code
from tensor2tensor.data_generators import ice_parsing
from tensor2tensor.data_generators import image
+from tensor2tensor.data_generators import imdb
from tensor2tensor.data_generators import lm1b
from tensor2tensor.data_generators import ptb
from tensor2tensor.data_generators import snli
diff --git a/tensor2tensor/data_generators/cnn_dailymail.py b/tensor2tensor/data_generators/cnn_dailymail.py
new file mode 100644
index 000000000..db4deae4e
--- /dev/null
+++ b/tensor2tensor/data_generators/cnn_dailymail.py
@@ -0,0 +1,137 @@
+# coding=utf-8
+# Copyright 2017 The Tensor2Tensor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Data generators for the CNN and Daily Mail datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tarfile
+
+# Dependency imports
+
+import six
+from tensor2tensor.data_generators import generator_utils
+from tensor2tensor.data_generators import problem
+from tensor2tensor.data_generators import text_encoder
+from tensor2tensor.utils import registry
+
+import tensorflow as tf
+
+
+# Links to data from http://cs.nyu.edu/~kcho/DMQA/
+_CNN_STORIES_DRIVE_URL = "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ"
+
+_DAILYMAIL_STORIES_DRIVE_URL = "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs"
+
+
+# End-of-sentence marker.
+EOS = text_encoder.EOS_ID
+
+
+def _maybe_download_corpora(tmp_dir):
+ """Download corpora if necessary and unzip them.
+
+ Args:
+ tmp_dir: directory containing dataset.
+
+ Returns:
+ filepath of the downloaded corpus file.
+ """
+ cnn_filename = "cnn_stories.tgz"
+ dailymail_filename = "dailymail_stories.tgz"
+ cnn_finalpath = os.path.join(tmp_dir, "cnn/stories/")
+ dailymail_finalpath = os.path.join(tmp_dir, "dailymail/stories/")
+ if not tf.gfile.Exists(cnn_finalpath):
+ cnn_file = generator_utils.maybe_download_from_drive(
+ tmp_dir, cnn_filename, _CNN_STORIES_DRIVE_URL)
+ with tarfile.open(cnn_file, "r:gz") as cnn_tar:
+ cnn_tar.extractall(tmp_dir)
+ if not tf.gfile.Exists(dailymail_finalpath):
+ dailymail_file = generator_utils.maybe_download_from_drive(
+ tmp_dir, dailymail_filename, _CNN_STORIES_DRIVE_URL)
+ with tarfile.open(dailymail_file, "r:gz") as dailymail_tar:
+ dailymail_tar.extractall(tmp_dir)
+ return [cnn_finalpath, dailymail_finalpath]
+
+
+def story_generator(tmp_dir):
+ paths = _maybe_download_corpora(tmp_dir)
+ for path in paths:
+ for story_file in tf.gfile.Glob(path + "*"):
+ story = u""
+ for line in tf.gfile.Open(story_file):
+ line = unicode(line, "utf-8") if six.PY2 else line.decode("utf-8")
+ story += line
+ yield story
+
+
+def _story_summary_split(story):
+ end_pos = story.find("\n\n") # Upto first empty line.
+ assert end_pos != -1
+ return story[:end_pos], story[end_pos:].strip()
+
+
+@registry.register_problem
+class SummarizeCnnDailymail32k(problem.Text2TextProblem):
+ """Summarize CNN and Daily Mail articles to their first paragraph."""
+
+ @property
+ def is_character_level(self):
+ return False
+
+ @property
+ def has_inputs(self):
+ return True
+
+ @property
+ def input_space_id(self):
+ return problem.SpaceID.EN_TOK
+
+ @property
+ def target_space_id(self):
+ return problem.SpaceID.EN_TOK
+
+ @property
+ def num_shards(self):
+ return 100
+
+ @property
+ def vocab_name(self):
+ return "vocab.cnndailymail"
+
+ @property
+ def use_subword_tokenizer(self):
+ return True
+
+ @property
+ def targeted_vocab_size(self):
+ return 2**15 # 32768
+
+ @property
+ def use_train_shards_for_dev(self):
+ return True
+
+ def generator(self, data_dir, tmp_dir, _):
+ encoder = generator_utils.get_or_generate_vocab_inner(
+ data_dir, self.vocab_file, self.targeted_vocab_size,
+ lambda: story_generator(tmp_dir))
+ for story in story_generator(tmp_dir):
+ summary, rest = _story_summary_split(story)
+ encoded_summary = encoder.encode(summary) + [EOS]
+ encoded_story = encoder.encode(rest) + [EOS]
+ yield {"inputs": encoded_story, "targets": encoded_summary}
diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py
index 71f4f0920..fbe91d70e 100644
--- a/tensor2tensor/data_generators/image.py
+++ b/tensor2tensor/data_generators/image.py
@@ -272,7 +272,8 @@ def hparams(self, defaults, model_hparams):
small_modality = "%s:small_image_modality" % registry.Modalities.IMAGE
modality = small_modality if self.is_small else registry.Modalities.IMAGE
p.input_modality = {"inputs": (modality, None)}
- p.target_modality = (registry.Modalities.CLASS_LABEL, self.num_classes)
+ p.target_modality = ("%s:2d" % registry.Modalities.CLASS_LABEL,
+ self.num_classes)
p.batch_size_multiplier = 4 if self.is_small else 256
p.max_expected_batch_size_per_shard = 8 if self.is_small else 2
p.loss_multiplier = 3.0 if self.is_small else 1.0
diff --git a/tensor2tensor/data_generators/imdb.py b/tensor2tensor/data_generators/imdb.py
new file mode 100644
index 000000000..281a03bee
--- /dev/null
+++ b/tensor2tensor/data_generators/imdb.py
@@ -0,0 +1,124 @@
+# coding=utf-8
+# Copyright 2017 The Tensor2Tensor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""IMDB Sentiment Classification Problem."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tarfile
+
+# Dependency imports
+
+from tensor2tensor.data_generators import generator_utils
+from tensor2tensor.data_generators import problem
+from tensor2tensor.data_generators import text_encoder
+from tensor2tensor.utils import registry
+
+import tensorflow as tf
+
+# End-of-sentence marker.
+EOS = text_encoder.EOS_ID
+
+
+@registry.register_problem
+class SentimentIMDB(problem.Problem):
+ """IMDB sentiment classification."""
+ URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
+
+ @property
+ def num_shards(self):
+ return 10
+
+ @property
+ def vocab_file(self):
+ return "sentiment_imdb.vocab"
+
+ @property
+ def targeted_vocab_size(self):
+ return 2**13 # 8k vocab suffices for this small dataset.
+
+ def doc_generator(self, imdb_dir, dataset, include_label=False):
+ dirs = [(os.path.join(imdb_dir, dataset, "pos"), True), (os.path.join(
+ imdb_dir, dataset, "neg"), False)]
+
+ for d, label in dirs:
+ for filename in os.listdir(d):
+ with tf.gfile.Open(os.path.join(d, filename)) as imdb_f:
+ doc = imdb_f.read().strip()
+ if include_label:
+ yield doc, label
+ else:
+ yield doc
+
+ def generator(self, data_dir, tmp_dir, train):
+ """Generate examples."""
+ # Download and extract
+ compressed_filename = os.path.basename(self.URL)
+ download_path = generator_utils.maybe_download(tmp_dir, compressed_filename,
+ self.URL)
+ imdb_dir = os.path.join(tmp_dir, "aclImdb")
+ if not tf.gfile.Exists(imdb_dir):
+ with tarfile.open(download_path, "r:gz") as tar:
+ tar.extractall(tmp_dir)
+
+ # Generate vocab
+ encoder = generator_utils.get_or_generate_vocab_inner(
+ data_dir, self.vocab_file, self.targeted_vocab_size,
+ lambda: self.doc_generator(imdb_dir, "train"))
+
+ # Generate examples
+ dataset = "train" if train else "test"
+ for doc, label in self.doc_generator(imdb_dir, dataset, include_label=True):
+ yield {
+ "inputs": encoder.encode(doc) + [EOS],
+ "targets": [int(label)],
+ }
+
+ def generate_data(self, data_dir, tmp_dir, task_id=-1):
+ train_paths = self.training_filepaths(
+ data_dir, self.num_shards, shuffled=False)
+ dev_paths = self.dev_filepaths(data_dir, 1, shuffled=False)
+ generator_utils.generate_dataset_and_shuffle(
+ self.generator(data_dir, tmp_dir, True), train_paths,
+ self.generator(data_dir, tmp_dir, False), dev_paths)
+
+ def hparams(self, defaults, model_hparams):
+ p = defaults
+ source_vocab_size = self._encoders["inputs"].vocab_size
+ p.input_modality = {
+ "inputs": (registry.Modalities.SYMBOL, source_vocab_size)
+ }
+ p.target_modality = (registry.Modalities.CLASS_LABEL, 2)
+ p.input_space_id = problem.SpaceID.EN_TOK
+ p.target_space_id = problem.SpaceID.GENERIC
+
+ def feature_encoders(self, data_dir):
+ vocab_filename = os.path.join(data_dir, self.vocab_file)
+ encoder = text_encoder.SubwordTextEncoder(vocab_filename)
+ return {
+ "inputs": encoder,
+ "targets": text_encoder.TextEncoder(),
+ }
+
+ def example_reading_spec(self):
+ data_fields = {
+ "inputs": tf.VarLenFeature(tf.int64),
+ "targets": tf.FixedLenFeature([1], tf.int64),
+ }
+ data_items_to_decoders = None
+ return (data_fields, data_items_to_decoders)
diff --git a/tensor2tensor/data_generators/lm1b.py b/tensor2tensor/data_generators/lm1b.py
index a3771e124..d45e4fe1e 100644
--- a/tensor2tensor/data_generators/lm1b.py
+++ b/tensor2tensor/data_generators/lm1b.py
@@ -29,8 +29,10 @@
from six.moves import xrange # pylint: disable=redefined-builtin
from tensor2tensor.data_generators import generator_utils
+from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import tokenizer
+from tensor2tensor.utils import registry
import tensorflow as tf
@@ -53,7 +55,7 @@ def _original_vocab(tmp_dir):
"""
vocab_url = ("http://download.tensorflow.org/models/LM_LSTM_CNN/"
"vocab-2016-09-10.txt")
- vocab_filename = os.path.basename(vocab_url)
+ vocab_filename = os.path.basename(vocab_url + ".en")
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
if not os.path.exists(vocab_filepath):
generator_utils.maybe_download(tmp_dir, vocab_filename, vocab_url)
@@ -140,29 +142,80 @@ def _get_or_build_subword_text_encoder(tmp_dir):
return ret
-def generator(tmp_dir, train, characters=False):
- """Generator for lm1b sentences.
-
- Args:
- tmp_dir: a string.
- train: a boolean.
- characters: a boolean
-
- Yields:
- A dictionary {"inputs": [0], "targets": []}
- """
- _maybe_download_corpus(tmp_dir)
- original_vocab = _original_vocab(tmp_dir)
- files = (_train_data_filenames(tmp_dir) if train
- else [_dev_data_filename(tmp_dir)])
- if characters:
- encoder = text_encoder.ByteTextEncoder()
- else:
- encoder = _get_or_build_subword_text_encoder(tmp_dir)
- for filepath in files:
- tf.logging.info("filepath = %s", filepath)
- for line in tf.gfile.Open(filepath):
- tokens = encoder.encode(
- _replace_oov(original_vocab, text_encoder.native_to_unicode(line)))
- tokens.append(EOS)
- yield {"inputs": [0], "targets": tokens}
+@registry.register_problem
+class LanguagemodelLm1b32k(problem.Text2TextProblem):
+ """A language model on the 1B words corpus."""
+
+ @property
+ def is_character_level(self):
+ return False
+
+ @property
+ def has_inputs(self):
+ return True
+
+ @property
+ def input_space_id(self):
+ # Ratio of dev tokens (including eos) to dev words (including eos)
+ # 176884 / 159658 = 1.107893; multiply ppx by this to compare results.
+ return problem.SpaceID.EN_TOK
+
+ @property
+ def target_space_id(self):
+ return problem.SpaceID.EN_TOK
+
+ @property
+ def num_shards(self):
+ return 100
+
+ @property
+ def vocab_name(self):
+ return "vocab.lm1b.en"
+
+ @property
+ def use_subword_tokenizer(self):
+ return True
+
+ @property
+ def targeted_vocab_size(self):
+ return 2**15 # 32768
+
+ @property
+ def use_train_shards_for_dev(self):
+ return True
+
+ def generator(self, tmp_dir, train, characters=False):
+ """Generator for lm1b sentences.
+
+ Args:
+ tmp_dir: a string.
+ train: a boolean.
+ characters: a boolean
+
+ Yields:
+ A dictionary {"inputs": [0], "targets": []}
+ """
+ _maybe_download_corpus(tmp_dir)
+ original_vocab = _original_vocab(tmp_dir)
+ files = (_train_data_filenames(tmp_dir) if train
+ else [_dev_data_filename(tmp_dir)])
+ if characters:
+ encoder = text_encoder.ByteTextEncoder()
+ else:
+ encoder = _get_or_build_subword_text_encoder(tmp_dir)
+ for filepath in files:
+ tf.logging.info("filepath = %s", filepath)
+ for line in tf.gfile.Open(filepath):
+ tokens = encoder.encode(
+ _replace_oov(original_vocab, text_encoder.native_to_unicode(line)))
+ tokens.append(EOS)
+ yield {"inputs": [0], "targets": tokens}
+
+
+@registry.register_problem
+class LanguagemodelLm1bCharacters(LanguagemodelLm1b32k):
+ """A language model on the 1B words corpus, character level."""
+
+ @property
+ def is_character_level(self):
+ return True
diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py
index 63b835f38..e002329bc 100644
--- a/tensor2tensor/data_generators/problem_hparams.py
+++ b/tensor2tensor/data_generators/problem_hparams.py
@@ -147,8 +147,8 @@ def default_problem_hparams():
# Modalities used to map from input features to a space compatible with
# chosen model architecture. One modality spec (which is a 2-tuple,
# (modality_full_name, vocab_size)) per feature key. modality_full_name is
- # a string type:name, e.g. class_label:class_label_2d. Leaving off the
- # name uses the default modality for that type (e.g. class_label ==
+ # a string type:name, e.g. class_label:2d. Leaving off the name uses the
+ # default modality for that type (e.g. class_label ==
# class_label:default).
input_modality={},
@@ -267,103 +267,6 @@ def audio_timit_tokens(model_hparams, wrong_vocab_size):
return p
-def audio_wsj_characters(unused_model_hparams):
- """English audio transcription benchmark."""
- p = default_problem_hparams()
- p.input_modality = {
- "inputs": (registry.Modalities.AUDIO, None),
- }
- p.target_modality = (registry.Modalities.SYMBOL, 256)
- p.vocabulary = {
- "inputs": text_encoder.TextEncoder(),
- "targets": text_encoder.ByteTextEncoder(),
- }
- p.batch_size_multiplier = 512
- p.loss_multiplier = 2.0
- p.input_space_id = 13
- p.target_space_id = 2
- return p
-
-
-def audio_wsj_tokens(model_hparams, wrong_vocab_size):
- """English audio transcription benchmark.
-
- Args:
- model_hparams: a tf.contrib.training.HParams
- wrong_vocab_size: a number used in the filename indicating the approximate
- vocabulary size. This is not to be confused with the actual vocabulary
- size.
- Returns:
- a tf.contrib.training.HParams
- """
- p = default_problem_hparams()
- # This vocab file must be present within the data directory.
- vocab_filename = os.path.join(model_hparams.data_dir,
- "vocab.endefr.%d" % wrong_vocab_size)
- subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename)
- p.input_modality = {
- "inputs": (registry.Modalities.AUDIO, None),
- }
- p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size)
- p.vocabulary = {
- "inputs": text_encoder.TextEncoder(),
- "targets": subtokenizer,
- }
- p.batch_size_multiplier = 512
- p.loss_multiplier = 2.0
- p.input_space_id = 12
- p.target_space_id = 3
- return p
-
-
-def lm1b_32k(model_hparams):
- """Billion-word language-modeling benchmark, 32k subword vocabulary."""
- p = default_problem_hparams()
- # ratio of dev tokens (including eos) to dev words (including eos)
- # 176884 / 159658 = 1.107893
- p.perplexity_exponent = 1.107893
- p.input_modality = {}
- encoder = text_encoder.SubwordTextEncoder(
- os.path.join(model_hparams.data_dir, "lm1b_32k.subword_text_encoder"))
- p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
- p.vocabulary = {"targets": encoder}
- p.target_space_id = 3
- return p
-
-
-def lm1b_characters(unused_model_hparams):
- """Billion-word language-modeling benchmark, 32k subword vocabulary."""
- p = default_problem_hparams()
- # ratio of dev tokens (including eos) to dev words (including eos)
- # 826189 / 159658 = 5.174742
- p.perplexity_exponent = 5.174742
- p.input_modality = {}
- encoder = text_encoder.ByteTextEncoder()
- p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
- p.vocabulary = {"targets": encoder}
- p.target_space_id = 2
- return p
-
-
-def wmt_ende_bpe32k(model_hparams):
- """English to German translation benchmark."""
- p = default_problem_hparams()
- vocab_size = 40960
- modality_spec = (registry.Modalities.SYMBOL, vocab_size)
- p.input_modality = {"inputs": modality_spec}
- p.target_modality = modality_spec
- # This vocab file must be present within the data directory.
- vocab_filename = os.path.join(model_hparams.data_dir, "vocab.bpe.32000")
- p.vocabulary = {
- "inputs": text_encoder.TokenTextEncoder(vocab_filename=vocab_filename),
- "targets": text_encoder.TokenTextEncoder(vocab_filename=vocab_filename),
- }
- p.loss_multiplier = 1.4
- p.input_space_id = 4
- p.target_space_id = 9
- return p
-
-
def wmt_parsing_characters(model_hparams):
"""English to parse tree translation benchmark."""
del model_hparams # Unused.
@@ -472,25 +375,11 @@ def img2img_imagenet(unused_model_hparams):
lambda p: audio_timit_tokens(p, 2**13),
"audio_timit_tokens_8k_test":
lambda p: audio_timit_tokens(p, 2**13),
- "audio_wsj_characters_tune":
- audio_wsj_characters,
- "audio_wsj_characters_test":
- audio_wsj_characters,
- "audio_wsj_tokens_8k_tune":
- lambda p: audio_wsj_tokens(p, 2**13),
- "audio_wsj_tokens_8k_test":
- lambda p: audio_wsj_tokens(p, 2**13),
- "languagemodel_1b_characters":
- lm1b_characters,
- "languagemodel_1b32k":
- lm1b_32k,
"parsing_english_ptb8k":
lambda p: wmt_parsing_tokens(p, 2**13),
"parsing_english_ptb16k":
lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda
p, "wsj", 2**14, 2**9),
- "translate_ende_wmt_bpe32k":
- wmt_ende_bpe32k,
"img2img_imagenet":
img2img_imagenet,
}
diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py
index c8a3bd1f9..ac9260cfa 100644
--- a/tensor2tensor/data_generators/text_encoder.py
+++ b/tensor2tensor/data_generators/text_encoder.py
@@ -161,6 +161,7 @@ def __init__(self,
vocab_filename,
reverse=False,
vocab_list=None,
+ replace_oov=None,
num_reserved_ids=NUM_RESERVED_TOKENS):
"""Initialize from a file or list, one token per line.
@@ -176,10 +177,13 @@ def __init__(self,
and decoding.
vocab_list: If not None, a list of elements of the vocabulary. If this is
not None, then vocab_filename should be None.
+ replace_oov: If not None, every out-of-vocabulary token seen when
+ encoding will be replaced by this string (which must be in vocab).
num_reserved_ids: Number of IDs to save for reserved tokens like .
"""
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
self._reverse = reverse
+ self._replace_oov = replace_oov
if vocab_filename:
self._init_vocab_from_file(vocab_filename)
else:
@@ -188,7 +192,11 @@ def __init__(self,
def encode(self, sentence):
"""Converts a space-separated string of tokens to a list of ids."""
- ret = [self._token_to_id[tok] for tok in sentence.strip().split()]
+ tokens = sentence.strip().split()
+ if self._replace_oov is not None:
+ tokens = [t if t in self._token_to_id else self._replace_oov
+ for t in tokens]
+ ret = [self._token_to_id[tok] for tok in tokens]
return ret[::-1] if self._reverse else ret
def decode(self, ids):
@@ -639,19 +647,32 @@ def _init_alphabet_from_tokens(self, tokens):
self._alphabet = {c for token in tokens for c in token}
self._alphabet |= _ESCAPE_CHARS
- def _load_from_file(self, filename):
- """Load from a file.
+ def _load_from_file_object(self, f):
+ """Load from a file object.
Args:
- filename: filename to load vocabulary from
+ f: File object to load vocabulary from
"""
subtoken_strings = []
- with tf.gfile.Open(filename) as f:
- for line in f:
- subtoken_strings.append(native_to_unicode(line.strip()[1:-1]))
+ for line in f:
+ s = line.strip()
+ # Some vocab files wrap words in single quotes, but others don't
+ if ((s.startswith("'") and s.endswith("'")) or
+ (s.startswith("\"") and s.endswith("\""))):
+ s = s[1:-1]
+ subtoken_strings.append(native_to_unicode(s))
self._init_subtokens_from_list(subtoken_strings)
self._init_alphabet_from_tokens(subtoken_strings)
+ def _load_from_file(self, filename):
+ """Load from a file.
+
+ Args:
+ filename: Filename to load vocabulary from
+ """
+ with tf.gfile.Open(filename) as f:
+ self._load_from_file_object(f)
+
def store_to_file(self, filename):
with tf.gfile.Open(filename, "w") as f:
for subtoken_string in self._all_subtoken_strings:
diff --git a/tensor2tensor/data_generators/text_encoder_test.py b/tensor2tensor/data_generators/text_encoder_test.py
index eadfcfb5e..c13078808 100644
--- a/tensor2tensor/data_generators/text_encoder_test.py
+++ b/tensor2tensor/data_generators/text_encoder_test.py
@@ -21,6 +21,7 @@
from __future__ import unicode_literals
import collections
+import io
import os
import shutil
@@ -31,22 +32,30 @@
import tensorflow as tf
+class NativeToUnicodeTest(tf.test.TestCase):
+
+ def test_native_to_unicode(self):
+ s = r"foo bar"
+ self.assertIsInstance(text_encoder.native_to_unicode(s), unicode)
+ self.assertEqual(text_encoder.native_to_unicode(s), u"foo bar")
+
+
class EscapeUnescapeTokenTest(tf.test.TestCase):
def test_escape_token(self):
escaped = text_encoder._escape_token(
- 'Foo! Bar.\nunder_score back\\slash',
- set('abcdefghijklmnopqrstuvwxyz .\n') | text_encoder._ESCAPE_CHARS)
+ "Foo! Bar.\nunder_score back\\slash",
+ set("abcdefghijklmnopqrstuvwxyz .\n") | text_encoder._ESCAPE_CHARS)
self.assertEqual(
- '\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_', escaped)
+ "\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_", escaped)
def test_unescape_token(self):
unescaped = text_encoder._unescape_token(
- '\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_')
+ "\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_")
self.assertEqual(
- 'Foo! Bar.\nunder_score back\\slash', unescaped)
+ "Foo! Bar.\nunder_score back\\slash", unescaped)
class TokenTextEncoderTest(tf.test.TestCase):
@@ -54,7 +63,7 @@ class TokenTextEncoderTest(tf.test.TestCase):
@classmethod
def setUpClass(cls):
"""Make sure the test dir exists and is empty."""
- cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), 'encoder_test')
+ cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), "encoder_test")
shutil.rmtree(cls.test_temp_dir, ignore_errors=True)
os.mkdir(cls.test_temp_dir)
@@ -65,8 +74,8 @@ def test_save_and_reload(self):
that this test size be "large".
"""
- corpus = 'A B C D E F G H I J K L M N O P Q R S T U V W X Y Z'
- vocab_filename = os.path.join(self.test_temp_dir, 'abc.vocab')
+ corpus = "A B C D E F G H I J K L M N O P Q R S T U V W X Y Z"
+ vocab_filename = os.path.join(self.test_temp_dir, "abc.vocab")
# Make text encoder from a list and store vocab to fake filesystem.
encoder = text_encoder.TokenTextEncoder(None, vocab_list=corpus.split())
@@ -80,7 +89,7 @@ def test_save_and_reload(self):
def test_reserved_tokens_in_corpus(self):
"""Test that we handle reserved tokens appearing in the corpus."""
- corpus = 'A B {} D E F {} G {}'.format(text_encoder.EOS,
+ corpus = "A B {} D E F {} G {}".format(text_encoder.EOS,
text_encoder.EOS,
text_encoder.PAD)
@@ -97,14 +106,14 @@ class SubwordTextEncoderTest(tf.test.TestCase):
def test_encode_decode(self):
corpus = (
- 'This is a corpus of text that provides a bunch of tokens from which '
- 'to build a vocabulary. It will be used when strings are encoded '
- 'with a TextEncoder subclass. The encoder was coded by a coder.')
- token_counts = collections.Counter(corpus.split(' '))
- alphabet = set(corpus) ^ {' '}
+ "This is a corpus of text that provides a bunch of tokens from which "
+ "to build a vocabulary. It will be used when strings are encoded "
+ "with a TextEncoder subclass. The encoder was coded by a coder.")
+ token_counts = collections.Counter(corpus.split(" "))
+ alphabet = set(corpus) ^ {" "}
- original = 'This is a coded sentence encoded by the SubwordTextEncoder.'
- token_counts.update(original.split(' '))
+ original = "This is a coded sentence encoded by the SubwordTextEncoder."
+ token_counts.update(original.split(" "))
encoder = text_encoder.SubwordTextEncoder.build_to_target_size(
100, token_counts, 2, 10)
@@ -118,31 +127,31 @@ def test_encode_decode(self):
# they should appear in the vocabulary even though they are substrings
# of other included strings.
subtoken_strings = {encoder._all_subtoken_strings[i] for i in encoded}
- self.assertIn('encoded_', subtoken_strings)
- self.assertIn('coded_', subtoken_strings)
- self.assertIn('TextEncoder', encoder._all_subtoken_strings)
- self.assertIn('coder', encoder._all_subtoken_strings)
+ self.assertIn("encoded_", subtoken_strings)
+ self.assertIn("coded_", subtoken_strings)
+ self.assertIn("TextEncoder", encoder._all_subtoken_strings)
+ self.assertIn("coder", encoder._all_subtoken_strings)
- # Every character in the corpus should be in the encoder's alphabet and
+ # Every character in the corpus should be in the encoders alphabet and
# its subtoken vocabulary.
self.assertTrue(alphabet.issubset(encoder._alphabet))
for a in alphabet:
self.assertIn(a, encoder._all_subtoken_strings)
def test_unicode(self):
- corpus = 'Cat emoticons. \U0001F638 \U0001F639 \U0001F63A \U0001F63B'
- token_counts = collections.Counter(corpus.split(' '))
+ corpus = "Cat emoticons. \U0001F638 \U0001F639 \U0001F63A \U0001F63B"
+ token_counts = collections.Counter(corpus.split(" "))
encoder = text_encoder.SubwordTextEncoder.build_to_target_size(
100, token_counts, 2, 10)
- self.assertIn('\U0001F638', encoder._alphabet)
- self.assertIn('\U0001F63B', encoder._all_subtoken_strings)
+ self.assertIn("\U0001F638", encoder._alphabet)
+ self.assertIn("\U0001F63B", encoder._all_subtoken_strings)
def test_small_vocab(self):
- corpus = 'The quick brown fox jumps over the lazy dog'
- token_counts = collections.Counter(corpus.split(' '))
- alphabet = set(corpus) ^ {' '}
+ corpus = "The quick brown fox jumps over the lazy dog"
+ token_counts = collections.Counter(corpus.split(" "))
+ alphabet = set(corpus) ^ {" "}
encoder = text_encoder.SubwordTextEncoder.build_to_target_size(
10, token_counts, 2, 10)
@@ -155,12 +164,12 @@ def test_small_vocab(self):
self.assertIn(a, encoder._all_subtoken_strings)
def test_encodable_when_not_in_alphabet(self):
- corpus = 'the quick brown fox jumps over the lazy dog'
- token_counts = collections.Counter(corpus.split(' '))
+ corpus = "the quick brown fox jumps over the lazy dog"
+ token_counts = collections.Counter(corpus.split(" "))
encoder = text_encoder.SubwordTextEncoder.build_to_target_size(
100, token_counts, 2, 10)
- original = 'This has UPPER CASE letters that are out of alphabet'
+ original = "This has UPPER CASE letters that are out of alphabet"
# Early versions could have an infinite loop when breaking into subtokens
# if there was any out-of-alphabet characters in the encoded string.
@@ -168,24 +177,42 @@ def test_encodable_when_not_in_alphabet(self):
decoded = encoder.decode(encoded)
self.assertEqual(original, decoded)
- encoded_str = ''.join(encoder._all_subtoken_strings[i] for i in encoded)
- self.assertIn('\\84;', encoded_str)
+ encoded_str = "".join(encoder._all_subtoken_strings[i] for i in encoded)
+ self.assertIn("\\84;", encoded_str)
- @mock.patch.object(text_encoder, '_ESCAPE_CHARS', new=set('\\_;13579'))
+ @mock.patch.object(text_encoder, "_ESCAPE_CHARS", new=set("\\_;13579"))
def test_raises_exception_when_not_encodable(self):
- corpus = 'the quick brown fox jumps over the lazy dog'
- token_counts = collections.Counter(corpus.split(' '))
+ corpus = "the quick brown fox jumps over the lazy dog"
+ token_counts = collections.Counter(corpus.split(" "))
# Deliberately exclude some required encoding chars from the alphabet
# and token list, making some strings unencodable.
encoder = text_encoder.SubwordTextEncoder.build_to_target_size(
100, token_counts, 2, 10)
- original = 'This has UPPER CASE letters that are out of alphabet'
+ original = "This has UPPER CASE letters that are out of alphabet"
# Previously there was a bug which produced an infinite loop in this case.
with self.assertRaises(AssertionError):
encoder.encode(original)
-
-if __name__ == '__main__':
+ def test_load_from_file(self):
+ # Test a vocab file with words not wrapped with single quotes
+ encoder = text_encoder.SubwordTextEncoder()
+ correct_vocab = ["the", "and", "of"]
+ vocab = io.StringIO("the\n"
+ "and\n"
+ "of\n")
+ encoder._load_from_file_object(vocab)
+ self.assertEqual(encoder._all_subtoken_strings, correct_vocab)
+
+ # Test a vocab file with words wrapped in single quotes
+ encoder = text_encoder.SubwordTextEncoder()
+ vocab = io.StringIO("\"the\"\n"
+ "\"and\"\n"
+ "\"of\"\n")
+ encoder._load_from_file_object(vocab)
+ self.assertEqual(encoder._all_subtoken_strings, correct_vocab)
+
+
+if __name__ == "__main__":
tf.test.main()
diff --git a/tensor2tensor/data_generators/wiki.py b/tensor2tensor/data_generators/wiki.py
index d1c80f2e1..9610cb1d8 100644
--- a/tensor2tensor/data_generators/wiki.py
+++ b/tensor2tensor/data_generators/wiki.py
@@ -31,6 +31,7 @@
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.utils import registry
+import tensorflow as tf
# End-of-sentence marker.
EOS = text_encoder.EOS_ID
@@ -49,7 +50,7 @@ def _maybe_download_corpus(tmp_dir):
"enwiki-20170620-pages-articles-multistream.xml.bz2")
corpus_filename = os.path.basename(corpus_url)
corpus_filepath = os.path.join(tmp_dir, corpus_filename)
- if not os.path.exists(corpus_filepath):
+ if not tf.gfile.Exists(corpus_filepath):
generator_utils.maybe_download(tmp_dir, corpus_filename, corpus_url)
return corpus_filepath
diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py
index 93fc27ac5..8d6cdae6f 100644
--- a/tensor2tensor/data_generators/wmt.py
+++ b/tensor2tensor/data_generators/wmt.py
@@ -305,17 +305,44 @@ def _get_wmt_ende_bpe_dataset(directory, filename):
return train_path
-def ende_bpe_token_generator(data_dir, tmp_dir, train):
- """Instance of token generator for the WMT en->de task, training set."""
- dataset_path = ("train.tok.clean.bpe.32000"
- if train else "newstest2013.tok.bpe.32000")
- train_path = _get_wmt_ende_bpe_dataset(tmp_dir, dataset_path)
- token_tmp_path = os.path.join(tmp_dir, "vocab.bpe.32000")
- token_path = os.path.join(data_dir, "vocab.bpe.32000")
- tf.gfile.Copy(token_tmp_path, token_path, overwrite=True)
- token_vocab = text_encoder.TokenTextEncoder(vocab_filename=token_path)
- return token_generator(train_path + ".en", train_path + ".de", token_vocab,
- EOS)
+@registry.register_problem
+class TranslateEndeWmtBpe32k(TranslateProblem):
+ """Problem spec for WMT En-De translation, BPE version."""
+
+ @property
+ def targeted_vocab_size(self):
+ return 32000
+
+ @property
+ def vocab_name(self):
+ return "vocab.bpe"
+
+ def feature_encoders(self, data_dir):
+ vocab_filename = os.path.join(data_dir, self.vocab_file)
+ encoder = text_encoder.TokenTextEncoder(vocab_filename, replace_oov="UNK")
+ return {"inputs": encoder, "targets": encoder}
+
+ def generator(self, data_dir, tmp_dir, train):
+ """Instance of token generator for the WMT en->de task, training set."""
+ dataset_path = ("train.tok.clean.bpe.32000"
+ if train else "newstest2013.tok.bpe.32000")
+ train_path = _get_wmt_ende_bpe_dataset(tmp_dir, dataset_path)
+ token_tmp_path = os.path.join(tmp_dir, self.vocab_file)
+ token_path = os.path.join(data_dir, self.vocab_file)
+ tf.gfile.Copy(token_tmp_path, token_path, overwrite=True)
+ with tf.gfile.GFile(token_path, mode="a") as f:
+ f.write("UNK\n") # Add UNK to the vocab.
+ token_vocab = text_encoder.TokenTextEncoder(token_path, replace_oov="UNK")
+ return token_generator(train_path + ".en", train_path + ".de",
+ token_vocab, EOS)
+
+ @property
+ def input_space_id(self):
+ return problem.SpaceID.EN_BPE_TOK
+
+ @property
+ def target_space_id(self):
+ return problem.SpaceID.DE_BPE_TOK
def _preprocess_sgm(line, is_sgm):
diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py
index d69e68f80..253e9bee5 100644
--- a/tensor2tensor/layers/common_attention.py
+++ b/tensor2tensor/layers/common_attention.py
@@ -30,6 +30,8 @@
import tensorflow as tf
+from tensorflow.python.framework import function
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
"""Adds a bunch of sinusoids of different frequencies to a Tensor.
@@ -1008,6 +1010,8 @@ def self_attention_expert(
depth = x.get_shape().as_list()[-1]
length = tf.shape(batch_coordinate)[0]
+ tf.summary.scalar("batch_size", length, family="experts_stats_batch_size")
+
attention_kq_size = attention_kq_size or depth
attention_v_size = attention_v_size or depth
@@ -1059,6 +1063,7 @@ def local_expert_attention(
loss_coef,
attention_num_experts,
train=True,
+ pad_remover=None,
**kwargs
):
"""Attention using a mixture of experts.
@@ -1073,6 +1078,7 @@ def local_expert_attention(
loss_coef: a scalar. A multiplier for the expert loss
attention_num_experts: The number of experts to use
train: a boolean for the current mode
+ pad_remover (PadRemover): A util object containing the padding position
**kwargs: Arguments to forward to self_attention_expert
Returns:
@@ -1093,4 +1099,153 @@ def local_expert_attention(
loss_coef=loss_coef,
pass_x=True,
pass_gates=False,
- additional_dispatch_params=additional_dispatch_params)
+ additional_dispatch_params=additional_dispatch_params,
+ pad_remover=pad_remover
+ )
+
+
+def scaled_dot_product_attention_simple(q, k, v, bias, name=None):
+ """scaled dot-product attention. One head. One spatial dimension.
+
+ Args:
+ q: a Tensor with shape [batch, length_q, depth_k]
+ k: a Tensor with shape [batch, length_kv, depth_k]
+ v: a Tensor with shape [batch, length_kv, depth_v]
+ bias: optional Tensor broadcastable to [batch, length_q, length_kv]
+ name: an optional string
+
+ Returns:
+ A Tensor.
+ """
+ with tf.variable_scope(
+ name, default_name="scaled_dot_product_attention_simple"):
+ scalar = tf.rsqrt(tf.to_float(tf.shape(q)[2]))
+ logits = tf.matmul(q * scalar, k, transpose_b=True)
+ if bias is not None:
+ logits += bias
+ weights = tf.nn.softmax(logits, name="attention_weights")
+ return tf.matmul(weights, v)
+
+
+_function_cache = {}
+
+
+def multihead_self_attention_memory_efficient(x,
+ bias,
+ num_heads,
+ head_size=None,
+ epsilon=1e-6,
+ forget=True,
+ test_vars=None,
+ name=None):
+ """Multihead scaled-dot-product self-attention.
+
+ Includes layer norm.
+
+ Returns multihead-self-attention(layer_norm(x))
+
+ Computes one attention head at a time to avoid exhausting memory.
+
+ If forget=True, then forget all forwards activations and recompute on
+ the backwards pass.
+
+ Args:
+ x: a Tensor with shape [batch, length, input_size]
+ bias: an attention bias tensor broadcastable to [batch, 1, length, length]
+ num_heads: an integer
+ head_size: an optional integer - defaults to input_size/num_heads
+ epsilon: a float, for layer norm
+ forget: a boolean - forget forwards activations and recompute on backprop
+ test_vars: optional tuple of variables for testing purposes
+ name: an optional string
+
+ Returns:
+ A Tensor.
+ """
+ io_size = x.get_shape().as_list()[-1]
+ if head_size is None:
+ assert io_size % num_heads == 0
+ head_size = io_size / num_heads
+
+ def forward_internal(x, wqkv, wo, attention_bias, norm_scale, norm_bias):
+ """Forward function."""
+ n = common_layers.layer_norm_compute_python(
+ x, epsilon, norm_scale, norm_bias)
+ wqkv_split = tf.unstack(wqkv, num=num_heads)
+ wo_split = tf.unstack(wo, num=num_heads)
+ y = 0
+ for h in xrange(num_heads):
+ with tf.control_dependencies([y] if h > 0 else []):
+ combined = tf.nn.conv1d(n, wqkv_split[h], 1, "SAME")
+ q, k, v = tf.split(combined, 3, axis=2)
+ o = scaled_dot_product_attention_simple(q, k, v, attention_bias)
+ y += tf.nn.conv1d(o, wo_split[h], 1, "SAME")
+ return y
+
+ key = ("multihead_self_attention_memory_efficient %s %s" %
+ (num_heads, epsilon))
+ if not forget:
+ forward_fn = forward_internal
+ elif key in _function_cache:
+ forward_fn = _function_cache[key]
+ else:
+ @function.Defun(compiled=True)
+ def grad_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias, dy):
+ with tf.control_dependencies([dy]):
+ n = common_layers.layer_norm_compute_python(
+ x, epsilon, norm_scale, norm_bias)
+ wqkv_split = tf.unstack(wqkv, num=num_heads)
+ wo_split = tf.unstack(wo, num=num_heads)
+ deps = []
+ dwqkvs = []
+ dwos = []
+ dn = 0
+ for h in xrange(num_heads):
+ with tf.control_dependencies(deps):
+ combined = tf.nn.conv1d(n, wqkv_split[h], 1, "SAME")
+ q, k, v = tf.split(combined, 3, axis=2)
+ o = scaled_dot_product_attention_simple(q, k, v, attention_bias)
+ partial_y = tf.nn.conv1d(o, wo_split[h], 1, "SAME")
+ pdn, dwqkvh, dwoh = tf.gradients(
+ ys=[partial_y],
+ xs=[n, wqkv_split[h], wo_split[h]],
+ grad_ys=[dy])
+ dn += pdn
+ dwqkvs.append(dwqkvh)
+ dwos.append(dwoh)
+ deps = [dn, dwqkvh, dwoh]
+ dwqkv = tf.stack(dwqkvs)
+ dwo = tf.stack(dwos)
+ with tf.control_dependencies(deps):
+ dx, dnorm_scale, dnorm_bias = tf.gradients(
+ ys=[n], xs=[x, norm_scale, norm_bias], grad_ys=[dn])
+ return (dx, dwqkv, dwo, tf.zeros_like(attention_bias),
+ dnorm_scale, dnorm_bias)
+
+ @function.Defun(grad_func=grad_fn, compiled=True,
+ separate_compiled_gradients=True)
+ def forward_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias):
+ return forward_internal(
+ x, wqkv, wo, attention_bias, norm_scale, norm_bias)
+ _function_cache[key] = forward_fn
+
+ if bias is not None:
+ bias = tf.squeeze(bias, 1)
+ with tf.variable_scope(name, default_name="multihead_attention", values=[x]):
+ # TODO(noam): it would be nice to save memory by casting x to float16
+ # here, but this causes problems with the gradients. Figure out if there
+ # is a way to leave the gradients as float32.
+ if test_vars is not None:
+ wqkv, wo, norm_scale, norm_bias = list(test_vars)
+ else:
+ wqkv = tf.get_variable(
+ "wqkv", [num_heads, 1, io_size, 3 * head_size],
+ initializer=tf.random_normal_initializer(stddev=io_size**-0.5))
+ wo = tf.get_variable(
+ "wo", [num_heads, 1, head_size, io_size],
+ initializer=tf.random_normal_initializer(
+ stddev=(head_size * num_heads)**-0.5))
+ norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
+ y = forward_fn(x, wqkv, wo, bias, norm_scale, norm_bias)
+ y.set_shape(x.get_shape())
+ return y
diff --git a/tensor2tensor/layers/common_attention_test.py b/tensor2tensor/layers/common_attention_test.py
index e49999fbb..6664bcc2d 100644
--- a/tensor2tensor/layers/common_attention_test.py
+++ b/tensor2tensor/layers/common_attention_test.py
@@ -23,6 +23,7 @@
import numpy as np
from tensor2tensor.layers import common_attention
+from tensor2tensor.layers import common_layers
import tensorflow as tf
@@ -117,6 +118,49 @@ def testLocalUnmaskedAttention2DMatchingBlockLength(self):
res = session.run(a)
self.assertEqual(res.shape, (5, 4, 25, 25, 16))
+ def testMultiheadSelfAttentionMemoryEfficient(self):
+ num_heads = 4
+ io_size = 16
+ batch = 2
+ length = 7
+ head_size = 5
+ x = np.random.rand(batch, length, io_size)
+ dy = np.random.rand(batch, length, io_size)
+ with self.test_session() as session:
+ x = tf.to_float(x)
+ dy = tf.to_float(dy)
+ bias = common_attention.attention_bias_lower_triangle(length)
+ wqkv = tf.get_variable(
+ "wqkv", [num_heads, 1, io_size, 3 * head_size],
+ initializer=tf.random_normal_initializer(stddev=io_size**-0.5))
+ wo = tf.get_variable(
+ "wo", [num_heads, 1, head_size, io_size],
+ initializer=tf.random_normal_initializer(
+ stddev=(head_size * num_heads)**-0.5))
+ norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
+ y = common_attention.multihead_self_attention_memory_efficient(
+ x, bias, num_heads, head_size=head_size, forget=False,
+ test_vars=(wqkv, wo, norm_scale, norm_bias))
+ y_forget = common_attention.multihead_self_attention_memory_efficient(
+ x, bias, num_heads, head_size=head_size, forget=True,
+ test_vars=(wqkv, wo, norm_scale, norm_bias))
+ dx, dwqkv, dwo, dnorm_scale, dnorm_bias = tf.gradients(
+ ys=[y], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy])
+ dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f = tf.gradients(
+ ys=[y_forget], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy])
+ session.run(tf.global_variables_initializer())
+ (y, y_forget,
+ dx, dwqkv, dwo, dnorm_scale, dnorm_bias,
+ dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f) = session.run(
+ [y, y_forget,
+ dx, dwqkv, dwo, dnorm_scale, dnorm_bias,
+ dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f])
+ self.assertAllClose(y, y_forget)
+ self.assertAllClose(dwo, dwo_f)
+ self.assertAllClose(dwqkv, dwqkv_f)
+ self.assertAllClose(dnorm_scale, dnorm_scale_f)
+ self.assertAllClose(dnorm_bias, dnorm_bias_f)
+ self.assertAllClose(dx, dx_f)
if __name__ == "__main__":
tf.test.main()
diff --git a/tensor2tensor/layers/common_hparams.py b/tensor2tensor/layers/common_hparams.py
index d4751bb0d..2e33c9e94 100644
--- a/tensor2tensor/layers/common_hparams.py
+++ b/tensor2tensor/layers/common_hparams.py
@@ -33,14 +33,6 @@ def basic_params1():
"""A set of basic hyperparameters."""
return tf.contrib.training.HParams(
batch_size=4096, # in tokens per batch per gpu
- # This flag controls the number of length buckets in the data reader.
- # Too many buckets slows down data reading - this needs fixing.
- # Too few buckets mean lots of wasted padding.
- # If this value is 1, we have buckets with maximum lengths:
- # [8, 12, 16, 24, 32, 48 ... (max_length or batch_size)]
- # If this value is 2, we have buckets with maximum lengths:
- # [8, 10, 12, 14, 16, 20, 24 ... (max_length or batch_size)]
- batching_mantissa_bits=1,
num_hidden_layers=4,
kernel_height=3,
kernel_width=1,
@@ -98,9 +90,22 @@ def basic_params1():
# epsilon parameter to normalization function
norm_epsilon=1e-6,
symbol_modality_num_shards=16,
- # setting the max length in a minibatch. 0 means default behavior,
- # max_length = hparams.batch_size * length_multiplier
+ # During training, we drop sequences whose inputs or targets are longer
+ # than max_length.
+ # If max_length==0, we use hparams.batch_size instead.
max_length=0,
+ # Maximum length in the smallest length bucket. Setting this
+ # flag too high will result in wasteful padding of short
+ # sequences. Due to some (hopefully) temporary hacks in the
+ # data reading and batching code, setting this flag too low
+ # results in a very long batch-shuffling queue.
+ # TODO(noam): change this once the Datasets API changes.
+ min_length_bucket=8,
+ # This flag controls the number of length buckets in the data
+ # reader. The buckets have maximum lengths from
+ # min_bucket_length to (max_length or batch_size), increasing
+ # (approximately) by factors of length_bucket_step.
+ length_bucket_step=1.1,
# If set to True, drop sequences longer than max_length during eval.
# This affects the validity of the evaluation metrics.
eval_drop_long_sequences=int(False),
diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py
index ad899bfbf..4b09e70cb 100644
--- a/tensor2tensor/layers/common_layers.py
+++ b/tensor2tensor/layers/common_layers.py
@@ -425,6 +425,15 @@ def conv_fn(inputs, filters, kernel_size, **kwargs):
return conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs)
+def layer_norm_vars(filters):
+ """Create Variables for layer norm."""
+ scale = tf.get_variable(
+ "layer_norm_scale", [filters], initializer=tf.ones_initializer())
+ bias = tf.get_variable(
+ "layer_norm_bias", [filters], initializer=tf.zeros_initializer())
+ return scale, bias
+
+
def layer_norm_compute_python(x, epsilon, scale, bias):
"""Layer norm raw computation."""
mean = tf.reduce_mean(x, axis=[-1], keep_dims=True)
@@ -1773,7 +1782,7 @@ def smoothing_cross_entropy_factored_grad(op, dy):
b = op.inputs[1]
labels = op.inputs[2]
confidence = op.inputs[3]
- num_splits = 32
+ num_splits = 16
vocab_size = tf.shape(b)[0]
labels = approximate_split(labels, num_splits)
a = approximate_split(a, num_splits)
@@ -1817,7 +1826,7 @@ def smoothing_cross_entropy_factored(a, b, labels, confidence):
Returns:
A Tensor with shape [batch]
"""
- num_splits = 32
+ num_splits = 16
vocab_size = tf.shape(b)[0]
labels = approximate_split(labels, num_splits)
a = approximate_split(a, num_splits)
@@ -1957,3 +1966,113 @@ def identity(*args):
id_out = identity(*(inputs + train_vars + outputs))
return id_out
+
+
+_function_cache = {}
+
+
+def conv_hidden_relu_memory_efficient(x,
+ filter_size,
+ epsilon=1e-6,
+ forget=True,
+ test_vars=None,
+ name=None):
+ """LayerNorm, Conv, ReLU, Conv.
+
+ All convolutions have kernel size 1.
+
+ returns conv(relu(conv(layer_norm(x))))
+
+ Args:
+ x: input Tensor with shape [batch, length, io_size]
+ filter_size: an integer - size of the hidden layer.
+ epsilon: a float (for layer norm)
+ forget: a boolean - forget forwards activations and recompute on backprop
+ test_vars: optional tuple of variables for testing purposes
+ name: an optional string
+
+ Returns:
+ a Tensor with shape [batch, length, io_size]
+ """
+ io_size = x.get_shape().as_list()[-1]
+
+ def forward_internal(x, f1, f2, scale, bias):
+ """Forward function."""
+ # split batch-wise to avoid exhausting memory in cast the batch is large
+ # and the hidden layer is large.
+ num_splits = 4
+ x_flat = tf.reshape(x, [-1, 1, tf.shape(x)[2]])
+ xs = approximate_split(x_flat, num_splits)
+ ys = []
+ for i in xrange(num_splits):
+ with tf.control_dependencies(ys[-1:]):
+ n = layer_norm_compute_python(xs[i], epsilon, scale, bias)
+ y = tf.nn.conv1d(n, f1, 1, "SAME")
+ y = tf.nn.relu(y)
+ y = tf.nn.conv1d(y, f2, 1, "SAME")
+ ys.append(y)
+ y = tf.concat(ys, 0)
+ y = tf.reshape(y, tf.shape(x))
+ return y
+ key = ("conv_hidden_relu_memory_efficient %s" % epsilon)
+ if not forget:
+ forward_fn = forward_internal
+ elif key in _function_cache:
+ forward_fn = _function_cache[key]
+ else:
+ @function.Defun(compiled=True)
+ def grad_fn(x, f1, f2, scale, bias, dy):
+ with tf.control_dependencies([dy]):
+ num_splits = 4
+ x_shape = tf.shape(x)
+ flat_shape = [-1, 1, x_shape[2]]
+ x = tf.reshape(x, flat_shape)
+ dy = tf.reshape(dy, flat_shape)
+ xs = approximate_split(x, num_splits)
+ dys = approximate_split(dy, num_splits)
+ dxs = []
+ df1 = 0
+ df2 = 0
+ dscale = 0
+ dbias = 0
+ deps = []
+ for i in xrange(num_splits):
+ with tf.control_dependencies(deps):
+ n = layer_norm_compute_python(xs[i], epsilon, scale, bias)
+ y = tf.nn.conv1d(n, f1, 1, "SAME")
+ y = tf.nn.relu(y)
+ y = tf.nn.conv1d(y, f2, 1, "SAME")
+ dxi, pdf1, pdf2, pdscale, pdbias = tf.gradients(
+ ys=[y], xs=[xs[i], f1, f2, scale, bias], grad_ys=[dys[i]])
+ df1 += pdf1
+ df2 += pdf2
+ dscale += pdscale
+ dbias += pdbias
+ dxs.append(dxi)
+ deps = [dxi, df1, df2, dscale, dbias]
+ with tf.control_dependencies(deps):
+ dx = tf.concat(dxs, 0)
+ dx = tf.reshape(dx, x_shape)
+ return dx, df1, df2, dscale, dbias
+
+ @function.Defun(grad_func=grad_fn, compiled=True,
+ separate_compiled_gradients=True)
+ def forward_fn(x, f1, f2, scale, bias):
+ return forward_internal(x, f1, f2, scale, bias)
+
+ with tf.variable_scope(name, default_name="ffn2", values=[x]):
+ # TODO(noam): it would be nice to save memory by casting x to float16
+ # here, but this causes problems with the gradients. Figure out if there
+ # is a way to leave the gradients as float32.
+ if test_vars is not None:
+ f1, f2, scale, bias = list(test_vars)
+ else:
+ f1 = tf.get_variable("f1", [1, io_size, filter_size])
+ f2 = tf.get_variable("f2", [1, filter_size, io_size])
+ scale, bias = layer_norm_vars(io_size)
+ if forget:
+ y = forward_fn(x, f1, f2, scale, bias)
+ else:
+ y = forward_internal(x, f1, f2, scale, bias)
+ y.set_shape(x.get_shape())
+ return y
diff --git a/tensor2tensor/layers/common_layers_test.py b/tensor2tensor/layers/common_layers_test.py
index 61023938f..d11f8ce2c 100644
--- a/tensor2tensor/layers/common_layers_test.py
+++ b/tensor2tensor/layers/common_layers_test.py
@@ -474,6 +474,43 @@ def testFactoredTensorImplicitConversion(self):
out = session.run(d)
self.assertEqual(out.shape, (3, 4, 6))
+ def testConvHiddenReluMemoryEfficient(self):
+ batch = 3
+ length = 23
+ io_size = 16
+ filter_size = 7
+ x = np.random.rand(batch, length, io_size)
+ dy = np.random.rand(batch, length, io_size)
+ with self.test_session() as session:
+ x = tf.to_float(x)
+ dy = tf.to_float(dy)
+ f1 = tf.get_variable("f1", [1, io_size, filter_size])
+ f2 = tf.get_variable("f2", [1, filter_size, io_size])
+ norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
+ y = common_layers.conv_hidden_relu_memory_efficient(
+ x, filter_size, forget=False,
+ test_vars=(f1, f2, norm_scale, norm_bias))
+ y_forget = common_layers.conv_hidden_relu_memory_efficient(
+ x, filter_size, forget=True,
+ test_vars=(f1, f2, norm_scale, norm_bias))
+ dx, df1, df2, dnorm_scale, dnorm_bias = tf.gradients(
+ ys=[y], xs=[x, f1, f2, norm_scale, norm_bias], grad_ys=[dy])
+ dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f = tf.gradients(
+ ys=[y_forget], xs=[x, f1, f2, norm_scale, norm_bias], grad_ys=[dy])
+ session.run(tf.global_variables_initializer())
+ (y, y_forget,
+ dx, df1, df2, dnorm_scale, dnorm_bias,
+ dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f) = session.run(
+ [y, y_forget,
+ dx, df1, df2, dnorm_scale, dnorm_bias,
+ dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f])
+ self.assertAllClose(y, y_forget)
+ self.assertAllClose(df2, df2_f)
+ self.assertAllClose(df1, df1_f)
+ self.assertAllClose(dnorm_scale, dnorm_scale_f)
+ self.assertAllClose(dnorm_bias, dnorm_bias_f)
+ self.assertAllClose(dx, dx_f)
+
class FnWithCustomGradTest(tf.test.TestCase):
diff --git a/tensor2tensor/layers/modalities.py b/tensor2tensor/layers/modalities.py
index 57652dbec..e03e6835e 100644
--- a/tensor2tensor/layers/modalities.py
+++ b/tensor2tensor/layers/modalities.py
@@ -361,9 +361,9 @@ def xnet_resblock(x, filters, res_relu, name):
"compress_block_final")
-@registry.register_class_label_modality("default")
+@registry.register_class_label_modality("2d")
class ClassLabelModality(modality.Modality):
- """Used for label data."""
+ """Used for label data; if is2d=True, uses Xception flow to logits."""
def __init__(self, model_hparams, vocab_size, is2d=True):
super(ClassLabelModality, self).__init__(model_hparams, vocab_size)
@@ -397,9 +397,11 @@ def targets_bottom(self, x):
def top(self, body_output, _):
"""Transform inputs from model space to target space.
- Perform the Xception "Exit flow", consisting of a single residual block and
- two separable convolutional upscalings followed by global spatial average
- pooling.
+ If instantiated with is2d=True, perform the Xception "Exit flow", consisting
+ of a single residual block and two separable convolutional upscalings
+ followed by global spatial average pooling.
+
+ Otherwise, a single linear layer to logits.
Args:
body_output: A Tensor with shape [batch, ?, ?, body_output_size].
@@ -417,11 +419,12 @@ def top(self, body_output, _):
spatial_dim = tf.to_int32(spatial_dim_float)
x_depth = int(x.get_shape()[3])
x = tf.reshape(x, [-1, spatial_dim, spatial_dim, x_depth])
- x = common_layers.conv_block_downsample(x, self._kernel, self._strides,
- self._padding)
- x = tf.nn.relu(x)
- x = tf.reduce_mean(x, axis=[1, 2], keep_dims=True)
- res = common_layers.conv(x, self._vocab_size, (1, 1))
+ x = common_layers.conv_block_downsample(x, self._kernel, self._strides,
+ self._padding)
+ x = tf.nn.relu(x)
+ x = tf.reduce_mean(x, axis=[1, 2], keep_dims=True)
+
+ res = tf.layers.dense(x, self._vocab_size)
return tf.expand_dims(res, 3)
def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
@@ -431,7 +434,7 @@ def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
top_out, targets, weights_fn=weights_fn)
-@registry.register_class_label_modality("class_label_2d")
+@registry.register_class_label_modality("default")
class ClassLabel1DModality(ClassLabelModality):
"""Used for label data."""
diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py
index 5bb63c303..3b72ea9c2 100644
--- a/tensor2tensor/models/attention_lm_moe.py
+++ b/tensor2tensor/models/attention_lm_moe.py
@@ -40,16 +40,18 @@
import tensorflow as tf
-class AttentionMoeType(object):
- NONE = "none"
- LOCAL = "local"
- GLOBAL = "global"
+class AttentionType(object):
+ MULTIHEAD = "multihead"
+ LOCAL_EXPERTS = "local_experts"
+ GLOBAL_MOE = "global_experts"
+ MEMORY_EFFICIENT = "memory_efficient"
@staticmethod
def get_choices():
return [
- AttentionMoeType.NONE,
- AttentionMoeType.LOCAL,
+ AttentionType.MULTIHEAD,
+ AttentionType.LOCAL_EXPERTS,
+ AttentionType.MEMORY_EFFICIENT,
]
@@ -70,7 +72,7 @@ def preprocess(x):
def postprocess(x, y):
return dp(common_layers.layer_postprocess, x, y, hparams)
- (decoder_input, decoder_self_attention_bias) = dp(
+ (decoder_input, decoder_self_attention_bias, pad_remover) = dp(
attention_lm_moe_prepare_decoder, targets, hparams)
x = dp(tf.nn.dropout, decoder_input,
@@ -87,15 +89,15 @@ def _diet_expert(x):
else:
expert_fn = expert_utils.ffn_expert_fn(
hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
+
for layer in xrange(hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
with tf.variable_scope(
- "attention_{}".format(hparams.attention_moe_type)):
- x = preprocess(x)
- if hparams.attention_moe_type == AttentionMoeType.NONE:
+ "attention_{}".format(hparams.attention_type)):
+ if hparams.attention_type == AttentionType.MULTIHEAD:
y = dp(
common_attention.multihead_attention,
- x,
+ preprocess(x),
None,
decoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
@@ -104,14 +106,23 @@ def _diet_expert(x):
hparams.num_heads,
hparams.attention_dropout,
name="decoder_self_attention")
- elif hparams.attention_moe_type == AttentionMoeType.LOCAL:
+ elif hparams.attention_type == AttentionType.MEMORY_EFFICIENT:
+ assert hparams.layer_preprocess_sequence == "n"
+ y = dp(
+ common_attention.multihead_self_attention_memory_efficient,
+ x,
+ decoder_self_attention_bias,
+ hparams.num_heads,
+ name="decoder_self_attention")
+ elif hparams.attention_type == AttentionType.LOCAL_EXPERTS:
y, loss = dp(
common_attention.local_expert_attention,
- x,
+ preprocess(x),
k=2,
- loss_coef=1e-2,
+ loss_coef=hparams.attention_load_balance,
attention_num_experts=hparams.attention_num_experts,
train=hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
+ pad_remover=pad_remover,
mask_right=True,
attention_kq_size=hparams.attention_kq_size,
attention_v_size=hparams.attention_v_size)
@@ -119,7 +130,7 @@ def _diet_expert(x):
extra_loss += tf.add_n(loss) / dp.n
else:
raise ValueError("Only {} supported for now.".format(
- AttentionMoeType.get_choices()))
+ AttentionType.get_choices()))
x = postprocess(x, y)
with tf.variable_scope("ffn"):
if str(layer) in hparams.moe_layers.split(","):
@@ -134,6 +145,12 @@ def _diet_expert(x):
k=hparams.moe_k,
loss_coef=hparams.moe_loss_coef)
extra_loss += loss
+ elif hparams.memory_efficient_ffn:
+ assert hparams.layer_preprocess_sequence == "n"
+ y = dp(
+ common_layers.conv_hidden_relu_memory_efficient,
+ x,
+ hparams.filter_size)
else:
y = dp(
common_layers.conv_hidden_relu,
@@ -158,17 +175,22 @@ def attention_lm_moe_prepare_decoder(targets, hparams):
decoder_input: a Tensor, bottom of decoder stack
decoder_self_attention_bias: a Tensor, containing large negative values
to implement masked attention and possibly baises for diagonal alignments
+ pad_remover (expert_utils.PadRemover): an util object to remove padding
"""
+ targets_pad_mask = common_attention.embedding_to_padding(targets)
+ with tf.name_scope("pad_remover"):
+ pad_remover = expert_utils.PadRemover(targets_pad_mask)
+
if hparams.prepend_mode == "prepend_inputs_full_attention":
- decoder_self_attention_bias = (common_attention.attention_bias_prepended(
- common_attention.embedding_to_padding(targets)))
+ decoder_self_attention_bias = (
+ common_attention.attention_bias_prepended(targets_pad_mask))
else:
decoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
decoder_input = common_layers.shift_left_3d(targets)
if hparams.pos == "timing":
decoder_input = common_attention.add_timing_signal_1d(decoder_input)
- return (decoder_input, decoder_self_attention_bias)
+ return (decoder_input, decoder_self_attention_bias, pad_remover)
@registry.register_hparams
@@ -215,12 +237,15 @@ def attention_lm_moe_base():
hparams.add_hparam("pos", "timing") # timing, none
hparams.add_hparam("moe_layers", "2") # comma separated list of layer numbers
# moe params. local attention moe.
- hparams.add_hparam("attention_moe_type", AttentionMoeType.NONE)
+ hparams.add_hparam("attention_type", AttentionType.MULTIHEAD)
hparams.add_hparam("attention_num_experts", 16)
# Key, query and value dimensions for the attention
- hparams.add_hparam("attention_kq_size", 64)
- hparams.add_hparam("attention_v_size", 64)
+ hparams.add_hparam("attention_kq_size", 128)
+ hparams.add_hparam("attention_v_size", 256)
+ # Loss coef for load balancing
+ hparams.add_hparam("attention_load_balance", 2e-2)
hparams.add_hparam("diet_experts", int(False))
+ hparams.add_hparam("memory_efficient_ffn", int(False))
return hparams
@@ -228,9 +253,12 @@ def attention_lm_moe_base():
def attention_lm_moe_base_ae():
"""Base model with attention expert."""
hparams = attention_lm_moe_base()
- hparams.attention_moe_type = AttentionMoeType.LOCAL
+ hparams.attention_type = AttentionType.LOCAL_EXPERTS
hparams.max_length = hparams.batch_size
hparams.eval_drop_long_sequences = int(True)
+ hparams.batching_mantissa_bits = 2 # More buckets
+ hparams.learning_rate = 0.05
+ hparams.learning_rate_warmup_steps = 10000
return hparams
@@ -279,7 +307,7 @@ def attention_lm_attention_moe_tiny():
hparams.moe_layers = ""
hparams.attention_num_experts = 128
hparams.filter_size = 8192
- hparams.attention_moe_type = AttentionMoeType.LOCAL
+ hparams.attention_type = AttentionType.LOCAL_EXPERTS
return hparams
@@ -335,6 +363,21 @@ def attention_lm_moe_large_diet():
return hparams
+@registry.register_hparams
+def attention_lm_moe_memory_efficient():
+ """Memory-efficient version."""
+ hparams = attention_lm_moe_large()
+ hparams.diet_experts = int(True)
+ hparams.layer_preprocess_sequence = "n"
+ hparams.layer_postprocess_sequence = "da"
+ hparams.layer_prepostprocess_dropout = 0.0
+ hparams.memory_efficient_ffn = True
+ hparams.attention_type = AttentionType.MEMORY_EFFICIENT
+ hparams.num_heads = 8
+ hparams.factored_logits = int(True)
+ return hparams
+
+
@registry.register_hparams
def attention_lm_moe_32b_diet():
"""Unnecessarily large model with 32B params - because we can."""
diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py
index 47db28c30..105d9eb32 100644
--- a/tensor2tensor/models/transformer.py
+++ b/tensor2tensor/models/transformer.py
@@ -378,7 +378,6 @@ def transformer_big():
hparams.hidden_size = 1024
hparams.filter_size = 4096
hparams.num_heads = 16
- hparams.batching_mantissa_bits = 2
hparams.layer_prepostprocess_dropout = 0.3
return hparams
@@ -390,7 +389,6 @@ def transformer_big_single_gpu():
hparams.layer_prepostprocess_dropout = 0.1
hparams.learning_rate_warmup_steps = 16000
hparams.optimizer_adam_beta2 = 0.998
- hparams.batching_mantissa_bits = 3
return hparams
@@ -400,7 +398,6 @@ def transformer_base_single_gpu():
hparams = transformer_base()
hparams.batch_size = 2048
hparams.learning_rate_warmup_steps = 16000
- hparams.batching_mantissa_bits = 2
return hparams
@@ -593,7 +590,6 @@ def transformer_big_dr1():
hparams.filter_size = 4096
hparams.num_heads = 16
hparams.layer_prepostprocess_dropout = 0.1
- hparams.batching_mantissa_bits = 2
return hparams
diff --git a/tensor2tensor/models/transformer_vae.py b/tensor2tensor/models/transformer_vae.py
index fa6b3f397..1c566e996 100644
--- a/tensor2tensor/models/transformer_vae.py
+++ b/tensor2tensor/models/transformer_vae.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""VAE Transformer."""
+"""AE Transformer."""
from __future__ import absolute_import
from __future__ import division
@@ -32,10 +32,9 @@
import tensorflow as tf
-def residual_conv(x, repeat, hparams, name, reuse=None):
+def residual_conv(x, repeat, k, hparams, name, reuse=None):
"""A stack of convolution blocks with residual connections."""
with tf.variable_scope(name, reuse=reuse):
- k = (3, 1)
dilations_and_kernels = [((1, 1), k) for _ in xrange(3)]
for i in xrange(repeat):
with tf.variable_scope("repeat_%d" % i):
@@ -72,15 +71,19 @@ def interleave(x, y, axis=1):
return tf.concat([x, y], axis=axis+1)
-def decompress_step(source, c, hparams, first_relu, name):
+def decompress_step(source, c, hparams, first_relu, is_2d, name):
"""Decompression function."""
with tf.variable_scope(name):
shape = tf.shape(source)
if c is not None:
source = attend(source, c, hparams, "decompress_attend")
+ multiplier = 4 if is_2d else 2
+ kernel = (1, 1) if is_2d else (1, 1)
thicker = common_layers.conv_block(
- source, hparams.hidden_size * 2, [((1, 1), (1, 1))],
+ source, hparams.hidden_size * multiplier, [((1, 1), kernel)],
first_relu=first_relu, name="decompress_conv")
+ if is_2d:
+ return tf.depth_to_space(thicker, 2)
return tf.reshape(thicker, [shape[0], shape[1] * 2, 1, hparams.hidden_size])
@@ -90,7 +93,7 @@ def gumbel_sample(shape):
return -tf.log(-tf.log(uniform_samples))
-def dvae(x, hparams, name):
+def dae(x, hparams, name):
with tf.variable_scope(name):
m = tf.layers.dense(x, hparams.v_size, name="mask")
logsm = tf.nn.log_softmax(m)
@@ -128,7 +131,7 @@ def nearest(x, means, hparams):
_, nearest_idx = tf.nn.top_k(- dist, k=1)
nearest_hot = tf.one_hot(tf.squeeze(nearest_idx, axis=1), hparams.v_size)
nearest_hot = tf.reshape(nearest_hot, [tf.shape(x)[0], tf.shape(x)[1],
- 1, hparams.v_size])
+ tf.shape(x)[2], hparams.v_size])
return tf.stop_gradient(nearest_hot)
@@ -137,21 +140,23 @@ def kmeans(x, means, hparams, name):
x_means_hot = nearest(x, means, hparams)
x_means = tf.gather(means, tf.argmax(x_means_hot, axis=-1))
kl = tf.reduce_sum(tf.square(x - x_means), axis=-1)
- return x_means_hot, x_means_hot, tf.reduce_mean(kl) * 100.0
+ return x_means_hot, tf.reduce_mean(kl) * 10.0
-def compress(x, c, hparams, name):
+def compress(x, c, is_2d, hparams, name):
"""Compress."""
with tf.variable_scope(name):
# Run compression by strided convs.
cur = x
+ k1 = (3, 3) if is_2d else (3, 1)
+ k2 = (2, 2) if is_2d else (2, 1)
for i in xrange(hparams.num_compress_steps):
if c is not None:
cur = attend(cur, c, hparams, "compress_attend_%d" % i)
- cur = residual_conv(cur, 1, hparams, "compress_rc_%d" % i)
+ cur = residual_conv(cur, 1, k1, hparams, "compress_rc_%d" % i)
cur = common_layers.conv_block(
- cur, hparams.hidden_size, [((1, 1), (2, 1))],
- strides=(2, 1), name="compress_%d" % i)
+ cur, hparams.hidden_size, [((1, 1), k2)],
+ strides=k2, name="compress_%d" % i)
return cur
@@ -188,7 +193,7 @@ def decode(cond_vec, cond_add, gold, c, ed, hparams):
decoder_input = tf.squeeze(decoder_input, axis=2)
decoder_input = common_attention.add_timing_signal_1d(decoder_input)
bias = common_attention.attention_bias_lower_triangle(tf.shape(gold)[1])
- if c is not None:
+ if c is not None and len(c.get_shape()) > 3:
c = tf.squeeze(c, axis=2)
return transformer.transformer_decoder(decoder_input, c, bias, ed, hparams)
@@ -205,62 +210,62 @@ def expand_batch(x, mul):
return tf.reshape(cx, res_shape)
-def vae_compress(x, c, ed, hparams, compress_name, decompress_name, reuse=None):
- """Compress, then VAE."""
- with tf.variable_scope(compress_name, reuse=reuse):
- cur = compress(x, None, hparams, "compress")
+def ae_compress(x, is_2d, hparams, name, reuse=None):
+ """Compress, then AE."""
+ with tf.variable_scope(name, reuse=reuse):
+ cur = compress(x, None, is_2d, hparams, "compress")
# Convolve and ReLu to get state.
cur = common_layers.conv_block(
cur, hparams.hidden_size, [((1, 1), (1, 1))], name="mid_conv")
cur = tf.nn.l2_normalize(cur, dim=3)
+ cur_n = hparams.kmeans_lr_factor * cur
+ cur_n += (1.0 - hparams.kmeans_lr_factor) * tf.stop_gradient(cur)
means = tf.get_variable("z_to_dense", [hparams.v_size, hparams.hidden_size])
- # z, kl_loss, mu, log_sigma = vae(cur, hparams, name="vae")
- # z_true, z_sample, kl_loss = dvae(cur, hparams, name="dvae")
- z_true, z_sample, kl_loss = kmeans(cur, means, hparams, name="kmeans")
-
- # Compress context.
- with tf.variable_scope(compress_name, reuse=reuse):
- compress_c = compress(c, None, hparams, "compress_context")
- dec_c = decode(None, compress_c, cur, None, None, hparams)
- c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context")
- reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits(
- labels=z_true, logits=c_z)
+ hot, loss = kmeans(cur_n, means, hparams, name="kmeans")
+ # We need a linear layer to undo the l2-normalization.
+ cur = tf.layers.dense(cur, hparams.hidden_size, name="unnormalize")
+ return cur, hot, loss
- # If not training, use the predicted z instead of the autoregressive one.
- if hparams.mode == tf.contrib.learn.ModeKeys.INFER:
- z = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size)
- with tf.variable_scope(decompress_name, reuse=reuse):
- # Decompress.
- z_sample_flat = tf.reshape(z_sample, [-1, hparams.v_size])
- z = tf.matmul(z_sample_flat, means)
- z = tf.reshape(z, [tf.shape(z_sample)[0], tf.shape(z_sample)[1],
- 1, hparams.hidden_size])
+def ae_embed(hot, hparams, name, reuse=None):
+ with tf.variable_scope(name, reuse=reuse):
+ means = tf.get_variable("z_to_dense", [hparams.v_size, hparams.hidden_size])
+ hot_flat = tf.reshape(hot, [-1, hparams.v_size])
+ emb = tf.matmul(hot_flat, means)
+ emb = tf.reshape(emb, [tf.shape(hot)[0], tf.shape(hot)[1],
+ tf.shape(hot)[2], hparams.hidden_size])
+ return tf.layers.dense(emb, hparams.hidden_size,
+ name="unnormalize", reuse=reuse)
+
+def ae_decompress(z, ae, x, is_2d, hparams, name, reuse=None):
+ """Decompress from z, leaking from ae."""
+ with tf.variable_scope(name + "_decompress", reuse=reuse):
# Leak at the beginning to help train.
- z = mix(z, cur, hparams.startup_steps)
+ z = mix(z, ae, hparams.startup_steps)
+ prob_z = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.8
+ prob_z = prob_z if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN else 1.0
+ z = tf.cond(tf.less(tf.random_uniform([]), prob_z),
+ lambda: z, lambda: ae)
# Dropout for better autoencoding.
- z = tf.nn.dropout(z, keep_prob=0.9)
+ z = tf.nn.dropout(z, keep_prob=1.0 - hparams.z_dropout)
# Decompress.
d = z
for i in xrange(hparams.num_compress_steps):
j = hparams.num_compress_steps - i - 1
- d = residual_conv(d, 1, hparams, "decompress_rc_%d" % j)
- d = decompress_step(d, c, hparams, i > 0, "decompress_step_%d" % j)
+ d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
+ d = decompress_step(d, None, hparams, i > 0, is_2d, "decompress_%d" % j)
k = 2**hparams.num_compress_steps
z_batch = tf.reshape(z, [-1, 1, 1, hparams.hidden_size])
x_batch = tf.reshape(x, [-1, k, 1, hparams.hidden_size])
d_batch = tf.reshape(d, [-1, k, 1, hparams.hidden_size])
- # dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams)
- c = expand_batch(c, tf.shape(x_batch)[0] / tf.shape(x)[0])
- ed = expand_batch(ed, tf.shape(x_batch)[0] / tf.shape(x)[0])
- dec_batch = decode(z_batch, d_batch, x_batch, c, ed, hparams)
+ dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams)
z = tf.reshape(dec_batch, [-1, tf.shape(x)[1], 1, hparams.hidden_size])
- return z, kl_loss, reconstruct_loss
+ return z
def ffn(x, hparams, name):
@@ -270,35 +275,42 @@ def ffn(x, hparams, name):
return common_layers.layer_postprocess(x, y, hparams)
-def vae_transformer_internal(inputs, targets, target_space, hparams):
- """VAE Transformer, main step used for training."""
- with tf.variable_scope("vae_transformer"):
- # Prepare inputs, targets, and k.
- inputs = common_layers.flatten4d3d(inputs)
- input_len = tf.shape(inputs)[1] # Double input size to cover targets.
- inputs = tf.pad(inputs, [[0, 0], [0, input_len], [0, 0]])
- inputs.set_shape([None, None, hparams.hidden_size])
- targets = common_layers.flatten4d3d(targets)
+def ae_transformer_internal(inputs, targets, target_space, hparams):
+ """AE Transformer, main step used for training."""
+ with tf.variable_scope("ae_transformer"):
+ # Prepare inputs, targets, k.
k = 2**hparams.num_compress_steps
- inputs, targets = common_layers.pad_to_same_length(
- inputs, targets, final_length_divisible_by=k)
- inputs, ed_bias = encode(inputs, target_space, hparams, "input_enc")
-
- # Compress and vae.
- z, kl, r = vae_compress(tf.expand_dims(targets, axis=2),
- tf.expand_dims(inputs, axis=2),
- ed_bias, hparams, "vae_compress", "vae_decompress")
+ _, targets = common_layers.pad_to_same_length(
+ targets, targets, final_length_divisible_by=k)
+ inputs = common_layers.flatten4d3d(inputs)
+ inputs, ed = encode(inputs, target_space, hparams, "input_enc")
+
+ # Compress and ae.
+ ae, hot, kl = ae_compress(targets, False, hparams, "ae")
+ emb = ae_embed(hot, hparams, "ae", reuse=True)
+
+ # Compress context and run autoregressive decoder on emb-hot.
+ dec_c = decode(None, None, emb, inputs, ed, hparams)
+ c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context")
+ reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits(
+ labels=hot, logits=c_z)
+ # If not training, use the predicted z instead of the autoregressive one.
+ if hparams.mode == tf.contrib.learn.ModeKeys.INFER:
+ hot = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size)
+
+ # Decompress, pass for ae loss.
+ z = ae_decompress(emb, ae, targets, False, hparams, "ae")
kl *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.5))
- r *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 2.0))
- losses = {"kl": kl, "reconstruction": r}
+ reconstruct_loss *= common_layers.inverse_exp_decay(hparams.startup_steps)
+ losses = {"kl": kl, "reconstruction": reconstruct_loss}
return z, losses
@registry.register_model
-class TransformerVAE(t2t_model.T2TModel):
+class TransformerAE(t2t_model.T2TModel):
def model_fn_body(self, features):
- return vae_transformer_internal(
+ return ae_transformer_internal(
features["inputs"], features["targets"], features["target_space_id"],
self._hparams)
@@ -341,7 +353,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
@registry.register_hparams
-def transformer_vae_small():
+def transformer_ae_small():
"""Set of hyperparameters."""
hparams = transformer.transformer_small()
hparams.batch_size = 2048
@@ -351,13 +363,15 @@ def transformer_vae_small():
hparams.add_hparam("num_compress_steps", 4)
hparams.add_hparam("kl_warmup_steps", 60000)
hparams.add_hparam("startup_steps", 30000)
+ hparams.add_hparam("kmeans_lr_factor", 0.002)
+ hparams.add_hparam("z_dropout", 0.1)
return hparams
@registry.register_hparams
-def transformer_vae_base():
+def transformer_ae_base():
"""Set of hyperparameters."""
- hparams = transformer_vae_small()
+ hparams = transformer_ae_small()
hparams.hidden_size = 512
hparams.filter_size = 2048
hparams.attention_dropout = 0.0
diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py
index dbbd8e936..d55911f19 100644
--- a/tensor2tensor/utils/data_reader.py
+++ b/tensor2tensor/utils/data_reader.py
@@ -18,8 +18,6 @@
from __future__ import division
from __future__ import print_function
-import fractions
-import math
import os
import random
@@ -271,10 +269,11 @@ def input_pipeline(problem, data_file_pattern, capacity, mode, hparams,
dataset = bucket_by_sequence_length(dataset, _example_length,
batching_scheme["boundaries"],
- batching_scheme["batch_sizes"])
- max_batch_size = max(batching_scheme["batch_sizes"])
+ batching_scheme["batch_sizes"],
+ batching_scheme["window_size"])
# We reshuffle the batches to prevent many long-sequence batches at once.
- dataset = dataset.shuffle(max_batch_size * 3)
+ if batching_scheme["shuffle_queue_size"] is not None:
+ dataset = dataset.shuffle(batching_scheme["shuffle_queue_size"])
batched_examples = dataset.make_one_shot_iterator().get_next()
return batched_examples
@@ -308,38 +307,8 @@ def _example_too_big(example, max_length):
return tf.less_equal(_example_length(example), max_length)
-def _lcm(l):
- """Least common multiple of integers in a list."""
- if not l:
- raise ValueError("LCD of an empty list.")
- if len(l) == 1:
- return l[0]
- x = l[0]
- y = _lcm(l[1:])
- return x * y // fractions.gcd(x, y)
-
-
-def _closest_small_primes(x):
- """Closest number to x which has only 2, 3, 5 as prime factors, 3,5 once."""
- assert x > 0
- def is_small_primes(x, covered3, covered5):
- if x % 2 == 0:
- return is_small_primes(x // 2, covered3, covered5)
- if x % 3 == 0 and not covered3:
- return is_small_primes(x // 3, True, covered5)
- if x % 5 == 0 and not covered5:
- return is_small_primes(x // 5, covered3, True)
- return x == 1
- for i in xrange(x):
- if is_small_primes(x - i, False, False):
- return x - i
- # We search for higher numbers too, but only 8 of them to not increase much.
- if i < 9 and is_small_primes(x + i, False, False):
- return x + i
-
-
def bucket_by_sequence_length(dataset, example_length_fn, bucket_boundaries,
- bucket_batch_sizes):
+ bucket_batch_sizes, window_size):
"""Bucket entries in dataset by length.
Args:
@@ -348,18 +317,11 @@ def bucket_by_sequence_length(dataset, example_length_fn, bucket_boundaries,
the example, which will determine the bucket it goes into.
bucket_boundaries: list, boundaries of the buckets.
bucket_batch_sizes: list, batch size per bucket.
+ window_size: an integer divisible by all elements of bucket_batch_sizes
Returns:
Dataset of padded and batched examples.
"""
- # Since the Datasets API only allows a single constant for window_size,
- # and it needs divide all bucket_batch_sizes, we first make sure they only
- # have a few primes in them so that their LCM doesn't explode quickly.
- # TODO(lukaszkaiser): remove this adjustment when Dataset API improves.
- bucket_batch_sizes1 = [_closest_small_primes(b) for b in bucket_batch_sizes]
- tf.logging.info("Corrected bucket_batch_sizes from %s to %s."
- % (str(bucket_batch_sizes), str(bucket_batch_sizes1)))
- bucket_batch_sizes = bucket_batch_sizes1
with tf.name_scope("bucket_by_seq_length"):
def example_to_bucket_id(example):
@@ -386,25 +348,27 @@ def batching_fn(bucket_id, grouped_dataset):
for name, shape in grouped_dataset.output_shapes.items()])
return grouped_dataset.padded_batch(batch_size, padded_shapes)
- window_size = _lcm(bucket_batch_sizes)
dataset = dataset.group_by_window(example_to_bucket_id, batching_fn,
window_size)
return dataset
-def _bucket_boundaries(max_length, min_length=8, mantissa_bits=2):
+def _bucket_boundaries(max_length, min_length=8, length_bucket_step=1.1):
"""A default set of length-bucket boundaries."""
+ assert min_length <= max_length
+ assert length_bucket_step > 1.0
x = min_length
boundaries = []
while x < max_length:
boundaries.append(x)
- x += 2**max(0, int(math.log(x, 2)) - mantissa_bits)
+ x = max(x + 1, int(x * length_bucket_step))
return boundaries
-def _batching_scheme(batch_size=16 * 256,
- max_length=None,
- batching_mantissa_bits=1,
+def _batching_scheme(batch_size,
+ max_length,
+ min_length_bucket,
+ length_bucket_step,
drop_long_sequences=False,
shard_multiplier=1,
length_multiplier=1):
@@ -416,7 +380,8 @@ def _batching_scheme(batch_size=16 * 256,
batch_size: int, total number of tokens in a batch.
max_length: int, sequences longer than this will be skipped. Defaults to
batch_size.
- batching_mantissa_bits: int, ??.
+ min_length_bucket: int
+ length_bucket_step: float greater than 1.0
drop_long_sequences: bool, if True, then sequences longer than
`max_length` are dropped. This prevents generating batches with
more than the usual number of tokens, which can cause out-of-memory
@@ -434,19 +399,47 @@ def _batching_scheme(batch_size=16 * 256,
"""
max_length = max_length or batch_size
boundaries = _bucket_boundaries(
- max_length, mantissa_bits=batching_mantissa_bits)
+ max_length, min_length_bucket, length_bucket_step)
boundaries = [boundary * length_multiplier for boundary in boundaries]
max_length *= length_multiplier
-
batch_sizes = [
- max(1, batch_size // length) * shard_multiplier
- for length in boundaries + [max_length]
+ max(1, batch_size // length) for length in boundaries + [max_length]
]
- return {
+ max_batch_size = max(batch_sizes)
+ # Since the Datasets API only allows a single constant for window_size,
+ # and it needs divide all bucket_batch_sizes, we pick a highly-compoisite
+ # window size and then round down all batch sizes to divisors of that window
+ # size, so that a window can always be divided evenly into batches.
+ # TODO(noam): remove this when Dataset API improves.
+ highly_composite_numbers = [
+ 1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, 1680,
+ 2520, 5040, 7560, 10080, 15120, 20160, 25200, 27720, 45360, 50400, 55440,
+ 83160, 110880, 166320, 221760, 277200, 332640, 498960, 554400, 665280,
+ 720720, 1081080, 1441440, 2162160, 2882880, 3603600, 4324320, 6486480,
+ 7207200, 8648640, 10810800, 14414400, 17297280, 21621600, 32432400,
+ 36756720, 43243200, 61261200, 73513440, 110270160]
+ window_size = max([
+ i for i in highly_composite_numbers if i <= 3 * max_batch_size])
+ divisors = [i for i in xrange(1, window_size + 1) if window_size % i == 0]
+ batch_sizes = [max([d for d in divisors if d <= bs]) for bs in batch_sizes]
+ window_size *= shard_multiplier
+ batch_sizes = [bs * shard_multiplier for bs in batch_sizes]
+ # The Datasets API splits one window into multiple batches, which
+ # produces runs of many consecutive batches of the same size. This
+ # is bad for training. To solve this, we will shuffle the batches
+ # using a queue which must be several times as large as the maximum
+ # number of batches per window.
+ max_batches_per_window = window_size // min(batch_sizes)
+ shuffle_queue_size = max_batches_per_window * 3
+ ret = {
"boundaries": boundaries,
"batch_sizes": batch_sizes,
- "max_length": (max_length if drop_long_sequences else 10**9)
+ "max_length": (max_length if drop_long_sequences else 10**9),
+ "shuffle_queue_size": shuffle_queue_size,
+ "window_size": window_size,
}
+ tf.logging.info("batching_scheme = %s" % ret)
+ return ret
def hparams_to_batching_scheme(hparams,
@@ -455,9 +448,10 @@ def hparams_to_batching_scheme(hparams,
length_multiplier=1):
"""Wrapper around _batching_scheme with hparams."""
return _batching_scheme(
- max_length=hparams.max_length,
batch_size=hparams.batch_size,
- batching_mantissa_bits=hparams.batching_mantissa_bits,
+ max_length=hparams.max_length,
+ min_length_bucket=hparams.min_length_bucket,
+ length_bucket_step=hparams.length_bucket_step,
drop_long_sequences=drop_long_sequences,
shard_multiplier=shard_multiplier,
length_multiplier=length_multiplier)
@@ -477,7 +471,9 @@ def constant_batching_scheme(constant_batch_size_in_sequences):
return {
"boundaries": boundaries,
"batch_sizes": batch_sizes,
- "max_length": 10**9
+ "max_length": 10**9,
+ "shuffle_queue_size": None,
+ "window_size": constant_batch_size_in_sequences,
}
diff --git a/tensor2tensor/utils/data_reader_test.py b/tensor2tensor/utils/data_reader_test.py
index 318fb1cab..991669a99 100644
--- a/tensor2tensor/utils/data_reader_test.py
+++ b/tensor2tensor/utils/data_reader_test.py
@@ -169,36 +169,62 @@ def testLengthFilter(self):
def testBatchingSchemeMaxLength(self):
scheme = data_reader._batching_scheme(
- batch_size=20, max_length=None, drop_long_sequences=False)
+ batch_size=20, max_length=None,
+ min_length_bucket=8, length_bucket_step=1.1,
+ drop_long_sequences=False)
self.assertGreater(scheme["max_length"], 10000)
scheme = data_reader._batching_scheme(
- batch_size=20, max_length=None, drop_long_sequences=True)
+ batch_size=20, max_length=None,
+ min_length_bucket=8, length_bucket_step=1.1,
+ drop_long_sequences=True)
self.assertEqual(scheme["max_length"], 20)
scheme = data_reader._batching_scheme(
- batch_size=20, max_length=15, drop_long_sequences=True)
+ batch_size=20, max_length=15,
+ min_length_bucket=8, length_bucket_step=1.1,
+ drop_long_sequences=True)
self.assertEqual(scheme["max_length"], 15)
scheme = data_reader._batching_scheme(
- batch_size=20, max_length=15, drop_long_sequences=False)
+ batch_size=20, max_length=15,
+ min_length_bucket=8, length_bucket_step=1.1,
+ drop_long_sequences=False)
self.assertGreater(scheme["max_length"], 10000)
def testBatchingSchemeBuckets(self):
- scheme = data_reader._batching_scheme(batch_size=128)
+ scheme = data_reader._batching_scheme(
+ batch_size=128,
+ max_length=0,
+ min_length_bucket=8,
+ length_bucket_step=1.1)
boundaries, batch_sizes = scheme["boundaries"], scheme["batch_sizes"]
self.assertEqual(len(boundaries), len(batch_sizes) - 1)
- expected_boundaries = [8, 12, 16, 24, 32, 48, 64, 96]
+ expected_boundaries = [
+ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 24, 26, 28,
+ 30, 33, 36, 39, 42, 46, 50, 55, 60, 66, 72, 79, 86, 94, 103, 113, 124]
self.assertEqual(expected_boundaries, boundaries)
- expected_batch_sizes = [16, 10, 8, 5, 4, 2, 2, 1, 1]
+ expected_batch_sizes = [
+ 16, 12, 12, 8, 8, 8, 8, 8, 8, 6, 6, 6, 6, 4, 4, 4, 4, 4, 3, 3, 3,
+ 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1]
self.assertEqual(expected_batch_sizes, batch_sizes)
- scheme = data_reader._batching_scheme(batch_size=128, shard_multiplier=2)
+ scheme = data_reader._batching_scheme(
+ batch_size=128,
+ max_length=0,
+ min_length_bucket=8,
+ length_bucket_step=1.1,
+ shard_multiplier=2)
boundaries, batch_sizes = scheme["boundaries"], scheme["batch_sizes"]
self.assertAllEqual([bs * 2 for bs in expected_batch_sizes], batch_sizes)
self.assertEqual(expected_boundaries, boundaries)
- scheme = data_reader._batching_scheme(batch_size=128, length_multiplier=2)
+ scheme = data_reader._batching_scheme(
+ batch_size=128,
+ max_length=0,
+ min_length_bucket=8,
+ length_bucket_step=1.1,
+ length_multiplier=2)
boundaries, batch_sizes = scheme["boundaries"], scheme["batch_sizes"]
self.assertAllEqual([b * 2 for b in expected_boundaries], boundaries)
self.assertEqual([max(1, bs // 2)
@@ -211,14 +237,16 @@ def example_len(ex):
boundaries = [10, 20, 30]
batch_sizes = [10, 8, 4, 2]
+ window_size = 40
dataset = data_reader.read_examples(
self.problem,
self.filepatterns[0],
32,
mode=tf.contrib.learn.ModeKeys.EVAL)
- dataset = data_reader.bucket_by_sequence_length(dataset, example_len,
- boundaries, batch_sizes)
+ dataset = data_reader.bucket_by_sequence_length(
+ dataset, example_len,
+ boundaries, batch_sizes, window_size)
batch = dataset.make_one_shot_iterator().get_next()
input_vals = []
diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py
index 2e430a204..3f00c25a9 100644
--- a/tensor2tensor/utils/decoding.py
+++ b/tensor2tensor/utils/decoding.py
@@ -37,85 +37,114 @@
FLAGS = tf.flags.FLAGS
-def decode_from_dataset(estimator):
+def _decode_from_dataset_log_results(inputs,
+ targets,
+ outputs,
+ problem_name,
+ prediction_idx,
+ inputs_vocab,
+ targets_vocab,
+ save_images=False,
+ model_dir=None,
+ identity_output=False):
+ """Log inference results."""
+ if "image" in problem_name and save_images:
+ save_path = os.path.join(model_dir, "%s_prediction_%d.jpg" %
+ (problem_name, prediction_idx))
+ show_and_save_image(inputs / 255., save_path)
+ elif inputs_vocab:
+ decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs.flatten()))
+ tf.logging.info("Inference results INPUT: %s" % decoded_inputs)
+
+ if identity_output:
+ decoded_outputs = "".join(map(str, outputs.flatten()))
+ decoded_targets = "".join(map(str, targets.flatten()))
+ else:
+ decoded_outputs = "".join(
+ map(str, targets_vocab.decode(_save_until_eos(outputs.flatten()))))
+ decoded_targets = "".join(
+ map(str, targets_vocab.decode(_save_until_eos(targets.flatten()))))
+
+ tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs)
+ tf.logging.info("Inference results TARGET: %s" % decoded_targets)
+ return decoded_outputs, decoded_targets
+
+
+def decode_from_dataset(estimator,
+ problem_names,
+ return_beams=False,
+ beam_size=1,
+ max_predictions=-1,
+ decode_to_file=None,
+ save_images=False,
+ identity_output=False):
+ tf.logging.info("Performing local inference from dataset for %s.",
+ str(problem_names))
hparams = estimator.hparams
- for i, problem in enumerate(FLAGS.problems.split("-")):
- inputs_vocab = hparams.problems[i].vocabulary.get("inputs", None)
- targets_vocab = hparams.problems[i].vocabulary["targets"]
- tf.logging.info("Performing local inference.")
+
+ for problem_idx, problem_name in enumerate(problem_names):
+ # Build the inference input function
infer_problems_data = data_reader.get_data_filepatterns(
- FLAGS.problems, hparams.data_dir, tf.contrib.learn.ModeKeys.INFER)
+ problem_name, hparams.data_dir, tf.contrib.learn.ModeKeys.INFER)
infer_input_fn = input_fn_builder.build_input_fn(
mode=tf.contrib.learn.ModeKeys.INFER,
hparams=hparams,
data_file_patterns=infer_problems_data,
num_datashards=devices.data_parallelism().n,
- fixed_problem=i)
-
- def log_fn(inputs,
- targets,
- outputs,
- problem,
- j,
- inputs_vocab=inputs_vocab,
- targets_vocab=targets_vocab):
- """Log inference results."""
- if "image" in problem and FLAGS.decode_save_images:
- save_path = os.path.join(estimator.model_dir,
- "%s_prediction_%d.jpg" % (problem, j))
- show_and_save_image(inputs / 255., save_path)
- elif inputs_vocab:
- decoded_inputs = inputs_vocab.decode(
- _save_until_eos(inputs.flatten()))
- tf.logging.info("Inference results INPUT: %s" % decoded_inputs)
+ fixed_problem=problem_idx)
- if FLAGS.identity_output:
- decoded_outputs = " ".join(map(str, outputs.flatten()))
- decoded_targets = " ".join(map(str, targets.flatten()))
- else:
- decoded_outputs = " ".join(map(
- str, targets_vocab.decode(_save_until_eos(outputs.flatten()))))
- decoded_targets = " ".join(map(
- str, targets_vocab.decode(_save_until_eos(targets.flatten()))))
-
- tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs)
- tf.logging.info("Inference results TARGET: %s" % decoded_targets)
- return decoded_outputs, decoded_targets
-
- result_iter = estimator.predict(input_fn=infer_input_fn, as_iterable=True)
- count = 0
- agg_outputs = []
- agg_targets = []
- for result in result_iter:
- # predictions from the test input. We use it to log inputs and decodes.
- inputs = result["inputs"]
- targets = result["targets"]
- outputs = result["outputs"]
- if FLAGS.decode_return_beams:
- output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0)
- for k, beam in enumerate(output_beams):
- tf.logging.info("BEAM %d:" % k)
- o, t = log_fn(inputs, targets, beam, problem, count)
- agg_outputs.append(o)
- agg_targets.append(t)
- else:
- o, t = log_fn(inputs, targets, outputs, problem, count)
- agg_outputs.append(o)
- agg_targets.append(t)
+ # Get the predictions as an iterable
+ predictions = estimator.predict(input_fn=infer_input_fn, as_iterable=True)
+
+ # Prepare output file writers if decode_to_file passed
+ if decode_to_file:
+ output_filepath = decode_to_file + ".outputs." + problem_name
+ target_filepath = decode_to_file + ".targets." + problem_name
- count += 1
- if FLAGS.decode_num_samples != -1 and count >= FLAGS.decode_num_samples:
- break
- if FLAGS.decode_to_file:
- output_filepath = FLAGS.decode_to_file + ".outputs." + problem
output_file = tf.gfile.Open(output_filepath, "w")
- target_filepath = FLAGS.decode_to_file + ".targets." + problem
target_file = tf.gfile.Open(target_filepath, "w")
- for o, t in zip(agg_outputs, agg_targets):
- output_file.write(str(o)+"\n")
- target_file.write(str(t)+"\n")
- tf.logging.info("Completed inference on %d samples." % count)
+
+ problem_hparams = hparams.problems[problem_idx]
+ inputs_vocab = problem_hparams.vocabulary.get("inputs", None)
+ targets_vocab = problem_hparams.vocabulary["targets"]
+ for num_predictions, prediction in enumerate(predictions):
+ inputs = prediction["inputs"]
+ targets = prediction["targets"]
+ outputs = prediction["outputs"]
+
+ # Log predictions
+ decoded_outputs = []
+ if return_beams:
+ output_beams = np.split(outputs, beam_size, axis=0)
+ for i, beam in enumerate(output_beams):
+ tf.logging.info("BEAM %d:" % i)
+ decoded = _decode_from_dataset_log_results(
+ inputs, targets, beam, problem_name, num_predictions,
+ inputs_vocab, targets_vocab, save_images, estimator.model_dir,
+ identity_output)
+ decoded_outputs.append(decoded)
+ else:
+ decoded = _decode_from_dataset_log_results(
+ inputs, targets, outputs, problem_name, num_predictions,
+ inputs_vocab, targets_vocab, save_images, estimator.model_dir,
+ identity_output)
+ decoded_outputs.append(decoded)
+
+ # Write out predictions if decode_to_file passed
+ if decode_to_file:
+ for decoded_output, decoded_target in decoded_outputs:
+ output_file.write(str(decoded_output) + "\n")
+ target_file.write(str(decoded_target) + "\n")
+
+ if max_predictions >= 0 and num_predictions >= max_predictions:
+ break
+
+ if decode_to_file:
+ output_file.close()
+ target_file.close()
+
+ tf.logging.info("Completed inference on %d samples." % num_predictions) # pylint: disable=undefined-loop-variable
def decode_from_file(estimator, filename):
diff --git a/tensor2tensor/utils/expert_utils.py b/tensor2tensor/utils/expert_utils.py
index 6f26f20fa..fb1d1fac0 100644
--- a/tensor2tensor/utils/expert_utils.py
+++ b/tensor2tensor/utils/expert_utils.py
@@ -436,6 +436,87 @@ def noisy_top_k_gating(x,
return gates, load
+class PadRemover(object):
+ """Helper to remove padding from a tensor before sending to the experts.
+
+ The padding is computed for one reference tensor containing the padding mask
+ and then can be applied to any other tensor of shape [dim_origin,...].
+
+ Ex:
+ input = [
+ [tok1, tok2],
+ [tok3, tok4],
+ [0, 0],
+ [0, 0],
+ [tok5, tok6],
+ [0, 0],
+ ]
+ output = [
+ [tok1, tok2],
+ [tok3, tok4],
+ [tok5, tok6],
+ ]
+ """
+
+ def __init__(self, pad_mask):
+ """Compute and store the location of the padding.
+
+ Args:
+ pad_mask (tf.Tensor): Reference padding tensor of shape
+ [batch_size,length] or [dim_origin] (dim_origin=batch_size*length)
+ containing non-zeros positive values to indicate padding location.
+ """
+ self.nonpad_ids = None
+ self.dim_origin = None
+
+ with tf.name_scope("pad_reduce/get_ids"):
+ pad_mask = tf.reshape(pad_mask, [-1]) # Flatten the batch
+ # nonpad_ids contains coordinates of zeros rows (as pad_mask is
+ # float32, checking zero equality is done with |x| < epsilon, with
+ # epsilon=1e-9 as standard, here pad_mask only contains positive values
+ # so tf.abs would be redundant)
+ self.nonpad_ids = tf.to_int32(tf.where(pad_mask < 1e-9))
+ self.dim_origin = tf.shape(pad_mask)[:1]
+
+ def remove(self, x):
+ """Remove padding from the given tensor.
+
+ Args:
+ x (tf.Tensor): of shape [dim_origin,...]
+
+ Returns:
+ a tensor of shape [dim_compressed,...] with dim_compressed <= dim_origin
+ """
+ with tf.name_scope("pad_reduce/remove"):
+ x_shape = x.get_shape().as_list()
+ x = tf.gather_nd(
+ x,
+ indices=self.nonpad_ids,
+ )
+ # This is a hack but for some reason, gather_nd return a tensor of
+ # undefined shape, so the shape is set up manually
+ x.set_shape([None] + x_shape[1:])
+ return x
+
+ def restore(self, x):
+ """Add padding back to the given tensor.
+
+ Args:
+ x (tf.Tensor): of shape [dim_compressed,...]
+
+ Returns:
+ a tensor of shape [dim_origin,...] with dim_compressed >= dim_origin. The
+ dim is restored from the original reference tensor
+ """
+ with tf.name_scope("pad_reduce/restore"):
+ x = tf.scatter_nd(
+ indices=self.nonpad_ids,
+ updates=x,
+ shape=tf.concat([self.dim_origin, tf.shape(x)[1:]], axis=0),
+ )
+ return x
+
+
class SparseDispatcher(object):
"""Helper for implementing a mixture of experts.
@@ -766,6 +847,7 @@ def local_moe(x,
pass_x=True,
pass_gates=False,
additional_dispatch_params=None,
+ pad_remover=None,
name=None):
"""Call a local mixture of experts.
@@ -782,6 +864,8 @@ def local_moe(x,
additional_dispatch_params: The extra tensors that need to be sent to each
expert. Examples include batch batch coordinates (see
common_attention.local_expert_attention)
+ pad_remover (PadRemover): If given, the padding is removed/restored before
+ sending to the experts
name: a string
Returns:
@@ -791,8 +875,18 @@ def local_moe(x,
training loss of the model. The backpropagation of this loss
encourages all experts to be approximately equally used across a batch.
"""
+
with tf.variable_scope(name, default_name="local_moe"):
x_flat = flatten_all_but_last(x)
+
+ # Remove the padding tokens
+ if pad_remover:
+ x_flat = pad_remover.remove(x_flat)
+ tf.summary.scalar( # Should match the targets_nonpadding_tokens
+ "nonpadding_tokens",
+ tf.shape(x_flat)[0],
+ family="experts_stats")
+
# The gates indicate which batch elements go to which tensors.
# load is a measure of approximately how many examples go to each expert
gates, load = noisy_top_k_gating(
@@ -805,17 +899,27 @@ def local_moe(x,
noise_epsilon=1e-2)
# This magic object helps us shuffle data between datashards and experts.
dispatcher = SparseDispatcher(num_experts, gates)
+
+ # Set up expert_fn arguments
expert_kwargs = {}
if pass_x:
expert_kwargs["x"] = dispatcher.dispatch(x_flat)
if pass_gates:
expert_kwargs["gates"] = dispatcher.expert_to_gates()
for k, v in six.iteritems(additional_dispatch_params or {}):
- expert_kwargs[k] = dispatcher.dispatch(flatten_all_but_last(v))
+ v = flatten_all_but_last(v)
+ if pad_remover:
+ v = pad_remover.remove(v)
+ expert_kwargs[k] = dispatcher.dispatch(v)
+
ep = Parallelism([DEFAULT_DEV_STRING] * num_experts)
expert_outputs = ep(expert_fn, **expert_kwargs)
+
y_flat = dispatcher.combine(expert_outputs)
+ if pad_remover:
+ y_flat = pad_remover.restore(y_flat)
y = reshape_like(y_flat, x)
+
importance = tf.reduce_sum(gates, 0)
loss = loss_coef * (cv_squared(importance) + cv_squared(load))
return y, loss
diff --git a/tensor2tensor/utils/expert_utils_test.py b/tensor2tensor/utils/expert_utils_test.py
new file mode 100644
index 000000000..93af9c78c
--- /dev/null
+++ b/tensor2tensor/utils/expert_utils_test.py
@@ -0,0 +1,143 @@
+# coding=utf-8
+# Copyright 2017 The Tensor2Tensor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for tensor2tensor.utils.expert_utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+from tensor2tensor.layers import common_attention
+from tensor2tensor.utils import expert_utils
+import tensorflow as tf
+
+
+class ExpertUtilsTest(tf.test.TestCase):
+
+ def _verify_value(self, sess, tensor, expected):
+ output = sess.run(tensor)
+ self.assertAllClose(output, expected, 1e-9)
+
+ def testPadRemover(self):
+ """Check that the padding remover is working correctly."""
+ x_1 = tf.constant([
+ [1, 2, 3],
+ [4, 5, 6],
+ [7, 8, 9],
+ [0, 0, 0], # pad
+ [0, 0, 0], # pad
+ [0, 0, 0], # pad
+ [10, 11, 12],
+ [13, 14, 15],
+ [0, 0, 0], # pad
+ ], dtype=tf.float32)
+ # Get padding mask
+ x_pad_mask = common_attention.embedding_to_padding(x_1)
+ x_2 = tf.constant([
+ [1],
+ [2],
+ [3],
+ [4], # pad
+ [5], # pad
+ [6], # pad
+ [7],
+ [8],
+ [9], # pad
+ ], dtype=tf.float32)
+ x_3 = tf.constant([
+ 1,
+ 2,
+ 3,
+ 4, # pad
+ 5, # pad
+ 6, # pad
+ 7,
+ 8,
+ 9, # pad
+ ], dtype=tf.float32)
+
+ pad_remover = expert_utils.PadRemover(x_pad_mask)
+
+ y_1 = pad_remover.remove(x_1)
+ y_2 = pad_remover.remove(x_2)
+ y_3 = pad_remover.remove(x_3)
+
+ z_1 = pad_remover.restore(y_1 * 2)
+ z_2 = pad_remover.restore(y_2 * 2)
+ z_3 = pad_remover.restore(y_3 * 2)
+
+ with self.test_session() as sess:
+ # Padding should have been removed
+ self._verify_value(sess, y_1, [
+ [1., 2., 3.],
+ [4., 5., 6.],
+ [7., 8., 9.],
+ [10., 11., 12.],
+ [13., 14., 15.],
+ ])
+ self._verify_value(sess, y_2, [
+ [1.],
+ [2.],
+ [3.],
+ [7.],
+ [8.],
+ ])
+ self._verify_value(sess, y_3, [
+ 1.,
+ 2.,
+ 3.,
+ 7.,
+ 8.,
+ ])
+
+ # Padding should have been restored
+ self._verify_value(sess, z_1, [
+ [2., 4., 6.],
+ [8., 10., 12.],
+ [14., 16, 18.],
+ [0., 0., 0.],
+ [0., 0., 0.],
+ [0., 0., 0.],
+ [20., 22., 24.],
+ [26., 28., 30.],
+ [0., 0., 0.],
+ ])
+ self._verify_value(sess, z_2, [
+ [2.],
+ [4.],
+ [6.],
+ [0.], # pad
+ [0.], # pad
+ [0.], # pad
+ [14.],
+ [16.],
+ [0.], # pad
+ ])
+ self._verify_value(sess, z_3, [
+ 2.,
+ 4.,
+ 6.,
+ 0., # pad
+ 0., # pad
+ 0., # pad
+ 14.,
+ 16.,
+ 0., # pad
+ ])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensor2tensor/utils/input_fn_builder.py b/tensor2tensor/utils/input_fn_builder.py
index c31ba0f31..bef95d58f 100644
--- a/tensor2tensor/utils/input_fn_builder.py
+++ b/tensor2tensor/utils/input_fn_builder.py
@@ -183,6 +183,12 @@ def input_fn():
if mode == tf.contrib.learn.ModeKeys.INFER:
rand_feature_map["infer_targets"] = rand_target
rand_target = None
+ # This is because of a bug in the tf.contrib.learn Estimator that
+ # short-circuits prediction if it doesn't see a QueueRunner.
+ # DummyQueueRunner implements the minimal expected interface but does
+ # nothing.
+ # TODO(rsepassi): Remove once we move to core Estimator.
+ tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, DummyQueueRunner())
return rand_feature_map, rand_target
return input_fn
@@ -195,3 +201,14 @@ def cond_on_index(fn, index_tensor, cur_idx, max_idx):
return tf.cond(
tf.equal(index_tensor, cur_idx), lambda: fn(cur_idx),
lambda: cond_on_index(fn, index_tensor, cur_idx + 1, max_idx))
+
+
+class DummyQueueRunner(object):
+ """Can stand-in for a QueueRunner but does nothing."""
+
+ def __init__(self):
+ pass
+
+ def create_threads(self, sess, coord=None, daemon=False, start=False):
+ del sess, coord, daemon, start
+ return []
diff --git a/tensor2tensor/utils/metrics.py b/tensor2tensor/utils/metrics.py
index e5cb88ddf..baff66669 100644
--- a/tensor2tensor/utils/metrics.py
+++ b/tensor2tensor/utils/metrics.py
@@ -42,6 +42,7 @@ class Metrics(object):
R2 = "r_squared"
ROUGE_2_F = "rouge_2_fscore"
ROUGE_L_F = "rouge_L_fscore"
+ EDIT_DISTANCE = "edit_distance"
def padded_rmse(predictions, labels, weights_fn=common_layers.weights_all):
@@ -122,6 +123,50 @@ def padded_sequence_accuracy(predictions,
return correct_seq, tf.constant(1.0)
+def sequence_edit_distance(predictions,
+ labels,
+ weights_fn=common_layers.weights_nonzero):
+ """Average edit distance, ignoring padding 0s.
+
+ The score returned is the edit distance divided by the total length of
+ reference truth and the weight returned is the total length of the truth.
+
+ Args:
+ predictions: Tensor of shape [`batch_size`, `length`, 1, `num_classes`] and
+ type tf.float32 representing the logits, 0-padded.
+ labels: Tensor of shape [`batch_size`, `length`, 1, 1] and type tf.int32
+ representing the labels of same length as logits and 0-padded.
+ weights_fn: ignored. The weights returned are the total length of the ground
+ truth labels, excluding 0-paddings.
+
+ Returns:
+ (edit distance / reference length, reference length)
+
+ Raises:
+ ValueError: if weights_fn is not common_layers.weights_nonzero.
+ """
+ if weights_fn is not common_layers.weights_nonzero:
+ raise ValueError("Only weights_nonzero can be used for this metric.")
+
+ with tf.variable_scope("edit_distance", values=[predictions, labels]):
+ # Transform logits into sequence classes by taking max at every step.
+ predictions = tf.to_int32(
+ tf.squeeze(tf.argmax(predictions, axis=-1), axis=(2, 3)))
+ nonzero_idx = tf.where(tf.not_equal(predictions, 0))
+ sparse_outputs = tf.SparseTensor(nonzero_idx,
+ tf.gather_nd(predictions, nonzero_idx),
+ tf.shape(predictions, out_type=tf.int64))
+ labels = tf.squeeze(labels, axis=(2, 3))
+ nonzero_idx = tf.where(tf.not_equal(labels, 0))
+ label_sparse_outputs = tf.SparseTensor(nonzero_idx,
+ tf.gather_nd(labels, nonzero_idx),
+ tf.shape(labels, out_type=tf.int64))
+ distance = tf.reduce_sum(
+ tf.edit_distance(sparse_outputs, label_sparse_outputs, normalize=False))
+ reference_length = tf.to_float(tf.shape(nonzero_idx)[0])
+ return distance / reference_length, reference_length
+
+
def padded_neg_log_perplexity(predictions,
labels,
weights_fn=common_layers.weights_nonzero):
@@ -234,4 +279,5 @@ def problem_metric_fn(predictions, labels, weights):
Metrics.R2: padded_variance_explained,
Metrics.ROUGE_2_F: rouge.rouge_2_fscore,
Metrics.ROUGE_L_F: rouge.rouge_l_fscore,
+ Metrics.EDIT_DISTANCE: sequence_edit_distance,
}
diff --git a/tensor2tensor/utils/metrics_test.py b/tensor2tensor/utils/metrics_test.py
index 0d78e632c..528fd4755 100644
--- a/tensor2tensor/utils/metrics_test.py
+++ b/tensor2tensor/utils/metrics_test.py
@@ -72,6 +72,29 @@ def testSequenceAccuracyMetric(self):
actual = session.run(a)
self.assertEqual(actual, expected)
+ def testSequenceEditDistanceMetric(self):
+ predictions = np.array([[3, 4, 5, 1, 0, 0],
+ [2, 1, 3, 4, 0, 0],
+ [2, 1, 3, 4, 0, 0]])
+ # Targets are just a bit different:
+ # - first sequence has a different prediction
+ # - second sequence has a different prediction and one extra step
+ # - third sequence is identical
+ targets = np.array([[5, 4, 5, 1, 0, 0],
+ [2, 5, 3, 4, 1, 0],
+ [2, 1, 3, 4, 0, 0]])
+ # Reshape to match expected input format by metric fns.
+ predictions = np.reshape(predictions, [3, 6, 1, 1])
+ targets = np.reshape(targets, [3, 6, 1, 1])
+ with self.test_session() as session:
+ scores, weight = metrics.sequence_edit_distance(
+ tf.one_hot(predictions, depth=6, dtype=tf.float32),
+ tf.constant(targets, dtype=tf.int32))
+ session.run(tf.global_variables_initializer())
+ actual_scores, actual_weight = session.run([scores, weight])
+ self.assertAlmostEqual(actual_scores, 3.0 / 13)
+ self.assertEqual(actual_weight, 13)
+
def testNegativeLogPerplexity(self):
predictions = np.random.randint(4, size=(12, 12, 12, 1))
targets = np.random.randint(4, size=(12, 12, 12, 1))
diff --git a/tensor2tensor/utils/model_builder.py b/tensor2tensor/utils/model_builder.py
index 24c17ca9e..34af6c827 100644
--- a/tensor2tensor/utils/model_builder.py
+++ b/tensor2tensor/utils/model_builder.py
@@ -104,6 +104,14 @@ def learning_rate_decay():
elif hparams.learning_rate_decay_scheme == "cosine":
cycle_steps = hparams.learning_rate_cosine_cycle_steps
return 0.5 * (1 + tf.cos(np.pi * (step % cycle_steps) / cycle_steps))
+ elif hparams.learning_rate_decay_scheme == "cyclelinear10x":
+ # Cycle the rate linearly by 10x every warmup_steps, up and down.
+ cycle_steps = hparams.learning_rate_warmup_steps
+ cycle_position = step % (2 * cycle_steps)
+ cycle_position = tf.to_float( # Normalize to the interval [-1, 1].
+ cycle_position - cycle_steps) / float(cycle_steps)
+ cycle_position = 1.0 - tf.abs(cycle_position) # 0 to 1 and back to 0.
+ return (cycle_position + 0.1) * 3.0 # 10x difference each cycle (0.3-3).
inv_base = tf.exp(tf.log(0.01) / warmup_steps)
inv_decay = inv_base**(warmup_steps - step)
@@ -156,12 +164,6 @@ def model_fn(features, targets, mode):
features = _interactive_input_tensor_to_features_dict(features, my_hp)
elif FLAGS.decode_from_file:
features = _decode_input_tensor_to_features_dict(features, my_hp)
- # A dictionary containing:
- # - problem_choice: A Tensor containing an integer indicating which problem
- # was selected for this run.
- # - predictions: A Tensor containing the model's output predictions.
- run_info = dict()
- run_info["problem_choice"] = features["problem_choice"]
if targets is not None:
features["targets"] = targets
@@ -185,17 +187,20 @@ def model_fn(features, targets, mode):
tf.summary.scalar("%s_nonpadding_fraction" % k,
tf.reduce_mean(nonpadding))
- # The new data reader occasionally emits very small batches, which
- # cause the examples in those batches to be grossly overweighted.
- # We decrease the loss proportionally to the ratio of the size of this
- # batch to the size of the largest training batch ever.
- # TODO(noam): to be more sophisticated, we could keep separate
- # maxima based on problem choice.
- max_nonpadding_var = tf.get_variable(
- "max_nonpadding", shape=[],
- initializer=tf.ones_initializer(), trainable=False)
- max_nonpadding = tf.maximum(max_nonpadding_var, targets_nonpadding_tokens)
if is_training:
+ # The new data reader occasionally emits very small batches, which
+ # cause the examples in those batches to be grossly overweighted.
+ # We decrease the loss proportionally to the ratio of the size of this
+ # batch to the size of the largest training batch ever.
+ # TODO(noam): to be more sophisticated, we could keep separate
+ # maxima based on problem choice.
+ max_nonpadding_var = tf.get_variable(
+ "max_nonpadding",
+ shape=[],
+ initializer=tf.ones_initializer(),
+ trainable=False)
+ max_nonpadding = tf.maximum(max_nonpadding_var,
+ targets_nonpadding_tokens)
with tf.control_dependencies(
[tf.assign(max_nonpadding_var, max_nonpadding)]):
small_batch_multiplier = targets_nonpadding_tokens / max_nonpadding
@@ -203,6 +208,7 @@ def model_fn(features, targets, mode):
# Get multi-problem logits and loss based on features["problem_choice"].
loss_variable_names = []
+
def nth_model(n):
"""Build the model for the n-th problem, plus some added variables."""
model_class = registry.model(model)(
@@ -249,8 +255,8 @@ def nth_model(n):
# Total loss was already constructed on input.
loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n)
except ValueError:
- loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n,
- initializer=100.0, trainable=False)
+ loss_moving_avg = tf.get_variable(
+ "problem_%d/total_loss" % n, initializer=100.0, trainable=False)
ops.append(
loss_moving_avg.assign(loss_moving_avg * 0.9 + total_loss * 0.1))
with tf.variable_scope("train_stats"): # Count steps for this problem.
@@ -287,11 +293,13 @@ def nth_model(n):
sharded_logits, total_loss = result_list[1:], result_list[0]
if mode == tf.contrib.learn.ModeKeys.EVAL:
- logits = tf.concat(sharded_logits, 0)
# For evaluation, return the logits layer as our predictions.
- run_info["predictions"] = logits
- train_op = None
- return run_info, total_loss, None
+ logits = tf.concat(sharded_logits, 0)
+ ret = {
+ "predictions": logits,
+ "problem_choice": features["problem_choice"],
+ }
+ return ret, total_loss, None
assert mode == tf.contrib.learn.ModeKeys.TRAIN
@@ -373,7 +381,7 @@ def nth_model(n):
del summaries[i]
tf.logging.info("Global model_fn finished.")
- return run_info, total_loss, train_op
+ return {"problem_choice": features["problem_choice"]}, total_loss, train_op
return model_fn
diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py
index fa9d9233e..a747b9a09 100644
--- a/tensor2tensor/utils/trainer_utils.py
+++ b/tensor2tensor/utils/trainer_utils.py
@@ -27,7 +27,6 @@
from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.models import models # pylint: disable=unused-import
from tensor2tensor.utils import data_reader
-from tensor2tensor.utils import decoding
from tensor2tensor.utils import devices
from tensor2tensor.utils import input_fn_builder
from tensor2tensor.utils import metrics
@@ -101,16 +100,13 @@
flags.DEFINE_string("ps_job", "/job:ps", "name of ps job")
flags.DEFINE_integer("ps_replicas", 0, "How many ps replicas.")
-# Decode flags
-# Set one of {decode_from_dataset, decode_interactive, decode_from_file} to
-# decode.
-flags.DEFINE_bool("decode_from_dataset", False, "Decode from dataset on disk.")
-flags.DEFINE_bool("decode_use_last_position_only", False,
- "In inference, use last position only for speedup.")
+# Decoding flags
+flags.DEFINE_string("decode_from_file", None, "Path to decode file")
flags.DEFINE_bool("decode_interactive", False,
"Interactive local inference mode.")
+flags.DEFINE_bool("decode_use_last_position_only", False,
+ "In inference, use last position only for speedup.")
flags.DEFINE_bool("decode_save_images", False, "Save inference input images.")
-flags.DEFINE_string("decode_from_file", None, "Path to decode file")
flags.DEFINE_string("decode_to_file", None, "Path to inference output file")
flags.DEFINE_integer("decode_shards", 1, "How many shards to decode.")
flags.DEFINE_integer("decode_problem_id", 0, "Which problem to decode.")
@@ -128,7 +124,7 @@
"Maximum number of ids in input. Or <= 0 for no max.")
flags.DEFINE_bool("identity_output", False, "To print the output as identity")
flags.DEFINE_integer("decode_num_samples", -1,
- "Number of samples to decode. Currently used in"
+ "Number of samples to decode. Currently used in "
"decode_from_dataset. Use -1 for all.")
@@ -149,8 +145,8 @@ def experiment_fn(output_dir):
def create_experiment(output_dir, data_dir, model_name, train_steps,
eval_steps):
"""Create Experiment."""
- hparams = create_hparams(FLAGS.hparams_set, FLAGS.problems, data_dir,
- passed_hparams=FLAGS.hparams)
+ hparams = create_hparams(
+ FLAGS.hparams_set, FLAGS.problems, data_dir, passed_hparams=FLAGS.hparams)
estimator, input_fns = create_experiment_components(
hparams=hparams,
output_dir=output_dir,
@@ -303,7 +299,6 @@ def run(data_dir, model, output_dir, train_steps, eval_steps, schedule):
if exp.train_steps > 0 or exp.eval_steps > 0:
tf.logging.info("Performing local training and evaluation.")
exp.train_and_evaluate()
- decode(exp.estimator)
else:
# Perform distributed training/evaluation.
learn_runner.run(
@@ -350,12 +345,3 @@ def session_config():
def get_data_filepatterns(data_dir, mode):
return data_reader.get_data_filepatterns(FLAGS.problems, data_dir, mode)
-
-
-def decode(estimator):
- if FLAGS.decode_interactive:
- decoding.decode_interactively(estimator)
- elif FLAGS.decode_from_file is not None and FLAGS.decode_from_file is not "":
- decoding.decode_from_file(estimator, FLAGS.decode_from_file)
- elif FLAGS.decode_from_dataset:
- decoding.decode_from_dataset(estimator)
diff --git a/tensor2tensor/visualization/TransformerVisualization.ipynb b/tensor2tensor/visualization/TransformerVisualization.ipynb
index ef1c7b45d..e3fb8f958 100644
--- a/tensor2tensor/visualization/TransformerVisualization.ipynb
+++ b/tensor2tensor/visualization/TransformerVisualization.ipynb
@@ -86,7 +86,7 @@
"import os\n",
"# PUT THE MODEL YOU WANT TO LOAD HERE!\n",
"\n",
- "PROBLEM = 'wmt_ende_tokens_32k'\n",
+ "PROBLEM = 'translate_ende_wmt32k'\n",
"MODEL = 'transformer'\n",
"HPARAMS = 'transformer_base_single_gpu'\n",
"\n",
@@ -118,7 +118,7 @@
}
],
"source": [
- "hparams = utils.create_hparams(HPARAMS, DATA_DIR)\n",
+ "hparams = utils.create_hparams(HPARAMS, PROBLEM, DATA_DIR)\n",
"\n",
"# SET EXTRA HYPER PARAMS HERE!\n",
"# e.g.\n",
@@ -381,8 +381,7 @@
},
"outputs": [],
"source": [
- "der = decode(beam_decode[0])\n",
- "output_ids = encode(der)\n",
+ "output_ids = beam_decode\n",
"\n",
"# Get attentions\n",
"np_enc_atts, np_dec_atts, np_encdec_atts = sess.run([enc_atts, dec_atts, encdec_atts], {\n",
diff --git a/tensor2tensor/visualization/__init__.py b/tensor2tensor/visualization/__init__.py
new file mode 100644
index 000000000..b62605264
--- /dev/null
+++ b/tensor2tensor/visualization/__init__.py
@@ -0,0 +1,16 @@
+# coding=utf-8
+# Copyright 2017 The Tensor2Tensor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
diff --git a/tensor2tensor/visualization/attention.py b/tensor2tensor/visualization/attention.py
index 2c1f61c9c..bc4238081 100644
--- a/tensor2tensor/visualization/attention.py
+++ b/tensor2tensor/visualization/attention.py
@@ -21,8 +21,7 @@
import json
import os
-from IPython.display import HTML
-from IPython.display import Javascript
+import IPython.display as display
import numpy as np
@@ -53,9 +52,9 @@ def show(inp_text, out_text, enc_atts, dec_atts, encdec_atts):
def _show_attention(att_json):
- display(HTML(vis_html)) # pylint: disable=undefined-variable
- display(Javascript('window.attention = %s' % att_json)) # pylint: disable=undefined-variable
- display(Javascript(vis_js)) # pylint: disable=undefined-variable
+ display.display(display.HTML(vis_html))
+ display.display(display.Javascript('window.attention = %s' % att_json))
+ display.display(display.Javascript(vis_js))
def _get_attention(inp_text, out_text, enc_atts, dec_atts, encdec_atts):
@@ -88,8 +87,8 @@ def _get_attention(inp_text, out_text, enc_atts, dec_atts, encdec_atts):
"""
def get_full_attention(layer):
"""Get the full input+output - input+output attentions."""
- enc_att = enc_atts[layer][0],
- dec_att = dec_atts[layer][0],
+ enc_att = enc_atts[layer][0]
+ dec_att = dec_atts[layer][0]
encdec_att = encdec_atts[layer][0]
enc_att = np.transpose(enc_att, [0, 2, 1])
dec_att = np.transpose(dec_att, [0, 2, 1])