Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CODE SHARING] Insertions of custom LLVM IR and AMDGCN codes to triton #610

Draft
wants to merge 10 commits into
base: sjw-pipeline-infra
Choose a base branch
from

Conversation

ravil-mobile
Copy link

@ravil-mobile ravil-mobile commented Jul 10, 2024

Triton Kernel and driver

#!/opt/conda/envs/py_3.10/bin/python3


import torch
import triton
import triton.language as tl
import time
import argparse
import os
import yaml

parser = argparse.ArgumentParser()
parser.add_argument("-b", "--benchmark", action='store_true', help='run benchmark')
parser.add_argument("-t", "--type", choices=['fp16', 'fp32'],
                    default="fp16",
                    help="data type under the test")
parser.add_argument("--dump-ir", choices=['llir', 'amdgcn'],
                    default="amdgcn",
                    help="dump IR format")
parser.add_argument("-p", "--prefix", type=str, default=None, help="prefix for the dumped files")
parser.add_argument("-f", "--file", type=str, default=None, help="load gemm parameters from file")
parser.add_argument("-d", "--double-buffering", action='store_true', help="enable double buffering")
parser.add_argument("-m", "--use-mask", action='store_true', help='use masked load/store')
parser.add_argument("-v", "--verbose", action='store_true', help='verbose output')
args = parser.parse_args()

curr_dir = os.path.dirname(os.path.abspath(__file__))

print(f"DOUBLE BUFFERING: {'enabled' if args.double_buffering else 'disabled'}")
print(f"MASKING load/store: {'enabled' if args.use_mask else 'disabled'}")

gemm_config = None
if args.file:
    # read GEMM config from a file
    if not os.path.exists(args.file):
        raise RuntimeError(f'cannot open `{args.file}`')
    with open(args.file, 'r') as file:
        gemm_config = yaml.safe_load(file)


def get_hip_autotune_config():
    num_stages = 3 if args.double_buffering else 2
    if gemm_config and 'tuning' in gemm_config.keys():
        tuning_config = gemm_config['tuning']
        # apply the tuning parameters specified in the file
        # essentially, it disables auto-tuning and applies specific parameters
        return [
           triton.Config(
              {'BLOCK_SIZE_M': tuning_config['BLOCK_SIZE_M'], 
               'BLOCK_SIZE_N': tuning_config['BLOCK_SIZE_N'], 
               'BLOCK_SIZE_K': tuning_config['BLOCK_SIZE_K'], 
               'GROUP_SIZE_M': tuning_config['GROUP_SIZE_M'], 
               'waves_per_eu': tuning_config['waves_per_eu']},
               num_warps=tuning_config['num_warps'],
               num_stages=tuning_config['num_stages'],
               num_ctas=tuning_config['num_ctas'])
        ]
    else:
        return [
            triton.Config(
                {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
                num_warps=4, num_stages=num_stages),
            triton.Config(
                {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
                num_warps=8, num_stages=num_stages),
            triton.Config(
                {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
                num_warps=8, num_stages=num_stages),
            triton.Config(
                {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3},
                num_warps=4, num_stages=num_stages),
            triton.Config(
                {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
                num_warps=4, num_stages=num_stages),
        ]


@triton.autotune(
    configs=get_hip_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr, bias_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    stride_bias,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BIAS: tl.constexpr, USE_MASK: tl.constexpr
):
    pid = tl.program_id(axis=0)
    pid_z = tl.program_id(1)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    if GROUP_SIZE_M == 1:
        pid_m = pid // num_pid_n
        pid_n = pid % num_pid_n
    else:
        num_pid_in_group = GROUP_SIZE_M * num_pid_n
        group_id = pid // num_pid_in_group
        first_pid_m = group_id * GROUP_SIZE_M
        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
        pid_m = first_pid_m + (pid % group_size_m)
        pid_n = (pid % num_pid_in_group) // group_size_m
    if SPLIT_K == 1:
        offs_k = tl.arange(0, BLOCK_SIZE_K)
    else:
        offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))
    a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
    if BIAS:
        bias_ptrs = bias_ptr + offs_am * stride_bias
        bias = tl.load(bias_ptrs, mask=offs_am < M if USE_MASK else None, other=0.0)
    acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
        a = tl.load(a_ptrs)
        b = tl.load(b_ptrs)
        accumulator += tl.dot(a, b)
        a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
    c = accumulator.to(c_ptr.type.element_ty)
    if BIAS:
        c += bias[:, None]
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    if SPLIT_K == 1:
        tl.store(c_ptrs, c, mask=c_mask if USE_MASK else None)
    else:
        tl.atomic_add(c_ptrs, c, mask=c_mask if USE_MASK else None)


def matmul(a, b, c, bias, use_bias=False):
  M, K = a.shape
  K, N = b.shape
  stride_bias = N

  grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
  
  handle = matmul_kernel[grid](
    a, b, c, bias,
    M, N, K,
    a.stride(0), a.stride(1),
    b.stride(0), b.stride(1),
    c.stride(0), c.stride(1),
    stride_bias,
    SPLIT_K=1,
    BIAS=use_bias,
    USE_MASK=True if args.use_mask else False,
  )

  if args.verbose:
    print(handle.asm.keys())

  if args.dump_ir:
    filename = f"matmul_kernel.{args.dump_ir}"
    filename = f"{args.prefix}-{filename}" if args.prefix else filename
    with open(os.path.join(curr_dir, filename), "w") as file:
       file.write(handle.asm[args.dump_ir])

if gemm_config:
    M = gemm_config['M']
    N = gemm_config['N']
    K = gemm_config['K']
else:
    M = 8192
    N = 8192
    K = 8192

if args.type == 'fp16':
    tested_dtype=torch.float16
elif args.type == 'fp32':
    tested_dtype=torch.float32
else:
    raise RuntimeError(f'`{args.type}` unsupported data type')

a = torch.randn((M, K), device='cuda', dtype=tested_dtype)
b = torch.randn((K, N), device='cuda', dtype=tested_dtype)
c = torch.empty((M, N), device='cuda', dtype=tested_dtype)
bias = torch.empty((M, N), device='cuda', dtype=tested_dtype)

enable_benchmark = args.benchmark

if not enable_benchmark:
  matmul(a, b, c, bias, False)

configs = []
configs.append(
    triton.testing.Benchmark(
        x_names=["M", "N", "K"],
        x_vals=[2**(i+10) for i in range(6)],
        line_arg="provider",
        line_vals=["triton"],
        line_names=["Triton"],
        styles=[("green", "-"),],
        ylabel="TFLOPS",
        plot_name=f"matmul-performance-{args.type}",
        args={},
    ))

@triton.testing.perf_report(configs)
def benchmark(M, N, K, provider):
  a = torch.randn((M, K), device='cuda', dtype=tested_dtype)
  b = torch.randn((K, N), device='cuda', dtype=tested_dtype)
  c = torch.empty((M, N), device='cuda', dtype=tested_dtype)
  bias = torch.empty((M, N), device='cuda', dtype=tested_dtype)
  quantiles = [0.5, 0.2, 0.8]
    
  ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c, bias, False), quantiles=quantiles)
    
  perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
  return perf(ms), perf(max_ms), perf(min_ms)



