Skip to content

Commit

Permalink
Implement ring attention backward pass. More tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
nshepperd committed Feb 28, 2024
1 parent 9f80518 commit bc9a01d
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 81 deletions.
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ Please cite (see below) and credit FlashAttention if you use it.
## Installation and features

Requirements:
- CUDA 11.6 and above.
- CUDA 11.8 and above.
- Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
- JAX >=`0.4.24`. The custom sharding used for ring attention requires some somewhat advanced features.

To install: TODO
To install: For now, download the appropriate release from the releases page and install it with pip.

Interface: `src/flash_attn_jax/flash.py`

Expand All @@ -28,6 +29,17 @@ Accepts q,k,v with shape `[n, l, h, d]`, and returns `[n, l, h, d]`. `softmax_sc
multiplier for the softmax, defaulting to `1/sqrt(d)`. Set window_size
to positive values for sliding window attention.

### Now Supports Ring Attention

Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm:

```py
with Mesh(devices, axis_names=('len',)) as mesh:
sharding = NamedSharding(mesh, P(None,'len',None)) # n l d
tokens = jax.device_put(tokens, sharding)
# invoke your jax.jit'd transformer.forward
```

FlashAttention-2 currently supports:
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
Expand Down
2 changes: 1 addition & 1 deletion src/flash_attn_jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .flash import flash_mha
__version__ = 'v2.5.0'
__version__ = 'v2.5.5'
178 changes: 123 additions & 55 deletions src/flash_attn_jax/flash_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,53 +30,6 @@

from jax._src.ad_checkpoint import _optimization_barrier

def ring_fwd(softmax_scale, is_causal, axis_name, axis_size, q,k,v):
[n,l,h,d] = q.shape

q_ix = jax.lax.axis_index(axis_name)
k_ix = jax.lax.axis_index(axis_name)

o = jnp.zeros([n,l,h,d], jnp.float32)
lse = jnp.full([n,h,l], float('-inf'), jnp.float32)

# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
def f(c, a):
(k, v, o, lse, k_ix) = c

o1, lse1 = o, lse
if is_causal:
o2, lse2 = jax.lax.switch((k_ix < q_ix).astype(jnp.int32) + (k_ix <= q_ix).astype(jnp.int32),
[
lambda q,k,v: (jnp.zeros([n,l,h,d], q.dtype), jnp.full([n,h,l], float('-inf'), jnp.float32)),
lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1)),
lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)),
], q, k, v)
else:
o2, lse2 = _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
o2 = o2.astype(jnp.float32)

mx = jnp.maximum(lse1,lse2)
mn = jnp.minimum(lse1,lse2)
lse = jnp.log1p(jnp.exp(mn-mx)) + mx

o = (o1 * rearrange(jnp.exp(lse1 - lse), 'n h l -> n l h 1') +
o2 * rearrange(jnp.exp(lse2 - lse), 'n h l -> n l h 1'))

k2 = jax.lax.ppermute(k, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
v2 = jax.lax.ppermute(v, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
k_ix = jax.lax.ppermute(k_ix, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])

return ((k2, v2, o, lse, k_ix), None)
acc = (k,v,o,lse,k_ix)
# We sadly have to manually unroll this because scan breaks the axis context preventing us from using ppermute (unroll=axis_size doesn't help either).
# Optimization barrier prevents instruction reordering so that ppermute and flash_mha execute concurrently.
for _ in range(axis_size):
acc, _ = f(acc, None)
acc = _optimization_barrier(acc)
(_,_,o,lse,_) = acc
# (_,_,o,lse), _ = jax.lax.scan(f,init,None,axis_size)
return o.astype(q.dtype), lse

def partition_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
result_shardings = jax.tree_map(lambda x: x.sharding, result_shape)
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
Expand Down Expand Up @@ -147,21 +100,136 @@ def partition_bwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, resul
o_sharding = arg_shardings[4]
lse_sharding = arg_shardings[5]
if isinstance(q_sharding, PositionalSharding):
do_sharding = q_sharding.replicate((1,3))
[n, l, h, d] = do_sharding.shape
lse_sharding = do_sharding.reshape(n,l,h).transpose(0,2,1) # n h l
result_shardings = (do_sharding,)*3
arg_shardings = (do_sharding,)*5 + (lse_sharding,)
assert q_sharding == k_sharding, "Expect q and k sharding to match"
assert q_sharding == v_sharding, "Expect q and v sharding to match"
[n, l, h, d] = q_sharding.shape
assert d == 1, "Sharding across `d` won't be efficient, so it's not supported."
assert l == 1, "For ring attention, use `with Mesh(...) as mesh` and NamedSharding."
lse_sharding = q_sharding.reshape(n,h,1) # n h l
result_shardings = (q_sharding,)*3
arg_shardings = (q_sharding,)*5 + (lse_sharding,)
elif isinstance(q_sharding, NamedSharding):
mesh = q_sharding.mesh
[n,l,h,d] = q_sharding.spec
do_sharding = NamedSharding(mesh, P(n,None,h,None))
lse_sharding = NamedSharding(mesh, P(n,h,None))
result_shardings = (do_sharding,)*3
assert d == None, "Sharding across `d` won't be efficient, so it's not supported."
if l != None:
# assert not is_causal and window_size == (-1,-1), "Ring attention doesn't support causal or local masking yet."
assert window_size == (-1,-1), "Ring attention doesn't support local masking yet."
result_shardings = q_sharding, q_sharding, q_sharding
lse_sharding = NamedSharding(mesh, P(n,h,l))
arg_shardings = (q_sharding,)*5 + (lse_sharding,)
axis_name = l
axis_size = mesh.shape[axis_name]
# ring attention
return mesh, partial(ring_bwd, softmax_scale, is_causal, axis_name, axis_size), result_shardings, arg_shardings
else:
result_shardings = q_sharding, q_sharding, q_sharding
lse_sharding = NamedSharding(mesh, P(n,h,l))
arg_shardings = (q_sharding,)*5 + (lse_sharding,)
def fwd(*args):
return _flash_mha_bwd_hlo(*args, softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size)
return mesh, fwd, result_shardings, arg_shardings

_flash_mha_bwd_hlo_sharded.def_partition(
infer_sharding_from_operands=infer_sharding_bwd,
partition=partition_bwd)

# ==== Ring Forward ====

def ring_fwd(softmax_scale, is_causal, axis_name, axis_size, q,k,v):
[n,l,h,d] = q.shape

q_ix = jax.lax.axis_index(axis_name)
k_ix = jax.lax.axis_index(axis_name)

o = jnp.zeros([n,l,h,d], jnp.float32)
lse = jnp.full([n,h,l], float('-inf'), jnp.float32)

# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
def f(c, a):
(k, v, o, lse, k_ix) = c

o1, lse1 = o, lse
if is_causal:
o2, lse2 = jax.lax.switch((k_ix < q_ix).astype(jnp.int32) + (k_ix <= q_ix).astype(jnp.int32),
[
lambda q,k,v: (jnp.zeros([n,l,h,d], q.dtype), jnp.full([n,h,l], float('-inf'), jnp.float32)),
lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1)),
lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)),
], q, k, v)
else:
o2, lse2 = _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
o2 = o2.astype(jnp.float32)

mx = jnp.maximum(lse1,lse2)
mn = jnp.minimum(lse1,lse2)
lse = jnp.log1p(jnp.exp(mn-mx)) + mx

o = (o1 * rearrange(jnp.exp(lse1 - lse), 'n h l -> n l h 1') +
o2 * rearrange(jnp.exp(lse2 - lse), 'n h l -> n l h 1'))

k2 = jax.lax.ppermute(k, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
v2 = jax.lax.ppermute(v, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
k_ix = jax.lax.ppermute(k_ix, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])

return ((k2, v2, o, lse, k_ix), None)
acc = (k,v,o,lse,k_ix)
# We sadly have to manually unroll this because scan breaks the axis context preventing us from using ppermute (unroll=axis_size doesn't help either).
# Optimization barrier prevents instruction reordering so that ppermute and flash_mha execute concurrently.
for _ in range(axis_size):
acc, _ = f(acc, None)
acc = _optimization_barrier(acc)
(_,_,o,lse,_) = acc
# (_,_,o,lse), _ = jax.lax.scan(f,init,None,axis_size)
return o.astype(q.dtype), lse

# ==== Ring Backward ===

# This doesn't seem like the most efficient way to do this, kind of wasting compute by calculating every dq,dk,dv twice.
# Should we send the accumulator for dk,dv cross-device instead? Relying on the fact that after a full cycle, they return to the starting device.
def ring_bwd(softmax_scale, is_causal, axis_name, axis_size, do,q,k,v,o,lse):
[n,l,h,d] = q.shape

ix = jax.lax.axis_index(axis_name)

dq = jnp.zeros([n,l,h,d], jnp.float32)
dk = jnp.zeros([n,l,h,d], jnp.float32)
dv = jnp.zeros([n,l,h,d], jnp.float32)

# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
def f(acc, a):
(do2,q2,k2,v2,o2,lse2,ix2, dq,dk,dv) = acc

cmp = (ix2 < ix).astype(jnp.int32) + (ix2 <= ix).astype(jnp.int32)
# 0: ix < ix2
# 1: ix = ix2
# 2: ix > ix2
if is_causal:
dqa = jax.lax.switch(cmp, [
lambda q,k,v: jnp.zeros([n,l,h,d], q.dtype),
lambda q,k,v: _flash_mha_bwd_hlo(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1))[0],
lambda q,k,v: _flash_mha_bwd_hlo(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))[0],
], q, k, v)
dka,dva = jax.lax.switch(cmp, [
lambda q,k,v: _flash_mha_bwd_hlo(do2,q2,k,v,o2,lse2, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))[1:],
lambda q,k,v: _flash_mha_bwd_hlo(do2,q2,k,v,o2,lse2, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1))[1:],
lambda q,k,v: (jnp.zeros([n,l,h,d], q.dtype),jnp.zeros([n,l,h,d], q.dtype)),
], q, k, v)
else:
dqa,_,_ = _flash_mha_bwd_hlo(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
_,dka,dva = _flash_mha_bwd_hlo(do2,q2,k,v,o2,lse2, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))

dq += dqa
dk += dka
dv += dva

(do2,q2,k2,v2,o2,lse2,ix2) = jax.lax.ppermute((do2,q2,k2,v2,o2,lse2,ix2), axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])

return ((do2,q2,k2,v2,o2,lse2,ix2, dq,dk,dv), None)
acc = (do,q,k,v,o,lse,ix,dq,dk,dv)
# Unrolled as above.
for _ in range(axis_size):
acc, _ = f(acc, None)
acc = _optimization_barrier(acc)
(do2,q2,k2,v2,o2,lse2,ix2, dq,dk,dv) = acc
return dq.astype(q.dtype),dk.astype(q.dtype),dv.astype(q.dtype)
62 changes: 39 additions & 23 deletions tests/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def with_sharding(q_sharding, kv_sharding=None):
@pytest.mark.parametrize("d", [32])
@pytest.mark.parametrize("h", [4])
@pytest.mark.parametrize("seqlen", [128])
@pytest.mark.parametrize("shard_dim", [0,2])
def test_flash_bwd_sharded_hlo(seqlen, h, d, causal, local, dtype, shard_dim):
def test_flash_bwd_sharded_hlo(seqlen, h, d, causal, local, dtype):
window_size = (3,3) if local else (-1,-1)

devices = jax.local_devices()[:4]
Expand All @@ -117,19 +116,35 @@ def test_flash_bwd_sharded_hlo(seqlen, h, d, causal, local, dtype, shard_dim):
def flash(qkv):
return (flash_mha(*qkv, is_causal=bool(causal), window_size=window_size)**2).sum()

q = jax.random.normal(jax.random.PRNGKey(0), [n, seqlen, h, d], dtype=dtype)
k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=dtype)
v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=dtype)
def with_sharding(sharding):
q = jax.random.normal(jax.random.PRNGKey(0), [n, seqlen, h, d], dtype=dtype)
k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=dtype)
v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=dtype)
(q,k,v) = jax.device_put((q,k,v), sharding)
hlo = flash.lower((q,k,v)).compile().as_text()
return hlo

shape = [1,1,1,1]
shape[shard_dim] = n
sharding = PositionalSharding(devices).reshape(shape)
hlo = with_sharding(PositionalSharding(devices).reshape(n,1,1,1))
assert 'all-gather' not in hlo
assert 'dynamic-slice' not in hlo

q,k,v = jax.device_put((q,k,v), sharding)
hlo = flash.lower((q,k,v)).compile().as_text()
hlo = with_sharding(PositionalSharding(devices).reshape(1,1,n,1))
assert 'all-gather' not in hlo
assert 'dynamic-slice' not in hlo

if not local:
with Mesh(np.array(devices), axis_names=('x',)) as mesh:
sharding = NamedSharding(mesh, P(None,'x',None,None))
hlo = with_sharding(sharding)
# No resharding should occur, only manual collective-permute.
assert 'all-gather' not in hlo
assert 'dynamic-slice' not in hlo
assert 'collective-permute' in hlo
# Should always run concurrently, meaning custom-call is always between start and done.
import re
collectives = ''.join(re.findall(" collective-permute-start| collective-permute-done| custom-call", hlo))
assert 'collective-permute-start collective-permute-done' not in collectives, hlo

@pytest.mark.skipif(len(jax.local_devices()) < 2, reason='Requires >1 gpu device')
@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16])
@pytest.mark.parametrize("local", ['local',''])
Expand Down Expand Up @@ -181,8 +196,7 @@ def check_sharding(sharding,q,k,v):
@pytest.mark.parametrize("d", [32])
@pytest.mark.parametrize("h", [4, 8])
@pytest.mark.parametrize("seqlen", [128])
@pytest.mark.parametrize("shard_dim", [0,2])
def test_flash_bwd_sharded(seqlen, h, d, causal, local, dtype, shard_dim):
def test_flash_bwd_sharded(seqlen, h, d, causal, local, dtype):
window_size = (3,3) if local else (-1,-1)

devices = jax.local_devices()
Expand All @@ -200,23 +214,25 @@ def flash(qkv):
k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=jnp.float32)
v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=jnp.float32)

if q.shape[shard_dim] % n != 0:
pytest.skip(f"{q.shape[shard_dim]} doesn't divide into {n} so we can't shard it.")

ref_out = ref((q,k,v))
q = q.astype(dtype)
k = k.astype(dtype)
v = v.astype(dtype)
repl_out = flash((q,k,v))
ref16_out = flash((q,k,v))

def check_sharding(sharding,q,k,v):
(q,k,v) = jax.device_put((q,k,v), sharding)
out = flash((q,k,v))
check(ref_out,ref16_out,out)

shape = [1,1,1,1]
shape[shard_dim] = n
sharding = PositionalSharding(devices).reshape(shape)
check_sharding(PositionalSharding(devices).reshape(n,1,1,1),q,k,v)
check_sharding(PositionalSharding(devices).reshape(1,1,n,1),q,k,v)

(q,k,v) = jax.device_put((q,k,v), sharding)
hlo = flash.lower((q,k,v)).compile().as_text()
out = flash((q,k,v))
check(ref_out, repl_out, out)
if not local:
# Ring attention
with Mesh(np.array(devices), axis_names=('x',)) as mesh:
sharding = NamedSharding(mesh, P(None,'x',None,None))
check_sharding(sharding,q,k,v)

if __name__ == '__main__':
test_flash_fwd_sharded_hlo(128,4,32,False,False,jnp.float16)

0 comments on commit bc9a01d

Please sign in to comment.