From 7d427ca16ed6981781b218182b6b579aa7d7fcad Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Tue, 18 Jun 2024 09:32:11 -0700 Subject: [PATCH] allow concatenating empty tensors (#1010) Summary: This PR added support to concatenate empty tensors Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/1010 Reviewed By: ColinPeppler, 22quinn Differential Revision: D58646494 Pulled By: chenyang78 fbshipit-source-id: c3075134232e0ab7e88f10e2db3e24f95b61beeb --- .../backend/common/concatenate_common.py | 32 +++++++-- .../backend/cuda/tensor/concatenate_fast.py | 2 +- .../compiler/ops/tensor/concatenate.py | 65 ++++++++++++++++--- python/aitemplate/utils/shape_utils.py | 9 +++ .../compiler/test_remove_no_op_concats.py | 2 +- tests/unittest/ops/test_concatenate.py | 29 ++++++++- 6 files changed, 121 insertions(+), 18 deletions(-) diff --git a/python/aitemplate/backend/common/concatenate_common.py b/python/aitemplate/backend/common/concatenate_common.py index 5e2d4f613..0e5235a43 100644 --- a/python/aitemplate/backend/common/concatenate_common.py +++ b/python/aitemplate/backend/common/concatenate_common.py @@ -15,6 +15,7 @@ """ backend concatenate function common templates. """ +from copy import deepcopy from typing import List import jinja2 @@ -24,6 +25,7 @@ from aitemplate.compiler.base import IntImm from aitemplate.compiler.ops.tensor import concatenate +from aitemplate.utils.shape_utils import is_empty_rank1_tensor FUNC_DECL_TEMPLATE = jinja2.Template( """ @@ -685,7 +687,7 @@ def gen_function( inputs = func_attrs["inputs"] original_inputs = func_attrs["original_inputs"] concatenate.check_rank(original_inputs, func_attrs["concat_dim"]) - orig_x = original_inputs[0] + orig_x = concatenate.get_first_non_empty_input_if_any(original_inputs) y = func_attrs["outputs"][0] x_shape = orig_x._attrs["shape"] @@ -830,7 +832,7 @@ def gen_function_call( ) original_inputs = func_attrs["original_inputs"] concatenate.check_rank(original_inputs, func_attrs["concat_dim"]) - orig_x = original_inputs[0] + orig_x, _ = concatenate.get_first_non_empty_input_if_any(original_inputs) y = func_attrs["outputs"][0] concat_dim = func_attrs["concat_dim"] @@ -857,9 +859,17 @@ def _make_dims_key(dims): ] return ",".join(dim_vals) + non_empty_input, non_empty_idx = concatenate.get_first_non_empty_input_if_any( + inputs + ) + non_empty_input_accessor = input_accessors[non_empty_idx] for idx, (i, input_accessor) in enumerate(zip(inputs, input_accessors)): input_shape_name = f'{i._attrs["name"]}_shape_{idx}' orig_input_shape = input_accessor.original_shapes + if is_empty_rank1_tensor(orig_input_shape): + orig_dim = orig_input_shape[0] + orig_input_shape = deepcopy(non_empty_input_accessor.original_shapes) + orig_input_shape[concat_dim] = orig_dim dims = ", ".join([dim._attrs["name"] for dim in orig_input_shape]) dims_key = _make_dims_key(orig_input_shape) seen_shape_name = seen_input_shape_dims.get(dims_key, None) @@ -883,14 +893,22 @@ def _make_dims_key(dims): input_masks = func_attrs["input_masks"] input_indices = [idx for idx, m in enumerate(input_masks) if m is True] assert len(inputs) == len(input_indices) - concat_dim_sizes = [ - "-1" if mask else str(original_inputs[idx]._attrs["shape"][concat_dim].value()) - for idx, mask in enumerate(input_masks) - ] + concat_dim_sizes = [] + for idx, mask in enumerate(input_masks): + if is_empty_rank1_tensor(original_inputs[idx]._attrs["shape"]): + d = "0" + elif mask: + d = "-1" + else: + d = str(original_inputs[idx]._attrs["shape"][concat_dim].value()) + concat_dim_sizes.append(d) # update dim size for real inputs for input_accessor, input_index in zip(input_accessors, input_indices): - dim = input_accessor.original_shapes[concat_dim]._attrs["name"] + if is_empty_rank1_tensor(input_accessor.original_shapes): + dim = input_accessor.original_shapes[0]._attrs["name"] + else: + dim = input_accessor.original_shapes[concat_dim]._attrs["name"] concat_dim_sizes[input_index] = dim input_mask_values = ["true" if mask is True else "false" for mask in input_masks] diff --git a/python/aitemplate/backend/cuda/tensor/concatenate_fast.py b/python/aitemplate/backend/cuda/tensor/concatenate_fast.py index fb58b6bac..6d3263c2f 100644 --- a/python/aitemplate/backend/cuda/tensor/concatenate_fast.py +++ b/python/aitemplate/backend/cuda/tensor/concatenate_fast.py @@ -128,7 +128,7 @@ def gen_function( inputs = func_attrs["inputs"] original_inputs = func_attrs["original_inputs"] concatenate.check_rank(original_inputs, func_attrs["concat_dim"]) - orig_x = original_inputs[0] + orig_x, _ = concatenate.get_first_non_empty_input_if_any(original_inputs) y = func_attrs["outputs"][0] x_shape = orig_x._attrs["shape"] diff --git a/python/aitemplate/compiler/ops/tensor/concatenate.py b/python/aitemplate/compiler/ops/tensor/concatenate.py index 1606d1ffd..27bda6492 100644 --- a/python/aitemplate/compiler/ops/tensor/concatenate.py +++ b/python/aitemplate/compiler/ops/tensor/concatenate.py @@ -15,12 +15,13 @@ """ Concatenate. """ +from copy import deepcopy from functools import reduce -from typing import List, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union from aitemplate import backend from aitemplate.backend import registry -from aitemplate.compiler.base import IntVar, Operator, Tensor +from aitemplate.compiler.base import IntImm, IntVar, Operator, Tensor from aitemplate.compiler.tensor_accessor import TensorAccessor from aitemplate.utils import shape_utils from aitemplate.utils.tensor_utils import wrap_dim @@ -57,13 +58,38 @@ def __init__(self, fast_cat=True) -> None: def _unique(self, vector): return sorted(set(vector)) + @staticmethod + def get_rank(inputs: List[Tensor]) -> Optional[int]: + input_rank = None + for inp in inputs: + if not shape_utils.is_empty_rank1_tensor(inp._attrs["shape"]): + input_rank = inp._rank() + break + return input_rank + + @staticmethod + def get_first_non_empty_input_if_any(inputs: List[Tensor]) -> Tuple[Tensor, int]: + """Return the first non-empty input and its index from the list. + If all inputs are empty, return the first input. + """ + assert len(inputs) > 0, "len(inputs) must be > 0!" + t = None + idx = 0 + for i, inp in enumerate(inputs): + if not shape_utils.is_empty_rank1_tensor(inp._attrs["shape"]): + return (inp, i) + if t is None: + t = inputs[0] + return (t, idx) + @staticmethod def check_rank(inputs: List[Tensor], dim) -> bool: """check if the rank is valid""" if len(inputs) < 1: raise RuntimeError("expected a list of Tensors") - x = inputs[0] - rank = len(x._attrs["shape"]) + rank = concatenate.get_rank(inputs) + if rank is None: + return if rank <= 0: raise RuntimeError("expected a non-scalar tensor") if dim >= rank: @@ -71,6 +97,8 @@ def check_rank(inputs: List[Tensor], dim) -> bool: f"concat_dim ({dim}) expected to be less than rank ({rank})" ) for t in inputs: + if shape_utils.is_empty_rank1_tensor(t._attrs["shape"]): + continue r = len(t._attrs["shape"]) if r != rank: raise RuntimeError( @@ -81,8 +109,22 @@ def check_rank(inputs: List[Tensor], dim) -> bool: def _infer_shapes(self, inputs: List[Tensor], dim) -> List[IntVar]: """Infers shapes for concatenate.""" concatenate.check_rank(inputs, dim) + rank = concatenate.get_rank(inputs) + # all inputs are empty + if rank is None: + return [IntImm(0)] - input_shapes = [i._attrs["shape"] for i in inputs] + ref_input, _ = concatenate.get_first_non_empty_input_if_any(inputs) + # reference shape should come from a non-empty tensor + ref_input_shape = ref_input._attrs["shape"] + input_shapes = [] + for t in inputs: + if shape_utils.is_empty_rank1_tensor(t._attrs["shape"]): + shape = deepcopy(ref_input_shape) + shape[dim] = IntImm(0) + else: + shape = t._attrs["shape"] + input_shapes.append(shape) output_shape = [] input_shape_values = [ [d._attrs["values"] for d in shape] for shape in input_shapes @@ -103,8 +145,11 @@ def _infer_shapes(self, inputs: List[Tensor], dim) -> List[IntVar]: ) output_shape.append(shape_var) else: - output_dim = input_shapes[0][idx] + output_dim = ref_input_shape[idx] for shape in input_shapes: + # the corresponding input tensor is empty + if shape_utils.is_empty_rank1_tensor(shape): + continue # if output_dim != shape[idx]: if output_dim._attrs["values"] != shape[idx]._attrs["values"]: raise RuntimeError( @@ -128,8 +173,12 @@ def __call__(self, inputs: List[Tensor], dim=0) -> Tensor: self._attrs["original_inputs"] = list(inputs) # True means the corresponding tensor will be copied by the concat backend. self._attrs["input_masks"] = [True] * len(inputs) - input_rank = inputs[0]._rank() - dim = wrap_dim(dim, input_rank) + input_rank = concatenate.get_rank(inputs) + if input_rank is not None: + dim = wrap_dim(dim, input_rank) + else: + # force dim to be 0 + dim = 0 self._attrs["concat_dim"] = dim self._set_depth() output_shape = self._infer_shapes(inputs, dim) diff --git a/python/aitemplate/utils/shape_utils.py b/python/aitemplate/utils/shape_utils.py index d02c84112..635ae85ab 100644 --- a/python/aitemplate/utils/shape_utils.py +++ b/python/aitemplate/utils/shape_utils.py @@ -234,3 +234,12 @@ def get_static_stride(shape, dim) -> Optional[int]: return None stride *= d.value() return stride + + +def is_empty_rank1_tensor(shape) -> bool: + """ + Return True if the input shape is empty + """ + from aitemplate.compiler.base import IntImm + + return len(shape) == 1 and isinstance(shape[0], IntImm) and shape[0].value() == 0 diff --git a/tests/unittest/compiler/test_remove_no_op_concats.py b/tests/unittest/compiler/test_remove_no_op_concats.py index dc213b292..44e5f45b7 100644 --- a/tests/unittest/compiler/test_remove_no_op_concats.py +++ b/tests/unittest/compiler/test_remove_no_op_concats.py @@ -106,7 +106,7 @@ def test_remove_no_op_concats_exceptions(self): # AIT expects all concat inputs to have the same rank. with self.assertRaises(RuntimeError): self._test_remove_no_op_concats_impl( - input_shapes=[[2, 4], [0]], + input_shapes=[[2, 4], [1]], should_keep_concat=False, test_name="test_remove_no_op_concats_same_rank", ) diff --git a/tests/unittest/ops/test_concatenate.py b/tests/unittest/ops/test_concatenate.py index 15fa91270..ec2828697 100644 --- a/tests/unittest/ops/test_concatenate.py +++ b/tests/unittest/ops/test_concatenate.py @@ -191,7 +191,12 @@ def _run_masked_concatenate( y = torch.empty_like(y_pt) module.run_with_tensors(inputs, [y]) - split_sections = [shape[dim] for shape in input_shapes] + split_sections = [] + for shape in input_shapes: + if len(shape) == 1 and shape[0] == 0: + split_sections.append(0) + else: + split_sections.append(shape[dim]) ys_pt = torch.split(y_pt, split_sections, dim=dim) ys = torch.split(y, split_sections, dim=dim) @@ -252,6 +257,21 @@ def test_batch_cat(self): ) def test_cat(self): + self._run_concatenate( + concatenate_op=ops.concatenate(), + input_shapes=([0], [2, 2, 2]), + dim=2, + ) + self._run_concatenate( + concatenate_op=ops.concatenate(), + input_shapes=([2, 0, 4], [2, 3, 4], [0], [2, 2, 4]), + dim=1, + ) + self._run_concatenate( + concatenate_op=ops.concatenate(), + input_shapes=([0], [2, 3, 4], [0], [2, 2, 4]), + dim=1, + ) self._run_concatenate( concatenate_op=ops.concatenate(), input_shapes=([1], [1]), dim=0 ) @@ -357,6 +377,13 @@ def test_cat(self): ) def test_masked_cat(self): + self._run_masked_concatenate( + concatenate_op=ops.concatenate(), + input_shapes=([0], [2, 2, 2]), + input_masks=[False, True], + dim=2, + optimize_args=True, + ) self._run_masked_concatenate( concatenate_op=ops.concatenate(), input_shapes=([2, 2, 2], [2, 2, 2], [2, 2, 2]),