Skip to content

Commit

Permalink
add gemm_rrr_bias op support (#796)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #796

Reviewed By: wushirong

Differential Revision: D47223564

Pulled By: aakhundov

fbshipit-source-id: cb7b41846ac6eca95f9fc164828a2b13eb3325b4
  • Loading branch information
lh-ycx authored and facebook-github-bot committed Jul 5, 2023
1 parent d947681 commit a706541
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/aitemplate/backend/cuda/gemm_universal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
gemm_rcr_permute,
gemm_rcr_permute_elup1,
gemm_rrr,
gemm_rrr_bias,
gemm_rrr_permute,
group_gemm_rcr,
group_gemm_rcr_bias,
Expand Down
62 changes: 62 additions & 0 deletions python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,36 @@ def gemm_rrr_config(func_attrs, dtype="float16"):
op.C.element = cutlass_lib.library.DataType.void


def common_gen_profiler(
func_attrs,
workdir,
profiler_filename,
dim_info_dict,
src_template,
problem_args_template,
problem_args_template_cutlass_3x=None,
bias_ptr_arg=None,
extra_code="",
):
output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render(
stride_dim="*b_dim1"
)
return common.gen_profiler(
func_attrs=func_attrs,
workdir=workdir,
profiler_filename=profiler_filename,
dim_info_dict=dim_info_dict,
src_template=src_template,
problem_args_template=problem_args_template,
problem_args_template_cutlass_3x=problem_args_template_cutlass_3x,
args_parser_template=ARGS_PARSER_TEMPLATE,
support_split_k=True,
output_addr_calculator=output_addr_calculator,
bias_ptr_arg=bias_ptr_arg,
extra_code=extra_code,
)


@registry.reg("cuda.gemm_rrr.gen_profiler")
def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict):
output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render(
Expand All @@ -126,6 +156,38 @@ def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict):
)


def get_input_addr_calculator(func_attrs):
input_a_batch_stride_dim = "M * K"
input_a_stride_k_dim = "K"
input_a_offset = 0
input_b_batch_stride_dim = "K * N"
input_b_stride_k_dim = "N"
input_b_offset = 0

if "input_accessors" in func_attrs:
input_a_accessor = func_attrs["input_accessors"][0]
input_b_accessor = func_attrs["input_accessors"][1]
if input_a_accessor.is_from_strided_tensor:
input_a_offset = input_a_accessor.offset
shapes = input_a_accessor.original_shapes
input_a_stride_k_dim = input_a_accessor.stride(len(shapes) - 2)

if input_b_accessor.is_from_strided_tensor:
input_b_offset = input_b_accessor.offset
shapes = input_b_accessor.original_shapes
input_b_stride_k_dim = input_b_accessor.stride(len(shapes) - 2)

input_addr_calculator = common.INPUT_ADDR_CALCULATOR.render(
input_a_batch_stride_dim=input_a_batch_stride_dim,
input_a_stride_dim=input_a_stride_k_dim,
input_a_offset_val=input_a_offset,
input_b_batch_stride_dim=input_b_batch_stride_dim,
input_b_stride_dim=input_b_stride_k_dim,
input_b_offset_val=input_b_offset,
)
return input_addr_calculator


