diff --git a/python/aitemplate/backend/cuda/gemm_universal/__init__.py b/python/aitemplate/backend/cuda/gemm_universal/__init__.py index 3cf6eecc4..c77d279cc 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/__init__.py +++ b/python/aitemplate/backend/cuda/gemm_universal/__init__.py @@ -32,6 +32,7 @@ gemm_rcr_permute, gemm_rcr_permute_elup1, gemm_rrr, + gemm_rrr_bias, gemm_rrr_permute, group_gemm_rcr, group_gemm_rcr_bias, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py index 696734094..f4b0a0d07 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py @@ -107,6 +107,36 @@ def gemm_rrr_config(func_attrs, dtype="float16"): op.C.element = cutlass_lib.library.DataType.void +def common_gen_profiler( + func_attrs, + workdir, + profiler_filename, + dim_info_dict, + src_template, + problem_args_template, + problem_args_template_cutlass_3x=None, + bias_ptr_arg=None, + extra_code="", +): + output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( + stride_dim="*b_dim1" + ) + return common.gen_profiler( + func_attrs=func_attrs, + workdir=workdir, + profiler_filename=profiler_filename, + dim_info_dict=dim_info_dict, + src_template=src_template, + problem_args_template=problem_args_template, + problem_args_template_cutlass_3x=problem_args_template_cutlass_3x, + args_parser_template=ARGS_PARSER_TEMPLATE, + support_split_k=True, + output_addr_calculator=output_addr_calculator, + bias_ptr_arg=bias_ptr_arg, + extra_code=extra_code, + ) + + @registry.reg("cuda.gemm_rrr.gen_profiler") def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( @@ -126,6 +156,38 @@ def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): ) +def get_input_addr_calculator(func_attrs): + input_a_batch_stride_dim = "M * K" + input_a_stride_k_dim = "K" + input_a_offset = 0 + input_b_batch_stride_dim = "K * N" + input_b_stride_k_dim = "N" + input_b_offset = 0 + + if "input_accessors" in func_attrs: + input_a_accessor = func_attrs["input_accessors"][0] + input_b_accessor = func_attrs["input_accessors"][1] + if input_a_accessor.is_from_strided_tensor: + input_a_offset = input_a_accessor.offset + shapes = input_a_accessor.original_shapes + input_a_stride_k_dim = input_a_accessor.stride(len(shapes) - 2) + + if input_b_accessor.is_from_strided_tensor: + input_b_offset = input_b_accessor.offset + shapes = input_b_accessor.original_shapes + input_b_stride_k_dim = input_b_accessor.stride(len(shapes) - 2) + + input_addr_calculator = common.INPUT_ADDR_CALCULATOR.render( + input_a_batch_stride_dim=input_a_batch_stride_dim, + input_a_stride_dim=input_a_stride_k_dim, + input_a_offset_val=input_a_offset, + input_b_batch_stride_dim=input_b_batch_stride_dim, + input_b_stride_dim=input_b_stride_k_dim, + input_b_offset_val=input_b_offset, + ) + return input_addr_calculator + + @registry.reg("cuda.gemm_rrr.gen_function") def gen_function( func_attrs, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_bias.py new file mode 100644 index 000000000..42c675ff0 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_bias.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = GeMM(A, B) + bias +where A[RowMajor][M, K], B[ColMajor][K, N], bias[RowMajor][N] +""" +import jinja2 + +from aitemplate.backend import registry + +from aitemplate.backend.backend_spec import CUDASpec +from aitemplate.backend.cuda.gemm_universal import common, common_bias, gemm_rrr +from aitemplate.backend.cuda.gemm_universal.layout import RRR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +EXTRA_CODE = jinja2.Template( + """ +using elem_input_type = {{elem_input_type}}; +using elem_output_type = {{elem_output_type}}; +""" +) + + +# used for real execution +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode + cutlass::gemm::GemmCoord{ + static_cast(M), + static_cast(N), + static_cast(K) + }, // GemmCoord problem_size + split_k, // int batch_count + {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, // typename EpilogueOutputOp::Params epilogue + ({{elem_input_type}}*)(a_ptr) + input_a_offset, // void const * ptr_A + ({{elem_input_type}}*)(b_ptr) + input_b_offset, // void const * ptr_B + ({{elem_input_type}}*)(bias_ptr), // void const * ptr_C + ({{elem_output_type}}*)(c_ptr) + output_offset, // void * ptr_D + input_a_batch_stride, // int64_t batch_stride_A + input_b_batch_stride, // int64_t batch_stride_B + N, // int64_t batch_stride_C + M * N, // int64_t batch_stride_D + input_a_stride, // typename LayoutA::Stride::LongIndex lda + input_b_stride, // typename LayoutB::Stride::LongIndex ldb + 0, // typename LayoutC::Stride::LongIndex ldc + output_stride, // typename LayoutC::Stride::LongIndex ldd +""" +) + + +# for profiler, no need to include TensorAccessor +PROFILER_PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode + cutlass::gemm::GemmCoord{ + static_cast(M), + static_cast(N), + static_cast(K) + }, // GemmCoord problem_size + split_k, // int batch_count + {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, // typename EpilogueOutputOp::Params epilogue + ({{elem_input_type}}*)(a_ptr), // void const * ptr_A + ({{elem_input_type}}*)(b_ptr), // void const * ptr_B + ({{elem_input_type}}*)(bias_ptr), // void const * ptr_C + ({{elem_output_type}}*)(c_ptr) + output_offset, // void * ptr_D + M * K, // int64_t batch_stride_A + K * N, // int64_t batch_stride_B + N, // int64_t batch_stride_C + M * N, // int64_t batch_stride_D + K, // typename LayoutA::Stride::LongIndex lda + N, // typename LayoutB::Stride::LongIndex ldb + 0, // typename LayoutC::Stride::LongIndex ldc + output_stride, // typename LayoutC::Stride::LongIndex ldd +""" +) + + +@registry.reg("cuda.gemm_rrr_bias.config") +def gemm_rrr_config(func_attrs, dtype="float16"): + common.make_fproc(func_attrs, RRR) + + +@registry.reg("cuda.gemm_rrr_bias.gen_profiler") +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + extra_code = EXTRA_CODE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) + return gemm_rrr.common_gen_profiler( + func_attrs=func_attrs, + workdir=workdir, + profiler_filename=profiler_filename, + dim_info_dict=dim_info_dict, + src_template=common_bias.SRC_TEMPLATE, + problem_args_template=PROFILER_PROBLEM_ARGS_TEMPLATE, + bias_ptr_arg="memory_pool->RequestTensorByIdx(3)", + extra_code=extra_code, + ) + + +@registry.reg("cuda.gemm_rrr_bias.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + input_addr_calculator = gemm_rrr.get_input_addr_calculator(func_attrs) + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + problem_args = PROBLEM_ARGS_TEMPLATE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) + extra_code = EXTRA_CODE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) + return common.gen_function( + func_attrs=func_attrs, + src_template=common_bias.SRC_TEMPLATE, + exec_cond_template=exec_cond_template, + problem_args=problem_args, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + output_ndims=output_ndims, + dim_info_dict=dim_info_dict, + support_split_k=True, + input_addr_calculator=input_addr_calculator, + output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N", output_accessor=func_attrs["output_accessors"][0] + ), + extra_code=extra_code, + ) + + +@registry.reg("cuda.gemm_rrr_bias.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_bias.FUNC_DECL_TEMPLATE.render( + func_name=func_name, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + support_split_k=True, + ) + + +@registry.reg("cuda.gemm_rrr_bias.func_call") +def gen_function_call(func_attrs, indent=" "): + bias = func_attrs["inputs"][2] + return common.gen_function_call( + func_attrs, indent, bias_ptr_arg=bias._attrs["name"] + ) + + +@registry.reg("cuda.gemm_rrr_bias.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/tests/unittest/ops/test_gemm_bias.py b/tests/unittest/ops/test_gemm_bias.py index cd276b739..330a030f2 100644 --- a/tests/unittest/ops/test_gemm_bias.py +++ b/tests/unittest/ops/test_gemm_bias.py @@ -149,6 +149,73 @@ def test_rcr_sm90(self) -> None: dtype="bfloat16", ) + def _test_rrr(self, Ms, N, K, test_name, dtype="float16"): + target = detect_target() + tolerance_limits = _TOLERANCE_LIMITS[dtype] + MDim = shape_utils.gen_int_var_min_max(Ms, name="m") + X = Tensor(shape=[MDim, IntImm(K)], dtype=dtype, name="input_0", is_input=True) + W = Tensor( + shape=[IntImm(K), IntImm(N)], dtype=dtype, name="input_1", is_input=True + ) + B = Tensor(shape=[IntImm(N)], dtype=dtype, name="input_2", is_input=True) + OP = ops.gemm_rrr_bias() + Y = OP(X, W, B) + Y._attrs["name"] = "output_0" + Y._attrs["is_output"] = True + module = compile_model( + Y, target, "./tmp", f"gemm_rrr_bias_{test_name}_{self._test_id}" + ) + self._test_id += 1 + + for M in Ms: + X_pt = get_random_torch_tensor([M, K], dtype) + W_pt = get_random_torch_tensor([N, K], dtype) + B_pt = get_random_torch_tensor([N], dtype) + Y_pt = torch.nn.functional.linear(X_pt, W_pt, bias=B_pt) + + W_transpose_pt = torch.transpose(W_pt, 0, 1).contiguous() + y = get_torch_empty_tensor([M, N], dtype) + module.run_with_tensors( + {"input_0": X_pt, "input_1": W_transpose_pt, "input_2": B_pt}, + [y], + ) + if X_pt.nelement() == 0 or W_pt.nelement() == 0: + pass + else: + torch.testing.assert_close(Y_pt, y, **tolerance_limits) + + def test_rrr_zero_size(self): + target = detect_target() + # This test triggered a c10 assertion failure internally + # caffe2/c10/util/SmallVector.h:338: + # Assertion `idx < size()' failed + if type(target).__name__ != "FBCUDA": + self._test_rrr([2], N=64, K=0, test_name="zero_k") + self._test_rrr([2], N=0, K=4, test_name="zero_n") + self._test_rrr([0], N=4, K=4, test_name="zero_m") + + def test_rrr_static(self): + self._test_rrr([4096], N=4, K=4, test_name="static") + self._test_rrr([1000], N=81, K=1024, test_name="static") + self._test_rrr([67200], N=3, K=256, test_name="static") + + def test_rrr_static_rocm(self): + self._test_rrr([4096], N=4, K=4, test_name="static") + self._test_rrr([1000], N=81, K=1024, test_name="static") + self._test_rrr([67200], N=3, K=256, test_name="static") + + def test_rrr_bfloat16_bf16(self): + dtype = "bfloat16" + self._test_rrr([4], N=2, K=11, test_name=f"static_{dtype}", dtype=dtype) + self._test_rrr([128], N=64, K=1024, test_name=f"static_{dtype}", dtype=dtype) + self._test_rrr( + [1, 7, 64, 127], + N=64, + K=1024, + test_name=f"dynamic_m_{dtype}", + dtype=dtype, + ) + filter_test_cases_by_test_env(GEMMBiasTestCase)