diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index b5486365a..a38f53c17 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -382,6 +382,20 @@ def acc_ops_abs( return elementwise(FuncEnum.ABS)(input_val) +@ait_converter(acc_ops.exp) +def acc_ops_exp( + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> ConverterOutput: + input_val = kwargs["input"] + if not isinstance(input_val, AITTensor): + raise RuntimeError(f"Unexpected input for {name}: {input_val}") + + return elementwise(FuncEnum.EXP)(input_val) + + @ait_converter(acc_ops.log) def acc_ops_log( target: Target, diff --git a/fx2ait/fx2ait/test/converters/test_ait_unary_ops.py b/fx2ait/fx2ait/test/converters/test_ait_unary_ops.py index 23be04bc5..672fecd30 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_unary_ops.py +++ b/fx2ait/fx2ait/test/converters/test_ait_unary_ops.py @@ -37,6 +37,7 @@ (torch.sqrt, acc_ops.sqrt), (torch.clone, acc_ops.clone), (torch.neg, acc_ops.neg), + (torch.exp, acc_ops.exp), ] TestEnvToPrecision: Dict[TestEnv, Set[LowerPrecision]] = {