diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index 1839adcbf..a21fed542 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -42,6 +42,7 @@ getitem, group_norm, identity, + index_select, IntImm, IntVar, IntVarTensor, @@ -1829,3 +1830,24 @@ def acc_ops_masked_select( mask = kwargs["mask"] return masked_select()(input_val, mask) + + +@ait_converter(acc_ops.index_select) +def acc_ops_index_select( + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> ConverterOutput: + input_val = args[0] if len(args) >= 1 else kwargs["input"] + dim = args[1] if len(args) >= 2 else kwargs["dim"] + index = args[2] if len(args) >= 3 else kwargs["index"] + + if not isinstance(input_val, AITTensor): + raise RuntimeError(f"Non-tensor 'input' for {name}: {input_val}") + if not isinstance(dim, int): + raise RuntimeError(f"Non-int 'dim' for {name}: {dim}") + if not isinstance(index, AITTensor): + raise RuntimeError(f"Non-tensor 'index' for {name}: {index}") + + return index_select(dim=dim)(x=input_val, dim_idxs=index) diff --git a/fx2ait/fx2ait/test/converters/test_ait_index_select.py b/fx2ait/fx2ait/test/converters/test_ait_index_select.py new file mode 100644 index 000000000..d3a85c15a --- /dev/null +++ b/fx2ait/fx2ait/test/converters/test_ait_index_select.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +from fx2ait.acc_tracer import acc_ops +from fx2ait.tools.common_fx2ait import AITTestCase +from parameterized import param, parameterized + + +class TestIndexSelectConverter(AITTestCase): + @parameterized.expand( + [ + param( + "first_dim", + torch.randn(5, 10, 20), + 0, + torch.randint(low=0, high=5, size=(3,)), + ), + param( + "mid_dim", + torch.randn(5, 10, 20), + 1, + torch.randint(low=0, high=10, size=(20,)), + ), + param( + "last_dim", + torch.randn(5, 10, 20), + 2, + torch.randint(low=0, high=20, size=(10,)), + ), + ] + ) + def test_index_select(self, _, inp, dim, index): + class TestModule(torch.nn.Module): + def forward( + self, + inp: torch.Tensor, + index: torch.Tensor, + ) -> torch.Tensor: + return torch.index_select(inp, dim, index=index) + + model = TestModule().eval().half().cuda() + inputs = [inp.cuda(), index.cuda()] + + self.run_test(model, inputs, expected_ops={acc_ops.index_select}) diff --git a/python/aitemplate/compiler/public/__init__.py b/python/aitemplate/compiler/public/__init__.py index 9d9f9bc33..17e031ee2 100644 --- a/python/aitemplate/compiler/public/__init__.py +++ b/python/aitemplate/compiler/public/__init__.py @@ -73,6 +73,7 @@ from aitemplate.compiler.ops.pool.avg_pool2d import avg_pool2d from aitemplate.compiler.ops.pool.max_pool2d import max_pool2d from aitemplate.compiler.ops.softmax.softmax import softmax +from aitemplate.compiler.ops.tensor.index_select import index_select from aitemplate.compiler.ops.tensor.masked_select import masked_select from aitemplate.compiler.ops.tensor.size import size from aitemplate.compiler.ops.tensor.topk import topk