Skip to content

Commit

Permalink
Add USE_TANH_FOR_SIGMOID flag for CUDA codegen. (#881)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #881

We surface USE_TANH_FOR_SIGMOID to AIT level, that applies not only for CUTLASS but AIT computation too.

Reviewed By: amateurcoffee, aakhundov

Differential Revision: D48097853

fbshipit-source-id: df9ed1696aefe1c3e01a85e074f60c4adf5a14f2
  • Loading branch information
muchulee8 authored and facebook-github-bot committed Aug 9, 2023
1 parent 46f170a commit 38aebd9
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference/env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,5 @@ Miscellaneous
**AIT_PLOT_SHORTEN_TENSOR_NAMES**: If set to "1", shorten too long tensor names for a plot of a model graph, thus making a plot much easier to analyze visually. "0" by default.

**AIT_USE_FAST_MATH**: If set to "0", no fast math option will be used for the device code generation. Default value is "1".

**AIT_USE_TANH_FOR_SIGMOID**: If set to "1", tanh will be used to approximate sigmoid during device code generation. Default value is "0".
4 changes: 4 additions & 0 deletions fx2ait/fx2ait/fx2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
save_remote_cache: Optional[bool] = False,
do_optimize_graph: bool = True,
use_fast_math: bool = True,
use_tanh_for_sigmoid: bool = False,
profile_timeout: int = 500,
optimize_for_compilation_time: bool = False,
):
Expand All @@ -89,6 +90,7 @@ def __init__(
remote_cache_file_path: AITemplate profiling cache location
save_remote_cache: whether to save the updated cache
use_fast_math: whether to use fast math in CUDA kernels
use_tanh_for_sigmoid: whether to use tanh to approximate sigmoid in CUDA kernels
profile_timeout: timeout in seconds for AIT profilers to complete
optimize_for_compilation_time: we use O1 and disable the ProfileImpl function to reduce compilation time.
"""
Expand All @@ -114,6 +116,7 @@ def __init__(
_LOGGER.info(f"Set CACHE_DIR to {self.cache_dir}")
self.use_fp16_acc = use_fp16_acc
self.use_fast_math = use_fast_math
self.use_tanh_for_sigmoid = use_tanh_for_sigmoid
self.optimize_for_compilation_time = optimize_for_compilation_time
self.hardware_target = self._create_target()
self.input_specs = input_specs
Expand Down Expand Up @@ -141,6 +144,7 @@ def _create_target(self):
use_fp16_acc=self.use_fp16_acc,
remote_cache_bytes=self.remote_cache_bytes,
use_fast_math=self.use_fast_math,
use_tanh_for_sigmoid=self.use_tanh_for_sigmoid,
optimize_for_compilation_time=self.optimize_for_compilation_time,
)

Expand Down
2 changes: 2 additions & 0 deletions fx2ait/fx2ait/lower/lower_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,5 @@ class LowerSettings:
trace_ait_module: bool = True
# If True, optimize for compilation time (ie. compile w/ -O1 rather than -O3 and skip profiling codegen)
optimize_for_compilation_time: bool = False
# If True, use tanh to approximate sigmoid in CUDA kernels
use_tanh_for_sigmoid: bool = False
10 changes: 5 additions & 5 deletions python/aitemplate/backend/cuda/elementwise/custom_math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,15 @@ __device__ bfloat16 fast_tanh(bfloat16 x) {
}

__device__ float fsigmoid_custom(const float a) {
#if defined(AIT_USE_FAST_MATH)
#if defined(AIT_USE_TANH_FOR_SIGMOID)
return (cutlass::fast_tanh(a * 0.5f) + 1.0f) * 0.5f;
#else
return 1.0f / (1.0f + expf(-a));
#endif
}

__device__ half hsigmoid_custom(const half a) {
#if defined(AIT_USE_FAST_MATH)
#if defined(AIT_USE_TANH_FOR_SIGMOID)
return __hmul(
(__hadd(fast_tanh(__hmul(a, CUDA_FP16_ONE_HALF)), CUDA_FP16_ONE)),
CUDA_FP16_ONE_HALF);
Expand All @@ -179,7 +179,7 @@ __device__ half hsigmoid_custom(const half a) {
}

__device__ half2 h2sigmoid_custom(const half2 a) {
#if defined(AIT_USE_FAST_MATH)
#if defined(AIT_USE_TANH_FOR_SIGMOID)
const auto halfX2 = half2(CUDA_FP16_ONE_HALF, CUDA_FP16_ONE_HALF);
const auto oneX2 = half2(CUDA_FP16_ONE, CUDA_FP16_ONE);
return __hmul2((__hadd2(fast_tanh(__hmul2(a, halfX2)), oneX2)), halfX2);
Expand All @@ -192,7 +192,7 @@ __device__ half2 h2sigmoid_custom(const half2 a) {
__device__ bfloat16 hsigmoid_custom(const bfloat16 a) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)

#if defined(AIT_USE_FAST_MATH)
#if defined(AIT_USE_TANH_FOR_SIGMOID)
return __hmul(
(__hadd(fast_tanh(__hmul(a, CUDA_BF16_ONE_HALF)), CUDA_BF16_ONE)),
CUDA_BF16_ONE_HALF);
Expand All @@ -208,7 +208,7 @@ __device__ bfloat16 hsigmoid_custom(const bfloat16 a) {
__device__ bfloat16_2 h2sigmoid_custom(const bfloat16_2 a) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)

#if defined(AIT_USE_FAST_MATH)
#if defined(AIT_USE_TANH_FOR_SIGMOID)
const auto halfX2 = bfloat16_2(CUDA_BF16_ONE_HALF, CUDA_BF16_ONE_HALF);
const auto oneX2 = bfloat16_2(CUDA_BF16_ONE, CUDA_BF16_ONE);
return __hmul2((__hadd2(fast_tanh(__hmul2(a, halfX2)), oneX2)), halfX2);
Expand Down
22 changes: 20 additions & 2 deletions python/aitemplate/backend/cuda/target_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,19 @@ def _build_nvcc_compiler_options(self) -> List[str]:
options.extend(
[
"--use_fast_math",
"-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
"-DAIT_USE_FAST_MATH=1",
]
)
if (
self._kwargs.get("use_tanh_for_sigmoid", False)
or environ.use_tanh_for_sigmoid()
):
options.extend(
[
"-DAIT_USE_TANH_FOR_SIGMOID=1",
"-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
]
)
return options

def get_device_compiler_options(self) -> List[str]:
Expand Down Expand Up @@ -445,10 +454,19 @@ def _build_compile_options(self):
compile_options.extend(
[
"--use_fast_math",
"-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
"-DAIT_USE_FAST_MATH=1",
]
)
if (
self._kwargs.get("use_tanh_for_sigmoid", False)
or environ.use_tanh_for_sigmoid()
):
compile_options.extend(
[
"-DAIT_USE_TANH_FOR_SIGMOID=1",
"-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
]
)
compile_options_str = " ".join(compile_options)
_LOGGER.info(f"The compile options are: {compile_options_str}")
return compile_options_str
Expand Down
13 changes: 11 additions & 2 deletions python/aitemplate/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,25 @@ def get_compiler_opt_level() -> str:
return compiler_opt


def use_fast_math() -> str:
def use_fast_math() -> bool:
"""
Whether the fast math option should be used for the device code generation.
Fast math implies the use of approximate math operations (say,
a division operation), allowing to gain speed at the cost of accuracy.
Default value is "1".
Default value to get from environment variable is "1".
"""
return os.getenv("AIT_USE_FAST_MATH", "1") == "1"


def use_tanh_for_sigmoid() -> bool:
"""
Whether the we want to use tanh to approximate sigmoid for the device code generation.
This controls both the code generation for AITemplate codegen and CUTLASS.
Default value to get from environment variable is "0".
"""
return os.getenv("AIT_USE_TANH_FOR_SIGMOID", "0") == "1"


def enable_cuda_lto() -> bool:
"""
nvcc will use LTO flags during compilation
Expand Down
48 changes: 48 additions & 0 deletions tests/unittest/ops/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,54 @@ def test_sigmoid(self, dtype):
dtype=dtype,
)

