Skip to content

Commit

Permalink
allow concatenating empty tensors
Browse files Browse the repository at this point in the history
This PR added support to concatenate empty tensors
  • Loading branch information
chenyang78 committed Jun 17, 2024
1 parent 95697dc commit 795a496
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 18 deletions.
32 changes: 25 additions & 7 deletions python/aitemplate/backend/common/concatenate_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""
backend concatenate function common templates.
"""
from copy import deepcopy
from typing import List

import jinja2
Expand All @@ -24,6 +25,7 @@

from aitemplate.compiler.base import IntImm
from aitemplate.compiler.ops.tensor import concatenate
from aitemplate.utils.shape_utils import is_empty_tensor_shape

FUNC_DECL_TEMPLATE = jinja2.Template(
"""
Expand Down Expand Up @@ -685,7 +687,7 @@ def gen_function(
inputs = func_attrs["inputs"]
original_inputs = func_attrs["original_inputs"]
concatenate.check_rank(original_inputs, func_attrs["concat_dim"])
orig_x = original_inputs[0]
orig_x = concatenate.get_first_non_empty_input_if_any(original_inputs)
y = func_attrs["outputs"][0]
x_shape = orig_x._attrs["shape"]

Expand Down Expand Up @@ -830,7 +832,7 @@ def gen_function_call(
)
original_inputs = func_attrs["original_inputs"]
concatenate.check_rank(original_inputs, func_attrs["concat_dim"])
orig_x = original_inputs[0]
orig_x, _ = concatenate.get_first_non_empty_input_if_any(original_inputs)
y = func_attrs["outputs"][0]
concat_dim = func_attrs["concat_dim"]

Expand All @@ -857,9 +859,17 @@ def _make_dims_key(dims):
]
return ",".join(dim_vals)

non_empty_input, non_empty_idx = concatenate.get_first_non_empty_input_if_any(
inputs
)
non_empty_input_accessor = input_accessors[non_empty_idx]
for idx, (i, input_accessor) in enumerate(zip(inputs, input_accessors)):
input_shape_name = f'{i._attrs["name"]}_shape_{idx}'
orig_input_shape = input_accessor.original_shapes
if is_empty_tensor_shape(orig_input_shape):
orig_dim = orig_input_shape[0]
orig_input_shape = deepcopy(non_empty_input_accessor.original_shapes)
orig_input_shape[concat_dim] = orig_dim
dims = ", ".join([dim._attrs["name"] for dim in orig_input_shape])
dims_key = _make_dims_key(orig_input_shape)
seen_shape_name = seen_input_shape_dims.get(dims_key, None)
Expand All @@ -883,14 +893,22 @@ def _make_dims_key(dims):
input_masks = func_attrs["input_masks"]
input_indices = [idx for idx, m in enumerate(input_masks) if m is True]
assert len(inputs) == len(input_indices)
concat_dim_sizes = [
"-1" if mask else str(original_inputs[idx]._attrs["shape"][concat_dim].value())
for idx, mask in enumerate(input_masks)
]
concat_dim_sizes = []
for idx, mask in enumerate(input_masks):
if is_empty_tensor_shape(original_inputs[idx]._attrs["shape"]):
d = "0"
elif mask:
d = "-1"
else:
d = str(original_inputs[idx]._attrs["shape"][concat_dim].value())
concat_dim_sizes.append(d)

# update dim size for real inputs
for input_accessor, input_index in zip(input_accessors, input_indices):
dim = input_accessor.original_shapes[concat_dim]._attrs["name"]
if is_empty_tensor_shape(input_accessor.original_shapes):
dim = input_accessor.original_shapes[0]._attrs["name"]
else:
dim = input_accessor.original_shapes[concat_dim]._attrs["name"]
concat_dim_sizes[input_index] = dim

