Skip to content

Commit

Permalink
Add support for dynamic dims in GEMMs
Browse files Browse the repository at this point in the history
This PR adds tests for dynamic M, N and K dims in
GEMMs. This works out of the box for the most part
and just requires moving the align index pass
after scheduling and handling the max value
of the induction variable for the reduction loop.

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Nov 16, 2024
1 parent d8ce8d2 commit b893abe
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 16 deletions.
3 changes: 2 additions & 1 deletion iree/turbine/kernel/compiler/host_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def isolated_test_call(
argument_dims = get_dynamic_dims(
sig.kernel_buffer_input_bindings, dynamic_symbols
)
input_tensors += [IndexType.get() for _ in argument_dims]
# Only add types for the unique dynamic symbols.
input_tensors += [IndexType.get() for _ in set(argument_dims)]

output_types = [b.as_mlir_type() for b in sig.kernel_buffer_output_bindings]
output_tensors = memref_to_tensor(output_types)
Expand Down
28 changes: 26 additions & 2 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,18 @@ def _floor(value):

def _ceiling(value):
if isinstance(value, _Rational):
value = arith_d.ceildivsi(*_broadcast(value.numerator, value.denominator))
# TODO: This is the expansion of ceildivui, but we should figure
# out how to run the arith-expand pass instead of doing this.
# ceildivui(x, y) = x == 0 ? 0 : ((x - 1) / y) + 1
one = _get_const(1)
zero = _get_const(0)
lhs_minus_one = arith_d.subi(*_broadcast(value.numerator, one))
div = arith_d.divui(*_broadcast(lhs_minus_one, value.denominator))
result = arith_d.addi(*_broadcast(div, one))
cmp = arith_d.cmpi(
arith_d.CmpIPredicate.eq, *_broadcast(value.numerator, zero)
)
value = arith_d.select(cmp, zero, result)

return value

Expand Down Expand Up @@ -417,6 +428,16 @@ def _get_const(val):
_enforce_non_rational(lhs, term)
res = arith_d.andi(*_broadcast(lhs, rhs))
stack.append(res)
case sympy.Max():
rhs = stack.pop()
lhs = stack.pop()
_enforce_non_rational(rhs, term)
_enforce_non_rational(lhs, term)
if _is_integer_like_type(rhs.type):
res = arith_d.maxsi(*_broadcast(lhs, rhs))
else:
res = arith_d.maximumf(*_broadcast(lhs, rhs))
stack.append(res)
case sympy.logic.boolalg.BooleanFalse():
res = arith_d.constant(IntegerType.get_signless(1), 0)
stack.append(res)
Expand Down Expand Up @@ -1062,7 +1083,10 @@ def handle_reduction(emitter: WaveEmitter, node: fx.Node):

# For now, we assume that dimensions that have tiling constraints on them,
# do not have any other constraints.
end = arith_d.constant(IndexType.get(), int(node.count))
if isinstance(node.count, sympy.Expr):
end = gen_sympy_index(add_emitter_subs(emitter), node.count)
else:
end = arith_d.constant(IndexType.get(), int(node.count))

# Since we divide the end by the tile size, we need to make sure that the
# step is 1.
Expand Down
27 changes: 20 additions & 7 deletions iree/turbine/kernel/wave/scheduling/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..utils import graph_copy, erase_graph, get_tiling_constraint, subs_idxc
import torch.fx as fx
from ....support.logging import get_logger
import math
import sympy

logger = get_logger("turbine.wave.scheduling.schedule")

Expand Down Expand Up @@ -92,12 +92,20 @@ def schedule_reduction(
# to have atleast N iterations of the loop where N > num_stages - 1 (because
# we will be peeling off num_stages iterations from the loop).
tiling_constraint = get_tiling_constraint(reduction, constraints)
max_induction_variable = int(
subs_idxc(tiling_constraint.dim) // subs_idxc(tiling_constraint.tile_size)
max_induction_variable = subs_idxc(tiling_constraint.dim) // subs_idxc(
tiling_constraint.tile_size
)
if max_induction_variable <= scheduler.num_stages - 1:
logger.warn("Not enough iterations to pipeline the loop. Skipping pipelining.")
return {}

ivar_is_number = max_induction_variable.is_number
if ivar_is_number:
# We can only do a compile-time check if the induction variable
# is not dynamic.
max_induction_variable = int(max_induction_variable)
if max_induction_variable <= scheduler.num_stages - 1:
logger.warn(
"Not enough iterations to pipeline the loop. Skipping pipelining."
)
return {}

new_reduction = construct_pipelined_loop(
trace,
Expand All @@ -112,7 +120,12 @@ def schedule_reduction(
)

# Update new reduction count.
new_reduction.count = max_induction_variable - (scheduler.num_stages - 1)
if ivar_is_number:
new_reduction.count = max_induction_variable - (scheduler.num_stages - 1)
else:
new_reduction.count = sympy.Max(
0, max_induction_variable - (scheduler.num_stages - 1)
)


def schedule_graph(
Expand Down
10 changes: 5 additions & 5 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,6 @@ def _trace_and_get_kernel_signature(
# Partition strided operators.
partition_strided_operators(graph, self.constraints)

# Align sizes to WG/Tile sizes
# This pass changes indexing keys, which can interfere with other passes,
# so it should be called close to the end of pipeline.
align_index_sizes(graph, self.constraints)

# Decompose reduce Ops.
decompose_reduce_ops(graph, self.constraints, idxc.subs)

Expand All @@ -278,6 +273,11 @@ def _trace_and_get_kernel_signature(
use_scheduling_barriers = kwargs.get("use_scheduling_barriers", False)
schedule_graph(graph, self.constraints, use_scheduling_barriers)

# Align sizes to WG/Tile sizes
# This pass changes indexing keys, which can interfere with other passes,
# so it should be called close to the end of pipeline.
align_index_sizes(graph, self.constraints)

# Add shared memory barriers.
add_shared_memory_barriers(graph)

Expand Down
96 changes: 96 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,102 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK-COUNT-8: amdgpu.mfma


@run_test
def test_dynamic_gemm_pipelined():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(2, 2, 1),
mma_type=tkw.MMAType.F32_16x16x16_F16,
)
]

@tkw.wave(constraints)
def dynamic_gemm_pipelined(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

with tk.gen.TestLaunchContext(
{
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
LOAD_ELEMS_PER_THREAD: 4,
STORE_ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
VALU_DELAY: 1,
SHUFFLE_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
VALU_UNITS: 8,
SHUFFLE_UNITS: 8,
},
canonicalize=True,
schedule=True,
use_scheduling_barriers=True,
dynamic_symbols=(M, N, K),
dynamic_symbols_map={M: 64, N: 128, K: 32},
):
a = torch.randn(64, 32, dtype=torch.float16)
b = torch.randn(128, 32, dtype=torch.float16)
c = torch.zeros(64, 128, dtype=torch.float32)
print(dynamic_gemm_pipelined(a, b, c).module_op)

# CHECK: func.func @dynamic_gemm_pipelined
# CHECK-COUNT-2: vector.maskedload
# CHECK-COUNT-2: vector.store
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-2: vector.maskedload
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-2: vector.store
# CHECK-COUNT-1: arith.maxsi
# CHECK-COUNT-1: scf.for
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-2: vector.maskedload
# CHECK-COUNT-3: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier"
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-1: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier"
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-2: vector.store
# CHECK-COUNT-2: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier"
# CHECK-COUNT-1: scf.yield
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-8: vector.load
# CHECK-COUNT-8: amdgpu.mfma


# This test is used to check three things
# 1. Reduction with multiple different types(MMA, ReduceOp) of iterArg works
# 2. ReduceOp lowering works using constraints from MMA (not just vector_shape).
Expand Down
23 changes: 22 additions & 1 deletion tests/kernel/wave/wave_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import json
from torch.testing import assert_close
from enum import Enum

_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0))
require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled")
Expand Down Expand Up @@ -60,6 +61,7 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]:
@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_gemm"))
@pytest.mark.parametrize("enable_scheduling", [False, True])
@pytest.mark.parametrize("dynamic_dims", [False, True])
@pytest.mark.parametrize(
"mfma_variant",
[
Expand All @@ -68,7 +70,11 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]:
],
)
def testGemm(
shape: tuple[int], enable_scheduling: bool, mfma_variant: MMAType, request
shape: tuple[int],
enable_scheduling: bool,
dynamic_dims: bool,
mfma_variant: MMAType,
request,
):
run_bench = request.config.getoption("--runperf")
dump_perf = request.config.getoption("--dump-perf-files-path")
Expand Down Expand Up @@ -161,6 +167,19 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
dump_perf, "tk_" + perf_filename
)

dynamic_symbols = []
dynamic_symbols_map = {}
if dynamic_dims:
dynamic_symbols_map[M] = hyperparams[M]
dynamic_symbols_map[N] = hyperparams[N]
dynamic_symbols_map[K] = hyperparams[K]
dynamic_symbols.append(M)
dynamic_symbols.append(N)
dynamic_symbols.append(K)
del hyperparams[M]
del hyperparams[N]
del hyperparams[K]

with tk.gen.TestLaunchContext(
hyperparams,
canonicalize=True,
Expand All @@ -169,6 +188,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
run_config=config,
schedule=enable_scheduling,
use_scheduling_barriers=enable_scheduling_barriers,
dynamic_symbols=dynamic_symbols,
dynamic_symbols_map=dynamic_symbols_map,
):
a = torch.randn(shape[0], shape[2], dtype=torch.float16)
b = torch.randn(shape[1], shape[2], dtype=torch.float16)
Expand Down

0 comments on commit b893abe

Please sign in to comment.