Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test and benchmark for explicit dot GEMM #637

Open
wants to merge 1 commit into
base: main_perf
Choose a base branch
from

Conversation

brunomazzottiamd
Copy link
Collaborator

@brunomazzottiamd brunomazzottiamd commented Sep 11, 2024

Summary

This PR adds missing test and benchmark features to the explicit dot GEMM Triton kernel developed in the scope of https://github.com/ROCm/triton-internal/issues/169. It also adds GEMM implemented with tl.dot to the mix so we can easily compare it to PyTorch GEMM and Triton GEMM implemented with explicit dot, both in terms of correctness and performance.

Help / How to run

python python/perf-kernels/multreduce_matmul_kernel.py -h
usage: multreduce_matmul_kernel.py [-h] [-M M] [-N N] [-K K] [--use-bias] [--use-dot] {run,bench}

C = A * B + BIAS matrix multiplication kernel for small matrices

positional arguments:
  {run,bench}  mode of operation:
                 run: run Triton kernel for a given (M, N, K) shape
                 bench: benchmark performance for target shapes

options:
  -h, --help   show this help message and exit

kernel shape arguments:
  -M M         rows of matrix A
  -N N         columns of matrix A / rows of matrix B
  -K K         columns of matrix B
  --use-bias   use BIAS vector
  --use-dot    use tl.dot for dot product

Running the Triton kernel for a single shape

python python/perf-kernels/multreduce_matmul_kernel.py run -M 1 -N 8192 -K 28672

Checking correctness of a single target shape

pytest -vvv python/perf-kernels/multreduce_matmul_kernel.py::test_matmul[1-4096-4096-True]

Checking correctness of all target shapes

pytest -vvv python/perf-kernels/multreduce_matmul_kernel.py

Benchmarking all target shapes

python python/perf-kernels/multreduce_matmul_kernel.py bench

Sample benchmark result on MI300X:

     M        N        K  PyTorch (GiB/s)  Triton Dot (GiB/s)  Triton Multreduce (GiB/s)
0  1.0   8192.0  28672.0          3863.76             2493.81                    2393.79
1  1.0   6144.0   6144.0          2138.50             2648.17                    1935.80
2  1.0   4096.0   4096.0          1960.94             1154.51                    1434.65
3  2.0  16384.0  16384.0          2411.74             2782.96                    2704.58
4  1.0   4096.0   3078.0          1066.43              775.85                     329.33

@brunomazzottiamd brunomazzottiamd self-assigned this Sep 11, 2024
@brunomazzottiamd brunomazzottiamd marked this pull request as ready for review September 11, 2024 21:26
@brunomazzottiamd brunomazzottiamd force-pushed the 271_add_test_and_bench_to_multreduce_matmul branch 2 times, most recently from c0c697a to f6209e7 Compare September 24, 2024 14:26
@brunomazzottiamd brunomazzottiamd force-pushed the 271_add_test_and_bench_to_multreduce_matmul branch from f6209e7 to 14a7e75 Compare October 1, 2024 14:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants