From fa04b43642ab849778252737ad6637979abda022 Mon Sep 17 00:00:00 2001 From: Quinn Zhu Date: Fri, 1 Mar 2024 01:59:40 -0800 Subject: [PATCH] Add log1p elementwise op Summary: `log1p(x)` is more precise than `log(1+x)` when `x` is close to 0. We utilize cuda `log1pf` implementation for fp32. For other precision types, input is first converted to float, then `log1pf` is computed, finally output is converted back to original precision. CUDA log1pf function for float and double: https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html Reviewed By: frank-wei Differential Revision: D54176180 fbshipit-source-id: 42fc23ed7b461e80e4aee07f8bea4f5c003016ed --- fx2ait/fx2ait/acc_tracer/acc_ops.py | 6 +++++ fx2ait/fx2ait/converters/ait_converters.py | 14 +++++++++++ python/aitemplate/backend/backend_spec.py | 7 ++++++ .../backend/cuda/elementwise/custom_math.cuh | 24 +++++++++++++++++++ .../compiler/ops/common/epilogue.py | 1 + python/aitemplate/compiler/ops/common/math.py | 4 ++++ 6 files changed, 56 insertions(+) diff --git a/fx2ait/fx2ait/acc_tracer/acc_ops.py b/fx2ait/fx2ait/acc_tracer/acc_ops.py index f3a787cb0..d1c159f12 100644 --- a/fx2ait/fx2ait/acc_tracer/acc_ops.py +++ b/fx2ait/fx2ait/acc_tracer/acc_ops.py @@ -1782,6 +1782,12 @@ def log(*, input): return torch.log(input=input) +@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) +@register_acc_op +def log1p(*, input): + return torch.log1p(input=input) + + @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.sqrt)) @register_acc_op_mapping(op_and_target=("call_method", "sqrt")) diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index 129d778d6..40794c9c6 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -479,6 +479,20 @@ def acc_ops_log( return elementwise(FuncEnum.LOGE)(input_val) +@ait_converter(acc_ops.log1p) +def acc_ops_log1p( + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> ConverterOutput: + input_val = kwargs["input"] + if not isinstance(input_val, AITTensor): + raise RuntimeError(f"Unexpected input for {name}: {input_val}") + + return elementwise(FuncEnum.LOG1P)(input_val) + + @ait_converter(acc_ops.var) def acc_ops_var( target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str diff --git a/python/aitemplate/backend/backend_spec.py b/python/aitemplate/backend/backend_spec.py index 47d8618e1..e101a77d3 100644 --- a/python/aitemplate/backend/backend_spec.py +++ b/python/aitemplate/backend/backend_spec.py @@ -183,6 +183,13 @@ class GPUBackendSpec(BackendSpec): "bfloat16": "hlog", "float": "logf", }, + FuncEnum.LOG1P: { + "half2": "h2log1p", + "bfloat16_2": "h2log1p", + "half": "hlog1p", + "bfloat16": "hlog1p", + "float": "log1pf", + }, FuncEnum.EXP: { "half2": "h2exp", "bfloat16_2": "h2exp", diff --git a/python/aitemplate/backend/cuda/elementwise/custom_math.cuh b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh index 57c8d3a4f..eebe344f0 100644 --- a/python/aitemplate/backend/cuda/elementwise/custom_math.cuh +++ b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh @@ -1075,4 +1075,28 @@ __device__ bfloat16_2 h2celu(const bfloat16_2 a, const bfloat16_2 alpha) { #endif } +__device__ half hlog1p(const half a) { + return half(log1pf(float(a))); +} + +__device__ bfloat16 hlog1p(const bfloat16 a) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16(log1pf(float(a))); +#else + NOT_IMPLEMENTED(); +#endif +} + +__device__ half2 h2log1p(const half2 a) { + return half2(log1pf(float(a.x)), log1pf(float(a.y))); +} + +__device__ bfloat16_2 h2log1p(const bfloat16_2 a) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_2(log1pf(float(a.x)), log1pf(float(a.y))); +#else + NOT_IMPLEMENTED(); +#endif +} + #endif diff --git a/python/aitemplate/compiler/ops/common/epilogue.py b/python/aitemplate/compiler/ops/common/epilogue.py index 18e41a35c..f634b44b6 100644 --- a/python/aitemplate/compiler/ops/common/epilogue.py +++ b/python/aitemplate/compiler/ops/common/epilogue.py @@ -66,3 +66,4 @@ class FuncEnum(Enum): FLOOR_DIV = 28 CELU = 29 FLOOR = 30 + LOG1P = 31 diff --git a/python/aitemplate/compiler/ops/common/math.py b/python/aitemplate/compiler/ops/common/math.py index 0b9194247..b4af047e9 100644 --- a/python/aitemplate/compiler/ops/common/math.py +++ b/python/aitemplate/compiler/ops/common/math.py @@ -47,6 +47,10 @@ def log(tensor: Any) -> Tensor: return OP_REGISTRY.get("LOGE")(tensor) +def log1p(tensor: Any) -> Tensor: + return OP_REGISTRY.get("LOG1P")(tensor) + + def exp(tensor: Any) -> Tensor: return OP_REGISTRY.get("EXP")(tensor)