Skip to content

Commit

Permalink
Fixed segfault caused by mis-aligned input tensors (#894)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #894

This diff fixed two issues:

* our input tensors may not be well-aligned with respect to our kernels.
For example, our fp16-gemm kernel may require 16-byte alignment, but
the input tensor for the gemm is not aligned by 16 bytes. We fix the
issue by cloning the non-alignment inputs with torch.clone, which almost
always gives us 256-byte aligned pointers.

* It's not uncommon that multiple shape definitions share the same
dimension values. In such a case, we could keep a single definition.
In some rare cases, this little trick can dramatically reduce the
number of lines generated for the relevant concatenate op and thus
may improve the compilation time. Currently, we only enable this
optimization for cases where we care about compilation time, as
for most cases, the unoptimized version can generate more readable
code while having little impact to the compilation time.

Reviewed By: ipiszy

Differential Revision: D48238260

fbshipit-source-id: 4298f586322c32a3beff46f55ca579de4a40171f
  • Loading branch information
chenyang78 authored and facebook-github-bot committed Aug 11, 2023
1 parent e0e00e2 commit c795b8c
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 22 deletions.
14 changes: 13 additions & 1 deletion fx2ait/fx2ait/csrc/AITModelImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,19 @@ std::vector<AITData> AITModelImpl::processInputs(
// call in a local!
input = input.to(*floating_point_input_dtype_);
}
inputs_contig.push_back(input.contiguous());
auto t = input.contiguous();
size_t elem_sz = t.element_size();
// Let's be conservative - make sure all inputs tensor pointers are
// aligned by 8 of sizeof(dtype)
if ((((uint64_t)(t.data_ptr())) % (elem_sz * 8)) != 0) {
LOG(INFO) << "FORCE tensor pointer alignment: input_name: " << input_name
<< ", unaligned addr: " << std::hex << t.data_ptr();
auto t2 = t.clone();
LOG(INFO) << "cloned tensor pointer addr: " << std::hex << t2.data_ptr();
inputs_contig.push_back(t2);
} else {
inputs_contig.push_back(t);
}
auto& input_contig = inputs_contig.back();
auto input_shape_array_ref = input_contig.sizes();
ait_inputs[ait_input_idx] = AITData{
Expand Down
76 changes: 60 additions & 16 deletions python/aitemplate/backend/common/concatenate_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import jinja2

from aitemplate.backend.common import tensor_accessor_codegen
from aitemplate.backend.target import Target

from aitemplate.compiler.base import IntImm
from aitemplate.compiler.ops.tensor import concatenate

FUNC_DECL_TEMPLATE = jinja2.Template(
Expand Down Expand Up @@ -745,18 +747,48 @@ def gen_function_call(
input_names = ",\n ".join([i._attrs["name"] for i in inputs])
real_input_shape_defs = []
real_input_shape_names = []
# It's not uncommon that multiple shape definitions share the same
# dimension values. In such a case, we could keep a single definition.
# In some rare cases, this little trick can dramatically reduce the
# number of lines generated for the relevant concatenate op and thus
# may improve the compilation time. Currently, we only enable this
# optimization for cases where we care about compilation time, as
# for most cases, the unoptimized version can generate more readable
# code while having little impact to the compilation time.
seen_input_shape_dims = {}
input_shape_name_subs = {}
optimize_for_compilation_time = Target.current()._kwargs.get(
"optimize_for_compilation_time", False
)

def _make_dims_key(dims):
dim_vals = []
for d in dims:
if isinstance(d, IntImm):
dim_vals.append(str(d.value()))
else:
d._attrs["name"]
return ",".join(dim_vals)

for idx, (i, input_accessor) in enumerate(zip(inputs, input_accessors)):
input_shape_name = f'{i._attrs["name"]}_shape_{idx}'
orig_input_shape = input_accessor.original_shapes
dims = ", ".join([dim._attrs["name"] for dim in orig_input_shape])
one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render(
indent=" ",
input_shape_name=input_shape_name,
input_dims=dims,
index_type=backend_spec.index_type,
)
real_input_shape_defs.append(one_shape_def)
real_input_shape_names.append(input_shape_name)
dims_key = _make_dims_key(orig_input_shape)
seen_shape_name = seen_input_shape_dims.get(dims_key, None)
if not optimize_for_compilation_time or seen_shape_name is None:
one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render(
indent=" ",
input_shape_name=input_shape_name,
input_dims=dims,
index_type=backend_spec.index_type,
)
real_input_shape_defs.append(one_shape_def)
real_input_shape_names.append(input_shape_name)
seen_input_shape_dims[dims_key] = input_shape_name
else:
real_input_shape_names.append(seen_shape_name)
input_shape_name_subs[input_shape_name] = seen_shape_name

y_shape = y._attrs["shape"]
y_dim_refs = ", ".join(["&" + dim._attrs["name"] for dim in y_shape])
Expand All @@ -781,6 +813,7 @@ def gen_function_call(
# all input shape defs and names, including those that are masked out
all_input_shape_defs = []
all_input_shape_names = []
seen_input_shape_dims = {}
# first, create shape defs for inputs that have been masked off
for (
mask,
Expand All @@ -792,19 +825,30 @@ def gen_function_call(
dims = ", ".join(
[str(dim._attrs["values"][0]) for dim in orig_input._attrs["shape"]]
)
one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render(
indent=" ",
input_shape_name=orig_input_shape_name,
input_dims=dims,
index_type=backend_spec.index_type,
)
all_input_shape_defs.append(one_shape_def)
all_input_shape_names.append(orig_input_shape_name)
dims_key = _make_dims_key(orig_input._attrs["shape"])
seen_shape_name = seen_input_shape_dims.get(dims_key, None)
if not optimize_for_compilation_time or seen_shape_name is None:
one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render(
indent=" ",
input_shape_name=orig_input_shape_name,
input_dims=dims,
index_type=backend_spec.index_type,
)
all_input_shape_defs.append(one_shape_def)
seen_input_shape_dims[dims_key] = orig_input_shape_name
all_input_shape_names.append(orig_input_shape_name)
else:
all_input_shape_names.append(seen_shape_name)
else:
all_input_shape_names.append(orig_input_shape_name)
else:
all_input_shape_names.append("")
# update all_input_shapes with real input shapes
for idx, (input_tensor, input_index) in enumerate(zip(inputs, input_indices)):
input_shape_name = f'{input_tensor._attrs["name"]}_shape_{idx}'
sub_name = input_shape_name_subs.get(input_shape_name, None)
if sub_name is not None:
input_shape_name = sub_name
all_input_shape_names[input_index] = input_shape_name

return FUNC_CALL_TEMPLATE.render(
Expand Down
60 changes: 55 additions & 5 deletions tests/unittest/ops/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ def __init__(self, *args, **kwargs):
self.test_count = 0

def _run_concatenate(
self, *, concatenate_op, input_shapes, dim=None, input_type="float16"
self,
*,
concatenate_op,
input_shapes,
dim=None,
input_type="float16",
optimize_args=False,
):
# generate torch reference result
input_tensors_pt = [
Expand All @@ -44,7 +50,10 @@ def _run_concatenate(
else torch.cat(input_tensors_pt, dim)
)

target = detect_target()
if optimize_args:
target = detect_target(optimize_for_compilation_time=True)
else:
target = detect_target()
inputs = [
Tensor(
shape=shape, dtype=input_type, name="input_{}".format(i), is_input=True
Expand All @@ -68,9 +77,19 @@ def _run_concatenate(
self.test_count += 1

def _run_batch_concatenate(
self, *, batch_sizes, concatenate_op, input_shapes, dim=0, input_type="float16"
self,
*,
batch_sizes,
concatenate_op,
input_shapes,
dim=0,
input_type="float16",
optimize_args=False,
):
target = detect_target()
if optimize_args:
target = detect_target(optimize_for_compilation_time=True)
else:
target = detect_target()
BATCH_DIM_NAME = "input_batch"
batch_dim = shape_utils.gen_int_var_min_max(
values=batch_sizes, name=BATCH_DIM_NAME
Expand Down Expand Up @@ -122,6 +141,7 @@ def _run_masked_concatenate(
input_masks,
dim=None,
input_type="float16",
optimize_args=False,
):
# generate torch reference result
input_tensors_pt = [
Expand All @@ -134,7 +154,10 @@ def _run_masked_concatenate(
else torch.cat(input_tensors_pt, dim)
)

target = detect_target()
if optimize_args:
target = detect_target(optimize_for_compilation_time=True)
else:
target = detect_target()
inputs = [
Tensor(
shape=shape, dtype=input_type, name="input_{}".format(i), is_input=True
Expand Down Expand Up @@ -184,6 +207,13 @@ def test_batch_cat(self):
input_shapes=([1], [1]),
dim=0,
)
self._run_batch_concatenate(
batch_sizes=[1, 1],
concatenate_op=ops.concatenate(),
input_shapes=([1], [1]),
dim=0,
optimize_args=True,
)
self._run_batch_concatenate(
batch_sizes=[1, 1],
concatenate_op=ops.concatenate(),
Expand Down Expand Up @@ -228,6 +258,12 @@ def test_cat(self):
self._run_concatenate(
concatenate_op=ops.concatenate(), input_shapes=([1, 1], [1, 1]), dim=0
)
self._run_concatenate(
concatenate_op=ops.concatenate(),
input_shapes=([1, 1], [1, 1]),
dim=0,
optimize_args=True,
)
self._run_concatenate(
concatenate_op=ops.concatenate(), input_shapes=([1, 1], [1, 1]), dim=1
)
Expand Down Expand Up @@ -321,6 +357,13 @@ def test_cat(self):
)

def test_masked_cat(self):
self._run_masked_concatenate(
concatenate_op=ops.concatenate(),
input_shapes=([2, 2, 2], [2, 2, 2], [2, 2, 2]),
input_masks=[True, True, False],
dim=2,
optimize_args=True,
)
self._run_masked_concatenate(
concatenate_op=ops.concatenate(),
input_shapes=([2], [2]),
Expand All @@ -345,6 +388,13 @@ def test_masked_cat(self):
input_masks=[False, True, False],
dim=2,
)
self._run_masked_concatenate(
concatenate_op=ops.concatenate(),
input_shapes=([1, 1, 1], [1, 1, 2], [1, 1, 4]),
input_masks=[False, True, False],
dim=2,
optimize_args=True,
)

@parameterized.expand(("float16", "float32", "bfloat16"))
def test_floats(self, dtype):
Expand Down

0 comments on commit c795b8c

Please sign in to comment.