@registry.reg("cuda.gemm_rrr.gen_function")
def gen_function(
func_attrs,
Expand Down
205 changes: 205 additions & 0 deletions python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
GEMM Specialization for
C = GeMM(A, B) + bias
where A[RowMajor][M, K], B[ColMajor][K, N], bias[RowMajor][N]
"""
import jinja2

from aitemplate.backend import registry

from aitemplate.backend.backend_spec import CUDASpec
from aitemplate.backend.cuda.gemm_universal import common, common_bias, gemm_rrr
from aitemplate.backend.cuda.gemm_universal.layout import RRR

# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703


EXTRA_CODE = jinja2.Template(
"""
using elem_input_type = {{elem_input_type}};
using elem_output_type = {{elem_output_type}};
"""
)


# used for real execution
PROBLEM_ARGS_TEMPLATE = jinja2.Template(
"""
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
cutlass::gemm::GemmCoord{
static_cast<coord_t>(M),
static_cast<coord_t>(N),
static_cast<coord_t>(K)
}, // GemmCoord problem_size
split_k, // int batch_count
{ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, // typename EpilogueOutputOp::Params epilogue
({{elem_input_type}}*)(a_ptr) + input_a_offset, // void const * ptr_A
({{elem_input_type}}*)(b_ptr) + input_b_offset, // void const * ptr_B
({{elem_input_type}}*)(bias_ptr), // void const * ptr_C
({{elem_output_type}}*)(c_ptr) + output_offset, // void * ptr_D
input_a_batch_stride, // int64_t batch_stride_A
input_b_batch_stride, // int64_t batch_stride_B
N, // int64_t batch_stride_C
M * N, // int64_t batch_stride_D
input_a_stride, // typename LayoutA::Stride::LongIndex lda
input_b_stride, // typename LayoutB::Stride::LongIndex ldb
0, // typename LayoutC::Stride::LongIndex ldc
output_stride, // typename LayoutC::Stride::LongIndex ldd
"""
)


# for profiler, no need to include TensorAccessor
PROFILER_PROBLEM_ARGS_TEMPLATE = jinja2.Template(
"""
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
cutlass::gemm::GemmCoord{
static_cast<coord_t>(M),
static_cast<coord_t>(N),
static_cast<coord_t>(K)
}, // GemmCoord problem_size
split_k, // int batch_count
{ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, // typename EpilogueOutputOp::Params epilogue
({{elem_input_type}}*)(a_ptr), // void const * ptr_A
({{elem_input_type}}*)(b_ptr), // void const * ptr_B
({{elem_input_type}}*)(bias_ptr), // void const * ptr_C
({{elem_output_type}}*)(c_ptr) + output_offset, // void * ptr_D
M * K, // int64_t batch_stride_A
K * N, // int64_t batch_stride_B
N, // int64_t batch_stride_C
M * N, // int64_t batch_stride_D
K, // typename LayoutA::Stride::LongIndex lda
N, // typename LayoutB::Stride::LongIndex ldb
0, // typename LayoutC::Stride::LongIndex ldc
output_stride, // typename LayoutC::Stride::LongIndex ldd
"""
)


@registry.reg("cuda.gemm_rrr_bias.config")
def gemm_rrr_config(func_attrs, dtype="float16"):
common.make_fproc(func_attrs, RRR)


@registry.reg("cuda.gemm_rrr_bias.gen_profiler")
def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict):
backend_spec = CUDASpec()
elem_input_type = backend_spec.dtype_to_lib_type(
func_attrs["inputs"][0]._attrs["dtype"]
)
elem_output_type = backend_spec.dtype_to_lib_type(
func_attrs["outputs"][0]._attrs["dtype"]
)
extra_code = EXTRA_CODE.render(
elem_input_type=elem_input_type,
elem_output_type=elem_output_type,
)
return gemm_rrr.common_gen_profiler(
func_attrs=func_attrs,
workdir=workdir,
profiler_filename=profiler_filename,
dim_info_dict=dim_info_dict,
src_template=common_bias.SRC_TEMPLATE,
problem_args_template=PROFILER_PROBLEM_ARGS_TEMPLATE,
bias_ptr_arg="memory_pool->RequestTensorByIdx(3)",
extra_code=extra_code,
)


@registry.reg("cuda.gemm_rrr_bias.gen_function")
def gen_function(
func_attrs,
exec_cond_template,
dim_info_dict,
):
input_addr_calculator = gemm_rrr.get_input_addr_calculator(func_attrs)
input_ndims = len(func_attrs["input_accessors"][0].original_shapes)
weight_ndims = len(func_attrs["input_accessors"][1].original_shapes)
output_ndims = len(func_attrs["output_accessors"][0].original_shapes)
backend_spec = CUDASpec()
elem_input_type = backend_spec.dtype_to_lib_type(
func_attrs["inputs"][0]._attrs["dtype"]
)
elem_output_type = backend_spec.dtype_to_lib_type(
func_attrs["outputs"][0]._attrs["dtype"]
)
problem_args = PROBLEM_ARGS_TEMPLATE.render(
elem_input_type=elem_input_type,
elem_output_type=elem_output_type,
)
extra_code = EXTRA_CODE.render(
elem_input_type=elem_input_type,
elem_output_type=elem_output_type,
)
return common.gen_function(
func_attrs=func_attrs,
src_template=common_bias.SRC_TEMPLATE,
exec_cond_template=exec_cond_template,
problem_args=problem_args,
input_ndims=input_ndims,
weight_ndims=weight_ndims,
output_ndims=output_ndims,
dim_info_dict=dim_info_dict,
support_split_k=True,
input_addr_calculator=input_addr_calculator,
output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render(
stride_dim="N", output_accessor=func_attrs["output_accessors"][0]
),
extra_code=extra_code,
)


@registry.reg("cuda.gemm_rrr_bias.func_decl")
def gen_function_decl(func_attrs):
func_name = func_attrs["name"]
input_ndims = len(func_attrs["input_accessors"][0].original_shapes)
weight_ndims = len(func_attrs["input_accessors"][1].original_shapes)
return common_bias.FUNC_DECL_TEMPLATE.render(
func_name=func_name,
input_ndims=input_ndims,
weight_ndims=weight_ndims,
support_split_k=True,
)


@registry.reg("cuda.gemm_rrr_bias.func_call")
def gen_function_call(func_attrs, indent=" "):
bias = func_attrs["inputs"][2]
return common.gen_function_call(
func_attrs, indent, bias_ptr_arg=bias._attrs["name"]
)


@registry.reg("cuda.gemm_rrr_bias.filter")
def function_filter(cfg, func_attrs, ab_alignment):
"""Generates function filter.
Parameters
----------
cfg: str
The filename generated for profiler.
func_attrs : Dict
Stores the operation attributes.
ab_alignment:
Input alignments.
Returns
-------
bool
If input cfg should be filtered.
"""
return common.function_filter(cfg, func_attrs, ab_alignment)
67 changes: 67 additions & 0 deletions tests/unittest/ops/test_gemm_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,73 @@ def test_rcr_sm90(self) -> None:
dtype="bfloat16",
)

def _test_rrr(self, Ms, N, K, test_name, dtype="float16"):
target = detect_target()
tolerance_limits = _TOLERANCE_LIMITS[dtype]
MDim = shape_utils.gen_int_var_min_max(Ms, name="m")
X = Tensor(shape=[MDim, IntImm(K)], dtype=dtype, name="input_0", is_input=True)
W = Tensor(
shape=[IntImm(K), IntImm(N)], dtype=dtype, name="input_1", is_input=True
)
B = Tensor(shape=[IntImm(N)], dtype=dtype, name="input_2", is_input=True)
OP = ops.gemm_rrr_bias()
Y = OP(X, W, B)
Y._attrs["name"] = "output_0"
Y._attrs["is_output"] = True
module = compile_model(
Y, target, "./tmp", f"gemm_rrr_bias_{test_name}_{self._test_id}"
)
self._test_id += 1

for M in Ms:
X_pt = get_random_torch_tensor([M, K], dtype)
W_pt = get_random_torch_tensor([N, K], dtype)
B_pt = get_random_torch_tensor([N], dtype)
Y_pt = torch.nn.functional.linear(X_pt, W_pt, bias=B_pt)

W_transpose_pt = torch.transpose(W_pt, 0, 1).contiguous()
y = get_torch_empty_tensor([M, N], dtype)
module.run_with_tensors(
{"input_0": X_pt, "input_1": W_transpose_pt, "input_2": B_pt},
[y],
)
if X_pt.nelement() == 0 or W_pt.nelement() == 0:
pass
else:
torch.testing.assert_close(Y_pt, y, **tolerance_limits)

def test_rrr_zero_size(self):
target = detect_target()
# This test triggered a c10 assertion failure internally
# caffe2/c10/util/SmallVector.h:338:
# Assertion `idx < size()' failed
if type(target).__name__ != "FBCUDA":
self._test_rrr([2], N=64, K=0, test_name="zero_k")
self._test_rrr([2], N=0, K=4, test_name="zero_n")
self._test_rrr([0], N=4, K=4, test_name="zero_m")

def test_rrr_static(self):
self._test_rrr([4096], N=4, K=4, test_name="static")
self._test_rrr([1000], N=81, K=1024, test_name="static")
self._test_rrr([67200], N=3, K=256, test_name="static")

def test_rrr_static_rocm(self):
self._test_rrr([4096], N=4, K=4, test_name="static")
self._test_rrr([1000], N=81, K=1024, test_name="static")
self._test_rrr([67200], N=3, K=256, test_name="static")

def test_rrr_bfloat16_bf16(self):
dtype = "bfloat16"
self._test_rrr([4], N=2, K=11, test_name=f"static_{dtype}", dtype=dtype)
self._test_rrr([128], N=64, K=1024, test_name=f"static_{dtype}", dtype=dtype)
self._test_rrr(
[1, 7, 64, 127],
N=64,
K=1024,
test_name=f"dynamic_m_{dtype}",
dtype=dtype,
)


filter_test_cases_by_test_env(GEMMBiasTestCase)

Expand Down

0 comments on commit a706541

Please sign in to comment.