From e6083468bf620dda7872b18ffe1a5547c2e052ac Mon Sep 17 00:00:00 2001 From: Henry Hu Date: Thu, 29 Feb 2024 21:01:57 -0800 Subject: [PATCH] Add floor op Summary: As title Reviewed By: frank-wei Differential Revision: D54332190 fbshipit-source-id: 7f2d25c7e29ab3cde061d759d282452fbaeee3d7 --- fx2ait/fx2ait/converters/ait_converters.py | 11 ++++ fx2ait/fx2ait/converters/utils.py | 2 + python/aitemplate/backend/backend_spec.py | 7 +++ .../backend/cuda/elementwise/custom_math.cuh | 27 ++++++++++ .../compiler/ops/common/epilogue.py | 1 + python/aitemplate/compiler/ops/common/math.py | 4 ++ tests/unittest/ops/test_activation.py | 54 +++++++++++++++++++ 7 files changed, 106 insertions(+) diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index b04448c34..129d778d6 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -155,6 +155,17 @@ def acc_ops_floor_div( return create_binary_op(FuncEnum.FLOOR_DIV, args, kwargs, name) +@ait_converter(acc_ops.floor) +def acc_ops_floor( + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> ConverterOutput: + input_val = kwargs["input"] + return elementwise(FuncEnum.FLOOR)(input_val) + + @ait_converter(acc_ops.add) def acc_ops_add( target: Target, diff --git a/fx2ait/fx2ait/converters/utils.py b/fx2ait/fx2ait/converters/utils.py index 540024fab..dc8d59a15 100644 --- a/fx2ait/fx2ait/converters/utils.py +++ b/fx2ait/fx2ait/converters/utils.py @@ -164,6 +164,8 @@ def get_python_op_from_ait_constant_elementwise_op( return math.sqrt elif op_type == FuncEnum.FLOOR_DIV: return operator.floordiv + elif op_type == FuncEnum.FLOOR: + return math.floor else: raise RuntimeError(f"{op_type} is not supported yet!") diff --git a/python/aitemplate/backend/backend_spec.py b/python/aitemplate/backend/backend_spec.py index 6e9f700ae..47d8618e1 100644 --- a/python/aitemplate/backend/backend_spec.py +++ b/python/aitemplate/backend/backend_spec.py @@ -312,6 +312,13 @@ class GPUBackendSpec(BackendSpec): "bfloat16": "floor_div", "bfloat16_2": "floor_div", }, + FuncEnum.FLOOR: { + "float": "__floor", + "half": "__floor", + "half2": "__floor", + "bfloat16": "__floor", + "bfloat16_2": "__floor", + }, FuncEnum.CELU: { "float": "fcelu", "half": "hcelu", diff --git a/python/aitemplate/backend/cuda/elementwise/custom_math.cuh b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh index 255e420b5..57c8d3a4f 100644 --- a/python/aitemplate/backend/cuda/elementwise/custom_math.cuh +++ b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh @@ -1010,7 +1010,34 @@ __device__ half2 floor_div(const half2 a, const half2 b) { __device__ bfloat16_2 floor_div(const bfloat16_2 a, const bfloat16_2 b) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) return bfloat16_2(floor_div(a.x, b.x), floor_div(a.y, b.y)); +#else + NOT_IMPLEMENTED(); +#endif +} +__device__ float __floor(const float a) { + return floor(a); +} + +__device__ half __floor(const half a) { + return hfloor(a); +} + +__device__ bfloat16 __floor(const bfloat16 a) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return hfloor(a); +#else + NOT_IMPLEMENTED(); +#endif +} + +__device__ half2 __floor(const half2 a) { + return half2(__floor(a.x), __floor(a.y)); +} + +__device__ bfloat16_2 __floor(const bfloat16_2 a) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_2(__floor(a.x), __floor(a.y)); #else NOT_IMPLEMENTED(); #endif diff --git a/python/aitemplate/compiler/ops/common/epilogue.py b/python/aitemplate/compiler/ops/common/epilogue.py index fd684bf6e..18e41a35c 100644 --- a/python/aitemplate/compiler/ops/common/epilogue.py +++ b/python/aitemplate/compiler/ops/common/epilogue.py @@ -65,3 +65,4 @@ class FuncEnum(Enum): SOFTSIGN = 27 FLOOR_DIV = 28 CELU = 29 + FLOOR = 30 diff --git a/python/aitemplate/compiler/ops/common/math.py b/python/aitemplate/compiler/ops/common/math.py index 4534628e6..0b9194247 100644 --- a/python/aitemplate/compiler/ops/common/math.py +++ b/python/aitemplate/compiler/ops/common/math.py @@ -117,3 +117,7 @@ def floor_div(tensor: Any) -> Tensor: def celu(tensor: Any) -> Tensor: return OP_REGISTRY.get("CELU")(tensor) + + +def floor(tensor: Any) -> Tensor: + return OP_REGISTRY.get("FLOOR")(tensor) diff --git a/tests/unittest/ops/test_activation.py b/tests/unittest/ops/test_activation.py index d8c6f04a5..7ec2c7ee7 100644 --- a/tests/unittest/ops/test_activation.py +++ b/tests/unittest/ops/test_activation.py @@ -44,6 +44,7 @@ FuncEnum.SIGMOID: torch.sigmoid, FuncEnum.RELU: torch.relu, FuncEnum.CELU: torch.celu, + FuncEnum.FLOOR: torch.floor, } @@ -127,6 +128,38 @@ def _test_floor_div( module.run_with_tensors([x1_pt], [x2]) torch.testing.assert_close(x2, x2_pt, atol=1e-2, rtol=1e-2) + def _test_floor( + self, + input_size, + test_name="floor", + copy_op=False, + dtype="float16", + ): + assert len(input_size) == 2 + X1 = Tensor( + shape=[IntImm(input_size[0]), IntImm(input_size[1])], + dtype=dtype, + name="input0", + is_input=True, + ) + X2_op = ops.elementwise(FuncEnum.FLOOR) + + if copy_op: + X2_op = ops.elementwise(**X2_op._get_op_attributes()) + X2 = X2_op(X1) + X2._attrs["is_output"] = True + X2._attrs["name"] = "output0" + + target = detect_target() + module = compile_model(X2, target, "./tmp", f"{test_name}_{dtype}") + + x1_pt = get_random_torch_tensor(input_size, dtype) + x2_pt = torch.floor(x1_pt) + + x2 = torch.empty_like(x2_pt) + module.run_with_tensors([x1_pt], [x2]) + torch.testing.assert_close(x2, x2_pt, atol=1e-2, rtol=1e-2) + def _test_hardtanh( self, input_size, @@ -816,6 +849,27 @@ def test_floor_div(self, dtype): dtype=dtype, ) + @parameterized.expand( + **filter_test_cases_by_params( + { + TestEnv.CUDA_LESS_THAN_SM80: [("float16"), ("float32")], + TestEnv.CUDA_SM80: [("bfloat16")], + TestEnv.ROCM: [("float16")], + } + ) + ) + def test_floor(self, dtype): + self._test_simple_function( + [511, 511], FuncEnum.FLOOR, test_name="floor_1", dtype=dtype + ) + self._test_simple_function( + [512, 512], + FuncEnum.FLOOR, + test_name="floor_1_copy_op", + copy_op=True, + dtype=dtype, + ) + @parameterized.expand( **filter_test_cases_by_params( {