Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #228 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.1.9
  • Loading branch information
lukaszkaiser authored Aug 17, 2017
2 parents 45a787e + f5d5405 commit 8f3a7fd
Show file tree
Hide file tree
Showing 33 changed files with 1,464 additions and 320 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.1.8',
version='1.1.9',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
22 changes: 14 additions & 8 deletions tensor2tensor/bin/t2t-datagen
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ flags.DEFINE_integer("num_shards", 0, "How many shards to use. Ignored for "
"registered Problems.")
flags.DEFINE_integer("max_cases", 0,
"Maximum number of cases to generate (unbounded if 0).")
flags.DEFINE_bool("only_list", False,
"If true, we only list the problems that will be generated.")
flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
flags.DEFINE_integer("task_id", -1, "For distributed data generation.")
flags.DEFINE_string("t2t_usr_dir", "",
Expand All @@ -81,33 +83,33 @@ _SUPPORTED_PROBLEM_GENERATORS = {
"algorithmic_algebra_inverse": (
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
"wmt_parsing_tokens_8k": (
"parsing_english_ptb8k": (
lambda: wmt.parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13),
lambda: wmt.parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13)),
"wsj_parsing_tokens_16k": (
"parsing_english_ptb16k": (
lambda: wsj_parsing.parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**14, 2**9),
lambda: wsj_parsing.parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**14, 2**9)),
"wmt_ende_bpe32k": (
"translate_ende_wmt_bpe32k": (
lambda: wmt.ende_bpe_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True),
lambda: wmt.ende_bpe_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False)),
"lm1b_32k": (
"languagemodel_1b32k": (
lambda: lm1b.generator(FLAGS.tmp_dir, True),
lambda: lm1b.generator(FLAGS.tmp_dir, False)
),
"lm1b_characters": (
"languagemodel_1b_characters": (
lambda: lm1b.generator(FLAGS.tmp_dir, True, characters=True),
lambda: lm1b.generator(FLAGS.tmp_dir, False, characters=True)
),
"image_celeba_tune": (
lambda: image.celeba_generator(FLAGS.tmp_dir, 162770),
lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)),
"snli_32k": (
"inference_snli32k": (
lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
),
Expand Down Expand Up @@ -181,7 +183,11 @@ def main(_):
"Data will be written to default data_dir=%s.",
FLAGS.data_dir)

tf.logging.info("Generating problems:\n * %s\n" % "\n * ".join(problems))
tf.logging.info("Generating problems:\n%s"
% registry.display_list_by_prefix(problems,
starting_spaces=4))
if FLAGS.only_list:
return
for problem in problems:
set_random_seed()

Expand Down Expand Up @@ -210,7 +216,7 @@ def generate_data_for_problem(problem):


def generate_data_for_registered_problem(problem_name):
tf.logging.info("Generating training data for %s.", problem_name)
tf.logging.info("Generating data for %s.", problem_name)
if FLAGS.num_shards:
raise ValueError("--num_shards should not be set for registered Problem.")
problem = registry.problem(problem_name)
Expand Down
8 changes: 4 additions & 4 deletions tensor2tensor/data_generators/cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


@registry.register_problem
class CipherShift5(algorithmic.AlgorithmicProblem):
class AlgorithmicCipherShift5(algorithmic.AlgorithmicProblem):
"""Shift cipher."""

@property
Expand Down Expand Up @@ -62,7 +62,7 @@ def dev_length(self):


@registry.register_problem
class CipherVigenere5(algorithmic.AlgorithmicProblem):
class AlgorithmicCipherVigenere5(algorithmic.AlgorithmicProblem):
"""Vinegre cipher."""

@property
Expand Down Expand Up @@ -95,7 +95,7 @@ def dev_length(self):


@registry.register_problem
class CipherShift200(CipherShift5):
class AlgorithmicCipherShift200(AlgorithmicCipherShift5):
"""Shift cipher."""

@property
Expand All @@ -110,7 +110,7 @@ def distribution(self):


@registry.register_problem
class CipherVigenere200(CipherVigenere5):
class AlgorithmicCipherVigenere200(AlgorithmicCipherVigenere5):
"""Vinegre cipher."""

@property
Expand Down
8 changes: 4 additions & 4 deletions tensor2tensor/data_generators/desc2code.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def generator_target():
}


