Skip to content

Commit

Permalink
Add converter for acc_ops.interpolate
Browse files Browse the repository at this point in the history
Differential Revision: D47312425

fbshipit-source-id: faa1884928a5806552b6c5903036b40d5978925f
  • Loading branch information
henryhu6 authored and facebook-github-bot committed Jul 8, 2023
1 parent 34de7fb commit 6865d87
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
28 changes: 28 additions & 0 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
vector_norm,
)

from aitemplate.frontend.nn import Upsampling2d

from fx2ait.acc_tracer import acc_ops, ait_acc_ops
from torch.fx.node import Argument, Target

Expand Down Expand Up @@ -1124,6 +1126,32 @@ def _is_int_list(iterable):
return expand()(input_val, shape)


@ait_converter(acc_ops.interpolate)
def ait_acc_ops_interpolate(
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> ConverterOutput:
input_val = kwargs["input"]

if not isinstance(input_val, AITTensor):
raise ValueError(f"Non-tensor inputs for {name}: {input_val}")

scale_factor = kwargs["scale_factor"]
if not scale_factor:
raise ValueError("scale_factor cannot be empty")

mode = kwargs["mode"]
if not mode:
raise ValueError("mode cannot be empty")

op = Upsampling2d(scale_factor=scale_factor, mode=mode)

res = op(ait_nchw2nhwc(input_val))
return ait_nhwc2nchw(res)


@ait_converter(acc_ops.batch_norm)
def acc_ops_batch_norm(
target: Target,
Expand Down
42 changes: 42 additions & 0 deletions fx2ait/fx2ait/test/converters/test_ait_upsampling2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from fx2ait.acc_tracer import acc_ops
from fx2ait.tools.common_fx2ait import AITTestCase
from parameterized import param, parameterized


class TestInterpolateConverter(AITTestCase):
@parameterized.expand(
[
param(scale_factor=1, mode="nearest"),
param(scale_factor=2, mode="nearest"),
param(scale_factor=2, mode="bilinear"),
]
)
def test_interpolate(self, scale_factor, mode):
class TestModule(torch.nn.Module):
def forward(self, y: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.interpolate(
y, scale_factor=scale_factor, mode=mode
)
return x

model = TestModule().cuda().half()
inputs = [
torch.randn([2, 8, 16, 16]).half().cuda(),
]

self.run_test(model, inputs, expected_ops={acc_ops.interpolate})

0 comments on commit 6865d87

Please sign in to comment.