Skip to content

Commit

Permalink
update frontend and mk_ck_lib (#777)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #777

Reviewed By: chenyang78

Differential Revision: D47039764

Pulled By: ipiszy

fbshipit-source-id: 4a2fa9228272ed32544498b68af4f4d42c02a460
  • Loading branch information
fsx950223 authored and facebook-github-bot committed Jun 30, 2023
1 parent 79d10cd commit 039bb9f
Show file tree
Hide file tree
Showing 10 changed files with 424 additions and 65 deletions.
56 changes: 10 additions & 46 deletions python/aitemplate/frontend/nn/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
# limitations under the License.
#
from aitemplate.compiler import ops
from aitemplate.compiler.public import FuncEnum
from aitemplate.frontend.nn.dropout import Dropout
from aitemplate.frontend.nn.layer_norm import LayerNorm
from aitemplate.frontend.nn.module import Module
from aitemplate.frontend.nn.parameter import Parameter
from aitemplate.testing import detect_target


class Embedding(Module):
Expand Down Expand Up @@ -61,8 +59,6 @@ def __init__(
dtype="float16",
):
super().__init__()
if BertEmbeddings.USE_CUDA is None:
BertEmbeddings.USE_CUDA = detect_target().name() == "cuda"
assert (
hidden_dropout_prob == 0.0
), "Dropout rate larger than 0 is not supported yet."
Expand All @@ -85,48 +81,16 @@ def forward(
token_type_ids, # [B, S]
position_ids, # [B, S]
):
if self.USE_CUDA:
embeddings = ops.bert_embeddings()(
input_ids,
token_type_ids,
position_ids,
self.word_embeddings.weight.tensor(),
self.token_type_embeddings.weight.tensor(),
self.position_embeddings.weight.tensor(),
self.LayerNorm.weight.tensor(),
self.LayerNorm.bias.tensor(),
self.LayerNorm.eps,
)
embeddings = self.dropout(embeddings)
return embeddings

input_shape = ops.size()(input_ids)

# [B * S]
input_ids = ops.reshape()(input_ids, [-1])
token_type_ids = ops.reshape()(token_type_ids, [-1])
position_ids = ops.reshape()(position_ids, [-1])

# [B * S, H]
input_embeddings = ops.batch_gather()(self.word_embeddings.tensor(), input_ids)
token_type_embeddings = ops.batch_gather()(
self.token_type_embeddings.tensor(), token_type_ids
)
position_embeddings = ops.batch_gather()(
self.position_embeddings.tensor(), position_ids
)

# add
embeddings = ops.elementwise(FuncEnum.ADD)(
input_embeddings, token_type_embeddings
embeddings = ops.bert_embeddings()(
input_ids,
token_type_ids,
position_ids,
self.word_embeddings.weight.tensor(),
self.token_type_embeddings.weight.tensor(),
self.position_embeddings.weight.tensor(),
self.LayerNorm.weight.tensor(),
self.LayerNorm.bias.tensor(),
self.LayerNorm.eps,
)

embeddings = ops.elementwise(FuncEnum.ADD)(embeddings, position_embeddings)

# norm
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)

embeddings = ops.reshape()(embeddings, input_shape + [-1])

return embeddings
1 change: 0 additions & 1 deletion python/aitemplate/frontend/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def get_submodule(self, target: str) -> "Module":
mod: Module = self

for item in atoms:

if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
Expand Down
1 change: 0 additions & 1 deletion python/aitemplate/frontend/nn/multiscale_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,6 @@ def __init__(
## TODO: add pool mode support for {"max", "avg"}

elif pool_mode == "conv":

self.pool_q = (
Conv3d(
head_dim,
Expand Down
5 changes: 2 additions & 3 deletions python/aitemplate/utils/mk_ck_lib/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def accumulator_type(self):
return library.DataType.f32

def emit(self) -> str:

template = jinja2.Template(
"""
using {{name}} = {{xdl_op_type}}<
Expand All @@ -285,7 +284,7 @@ def emit(self) -> str:
{{WeiLayout}}, // WeiLayout
{% if func=="PT" %}
ck::Tuple<>,
{% elif func=="AAR" %}
{% elif func in ["AA", "AAR"] %}
ck::Tuple<{{OutLayout}}, {{OutLayout}}>, // BiasLayout
{% else %}
{% if "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1" in xdl_op_type %}
Expand All @@ -301,7 +300,7 @@ def emit(self) -> str:
{{CShuffleDType}}, // CShuffleDataType
{% if func=="PT" %}
ck::Tuple<>,
{% elif func=="AAR" %}
{% elif func in ["AA", "AAR"] %}
ck::Tuple<{{CDType}}, {{CDType}}>, // BiasLayout
{% else %}
ck::Tuple<{{CDType}}>, // BiasDataType
Expand Down
46 changes: 39 additions & 7 deletions python/aitemplate/utils/mk_ck_lib/gemm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class XdlOpType(enum.Enum):
DeviceBatchedContractionMultipleD_Xdl_CShuffle = auto()
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle = auto()
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle = auto()
DeviceBatchedGemmMultiD_Xdl = auto()


XdlOpTag = {
Expand All @@ -62,6 +63,7 @@ class XdlOpType(enum.Enum):
XdlOpType.DeviceBatchedContractionMultipleD_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Xdl_CShuffle",
XdlOpType.DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle",
XdlOpType.DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle",
XdlOpType.DeviceBatchedGemmMultiD_Xdl: "ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl",
}


Expand Down Expand Up @@ -247,7 +249,11 @@ def __str__(self) -> str:
_{{n_xdl_per_wave}}
{{m_n_block_wave_per_xdl|join('_')}}S
{{scalar_per_vector}}
{{causal_mask}}
{% if causal_mask == 1 %}
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle // causal_mask
{% else %}
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled // causal_mask
{% endif %}
""",
trim_blocks=True,
lstrip_blocks=True,
Expand All @@ -264,7 +270,11 @@ def emit(self) -> str:
{{n_xdl_per_wave}}, // n_xdl_per_wave
ck::Sequence<{{m_n_block_wave_per_xdl|join(',')}}>, // m_n_block_wave_per_xdl
{{scalar_per_vector}}, // scalar_per_vector
{{causal_mask}} // causal_mask
{% if causal_mask == 1 %}
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle // causal_mask
{% else %}
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled // causal_mask
{% endif %}
""",
trim_blocks=True,
lstrip_blocks=True,
Expand Down Expand Up @@ -392,21 +402,38 @@ def emit(self) -> str:
ck::Tuple<ck::half_t>,
{% endif %}
ck::half_t,
{% elif xdl_op_type_value in [7, 8] %}
{% elif xdl_op_type_value == 7 %}
{{ALayout}},
{{BLayout}},
{{CLayout}},
{% if xdl_op_type_value == 8 %}
ck::Sequence<2,1,1>,
{% else %}
{{CLayout}},
{% endif %}
{{ADType}},
{{BDType}},
{{BDType}},
{{CDType}},
{{AccDType}},
float, // CShuffleDType,
{% elif xdl_op_type_value == 8 %}
2, 1, 1, 1, 1,
{{ADType}},
{{BDType}},
{{BDType}},
{{CDType}},
ck::Tuple<>,
ck::Tuple<>,
{{AccDType}},
float, // CShuffleDType,
{% elif xdl_op_type_value == 9 %}
{{ALayout}},
{{BLayout}},
ck::Tuple<{{DsLayout}}>, // DsLayout
{{CLayout}},
{{ADType}},
{{BDType}},
{{AccDType}},
{{CShuffleDType}},
ck::Tuple<{{DsDType}}>, // DsType
{{EDType}},
{% endif %}
{% if xdl_op_type_value in [7, 8] %}
{{A_elem_op}},
Expand All @@ -423,6 +450,11 @@ def emit(self) -> str:
ck::tensor_operation::device::TensorSpecialization::Packed,
ck::tensor_operation::device::TensorSpecialization::Packed,
ck::tensor_operation::device::TensorSpecialization::Default,
{% elif xdl_op_type_value==8 %}
ck::tensor_operation::device::TensorSpecialization::Default,
ck::tensor_operation::device::TensorSpecialization::Default,
ck::tensor_operation::device::TensorSpecialization::Default,
ck::tensor_operation::device::TensorSpecialization::Default,
{% endif %}
1,
{% endif %}
Expand Down
Loading

0 comments on commit 039bb9f

Please sign in to comment.