Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #201 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.1.4
  • Loading branch information
lukaszkaiser authored Aug 2, 2017
2 parents 0df0f50 + 41bca68 commit c35c7a3
Show file tree
Hide file tree
Showing 43 changed files with 1,012 additions and 657 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ _pycache__/
# PyPI distribution artifacts.
build/
dist/
data/

# Sublime project files
*.sublime-project
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.1.3',
version='1.1.4',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
31 changes: 0 additions & 31 deletions tensor2tensor/bin/t2t-datagen
Original file line number Diff line number Diff line change
Expand Up @@ -118,40 +118,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
lambda: wiki.generator(FLAGS.tmp_dir, True),
1000
),
"image_mnist_tune": (
lambda: image.mnist_generator(FLAGS.tmp_dir, True, 55000),
lambda: image.mnist_generator(FLAGS.tmp_dir, True, 5000, 55000)),
"image_mnist_test": (
lambda: image.mnist_generator(FLAGS.tmp_dir, True, 60000),
lambda: image.mnist_generator(FLAGS.tmp_dir, False, 10000)),
"image_cifar10_tune": (
lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 48000),
lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 2000, 48000)),
"image_cifar10_test": (
lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 50000),
lambda: image.cifar10_generator(FLAGS.tmp_dir, False, 10000)),
"image_mscoco_characters_test": (
lambda: image.mscoco_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True, 80000),
lambda: image.mscoco_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False, 40000)),
"image_celeba_tune": (
lambda: image.celeba_generator(FLAGS.tmp_dir, 162770),
lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)),
"image_mscoco_tokens_8k_test": (
lambda: image.mscoco_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True, 80000,
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13),
lambda: image.mscoco_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False, 40000,
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13)),
"image_mscoco_tokens_32k_test": (
lambda: image.mscoco_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True, 80000,
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15),
lambda: image.mscoco_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False, 40000,
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
"snli_32k": (
lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ for an example.

`Problem`s support data generation, training, and decoding.

Data generation is handles by `Problem.generate_data` which should produce 2
Data generation is handled by `Problem.generate_data` which should produce 2
datasets, training and dev, which should be named according to
`Problem.training_filepaths` and `Problem.dev_filepaths`.
`Problem.generate_data` should also produce any other files that may be required
Expand Down
3 changes: 2 additions & 1 deletion tensor2tensor/data_generators/gene_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,9 @@ def example_reading_spec(self):
data_items_to_decoders = None
return (data_fields, data_items_to_decoders)

def preprocess_examples(self, examples, mode):
def preprocess_examples(self, examples, mode, hparams):
del mode
del hparams

# Reshape targets
examples["targets"] = tf.reshape(examples["targets"],
Expand Down
51 changes: 51 additions & 0 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

# Dependency imports

import requests
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import six.moves.urllib_request as urllib # Imports urllib on Python2, urllib.request on Python3
Expand Down Expand Up @@ -196,6 +197,56 @@ def maybe_download(directory, filename, url):
return filepath


def maybe_download_from_drive(directory, filename, url):
"""Download filename from google drive unless it's already in directory.
Args:
directory: path to the directory that will be used.
filename: name of the file to download to (do nothing if it already exists).
url: URL to download from.
Returns:
The path to the downloaded file.
"""
if not tf.gfile.Exists(directory):
tf.logging.info("Creating directory %s" % directory)
os.mkdir(directory)
filepath = os.path.join(directory, filename)
confirm_token = None
if tf.gfile.Exists(filepath):
tf.logging.info("Not downloading, file already found: %s" % filepath)
return filepath

# Since the file is big, drive will scan it for virus and take it to a
# warning page. We find the confirm token on this page and append it to the
# URL to start the download process.
confirm_token = None
session = requests.Session()
response = session.get(url, stream=True)
for k, v in response.cookies.items():
if k.startswith("download_warning"):
confirm_token = v

if confirm_token:
url = url + "&confirm=" + confirm_token
tf.logging.info("Downloading %s to %s" % (url, filepath))

response = session.get(url, stream=True)
# Now begin the download.
chunk_size = 16 * 1024
with open(filepath, "wb") as f:
for chunk in response.iter_content(chunk_size):
if chunk:
f.write(chunk)

# Print newline to clear the carriage return from the download progress
print()
statinfo = os.stat(filepath)
tf.logging.info("Succesfully downloaded %s, %s bytes." % (filename,
statinfo.st_size))
return filepath


def gunzip_file(gz_path, new_path):
"""Unzips from gz_path into new_path.
Expand Down
14 changes: 14 additions & 0 deletions tensor2tensor/data_generators/generator_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ def testMaybeDownload(self):
os.remove(tmp_file_path + ".http")
os.remove(tmp_file_path)

def testMaybeDownloadFromDrive(self):
tmp_dir = self.get_temp_dir()
(_, tmp_file_path) = tempfile.mkstemp(dir=tmp_dir)
tmp_file_name = os.path.basename(tmp_file_path)

# Download Google index to the temporary file.http.
res_path = generator_utils.maybe_download_from_drive(
tmp_dir, tmp_file_name + ".http", "http://drive.google.com")
self.assertEqual(res_path, tmp_file_path + ".http")

# Clean up.
os.remove(tmp_file_path + ".http")
os.remove(tmp_file_path)

def testGunzipFile(self):
tmp_dir = self.get_temp_dir()
(_, tmp_file_path) = tempfile.mkstemp(dir=tmp_dir)
Expand Down
Loading

0 comments on commit c35c7a3

Please sign in to comment.