def _test_fast_sigmoid_extreme_value(
self, atol, rtol, use_tanh_for_sigmoid, testname
):
X1 = Tensor(
shape=[IntImm(2), IntImm(2)],
dtype="float16",
name="input0",
is_input=True,
)
sigmoid_op = ops.elementwise(FuncEnum.SIGMOID)
X2 = sigmoid_op(X1)
X2._attrs["is_output"] = True
X2._attrs["name"] = "output0"

target = detect_target(use_tanh_for_sigmoid=use_tanh_for_sigmoid)
module = compile_model(X2, target, "./tmp", f"fast_sigmoid_{testname}")

x1_pt = get_random_torch_tensor((2, 2), "float16")
x1_pt[0, 0] = -6
x1_pt[0, 1] = -8
x1_pt[1, 0] = -10
x1_pt[1, 1] = -12

x2_pt = torch.sigmoid(x1_pt)
x2 = torch.empty_like(x2_pt)
module.run_with_tensors([x1_pt], [x2])
torch.testing.assert_close(x2, x2_pt, atol=atol, rtol=rtol, equal_nan=True)

def test_fast_sigmoid(self):
self._test_fast_sigmoid_extreme_value(
atol=1e-2, rtol=1e-2, use_tanh_for_sigmoid=True, testname="use_tanh"
)
self._test_fast_sigmoid_extreme_value(
atol=1e-6, rtol=1e-2, use_tanh_for_sigmoid=False, testname="original"
)

@unittest.skipIf(
detect_target().name() == "rocm", "ROCM doesn't use tanh for sigmoid."
)
def test_fast_sigmoid_fail(self):
with self.assertRaises(AssertionError):
self._test_fast_sigmoid_extreme_value(
atol=1e-6,
rtol=1e-2,
use_tanh_for_sigmoid=True,
testname="use_tanh_fail",
)

@parameterized.expand(
**filter_test_cases_by_params(
{
Expand Down

0 comments on commit 38aebd9

Please sign in to comment.