@registry.register_problem("desc2code_py")
class Desc2CodePyProblem(Desc2CodeProblem):
@registry.register_problem
class ProgrammingDesc2codePy(Desc2CodeProblem):
"""Description2Code for python problem."""

@property
Expand All @@ -222,8 +222,8 @@ def preprocess_target(self, target):
return target.replace("\t", " ")


@registry.register_problem("desc2code_cpp")
class Desc2CodeCppProblem(Desc2CodeProblem):
@registry.register_problem
class ProgrammingDesc2codeCpp(Desc2CodeProblem):
"""Description2Code for C++ problem."""

@property
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/desc2code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Desc2codeTest(tf.test.TestCase):

def testCppPreprocess(self):
"""Check that the file correctly preprocess the code source."""
cpp_pb = desc2code.Desc2CodeCppProblem()
cpp_pb = desc2code.ProgrammingDesc2codeCpp()

self.assertEqual( # Add space beween two lines
cpp_pb.preprocess_target("firstline//comm1\nsecondline//comm2\n"),
Expand Down
12 changes: 6 additions & 6 deletions tensor2tensor/data_generators/gene_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def eval_metrics(self):
return [metrics.Metrics.LOG_POISSON, metrics.Metrics.R2]


@registry.register_problem("gene_expression_cage10")
class GeneExpressionCAGE10(GeneExpressionProblem):
@registry.register_problem
class GenomicsExpressionCage10(GeneExpressionProblem):

@property
def download_url(self):
Expand All @@ -188,8 +188,8 @@ def h5_file(self):
return "cage10.h5"


@registry.register_problem("gene_expression_gm12878")
class GeneExpressionGM12878(GeneExpressionProblem):
@registry.register_problem
class GenomicsExpressionGm12878(GeneExpressionProblem):

@property
def download_url(self):
Expand All @@ -200,8 +200,8 @@ def h5_file(self):
return "gm12878.h5"


@registry.register_problem("gene_expression_l262k")
class GeneExpressionL262k(GeneExpressionProblem):
@registry.register_problem
class GenomicsExpressionL262k(GeneExpressionProblem):

@property
def h5_file(self):
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/data_generators/ice_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def tabbed_parsing_character_generator(tmp_dir, train):
return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS)


@registry.register_problem("ice_parsing_tokens")
class IceParsingTokens(problem.Problem):
@registry.register_problem
class ParsingIcelandic16k(problem.Problem):
"""Problem spec for parsing tokenized Icelandic text to constituency trees."""

@property
Expand Down
19 changes: 15 additions & 4 deletions tensor2tensor/data_generators/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,21 @@ def dataset_filename(self):
def is_small(self):
return True # Modalities like for CIFAR.

def preprocess_examples(self, examples, mode):
examples = imagenet_preprocess_examples(examples, mode)
examples["inputs"] = tf.to_int64(
tf.image.resize_images(examples["inputs"], [32, 32]))
@property
def num_classes(self):
return 1000

def preprocess_examples(self, examples, mode, hparams):
# Just resize with area.
if self._was_reversed:
examples["inputs"] = tf.to_int64(
tf.image.resize_images(examples["inputs"], [32, 32],
tf.image.ResizeMethod.AREA))
else:
examples = imagenet_preprocess_examples(examples, mode)
examples["inputs"] = tf.to_int64(
tf.image.resize_images(examples["inputs"], [32, 32]))
return examples


