diff --git a/python/aitemplate/frontend/nn/attention.py b/python/aitemplate/frontend/nn/attention.py index 1f1240762..7c9f36279 100644 --- a/python/aitemplate/frontend/nn/attention.py +++ b/python/aitemplate/frontend/nn/attention.py @@ -16,6 +16,7 @@ Frontend for attention module """ from aitemplate.compiler import ops +from aitemplate.compiler.base import IntVar from aitemplate.compiler.ops import flash_attention from aitemplate.compiler.ops.common.epilogue import FuncEnum from aitemplate.frontend import Tensor @@ -256,7 +257,7 @@ def attention(self, x): ) return out - def forward(self, *args): + def forward(self, *args, seqlens=None): """forward pass for calling mha module""" assert len(args) >= 1 x = args[0] @@ -361,7 +362,7 @@ def __init__( ) self.proj_drop = Dropout(proj_drop, dtype=dtype) - def attention(self, q, k, v): + def attention(self, q, k, v, seqlens=None): batch = q.shape()[0] head_dim = self.dim // self.num_heads @@ -369,24 +370,47 @@ def attention(self, q, k, v): key = self.proj_k(k) value = self.proj_v(v) - query = ops.permute()( - ops.reshape()(query, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3] - ) - key = ops.permute()( - ops.reshape()(key, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3] - ) - value = ops.permute()( - ops.reshape()(value, [batch, -1, self.num_heads, head_dim]), - [0, 2, 1, 3], + if detect_target().name() == "cuda": + query = ops.permute()( + ops.reshape()(query, [batch, -1, self.num_heads, head_dim]), + [0, 2, 1, 3], + ) + key = ops.permute()( + ops.reshape()(key, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3] + ) + value = ops.permute()( + ops.reshape()(value, [batch, -1, self.num_heads, head_dim]), + [0, 2, 1, 3], + ) + return self.op(query, key, value) + elif seqlens: + query = ops.reshape()(query, [batch, self.num_heads, -1, head_dim]) + key = ops.reshape()(key, [batch, self.num_heads, -1, head_dim]) + value = ops.reshape()(value, [batch, self.num_heads, -1, head_dim]) + return self.op(query, key, value, seqlens) + + query = ops.reshape()(query, [batch, -1, self.num_heads, head_dim]) + query = ops.transpose()(query, 1, 2) + query = ops.reshape()(query, [-1, query.shape()[2], head_dim]) + key = ops.reshape()(key, [batch, -1, self.num_heads, head_dim]) + key = ops.transpose()(key, 1, 2) + key = ops.reshape()(key, [-1, key.shape()[2], head_dim]) + value = ops.reshape()(value, [batch, -1, self.num_heads, head_dim]) + value = ops.transpose()(value, 1, 2) + value = ops.reshape()(value, [-1, value.shape()[2], head_dim]) + OP = ops.bmm_softmax_bmm_permute( + shape=(self.num_heads,), + scale=head_dim**-0.5, + causal=self.causal, ) - return self.op(query, key, value) + return OP(query, key, value) - def forward(self, *args): + def forward(self, *args, seqlens=None): """forward pass for calling mha module""" assert len(args) >= 3 x = args[0] batch = x.shape()[0] - attn_output = self.attention(args[0], args[1], args[2]) + attn_output = self.attention(args[0], args[1], args[2], seqlens=seqlens) attn_output = ops.reshape()(attn_output, [batch, -1, self.dim]) if self.has_residual: @@ -395,7 +419,8 @@ def forward(self, *args): else: x = self.proj(attn_output) x = self.proj_drop(x) - x = ops.reshape()(x, [batch, -1, self.dim]) + if not isinstance(batch, IntVar): + x = ops.reshape()(x, [batch, -1, self.dim]) return x diff --git a/python/aitemplate/frontend/nn/embedding.py b/python/aitemplate/frontend/nn/embedding.py index f5144eca1..fc0b29b55 100644 --- a/python/aitemplate/frontend/nn/embedding.py +++ b/python/aitemplate/frontend/nn/embedding.py @@ -13,12 +13,10 @@ # limitations under the License. # from aitemplate.compiler import ops -from aitemplate.compiler.public import FuncEnum from aitemplate.frontend.nn.dropout import Dropout from aitemplate.frontend.nn.layer_norm import LayerNorm from aitemplate.frontend.nn.module import Module from aitemplate.frontend.nn.parameter import Parameter -from aitemplate.testing import detect_target class Embedding(Module): @@ -61,8 +59,6 @@ def __init__( dtype="float16", ): super().__init__() - if BertEmbeddings.USE_CUDA is None: - BertEmbeddings.USE_CUDA = detect_target().name() == "cuda" assert ( hidden_dropout_prob == 0.0 ), "Dropout rate larger than 0 is not supported yet." @@ -85,48 +81,16 @@ def forward( token_type_ids, # [B, S] position_ids, # [B, S] ): - if self.USE_CUDA: - embeddings = ops.bert_embeddings()( - input_ids, - token_type_ids, - position_ids, - self.word_embeddings.weight.tensor(), - self.token_type_embeddings.weight.tensor(), - self.position_embeddings.weight.tensor(), - self.LayerNorm.weight.tensor(), - self.LayerNorm.bias.tensor(), - self.LayerNorm.eps, - ) - embeddings = self.dropout(embeddings) - return embeddings - - input_shape = ops.size()(input_ids) - - # [B * S] - input_ids = ops.reshape()(input_ids, [-1]) - token_type_ids = ops.reshape()(token_type_ids, [-1]) - position_ids = ops.reshape()(position_ids, [-1]) - - # [B * S, H] - input_embeddings = ops.batch_gather()(self.word_embeddings.tensor(), input_ids) - token_type_embeddings = ops.batch_gather()( - self.token_type_embeddings.tensor(), token_type_ids - ) - position_embeddings = ops.batch_gather()( - self.position_embeddings.tensor(), position_ids - ) - - # add - embeddings = ops.elementwise(FuncEnum.ADD)( - input_embeddings, token_type_embeddings + embeddings = ops.bert_embeddings()( + input_ids, + token_type_ids, + position_ids, + self.word_embeddings.weight.tensor(), + self.token_type_embeddings.weight.tensor(), + self.position_embeddings.weight.tensor(), + self.LayerNorm.weight.tensor(), + self.LayerNorm.bias.tensor(), + self.LayerNorm.eps, ) - - embeddings = ops.elementwise(FuncEnum.ADD)(embeddings, position_embeddings) - - # norm - embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) - - embeddings = ops.reshape()(embeddings, input_shape + [-1]) - return embeddings diff --git a/python/aitemplate/frontend/nn/module.py b/python/aitemplate/frontend/nn/module.py index c51a49db9..391d9d5d7 100644 --- a/python/aitemplate/frontend/nn/module.py +++ b/python/aitemplate/frontend/nn/module.py @@ -296,7 +296,6 @@ def get_submodule(self, target: str) -> "Module": mod: Module = self for item in atoms: - if not hasattr(mod, item): raise AttributeError( mod._get_name() + " has no " "attribute `" + item + "`" diff --git a/python/aitemplate/frontend/nn/multiscale_attention.py b/python/aitemplate/frontend/nn/multiscale_attention.py index 848cb1637..53fe02300 100644 --- a/python/aitemplate/frontend/nn/multiscale_attention.py +++ b/python/aitemplate/frontend/nn/multiscale_attention.py @@ -375,7 +375,6 @@ def __init__( ## TODO: add pool mode support for {"max", "avg"} elif pool_mode == "conv": - self.pool_q = ( Conv3d( head_dim, diff --git a/python/aitemplate/utils/mk_ck_lib/conv2d_operation.py b/python/aitemplate/utils/mk_ck_lib/conv2d_operation.py index 931651b99..4c46deeb2 100644 --- a/python/aitemplate/utils/mk_ck_lib/conv2d_operation.py +++ b/python/aitemplate/utils/mk_ck_lib/conv2d_operation.py @@ -266,7 +266,6 @@ def accumulator_type(self): return library.DataType.f32 def emit(self) -> str: - template = jinja2.Template( """ using {{name}} = {{xdl_op_type}}< @@ -285,7 +284,7 @@ def emit(self) -> str: {{WeiLayout}}, // WeiLayout {% if func=="PT" %} ck::Tuple<>, - {% elif func=="AAR" %} + {% elif func in ["AA", "AAR"] %} ck::Tuple<{{OutLayout}}, {{OutLayout}}>, // BiasLayout {% else %} {% if "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1" in xdl_op_type %} @@ -301,7 +300,7 @@ def emit(self) -> str: {{CShuffleDType}}, // CShuffleDataType {% if func=="PT" %} ck::Tuple<>, - {% elif func=="AAR" %} + {% elif func in ["AA", "AAR"] %} ck::Tuple<{{CDType}}, {{CDType}}>, // BiasLayout {% else %} ck::Tuple<{{CDType}}>, // BiasDataType diff --git a/python/aitemplate/utils/mk_ck_lib/gemm_operation.py b/python/aitemplate/utils/mk_ck_lib/gemm_operation.py index 28b44f308..dc1557a5b 100644 --- a/python/aitemplate/utils/mk_ck_lib/gemm_operation.py +++ b/python/aitemplate/utils/mk_ck_lib/gemm_operation.py @@ -51,6 +51,7 @@ class XdlOpType(enum.Enum): DeviceBatchedContractionMultipleD_Xdl_CShuffle = auto() DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle = auto() DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle = auto() + DeviceBatchedGemmMultiD_Xdl = auto() XdlOpTag = { @@ -62,6 +63,7 @@ class XdlOpType(enum.Enum): XdlOpType.DeviceBatchedContractionMultipleD_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Xdl_CShuffle", XdlOpType.DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle", XdlOpType.DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle", + XdlOpType.DeviceBatchedGemmMultiD_Xdl: "ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl", } @@ -247,7 +249,11 @@ def __str__(self) -> str: _{{n_xdl_per_wave}} {{m_n_block_wave_per_xdl|join('_')}}S {{scalar_per_vector}} - {{causal_mask}} + {% if causal_mask == 1 %} + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle // causal_mask + {% else %} + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled // causal_mask + {% endif %} """, trim_blocks=True, lstrip_blocks=True, @@ -264,7 +270,11 @@ def emit(self) -> str: {{n_xdl_per_wave}}, // n_xdl_per_wave ck::Sequence<{{m_n_block_wave_per_xdl|join(',')}}>, // m_n_block_wave_per_xdl {{scalar_per_vector}}, // scalar_per_vector - {{causal_mask}} // causal_mask + {% if causal_mask == 1 %} + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle // causal_mask + {% else %} + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled // causal_mask + {% endif %} """, trim_blocks=True, lstrip_blocks=True, @@ -392,21 +402,38 @@ def emit(self) -> str: ck::Tuple, {% endif %} ck::half_t, -{% elif xdl_op_type_value in [7, 8] %} +{% elif xdl_op_type_value == 7 %} {{ALayout}}, {{BLayout}}, {{CLayout}}, - {% if xdl_op_type_value == 8 %} - ck::Sequence<2,1,1>, - {% else %} {{CLayout}}, - {% endif %} {{ADType}}, {{BDType}}, {{BDType}}, {{CDType}}, {{AccDType}}, float, // CShuffleDType, +{% elif xdl_op_type_value == 8 %} + 2, 1, 1, 1, 1, + {{ADType}}, + {{BDType}}, + {{BDType}}, + {{CDType}}, + ck::Tuple<>, + ck::Tuple<>, + {{AccDType}}, + float, // CShuffleDType, +{% elif xdl_op_type_value == 9 %} + {{ALayout}}, + {{BLayout}}, + ck::Tuple<{{DsLayout}}>, // DsLayout + {{CLayout}}, + {{ADType}}, + {{BDType}}, + {{AccDType}}, + {{CShuffleDType}}, + ck::Tuple<{{DsDType}}>, // DsType + {{EDType}}, {% endif %} {% if xdl_op_type_value in [7, 8] %} {{A_elem_op}}, @@ -423,6 +450,11 @@ def emit(self) -> str: ck::tensor_operation::device::TensorSpecialization::Packed, ck::tensor_operation::device::TensorSpecialization::Packed, ck::tensor_operation::device::TensorSpecialization::Default, + {% elif xdl_op_type_value==8 %} + ck::tensor_operation::device::TensorSpecialization::Default, + ck::tensor_operation::device::TensorSpecialization::Default, + ck::tensor_operation::device::TensorSpecialization::Default, + ck::tensor_operation::device::TensorSpecialization::Default, {% endif %} 1, {% endif %} diff --git a/python/aitemplate/utils/mk_ck_lib/generator.py b/python/aitemplate/utils/mk_ck_lib/generator.py index 91b44ea7d..e8f89f666 100644 --- a/python/aitemplate/utils/mk_ck_lib/generator.py +++ b/python/aitemplate/utils/mk_ck_lib/generator.py @@ -24,6 +24,7 @@ softmax_operation as softmax, ) + ########################################################################################################### # Convolution for 2D Fwd operations def CreateConv2dFwdOperator(manifest, operation_kind, out_element_op, out_data_op=""): @@ -1390,7 +1391,6 @@ def CreateBmmSoftmaxBmmOperator( ] c_block_descriptions, b1_block_descriptions = [], [] for i in range(len(tile_descriptions)): - if i in [0, 2, 4, 5, 9, 11]: block_transfer = [16, 16, 1] else: @@ -1505,7 +1505,6 @@ def CreateBmmSoftmaxBmmPermOperator( c_block_descriptions, b1_block_descriptions = [], [] for i in range(len(tile_descriptions)): - if i in [0, 2, 4, 5, 9, 11]: block_transfer = [16, 16, 1] else: @@ -1667,6 +1666,354 @@ def CreateBmmRRROperator(manifest): return operations +def CreateBmmRRRBillinearOperator(manifest, c_element_op): + operation_kind = library.GemmKind.BatchGemm + a_element_desc = library.TensorDesc( + library.DataType.f16, library.LayoutType.RowMajor + ) + b_element_desc = library.TensorDesc( + library.DataType.f16, library.LayoutType.RowMajor + ) + c_element_desc = library.TensorDesc( + library.DataType.f16, library.LayoutType.RowMajor + ) + element_op = library.TensorOperation.PassThrough + # 0 indicates not print + tile_descriptions = [ + gemm.TileDesc(256, 256, 128, 32, 8, 2, 32, 32, 4, 2), + gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 256, 32, 8, 2, 32, 32, 2, 4), + gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4), + gemm.TileDesc(128, 128, 128, 32, 8, 2, 32, 32, 4, 2), + gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 2, 32, 32, 2, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 64, 32, 8, 2, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 32, 8, 2, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(256, 128, 64, 32, 8, 2, 32, 32, 2, 1), + gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(256, 64, 128, 32, 8, 2, 32, 32, 1, 2), + gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2), + ] + + a_block_descriptions = [] + c_block_descriptions = [] + for t in tile_descriptions: + a_block_transfer = -1 + c_block_transfer = -1 + if t.block_size == 256: + a_block_transfer = [4, 64, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8) + if t.block_size == 128 and t.n_per_block != 64: + a_block_transfer = [4, 32, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8) + if t.block_size == 128 and t.n_per_block == 64: + a_block_transfer = [4, 32, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8) + + assert ( + a_block_transfer != -1 + and c_block_transfer != -1 + and "Cannot determine block_transfer_size with block_size " + + str(t.block_size) + ) + a_block_descriptions.append( + gemm.BlockTransferDesc( + a_block_transfer, [1, 0, 2], [1, 0, 2], 2, 8, 8, 1, True + ) + ) + c_block_descriptions.append(c_block_transfer) + b_block_descriptions = [ + gemm.BlockTransferDesc([8, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 4, 8, 1, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 8, 1, True), + gemm.BlockTransferDesc([8, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1, True), + gemm.BlockTransferDesc([8, 16, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 8, 1, True), + gemm.BlockTransferDesc([16, 16, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 1, 8, 1, True), + gemm.BlockTransferDesc([8, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1, True), + ] + gemm_specialization = [ + gemm.GemmSpecialization.GemmDefault, + gemm.GemmSpecialization.MNKPadding, + ] + operations = [] + ds_dtype = [library.DataType.f16] + ds_layout = [library.LayoutType.RowMajor] + e_dtype = library.DataType.f16 + for gemm_spec in gemm_specialization: + for tile_desc, a_block_desc, b_block_desc, c_block_desc in zip( + tile_descriptions, + a_block_descriptions, + b_block_descriptions, + c_block_descriptions, + ): + new_operation = gemm.GemmOperation( + operation_kind=operation_kind, + extra_kind=c_element_op, + xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmMultiD_Xdl, + A=a_element_desc, + B=b_element_desc, + C=c_element_desc, + a_elem_op=element_op, + b_elem_op=element_op, + epilogue_functor=c_element_op, + gemm_specialization=gemm_spec, + tile_desc=tile_desc, + a_block_transfer=a_block_desc, + b_block_transfer=b_block_desc, + c_block_transfer=c_block_desc, + ds_dtype=ds_dtype, + ds_layout=ds_layout, + e_dtype=e_dtype, + ) + manifest.append(new_operation) + operations.append(new_operation) + return operations + + +def CreateBmmCCRBillinearOperator(manifest, c_element_op): + operation_kind = library.GemmKind.BatchGemm + a_element_desc = library.TensorDesc( + library.DataType.f16, library.LayoutType.ColumnMajor + ) + b_element_desc = library.TensorDesc( + library.DataType.f16, library.LayoutType.ColumnMajor + ) + c_element_desc = library.TensorDesc( + library.DataType.f16, library.LayoutType.RowMajor + ) + element_op = library.TensorOperation.PassThrough + # 0 indicates not print + tile_descriptions = [ + gemm.TileDesc(256, 256, 128, 32, 2, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 256, 32, 2, 8, 32, 32, 2, 4), + gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4), + gemm.TileDesc(128, 128, 128, 32, 2, 8, 32, 32, 4, 2), + gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 128, 32, 2, 8, 32, 32, 2, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 64, 32, 2, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 32, 2, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(256, 128, 64, 32, 2, 8, 32, 32, 2, 1), + gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(256, 64, 128, 32, 2, 8, 32, 32, 1, 2), + gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2), + ] + + b_block_descriptions = [] + c_block_descriptions = [] + for t in tile_descriptions: + b_block_transfer = -1 + c_block_transfer = -1 + if t.block_size == 256: + b_block_transfer = [4, 64, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8) + if t.block_size == 128 and t.n_per_block != 64: + b_block_transfer = [4, 32, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8) + if t.block_size == 128 and t.n_per_block == 64: + b_block_transfer = [4, 32, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8) + + assert ( + b_block_transfer != -1 + and c_block_transfer != -1 + and "Cannot determine block_transfer_size with block_size " + + str(t.block_size) + ) + b_block_descriptions.append( + gemm.BlockTransferDesc( + b_block_transfer, [1, 0, 2], [1, 0, 2], 2, 8, 8, 1, True + ) + ) + c_block_descriptions.append(c_block_transfer) + a_block_descriptions = [ + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 4, 8, 1, True), + gemm.BlockTransferDesc([8, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 8, 1, True), + gemm.BlockTransferDesc([8, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 8, 1, True), + gemm.BlockTransferDesc([8, 16, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1, True), + gemm.BlockTransferDesc([8, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1, True), + gemm.BlockTransferDesc([16, 16, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 1, 8, 1, True), + ] + gemm_specialization = [ + gemm.GemmSpecialization.GemmDefault, + gemm.GemmSpecialization.MNKPadding, + ] + operations = [] + ds_dtype = [library.DataType.f16] + ds_layout = [library.LayoutType.RowMajor] + e_dtype = library.DataType.f16 + for gemm_spec in gemm_specialization: + for tile_desc, a_block_desc, b_block_desc, c_block_desc in zip( + tile_descriptions, + a_block_descriptions, + b_block_descriptions, + c_block_descriptions, + ): + new_operation = gemm.GemmOperation( + operation_kind=operation_kind, + extra_kind=c_element_op, + xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmMultiD_Xdl, + A=a_element_desc, + B=b_element_desc, + C=c_element_desc, + a_elem_op=element_op, + b_elem_op=element_op, + epilogue_functor=c_element_op, + gemm_specialization=gemm_spec, + tile_desc=tile_desc, + a_block_transfer=a_block_desc, + b_block_transfer=b_block_desc, + c_block_transfer=c_block_desc, + ds_dtype=ds_dtype, + ds_layout=ds_layout, + e_dtype=e_dtype, + ) + manifest.append(new_operation) + operations.append(new_operation) + return operations + + +def CreateBmmCRRBillinearOperator(manifest, c_element_op): + operation_kind = library.GemmKind.BatchGemm + a_element_desc = library.TensorDesc( + library.DataType.f16, library.LayoutType.ColumnMajor + ) + b_element_desc = library.TensorDesc( + library.DataType.f16, library.LayoutType.RowMajor + ) + c_element_desc = library.TensorDesc( + library.DataType.f16, library.LayoutType.RowMajor + ) + element_op = library.TensorOperation.PassThrough + # 0 indicates not print + tile_descriptions = [ + gemm.TileDesc(256, 256, 128, 32, 2, 2, 32, 32, 4, 2), + gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 256, 32, 2, 2, 32, 32, 2, 4), + gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4), + gemm.TileDesc(128, 128, 128, 32, 2, 2, 32, 32, 4, 2), + gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 128, 32, 2, 2, 32, 32, 2, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 64, 32, 2, 2, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 32, 2, 2, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(256, 128, 64, 32, 2, 2, 32, 32, 2, 1), + gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(256, 64, 128, 32, 2, 2, 32, 32, 1, 2), + gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2), + ] + + b_block_descriptions = [] + c_block_descriptions = [] + for t in tile_descriptions: + b_block_transfer = -1 + c_block_transfer = -1 + if t.block_size == 256: + b_block_transfer = [4, 64, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8) + if t.block_size == 128 and t.n_per_block != 64: + b_block_transfer = [4, 32, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8) + if t.block_size == 128 and t.n_per_block == 64: + b_block_transfer = [4, 32, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8) + + assert ( + b_block_transfer != -1 + and c_block_transfer != -1 + and "Cannot determine block_transfer_size with block_size " + + str(t.block_size) + ) + b_block_descriptions.append( + gemm.BlockTransferDesc( + b_block_transfer, [1, 0, 2], [1, 0, 2], 2, 8, 8, 1, True + ) + ) + c_block_descriptions.append(c_block_transfer) + a_block_descriptions = [ + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 4, 8, 1, True), + gemm.BlockTransferDesc([8, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 8, 1, True), + gemm.BlockTransferDesc([8, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 8, 1, True), + gemm.BlockTransferDesc([8, 16, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1, True), + gemm.BlockTransferDesc([8, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1, True), + gemm.BlockTransferDesc([16, 16, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0, True), + gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 1, 8, 1, True), + ] + gemm_specialization = [ + gemm.GemmSpecialization.GemmDefault, + gemm.GemmSpecialization.MNKPadding, + ] + operations = [] + ds_dtype = [library.DataType.f16] + ds_layout = [library.LayoutType.RowMajor] + e_dtype = library.DataType.f16 + for gemm_spec in gemm_specialization: + for tile_desc, a_block_desc, b_block_desc, c_block_desc in zip( + tile_descriptions, + a_block_descriptions, + b_block_descriptions, + c_block_descriptions, + ): + new_operation = gemm.GemmOperation( + operation_kind=operation_kind, + extra_kind=c_element_op, + xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmMultiD_Xdl, + A=a_element_desc, + B=b_element_desc, + C=c_element_desc, + a_elem_op=element_op, + b_elem_op=element_op, + epilogue_functor=c_element_op, + gemm_specialization=gemm_spec, + tile_desc=tile_desc, + a_block_transfer=a_block_desc, + b_block_transfer=b_block_desc, + c_block_transfer=c_block_desc, + ds_dtype=ds_dtype, + ds_layout=ds_layout, + e_dtype=e_dtype, + ) + manifest.append(new_operation) + operations.append(new_operation) + return operations + + def CreateBmmRRRPermOperator(manifest): operation_kind = library.GemmKind.BatchGemmPermute a_element_desc = library.TensorDesc( @@ -1980,6 +2327,8 @@ def CreateLayerNormOperator(manifest, rank=2): layernorm.TileDesc(256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8), layernorm.TileDesc(256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8), layernorm.TileDesc(256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8), + layernorm.TileDesc(1024, 1, 1024, 1, 32, 1, 8, 1, 8, 1, 8, 8), + layernorm.TileDesc(1024, 1, 1024, 1, 8, 1, 2, 1, 2, 1, 2, 2), ] operations = [] @@ -2056,6 +2405,12 @@ def GenerateTensorOp(manifest): library.TensorOperation.AddRelu, library.MemoryDataOperation.MemorySet, ) + # Conv2dBiasAdd + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasRelu, + library.TensorOperation.AddAdd, + ) # Conv2dBiasReluAdd CreateConv2dFwdOperator( manifest, @@ -2101,8 +2456,10 @@ def GenerateTensorOp(manifest): CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddTanh) # GemmRCRBiasTanh CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddFastGelu) - # GemmRCRBiasSwish + # GemmRCRBiasHardswish CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddHardswish) + # GemmRCRBiasSwish + CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddSwish) # GemmRCRBiasSigmoid CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddSigmoid) # GemmRCRBiasAdd @@ -2127,6 +2484,12 @@ def GenerateTensorOp(manifest): CreateBmmRCROperator(manifest) # BmmRRR CreateBmmRRROperator(manifest) + # BmmRRRAdd + CreateBmmRRRBillinearOperator(manifest, library.TensorOperation.Add) + # BmmCRRAdd + CreateBmmCRRBillinearOperator(manifest, library.TensorOperation.Add) + # BmmCRRAdd + CreateBmmCCRBillinearOperator(manifest, library.TensorOperation.Add) # BmmCCR CreateBmmCCROperator(manifest) # BmmCRR diff --git a/python/aitemplate/utils/mk_ck_lib/groupnorm_operation.py b/python/aitemplate/utils/mk_ck_lib/groupnorm_operation.py index e61fa7ef9..969efc6ed 100644 --- a/python/aitemplate/utils/mk_ck_lib/groupnorm_operation.py +++ b/python/aitemplate/utils/mk_ck_lib/groupnorm_operation.py @@ -78,7 +78,7 @@ def accumulator_type(self): def emit(self) -> str: template = jinja2.Template( """ -using {{name}} = ck::tensor_operation::device::DeviceLayernormImpl< +using {{name}} = ck::tensor_operation::device::DeviceNormalizationImpl< {{InDType}}, {{InDType}}, {{InDType}}, diff --git a/python/aitemplate/utils/mk_ck_lib/layernorm_operation.py b/python/aitemplate/utils/mk_ck_lib/layernorm_operation.py index 6e28da94f..264cba714 100644 --- a/python/aitemplate/utils/mk_ck_lib/layernorm_operation.py +++ b/python/aitemplate/utils/mk_ck_lib/layernorm_operation.py @@ -78,7 +78,7 @@ def accumulator_type(self): def emit(self) -> str: template = jinja2.Template( """ -using {{name}} = ck::tensor_operation::device::DeviceLayernormImpl< +using {{name}} = ck::tensor_operation::device::DeviceNormalizationImpl< {{InDType}}, {{InDType}}, {{InDType}}, @@ -94,7 +94,7 @@ def emit(self) -> str: return template.render( name=self.__str__(), InDType=library.DataTypeTag[self.In], - AccDType=library.DataTypeTag[library.DataType.f32], + AccDType=library.DataTypeTag[self.accumulator_type()], OutDType=library.DataTypeTag[self.Out], Rank=self.Rank, NumReduceDim=self.NumReduceDim, # we only need softmax(dim=-1) at this moment diff --git a/python/aitemplate/utils/mk_ck_lib/library.py b/python/aitemplate/utils/mk_ck_lib/library.py index a3fdb1c00..4b6a357b9 100644 --- a/python/aitemplate/utils/mk_ck_lib/library.py +++ b/python/aitemplate/utils/mk_ck_lib/library.py @@ -201,6 +201,7 @@ class LayoutType(enum.Enum): LayoutType.GNWK: "GNWK", } + # class OperationKind(enum.Enum): Gemm = auto() @@ -282,6 +283,7 @@ class TensorOperation(enum.Enum): AddFastGelu = auto() AddTanh = auto() AddHardswish = auto() + AddSwish = auto() AddSigmoid = auto() AddReluAdd = auto() AddAddRelu = auto() @@ -312,6 +314,7 @@ class TensorOperation(enum.Enum): TensorOperation.AddTanh: "ck::tensor_operation::element_wise::AddTanh", TensorOperation.AddSigmoid: "ck::tensor_operation::element_wise::AddSigmoid", TensorOperation.AddHardswish: "ck::tensor_operation::element_wise::AddHardswish", + TensorOperation.AddSwish: "ck::tensor_operation::element_wise::AddSwish", TensorOperation.AddReluAdd: "ck::tensor_operation::element_wise::AddReluAdd", TensorOperation.AddAddRelu: "ck::tensor_operation::element_wise::AddAddRelu", TensorOperation.AddHardswishAdd: "ck::tensor_operation::element_wise::AddHardswishAdd", @@ -341,6 +344,7 @@ class TensorOperation(enum.Enum): TensorOperation.AddTanh: "AT", TensorOperation.AddSigmoid: "AS", TensorOperation.AddHardswish: "AH", + TensorOperation.AddSwish: "ASW", TensorOperation.AddReluAdd: "ARA", TensorOperation.AddAddRelu: "AAR", TensorOperation.AddHardswishAdd: "AHA", diff --git a/python/aitemplate/utils/mk_ck_lib/manifest.py b/python/aitemplate/utils/mk_ck_lib/manifest.py index 077ee9103..c572737d8 100644 --- a/python/aitemplate/utils/mk_ck_lib/manifest.py +++ b/python/aitemplate/utils/mk_ck_lib/manifest.py @@ -87,7 +87,6 @@ def get_kernel_filters(self, kernelListFile): return [] def filter_out_kernels(self, kernel_name, kernel_filter_list): - for kernel_filter_re in kernel_filter_list: if kernel_filter_re.search(kernel_name) is not None: return True