Skip to content

Commit

Permalink
update frontend and mk_ck_lib
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 committed Jun 16, 2023
1 parent e98d2dd commit a4b1e38
Show file tree
Hide file tree
Showing 11 changed files with 464 additions and 80 deletions.
55 changes: 40 additions & 15 deletions python/aitemplate/frontend/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Frontend for attention module
"""
from aitemplate.compiler import ops
from aitemplate.compiler.base import IntVar
from aitemplate.compiler.ops import flash_attention
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.frontend import Tensor
Expand Down Expand Up @@ -256,7 +257,7 @@ def attention(self, x):
)
return out

def forward(self, *args):
def forward(self, *args, seqlens=None):
"""forward pass for calling mha module"""
assert len(args) >= 1
x = args[0]
Expand Down Expand Up @@ -361,32 +362,55 @@ def __init__(
)
self.proj_drop = Dropout(proj_drop, dtype=dtype)

def attention(self, q, k, v):
def attention(self, q, k, v, seqlens=None):
batch = q.shape()[0]
head_dim = self.dim // self.num_heads

query = self.proj_q(q)
key = self.proj_k(k)
value = self.proj_v(v)

query = ops.permute()(
ops.reshape()(query, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3]
)
key = ops.permute()(
ops.reshape()(key, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3]
)
value = ops.permute()(
ops.reshape()(value, [batch, -1, self.num_heads, head_dim]),
[0, 2, 1, 3],
if detect_target().name() == "cuda":
query = ops.permute()(
ops.reshape()(query, [batch, -1, self.num_heads, head_dim]),
[0, 2, 1, 3],
)
key = ops.permute()(
ops.reshape()(key, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3]
)
value = ops.permute()(
ops.reshape()(value, [batch, -1, self.num_heads, head_dim]),
[0, 2, 1, 3],
)
return self.op(query, key, value)
elif seqlens:
query = ops.reshape()(query, [batch, self.num_heads, -1, head_dim])
key = ops.reshape()(key, [batch, self.num_heads, -1, head_dim])
value = ops.reshape()(value, [batch, self.num_heads, -1, head_dim])
return self.op(query, key, value, seqlens)

query = ops.reshape()(query, [batch, -1, self.num_heads, head_dim])
query = ops.transpose()(query, 1, 2)
query = ops.reshape()(query, [-1, query.shape()[2], head_dim])
key = ops.reshape()(key, [batch, -1, self.num_heads, head_dim])
key = ops.transpose()(key, 1, 2)
key = ops.reshape()(key, [-1, key.shape()[2], head_dim])
value = ops.reshape()(value, [batch, -1, self.num_heads, head_dim])
value = ops.transpose()(value, 1, 2)
value = ops.reshape()(value, [-1, value.shape()[2], head_dim])
OP = ops.bmm_softmax_bmm_permute(
shape=(self.num_heads,),
scale=head_dim**-0.5,
causal=self.causal,
)
return self.op(query, key, value)
return OP(query, key, value)

def forward(self, *args):
def forward(self, *args, seqlens=None):
"""forward pass for calling mha module"""
assert len(args) >= 3
x = args[0]
batch = x.shape()[0]
attn_output = self.attention(args[0], args[1], args[2])
attn_output = self.attention(args[0], args[1], args[2], seqlens=seqlens)
attn_output = ops.reshape()(attn_output, [batch, -1, self.dim])

if self.has_residual:
Expand All @@ -395,7 +419,8 @@ def forward(self, *args):
else:
x = self.proj(attn_output)
x = self.proj_drop(x)
x = ops.reshape()(x, [batch, -1, self.dim])
if not isinstance(batch, IntVar):
x = ops.reshape()(x, [batch, -1, self.dim])
return x


Expand Down
56 changes: 10 additions & 46 deletions python/aitemplate/frontend/nn/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
# limitations under the License.
#
from aitemplate.compiler import ops
from aitemplate.compiler.public import FuncEnum
from aitemplate.frontend.nn.dropout import Dropout
from aitemplate.frontend.nn.layer_norm import LayerNorm
from aitemplate.frontend.nn.module import Module
from aitemplate.frontend.nn.parameter import Parameter
from aitemplate.testing import detect_target


class Embedding(Module):
Expand Down Expand Up @@ -61,8 +59,6 @@ def __init__(
dtype="float16",
):
super().__init__()
if BertEmbeddings.USE_CUDA is None:
BertEmbeddings.USE_CUDA = detect_target().name() == "cuda"
assert (
hidden_dropout_prob == 0.0
), "Dropout rate larger than 0 is not supported yet."
Expand All @@ -85,48 +81,16 @@ def forward(
token_type_ids, # [B, S]
position_ids, # [B, S]
):
if self.USE_CUDA:
embeddings = ops.bert_embeddings()(
input_ids,
token_type_ids,
position_ids,
self.word_embeddings.weight.tensor(),
self.token_type_embeddings.weight.tensor(),
self.position_embeddings.weight.tensor(),
self.LayerNorm.weight.tensor(),
self.LayerNorm.bias.tensor(),
self.LayerNorm.eps,
)
embeddings = self.dropout(embeddings)
return embeddings

input_shape = ops.size()(input_ids)

# [B * S]
input_ids = ops.reshape()(input_ids, [-1])
token_type_ids = ops.reshape()(token_type_ids, [-1])
position_ids = ops.reshape()(position_ids, [-1])

# [B * S, H]
input_embeddings = ops.batch_gather()(self.word_embeddings.tensor(), input_ids)
token_type_embeddings = ops.batch_gather()(
self.token_type_embeddings.tensor(), token_type_ids
)
position_embeddings = ops.batch_gather()(
self.position_embeddings.tensor(), position_ids
)

# add
embeddings = ops.elementwise(FuncEnum.ADD)(
input_embeddings, token_type_embeddings
embeddings = ops.bert_embeddings()(
input_ids,
token_type_ids,
position_ids,
self.word_embeddings.weight.tensor(),
self.token_type_embeddings.weight.tensor(),
self.position_embeddings.weight.tensor(),
self.LayerNorm.weight.tensor(),
self.LayerNorm.bias.tensor(),
self.LayerNorm.eps,
)

embeddings = ops.elementwise(FuncEnum.ADD)(embeddings, position_embeddings)

# norm
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)

embeddings = ops.reshape()(embeddings, input_shape + [-1])

return embeddings
1 change: 0 additions & 1 deletion python/aitemplate/frontend/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def get_submodule(self, target: str) -> "Module":
mod: Module = self

for item in atoms:

if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
Expand Down
1 change: 0 additions & 1 deletion python/aitemplate/frontend/nn/multiscale_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,6 @@ def __init__(
## TODO: add pool mode support for {"max", "avg"}

elif pool_mode == "conv":

self.pool_q = (
Conv3d(
head_dim,
Expand Down
5 changes: 2 additions & 3 deletions python/aitemplate/utils/mk_ck_lib/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def accumulator_type(self):
return library.DataType.f32

def emit(self) -> str:

template = jinja2.Template(
"""
using {{name}} = {{xdl_op_type}}<
Expand All @@ -285,7 +284,7 @@ def emit(self) -> str:
{{WeiLayout}}, // WeiLayout
{% if func=="PT" %}
ck::Tuple<>,
{% elif func=="AAR" %}
{% elif func in ["AA", "AAR"] %}
ck::Tuple<{{OutLayout}}, {{OutLayout}}>, // BiasLayout
{% else %}
{% if "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1" in xdl_op_type %}
Expand All @@ -301,7 +300,7 @@ def emit(self) -> str:
{{CShuffleDType}}, // CShuffleDataType
{% if func=="PT" %}
ck::Tuple<>,
{% elif func=="AAR" %}
{% elif func in ["AA", "AAR"] %}
ck::Tuple<{{CDType}}, {{CDType}}>, // BiasLayout
{% else %}
ck::Tuple<{{CDType}}>, // BiasDataType
Expand Down
46 changes: 39 additions & 7 deletions python/aitemplate/utils/mk_ck_lib/gemm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class XdlOpType(enum.Enum):
DeviceBatchedContractionMultipleD_Xdl_CShuffle = auto()
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle = auto()
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle = auto()
DeviceBatchedGemmMultiD_Xdl = auto()


