Skip to content

Commit

Permalink
Add bfloat16 support to index_select op (facebookincubator#948)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#948

ATT

Reviewed By: kadeng

Differential Revision: D50120201

fbshipit-source-id: 8af266d4576d685770356ce81888b27c679cb4d8
  • Loading branch information
aakhundov authored and facebook-github-bot committed Oct 11, 2023
1 parent 8f1af39 commit a530636
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 9 deletions.
21 changes: 12 additions & 9 deletions python/aitemplate/backend/cuda/tensor/index_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,27 @@
import jinja2

from aitemplate.backend import registry

from aitemplate.backend.backend_spec import CUDASpec
from aitemplate.backend.cuda import cuda_common
from aitemplate.utils import shape_utils


header_files = """
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include "cutlass/util/host_tensor.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include <cub/cub.cuh>
using bfloat16 = nv_bfloat16;
"""

FUNC_DECL_TEMPLATE = jinja2.Template(
"""
void {{func_name}}(
{{input_type}}* /*output*/,
const {{input_type}}* /*input*/,
void* /*output*/,
const void* /*input*/,
const {{index_type}} /*dim_len*/,
const {{index_type}}* /*dim_idxs*/,
const {{index_type}} /*dim_idxs_len*/,
Expand Down Expand Up @@ -103,8 +104,8 @@
}
void {{func_name}}(
{{input_type}}* output,
const {{input_type}}* input,
void* output,
const void* input,
const {{index_type}} dim_len,
const {{index_type}}* dim_idxs,
const {{index_type}} dim_idxs_len,
Expand Down Expand Up @@ -176,8 +177,8 @@ def gen_function(func_attrs) -> str:
y = func_attrs["outputs"][0]
dim = func_attrs["dim"]

input_type = cuda_common.dtype_to_cuda_type(x._attrs["dtype"])
output_type = cuda_common.dtype_to_cuda_type(y._attrs["dtype"])
input_type = backend_spec.dtype_to_backend_type(x._attrs["dtype"])
output_type = backend_spec.dtype_to_backend_type(y._attrs["dtype"])

if input_type != output_type:
raise TypeError("input type must equal to output type")
Expand Down Expand Up @@ -214,8 +215,10 @@ def gen_function_decl(func_attrs) -> str:
The function declaration string
"""
backend_spec = CUDASpec()

x = func_attrs["inputs"][0]
input_type = cuda_common.dtype_to_cuda_type(x._attrs["dtype"])
input_type = backend_spec.dtype_to_backend_type(x._attrs["dtype"])

return FUNC_DECL_TEMPLATE.render(
func_name=func_attrs["name"],
input_type=input_type,
Expand Down
40 changes: 40 additions & 0 deletions tests/unittest/ops/test_index_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,46 @@ def test_fp16(self, shape, benchmark=False):
benchmark=benchmark,
)

@parameterized.expand(
[
[(5, 4, 3, 2), False],
# [(2, 6), False],
# [(20, 6), False],
# [(300, 80), False],
# Uncomment to benchmark
# [(5, 4, 3, 2), True],
# [(2, 6), True],
# [(20, 6), True],
# [(300, 80), True],
# [(1024, 128, 256), True],
# [(1024, 1024, 100), True],
# [(1, 1), True],
# [(10, 1), True],
# [(100, 1), True],
# [(1000, 1), True],
# [(10000, 1), True], #revisit
# [(100000, 1), True],
# [(1000000, 1), True],
# [(10000000, 1), True],
# [(100000000, 1), True],
# [(10000, 10000), True],
# [(10, 10, 10, 10, 10, 10, 10, 10), True],
]
)
def test_bf16(self, shape, benchmark=False):
torch.manual_seed(1024)
random.seed(1024)
for idx, _ in enumerate(shape):
for dim_idx_len in [1, int(shape[idx] / 2), shape[idx]]:
self._test_index_select(
shape=shape,
dim_idx=idx,
dim_idx_len=dim_idx_len if dim_idx_len > 0 else 1,
test_name="index_select_bf16",
dtype="bfloat16",
benchmark=benchmark,
)


if __name__ == "__main__":
torch.manual_seed(1024)
Expand Down

0 comments on commit a530636

Please sign in to comment.