From 38aebd90b10c2c9252f9f83984f7affe490ff8a8 Mon Sep 17 00:00:00 2001 From: Mu-Chu Lee Date: Tue, 8 Aug 2023 20:19:46 -0700 Subject: [PATCH] Add USE_TANH_FOR_SIGMOID flag for CUDA codegen. (#881) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/881 We surface USE_TANH_FOR_SIGMOID to AIT level, that applies not only for CUTLASS but AIT computation too. Reviewed By: amateurcoffee, aakhundov Differential Revision: D48097853 fbshipit-source-id: df9ed1696aefe1c3e01a85e074f60c4adf5a14f2 --- docs/source/reference/env.rst | 2 + fx2ait/fx2ait/fx2ait.py | 4 ++ fx2ait/fx2ait/lower/lower_settings.py | 2 + .../backend/cuda/elementwise/custom_math.cuh | 10 ++-- python/aitemplate/backend/cuda/target_def.py | 22 ++++++++- python/aitemplate/utils/environ.py | 13 ++++- tests/unittest/ops/test_activation.py | 48 +++++++++++++++++++ 7 files changed, 92 insertions(+), 9 deletions(-) diff --git a/docs/source/reference/env.rst b/docs/source/reference/env.rst index 392999b33..64f8a473a 100644 --- a/docs/source/reference/env.rst +++ b/docs/source/reference/env.rst @@ -70,3 +70,5 @@ Miscellaneous **AIT_PLOT_SHORTEN_TENSOR_NAMES**: If set to "1", shorten too long tensor names for a plot of a model graph, thus making a plot much easier to analyze visually. "0" by default. **AIT_USE_FAST_MATH**: If set to "0", no fast math option will be used for the device code generation. Default value is "1". + +**AIT_USE_TANH_FOR_SIGMOID**: If set to "1", tanh will be used to approximate sigmoid during device code generation. Default value is "0". diff --git a/fx2ait/fx2ait/fx2ait.py b/fx2ait/fx2ait/fx2ait.py index 45d0e07e2..942ef402a 100644 --- a/fx2ait/fx2ait/fx2ait.py +++ b/fx2ait/fx2ait/fx2ait.py @@ -69,6 +69,7 @@ def __init__( save_remote_cache: Optional[bool] = False, do_optimize_graph: bool = True, use_fast_math: bool = True, + use_tanh_for_sigmoid: bool = False, profile_timeout: int = 500, optimize_for_compilation_time: bool = False, ): @@ -89,6 +90,7 @@ def __init__( remote_cache_file_path: AITemplate profiling cache location save_remote_cache: whether to save the updated cache use_fast_math: whether to use fast math in CUDA kernels + use_tanh_for_sigmoid: whether to use tanh to approximate sigmoid in CUDA kernels profile_timeout: timeout in seconds for AIT profilers to complete optimize_for_compilation_time: we use O1 and disable the ProfileImpl function to reduce compilation time. """ @@ -114,6 +116,7 @@ def __init__( _LOGGER.info(f"Set CACHE_DIR to {self.cache_dir}") self.use_fp16_acc = use_fp16_acc self.use_fast_math = use_fast_math + self.use_tanh_for_sigmoid = use_tanh_for_sigmoid self.optimize_for_compilation_time = optimize_for_compilation_time self.hardware_target = self._create_target() self.input_specs = input_specs @@ -141,6 +144,7 @@ def _create_target(self): use_fp16_acc=self.use_fp16_acc, remote_cache_bytes=self.remote_cache_bytes, use_fast_math=self.use_fast_math, + use_tanh_for_sigmoid=self.use_tanh_for_sigmoid, optimize_for_compilation_time=self.optimize_for_compilation_time, ) diff --git a/fx2ait/fx2ait/lower/lower_settings.py b/fx2ait/fx2ait/lower/lower_settings.py index cfaef6a37..465c0bb9f 100644 --- a/fx2ait/fx2ait/lower/lower_settings.py +++ b/fx2ait/fx2ait/lower/lower_settings.py @@ -85,3 +85,5 @@ class LowerSettings: trace_ait_module: bool = True # If True, optimize for compilation time (ie. compile w/ -O1 rather than -O3 and skip profiling codegen) optimize_for_compilation_time: bool = False + # If True, use tanh to approximate sigmoid in CUDA kernels + use_tanh_for_sigmoid: bool = False diff --git a/python/aitemplate/backend/cuda/elementwise/custom_math.cuh b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh index acfaa2018..255e420b5 100644 --- a/python/aitemplate/backend/cuda/elementwise/custom_math.cuh +++ b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh @@ -161,7 +161,7 @@ __device__ bfloat16 fast_tanh(bfloat16 x) { } __device__ float fsigmoid_custom(const float a) { -#if defined(AIT_USE_FAST_MATH) +#if defined(AIT_USE_TANH_FOR_SIGMOID) return (cutlass::fast_tanh(a * 0.5f) + 1.0f) * 0.5f; #else return 1.0f / (1.0f + expf(-a)); @@ -169,7 +169,7 @@ __device__ float fsigmoid_custom(const float a) { } __device__ half hsigmoid_custom(const half a) { -#if defined(AIT_USE_FAST_MATH) +#if defined(AIT_USE_TANH_FOR_SIGMOID) return __hmul( (__hadd(fast_tanh(__hmul(a, CUDA_FP16_ONE_HALF)), CUDA_FP16_ONE)), CUDA_FP16_ONE_HALF); @@ -179,7 +179,7 @@ __device__ half hsigmoid_custom(const half a) { } __device__ half2 h2sigmoid_custom(const half2 a) { -#if defined(AIT_USE_FAST_MATH) +#if defined(AIT_USE_TANH_FOR_SIGMOID) const auto halfX2 = half2(CUDA_FP16_ONE_HALF, CUDA_FP16_ONE_HALF); const auto oneX2 = half2(CUDA_FP16_ONE, CUDA_FP16_ONE); return __hmul2((__hadd2(fast_tanh(__hmul2(a, halfX2)), oneX2)), halfX2); @@ -192,7 +192,7 @@ __device__ half2 h2sigmoid_custom(const half2 a) { __device__ bfloat16 hsigmoid_custom(const bfloat16 a) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -#if defined(AIT_USE_FAST_MATH) +#if defined(AIT_USE_TANH_FOR_SIGMOID) return __hmul( (__hadd(fast_tanh(__hmul(a, CUDA_BF16_ONE_HALF)), CUDA_BF16_ONE)), CUDA_BF16_ONE_HALF); @@ -208,7 +208,7 @@ __device__ bfloat16 hsigmoid_custom(const bfloat16 a) { __device__ bfloat16_2 h2sigmoid_custom(const bfloat16_2 a) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -#if defined(AIT_USE_FAST_MATH) +#if defined(AIT_USE_TANH_FOR_SIGMOID) const auto halfX2 = bfloat16_2(CUDA_BF16_ONE_HALF, CUDA_BF16_ONE_HALF); const auto oneX2 = bfloat16_2(CUDA_BF16_ONE, CUDA_BF16_ONE); return __hmul2((__hadd2(fast_tanh(__hmul2(a, halfX2)), oneX2)), halfX2); diff --git a/python/aitemplate/backend/cuda/target_def.py b/python/aitemplate/backend/cuda/target_def.py index bc95469a5..17c376596 100644 --- a/python/aitemplate/backend/cuda/target_def.py +++ b/python/aitemplate/backend/cuda/target_def.py @@ -178,10 +178,19 @@ def _build_nvcc_compiler_options(self) -> List[str]: options.extend( [ "--use_fast_math", - "-DCUTLASS_USE_TANH_FOR_SIGMOID=1", "-DAIT_USE_FAST_MATH=1", ] ) + if ( + self._kwargs.get("use_tanh_for_sigmoid", False) + or environ.use_tanh_for_sigmoid() + ): + options.extend( + [ + "-DAIT_USE_TANH_FOR_SIGMOID=1", + "-DCUTLASS_USE_TANH_FOR_SIGMOID=1", + ] + ) return options def get_device_compiler_options(self) -> List[str]: @@ -445,10 +454,19 @@ def _build_compile_options(self): compile_options.extend( [ "--use_fast_math", - "-DCUTLASS_USE_TANH_FOR_SIGMOID=1", "-DAIT_USE_FAST_MATH=1", ] ) + if ( + self._kwargs.get("use_tanh_for_sigmoid", False) + or environ.use_tanh_for_sigmoid() + ): + compile_options.extend( + [ + "-DAIT_USE_TANH_FOR_SIGMOID=1", + "-DCUTLASS_USE_TANH_FOR_SIGMOID=1", + ] + ) compile_options_str = " ".join(compile_options) _LOGGER.info(f"The compile options are: {compile_options_str}") return compile_options_str diff --git a/python/aitemplate/utils/environ.py b/python/aitemplate/utils/environ.py index 160b56a1e..b2db8b798 100644 --- a/python/aitemplate/utils/environ.py +++ b/python/aitemplate/utils/environ.py @@ -36,16 +36,25 @@ def get_compiler_opt_level() -> str: return compiler_opt -def use_fast_math() -> str: +def use_fast_math() -> bool: """ Whether the fast math option should be used for the device code generation. Fast math implies the use of approximate math operations (say, a division operation), allowing to gain speed at the cost of accuracy. - Default value is "1". + Default value to get from environment variable is "1". """ return os.getenv("AIT_USE_FAST_MATH", "1") == "1" +def use_tanh_for_sigmoid() -> bool: + """ + Whether the we want to use tanh to approximate sigmoid for the device code generation. + This controls both the code generation for AITemplate codegen and CUTLASS. + Default value to get from environment variable is "0". + """ + return os.getenv("AIT_USE_TANH_FOR_SIGMOID", "0") == "1" + + def enable_cuda_lto() -> bool: """ nvcc will use LTO flags during compilation diff --git a/tests/unittest/ops/test_activation.py b/tests/unittest/ops/test_activation.py index 8f78dcdcc..d8c6f04a5 100644 --- a/tests/unittest/ops/test_activation.py +++ b/tests/unittest/ops/test_activation.py @@ -671,6 +671,54 @@ def test_sigmoid(self, dtype): dtype=dtype, ) + def _test_fast_sigmoid_extreme_value( + self, atol, rtol, use_tanh_for_sigmoid, testname + ): + X1 = Tensor( + shape=[IntImm(2), IntImm(2)], + dtype="float16", + name="input0", + is_input=True, + ) + sigmoid_op = ops.elementwise(FuncEnum.SIGMOID) + X2 = sigmoid_op(X1) + X2._attrs["is_output"] = True + X2._attrs["name"] = "output0" + + target = detect_target(use_tanh_for_sigmoid=use_tanh_for_sigmoid) + module = compile_model(X2, target, "./tmp", f"fast_sigmoid_{testname}") + + x1_pt = get_random_torch_tensor((2, 2), "float16") + x1_pt[0, 0] = -6 + x1_pt[0, 1] = -8 + x1_pt[1, 0] = -10 + x1_pt[1, 1] = -12 + + x2_pt = torch.sigmoid(x1_pt) + x2 = torch.empty_like(x2_pt) + module.run_with_tensors([x1_pt], [x2]) + torch.testing.assert_close(x2, x2_pt, atol=atol, rtol=rtol, equal_nan=True) + + def test_fast_sigmoid(self): + self._test_fast_sigmoid_extreme_value( + atol=1e-2, rtol=1e-2, use_tanh_for_sigmoid=True, testname="use_tanh" + ) + self._test_fast_sigmoid_extreme_value( + atol=1e-6, rtol=1e-2, use_tanh_for_sigmoid=False, testname="original" + ) + + @unittest.skipIf( + detect_target().name() == "rocm", "ROCM doesn't use tanh for sigmoid." + ) + def test_fast_sigmoid_fail(self): + with self.assertRaises(AssertionError): + self._test_fast_sigmoid_extreme_value( + atol=1e-6, + rtol=1e-2, + use_tanh_for_sigmoid=True, + testname="use_tanh_fail", + ) + @parameterized.expand( **filter_test_cases_by_params( {