Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #325 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.2.4
  • Loading branch information
lukaszkaiser authored Sep 30, 2017
2 parents feb752c + 583356d commit ffc5800
Show file tree
Hide file tree
Showing 38 changed files with 2,146 additions and 473 deletions.
17 changes: 15 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,22 @@ before_install:
install:
- pip install tensorflow
- pip install .[tests]
env:
global:
- T2T_PROBLEM=algorithmic_reverse_binary40_test
- T2T_DATA_DIR=/tmp/t2t-data
- T2T_TRAIN_DIR=/tmp/t2t-train
script:
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/utils/trainer_utils_test.py --ignore=tensor2tensor/problems_test.py
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/utils/trainer_utils_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/tpu/tpu_trainer_lib_test.py
- pytest tensor2tensor/utils/registry_test.py
- pytest tensor2tensor/utils/trainer_utils_test.py
- t2t-datagen 2>&1 | grep translate && echo passed
- python -c "from tensor2tensor.models import transformer; print(transformer.Transformer.__name__)"
- t2t-trainer --registry_help
- mkdir $T2T_DATA_DIR
- mkdir $T2T_TRAIN_DIR
- t2t-datagen --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR
- t2t-trainer --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --train_steps=5 --eval_steps=5 --output_dir=$T2T_TRAIN_DIR
- t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR
git:
depth: 3
depth: 3
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.2.3',
version='1.2.4',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
11 changes: 7 additions & 4 deletions tensor2tensor/bin/t2t-decoder
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main(_):

