Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #172 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.1.1
  • Loading branch information
lukaszkaiser authored Jul 20, 2017
2 parents 47d556a + 2fd79ec commit 668e385
Show file tree
Hide file tree
Showing 19 changed files with 510 additions and 163 deletions.
23 changes: 19 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,25 @@ send along a pull request to add your dataset or model.
See [our contribution
doc](CONTRIBUTING.md) for details and our [open
issues](https://github.com/tensorflow/tensor2tensor/issues).
And chat with us and other users on
[Gitter](https://gitter.im/tensor2tensor/Lobby).
You can chat with us and other users on
[Gitter](https://gitter.im/tensor2tensor/Lobby) and please join our
[Google Group](https://groups.google.com/forum/#!forum/tensor2tensor) to keep up
with T2T announcements.

Here is a one-command version that installs tensor2tensor, downloads the data,
trains an English-German translation model, and lets you use it interactively:
```
pip install tensor2tensor && t2t-trainer \
--generate_data \
--data_dir=~/t2t_data \
--problems=wmt_ende_tokens_32k \
--model=transformer \
--hparams_set=transformer_base_single_gpu \
--output_dir=~/t2t_train/base \
--decode_interactive
```

See the [Walkthrough](#walkthrough) below for more details on each step.

### Contents

Expand Down Expand Up @@ -72,8 +89,6 @@ t2t-datagen \
--num_shards=100 \
--problem=$PROBLEM
cp $TMP_DIR/tokens.vocab.* $DATA_DIR
# Train
# * If you run out of memory, add --hparams='batch_size=2048' or even 1024.
t2t-trainer \
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.1.0',
version='1.1.1',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
137 changes: 49 additions & 88 deletions tensor2tensor/bin/t2t-datagen
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import random
import tempfile

Expand Down Expand Up @@ -79,24 +80,30 @@ _SUPPORTED_PROBLEM_GENERATORS = {
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
"ice_parsing_tokens": (
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
True, "ice", 2**13, 2**8),
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
False, "ice", 2**13, 2**8)),
lambda: wmt.tabbed_parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True, "ice", 2**13, 2**8),
lambda: wmt.tabbed_parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False, "ice", 2**13, 2**8)),
"ice_parsing_characters": (
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, True),
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, False)),
lambda: wmt.tabbed_parsing_character_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True),
lambda: wmt.tabbed_parsing_character_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False)),
"wmt_parsing_tokens_8k": (
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, True, 2**13),
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, False, 2**13)),
lambda: wmt.parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13),
lambda: wmt.parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13)),
"wsj_parsing_tokens_16k": (
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, True,
2**14, 2**9),
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False,
2**14, 2**9)),
lambda: wsj_parsing.parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**14, 2**9),
lambda: wsj_parsing.parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**14, 2**9)),
"wmt_ende_bpe32k": (
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, True),
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, False)),
lambda: wmt.ende_bpe_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True),
lambda: wmt.ende_bpe_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False)),
"lm1b_32k": (
lambda: lm1b.generator(FLAGS.tmp_dir, True),
lambda: lm1b.generator(FLAGS.tmp_dir, False)
Expand All @@ -118,98 +125,50 @@ _SUPPORTED_PROBLEM_GENERATORS = {
lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 50000),
lambda: image.cifar10_generator(FLAGS.tmp_dir, False, 10000)),
"image_mscoco_characters_test": (
lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 80000),
lambda: image.mscoco_generator(FLAGS.tmp_dir, False, 40000)),
lambda: image.mscoco_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True, 80000),
lambda: image.mscoco_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False, 40000)),
"image_celeba_tune": (
lambda: image.celeba_generator(FLAGS.tmp_dir, 162770),
lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)),
"image_mscoco_tokens_8k_test": (
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
True,
80000,
vocab_filename="tokens.vocab.%d" % 2**13,
vocab_size=2**13),
FLAGS.data_dir, FLAGS.tmp_dir, True, 80000,
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13),
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
False,
40000,
vocab_filename="tokens.vocab.%d" % 2**13,
vocab_size=2**13)),
FLAGS.data_dir, FLAGS.tmp_dir, False, 40000,
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13)),
"image_mscoco_tokens_32k_test": (
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
True,
80000,
vocab_filename="tokens.vocab.%d" % 2**15,
vocab_size=2**15),
FLAGS.data_dir, FLAGS.tmp_dir, True, 80000,
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15),
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
False,
40000,
vocab_filename="tokens.vocab.%d" % 2**15,
vocab_size=2**15)),
FLAGS.data_dir, FLAGS.tmp_dir, False, 40000,
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
"snli_32k": (
lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
),
"audio_timit_characters_tune": (
lambda: audio.timit_generator(FLAGS.tmp_dir, True, 1374),
lambda: audio.timit_generator(FLAGS.tmp_dir, True, 344, 1374)),
"audio_timit_characters_test": (
lambda: audio.timit_generator(FLAGS.tmp_dir, True, 1718),
lambda: audio.timit_generator(FLAGS.tmp_dir, False, 626)),
"audio_timit_tokens_8k_tune": (
lambda: audio.timit_generator(
FLAGS.tmp_dir,
True,
1374,
vocab_filename="tokens.vocab.%d" % 2**13,
vocab_size=2**13),
FLAGS.data_dir, FLAGS.tmp_dir, True, 1718),
lambda: audio.timit_generator(
FLAGS.tmp_dir,
True,
344,
1374,
vocab_filename="tokens.vocab.%d" % 2**13,
vocab_size=2**13)),
FLAGS.data_dir, FLAGS.tmp_dir, False, 626)),
"audio_timit_tokens_8k_test": (
lambda: audio.timit_generator(
FLAGS.tmp_dir,
True,
1718,
vocab_filename="tokens.vocab.%d" % 2**13,
vocab_size=2**13),
lambda: audio.timit_generator(
FLAGS.tmp_dir,
False,
626,
vocab_filename="tokens.vocab.%d" % 2**13,
vocab_size=2**13)),
"audio_timit_tokens_32k_tune": (
lambda: audio.timit_generator(
FLAGS.tmp_dir,
True,
1374,
vocab_filename="tokens.vocab.%d" % 2**15,
vocab_size=2**15),
FLAGS.data_dir, FLAGS.tmp_dir, True, 1718,
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13),
lambda: audio.timit_generator(
FLAGS.tmp_dir,
True,
344,
1374,
vocab_filename="tokens.vocab.%d" % 2**15,
vocab_size=2**15)),
FLAGS.data_dir, FLAGS.tmp_dir, False, 626,
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13)),
"audio_timit_tokens_32k_test": (
lambda: audio.timit_generator(
FLAGS.tmp_dir,
True,
1718,
vocab_filename="tokens.vocab.%d" % 2**15,
vocab_size=2**15),
FLAGS.data_dir, FLAGS.tmp_dir, True, 1718,
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15),
lambda: audio.timit_generator(
FLAGS.tmp_dir,
False,
626,
vocab_filename="tokens.vocab.%d" % 2**15,
vocab_size=2**15)),
FLAGS.data_dir, FLAGS.tmp_dir, False, 626,
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
"lmptb_10k": (
lambda: ptb.train_generator(
FLAGS.tmp_dir,
Expand Down Expand Up @@ -317,7 +276,9 @@ def generate_data_for_problem(problem):

def generate_data_for_registered_problem(problem_name):
problem = registry.problem(problem_name)
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir, FLAGS.num_shards)
problem.generate_data(os.path.expanduser(FLAGS.data_dir),
os.path.expanduser(FLAGS.tmp_dir),
FLAGS.num_shards)


if __name__ == "__main__":
Expand Down
24 changes: 20 additions & 4 deletions tensor2tensor/bin/t2t-trainer
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ from __future__ import print_function

# Dependency imports

from tensor2tensor.utils import trainer_utils as utils
from tensor2tensor.utils import registry
from tensor2tensor.utils import trainer_utils
from tensor2tensor.utils import usr_dir

import tensorflow as tf
Expand All @@ -45,14 +46,29 @@ flags.DEFINE_string("t2t_usr_dir", "",
"The imported files should contain registrations, "
"e.g. @registry.register_model calls, that will then be "
"available to the t2t-trainer.")
flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen",
"Temporary storage directory.")
flags.DEFINE_bool("generate_data", False, "Generate data before training?")


