Skip to content

Commit

Permalink
Fix group_layernorm alignment and reshape fusion
Browse files Browse the repository at this point in the history
Summary:
The group_layernorm back-end implementation is not well-aligned with the input TensorAccessors which can cause CUDA illegal memory access under certain fusion scenarios (e.g., reshape fusion).

This diff addresses these issues.

Differential Revision: D47395663

fbshipit-source-id: 0369a305e52f19515efc1867fb34ac0f1b088f18
  • Loading branch information
aakhundov authored and facebook-github-bot committed Jul 12, 2023
1 parent ecb705a commit b735509
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,11 @@ def group_layernorm_sigmoid_mul_gen_function_call(func_attrs, indent=" "):

all_shape_funcs = []
# all Ms are the same
input_0_shapes = inputs[0]._attrs["shape"]
if func_attrs.get("input_accessors", None):
input_accessor = func_attrs["input_accessors"][0]
input_0_shapes = input_accessor.original_shapes
else:
input_0_shapes = inputs[0]._attrs["shape"]
norm_ndim = len(func_attrs["normalized_shape"][0])
m_name = "M"
m_shape_func = layernorm_common.generate_m_shape_func(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2142,10 +2142,19 @@ cudaError_t invokeGroupLayernormSigmoidMul(
return cudaSuccess;
}

bool accessors_aligned_to_4 = true;
for (size_t i = 0; i < b; ++i) {
if (!input_accessors[i].is_valid_alignment(4) ||
!output_accessors[i].is_valid_alignment(4)) {
accessors_aligned_to_4 = false;
break;
}
}

dim3 grid(m, b);
// TODO: implement float4 group kernel
if (std::is_same<T, half>::value && n_is_multiple_of_4 && (min_n >= 128) &&
(max_n <= 4096)) {
(max_n <= 4096) && accessors_aligned_to_4) {
dim3 block(min_n);
// round up to multiples of 32 to make warp shuffles safe
block.x = (block.x / 4 + 31) / 32 * 32;
Expand Down
144 changes: 143 additions & 1 deletion tests/unittest/compiler/test_strided_group_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@

import torch
from aitemplate.compiler import compile_model, ops
from aitemplate.frontend import Tensor
from aitemplate.frontend import IntImm, IntVar, Tensor
from aitemplate.testing import detect_target
from aitemplate.testing.test_utils import (
get_random_torch_tensor,
get_torch_empty_tensor,
)
from aitemplate.utils import shape_utils, torch_utils


Expand Down Expand Up @@ -369,6 +373,144 @@ def test_slice_group_layer_norm_float(self):
dtype="float32",
)

@unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.")
def test_group_layernorm_no_cuda_illegal_memory_access(self):
"""
This subgraph has led to CUDA illegal memory issues before.
Adding it as a unit test to ensure there are no regressions.
"""
batch_size = IntVar(values=[1, 2048], name="batch_size")

unsqueeze_46_0 = Tensor(
shape=[batch_size, 7680, 1],
is_input=True,
name="unsqueeze_46_0",
)
unsqueeze_58_0 = Tensor(
shape=[batch_size, 7680, 1],
is_input=True,
name="unsqueeze_58_0",
)
unsqueeze_70_0 = Tensor(
shape=[batch_size, 7680, 1],
is_input=True,
name="unsqueeze_70_0",
)
unsqueeze_131_0 = Tensor(
shape=[batch_size, 3, 1],
is_input=True,
name="unsqueeze_131_0",
)
main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_weight = Tensor(
shape=[IntImm(256)],
is_input=True,
name="main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_weight",
)
main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_bias = Tensor(
shape=[IntImm(256)],
is_input=True,
name="main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_bias",
)

unsqueeze_83_0 = Tensor(
shape=[batch_size, 7680, 1],
is_input=True,
name="unsqueeze_83_0",
)
unsqueeze_95_0 = Tensor(
shape=[batch_size, 7680, 1],
is_input=True,
name="unsqueeze_95_0",
)
unsqueeze_107_0 = Tensor(
shape=[batch_size, 7680, 1],
is_input=True,
name="unsqueeze_107_0",
)
unsqueeze_358_0 = Tensor(
shape=[batch_size, 3, 1],
is_input=True,
name="unsqueeze_358_0",
)
main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_weight = Tensor(
shape=[IntImm(256)],
is_input=True,
name="main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_weight",
)
main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_bias = Tensor(
shape=[IntImm(256)],
is_input=True,
name="main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_bias",
)

concatenate_71_0 = ops.concatenate()(
inputs=[unsqueeze_46_0, unsqueeze_58_0, unsqueeze_70_0],
dim=2,
)
bmm_rrr_132_0 = ops.bmm_rrr()(concatenate_71_0, unsqueeze_131_0)
reshape_133_0 = ops.reshape()(bmm_rrr_132_0, shape=[-1, 30, 256])
layernorm_134_0 = ops.layernorm(normalized_shape=[IntImm(256)])(
reshape_133_0,
main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_weight,
main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_bias,
)
permute021_136_0 = ops.permute021()(layernorm_134_0)

concatenate_108_0 = ops.concatenate()(
inputs=[unsqueeze_83_0, unsqueeze_95_0, unsqueeze_107_0],
dim=2,
)
bmm_rrr_359_0 = ops.bmm_rrr()(concatenate_108_0, unsqueeze_358_0)
reshape_360_0 = ops.reshape()(bmm_rrr_359_0, shape=[-1, 30, 256])
layernorm_361_0 = ops.layernorm(normalized_shape=[IntImm(256)])(
reshape_360_0,
main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_weight,
main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_bias,
)
permute021_363_0 = ops.permute021()(layernorm_361_0)

outputs = [permute021_136_0, permute021_363_0]

for i, output in enumerate(outputs):
output._attrs["is_output"] = True
output._attrs["name"] = f"output_{i}"

model = compile_model(
outputs,
detect_target(),
"./tmp",
"test_group_layernorm_repro",
)

pt_inputs = {
"unsqueeze_46_0": get_random_torch_tensor(shape=[1024, 7680, 1]),
"unsqueeze_58_0": get_random_torch_tensor(shape=[1024, 7680, 1]),
"unsqueeze_70_0": get_random_torch_tensor(shape=[1024, 7680, 1]),
"unsqueeze_131_0": get_random_torch_tensor(shape=[1024, 3, 1]),
"main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_weight": get_random_torch_tensor(
shape=[256]
),
"main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_bias": get_random_torch_tensor(
shape=[256]
),
"unsqueeze_83_0": get_random_torch_tensor(shape=[1024, 7680, 1]),
"unsqueeze_95_0": get_random_torch_tensor(shape=[1024, 7680, 1]),
"unsqueeze_107_0": get_random_torch_tensor(shape=[1024, 7680, 1]),
"unsqueeze_358_0": get_random_torch_tensor(shape=[1024, 3, 1]),
"main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_weight": get_random_torch_tensor(
shape=[256]
),
"main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_bias": get_random_torch_tensor(
shape=[256]
),
}
pt_outputs = {
"output_0": get_torch_empty_tensor(shape=[1024, 256, 30]),
"output_1": get_torch_empty_tensor(shape=[1024, 256, 30]),
}

model.run_with_tensors(pt_inputs, pt_outputs)


if __name__ == "__main__":
torch.manual_seed(0)
Expand Down

0 comments on commit b735509

Please sign in to comment.