From c795b8c6461e7ca098501c1b9e02a72ed29030d4 Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Thu, 10 Aug 2023 17:15:00 -0700 Subject: [PATCH] Fixed segfault caused by mis-aligned input tensors (#894) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/894 This diff fixed two issues: * our input tensors may not be well-aligned with respect to our kernels. For example, our fp16-gemm kernel may require 16-byte alignment, but the input tensor for the gemm is not aligned by 16 bytes. We fix the issue by cloning the non-alignment inputs with torch.clone, which almost always gives us 256-byte aligned pointers. * It's not uncommon that multiple shape definitions share the same dimension values. In such a case, we could keep a single definition. In some rare cases, this little trick can dramatically reduce the number of lines generated for the relevant concatenate op and thus may improve the compilation time. Currently, we only enable this optimization for cases where we care about compilation time, as for most cases, the unoptimized version can generate more readable code while having little impact to the compilation time. Reviewed By: ipiszy Differential Revision: D48238260 fbshipit-source-id: 4298f586322c32a3beff46f55ca579de4a40171f --- fx2ait/fx2ait/csrc/AITModelImpl.cpp | 14 +++- .../backend/common/concatenate_common.py | 76 +++++++++++++++---- tests/unittest/ops/test_concatenate.py | 60 +++++++++++++-- 3 files changed, 128 insertions(+), 22 deletions(-) diff --git a/fx2ait/fx2ait/csrc/AITModelImpl.cpp b/fx2ait/fx2ait/csrc/AITModelImpl.cpp index b2a636637..4fbfe02ef 100644 --- a/fx2ait/fx2ait/csrc/AITModelImpl.cpp +++ b/fx2ait/fx2ait/csrc/AITModelImpl.cpp @@ -393,7 +393,19 @@ std::vector AITModelImpl::processInputs( // call in a local! input = input.to(*floating_point_input_dtype_); } - inputs_contig.push_back(input.contiguous()); + auto t = input.contiguous(); + size_t elem_sz = t.element_size(); + // Let's be conservative - make sure all inputs tensor pointers are + // aligned by 8 of sizeof(dtype) + if ((((uint64_t)(t.data_ptr())) % (elem_sz * 8)) != 0) { + LOG(INFO) << "FORCE tensor pointer alignment: input_name: " << input_name + << ", unaligned addr: " << std::hex << t.data_ptr(); + auto t2 = t.clone(); + LOG(INFO) << "cloned tensor pointer addr: " << std::hex << t2.data_ptr(); + inputs_contig.push_back(t2); + } else { + inputs_contig.push_back(t); + } auto& input_contig = inputs_contig.back(); auto input_shape_array_ref = input_contig.sizes(); ait_inputs[ait_input_idx] = AITData{ diff --git a/python/aitemplate/backend/common/concatenate_common.py b/python/aitemplate/backend/common/concatenate_common.py index 025688309..872a6fd94 100644 --- a/python/aitemplate/backend/common/concatenate_common.py +++ b/python/aitemplate/backend/common/concatenate_common.py @@ -18,7 +18,9 @@ import jinja2 from aitemplate.backend.common import tensor_accessor_codegen +from aitemplate.backend.target import Target +from aitemplate.compiler.base import IntImm from aitemplate.compiler.ops.tensor import concatenate FUNC_DECL_TEMPLATE = jinja2.Template( @@ -745,18 +747,48 @@ def gen_function_call( input_names = ",\n ".join([i._attrs["name"] for i in inputs]) real_input_shape_defs = [] real_input_shape_names = [] + # It's not uncommon that multiple shape definitions share the same + # dimension values. In such a case, we could keep a single definition. + # In some rare cases, this little trick can dramatically reduce the + # number of lines generated for the relevant concatenate op and thus + # may improve the compilation time. Currently, we only enable this + # optimization for cases where we care about compilation time, as + # for most cases, the unoptimized version can generate more readable + # code while having little impact to the compilation time. + seen_input_shape_dims = {} + input_shape_name_subs = {} + optimize_for_compilation_time = Target.current()._kwargs.get( + "optimize_for_compilation_time", False + ) + + def _make_dims_key(dims): + dim_vals = [] + for d in dims: + if isinstance(d, IntImm): + dim_vals.append(str(d.value())) + else: + d._attrs["name"] + return ",".join(dim_vals) + for idx, (i, input_accessor) in enumerate(zip(inputs, input_accessors)): input_shape_name = f'{i._attrs["name"]}_shape_{idx}' orig_input_shape = input_accessor.original_shapes dims = ", ".join([dim._attrs["name"] for dim in orig_input_shape]) - one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render( - indent=" ", - input_shape_name=input_shape_name, - input_dims=dims, - index_type=backend_spec.index_type, - ) - real_input_shape_defs.append(one_shape_def) - real_input_shape_names.append(input_shape_name) + dims_key = _make_dims_key(orig_input_shape) + seen_shape_name = seen_input_shape_dims.get(dims_key, None) + if not optimize_for_compilation_time or seen_shape_name is None: + one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render( + indent=" ", + input_shape_name=input_shape_name, + input_dims=dims, + index_type=backend_spec.index_type, + ) + real_input_shape_defs.append(one_shape_def) + real_input_shape_names.append(input_shape_name) + seen_input_shape_dims[dims_key] = input_shape_name + else: + real_input_shape_names.append(seen_shape_name) + input_shape_name_subs[input_shape_name] = seen_shape_name y_shape = y._attrs["shape"] y_dim_refs = ", ".join(["&" + dim._attrs["name"] for dim in y_shape]) @@ -781,6 +813,7 @@ def gen_function_call( # all input shape defs and names, including those that are masked out all_input_shape_defs = [] all_input_shape_names = [] + seen_input_shape_dims = {} # first, create shape defs for inputs that have been masked off for ( mask, @@ -792,19 +825,30 @@ def gen_function_call( dims = ", ".join( [str(dim._attrs["values"][0]) for dim in orig_input._attrs["shape"]] ) - one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render( - indent=" ", - input_shape_name=orig_input_shape_name, - input_dims=dims, - index_type=backend_spec.index_type, - ) - all_input_shape_defs.append(one_shape_def) - all_input_shape_names.append(orig_input_shape_name) + dims_key = _make_dims_key(orig_input._attrs["shape"]) + seen_shape_name = seen_input_shape_dims.get(dims_key, None) + if not optimize_for_compilation_time or seen_shape_name is None: + one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render( + indent=" ", + input_shape_name=orig_input_shape_name, + input_dims=dims, + index_type=backend_spec.index_type, + ) + all_input_shape_defs.append(one_shape_def) + seen_input_shape_dims[dims_key] = orig_input_shape_name + all_input_shape_names.append(orig_input_shape_name) + else: + all_input_shape_names.append(seen_shape_name) + else: + all_input_shape_names.append(orig_input_shape_name) else: all_input_shape_names.append("") # update all_input_shapes with real input shapes for idx, (input_tensor, input_index) in enumerate(zip(inputs, input_indices)): input_shape_name = f'{input_tensor._attrs["name"]}_shape_{idx}' + sub_name = input_shape_name_subs.get(input_shape_name, None) + if sub_name is not None: + input_shape_name = sub_name all_input_shape_names[input_index] = input_shape_name return FUNC_CALL_TEMPLATE.render( diff --git a/tests/unittest/ops/test_concatenate.py b/tests/unittest/ops/test_concatenate.py index a7560f661..15fa91270 100644 --- a/tests/unittest/ops/test_concatenate.py +++ b/tests/unittest/ops/test_concatenate.py @@ -31,7 +31,13 @@ def __init__(self, *args, **kwargs): self.test_count = 0 def _run_concatenate( - self, *, concatenate_op, input_shapes, dim=None, input_type="float16" + self, + *, + concatenate_op, + input_shapes, + dim=None, + input_type="float16", + optimize_args=False, ): # generate torch reference result input_tensors_pt = [ @@ -44,7 +50,10 @@ def _run_concatenate( else torch.cat(input_tensors_pt, dim) ) - target = detect_target() + if optimize_args: + target = detect_target(optimize_for_compilation_time=True) + else: + target = detect_target() inputs = [ Tensor( shape=shape, dtype=input_type, name="input_{}".format(i), is_input=True @@ -68,9 +77,19 @@ def _run_concatenate( self.test_count += 1 def _run_batch_concatenate( - self, *, batch_sizes, concatenate_op, input_shapes, dim=0, input_type="float16" + self, + *, + batch_sizes, + concatenate_op, + input_shapes, + dim=0, + input_type="float16", + optimize_args=False, ): - target = detect_target() + if optimize_args: + target = detect_target(optimize_for_compilation_time=True) + else: + target = detect_target() BATCH_DIM_NAME = "input_batch" batch_dim = shape_utils.gen_int_var_min_max( values=batch_sizes, name=BATCH_DIM_NAME @@ -122,6 +141,7 @@ def _run_masked_concatenate( input_masks, dim=None, input_type="float16", + optimize_args=False, ): # generate torch reference result input_tensors_pt = [ @@ -134,7 +154,10 @@ def _run_masked_concatenate( else torch.cat(input_tensors_pt, dim) ) - target = detect_target() + if optimize_args: + target = detect_target(optimize_for_compilation_time=True) + else: + target = detect_target() inputs = [ Tensor( shape=shape, dtype=input_type, name="input_{}".format(i), is_input=True @@ -184,6 +207,13 @@ def test_batch_cat(self): input_shapes=([1], [1]), dim=0, ) + self._run_batch_concatenate( + batch_sizes=[1, 1], + concatenate_op=ops.concatenate(), + input_shapes=([1], [1]), + dim=0, + optimize_args=True, + ) self._run_batch_concatenate( batch_sizes=[1, 1], concatenate_op=ops.concatenate(), @@ -228,6 +258,12 @@ def test_cat(self): self._run_concatenate( concatenate_op=ops.concatenate(), input_shapes=([1, 1], [1, 1]), dim=0 ) + self._run_concatenate( + concatenate_op=ops.concatenate(), + input_shapes=([1, 1], [1, 1]), + dim=0, + optimize_args=True, + ) self._run_concatenate( concatenate_op=ops.concatenate(), input_shapes=([1, 1], [1, 1]), dim=1 ) @@ -321,6 +357,13 @@ def test_cat(self): ) def test_masked_cat(self): + self._run_masked_concatenate( + concatenate_op=ops.concatenate(), + input_shapes=([2, 2, 2], [2, 2, 2], [2, 2, 2]), + input_masks=[True, True, False], + dim=2, + optimize_args=True, + ) self._run_masked_concatenate( concatenate_op=ops.concatenate(), input_shapes=([2], [2]), @@ -345,6 +388,13 @@ def test_masked_cat(self): input_masks=[False, True, False], dim=2, ) + self._run_masked_concatenate( + concatenate_op=ops.concatenate(), + input_shapes=([1, 1, 1], [1, 1, 2], [1, 1, 4]), + input_masks=[False, True, False], + dim=2, + optimize_args=True, + ) @parameterized.expand(("float16", "float32", "bfloat16")) def test_floats(self, dtype):