Skip to content

Commit

Permalink
Re-sync with internal repository (#793)
Browse files Browse the repository at this point in the history
Co-authored-by: Facebook Community Bot <[email protected]>
  • Loading branch information
facebook-github-bot and facebook-github-bot committed Jun 24, 2023
1 parent 9336061 commit 57c5e03
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/aitemplate/backend/cuda/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from aitemplate.backend.cuda.tensor import (
argmax,
batch_gather,
cast,
concatenate,
concatenate_tanh,
dynamic_slice,
Expand All @@ -42,6 +43,7 @@
__all__ = [
"argmax",
"batch_gather",
"cast",
"concatenate",
"concatenate_tanh",
"dynamic_slice",
Expand Down
155 changes: 155 additions & 0 deletions python/aitemplate/backend/cuda/tensor/cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, Dict

import jinja2

from aitemplate.backend import registry
from aitemplate.backend.backend_spec import CUDASpec
from aitemplate.backend.common.elementwise_common import gen_int_var_product_str

CUDA_HEADER_FILES = """
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
"""

CONSTANT_TEMPLATE = jinja2.Template(
"""
#define N_THREADS_PER_BLOCK 256
"""
)

FUNC_DECL_TEMPLATE = jinja2.Template(
"""
void invoke_{{func_name}}(
void* y,
const void* x,
{{index_type}} n_elements,
{{prefix}}Stream_t stream);
"""
)


FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{
{{indent}}const {{index_type}} {{func_name}}_n_elements = {{calculate_n}};
{{indent}}invoke_{{func_name}}({{output}}, {{input}}, {{func_name}}_n_elements, stream);
{{indent}}}
"""
)


FUNC_TEMPLATE = jinja2.Template(
"""
{{header_files}}
namespace {
{{constant}}
__global__ void cast_op(
{{output_type}}* output,
const {{input_type}}* input,
{{index_type}} n_elements
) {
const {{index_type}} idx = (blockIdx.x * blockDim.x + threadIdx.x);
if (idx >= n_elements) {
return;
}
output[idx] = {{cast_func_call}}
}
} // namespace
void invoke_{{func_name}}(void* output, const void* input,
{{index_type}} n_elements, {{prefix}}Stream_t stream) {
if (n_elements == 0) {
return;
}
int grid_size = static_cast<int>(std::ceil(static_cast<double>(n_elements) / N_THREADS_PER_BLOCK));
cast_op<<<grid_size, N_THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<{{output_type}}*>(output),
reinterpret_cast<const {{input_type}}*>(input),
n_elements
);
}
"""
)

CAST_FUNCS = {
"half": {
"bfloat16": "__float2bfloat16_rn(__half2float(input[idx]));",
"float": "__half2float(input[idx]);",
},
"bfloat16": {
"half": "__float2half_rn(__bfloat162float(input[idx]));",
"float": "__bfloat162float(input[idx]);",
},
"float": {
"bfloat16": "__float2bfloat16_rn(input[idx]);",
"half": "__float2half_rn(input[idx]);",
},
}


@registry.reg("cuda.cast.gen_function")
def gen_function(func_attrs: Dict[str, Any]) -> str:
input_ = func_attrs["inputs"][0]
output = func_attrs["outputs"][0]
backend_spec = CUDASpec()
output_dtype = output.dtype()
output_type = backend_spec.dtype_to_backend_type(output_dtype)
input_type = backend_spec.dtype_to_backend_type(input_.dtype())
cast_func_call = CAST_FUNCS[input_type][output_type]

return FUNC_TEMPLATE.render(
header_files=backend_spec.header_src_template.render(
extra_header=CUDA_HEADER_FILES
),
constant=CONSTANT_TEMPLATE.render(),
func_name=func_attrs["name"],
input_type=input_type,
output_type=output_type,
index_type=backend_spec.index_type,
cast_func_call=cast_func_call,
prefix=backend_spec.prefix,
)


@registry.reg("cuda.cast.func_decl")
def gen_function_decl(func_attrs: Dict[str, Any]) -> str:
backend_spec = CUDASpec()
return FUNC_DECL_TEMPLATE.render(
func_name=func_attrs["name"],
prefix=backend_spec.prefix,
index_type=backend_spec.index_type,
)


@registry.reg("cuda.cast.func_call")
def gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str:
backend_spec = CUDASpec()
return FUNC_CALL_TEMPLATE.render(
func_name=func_attrs["name"],
output=func_attrs["outputs"][0]._attrs["name"],
input=func_attrs["inputs"][0]._attrs["name"],
calculate_n=gen_int_var_product_str(func_attrs["inputs"][0].shape()),
index_type=backend_spec.index_type,
indent=indent,
)
1 change: 1 addition & 0 deletions python/aitemplate/compiler/ops/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
from aitemplate.compiler.ops.tensor.argmax import argmax
from aitemplate.compiler.ops.tensor.batch_gather import batch_gather
from aitemplate.compiler.ops.tensor.cast import cast
from aitemplate.compiler.ops.tensor.chunk import chunk
from aitemplate.compiler.ops.tensor.concatenate import concatenate
from aitemplate.compiler.ops.tensor.concatenate_tanh import concatenate_tanh
Expand Down
81 changes: 81 additions & 0 deletions python/aitemplate/compiler/ops/tensor/cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import Operator, Tensor
from aitemplate.compiler.dtype import normalize_dtype


class cast(Operator):
"""
Returns the cast of input tensor to specified type.
Only the conversion between any pair of float16, bfloat16,
and float32 dtypes is supported.
Args:
x (Tensor): the source tensor
dtype (str): the target type for the cast operator
Returns:
Tensor: a tensor with the type converted to the
specified dtype.
"""

def __init__(self) -> None:
super().__init__()

self._attrs["op"] = "cast"
self._attrs["has_profiler"] = False

def __call__(
self,
x: Tensor,
dtype: str = "bfloat16",
) -> Tensor:
x_dtype = normalize_dtype(x._attrs["dtype"])
dtype = normalize_dtype(dtype)
if x_dtype not in ("float16", "bfloat16", "float32"):
raise TypeError(
f"Expected dtype for x must be float16,bfloat16 or float32 , but got {x_dtype}."
)

if dtype not in ("float16", "bfloat16", "float32"):
raise TypeError(
f"Expected dtype to cast must be float16,bfloat16 or float32 , but got {dtype}."
)
if dtype == x_dtype:
return x

self._attrs["inputs"] = [x]
self._attrs["cast_dtype"] = dtype
self._set_depth()

output_shape = x._attrs["shape"]
output = Tensor(
output_shape,
src_ops={self},
dtype=dtype,
)
self._attrs["outputs"] = [output]
return output

def gen_function(self) -> str:
target = backend.target.Target.current()
func_key = f"{target.name()}.{self._attrs['op']}.gen_function"
func = registry.get(func_key)
return func(self._attrs)
97 changes: 97 additions & 0 deletions tests/unittest/ops/test_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from aitemplate.compiler import compile_model, ops
from aitemplate.frontend import Tensor
from aitemplate.testing import detect_target
from aitemplate.testing.test_utils import (
get_random_torch_tensor,
get_torch_empty_tensor,
)
from aitemplate.utils.torch_utils import string_to_torch_dtype
from parameterized import param, parameterized


class TestCast(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._test_id = 0

def _test_cast(
self,
shape,
dtype="float32",
cast_dtype="bfloat16",
test_name="cast",
) -> None:
if not isinstance(shape, list):
shape = [shape]

X = Tensor(
shape=shape,
name="X",
dtype=dtype,
is_input=True,
)

Y = ops.cast()(X, cast_dtype)
Y._attrs["name"] = "Y"
Y._attrs["is_output"] = True

target = detect_target()
module = compile_model(Y, target, "./tmp", f"{test_name}_{self._test_id}")
self._test_id += 1

x = get_random_torch_tensor(shape, dtype=dtype)
y = get_torch_empty_tensor(shape, dtype=cast_dtype)
inputs = {"X": x}
outputs = {"Y": y}
module.run_with_tensors(inputs, outputs)

y_pt = x.to(string_to_torch_dtype(cast_dtype))
torch.testing.assert_close(y, y_pt, atol=1e-2, rtol=1e-2)

@parameterized.expand(
[
param(1, "float16", "bfloat16", [1], "float16_to_bfloat16"),
param(2, "float16", "float32", [10, 20], "float16_to_float32"),
param(3, "bfloat16", "float16", [10, 20, 30], "bfloat16_to_float16"),
param(4, "bfloat16", "float32", 123, "bfloat16_to_float32"),
param(5, "float32", "float16", [20, 30], "float32_to_float16"),
param(6, "float32", "bfloat16", [1, 128], "float32_to_bfloat16"),
]
)
def test_cast(
self,
i,
dtype,
cast_dtype,
shape,
test_name,
):
self._test_cast(
shape=shape,
dtype=dtype,
cast_dtype=cast_dtype,
test_name=test_name,
)


if __name__ == "__main__":
torch.manual_seed(0)
unittest.main()

0 comments on commit 57c5e03

Please sign in to comment.