if enable_benchmark:
  benchmark.run(show_plots=True, print_data=True)

Kernel Config

$ cat ./gemm-config.yaml
M: 8192
N: 8192
K: 8192
tuning:
  BLOCK_SIZE_M: 256
  BLOCK_SIZE_N: 256
  BLOCK_SIZE_K: 16
  GROUP_SIZE_M: 4
  waves_per_eu: 2
  num_warps: 8
  num_ctas: 1
  num_stages: 2

Recording

python3 ./gemm.py -f ./gemm-config.yaml --type "fp16" --dump-ir "llir"
python3 ./gemm.py -f ./gemm-config.yaml --type "fp16" --dump-ir "amdgcn"

# `matmul_kernel.llir` and `matmul_kernel.amdgcn` will be produced

Insertion

AMD_INSERT_LLVM_IR=./matmul_kernel.llir python3 ./gemm.py -f ./gemm-config.yaml --type "fp16"
AMD_INSERT_AMDGCN=./matmul_kernel.amdgcn python3 ./gemm.py -f ./gemm-config.yaml --type "fp16"

sjw36 and others added 10 commits June 25, 2024 16:37
…structure

    - Copied scheduler from MatmulLoopPipeline (much could be consolidated)
    - Enable register buffering (even though may increases register pressure)
    - Enable num_stages=2+, including multi-buffering, and make `2` the default
    - updated tutorial for new tuning default
    - added lit	tests
- Also move independent(from loop-carried buffer) `triton_gpu.local_store` as early as possible
- check for last atomic (sync?)
- also check for other accesses to the source
@ravil-mobile ravil-mobile changed the base branch from main to sjw-pipeline-infra July 10, 2024 15:47
@sjw36 sjw36 force-pushed the sjw-pipeline-infra branch 2 times, most recently from 4eeb8cc to faf95cb Compare July 18, 2024 14:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants