Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Fix attention rng mismatch between forward and reverse direction
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 272707157
  • Loading branch information
T2T Team authored and copybara-github committed Oct 3, 2019
1 parent 9f29518 commit 176148c
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 7 deletions.
3 changes: 2 additions & 1 deletion tensor2tensor/trax/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,8 @@ def new_params_and_state(self, input_shape, input_dtype, rng):
class BaseCausalAttention(base.Layer):
"""Base class for variants of causal self-attention."""

def __init__(self):
def __init__(self, mode='train'):
del mode
super(BaseCausalAttention, self).__init__(n_inputs=3)

def forward(self, inputs, params=(), state=(), rng=None, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions tensor2tensor/trax/layers/reversible.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def reverse(self, output, params=(), state=(), **kwargs):
rngs = backend.random.split(rng, self._n_layers)

layer_val = output
for layer, p, s, rng in reversed(zip(self.sublayers,
params, state, rngs)):
for layer, p, s, rng in reversed(list(zip(self.sublayers,
params, state, rngs))):
layer_val = layer.reverse(layer_val, p, s, rng=rng, **kwargs)

return layer_val
Expand All @@ -116,8 +116,8 @@ def reverse_and_grad(self, output, ct, params=(), state=(), **kwargs):
layer_val = output
layer_ct = ct
params_ct = []
for layer, p, s, rng in reversed(zip(self.sublayers,
params, state, rngs)):
for layer, p, s, rng in reversed(list(zip(self.sublayers,
params, state, rngs))):
layer_val, layer_ct = layer.reverse_and_grad(
layer_val, layer_ct, p, s, rng=rng, **kwargs)
layer_ct, p_ct = layer_ct
Expand Down
8 changes: 6 additions & 2 deletions tensor2tensor/trax/models/research/reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,18 @@ def __init__(self, attention):
super(ApplyAttentionWrapper, self).__init__(attention, [], [])
self.attention = attention

def forward_and_backward(self, inputs, ct, **kwargs):
def forward_and_backward(self, inputs, ct, rng=None, **kwargs):
# Simultaneous forward pass and backprop through the attention mechanism.
qkv = inputs[:3]
passthrough = inputs[3:]
out_ct = ct[0]
passthrough_ct = ct[1:]
if rng is not None:
# Adjust RNG to match the forward pass.
rng = backend.random.split(rng, self._n_layers)[0]

out, qkv_ct = self.attention.forward_and_backward(qkv, out_ct, **kwargs)
out, qkv_ct = self.attention.forward_and_backward(
qkv, out_ct, rng=rng, **kwargs)
return (out,) + passthrough, qkv_ct + passthrough_ct


Expand Down
60 changes: 60 additions & 0 deletions tensor2tensor/trax/models/research/reformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,43 @@

from absl.testing import absltest
from absl.testing import parameterized
import jax
import numpy as onp

from tensor2tensor.trax import backend
from tensor2tensor.trax import layers as tl
from tensor2tensor.trax.backend import numpy as np
from tensor2tensor.trax.models.research import reformer


class PoisonOnRNGMismatchAttention(tl.BaseCausalAttention):
"""Fills gradients with NaNs if reverse rng does not match forward rng."""

# pylint: disable=protected-access
def forward_and_backward(self, inputs, ct, rng=None, **kwargs):
assert backend.get_name() == 'jax', (
'JAX backend is required to use forward_and_backward.')

if ct is not None and tl.Layer._STASH_OUT is not None:
recovered_rng = tl.Layer._STASH_OUT.pop(self)
is_same = (rng[0] == recovered_rng[0]) & (rng[1] == recovered_rng[1])
is_same = is_same.astype(np.float32)
# Divides by zero if rngs are not the same, which results in NaNs.
inputs = (inputs[0] / is_same, inputs[1] / is_same, inputs[2] / is_same)

def _do_forward(x): # pylint: disable=invalid-name
res, _ = self.forward(x, rng=rng, **kwargs)
return res
output, vjpfun = jax.vjp(_do_forward, inputs)
return output, vjpfun(ct)[0]

def forward(self, inputs, params=(), state=(), rng=None, **kwargs):
if tl.Layer._STASH_IN is not None:
tl.Layer._STASH_IN[self] = rng
return inputs[2], state
# pylint: enable=protected-access


class ReformerTest(parameterized.TestCase):

def test_reformer_lm_forward_shape(self):
Expand All @@ -39,6 +72,33 @@ def test_reformer_lm_forward_shape(self):
model, tuple(input_shape), integer_inputs=True)
self.assertEqual(((1, 8, 16), (1, 8, 16)), final_shape)

def test_reformer_rng_consistency(self):
with backend.use_backend('jax'):
vocab_size = 16
batch_size = 1
input_shape = ((batch_size, 8), (batch_size, 8))
model = reformer.ReformerLM(
vocab_size, d_model=32, d_ff=64,
d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2,
max_len=16, n_chunks=2, n_attention_chunks=1, mode='train',
attention_type=PoisonOnRNGMismatchAttention)

rng = backend.random.get_prng(0)
params, state = model.initialize_once(
input_shape, (np.int32, np.int32), rng)

def dummy_loss_fn(params):
inputs = (np.zeros(input_shape[0], dtype=np.int32),) * 2
output = model(inputs, params=params, state=state, rng=rng)
dummy_loss = backend.numpy.sum(output[0])
return dummy_loss

grad_fn = backend.grad(dummy_loss_fn)
grads = grad_fn(params)
# PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch.
for grad in jax.tree_util.tree_leaves(grads):
assert onp.all(onp.isfinite(grad))


if __name__ == '__main__':
absltest.main()

0 comments on commit 176148c

Please sign in to comment.