Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #222 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.1.8
  • Loading branch information
lukaszkaiser authored Aug 11, 2017
2 parents b669110 + 8abc5d2 commit 45a787e
Show file tree
Hide file tree
Showing 36 changed files with 944 additions and 1,656 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.1.7',
version='1.1.8',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand All @@ -19,6 +19,7 @@
'tensor2tensor/bin/t2t-make-tf-configs',
],
install_requires=[
'bz2file',
'numpy',
'requests',
'sympy',
Expand Down
51 changes: 13 additions & 38 deletions tensor2tensor/bin/t2t-datagen
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import image
from tensor2tensor.data_generators import lm1b
from tensor2tensor.data_generators import snli
from tensor2tensor.data_generators import wiki
from tensor2tensor.data_generators import wmt
from tensor2tensor.data_generators import wsj_parsing
from tensor2tensor.utils import registry
Expand Down Expand Up @@ -105,10 +104,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
lambda: lm1b.generator(FLAGS.tmp_dir, True, characters=True),
lambda: lm1b.generator(FLAGS.tmp_dir, False, characters=True)
),
"wiki_32k": (
lambda: wiki.generator(FLAGS.tmp_dir, True),
1000
),
"image_celeba_tune": (
lambda: image.celeba_generator(FLAGS.tmp_dir, 162770),
lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)),
Expand Down Expand Up @@ -170,17 +165,14 @@ def main(_):
# Remove parsing if paths are not given.
if not FLAGS.parsing_path:
problems = [p for p in problems if "parsing" not in p]
# Remove en-de BPE if paths are not given.
if not FLAGS.ende_bpe_path:
problems = [p for p in problems if "ende_bpe" not in p]

if not problems:
problems_str = "\n * ".join(
sorted(list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems()))
error_msg = ("You must specify one of the supported problems to "
"generate data for:\n * " + problems_str + "\n")
error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
"--timit_paths, --ende_bpe_path and --parsing_path.")
error_msg += ("TIMIT and parsing need data_sets specified with "
"--timit_paths and --parsing_path.")
raise ValueError(error_msg)

if not FLAGS.data_dir:
Expand All @@ -203,34 +195,17 @@ def generate_data_for_problem(problem):
"""Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]

if isinstance(dev_gen, int):
# The dev set and test sets are generated as extra shards using the
# training generator. The integer specifies the number of training
# shards. FLAGS.num_shards is ignored.
num_training_shards = dev_gen
tf.logging.info("Generating data for %s.", problem)
all_output_files = generator_utils.combined_data_filenames(
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
num_training_shards)
generator_utils.generate_files(training_gen(), all_output_files,
FLAGS.max_cases)
else:
# usual case - train data and dev data are generated using separate
# generators.
num_shards = FLAGS.num_shards or 10
tf.logging.info("Generating training data for %s.", problem)
train_output_files = generator_utils.train_data_filenames(
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards)
generator_utils.generate_files(training_gen(), train_output_files,
FLAGS.max_cases)
tf.logging.info("Generating development data for %s.", problem)
dev_shards = 10 if "coco" in problem else 1
dev_output_files = generator_utils.dev_data_filenames(
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards)
generator_utils.generate_files(dev_gen(), dev_output_files)
all_output_files = train_output_files + dev_output_files

tf.logging.info("Shuffling data...")
num_shards = FLAGS.num_shards or 10
tf.logging.info("Generating training data for %s.", problem)
train_output_files = generator_utils.train_data_filenames(
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards)
generator_utils.generate_files(training_gen(), train_output_files,
FLAGS.max_cases)
tf.logging.info("Generating development data for %s.", problem)
dev_output_files = generator_utils.dev_data_filenames(
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, 1)
generator_utils.generate_files(dev_gen(), dev_output_files)
all_output_files = train_output_files + dev_output_files
generator_utils.shuffle_dataset(all_output_files)


Expand Down
Empty file modified tensor2tensor/bin/t2t-trainer
100755 → 100644
Empty file.
Loading

0 comments on commit 45a787e

Please sign in to comment.