From ac6c3ad76c482f69f75d047109f1a472b943afb4 Mon Sep 17 00:00:00 2001 From: Shuqi Yang Date: Thu, 29 Jun 2023 15:54:55 -0700 Subject: [PATCH] Split slice_scatter into multiple ones if it has too many inputs (#801) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/801 Split slice_scatter into multiple ones if it has too many inputs. The process is very similar to split slice_reshape_scatter. Added the TensorAccessor attribute in slice_scatter op (but will only use its offset field) to make the split logic work. Differential Revision: D46962881 fbshipit-source-id: 6e642d3e0ad51d7f8b97c743d6d92925a4392545 --- .../backend/cuda/tensor/slice_scatter.py | 8 +- .../compiler/ops/tensor/slice_scatter.py | 25 ++-- .../split_large_slice_scatter_ops.py | 11 +- .../transform/transform_strided_ops.py | 2 +- .../test_split_large_slice_scatter.py | 113 ++++++++++++++++++ 5 files changed, 145 insertions(+), 14 deletions(-) create mode 100644 tests/unittest/compiler/test_split_large_slice_scatter.py diff --git a/python/aitemplate/backend/cuda/tensor/slice_scatter.py b/python/aitemplate/backend/cuda/tensor/slice_scatter.py index 254193f59..8cf968415 100644 --- a/python/aitemplate/backend/cuda/tensor/slice_scatter.py +++ b/python/aitemplate/backend/cuda/tensor/slice_scatter.py @@ -53,8 +53,14 @@ def gen_function(func_attrs): """ # TODO: consider to profile elems_per_thread elems_per_thread = 8 if len(func_attrs["inputs"]) == 1 else 256 + output_accessor = func_attrs["output_accessors"][0] + output_offset = output_accessor.offset return slice_common.gen_function( - func_attrs, backend_spec=CUDASpec(), elems_per_thread=elems_per_thread + func_attrs, + backend_spec=CUDASpec(), + elems_per_thread=elems_per_thread, + output_offset=output_offset, + update_output_shape=False, ) diff --git a/python/aitemplate/compiler/ops/tensor/slice_scatter.py b/python/aitemplate/compiler/ops/tensor/slice_scatter.py index 1c5cf5393..8c2d4f008 100644 --- a/python/aitemplate/compiler/ops/tensor/slice_scatter.py +++ b/python/aitemplate/compiler/ops/tensor/slice_scatter.py @@ -20,6 +20,7 @@ from aitemplate.backend import registry from aitemplate.compiler.base import Operator from aitemplate.compiler.stable_set import StableSet +from aitemplate.compiler.tensor_accessor import TensorAccessor # pylint: disable=C0103,W0221 @@ -62,6 +63,10 @@ def _update_inputs_outputs(self, cat_op): input_tensor._attrs["dst_ops"].add(self) self._attrs["inputs"].append(input_tensor) + # The original output of this slice_scatter op is the output of the cat_op. + # We set the TensorAccessor, but will only use its offset field in the backend. + self._attrs["output_accessors"] = [TensorAccessor(cat_op._attrs["outputs"][0])] + self._attrs["outputs"] = cat_op._attrs["outputs"] for y in self._attrs["outputs"]: y._attrs["src_ops"] = StableSet({self}) @@ -74,23 +79,27 @@ def _update_inputs_outputs(self, cat_op): x._attrs["src_ops"] = StableSet() x._attrs["dst_ops"] = StableSet() - def __init__(self, cat_op: Operator) -> None: + def __init__(self, scatter_dim: int) -> None: super().__init__() - assert slice_scatter.is_valid(cat_op) - self._attrs["op"] = "slice_scatter" self._attrs["has_profiler"] = False - self._attrs["scatter_dim"] = cat_op._attrs["concat_dim"] + self._attrs["scatter_dim"] = scatter_dim + + @staticmethod + def make_op(cat_op: Operator) -> Operator: + assert slice_scatter.is_valid(cat_op) + scatter_dim = cat_op._attrs["concat_dim"] + new_op = slice_scatter(scatter_dim) slice_ops = [] for x in cat_op._attrs["inputs"]: src_ops = x.src_ops() assert len(src_ops) == 1 slice_op = list(src_ops)[0] slice_ops.append(slice_op) - self._attrs["slice_ops"] = slice_ops - - self._update_inputs_outputs(cat_op) - self._set_depth() + new_op._attrs["slice_ops"] = slice_ops + new_op._update_inputs_outputs(cat_op) + new_op._set_depth() + return new_op def __call__(self): raise RuntimeError("op {} cannot be called directly".format(self._attrs["op"])) diff --git a/python/aitemplate/compiler/transform/split_large_slice_scatter_ops.py b/python/aitemplate/compiler/transform/split_large_slice_scatter_ops.py index 911f57656..864d14ec5 100644 --- a/python/aitemplate/compiler/transform/split_large_slice_scatter_ops.py +++ b/python/aitemplate/compiler/transform/split_large_slice_scatter_ops.py @@ -63,8 +63,7 @@ def split_large_slice_scatter_ops(sorted_graph: List[Tensor], _: str) -> List[Te """ sorted_ops = graph_utils.get_sorted_ops(sorted_graph) for op in sorted_ops: - # TODO: enable slice_scatter later - if not op._attrs["op"].startswith("slice_reshape_scatter"): + if not op._attrs["op"] in ["slice_reshape_scatter", "slice_scatter"]: continue slice_scatter_op = op # We create InputMeta for inputs that need to copy data. @@ -96,10 +95,14 @@ def split_large_slice_scatter_ops(sorted_graph: List[Tensor], _: str) -> List[Te has_profiler = slice_scatter_op._attrs["has_profiler"] local_output_offset = 0 orig_name = slice_scatter_op._attrs["name"] - element_func = slice_scatter_op._attrs["element_func"] slice_ops = slice_scatter_op._attrs["slice_ops"] for split_idx, new_inputs_size in enumerate(split_sizes): - new_slice_scatter_op = ops.slice_reshape_scatter(scatter_dim, element_func) + if op._attrs["op"] == "slice_scatter": + new_slice_scatter_op = ops.slice_scatter(scatter_dim) + elif op._attrs["op"] == "slice_reshape_scatter": + new_slice_scatter_op = ops.slice_reshape_scatter( + scatter_dim, slice_scatter_op._attrs["element_func"] + ) new_name = f"{orig_name}_split_{split_idx}" new_slice_scatter_op._attrs["name"] = new_name new_slice_scatter_op._attrs["original_name"] = new_name diff --git a/python/aitemplate/compiler/transform/transform_strided_ops.py b/python/aitemplate/compiler/transform/transform_strided_ops.py index 5174ba389..2de95d9be 100644 --- a/python/aitemplate/compiler/transform/transform_strided_ops.py +++ b/python/aitemplate/compiler/transform/transform_strided_ops.py @@ -51,7 +51,7 @@ def _fuse_slices_concat(sorted_graph: List[Tensor]) -> List[Tensor]: continue concat_op = src_op if slice_scatter.is_valid(concat_op): - slice_scatter(concat_op) + slice_scatter.make_op(concat_op) return transform_utils.sanitize_sorted_graph(sorted_graph) diff --git a/tests/unittest/compiler/test_split_large_slice_scatter.py b/tests/unittest/compiler/test_split_large_slice_scatter.py new file mode 100644 index 000000000..e47152dbe --- /dev/null +++ b/tests/unittest/compiler/test_split_large_slice_scatter.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import torch + +from aitemplate.compiler import compile_model, ops +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import ( + get_random_torch_tensor, + get_torch_empty_tensor, +) + + +class SliceScatterLargeInputsTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(SliceScatterLargeInputsTestCase, self).__init__(*args, **kwargs) + self.test_count = 1 + + @classmethod + def setUpClass(cls) -> None: + torch.manual_seed(0) + + def _test_slice_scatter( + self, input_shape, start_indices, end_indices, concat_dim, dtype + ): + num_slices = 140 + slice_outputs = [ + ops.dynamic_slice()( + Tensor( + shape=input_shape, dtype=dtype, name=f"input{idx}", is_input=True + ), + start_indices=start_indices, + end_indices=end_indices, + ) + for idx in range(num_slices) + ] + + Y = ops.concatenate()(slice_outputs, concat_dim) + + Y._attrs["name"] = "y" + Y._attrs["is_output"] = True + + target = detect_target() + dll_name = f"test_{self.test_count}.so" + test_name = f"slice_scatter_large_inputs_{self.test_count}" + + module = compile_model(Y, target, "./tmp", test_name, dll_name=dll_name) + + Y_src_ops = list(Y._attrs["src_ops"]) + self.assertEqual(len(Y_src_ops), 5) + self.assertTrue(all(op._attrs["op"] == "slice_scatter" for op in Y_src_ops)) + + input_pt = [ + get_random_torch_tensor(input_shape, dtype) for _ in range(num_slices) + ] + slice_indices = [slice(i, j) for i, j in zip(start_indices, end_indices)] + slice_outputs_pt = [input_i[slice_indices] for input_i in input_pt] + y_pt = torch.cat(slice_outputs_pt, concat_dim) + + inputs = {f"input{idx}": input_pt[idx] for idx in range(num_slices)} + y = get_torch_empty_tensor(y_pt.size(), dtype) + module.run_with_tensors(inputs, [y]) + self.assertTrue(torch.allclose(y_pt, y, atol=1e-2, rtol=1e-2)) + + self.test_count += 1 + + def test_slice_scatter_float(self): + self._test_slice_scatter( + input_shape=[3, 7, 10], + start_indices=[0, 0, 0], + end_indices=[2, 1, 4], + concat_dim=0, + dtype="float", + ) + self._test_slice_scatter( + input_shape=[3, 7, 10], + start_indices=[0, 0, 0], + end_indices=[2, 1, 4], + concat_dim=1, + dtype="float", + ) + self._test_slice_scatter( + input_shape=[3, 7, 10], + start_indices=[0, 0, 0], + end_indices=[2, 1, 4], + concat_dim=2, + dtype="float", + ) + self._test_slice_scatter( + input_shape=[3, 7, 10], + start_indices=[0, 0, 0], + end_indices=[2, 1, 4], + concat_dim=1, + dtype="float16", + ) + + +if __name__ == "__main__": + unittest.main()