input_mask_values = ["true" if mask is True else "false" for mask in input_masks]
Expand Down
2 changes: 1 addition & 1 deletion python/aitemplate/backend/cuda/tensor/concatenate_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def gen_function(
inputs = func_attrs["inputs"]
original_inputs = func_attrs["original_inputs"]
concatenate.check_rank(original_inputs, func_attrs["concat_dim"])
orig_x = original_inputs[0]
orig_x, _ = concatenate.get_first_non_empty_input_if_any(original_inputs)
y = func_attrs["outputs"][0]
x_shape = orig_x._attrs["shape"]

Expand Down
65 changes: 57 additions & 8 deletions python/aitemplate/compiler/ops/tensor/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
"""
Concatenate.
"""
from copy import deepcopy
from functools import reduce
from typing import List, Sequence, Union
from typing import List, Optional, Sequence, Tuple, Union

from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import IntVar, Operator, Tensor
from aitemplate.compiler.base import IntImm, IntVar, Operator, Tensor
from aitemplate.compiler.tensor_accessor import TensorAccessor
from aitemplate.utils import shape_utils
from aitemplate.utils.tensor_utils import wrap_dim
Expand Down Expand Up @@ -57,20 +58,47 @@ def __init__(self, fast_cat=True) -> None:
def _unique(self, vector):
return sorted(set(vector))

@staticmethod
def get_rank(inputs: List[Tensor]) -> Optional[int]:
input_rank = None
for inp in inputs:
if not shape_utils.is_empty_tensor_shape(inp._attrs["shape"]):
input_rank = inp._rank()
break
return input_rank

@staticmethod
def get_first_non_empty_input_if_any(inputs: List[Tensor]) -> Tuple[Tensor, int]:
"""Return the first non-empty input and its index from the list.
If all inputs are empty, return the first input.
"""
assert len(inputs) > 0, "len(inputs) must be > 0!"
t = None
idx = 0
for i, inp in enumerate(inputs):
if not shape_utils.is_empty_tensor_shape(inp._attrs["shape"]):
return (inp, i)
if t is None:
t = inputs[0]
return (t, idx)

@staticmethod
def check_rank(inputs: List[Tensor], dim) -> bool:
"""check if the rank is valid"""
if len(inputs) < 1:
raise RuntimeError("expected a list of Tensors")
x = inputs[0]
rank = len(x._attrs["shape"])
rank = concatenate.get_rank(inputs)
if rank is None:
return
if rank <= 0:
raise RuntimeError("expected a non-scalar tensor")
if dim >= rank:
raise RuntimeError(
f"concat_dim ({dim}) expected to be less than rank ({rank})"
)
for t in inputs:
if shape_utils.is_empty_tensor_shape(t._attrs["shape"]):
continue
r = len(t._attrs["shape"])
if r != rank:
raise RuntimeError(
Expand All @@ -81,8 +109,22 @@ def check_rank(inputs: List[Tensor], dim) -> bool:
def _infer_shapes(self, inputs: List[Tensor], dim) -> List[IntVar]:
"""Infers shapes for concatenate."""
concatenate.check_rank(inputs, dim)
rank = concatenate.get_rank(inputs)
# all inputs are empty
if rank is None:
return [IntImm(0)]

input_shapes = [i._attrs["shape"] for i in inputs]
ref_input, _ = concatenate.get_first_non_empty_input_if_any(inputs)
# reference shape should come from a non-empty tensor
ref_input_shape = ref_input._attrs["shape"]
input_shapes = []
for t in inputs:
if shape_utils.is_empty_tensor_shape(t._attrs["shape"]):
shape = deepcopy(ref_input_shape)
shape[dim] = IntImm(0)
else:
shape = t._attrs["shape"]
input_shapes.append(shape)
output_shape = []
input_shape_values = [
[d._attrs["values"] for d in shape] for shape in input_shapes
Expand All @@ -103,8 +145,11 @@ def _infer_shapes(self, inputs: List[Tensor], dim) -> List[IntVar]:
)
output_shape.append(shape_var)
else:
output_dim = input_shapes[0][idx]
output_dim = ref_input_shape[idx]
for shape in input_shapes:
# the corresponding input tensor is empty
if shape_utils.is_empty_tensor_shape(shape):
continue
# if output_dim != shape[idx]:
if output_dim._attrs["values"] != shape[idx]._attrs["values"]:
raise RuntimeError(
Expand All @@ -128,8 +173,12 @@ def __call__(self, inputs: List[Tensor], dim=0) -> Tensor:
self._attrs["original_inputs"] = list(inputs)
# True means the corresponding tensor will be copied by the concat backend.
self._attrs["input_masks"] = [True] * len(inputs)
input_rank = inputs[0]._rank()
dim = wrap_dim(dim, input_rank)
input_rank = concatenate.get_rank(inputs)
if input_rank is not None:
dim = wrap_dim(dim, input_rank)
else:
# force dim to be 0
dim = 0
self._attrs["concat_dim"] = dim
self._set_depth()
output_shape = self._infer_shapes(inputs, dim)
Expand Down
9 changes: 9 additions & 0 deletions python/aitemplate/utils/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,12 @@ def get_static_stride(shape, dim) -> Optional[int]:
return None
stride *= d.value()
return stride


def is_empty_tensor_shape(shape) -> bool:
"""
Return True if the input shape is empty
"""
from aitemplate.compiler.base import IntImm

return len(shape) == 1 and isinstance(shape[0], IntImm) and shape[0].value() == 0
2 changes: 1 addition & 1 deletion tests/unittest/compiler/test_remove_no_op_concats.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_remove_no_op_concats_exceptions(self):
# AIT expects all concat inputs to have the same rank.
with self.assertRaises(RuntimeError):
self._test_remove_no_op_concats_impl(
input_shapes=[[2, 4], [0]],
input_shapes=[[2, 4], [1]],
should_keep_concat=False,
test_name="test_remove_no_op_concats_same_rank",
)
Expand Down
29 changes: 28 additions & 1 deletion tests/unittest/ops/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,12 @@ def _run_masked_concatenate(
y = torch.empty_like(y_pt)
module.run_with_tensors(inputs, [y])

split_sections = [shape[dim] for shape in input_shapes]
split_sections = []
for shape in input_shapes:
if len(shape) == 1 and shape[0] == 0:
split_sections.append(0)
else:
split_sections.append(shape[dim])

ys_pt = torch.split(y_pt, split_sections, dim=dim)
ys = torch.split(y, split_sections, dim=dim)
Expand Down Expand Up @@ -252,6 +257,21 @@ def test_batch_cat(self):
)

def test_cat(self):
self._run_concatenate(
concatenate_op=ops.concatenate(),
input_shapes=([0], [2, 2, 2]),
dim=2,
)
self._run_concatenate(
concatenate_op=ops.concatenate(),
input_shapes=([2, 0, 4], [2, 3, 4], [0], [2, 2, 4]),
dim=1,
)
self._run_concatenate(
concatenate_op=ops.concatenate(),
input_shapes=([0], [2, 3, 4], [0], [2, 2, 4]),
dim=1,
)
self._run_concatenate(
concatenate_op=ops.concatenate(), input_shapes=([1], [1]), dim=0
)
Expand Down Expand Up @@ -357,6 +377,13 @@ def test_cat(self):
)

def test_masked_cat(self):
self._run_masked_concatenate(
concatenate_op=ops.concatenate(),
input_shapes=([0], [2, 2, 2]),
input_masks=[False, True],
dim=2,
optimize_args=True,
)
self._run_masked_concatenate(
concatenate_op=ops.concatenate(),
input_shapes=([2, 2, 2], [2, 2, 2], [2, 2, 2]),
Expand Down

0 comments on commit 795a496

Please sign in to comment.