XdlOpTag = {
Expand All @@ -62,6 +63,7 @@ class XdlOpType(enum.Enum):
XdlOpType.DeviceBatchedContractionMultipleD_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Xdl_CShuffle",
XdlOpType.DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle",
XdlOpType.DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle",
XdlOpType.DeviceBatchedGemmMultiD_Xdl: "ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl",
}


Expand Down Expand Up @@ -247,7 +249,11 @@ def __str__(self) -> str:
_{{n_xdl_per_wave}}
{{m_n_block_wave_per_xdl|join('_')}}S
{{scalar_per_vector}}
{{causal_mask}}
{% if causal_mask == 1 %}
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle // causal_mask
{% else %}
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled // causal_mask
{% endif %}
""",
trim_blocks=True,
lstrip_blocks=True,
Expand All @@ -264,7 +270,11 @@ def emit(self) -> str:
{{n_xdl_per_wave}}, // n_xdl_per_wave
ck::Sequence<{{m_n_block_wave_per_xdl|join(',')}}>, // m_n_block_wave_per_xdl
{{scalar_per_vector}}, // scalar_per_vector
{{causal_mask}} // causal_mask
{% if causal_mask == 1 %}
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle // causal_mask
{% else %}
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled // causal_mask
{% endif %}
""",
trim_blocks=True,
lstrip_blocks=True,
Expand Down Expand Up @@ -392,21 +402,38 @@ def emit(self) -> str:
ck::Tuple<ck::half_t>,
{% endif %}
ck::half_t,
{% elif xdl_op_type_value in [7, 8] %}
{% elif xdl_op_type_value == 7 %}
{{ALayout}},
{{BLayout}},
{{CLayout}},
{% if xdl_op_type_value == 8 %}
ck::Sequence<2,1,1>,
{% else %}
{{CLayout}},
{% endif %}
{{ADType}},
{{BDType}},
{{BDType}},
{{CDType}},
{{AccDType}},
float, // CShuffleDType,
{% elif xdl_op_type_value == 8 %}
2, 1, 1, 1, 1,
{{ADType}},
{{BDType}},
{{BDType}},
{{CDType}},
ck::Tuple<>,
ck::Tuple<>,
{{AccDType}},
float, // CShuffleDType,
{% elif xdl_op_type_value == 9 %}
{{ALayout}},
{{BLayout}},
ck::Tuple<{{DsLayout}}>, // DsLayout
{{CLayout}},
{{ADType}},
{{BDType}},
{{AccDType}},
{{CShuffleDType}},
ck::Tuple<{{DsDType}}>, // DsType
{{EDType}},
{% endif %}
{% if xdl_op_type_value in [7, 8] %}
{{A_elem_op}},
Expand All @@ -423,6 +450,11 @@ def emit(self) -> str:
ck::tensor_operation::device::TensorSpecialization::Packed,
ck::tensor_operation::device::TensorSpecialization::Packed,
ck::tensor_operation::device::TensorSpecialization::Default,
{% elif xdl_op_type_value==8 %}
ck::tensor_operation::device::TensorSpecialization::Default,
ck::tensor_operation::device::TensorSpecialization::Default,
ck::tensor_operation::device::TensorSpecialization::Default,
ck::tensor_operation::device::TensorSpecialization::Default,
{% endif %}
1,
{% endif %}
Expand Down
Loading

0 comments on commit a4b1e38

Please sign in to comment.