From 4667de0e763a542e3fe48a09c43bd572f121bc86 Mon Sep 17 00:00:00 2001 From: Henry Hu Date: Mon, 10 Jul 2023 21:30:57 -0700 Subject: [PATCH] Add acc_ops.exp support (#820) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/820 Reviewed By: frank-wei Differential Revision: D47348835 fbshipit-source-id: 7e2a9645c174d8e7ae7f65fc2d450e1ae58b3c0f --- fx2ait/fx2ait/converters/ait_converters.py | 14 ++++++++++++++ .../fx2ait/test/converters/test_ait_unary_ops.py | 1 + 2 files changed, 15 insertions(+) diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index b5486365a..a38f53c17 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -382,6 +382,20 @@ def acc_ops_abs( return elementwise(FuncEnum.ABS)(input_val) +@ait_converter(acc_ops.exp) +def acc_ops_exp( + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> ConverterOutput: + input_val = kwargs["input"] + if not isinstance(input_val, AITTensor): + raise RuntimeError(f"Unexpected input for {name}: {input_val}") + + return elementwise(FuncEnum.EXP)(input_val) + + @ait_converter(acc_ops.log) def acc_ops_log( target: Target, diff --git a/fx2ait/fx2ait/test/converters/test_ait_unary_ops.py b/fx2ait/fx2ait/test/converters/test_ait_unary_ops.py index 23be04bc5..672fecd30 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_unary_ops.py +++ b/fx2ait/fx2ait/test/converters/test_ait_unary_ops.py @@ -37,6 +37,7 @@ (torch.sqrt, acc_ops.sqrt), (torch.clone, acc_ops.clone), (torch.neg, acc_ops.neg), + (torch.exp, acc_ops.exp), ] TestEnvToPrecision: Dict[TestEnv, Set[LowerPrecision]] = {