Skip to content

Commit

Permalink
Add index_select AIT converter (facebookincubator#947)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#947

`index_select` op has been [available](https://github.com/facebookincubator/AITemplate/blob/main/python/aitemplate/compiler/ops/tensor/index_select.py) in AIT for a while, but we never had an fx2ait converter for it. This diff adds one.

Reviewed By: qxy11

Differential Revision: D50119460

fbshipit-source-id: d1a09c2492504b849edf37af156938effd54666e
  • Loading branch information
aakhundov authored and facebook-github-bot committed Oct 11, 2023
1 parent a530636 commit 9452d5c
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 0 deletions.
22 changes: 22 additions & 0 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
getitem,
group_norm,
identity,
index_select,
IntImm,
IntVar,
IntVarTensor,
Expand Down Expand Up @@ -1829,3 +1830,24 @@ def acc_ops_masked_select(
mask = kwargs["mask"]

return masked_select()(input_val, mask)


@ait_converter(acc_ops.index_select)
def acc_ops_index_select(
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> ConverterOutput:
input_val = args[0] if len(args) >= 1 else kwargs["input"]
dim = args[1] if len(args) >= 2 else kwargs["dim"]
index = args[2] if len(args) >= 3 else kwargs["index"]

if not isinstance(input_val, AITTensor):
raise RuntimeError(f"Non-tensor 'input' for {name}: {input_val}")
if not isinstance(dim, int):
raise RuntimeError(f"Non-int 'dim' for {name}: {dim}")
if not isinstance(index, AITTensor):
raise RuntimeError(f"Non-tensor 'index' for {name}: {index}")

return index_select(dim=dim)(x=input_val, dim_idxs=index)
56 changes: 56 additions & 0 deletions fx2ait/fx2ait/test/converters/test_ait_index_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from fx2ait.acc_tracer import acc_ops
from fx2ait.tools.common_fx2ait import AITTestCase
from parameterized import param, parameterized


class TestIndexSelectConverter(AITTestCase):
@parameterized.expand(
[
param(
"first_dim",
torch.randn(5, 10, 20),
0,
torch.randint(low=0, high=5, size=(3,)),
),
param(
"mid_dim",
torch.randn(5, 10, 20),
1,
torch.randint(low=0, high=10, size=(20,)),
),
param(
"last_dim",
torch.randn(5, 10, 20),
2,
torch.randint(low=0, high=20, size=(10,)),
),
]
)
def test_index_select(self, _, inp, dim, index):
class TestModule(torch.nn.Module):
def forward(
self,
inp: torch.Tensor,
index: torch.Tensor,
) -> torch.Tensor:
return torch.index_select(inp, dim, index=index)

model = TestModule().eval().half().cuda()
inputs = [inp.cuda(), index.cuda()]

self.run_test(model, inputs, expected_ops={acc_ops.index_select})
1 change: 1 addition & 0 deletions python/aitemplate/compiler/public/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from aitemplate.compiler.ops.pool.avg_pool2d import avg_pool2d
from aitemplate.compiler.ops.pool.max_pool2d import max_pool2d
from aitemplate.compiler.ops.softmax.softmax import softmax
from aitemplate.compiler.ops.tensor.index_select import index_select
from aitemplate.compiler.ops.tensor.masked_select import masked_select
from aitemplate.compiler.ops.tensor.size import size
from aitemplate.compiler.ops.tensor.topk import topk
Expand Down

0 comments on commit 9452d5c

Please sign in to comment.