Skip to content

Commit

Permalink
Split slice_scatter into multiple ones if it has too many inputs (fac…
Browse files Browse the repository at this point in the history
…ebookincubator#801)

Summary:
Pull Request resolved: facebookincubator#801

Split slice_scatter into multiple ones if it has too many inputs. The process is very similar to split slice_reshape_scatter.

Added the TensorAccessor attribute in slice_scatter op (but will only use its offset field) to make the split logic work.

Differential Revision: D46962881

fbshipit-source-id: 6e642d3e0ad51d7f8b97c743d6d92925a4392545
  • Loading branch information
y-sq authored and facebook-github-bot committed Jun 29, 2023
1 parent eb4c375 commit ac6c3ad
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 14 deletions.
8 changes: 7 additions & 1 deletion python/aitemplate/backend/cuda/tensor/slice_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,14 @@ def gen_function(func_attrs):
"""
# TODO: consider to profile elems_per_thread
elems_per_thread = 8 if len(func_attrs["inputs"]) == 1 else 256
output_accessor = func_attrs["output_accessors"][0]
output_offset = output_accessor.offset
return slice_common.gen_function(
func_attrs, backend_spec=CUDASpec(), elems_per_thread=elems_per_thread
func_attrs,
backend_spec=CUDASpec(),
elems_per_thread=elems_per_thread,
output_offset=output_offset,
update_output_shape=False,
)


Expand Down
25 changes: 17 additions & 8 deletions python/aitemplate/compiler/ops/tensor/slice_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from aitemplate.backend import registry
from aitemplate.compiler.base import Operator
from aitemplate.compiler.stable_set import StableSet
from aitemplate.compiler.tensor_accessor import TensorAccessor

# pylint: disable=C0103,W0221

Expand Down Expand Up @@ -62,6 +63,10 @@ def _update_inputs_outputs(self, cat_op):
input_tensor._attrs["dst_ops"].add(self)
self._attrs["inputs"].append(input_tensor)

# The original output of this slice_scatter op is the output of the cat_op.
# We set the TensorAccessor, but will only use its offset field in the backend.
self._attrs["output_accessors"] = [TensorAccessor(cat_op._attrs["outputs"][0])]

self._attrs["outputs"] = cat_op._attrs["outputs"]
for y in self._attrs["outputs"]:
y._attrs["src_ops"] = StableSet({self})
Expand All @@ -74,23 +79,27 @@ def _update_inputs_outputs(self, cat_op):
x._attrs["src_ops"] = StableSet()
x._attrs["dst_ops"] = StableSet()

def __init__(self, cat_op: Operator) -> None:
def __init__(self, scatter_dim: int) -> None:
super().__init__()
assert slice_scatter.is_valid(cat_op)

self._attrs["op"] = "slice_scatter"
self._attrs["has_profiler"] = False
self._attrs["scatter_dim"] = cat_op._attrs["concat_dim"]
self._attrs["scatter_dim"] = scatter_dim

@staticmethod
def make_op(cat_op: Operator) -> Operator:
assert slice_scatter.is_valid(cat_op)
scatter_dim = cat_op._attrs["concat_dim"]
new_op = slice_scatter(scatter_dim)
slice_ops = []
for x in cat_op._attrs["inputs"]:
src_ops = x.src_ops()
assert len(src_ops) == 1
slice_op = list(src_ops)[0]
slice_ops.append(slice_op)
self._attrs["slice_ops"] = slice_ops

self._update_inputs_outputs(cat_op)
self._set_depth()
new_op._attrs["slice_ops"] = slice_ops
new_op._update_inputs_outputs(cat_op)
new_op._set_depth()
return new_op

def __call__(self):
raise RuntimeError("op {} cannot be called directly".format(self._attrs["op"]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def split_large_slice_scatter_ops(sorted_graph: List[Tensor], _: str) -> List[Te
"""
sorted_ops = graph_utils.get_sorted_ops(sorted_graph)
for op in sorted_ops:
# TODO: enable slice_scatter later
if not op._attrs["op"].startswith("slice_reshape_scatter"):
if not op._attrs["op"] in ["slice_reshape_scatter", "slice_scatter"]:
continue
slice_scatter_op = op
# We create InputMeta for inputs that need to copy data.
Expand Down Expand Up @@ -96,10 +95,14 @@ def split_large_slice_scatter_ops(sorted_graph: List[Tensor], _: str) -> List[Te
has_profiler = slice_scatter_op._attrs["has_profiler"]
local_output_offset = 0
orig_name = slice_scatter_op._attrs["name"]
element_func = slice_scatter_op._attrs["element_func"]
slice_ops = slice_scatter_op._attrs["slice_ops"]
for split_idx, new_inputs_size in enumerate(split_sizes):
new_slice_scatter_op = ops.slice_reshape_scatter(scatter_dim, element_func)
if op._attrs["op"] == "slice_scatter":
new_slice_scatter_op = ops.slice_scatter(scatter_dim)
elif op._attrs["op"] == "slice_reshape_scatter":
new_slice_scatter_op = ops.slice_reshape_scatter(
scatter_dim, slice_scatter_op._attrs["element_func"]
)
new_name = f"{orig_name}_split_{split_idx}"
new_slice_scatter_op._attrs["name"] = new_name
new_slice_scatter_op._attrs["original_name"] = new_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _fuse_slices_concat(sorted_graph: List[Tensor]) -> List[Tensor]:
continue
concat_op = src_op
if slice_scatter.is_valid(concat_op):
slice_scatter(concat_op)
slice_scatter.make_op(concat_op)

return transform_utils.sanitize_sorted_graph(sorted_graph)

Expand Down
113 changes: 113 additions & 0 deletions tests/unittest/compiler/test_split_large_slice_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

import torch

from aitemplate.compiler import compile_model, ops
from aitemplate.frontend import Tensor
from aitemplate.testing import detect_target
from aitemplate.testing.test_utils import (
get_random_torch_tensor,
get_torch_empty_tensor,
)


class SliceScatterLargeInputsTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(SliceScatterLargeInputsTestCase, self).__init__(*args, **kwargs)
self.test_count = 1

@classmethod
def setUpClass(cls) -> None:
torch.manual_seed(0)

def _test_slice_scatter(
self, input_shape, start_indices, end_indices, concat_dim, dtype
):
num_slices = 140
slice_outputs = [
ops.dynamic_slice()(
Tensor(
shape=input_shape, dtype=dtype, name=f"input{idx}", is_input=True
),
start_indices=start_indices,
end_indices=end_indices,
)
for idx in range(num_slices)
]

Y = ops.concatenate()(slice_outputs, concat_dim)

Y._attrs["name"] = "y"
Y._attrs["is_output"] = True

target = detect_target()
dll_name = f"test_{self.test_count}.so"
test_name = f"slice_scatter_large_inputs_{self.test_count}"

module = compile_model(Y, target, "./tmp", test_name, dll_name=dll_name)

Y_src_ops = list(Y._attrs["src_ops"])
self.assertEqual(len(Y_src_ops), 5)
self.assertTrue(all(op._attrs["op"] == "slice_scatter" for op in Y_src_ops))

input_pt = [
get_random_torch_tensor(input_shape, dtype) for _ in range(num_slices)
]
slice_indices = [slice(i, j) for i, j in zip(start_indices, end_indices)]
slice_outputs_pt = [input_i[slice_indices] for input_i in input_pt]
y_pt = torch.cat(slice_outputs_pt, concat_dim)

inputs = {f"input{idx}": input_pt[idx] for idx in range(num_slices)}
y = get_torch_empty_tensor(y_pt.size(), dtype)
module.run_with_tensors(inputs, [y])
self.assertTrue(torch.allclose(y_pt, y, atol=1e-2, rtol=1e-2))

self.test_count += 1

def test_slice_scatter_float(self):
self._test_slice_scatter(
input_shape=[3, 7, 10],
start_indices=[0, 0, 0],
end_indices=[2, 1, 4],
concat_dim=0,
dtype="float",
)
self._test_slice_scatter(
input_shape=[3, 7, 10],
start_indices=[0, 0, 0],
end_indices=[2, 1, 4],
concat_dim=1,
dtype="float",
)
self._test_slice_scatter(
input_shape=[3, 7, 10],
start_indices=[0, 0, 0],
end_indices=[2, 1, 4],
concat_dim=2,
dtype="float",
)
self._test_slice_scatter(
input_shape=[3, 7, 10],
start_indices=[0, 0, 0],
end_indices=[2, 1, 4],
concat_dim=1,
dtype="float16",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit ac6c3ad

Please sign in to comment.