diff --git a/fx2ait/fx2ait/csrc/AITModelImpl.cpp b/fx2ait/fx2ait/csrc/AITModelImpl.cpp index b2a636637..4fbfe02ef 100644 --- a/fx2ait/fx2ait/csrc/AITModelImpl.cpp +++ b/fx2ait/fx2ait/csrc/AITModelImpl.cpp @@ -393,7 +393,19 @@ std::vector AITModelImpl::processInputs( // call in a local! input = input.to(*floating_point_input_dtype_); } - inputs_contig.push_back(input.contiguous()); + auto t = input.contiguous(); + size_t elem_sz = t.element_size(); + // Let's be conservative - make sure all inputs tensor pointers are + // aligned by 8 of sizeof(dtype) + if ((((uint64_t)(t.data_ptr())) % (elem_sz * 8)) != 0) { + LOG(INFO) << "FORCE tensor pointer alignment: input_name: " << input_name + << ", unaligned addr: " << std::hex << t.data_ptr(); + auto t2 = t.clone(); + LOG(INFO) << "cloned tensor pointer addr: " << std::hex << t2.data_ptr(); + inputs_contig.push_back(t2); + } else { + inputs_contig.push_back(t); + } auto& input_contig = inputs_contig.back(); auto input_shape_array_ref = input_contig.sizes(); ait_inputs[ait_input_idx] = AITData{ diff --git a/python/aitemplate/backend/common/concatenate_common.py b/python/aitemplate/backend/common/concatenate_common.py index 025688309..872a6fd94 100644 --- a/python/aitemplate/backend/common/concatenate_common.py +++ b/python/aitemplate/backend/common/concatenate_common.py @@ -18,7 +18,9 @@ import jinja2 from aitemplate.backend.common import tensor_accessor_codegen +from aitemplate.backend.target import Target +from aitemplate.compiler.base import IntImm from aitemplate.compiler.ops.tensor import concatenate FUNC_DECL_TEMPLATE = jinja2.Template( @@ -745,18 +747,48 @@ def gen_function_call( input_names = ",\n ".join([i._attrs["name"] for i in inputs]) real_input_shape_defs = [] real_input_shape_names = [] + # It's not uncommon that multiple shape definitions share the same + # dimension values. In such a case, we could keep a single definition. + # In some rare cases, this little trick can dramatically reduce the + # number of lines generated for the relevant concatenate op and thus + # may improve the compilation time. Currently, we only enable this + # optimization for cases where we care about compilation time, as + # for most cases, the unoptimized version can generate more readable + # code while having little impact to the compilation time. + seen_input_shape_dims = {} + input_shape_name_subs = {} + optimize_for_compilation_time = Target.current()._kwargs.get( + "optimize_for_compilation_time", False + ) + + def _make_dims_key(dims): + dim_vals = [] + for d in dims: + if isinstance(d, IntImm): + dim_vals.append(str(d.value())) + else: + d._attrs["name"] + return ",".join(dim_vals) + for idx, (i, input_accessor) in enumerate(zip(inputs, input_accessors)): input_shape_name = f'{i._attrs["name"]}_shape_{idx}' orig_input_shape = input_accessor.original_shapes dims = ", ".join([dim._attrs["name"] for dim in orig_input_shape]) - one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render( - indent=" ", - input_shape_name=input_shape_name, - input_dims=dims, - index_type=backend_spec.index_type, - ) - real_input_shape_defs.append(one_shape_def) - real_input_shape_names.append(input_shape_name) + dims_key = _make_dims_key(orig_input_shape) + seen_shape_name = seen_input_shape_dims.get(dims_key, None) + if not optimize_for_compilation_time or seen_shape_name is None: + one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render( + indent=" ", + input_shape_name=input_shape_name, + input_dims=dims, + index_type=backend_spec.index_type, + ) + real_input_shape_defs.append(one_shape_def) + real_input_shape_names.append(input_shape_name) + seen_input_shape_dims[dims_key] = input_shape_name + else: + real_input_shape_names.append(seen_shape_name) + input_shape_name_subs[input_shape_name] = seen_shape_name y_shape = y._attrs["shape"] y_dim_refs = ", ".join(["&" + dim._attrs["name"] for dim in y_shape]) @@ -781,6 +813,7 @@ def gen_function_call( # all input shape defs and names, including those that are masked out all_input_shape_defs = [] all_input_shape_names = [] + seen_input_shape_dims = {} # first, create shape defs for inputs that have been masked off for ( mask, @@ -792,19 +825,30 @@ def gen_function_call( dims = ", ".join( [str(dim._attrs["values"][0]) for dim in orig_input._attrs["shape"]] ) - one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render( - indent=" ", - input_shape_name=orig_input_shape_name, - input_dims=dims, - index_type=backend_spec.index_type, - ) - all_input_shape_defs.append(one_shape_def) - all_input_shape_names.append(orig_input_shape_name) + dims_key = _make_dims_key(orig_input._attrs["shape"]) + seen_shape_name = seen_input_shape_dims.get(dims_key, None) + if not optimize_for_compilation_time or seen_shape_name is None: + one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render( + indent=" ", + input_shape_name=orig_input_shape_name, + input_dims=dims, + index_type=backend_spec.index_type, + ) + all_input_shape_defs.append(one_shape_def) + seen_input_shape_dims[dims_key] = orig_input_shape_name + all_input_shape_names.append(orig_input_shape_name) + else: + all_input_shape_names.append(seen_shape_name) + else: + all_input_shape_names.append(orig_input_shape_name) else: all_input_shape_names.append("") # update all_input_shapes with real input shapes for idx, (input_tensor, input_index) in enumerate(zip(inputs, input_indices)): input_shape_name = f'{input_tensor._attrs["name"]}_shape_{idx}' + sub_name = input_shape_name_subs.get(input_shape_name, None) + if sub_name is not None: + input_shape_name = sub_name all_input_shape_names[input_index] = input_shape_name return FUNC_CALL_TEMPLATE.render( diff --git a/tests/unittest/ops/test_concatenate.py b/tests/unittest/ops/test_concatenate.py index a7560f661..15fa91270 100644 --- a/tests/unittest/ops/test_concatenate.py +++ b/tests/unittest/ops/test_concatenate.py @@ -31,7 +31,13 @@ def __init__(self, *args, **kwargs): self.test_count = 0 def _run_concatenate( - self, *, concatenate_op, input_shapes, dim=None, input_type="float16" + self, + *, + concatenate_op, + input_shapes, + dim=None, + input_type="float16", + optimize_args=False, ): # generate torch reference result input_tensors_pt = [ @@ -44,7 +50,10 @@ def _run_concatenate( else torch.cat(input_tensors_pt, dim) ) - target = detect_target() + if optimize_args: + target = detect_target(optimize_for_compilation_time=True) + else: + target = detect_target() inputs = [ Tensor( shape=shape, dtype=input_type, name="input_{}".format(i), is_input=True @@ -68,9 +77,19 @@ def _run_concatenate( self.test_count += 1 def _run_batch_concatenate( - self, *, batch_sizes, concatenate_op, input_shapes, dim=0, input_type="float16" + self, + *, + batch_sizes, + concatenate_op, + input_shapes, + dim=0, + input_type="float16", + optimize_args=False, ): - target = detect_target() + if optimize_args: + target = detect_target(optimize_for_compilation_time=True) + else: + target = detect_target() BATCH_DIM_NAME = "input_batch" batch_dim = shape_utils.gen_int_var_min_max( values=batch_sizes, name=BATCH_DIM_NAME @@ -122,6 +141,7 @@ def _run_masked_concatenate( input_masks, dim=None, input_type="float16", + optimize_args=False, ): # generate torch reference result input_tensors_pt = [ @@ -134,7 +154,10 @@ def _run_masked_concatenate( else torch.cat(input_tensors_pt, dim) ) - target = detect_target() + if optimize_args: + target = detect_target(optimize_for_compilation_time=True) + else: + target = detect_target() inputs = [ Tensor( shape=shape, dtype=input_type, name="input_{}".format(i), is_input=True @@ -184,6 +207,13 @@ def test_batch_cat(self): input_shapes=([1], [1]), dim=0, ) + self._run_batch_concatenate( + batch_sizes=[1, 1], + concatenate_op=ops.concatenate(), + input_shapes=([1], [1]), + dim=0, + optimize_args=True, + ) self._run_batch_concatenate( batch_sizes=[1, 1], concatenate_op=ops.concatenate(), @@ -228,6 +258,12 @@ def test_cat(self): self._run_concatenate( concatenate_op=ops.concatenate(), input_shapes=([1, 1], [1, 1]), dim=0 ) + self._run_concatenate( + concatenate_op=ops.concatenate(), + input_shapes=([1, 1], [1, 1]), + dim=0, + optimize_args=True, + ) self._run_concatenate( concatenate_op=ops.concatenate(), input_shapes=([1, 1], [1, 1]), dim=1 ) @@ -321,6 +357,13 @@ def test_cat(self): ) def test_masked_cat(self): + self._run_masked_concatenate( + concatenate_op=ops.concatenate(), + input_shapes=([2, 2, 2], [2, 2, 2], [2, 2, 2]), + input_masks=[True, True, False], + dim=2, + optimize_args=True, + ) self._run_masked_concatenate( concatenate_op=ops.concatenate(), input_shapes=([2], [2]), @@ -345,6 +388,13 @@ def test_masked_cat(self): input_masks=[False, True, False], dim=2, ) + self._run_masked_concatenate( + concatenate_op=ops.concatenate(), + input_shapes=([1, 1, 1], [1, 1, 2], [1, 1, 4]), + input_masks=[False, True, False], + dim=2, + optimize_args=True, + ) @parameterized.expand(("float16", "float32", "bfloat16")) def test_floats(self, dtype):