def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
utils.log_registry()
utils.validate_flags()
utils.run(
trainer_utils.log_registry()
trainer_utils.validate_flags()
tf.gfile.MakeDirs(FLAGS.output_dir)

# Generate data if requested.
if FLAGS.generate_data:
tf.gfile.MakeDirs(FLAGS.data_dir)
tf.gfile.MakeDirs(FLAGS.tmp_dir)
for problem_name in FLAGS.problems.split("-"):
tf.logging.info("Generating data for %s" % problem_name)
problem = registry.problem(problem_name)
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)

# Run the trainer.
trainer_utils.run(
data_dir=FLAGS.data_dir,
model=FLAGS.model,
output_dir=FLAGS.output_dir,
Expand Down
9 changes: 9 additions & 0 deletions tensor2tensor/data_generators/all_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,13 @@
from tensor2tensor.data_generators import wiki
from tensor2tensor.data_generators import wmt
from tensor2tensor.data_generators import wsj_parsing

# Problem modules that require optional dependencies
# pylint: disable=g-import-not-at-top
try:
# Requires h5py
from tensor2tensor.data_generators import genetics
except ImportError:
pass
# pylint: enable=g-import-not-at-top
# pylint: enable=unused-import
6 changes: 4 additions & 2 deletions tensor2tensor/data_generators/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def _get_text_data(filepath):
return " ".join(words)


def timit_generator(tmp_dir,
def timit_generator(data_dir,
tmp_dir,
training,
how_many,
start_from=0,
Expand All @@ -107,6 +108,7 @@ def timit_generator(tmp_dir,
"""Data generator for TIMIT transcription problem.
Args:
data_dir: path to the data directory.
tmp_dir: path to temporary storage directory.
training: a Boolean; if true, we use the train set, otherwise the test set.
how_many: how many inputs and labels to generate.
Expand All @@ -128,7 +130,7 @@ def timit_generator(tmp_dir,
eos_list = [1] if eos_list is None else eos_list
if vocab_filename is not None:
vocab_symbolizer = generator_utils.get_or_generate_vocab(
tmp_dir, vocab_filename, vocab_size)
data_dir, tmp_dir, vocab_filename, vocab_size)
_get_timit(tmp_dir)
datasets = (_TIMIT_TRAIN_DATASETS if training else _TIMIT_TEST_DATASETS)
i = 0
Expand Down
14 changes: 6 additions & 8 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,13 @@ def gunzip_file(gz_path, new_path):
"http://www.statmt.org/wmt13/training-parallel-un.tgz",
["un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr"]
],
[
"https://github.com/stefan-it/nmt-mk-en/raw/master/data/setimes.mk-en.train.tgz", # pylint: disable=line-too-long
["train.mk", "train.en"]
],
]


def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
def get_or_generate_vocab(data_dir, tmp_dir,
vocab_filename, vocab_size, sources=None):
"""Generate a vocabulary from the datasets in sources (_DATA_FILE_URLS)."""
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
vocab_filepath = os.path.join(data_dir, vocab_filename)
if tf.gfile.Exists(vocab_filepath):
tf.logging.info("Found vocab file: %s", vocab_filepath)
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
Expand Down Expand Up @@ -304,7 +301,7 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
return vocab


def get_or_generate_tabbed_vocab(tmp_dir, source_filename,
def get_or_generate_tabbed_vocab(data_dir, tmp_dir, source_filename,
index, vocab_filename, vocab_size):
r"""Generate a vocabulary from a tabbed source file.
Expand All @@ -313,6 +310,7 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename,
The index parameter specifies 0 for the source or 1 for the target.
Args:
data_dir: path to the data directory.
tmp_dir: path to the temporary directory.
source_filename: the name of the tab-separated source file.
index: index.
Expand All @@ -322,7 +320,7 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename,
Returns:
The vocabulary.
"""
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
vocab_filepath = os.path.join(data_dir, vocab_filename)
if os.path.exists(vocab_filepath):
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
return vocab
Expand Down
Loading

0 comments on commit 668e385

Please sign in to comment.