From b735509e4b890b7a682aaff71a3cfa030e280079 Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Wed, 12 Jul 2023 03:29:26 -0700 Subject: [PATCH] Fix group_layernorm alignment and reshape fusion Summary: The group_layernorm back-end implementation is not well-aligned with the input TensorAccessors which can cause CUDA illegal memory access under certain fusion scenarios (e.g., reshape fusion). This diff addresses these issues. Differential Revision: D47395663 fbshipit-source-id: 0369a305e52f19515efc1867fb34ac0f1b088f18 --- .../group_layernorm_sigmoid_mul.py | 6 +- .../layernorm_sigmoid_mul_kernel.cuh | 11 +- .../compiler/test_strided_group_layernorm.py | 144 +++++++++++++++++- 3 files changed, 158 insertions(+), 3 deletions(-) diff --git a/python/aitemplate/backend/cuda/layernorm_sigmoid_mul/group_layernorm_sigmoid_mul.py b/python/aitemplate/backend/cuda/layernorm_sigmoid_mul/group_layernorm_sigmoid_mul.py index 2bf72cb3b..e11eb4c28 100644 --- a/python/aitemplate/backend/cuda/layernorm_sigmoid_mul/group_layernorm_sigmoid_mul.py +++ b/python/aitemplate/backend/cuda/layernorm_sigmoid_mul/group_layernorm_sigmoid_mul.py @@ -269,7 +269,11 @@ def group_layernorm_sigmoid_mul_gen_function_call(func_attrs, indent=" "): all_shape_funcs = [] # all Ms are the same - input_0_shapes = inputs[0]._attrs["shape"] + if func_attrs.get("input_accessors", None): + input_accessor = func_attrs["input_accessors"][0] + input_0_shapes = input_accessor.original_shapes + else: + input_0_shapes = inputs[0]._attrs["shape"] norm_ndim = len(func_attrs["normalized_shape"][0]) m_name = "M" m_shape_func = layernorm_common.generate_m_shape_func( diff --git a/python/aitemplate/backend/cuda/layernorm_sigmoid_mul/layernorm_sigmoid_mul_kernel.cuh b/python/aitemplate/backend/cuda/layernorm_sigmoid_mul/layernorm_sigmoid_mul_kernel.cuh index a29179ea8..1b562d124 100644 --- a/python/aitemplate/backend/cuda/layernorm_sigmoid_mul/layernorm_sigmoid_mul_kernel.cuh +++ b/python/aitemplate/backend/cuda/layernorm_sigmoid_mul/layernorm_sigmoid_mul_kernel.cuh @@ -2142,10 +2142,19 @@ cudaError_t invokeGroupLayernormSigmoidMul( return cudaSuccess; } + bool accessors_aligned_to_4 = true; + for (size_t i = 0; i < b; ++i) { + if (!input_accessors[i].is_valid_alignment(4) || + !output_accessors[i].is_valid_alignment(4)) { + accessors_aligned_to_4 = false; + break; + } + } + dim3 grid(m, b); // TODO: implement float4 group kernel if (std::is_same::value && n_is_multiple_of_4 && (min_n >= 128) && - (max_n <= 4096)) { + (max_n <= 4096) && accessors_aligned_to_4) { dim3 block(min_n); // round up to multiples of 32 to make warp shuffles safe block.x = (block.x / 4 + 31) / 32 * 32; diff --git a/tests/unittest/compiler/test_strided_group_layernorm.py b/tests/unittest/compiler/test_strided_group_layernorm.py index 02f24a795..afdf599d6 100644 --- a/tests/unittest/compiler/test_strided_group_layernorm.py +++ b/tests/unittest/compiler/test_strided_group_layernorm.py @@ -18,8 +18,12 @@ import torch from aitemplate.compiler import compile_model, ops -from aitemplate.frontend import Tensor +from aitemplate.frontend import IntImm, IntVar, Tensor from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import ( + get_random_torch_tensor, + get_torch_empty_tensor, +) from aitemplate.utils import shape_utils, torch_utils @@ -369,6 +373,144 @@ def test_slice_group_layer_norm_float(self): dtype="float32", ) + @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") + def test_group_layernorm_no_cuda_illegal_memory_access(self): + """ + This subgraph has led to CUDA illegal memory issues before. + Adding it as a unit test to ensure there are no regressions. + """ + batch_size = IntVar(values=[1, 2048], name="batch_size") + + unsqueeze_46_0 = Tensor( + shape=[batch_size, 7680, 1], + is_input=True, + name="unsqueeze_46_0", + ) + unsqueeze_58_0 = Tensor( + shape=[batch_size, 7680, 1], + is_input=True, + name="unsqueeze_58_0", + ) + unsqueeze_70_0 = Tensor( + shape=[batch_size, 7680, 1], + is_input=True, + name="unsqueeze_70_0", + ) + unsqueeze_131_0 = Tensor( + shape=[batch_size, 3, 1], + is_input=True, + name="unsqueeze_131_0", + ) + main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_weight = Tensor( + shape=[IntImm(256)], + is_input=True, + name="main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_weight", + ) + main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_bias = Tensor( + shape=[IntImm(256)], + is_input=True, + name="main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_bias", + ) + + unsqueeze_83_0 = Tensor( + shape=[batch_size, 7680, 1], + is_input=True, + name="unsqueeze_83_0", + ) + unsqueeze_95_0 = Tensor( + shape=[batch_size, 7680, 1], + is_input=True, + name="unsqueeze_95_0", + ) + unsqueeze_107_0 = Tensor( + shape=[batch_size, 7680, 1], + is_input=True, + name="unsqueeze_107_0", + ) + unsqueeze_358_0 = Tensor( + shape=[batch_size, 3, 1], + is_input=True, + name="unsqueeze_358_0", + ) + main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_weight = Tensor( + shape=[IntImm(256)], + is_input=True, + name="main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_weight", + ) + main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_bias = Tensor( + shape=[IntImm(256)], + is_input=True, + name="main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_bias", + ) + + concatenate_71_0 = ops.concatenate()( + inputs=[unsqueeze_46_0, unsqueeze_58_0, unsqueeze_70_0], + dim=2, + ) + bmm_rrr_132_0 = ops.bmm_rrr()(concatenate_71_0, unsqueeze_131_0) + reshape_133_0 = ops.reshape()(bmm_rrr_132_0, shape=[-1, 30, 256]) + layernorm_134_0 = ops.layernorm(normalized_shape=[IntImm(256)])( + reshape_133_0, + main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_weight, + main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_bias, + ) + permute021_136_0 = ops.permute021()(layernorm_134_0) + + concatenate_108_0 = ops.concatenate()( + inputs=[unsqueeze_83_0, unsqueeze_95_0, unsqueeze_107_0], + dim=2, + ) + bmm_rrr_359_0 = ops.bmm_rrr()(concatenate_108_0, unsqueeze_358_0) + reshape_360_0 = ops.reshape()(bmm_rrr_359_0, shape=[-1, 30, 256]) + layernorm_361_0 = ops.layernorm(normalized_shape=[IntImm(256)])( + reshape_360_0, + main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_weight, + main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_bias, + ) + permute021_363_0 = ops.permute021()(layernorm_361_0) + + outputs = [permute021_136_0, permute021_363_0] + + for i, output in enumerate(outputs): + output._attrs["is_output"] = True + output._attrs["name"] = f"output_{i}" + + model = compile_model( + outputs, + detect_target(), + "./tmp", + "test_group_layernorm_repro", + ) + + pt_inputs = { + "unsqueeze_46_0": get_random_torch_tensor(shape=[1024, 7680, 1]), + "unsqueeze_58_0": get_random_torch_tensor(shape=[1024, 7680, 1]), + "unsqueeze_70_0": get_random_torch_tensor(shape=[1024, 7680, 1]), + "unsqueeze_131_0": get_random_torch_tensor(shape=[1024, 3, 1]), + "main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_weight": get_random_torch_tensor( + shape=[256] + ), + "main_module_base_forward_module_over_arch_bottom_arch_list_1_dime_shared_arch_layer_norm__norm_bias": get_random_torch_tensor( + shape=[256] + ), + "unsqueeze_83_0": get_random_torch_tensor(shape=[1024, 7680, 1]), + "unsqueeze_95_0": get_random_torch_tensor(shape=[1024, 7680, 1]), + "unsqueeze_107_0": get_random_torch_tensor(shape=[1024, 7680, 1]), + "unsqueeze_358_0": get_random_torch_tensor(shape=[1024, 3, 1]), + "main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_weight": get_random_torch_tensor( + shape=[256] + ), + "main_module_base_forward_module_over_arch_bottom_arch_list_0_dime_shared_arch_layer_norm__norm_bias": get_random_torch_tensor( + shape=[256] + ), + } + pt_outputs = { + "output_0": get_torch_empty_tensor(shape=[1024, 256, 30]), + "output_1": get_torch_empty_tensor(shape=[1024, 256, 30]), + } + + model.run_with_tensors(pt_inputs, pt_outputs) + if __name__ == "__main__": torch.manual_seed(0)