diff --git a/python/aitemplate/backend/cuda/tensor/index_select.py b/python/aitemplate/backend/cuda/tensor/index_select.py index 6be6e0134..29c026d97 100644 --- a/python/aitemplate/backend/cuda/tensor/index_select.py +++ b/python/aitemplate/backend/cuda/tensor/index_select.py @@ -43,26 +43,27 @@ import jinja2 from aitemplate.backend import registry - from aitemplate.backend.backend_spec import CUDASpec -from aitemplate.backend.cuda import cuda_common from aitemplate.utils import shape_utils header_files = """ #include +#include #include #include "cutlass/util/host_tensor.h" #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include + +using bfloat16 = nv_bfloat16; """ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - {{input_type}}* /*output*/, - const {{input_type}}* /*input*/, + void* /*output*/, + const void* /*input*/, const {{index_type}} /*dim_len*/, const {{index_type}}* /*dim_idxs*/, const {{index_type}} /*dim_idxs_len*/, @@ -103,8 +104,8 @@ } void {{func_name}}( - {{input_type}}* output, - const {{input_type}}* input, + void* output, + const void* input, const {{index_type}} dim_len, const {{index_type}}* dim_idxs, const {{index_type}} dim_idxs_len, @@ -176,8 +177,8 @@ def gen_function(func_attrs) -> str: y = func_attrs["outputs"][0] dim = func_attrs["dim"] - input_type = cuda_common.dtype_to_cuda_type(x._attrs["dtype"]) - output_type = cuda_common.dtype_to_cuda_type(y._attrs["dtype"]) + input_type = backend_spec.dtype_to_backend_type(x._attrs["dtype"]) + output_type = backend_spec.dtype_to_backend_type(y._attrs["dtype"]) if input_type != output_type: raise TypeError("input type must equal to output type") @@ -214,8 +215,10 @@ def gen_function_decl(func_attrs) -> str: The function declaration string """ backend_spec = CUDASpec() + x = func_attrs["inputs"][0] - input_type = cuda_common.dtype_to_cuda_type(x._attrs["dtype"]) + input_type = backend_spec.dtype_to_backend_type(x._attrs["dtype"]) + return FUNC_DECL_TEMPLATE.render( func_name=func_attrs["name"], input_type=input_type, diff --git a/tests/unittest/ops/test_index_select.py b/tests/unittest/ops/test_index_select.py index a219ab311..c404ab9af 100644 --- a/tests/unittest/ops/test_index_select.py +++ b/tests/unittest/ops/test_index_select.py @@ -280,6 +280,46 @@ def test_fp16(self, shape, benchmark=False): benchmark=benchmark, ) + @parameterized.expand( + [ + [(5, 4, 3, 2), False], + # [(2, 6), False], + # [(20, 6), False], + # [(300, 80), False], + # Uncomment to benchmark + # [(5, 4, 3, 2), True], + # [(2, 6), True], + # [(20, 6), True], + # [(300, 80), True], + # [(1024, 128, 256), True], + # [(1024, 1024, 100), True], + # [(1, 1), True], + # [(10, 1), True], + # [(100, 1), True], + # [(1000, 1), True], + # [(10000, 1), True], #revisit + # [(100000, 1), True], + # [(1000000, 1), True], + # [(10000000, 1), True], + # [(100000000, 1), True], + # [(10000, 10000), True], + # [(10, 10, 10, 10, 10, 10, 10, 10), True], + ] + ) + def test_bf16(self, shape, benchmark=False): + torch.manual_seed(1024) + random.seed(1024) + for idx, _ in enumerate(shape): + for dim_idx_len in [1, int(shape[idx] / 2), shape[idx]]: + self._test_index_select( + shape=shape, + dim_idx=idx, + dim_idx_len=dim_idx_len if dim_idx_len > 0 else 1, + test_name="index_select_bf16", + dtype="bfloat16", + benchmark=benchmark, + ) + if __name__ == "__main__": torch.manual_seed(1024)