hparams = trainer_utils.create_hparams(
FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams)
hparams = trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
estimator, _ = trainer_utils.create_experiment_components(
data_dir=data_dir,
model_name=FLAGS.model,
Expand All @@ -90,9 +90,12 @@ def main(_):
decoding.decode_from_file(estimator, FLAGS.decode_from_file, decode_hp,
FLAGS.decode_to_file)
else:
decoding.decode_from_dataset(estimator,
FLAGS.problems.split("-"), decode_hp,
FLAGS.decode_to_file)
decoding.decode_from_dataset(
estimator,
FLAGS.problems.split("-"),
decode_hp,
decode_to_file=FLAGS.decode_to_file,
dataset_split="test" if FLAGS.eval_use_test_set else None)


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions tensor2tensor/bin/t2t-trainer
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def main(_):
trainer_utils.validate_flags()
output_dir = os.path.expanduser(FLAGS.output_dir)
tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
if not FLAGS.data_dir:
raise ValueError("You must specify a --data_dir")
data_dir = os.path.expanduser(FLAGS.data_dir)
tf.gfile.MakeDirs(output_dir)

Expand Down
54 changes: 36 additions & 18 deletions tensor2tensor/data_generators/algorithmic.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ def num_shards(self):
return 10

def generate_data(self, data_dir, _, task_id=-1):

def generator_eos(nbr_symbols, max_length, nbr_cases):
"""Shift by NUM_RESERVED_IDS and append EOS token."""
for case in self.generator(nbr_symbols, max_length, nbr_cases):
new_case = {}
for feature in case:
new_case[feature] = [i + text_encoder.NUM_RESERVED_TOKENS
for i in case[feature]] + [text_encoder.EOS_ID]
new_case[feature] = [
i + text_encoder.NUM_RESERVED_TOKENS for i in case[feature]
] + [text_encoder.EOS_ID]
yield new_case

utils.generate_dataset_and_shuffle(
Expand Down Expand Up @@ -154,10 +156,7 @@ def generator(self, nbr_symbols, max_length, nbr_cases):
for _ in xrange(nbr_cases):
l = np.random.randint(max_length) + 1
inputs = [np.random.randint(nbr_symbols - shift) for _ in xrange(l)]
yield {
"inputs": inputs,
"targets": [i + shift for i in inputs]
}
yield {"inputs": inputs, "targets": [i + shift for i in inputs]}

@property
def dev_length(self):
Expand Down Expand Up @@ -191,10 +190,7 @@ def generator(self, nbr_symbols, max_length, nbr_cases):
for _ in xrange(nbr_cases):
l = np.random.randint(max_length) + 1
inputs = [np.random.randint(nbr_symbols) for _ in xrange(l)]
yield {
"inputs": inputs,
"targets": list(reversed(inputs))
}
yield {"inputs": inputs, "targets": list(reversed(inputs))}


@registry.register_problem
Expand Down Expand Up @@ -272,10 +268,7 @@ def reverse_generator_nlplike(nbr_symbols,
for _ in xrange(nbr_cases):
l = int(abs(np.random.normal(loc=max_length / 2, scale=std_dev)) + 1)
inputs = zipf_random_sample(distr_map, l)
yield {
"inputs": inputs,
"targets": list(reversed(inputs))
}
yield {"inputs": inputs, "targets": list(reversed(inputs))}


@registry.register_problem
Expand All @@ -287,8 +280,8 @@ def num_symbols(self):
return 8000

def generator(self, nbr_symbols, max_length, nbr_cases):
return reverse_generator_nlplike(
nbr_symbols, max_length, nbr_cases, 10, 1.300)
return reverse_generator_nlplike(nbr_symbols, max_length, nbr_cases, 10,
1.300)

@property
def train_length(self):
Expand All @@ -308,8 +301,8 @@ def num_symbols(self):
return 32000

def generator(self, nbr_symbols, max_length, nbr_cases):
return reverse_generator_nlplike(
nbr_symbols, max_length, nbr_cases, 10, 1.050)
return reverse_generator_nlplike(nbr_symbols, max_length, nbr_cases, 10,
1.050)


def lower_endian_to_number(l, base):
Expand Down Expand Up @@ -431,3 +424,28 @@ class AlgorithmicMultiplicationDecimal40(AlgorithmicMultiplicationBinary40):
@property
def num_symbols(self):
return 10


@registry.register_problem
class AlgorithmicReverseBinary40Test(AlgorithmicReverseBinary40):
"""Test Problem with tiny dataset."""

@property
def train_length(self):
return 10

@property
def dev_length(self):
return 10

@property
def train_size(self):
return 1000

@property
def dev_size(self):
return 100

@property
def num_shards(self):
return 1
36 changes: 36 additions & 0 deletions tensor2tensor/data_generators/all_problems_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# coding=utf-8
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for Tensor2Tensor's all_problems.py."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports
from tensor2tensor.data_generators import all_problems

import tensorflow as tf


class AllProblemsTest(tf.test.TestCase):

def testImport(self):
"""Make sure that importing all_problems doesn't break."""
self.assertIsNotNone(all_problems)


if __name__ == '__main__':
tf.test.main()
1 change: 1 addition & 0 deletions tensor2tensor/data_generators/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ def generator(self, data_dir, tmp_dir, is_training):
class ImageCifar10Plain(ImageCifar10):

def preprocess_example(self, example, mode, unused_hparams):
example["inputs"] = tf.to_int64(example["inputs"])
return example


Expand Down
58 changes: 30 additions & 28 deletions tensor2tensor/data_generators/lm1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

import tensorflow as tf


# End-of-sentence marker (should correspond to the position of EOS in the
# RESERVED_TOKENS list in text_encoder.py)
EOS = 1
Expand All @@ -59,9 +58,10 @@ def _original_vocab(tmp_dir):
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
if not os.path.exists(vocab_filepath):
generator_utils.maybe_download(tmp_dir, vocab_filename, vocab_url)
return set(
[text_encoder.native_to_unicode(l.strip()) for l in
tf.gfile.Open(vocab_filepath)])
return set([
text_encoder.native_to_unicode(l.strip())
for l in tf.gfile.Open(vocab_filepath)
])


def _replace_oov(original_vocab, line):
Expand All @@ -81,19 +81,19 @@ def _replace_oov(original_vocab, line):


def _train_data_filenames(tmp_dir):
return [os.path.join(
tmp_dir,
"1-billion-word-language-modeling-benchmark-r13output",
"training-monolingual.tokenized.shuffled",
"news.en-%05d-of-00100" % i) for i in xrange(1, 100)]
return [
os.path.join(tmp_dir,
"1-billion-word-language-modeling-benchmark-r13output",
"training-monolingual.tokenized.shuffled",
"news.en-%05d-of-00100" % i) for i in xrange(1, 100)
]


def _dev_data_filename(tmp_dir):
return os.path.join(
tmp_dir,
"1-billion-word-language-modeling-benchmark-r13output",
"heldout-monolingual.tokenized.shuffled",
"news.en.heldout-00000-of-00050")
return os.path.join(tmp_dir,
"1-billion-word-language-modeling-benchmark-r13output",
"heldout-monolingual.tokenized.shuffled",
"news.en.heldout-00000-of-00050")


def _maybe_download_corpus(tmp_dir):
Expand All @@ -112,17 +112,18 @@ def _maybe_download_corpus(tmp_dir):
corpus_tar.extractall(tmp_dir)


def _get_or_build_subword_text_encoder(tmp_dir):
def _get_or_build_subword_text_encoder(tmp_dir, vocab_filepath):
"""Builds a SubwordTextEncoder based on the corpus.
Args:
tmp_dir: directory containing dataset.
vocab_filepath: path to store (or load) vocab.
Returns:
a SubwordTextEncoder.
"""
filepath = os.path.join(tmp_dir, "lm1b_32k.subword_text_encoder")
if tf.gfile.Exists(filepath):
return text_encoder.SubwordTextEncoder(filepath)
if tf.gfile.Exists(vocab_filepath):
return text_encoder.SubwordTextEncoder(vocab_filepath)
_maybe_download_corpus(tmp_dir)
original_vocab = _original_vocab(tmp_dir)
token_counts = defaultdict(int)
Expand All @@ -138,7 +139,7 @@ def _get_or_build_subword_text_encoder(tmp_dir):
break
ret = text_encoder.SubwordTextEncoder()
ret.build_from_token_counts(token_counts, min_count=5)
ret.store_to_file(filepath)
ret.store_to_file(vocab_filepath)
return ret


Expand All @@ -152,7 +153,7 @@ def is_character_level(self):

@property
def has_inputs(self):
return True
return False

@property
def input_space_id(self):
Expand Down Expand Up @@ -184,25 +185,26 @@ def targeted_vocab_size(self):
def use_train_shards_for_dev(self):
return True

def generator(self, tmp_dir, train, characters=False):
def generator(self, data_dir, tmp_dir, is_training):
"""Generator for lm1b sentences.
Args:
tmp_dir: a string.
train: a boolean.
characters: a boolean
data_dir: data dir.
tmp_dir: tmp dir.
is_training: a boolean.
Yields:
A dictionary {"inputs": [0], "targets": [<subword ids>]}
"""
_maybe_download_corpus(tmp_dir)
original_vocab = _original_vocab(tmp_dir)
files = (_train_data_filenames(tmp_dir) if train
else [_dev_data_filename(tmp_dir)])
if characters:
files = (_train_data_filenames(tmp_dir)
if is_training else [_dev_data_filename(tmp_dir)])
if self.is_character_level:
encoder = text_encoder.ByteTextEncoder()
else:
encoder = _get_or_build_subword_text_encoder(tmp_dir)
vocab_filepath = os.path.join(data_dir, self.vocab_file)
encoder = _get_or_build_subword_text_encoder(tmp_dir, vocab_filepath)
for filepath in files:
tf.logging.info("filepath = %s", filepath)
for line in tf.gfile.Open(filepath):
Expand Down
Loading

0 comments on commit ffc5800

Please sign in to comment.