Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #449 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.3
  • Loading branch information
lukaszkaiser authored Nov 29, 2017
2 parents 92983ea + 69701e4 commit e3cd447
Show file tree
Hide file tree
Showing 59 changed files with 2,656 additions and 2,105 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.2.9',
version='1.3.0',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/bin/t2t-datagen
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ _SUPPORTED_PROBLEM_GENERATORS = {
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15),
lambda: audio.timit_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False, 626,
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
}

# pylint: enable=g-long-lambda
Expand Down
6 changes: 4 additions & 2 deletions tensor2tensor/bin/t2t-decoder
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("output_dir", "", "Training directory to load from.")
flags.DEFINE_string("decode_from_file", None, "Path to the source file for decoding")
flags.DEFINE_string("decode_to_file", None, "Path to the decoded (output) file")
flags.DEFINE_string("decode_from_file", None,
"Path to the source file for decoding")
flags.DEFINE_string("decode_to_file", None,
"Path to the decoded (output) file")
flags.DEFINE_bool("decode_interactive", False,
"Interactive local inference mode.")
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")
Expand Down
21 changes: 10 additions & 11 deletions tensor2tensor/bin/t2t-trainer
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ flags.DEFINE_string("schedule", "train_and_evaluate",
"Method of tf.contrib.learn.Experiment to run.")
flags.DEFINE_bool("profile", False, "Profile performance?")


def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
Expand All @@ -85,22 +86,20 @@ def main(_):
# Run the trainer.
def run_experiment():
trainer_utils.run(
data_dir=data_dir,
model=FLAGS.model,
output_dir=output_dir,
train_steps=FLAGS.train_steps,
eval_steps=FLAGS.eval_steps,
schedule=FLAGS.schedule)
data_dir=data_dir,
model=FLAGS.model,
output_dir=output_dir,
train_steps=FLAGS.train_steps,
eval_steps=FLAGS.eval_steps,
schedule=FLAGS.schedule)

if FLAGS.profile:
with tf.contrib.tfprof.ProfileContext('t2tprof',
with tf.contrib.tfprof.ProfileContext("t2tprof",
trace_steps=range(100),
dump_steps=range(100)) as pctx:
opts = tf.profiler.ProfileOptionBuilder.time_and_memory()
pctx.add_auto_profiling('op', opts, range(100))

pctx.add_auto_profiling("op", opts, range(100))
run_experiment()

else:
run_experiment()

Expand Down
28 changes: 19 additions & 9 deletions tensor2tensor/data_generators/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from __future__ import division
from __future__ import print_function

import hashlib
import io
import os
import tarfile
import hashlib

# Dependency imports

Expand Down Expand Up @@ -129,7 +129,9 @@ def generate_hash(inp):

return filelist


def example_generator(all_files, urls_path, sum_token):
"""Generate examples."""
def fix_run_on_sents(line):
if u"@highlight" in line:
return line
Expand Down Expand Up @@ -168,30 +170,37 @@ def fix_run_on_sents(line):

yield " ".join(story) + story_summary_split_token + " ".join(summary)


def _story_summary_split(story):
split_str = u" <summary> "
split_str_len = len(split_str)
split_pos = story.find(split_str)
return story[:split_pos], story[split_pos+split_str_len:] # story, summary

def write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training):

def write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir,
is_training):
"""Write text to files."""
def write_to_file(all_files, urls_path, data_dir, filename):
with io.open(os.path.join(data_dir, filename+".source"), "w") as fstory, io.open(os.path.join(data_dir, filename+".target"), "w") as fsummary:
for example in example_generator(all_files, urls_path, sum_token=True):
story, summary = _story_summary_split(example)
fstory.write(story+"\n")
fsummary.write(summary+"\n")
with io.open(os.path.join(data_dir, filename+".source"), "w") as fstory:
with io.open(os.path.join(data_dir, filename+".target"), "w") as fsummary:
for example in example_generator(all_files, urls_path, sum_token=True):
story, summary = _story_summary_split(example)
fstory.write(story+"\n")
fsummary.write(summary+"\n")

filename = "cnndm.train" if is_training else "cnndm.dev"
tf.logging.info("Writing %s" % filename)
write_to_file(all_files, urls_path, data_dir, filename)

if not is_training:
test_urls_path = generator_utils.maybe_download(tmp_dir, "all_test.txt", _TEST_URLS)
test_urls_path = generator_utils.maybe_download(
tmp_dir, "all_test.txt", _TEST_URLS)
filename = "cnndm.test"
tf.logging.info("Writing %s" % filename)
write_to_file(all_files, test_urls_path, data_dir, filename)


@registry.register_problem
class SummarizeCnnDailymail32k(problem.Text2TextProblem):
"""Summarize CNN and Daily Mail articles to their summary highlights."""
Expand Down Expand Up @@ -237,7 +246,8 @@ def generator(self, data_dir, tmp_dir, is_training):
encoder = generator_utils.get_or_generate_vocab_inner(
data_dir, self.vocab_file, self.targeted_vocab_size,
example_generator(all_files, urls_path, sum_token=False))
write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training)
write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir,
is_training)
for example in example_generator(all_files, urls_path, sum_token=True):
story, summary = _story_summary_split(example)
encoded_summary = encoder.encode(summary) + [EOS]
Expand Down
65 changes: 65 additions & 0 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,68 @@ def shuffle_dataset(filenames):
out_fname = fname.replace(UNSHUFFLED_SUFFIX, "")
write_records(records, out_fname)
tf.gfile.Remove(fname)


def combine_examples_no_inputs(examples, max_length):
"""Combine examples into longer examples.
Concatenate targets to form target sequences with length up to max_length.
Target sequences longer than max_length are chopped into multiple sequences.
Args:
examples: a generator returning feature dictionaries.
max_length: an integer.
Yields:
feature dictionaries.
"""
partial = []
for example in examples:
x = example["targets"]
if len(x) + len(partial) > max_length:
if partial:
yield {"inputs": [0], "targets": partial}
partial = []
if len(x) > max_length:
num_fragments = len(x) // max_length
for i in xrange(num_fragments):
yield {"inputs": [0], "targets": x[max_length * i:max_length * (i + 1)]}
partial = x[max_length * num_fragments:]
else:
partial += x
if partial:
yield {"inputs": [0], "targets": partial}


def combine_examples_with_inputs(examples, max_length):
"""Combine examples into longer examples.
We combine multiple examples by concatenating the inputs and concatenating
the targets. Sequences where the inputs or the targets are too long are
emitted as singletons (not chopped).
Args:
examples: a generator returning feature dictionaries.
max_length: an integer.
Yields:
feature dictionaries.
"""
partial_a = []
partial_b = []
for example in examples:
a = example["inputs"]
b = example["targets"]
if (len(a) + len(partial_a) > max_length or
len(b) + len(partial_b) > max_length):
if partial_a or partial_b:
yield {"inputs": partial_a, "targets": partial_b}
partial_a = []
partial_b = []
if len(a) > max_length or len(b) > max_length:
yield {"inputs": a, "targets": b}
else:
partial_a += a
partial_b += b
if partial_a or partial_b:
yield {"inputs": partial_a, "targets": partial_b}
Loading

0 comments on commit e3cd447

Please sign in to comment.