diff --git a/python/aitemplate/backend/cuda/tensor/__init__.py b/python/aitemplate/backend/cuda/tensor/__init__.py index ab5f5ffe8..08dcceaee 100644 --- a/python/aitemplate/backend/cuda/tensor/__init__.py +++ b/python/aitemplate/backend/cuda/tensor/__init__.py @@ -18,6 +18,7 @@ from aitemplate.backend.cuda.tensor import ( argmax, batch_gather, + cast, concatenate, concatenate_tanh, dynamic_slice, @@ -42,6 +43,7 @@ __all__ = [ "argmax", "batch_gather", + "cast", "concatenate", "concatenate_tanh", "dynamic_slice", diff --git a/python/aitemplate/backend/cuda/tensor/cast.py b/python/aitemplate/backend/cuda/tensor/cast.py new file mode 100644 index 000000000..bab6388fc --- /dev/null +++ b/python/aitemplate/backend/cuda/tensor/cast.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Dict + +import jinja2 + +from aitemplate.backend import registry +from aitemplate.backend.backend_spec import CUDASpec +from aitemplate.backend.common.elementwise_common import gen_int_var_product_str + +CUDA_HEADER_FILES = """ +#include +#include +#include +""" + +CONSTANT_TEMPLATE = jinja2.Template( + """ +#define N_THREADS_PER_BLOCK 256 + + """ +) + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void invoke_{{func_name}}( + void* y, + const void* x, + {{index_type}} n_elements, + {{prefix}}Stream_t stream); + """ +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{ + {{indent}}const {{index_type}} {{func_name}}_n_elements = {{calculate_n}}; + {{indent}}invoke_{{func_name}}({{output}}, {{input}}, {{func_name}}_n_elements, stream); +{{indent}}} + """ +) + + +FUNC_TEMPLATE = jinja2.Template( + """ +{{header_files}} + +namespace { + +{{constant}} + +__global__ void cast_op( + {{output_type}}* output, + const {{input_type}}* input, + {{index_type}} n_elements +) { + const {{index_type}} idx = (blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= n_elements) { + return; + } + output[idx] = {{cast_func_call}} + } + +} // namespace + +void invoke_{{func_name}}(void* output, const void* input, + {{index_type}} n_elements, {{prefix}}Stream_t stream) { + if (n_elements == 0) { + return; + } + int grid_size = static_cast(std::ceil(static_cast(n_elements) / N_THREADS_PER_BLOCK)); + cast_op<<>>( + reinterpret_cast<{{output_type}}*>(output), + reinterpret_cast(input), + n_elements + ); +} + """ +) + +CAST_FUNCS = { + "half": { + "bfloat16": "__float2bfloat16_rn(__half2float(input[idx]));", + "float": "__half2float(input[idx]);", + }, + "bfloat16": { + "half": "__float2half_rn(__bfloat162float(input[idx]));", + "float": "__bfloat162float(input[idx]);", + }, + "float": { + "bfloat16": "__float2bfloat16_rn(input[idx]);", + "half": "__float2half_rn(input[idx]);", + }, +} + + +@registry.reg("cuda.cast.gen_function") +def gen_function(func_attrs: Dict[str, Any]) -> str: + input_ = func_attrs["inputs"][0] + output = func_attrs["outputs"][0] + backend_spec = CUDASpec() + output_dtype = output.dtype() + output_type = backend_spec.dtype_to_backend_type(output_dtype) + input_type = backend_spec.dtype_to_backend_type(input_.dtype()) + cast_func_call = CAST_FUNCS[input_type][output_type] + + return FUNC_TEMPLATE.render( + header_files=backend_spec.header_src_template.render( + extra_header=CUDA_HEADER_FILES + ), + constant=CONSTANT_TEMPLATE.render(), + func_name=func_attrs["name"], + input_type=input_type, + output_type=output_type, + index_type=backend_spec.index_type, + cast_func_call=cast_func_call, + prefix=backend_spec.prefix, + ) + + +@registry.reg("cuda.cast.func_decl") +def gen_function_decl(func_attrs: Dict[str, Any]) -> str: + backend_spec = CUDASpec() + return FUNC_DECL_TEMPLATE.render( + func_name=func_attrs["name"], + prefix=backend_spec.prefix, + index_type=backend_spec.index_type, + ) + + +@registry.reg("cuda.cast.func_call") +def gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str: + backend_spec = CUDASpec() + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + output=func_attrs["outputs"][0]._attrs["name"], + input=func_attrs["inputs"][0]._attrs["name"], + calculate_n=gen_int_var_product_str(func_attrs["inputs"][0].shape()), + index_type=backend_spec.index_type, + indent=indent, + ) diff --git a/python/aitemplate/compiler/ops/tensor/__init__.py b/python/aitemplate/compiler/ops/tensor/__init__.py index 569a82aef..265f3293d 100644 --- a/python/aitemplate/compiler/ops/tensor/__init__.py +++ b/python/aitemplate/compiler/ops/tensor/__init__.py @@ -18,6 +18,7 @@ """ from aitemplate.compiler.ops.tensor.argmax import argmax from aitemplate.compiler.ops.tensor.batch_gather import batch_gather +from aitemplate.compiler.ops.tensor.cast import cast from aitemplate.compiler.ops.tensor.chunk import chunk from aitemplate.compiler.ops.tensor.concatenate import concatenate from aitemplate.compiler.ops.tensor.concatenate_tanh import concatenate_tanh diff --git a/python/aitemplate/compiler/ops/tensor/cast.py b/python/aitemplate/compiler/ops/tensor/cast.py new file mode 100644 index 000000000..1850367a6 --- /dev/null +++ b/python/aitemplate/compiler/ops/tensor/cast.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from aitemplate import backend +from aitemplate.backend import registry +from aitemplate.compiler.base import Operator, Tensor +from aitemplate.compiler.dtype import normalize_dtype + + +class cast(Operator): + """ + Returns the cast of input tensor to specified type. + Only the conversion between any pair of float16, bfloat16, + and float32 dtypes is supported. + + Args: + x (Tensor): the source tensor + dtype (str): the target type for the cast operator + + Returns: + Tensor: a tensor with the type converted to the + specified dtype. + + """ + + def __init__(self) -> None: + super().__init__() + + self._attrs["op"] = "cast" + self._attrs["has_profiler"] = False + + def __call__( + self, + x: Tensor, + dtype: str = "bfloat16", + ) -> Tensor: + x_dtype = normalize_dtype(x._attrs["dtype"]) + dtype = normalize_dtype(dtype) + if x_dtype not in ("float16", "bfloat16", "float32"): + raise TypeError( + f"Expected dtype for x must be float16,bfloat16 or float32 , but got {x_dtype}." + ) + + if dtype not in ("float16", "bfloat16", "float32"): + raise TypeError( + f"Expected dtype to cast must be float16,bfloat16 or float32 , but got {dtype}." + ) + if dtype == x_dtype: + return x + + self._attrs["inputs"] = [x] + self._attrs["cast_dtype"] = dtype + self._set_depth() + + output_shape = x._attrs["shape"] + output = Tensor( + output_shape, + src_ops={self}, + dtype=dtype, + ) + self._attrs["outputs"] = [output] + return output + + def gen_function(self) -> str: + target = backend.target.Target.current() + func_key = f"{target.name()}.{self._attrs['op']}.gen_function" + func = registry.get(func_key) + return func(self._attrs) diff --git a/tests/unittest/ops/test_cast.py b/tests/unittest/ops/test_cast.py new file mode 100644 index 000000000..715befc73 --- /dev/null +++ b/tests/unittest/ops/test_cast.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from aitemplate.compiler import compile_model, ops +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import ( + get_random_torch_tensor, + get_torch_empty_tensor, +) +from aitemplate.utils.torch_utils import string_to_torch_dtype +from parameterized import param, parameterized + + +class TestCast(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._test_id = 0 + + def _test_cast( + self, + shape, + dtype="float32", + cast_dtype="bfloat16", + test_name="cast", + ) -> None: + if not isinstance(shape, list): + shape = [shape] + + X = Tensor( + shape=shape, + name="X", + dtype=dtype, + is_input=True, + ) + + Y = ops.cast()(X, cast_dtype) + Y._attrs["name"] = "Y" + Y._attrs["is_output"] = True + + target = detect_target() + module = compile_model(Y, target, "./tmp", f"{test_name}_{self._test_id}") + self._test_id += 1 + + x = get_random_torch_tensor(shape, dtype=dtype) + y = get_torch_empty_tensor(shape, dtype=cast_dtype) + inputs = {"X": x} + outputs = {"Y": y} + module.run_with_tensors(inputs, outputs) + + y_pt = x.to(string_to_torch_dtype(cast_dtype)) + torch.testing.assert_close(y, y_pt, atol=1e-2, rtol=1e-2) + + @parameterized.expand( + [ + param(1, "float16", "bfloat16", [1], "float16_to_bfloat16"), + param(2, "float16", "float32", [10, 20], "float16_to_float32"), + param(3, "bfloat16", "float16", [10, 20, 30], "bfloat16_to_float16"), + param(4, "bfloat16", "float32", 123, "bfloat16_to_float32"), + param(5, "float32", "float16", [20, 30], "float32_to_float16"), + param(6, "float32", "bfloat16", [1, 128], "float32_to_bfloat16"), + ] + ) + def test_cast( + self, + i, + dtype, + cast_dtype, + shape, + test_name, + ): + self._test_cast( + shape=shape, + dtype=dtype, + cast_dtype=cast_dtype, + test_name=test_name, + ) + + +if __name__ == "__main__": + torch.manual_seed(0) + unittest.main()