Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #54 from rsepassi/push
Browse files Browse the repository at this point in the history
Push v1.0.8
  • Loading branch information
lukaszkaiser authored Jun 27, 2017
2 parents 49f5afc + 8226f15 commit de38b16
Show file tree
Hide file tree
Showing 30 changed files with 156 additions and 143 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
# Python egg metadata, regenerated from source files by setuptools.
/*.egg-info

# PyPI distribution artifacts
# PyPI distribution artificats
build/
dist/
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.0.7',
version='1.0.8',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
3 changes: 2 additions & 1 deletion tensor2tensor/bin/make_tf_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def main(_):
for idx, job in enumerate(jobs):
if task_type == "worker":
cmd_line_flags = " ".join([
"--master=%s" % job,
"--master=grpc://%s" % job,
"--ps_replicas=%d" % len(ps),
"--worker_replicas=%d" % len(workers),
"--worker_gpu=1",
Expand All @@ -66,6 +66,7 @@ def main(_):
])
else:
cmd_line_flags = " ".join([
"--master=grpc://%s" % job,
"--schedule=run_std_server",
])

Expand Down
14 changes: 7 additions & 7 deletions tensor2tensor/data_generators/algorithmic_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,16 +570,16 @@ def calculus_integrate(alphabet_size=26,

functions = {"log": "L"}
alg_cfg = math_dataset_init(alphabet_size, digits=5, functions=functions)
nbr_case=0
nbr_case = 0
while nbr_case < nbr_cases:
try:
sample, target = generate_calculus_integrate_sample(
alg_cfg.vlist,
list(alg_cfg.ops.values()), min_depth, max_depth, alg_cfg.functions)
alg_cfg.vlist,
list(alg_cfg.ops.values()), min_depth, max_depth, alg_cfg.functions)
yield {
"inputs": alg_cfg.int_encoder(sample),
"targets": alg_cfg.int_encoder(target)
"inputs": alg_cfg.int_encoder(sample),
"targets": alg_cfg.int_encoder(target)
}
except:
except: # pylint:disable=bare-except
continue
nbr_case = nbr_case + 1
nbr_case += 1
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/generator_utils.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import six
from six.moves import xrange # pylint: disable=redefined-builtin
import six.moves.urllib_request as urllib # Imports urllib on Python2, urllib.request on Python3
import six.moves.urllib_request as urllib # Imports urllib on Python2, urllib.request on Python3

from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder
from tensor2tensor.data_generators.tokenizer import Tokenizer
Expand Down
2 changes: 2 additions & 0 deletions tensor2tensor/data_generators/wmt.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import tensorflow as tf


# End-of-sentence marker (should correspond to the position of EOS in the
# RESERVED_TOKENS list in text_encoder.py)
EOS = 1
Expand All @@ -44,6 +45,7 @@ def character_generator(source_path, target_path, character_vocab, eos=None):
Args:
source_path: path to the file with source sentences.
target_path: path to the file with target sentences.
character_vocab: a TextEncoder to encode the characters.
eos: integer to append at the end of each sequence (default: None).
Yields:
Expand Down
11 changes: 5 additions & 6 deletions tensor2tensor/data_generators/wmt_test.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
# Dependency imports

import six
from tensor2tensor.data_generators import wmt
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import wmt

import tensorflow as tf

Expand All @@ -40,7 +40,7 @@ def testCharacterGenerator(self):
if six.PY2:
enc_f = lambda s: s
else:
enc_f = lambda s: s.encode('utf-8')
enc_f = lambda s: s.encode("utf-8")
with io.open(tmp_file_path + ".src", "wb") as src_file:
src_file.write(enc_f("source1\n"))
src_file.write(enc_f("source2\n"))
Expand All @@ -51,16 +51,15 @@ def testCharacterGenerator(self):
# Call character generator on the generated files.
results_src, results_tgt = [], []
character_vocab = text_encoder.ByteTextEncoder()
for dictionary in wmt.character_generator(tmp_file_path + ".src",
tmp_file_path + ".tgt",
character_vocab):
for dictionary in wmt.character_generator(
tmp_file_path + ".src", tmp_file_path + ".tgt", character_vocab):
self.assertEqual(sorted(list(dictionary)), ["inputs", "targets"])
results_src.append(dictionary["inputs"])
results_tgt.append(dictionary["targets"])

# Check that the results match the files.
# First check that the results match the encoded original strings;
# this is a comparison of integer arrays
# this is a comparison of integer arrays.
self.assertEqual(len(results_src), 2)
self.assertEqual(results_src[0],
character_vocab.encode("source1"))
Expand Down
13 changes: 12 additions & 1 deletion tensor2tensor/docs/distributed_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ os.environ['TF_CONFIG'] = json.dumps({
The following T2T command-line flags must also be set on the workers for
distributed training:

- `--master=$ADDRESS`
- `--master=grpc://$ADDRESS`
- `--worker_replicas=$NUM_WORKERS`
- `--worker_gpu=$NUM_GPUS_PER_WORKER`
- `--worker_id=$WORKER_ID`
Expand All @@ -55,6 +55,17 @@ Parameter servers only need `--schedule=run_std_server`.
generates the `TF_CONFIG` json strings and the above-mentioned command-line
flags for the workers and parameter servers.

Given a set of worker and parameter server addresses, the script outputs, for
each job, a line with the `TF_CONFIG` environment variable and the command-line
flags necessary for distributed training. For each job, you should invoke the
`t2t-trainer` with the `TF_CONFIG` value and flags that are output.

For example:

```
TF_CONFIG=$JOB_TF_CONFIG t2t-trainer $JOB_FLAGS --model=transformer ...
```

## Command-line flags for eval jobs

Eval jobs should set the following flags and do not need the `TF_CONFIG`
Expand Down
14 changes: 5 additions & 9 deletions tensor2tensor/models/attention_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from __future__ import division
from __future__ import print_function

import copy

# Dependency imports

from six.moves import xrange # pylint: disable=redefined-builtin
Expand All @@ -43,13 +41,9 @@
class AttentionLM(t2t_model.T2TModel):
"""Attention net. See file docstring."""

def model_fn_body(self, features, train):
def model_fn_body(self, features):
# Remove dropout if not training
hparams = copy.copy(self._hparams)
if not train:
hparams.attention_dropout = 0.
hparams.relu_dropout = 0.
hparams.residual_dropout = 0.
hparams = self._hparams
targets = features["targets"]
targets = tf.squeeze(targets, 2)

Expand Down Expand Up @@ -162,8 +156,10 @@ def attention_lm_base():
hparams.add_hparam("num_heads", 8)
hparams.add_hparam("attention_key_channels", 0)
hparams.add_hparam("attention_value_channels", 0)
# All hyperparameters ending in "dropout" are automatically set to 0.0
# when not in training mode.
hparams.add_hparam("attention_dropout", 0.0)
hparams.add_hparam("relu_dropout", 0.0)
hparams.add_hparam("pos", "timing") # timing, none
hparams.add_hparam("residual_dropout", 0.1)
hparams.add_hparam("pos", "timing") # timing, none
return hparams
18 changes: 8 additions & 10 deletions tensor2tensor/models/attention_lm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from __future__ import division
from __future__ import print_function

import copy

# Dependency imports

from six.moves import xrange # pylint: disable=redefined-builtin
Expand All @@ -43,13 +41,9 @@
class AttentionLmMoe(t2t_model.T2TModel):
"""Attention net. See file docstring."""

def model_fn_body_sharded(self, sharded_features, train):
def model_fn_body_sharded(self, sharded_features):
# Remove dropout if not training
hparams = copy.copy(self._hparams)
if not train:
hparams.attention_dropout = 0.
hparams.relu_dropout = 0.
hparams.residual_dropout = 0.
hparams = self._hparams
dp = self._data_parallelism
targets = sharded_features["targets"]
targets = dp(tf.squeeze, targets, 2)
Expand Down Expand Up @@ -81,7 +75,9 @@ def residual_fn(x, y):
with tf.variable_scope("ffn"):
if str(layer) in hparams.moe_layers.split(","):
y, loss = common_layers.moe_layer(
dp, self._ps_devices, x, train, hparams.hidden_size,
dp, self._ps_devices, x,
hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
hparams.hidden_size,
hparams.moe_hidden_size, hparams.moe_n1, hparams.moe_n2,
hparams.moe_loss_coef)
extra_loss += loss
Expand Down Expand Up @@ -162,10 +158,12 @@ def attention_lm_moe_base():
hparams.add_hparam("num_heads", 8)
hparams.add_hparam("attention_key_channels", 0)
hparams.add_hparam("attention_value_channels", 0)
# All hyperparameters ending in "dropout" are automatically set to 0.0
# when not in training mode.
hparams.add_hparam("attention_dropout", 0.0)
hparams.add_hparam("relu_dropout", 0.0)
hparams.add_hparam("pos", "timing") # timing, none
hparams.add_hparam("residual_dropout", 0.1)
hparams.add_hparam("pos", "timing") # timing, none
return hparams


Expand Down
43 changes: 21 additions & 22 deletions tensor2tensor/models/bluenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import tensorflow as tf


def residual_module(x, hparams, train, n, sep):
def residual_module(x, hparams, n, sep):
"""A stack of convolution blocks with residual connection."""
k = (hparams.kernel_height, hparams.kernel_width)
dilations_and_kernels = [((1, 1), k) for _ in xrange(n)]
Expand All @@ -43,56 +43,55 @@ def residual_module(x, hparams, train, n, sep):
separability=sep,
name="block")
x = common_layers.layer_norm(x + y, hparams.hidden_size, name="lnorm")
return tf.nn.dropout(x, 1.0 - hparams.dropout * tf.to_float(train))
return tf.nn.dropout(x, 1.0 - hparams.dropout)


def residual_module1(x, hparams, train):
return residual_module(x, hparams, train, 1, 1)
def residual_module1(x, hparams):
return residual_module(x, hparams, 1, 1)


def residual_module1_sep(x, hparams, train):
return residual_module(x, hparams, train, 1, 0)
def residual_module1_sep(x, hparams):
return residual_module(x, hparams, 1, 0)


def residual_module2(x, hparams, train):
return residual_module(x, hparams, train, 2, 1)
def residual_module2(x, hparams):
return residual_module(x, hparams, 2, 1)


def residual_module2_sep(x, hparams, train):
return residual_module(x, hparams, train, 2, 0)
def residual_module2_sep(x, hparams):
return residual_module(x, hparams, 2, 0)


def residual_module3(x, hparams, train):
return residual_module(x, hparams, train, 3, 1)
def residual_module3(x, hparams):
return residual_module(x, hparams, 3, 1)


def residual_module3_sep(x, hparams, train):
return residual_module(x, hparams, train, 3, 0)
def residual_module3_sep(x, hparams):
return residual_module(x, hparams, 3, 0)


def norm_module(x, hparams, train):
del train # Unused.
def norm_module(x, hparams):
return common_layers.layer_norm(x, hparams.hidden_size, name="norm_module")


def identity_module(x, hparams, train):
del hparams, train # Unused.
def identity_module(x, hparams):
del hparams # Unused.
return x


def run_modules(blocks, cur, hparams, train, dp):
def run_modules(blocks, cur, hparams, dp):
"""Run blocks in parallel using dp as data_parallelism."""
assert len(blocks) % dp.n == 0
res = []
for i in xrange(len(blocks) // dp.n):
res.extend(dp(blocks[i * dp.n:(i + 1) * dp.n], cur, hparams, train))
res.extend(dp(blocks[i * dp.n:(i + 1) * dp.n], cur, hparams))
return res


@registry.register_model
class BlueNet(t2t_model.T2TModel):

def model_fn_body_sharded(self, sharded_features, train):
def model_fn_body_sharded(self, sharded_features):
dp = self._data_parallelism
dp._reuse = False # pylint:disable=protected-access
hparams = self._hparams
Expand All @@ -106,7 +105,7 @@ def model_fn_body_sharded(self, sharded_features, train):
cur_shape = cur.get_shape()
for i in xrange(hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % i):
processed = run_modules(blocks, cur, hparams, train, dp)
processed = run_modules(blocks, cur, hparams, dp)
cur = common_layers.shakeshake(processed)
cur.set_shape(cur_shape)

Expand Down
5 changes: 3 additions & 2 deletions tensor2tensor/models/bluenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def testBlueNet(self):
"inputs": tf.constant(x, dtype=tf.int32),
"targets": tf.constant(y, dtype=tf.int32),
}
model = bluenet.BlueNet(hparams, p_hparams)
sharded_logits, _, _ = model.model_fn(features, True)
model = bluenet.BlueNet(
hparams, tf.contrib.learn.ModeKeys.TRAIN, p_hparams)
sharded_logits, _, _ = model.model_fn(features)
logits = tf.concat(sharded_logits, 0)
session.run(tf.global_variables_initializer())
res = session.run(logits)
Expand Down
14 changes: 7 additions & 7 deletions tensor2tensor/models/bytenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import tensorflow as tf


def residual_dilated_conv(x, repeat, padding, name, hparams, train):
def residual_dilated_conv(x, repeat, padding, name, hparams):
"""A stack of convolution blocks with residual connections."""
with tf.variable_scope(name):
k = (hparams.kernel_height, hparams.kernel_width)
Expand All @@ -45,11 +45,11 @@ def residual_dilated_conv(x, repeat, padding, name, hparams, train):
padding=padding,
name="residual_conv")
x = common_layers.layer_norm(x + y, hparams.hidden_size, name="lnorm")
x = tf.nn.dropout(x, 1.0 - hparams.dropout * tf.to_float(train))
x = tf.nn.dropout(x, hparams.dropout)
return x


def bytenet_internal(inputs, targets, hparams, train):
def bytenet_internal(inputs, targets, hparams):
"""ByteNet, main step used for training."""
with tf.variable_scope("bytenet"):
# Flatten inputs and extend length by 50%.
Expand All @@ -63,7 +63,7 @@ def bytenet_internal(inputs, targets, hparams, train):
inputs, targets = common_layers.pad_to_same_length(
inputs, targets, final_length_divisible_by=50)
final_encoder = residual_dilated_conv(
inputs, hparams.num_block_repeat, "SAME", "encoder", hparams, train)
inputs, hparams.num_block_repeat, "SAME", "encoder", hparams)

shifted_targets = common_layers.shift_left(targets)
kernel = (hparams.kernel_height, hparams.kernel_width)
Expand All @@ -74,15 +74,15 @@ def bytenet_internal(inputs, targets, hparams, train):

return residual_dilated_conv(
decoder_start, hparams.num_block_repeat,
"LEFT", "decoder", hparams, train)
"LEFT", "decoder", hparams)


@registry.register_model
class ByteNet(t2t_model.T2TModel):

def model_fn_body(self, features, train):
def model_fn_body(self, features):
return bytenet_internal(features["inputs"], features["targets"],
self._hparams, train)
self._hparams)


@registry.register_hparams
Expand Down
Loading

0 comments on commit de38b16

Please sign in to comment.