-
Notifications
You must be signed in to change notification settings - Fork 363
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #796 Reviewed By: wushirong Differential Revision: D47223564 Pulled By: aakhundov fbshipit-source-id: cb7b41846ac6eca95f9fc164828a2b13eb3325b4
- Loading branch information
1 parent
d947681
commit a706541
Showing
4 changed files
with
335 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
205 changes: 205 additions & 0 deletions
205
python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_bias.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
""" | ||
GEMM Specialization for | ||
C = GeMM(A, B) + bias | ||
where A[RowMajor][M, K], B[ColMajor][K, N], bias[RowMajor][N] | ||
""" | ||
import jinja2 | ||
|
||
from aitemplate.backend import registry | ||
|
||
from aitemplate.backend.backend_spec import CUDASpec | ||
from aitemplate.backend.cuda.gemm_universal import common, common_bias, gemm_rrr | ||
from aitemplate.backend.cuda.gemm_universal.layout import RRR | ||
|
||
# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 | ||
|
||
|
||
EXTRA_CODE = jinja2.Template( | ||
""" | ||
using elem_input_type = {{elem_input_type}}; | ||
using elem_output_type = {{elem_output_type}}; | ||
""" | ||
) | ||
|
||
|
||
# used for real execution | ||
PROBLEM_ARGS_TEMPLATE = jinja2.Template( | ||
""" | ||
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode | ||
cutlass::gemm::GemmCoord{ | ||
static_cast<coord_t>(M), | ||
static_cast<coord_t>(N), | ||
static_cast<coord_t>(K) | ||
}, // GemmCoord problem_size | ||
split_k, // int batch_count | ||
{ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, // typename EpilogueOutputOp::Params epilogue | ||
({{elem_input_type}}*)(a_ptr) + input_a_offset, // void const * ptr_A | ||
({{elem_input_type}}*)(b_ptr) + input_b_offset, // void const * ptr_B | ||
({{elem_input_type}}*)(bias_ptr), // void const * ptr_C | ||
({{elem_output_type}}*)(c_ptr) + output_offset, // void * ptr_D | ||
input_a_batch_stride, // int64_t batch_stride_A | ||
input_b_batch_stride, // int64_t batch_stride_B | ||
N, // int64_t batch_stride_C | ||
M * N, // int64_t batch_stride_D | ||
input_a_stride, // typename LayoutA::Stride::LongIndex lda | ||
input_b_stride, // typename LayoutB::Stride::LongIndex ldb | ||
0, // typename LayoutC::Stride::LongIndex ldc | ||
output_stride, // typename LayoutC::Stride::LongIndex ldd | ||
""" | ||
) | ||
|
||
|
||
# for profiler, no need to include TensorAccessor | ||
PROFILER_PROBLEM_ARGS_TEMPLATE = jinja2.Template( | ||
""" | ||
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode | ||
cutlass::gemm::GemmCoord{ | ||
static_cast<coord_t>(M), | ||
static_cast<coord_t>(N), | ||
static_cast<coord_t>(K) | ||
}, // GemmCoord problem_size | ||
split_k, // int batch_count | ||
{ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, // typename EpilogueOutputOp::Params epilogue | ||
({{elem_input_type}}*)(a_ptr), // void const * ptr_A | ||
({{elem_input_type}}*)(b_ptr), // void const * ptr_B | ||
({{elem_input_type}}*)(bias_ptr), // void const * ptr_C | ||
({{elem_output_type}}*)(c_ptr) + output_offset, // void * ptr_D | ||
M * K, // int64_t batch_stride_A | ||
K * N, // int64_t batch_stride_B | ||
N, // int64_t batch_stride_C | ||
M * N, // int64_t batch_stride_D | ||
K, // typename LayoutA::Stride::LongIndex lda | ||
N, // typename LayoutB::Stride::LongIndex ldb | ||
0, // typename LayoutC::Stride::LongIndex ldc | ||
output_stride, // typename LayoutC::Stride::LongIndex ldd | ||
""" | ||
) | ||
|
||
|
||
@registry.reg("cuda.gemm_rrr_bias.config") | ||
def gemm_rrr_config(func_attrs, dtype="float16"): | ||
common.make_fproc(func_attrs, RRR) | ||
|
||
|
||
@registry.reg("cuda.gemm_rrr_bias.gen_profiler") | ||
def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): | ||
backend_spec = CUDASpec() | ||
elem_input_type = backend_spec.dtype_to_lib_type( | ||
func_attrs["inputs"][0]._attrs["dtype"] | ||
) | ||
elem_output_type = backend_spec.dtype_to_lib_type( | ||
func_attrs["outputs"][0]._attrs["dtype"] | ||
) | ||
extra_code = EXTRA_CODE.render( | ||
elem_input_type=elem_input_type, | ||
elem_output_type=elem_output_type, | ||
) | ||
return gemm_rrr.common_gen_profiler( | ||
func_attrs=func_attrs, | ||
workdir=workdir, | ||
profiler_filename=profiler_filename, | ||
dim_info_dict=dim_info_dict, | ||
src_template=common_bias.SRC_TEMPLATE, | ||
problem_args_template=PROFILER_PROBLEM_ARGS_TEMPLATE, | ||
bias_ptr_arg="memory_pool->RequestTensorByIdx(3)", | ||
extra_code=extra_code, | ||
) | ||
|
||
|
||
@registry.reg("cuda.gemm_rrr_bias.gen_function") | ||
def gen_function( | ||
func_attrs, | ||
exec_cond_template, | ||
dim_info_dict, | ||
): | ||
input_addr_calculator = gemm_rrr.get_input_addr_calculator(func_attrs) | ||
input_ndims = len(func_attrs["input_accessors"][0].original_shapes) | ||
weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) | ||
output_ndims = len(func_attrs["output_accessors"][0].original_shapes) | ||
backend_spec = CUDASpec() | ||
elem_input_type = backend_spec.dtype_to_lib_type( | ||
func_attrs["inputs"][0]._attrs["dtype"] | ||
) | ||
elem_output_type = backend_spec.dtype_to_lib_type( | ||
func_attrs["outputs"][0]._attrs["dtype"] | ||
) | ||
problem_args = PROBLEM_ARGS_TEMPLATE.render( | ||
elem_input_type=elem_input_type, | ||
elem_output_type=elem_output_type, | ||
) | ||
extra_code = EXTRA_CODE.render( | ||
elem_input_type=elem_input_type, | ||
elem_output_type=elem_output_type, | ||
) | ||
return common.gen_function( | ||
func_attrs=func_attrs, | ||
src_template=common_bias.SRC_TEMPLATE, | ||
exec_cond_template=exec_cond_template, | ||
problem_args=problem_args, | ||
input_ndims=input_ndims, | ||
weight_ndims=weight_ndims, | ||
output_ndims=output_ndims, | ||
dim_info_dict=dim_info_dict, | ||
support_split_k=True, | ||
input_addr_calculator=input_addr_calculator, | ||
output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( | ||
stride_dim="N", output_accessor=func_attrs["output_accessors"][0] | ||
), | ||
extra_code=extra_code, | ||
) | ||
|
||
|
||
@registry.reg("cuda.gemm_rrr_bias.func_decl") | ||
def gen_function_decl(func_attrs): | ||
func_name = func_attrs["name"] | ||
input_ndims = len(func_attrs["input_accessors"][0].original_shapes) | ||
weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) | ||
return common_bias.FUNC_DECL_TEMPLATE.render( | ||
func_name=func_name, | ||
input_ndims=input_ndims, | ||
weight_ndims=weight_ndims, | ||
support_split_k=True, | ||
) | ||
|
||
|
||
@registry.reg("cuda.gemm_rrr_bias.func_call") | ||
def gen_function_call(func_attrs, indent=" "): | ||
bias = func_attrs["inputs"][2] | ||
return common.gen_function_call( | ||
func_attrs, indent, bias_ptr_arg=bias._attrs["name"] | ||
) | ||
|
||
|
||
@registry.reg("cuda.gemm_rrr_bias.filter") | ||
def function_filter(cfg, func_attrs, ab_alignment): | ||
"""Generates function filter. | ||
Parameters | ||
---------- | ||
cfg: str | ||
The filename generated for profiler. | ||
func_attrs : Dict | ||
Stores the operation attributes. | ||
ab_alignment: | ||
Input alignments. | ||
Returns | ||
------- | ||
bool | ||
If input cfg should be filtered. | ||
""" | ||
return common.function_filter(cfg, func_attrs, ab_alignment) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters