Skip to content

Commit

Permalink
[Pallas:MGPU] Add support for multiple heads in attention
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694104006
  • Loading branch information
apaszke authored and Google-ML-Automation committed Nov 7, 2024
1 parent 5066712 commit f8dba3c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
39 changes: 27 additions & 12 deletions jax/experimental/pallas/ops/gpu/attention_mgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ class TuningConfig:
block_kv: int
max_concurrent_steps: int

def __post_init__(self):
if self.block_q % 64:
raise ValueError(f"{self.block_q=} must be a multiple of 64")
if self.block_kv % 64:
raise ValueError(f"{self.block_kv=} must be a multiple of 64")
if self.max_concurrent_steps < 2:
raise ValueError(f"{self.max_concurrent_steps=} must be at least 2")


@functools.partial(jax.jit, static_argnames=["config"])
def attention(q, k, v, config: TuningConfig):
Expand All @@ -46,14 +54,16 @@ def attention(q, k, v, config: TuningConfig):
raise ValueError(f"Expected {v.shape=} to be {kv_shape} (inferred from q)")
if (dtype := q.dtype) != k.dtype or dtype != v.dtype:
raise ValueError(f"q, k, and v should all have the same dtype, got: {q.dtype}, {k.dtype}, {v.dtype}")
if batch_size != 1 or num_q_heads != 1 or num_kv_heads != 1:
raise NotImplementedError(
"Only batch_size=1, num_q_heads=1, and num_kv_heads=1 are supported,"
f" got: {batch_size=}, {num_q_heads=}, {num_kv_heads=}"
)
if num_q_heads % num_kv_heads:
raise ValueError(f"{num_q_heads=} must be divisible by and {num_kv_heads=}")
q_heads_per_kv_head = num_q_heads // num_kv_heads
if head_dim % 64:
raise ValueError(f"{head_dim=} must be divisible by 64")
if batch_size != 1:
raise NotImplementedError(f"Only batch_size=1 is supported, got: {batch_size=}")
if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]):
raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}")
q, k, v = map(lambda x: x[0, :, 0, :], (q, k, v))
q, k, v = map(lambda x: x[0], (q, k, v))
max_concurrent_steps = min(
config.max_concurrent_steps, kv_seq_len // config.block_kv
)
Expand All @@ -74,9 +84,10 @@ def _compute_wg():
plgpu.set_max_registers(232, action="increase")
qo_smem = qo_smem2.at[wg_idx]
q_seq_base = lax.axis_index("q") * (2 * block_q) + wg_idx * block_q
q_head = lax.axis_index("heads")

plgpu.copy_gmem_to_smem(
q_ref.at[pl.ds(q_seq_base, block_q)],
q_ref.at[pl.ds(q_seq_base, block_q), q_head],
qo_smem,
barrier=q_barriers.at[wg_idx],
)
Expand Down Expand Up @@ -146,21 +157,22 @@ def _wait():
qo_smem[...] = acc.astype(dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(
qo_smem, out_ref.at[pl.ds(q_seq_base, block_q)],
qo_smem, out_ref.at[pl.ds(q_seq_base, block_q), q_head],
)
plgpu.wait_smem_to_gmem(0)
@pl.when(wg_idx == 2)
def _memory_wg():
plgpu.set_max_registers(40, action="decrease")
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
for i in range(max_concurrent_steps):
s = pl.ds(i * block_kv, block_kv)
s = (pl.ds(i * block_kv, block_kv), kv_head)
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], barrier=k_barriers.at[i])
plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], barrier=v_barriers.at[i])

def kv_loop(kv_step, _):
tma_step = kv_step + max_concurrent_steps
tma_slot = lax.rem(kv_step, max_concurrent_steps)
s = pl.ds(tma_step * block_kv, block_kv)
s = (pl.ds(tma_step * block_kv, block_kv), kv_head)
plgpu.barrier_wait(k_consumed_barrier)
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], barrier=k_barriers.at[tma_slot])
plgpu.barrier_wait(v_consumed_barrier)
Expand All @@ -179,7 +191,10 @@ def run(refs):
if rem:
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}")
mesh = plgpu.GPUMesh(
grid=(num_q_tiles,), num_threads=3, axis_names=("q", "wg"), approx_math=True,
grid=(num_q_tiles, num_q_heads),
num_threads=3,
axis_names=("q", "heads", "wg"),
approx_math=True,
)
@pl.core_map(mesh)
def _kernel_entry():
Expand Down Expand Up @@ -212,7 +227,7 @@ def _kernel_entry():
)

_, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf)))
return out[None, :, None, :]
return out[None]


@jax.jit
Expand Down
8 changes: 3 additions & 5 deletions tests/pallas/mgpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,9 @@ def setUp(self):
batch_size=(1,),
q_seq_len=(4096,),
kv_seq_len=(4096,),
num_q_and_kv_heads=((1, 1),),
# TODO(apaszke): Enable once we support many heads.
# num_q_and_kv_heads=((4, 1), # MQA
# (6, 3), # GQA
# (4, 4),), # MHA
num_q_and_kv_heads=((4, 1), # MQA
(6, 3), # GQA
(4, 4),), # MHA
head_dim=(64, 128, 256),
)
def test_flash_attention(
Expand Down

0 comments on commit f8dba3c

Please sign in to comment.