Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
TPU compatibility for data chunking
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 237326592
  • Loading branch information
T2T Team authored and Copybara-Service committed Mar 7, 2019
1 parent 6671139 commit 83d98cd
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion tensor2tensor/utils/data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,17 +490,27 @@ def is_nonzero_chunk(example):
def split_on_length(example):
"""Split a batch of ditcs on length."""
x = example["targets"]
# TODO(kitaev): This code breaks if chunk_length * max_chunks < batch_size
length_diff = chunk_length * max_chunks - tf.shape(x)[1]
padded_x = tf.pad(x, [(0, 0), (0, length_diff), (0, 0), (0, 0)])
chunks = [padded_x[:, i*chunk_length:(i+1)*chunk_length, :, :]
for i in range(max_chunks - 1)]
chunks.append(padded_x[:, (max_chunks - 1)*chunk_length:, :, :])
new_example = {}
new_example["chunk_number"] = tf.range(max_chunks)
# Setting chunk_number to be tf.range(max_chunks) is incompatible with TPU
new_example["chunk_number"] = tf.concat([
tf.expand_dims(tf.ones_like(c) * n, axis=0)
for n, c in enumerate(chunks)
],
axis=0)
new_example["targets"] = tf.concat(
[tf.expand_dims(c, axis=0) for c in chunks], axis=0)
for k in example:
if k != "targets":
assert k != "chunk_number", (
"Chunking code expects the chunk_number feature name to be "
"available"
)
new_example[k] = tf.concat(
[tf.expand_dims(example[k], axis=0) for _ in range(max_chunks)],
axis=0)
Expand Down

0 comments on commit 83d98cd

Please sign in to comment.