From 02733de118ae60fb32e9c8edf3766a6d0326d7b7 Mon Sep 17 00:00:00 2001 From: hhh Date: Fri, 7 Jul 2023 23:52:22 -0700 Subject: [PATCH] Add converter for acc_ops.interpolate Differential Revision: D47312425 fbshipit-source-id: ce9c8da99f27643cc11000786da18512984c77c7 --- fx2ait/fx2ait/converters/ait_converters.py | 28 +++++++++++++ .../test/converters/test_ait_upsampling2d.py | 42 +++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 fx2ait/fx2ait/test/converters/test_ait_upsampling2d.py diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index ab4418987..b5486365a 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -67,6 +67,8 @@ vector_norm, ) +from aitemplate.frontend.nn import Upsampling2d + from fx2ait.acc_tracer import acc_ops, ait_acc_ops from torch.fx.node import Argument, Target @@ -1124,6 +1126,32 @@ def _is_int_list(iterable): return expand()(input_val, shape) +@ait_converter(acc_ops.interpolate) +def ait_acc_ops_interpolate( + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> ConverterOutput: + input_val = kwargs["input"] + + if not isinstance(input_val, AITTensor): + raise ValueError(f"Non-tensor inputs for {name}: {input_val}") + + scale_factor = kwargs["scale_factor"] + if not scale_factor: + raise ValueError("scale_factor cannot be empty") + + mode = kwargs["mode"] + if not mode: + raise ValueError("mode cannot be empty") + + op = Upsampling2d(scale_factor=scale_factor, mode=mode) + + res = op(ait_nchw2nhwc(input_val)) + return ait_nhwc2nchw(res) + + @ait_converter(acc_ops.batch_norm) def acc_ops_batch_norm( target: Target, diff --git a/fx2ait/fx2ait/test/converters/test_ait_upsampling2d.py b/fx2ait/fx2ait/test/converters/test_ait_upsampling2d.py new file mode 100644 index 000000000..20b506fc0 --- /dev/null +++ b/fx2ait/fx2ait/test/converters/test_ait_upsampling2d.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +from fx2ait.acc_tracer import acc_ops +from fx2ait.tools.common_fx2ait import AITTestCase +from parameterized import param, parameterized + + +class TestInterpolateConverter(AITTestCase): + @parameterized.expand( + [ + param(scale_factor=1, mode="nearest"), + param(scale_factor=2, mode="nearest"), + param(scale_factor=2, mode="bilinear"), + ] + ) + def test_interpolate(self, scale_factor, mode): + class TestModule(torch.nn.Module): + def forward(self, y: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate( + y, scale_factor=scale_factor, mode=mode + ) + return x + + model = TestModule().cuda().half() + inputs = [ + torch.randn([2, 8, 16, 16]).half().cuda(), + ] + + self.run_test(model, inputs, expected_ops={acc_ops.interpolate})