def image_generator(images, labels):
Expand Down
13 changes: 7 additions & 6 deletions tensor2tensor/data_generators/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def preprocess_examples_common(examples, hparams):
examples["inputs"] = examples["inputs"][:hparams.max_input_seq_length]
if hparams.max_target_seq_length > 0:
examples["targets"] = examples["targets"][:hparams.max_target_seq_length]
if hparams.prepend_inputs_to_targets:
if hparams.prepend_mode != "none":
examples["targets"] = tf.concat(
[examples["inputs"], [0], examples["targets"]], 0)
return examples
Expand Down Expand Up @@ -410,11 +410,12 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
generator_utils.generate_files(
self.generator(data_dir, tmp_dir, True), all_paths)
generator_utils.shuffle_dataset(all_paths)
generator_utils.generate_dataset_and_shuffle(
self.generator(data_dir, tmp_dir, True),
self.training_filepaths(data_dir, self.num_shards, shuffled=False),
self.generator(data_dir, tmp_dir, False),
self.dev_filepaths(data_dir, self.num_dev_shards, shuffled=False))
else:
generator_utils.generate_dataset_and_shuffle(
self.generator(data_dir, tmp_dir, True),
self.training_filepaths(data_dir, self.num_shards, shuffled=False),
self.generator(data_dir, tmp_dir, False),
self.dev_filepaths(data_dir, self.num_dev_shards, shuffled=False))

def feature_encoders(self, data_dir):
if self.is_character_level:
Expand Down
10 changes: 5 additions & 5 deletions tensor2tensor/data_generators/problem_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,16 +492,16 @@ def image_celeba(unused_model_hparams):
lambda p: audio_wsj_tokens(p, 2**13),
"audio_wsj_tokens_8k_test":
lambda p: audio_wsj_tokens(p, 2**13),
"lm1b_characters":
"languagemodel_1b_characters":
lm1b_characters,
"lm1b_32k":
"languagemodel_1b32k":
lm1b_32k,
"wmt_parsing_tokens_8k":
"parsing_english_ptb8k":
lambda p: wmt_parsing_tokens(p, 2**13),
"wsj_parsing_tokens_16k":
"parsing_english_ptb16k":
lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda
p, "wsj", 2**14, 2**9),
"wmt_ende_bpe32k":
"translate_ende_wmt_bpe32k":
wmt_ende_bpe32k,
"image_celeba_tune":
image_celeba,
Expand Down
6 changes: 3 additions & 3 deletions tensor2tensor/data_generators/ptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def _generator(self, filename, encoder):
yield {"inputs": [0], "targets": tok}


@registry.register_problem("lm_ptb_10k")
class LmPtb10k(PTBProblem):
@registry.register_problem
class LanguagemodelPtb10k(PTBProblem):
"""A class for generating PTB data, 10k vocab."""

@property
Expand All @@ -167,7 +167,7 @@ def is_character_level(self):


@registry.register_problem
class LmPtbCharacters(PTBProblem):
class LanguagemodelPtbCharacters(PTBProblem):
"""A class for generating PTB data, character-level."""

@property
Expand Down
13 changes: 11 additions & 2 deletions tensor2tensor/data_generators/wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def _page_title(page):


@registry.register_problem
class Wiki32k(problem.Text2TextProblem):
"""A class for generating PTB data."""
class LanguagemodelWikiFull32k(problem.Text2TextProblem):
"""A language model on full English Wikipedia."""

@property
def is_character_level(self):
Expand Down Expand Up @@ -129,3 +129,12 @@ def generator(self, data_dir, tmp_dir, _):
encoded = encoder.encode(page) + [EOS]
encoded_title = encoder.encode(title) + [EOS]
yield {"inputs": encoded_title, "targets": encoded}


@registry.register_problem
class LanguagemodelWikiFull8k(problem.Text2TextProblem):
"""A language model on full English Wikipedia."""

@property
def targeted_vocab_size(self):
return 2**13 # 8192
Loading

0 comments on commit 8f3a7fd

Please sign in to comment.