Skip to content

Commit

Permalink
Add reciprocal operator
Browse files Browse the repository at this point in the history
Summary: Add reciprocal

Differential Revision: D62000543
  • Loading branch information
muchulee8 authored and facebook-github-bot committed Sep 4, 2024
1 parent e2eff07 commit fc8c4fa
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 0 deletions.
11 changes: 11 additions & 0 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ def acc_ops_floor(
return elementwise(FuncEnum.FLOOR)(input_val)


@ait_converter(acc_ops.reciprocal)
def acc_ops_reciprocal(
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> ConverterOutput:
input_val = kwargs["input"]
return elementwise(FuncEnum.RECIPROCAL)(input_val)


@ait_converter(acc_ops.add)
def acc_ops_add(
target: Target,
Expand Down
2 changes: 2 additions & 0 deletions fx2ait/fx2ait/converters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def get_python_op_from_ait_constant_elementwise_op(
return operator.floordiv
elif op_type == FuncEnum.FLOOR:
return math.floor
elif op_type == FuncEnum.RECIPROCAL:
return math.reciprocal
else:
raise RuntimeError(f"{op_type} is not supported yet!")

Expand Down
7 changes: 7 additions & 0 deletions python/aitemplate/backend/backend_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,13 @@ class GPUBackendSpec(BackendSpec):
"bfloat16": "__floor",
"bfloat16_2": "__floor",
},
FuncEnum.RECIPROCAL: {
"float": "__reciprocal",
"half": "__hreciprocal",
"half2": "__h2reciprocal",
"bfloat16": "__breciprocal",
"bfloat16_2": "__b2reciprocal",
},
FuncEnum.CELU: {
"float": "fcelu",
"half": "hcelu",
Expand Down
28 changes: 28 additions & 0 deletions python/aitemplate/backend/cuda/elementwise/custom_math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,34 @@ __device__ bfloat16_2 __floor(const bfloat16_2 a) {
#endif
}

__device__ float __reciprocal(const float a) {
return 1.0f / a;
}

__device__ half __hreciprocal(const half a) {
return __hdiv(1.0f, a);
}

__device__ bfloat16 __breciprocal(const bfloat16 a) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return __hdiv(1.0f, a);
#else
NOT_IMPLEMENTED();
#endif
}

__device__ half2 __h2reciprocal(const half2 a) {
return half2(__hdiv(1.0f, a.x), __hdiv(1.0f, a.y));
}

__device__ bfloat16_2 __b2reciprocal(const bfloat16_2 a) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return bfloat16_2(__hdiv(1.0f, a.x), __hdiv(1.0f, a.y));
#else
NOT_IMPLEMENTED();
#endif
}

__device__ float fcelu(const float a, const float alpha) {
return a > 0.f ? a : alpha * (expf(a / alpha) - 1.0f);
}
Expand Down
1 change: 1 addition & 0 deletions python/aitemplate/compiler/ops/common/epilogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ class FuncEnum(Enum):
CELU = 29
FLOOR = 30
LOG1P = 31
RECIPROCAL = 32
4 changes: 4 additions & 0 deletions python/aitemplate/compiler/ops/common/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ def celu(tensor: Any) -> Tensor:

def floor(tensor: Any) -> Tensor:
return OP_REGISTRY.get("FLOOR")(tensor)


def reciprocal(tensor: Any) -> Tensor:
return OP_REGISTRY.get("RECIPROCAL")(tensor)
54 changes: 54 additions & 0 deletions tests/unittest/ops/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
FuncEnum.RELU: torch.relu,
FuncEnum.CELU: torch.celu,
FuncEnum.FLOOR: torch.floor,
FuncEnum.RECIPROCAL: torch.reciprocal,
}


Expand Down Expand Up @@ -437,6 +438,38 @@ def _test_fast_gelu(
module.run_with_tensors([x1_pt], [x2])
torch.testing.assert_close(x2, x2_pt, atol=1e-2, rtol=1e-2)

def _test_reciprocal(
self,
input_size,
test_name="reciprocal",
copy_op=False,
dtype="float16",
):
assert len(input_size) == 2
X1 = Tensor(
shape=[IntImm(input_size[0]), IntImm(input_size[1])],
dtype=dtype,
name="input0",
is_input=True,
)
X2_op = ops.elementwise(FuncEnum.RECIPROCAL)

if copy_op:
X2_op = ops.elementwise(**X2_op._get_op_attributes())
X2 = X2_op(X1)
X2._attrs["is_output"] = True
X2._attrs["name"] = "output0"

target = detect_target()
module = compile_model(X2, target, "./tmp", f"{test_name}_{dtype}")

x1_pt = get_random_torch_tensor(input_size, dtype)
x2_pt = torch.floor(x1_pt)

x2 = torch.empty_like(x2_pt)
module.run_with_tensors([x1_pt], [x2])
torch.testing.assert_close(x2, x2_pt, atol=1e-2, rtol=1e-2)

@parameterized.expand(
**filter_test_cases_by_params(
{
Expand Down Expand Up @@ -921,6 +954,27 @@ def test_fast_gelu(self, dtype):
[256, 128], test_name="fast_gelu_4_copy_op", copy_op=True, dtype=dtype
)

@parameterized.expand(
**filter_test_cases_by_params(
{
TestEnv.CUDA_LESS_THAN_SM80: [("float16"), ("float32")],
TestEnv.CUDA_SM80: [("bfloat16")],
TestEnv.ROCM: [("float16")],
}
)
)
def test_reciprocal(self, dtype):
self._test_simple_function(
[32, 128], FuncEnum.RECIPROCAL, test_name="reciprocal", dtype=dtype
)
self._test_simple_function(
[32, 128],
FuncEnum.RECIPROCAL,
test_name="reciprocal_copy_op",
copy_op=True,
dtype=dtype,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit fc8c4fa

Please sign in to comment.