diff --git a/python/aitemplate/backend/cuda/tensor/masked_select.py b/python/aitemplate/backend/cuda/tensor/masked_select.py index abf053431..14effddc0 100644 --- a/python/aitemplate/backend/cuda/tensor/masked_select.py +++ b/python/aitemplate/backend/cuda/tensor/masked_select.py @@ -15,13 +15,20 @@ """ Define masked_select codegen and CUDA kernel """ -import jinja2 +from typing import List +import jinja2 from aitemplate.backend import registry from aitemplate.backend.backend_spec import CUDASpec -from aitemplate.backend.cuda import cuda_common +from aitemplate.backend.common.elementwise_common import ( + gen_dynamic_dim_str, + get_dynamic_dims, + get_stride_expressions, +) +from aitemplate.backend.cuda import cuda_common +from aitemplate.compiler.base import IntImm, IntVar header_files = """ #include @@ -36,6 +43,7 @@ {{input_type}}* /*output*/, const {{input_type}}* /*input*/, const bool* /*mask*/, + {% if need_broadcast %} {{dynamic_dims_decl}} {% endif %} {{index_type}} /*num_elems*/, {{index_type}}* /*output size*/, void* workspace /*workspace*/, @@ -60,10 +68,65 @@ } while (0) #endif // CUDA_CHECK_MASKED_SELECT +{% if need_broadcast_input or need_broadcast_mask %} +__global__ void expand_input_mask_kernel( + {% if need_broadcast_input %} + {{input_type}}* expanded_input, + const {{input_type}}* input, + {% endif %} + {% if need_broadcast_mask %} + bool* expanded_mask, + const bool* mask, + {% endif %} + {{dynamic_dims_decl}} + const {{index_type}} num_elems +) { + for(auto idx = blockIdx.x*blockDim.x + threadIdx.x; idx <= num_elems; idx+=gridDim.x*blockDim.x) { + + if (idx < num_elems) { + + {% if need_broadcast_input %} + {{index_type}} input_idx = 0; + {% endif %} + {% if need_broadcast_mask %} + {{index_type}} mask_idx = 0; + {% endif %} + {{index_type}} cur; + auto tmp = idx; + + {% for i in range(max_rank) %} + cur = tmp % {{max_dims[max_rank-i-1]}}; + tmp = tmp / {{max_dims[max_rank-i-1]}}; + {% if need_broadcast_input and (i < input_rank) %} + if ({{input_dims[input_rank-i-1]}} > 1) { + input_idx += cur * {{input_strides[input_rank-i-1]}}; + } + {% endif %} + {% if need_broadcast_mask and (i < mask_rank) %} + if ({{mask_dims[mask_rank-i-1]}} > 1) { + mask_idx += cur * {{mask_strides[mask_rank-i-1]}}; + } + {% endif %} + {% endfor %} + + {% if need_broadcast_input %} + expanded_input[idx] = input[input_idx]; + {% endif %} + {% if need_broadcast_mask %} + expanded_mask[idx] = mask[mask_idx]; + {% endif %} + } + } +} +{% endif %} + void {{func_name}}( {{input_type}}* output, const {{input_type}}* input, const bool* mask, + {% if need_broadcast_input or need_broadcast_mask %} + {{dynamic_dims_decl}} + {% endif %} {{index_type}} num_elems, {{index_type}}* num_nonmasked, void* workspace, @@ -84,42 +147,113 @@ throw std::runtime_error("workspace is NULL!"); } size_t allocated_storage = {{workspace_size}}; + constexpr size_t INPUT_TYPE_SIZE = sizeof({{input_type}}); + constexpr size_t BOOL_SIZE = sizeof(bool); + constexpr size_t INDEX_TYPE_SIZE = sizeof({{index_type}}); - // Keep the number of nonmasked elements at the beginning of the workspace - const size_t NUM_NONMASKED_SIZE = sizeof({{index_type}}); - {{index_type}}* num_nonmasked_device = static_cast<{{index_type}}*>(workspace); + {{index_type}} workspace_offset = 0; + {{index_type}}* num_nonmasked_device = static_cast<{{index_type}}*>(workspace+workspace_offset); + workspace_offset += INDEX_TYPE_SIZE; + {% if need_broadcast_input %} + {{input_type}}* expanded_input = static_cast<{{input_type}}*>(workspace+workspace_offset); + workspace_offset += INPUT_TYPE_SIZE * num_elems; + {% endif %} + {% if need_broadcast_mask %} + bool* expanded_mask = static_cast(workspace+workspace_offset); + workspace_offset += BOOL_SIZE * num_elems; + {% endif %} // Get needed temporary storage size and reallocate if necessary void* d_temp_storage = nullptr; size_t temp_storage_bytes = 0; - CUDA_CHECK_MASKED_SELECT(cub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, input, mask, output, num_nonmasked_device, num_elems, stream), - "Error when checking the required buffer size!"); - CUDA_CHECK_MASKED_SELECT(cudaStreamSynchronize(stream), "Error when synchronizing the stream!"); + CUDA_CHECK_MASKED_SELECT( + cub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, + {% if need_broadcast_input %} + expanded_input, + {% else %} + input, + {% endif %} + {% if need_broadcast_mask %} + expanded_mask, + {% else %} + mask, + {% endif %} + output, num_nonmasked_device, num_elems, stream), + "Error when checking the required buffer size!" + ); + CUDA_CHECK_MASKED_SELECT( + cudaStreamSynchronize(stream), + "Error when synchronizing the stream!" + ); - if (allocated_storage < temp_storage_bytes + NUM_NONMASKED_SIZE) { - auto msg = "Got pre-allocated buffer of size " + std::to_string(allocated_storage) + ", but need " + std::to_string(temp_storage_bytes) - + ". Allocating a new buffer, expect performance degradation."; + if (allocated_storage < temp_storage_bytes + workspace_offset) { + auto msg = "Got pre-allocated buffer of size " + std::to_string(allocated_storage) + + ", but need " + std::to_string(temp_storage_bytes+workspace_offset) + + ". Allocating a new buffer, expect performance degradation."; std::cerr << msg << std::endl; // Allocate temporary storage - temp_storage_bytes += NUM_NONMASKED_SIZE; - CUDA_CHECK_MASKED_SELECT(cudaMallocAsync(&d_temp_storage, temp_storage_bytes, stream), "Error when trying to allocate a new buffer!"); - CUDA_CHECK_MASKED_SELECT(cudaStreamSynchronize(stream), "Error when synchronizing the stream!"); + temp_storage_bytes += workspace_offset; + CUDA_CHECK_MASKED_SELECT( + cudaMallocAsync(&d_temp_storage, temp_storage_bytes, stream), + "Error when trying to allocate a new buffer!" + ); + CUDA_CHECK_MASKED_SELECT( + cudaStreamSynchronize(stream), + "Error when synchronizing the stream!" + ); workspace = d_temp_storage; allocated_storage = temp_storage_bytes; } - allocated_storage -= NUM_NONMASKED_SIZE; // First NUM_NONMASKED_SIZE bytes are reserved + allocated_storage -= workspace_offset; - // Select nonmasked elements. First NUM_NONMASKED_SIZE bytes of workspace are reserved for num_nonmasked_device - CUDA_CHECK_MASKED_SELECT(cub::DeviceSelect::Flagged(workspace + NUM_NONMASKED_SIZE, allocated_storage, input, mask, output, - num_nonmasked_device, num_elems, stream), "Error when selecting nonmasked elements!"); + {% if need_broadcast_input or need_broadcast_mask %} + const {{index_type}} THREADS_PER_BLOCK = 256; + const {{index_type}} ELEMS_PER_THREAD = 128; + auto blocks = (num_elems + THREADS_PER_BLOCK * ELEMS_PER_THREAD) / (THREADS_PER_BLOCK * ELEMS_PER_THREAD); + expand_input_mask_kernel<<>>( + {% if need_broadcast_input %} + expanded_input, + input, + {% endif %} + {% if need_broadcast_mask %} + expanded_mask, + mask, + {% endif %} + {{dynamic_dims_call}} num_elems); + {% endif %} + + // Select nonmasked elements + CUDA_CHECK_MASKED_SELECT( + cub::DeviceSelect::Flagged(workspace+workspace_offset, allocated_storage, + {% if need_broadcast_input %} + expanded_input, + {% else %} + input, + {% endif %} + {% if need_broadcast_mask %} + expanded_mask, + {% else %} + mask, + {% endif %} + output, num_nonmasked_device, num_elems, stream), + "Error when selecting nonmasked elements!" + ); // Extract number of nonmasked elements (size of the output) - CUDA_CHECK_MASKED_SELECT(cudaMemcpyAsync(num_nonmasked, num_nonmasked_device, NUM_NONMASKED_SIZE, cudaMemcpyDeviceToHost, stream), - "Error when copying the number of nonmasked elements from device to host!"); - CUDA_CHECK_MASKED_SELECT(cudaStreamSynchronize(stream), "Error when synchronizing the stream!"); + CUDA_CHECK_MASKED_SELECT( + cudaMemcpyAsync(num_nonmasked, num_nonmasked_device, INDEX_TYPE_SIZE, cudaMemcpyDeviceToHost, stream), + "Error when copying the number of nonmasked elements from device to host!" + ); + CUDA_CHECK_MASKED_SELECT( + cudaStreamSynchronize(stream), + "Error when synchronizing the stream!" + ); if (d_temp_storage != nullptr) { - CUDA_CHECK_MASKED_SELECT(cudaFreeAsync(d_temp_storage, stream), "Error when freeing GPU memory allocated by masked_select!"); + CUDA_CHECK_MASKED_SELECT( + cudaFreeAsync(d_temp_storage, stream), + "Error when freeing GPU memory allocated by masked_select!" + ); } } """ @@ -129,16 +263,16 @@ FUNC_CALL_TEMPLATE = jinja2.Template( """ {{indent}}{ -{{indent}} -{{indent}} const {{index_type}} input_dims[] = {{input_dims}}; +{{indent}} const {{index_type}} max_dims[] = {{max_dims}}; {{indent}} int64_t num_elems = 1; -{{indent}} for ({{index_type}} i = 0; i < {{rank}}; i++) { -{{indent}} num_elems *= input_dims[i]; +{{indent}} for ({{index_type}} i = 0; i < {{max_rank}}; i++) { +{{indent}} num_elems *= max_dims[i]; {{indent}} } {{indent}} {{func_name}}( {{indent}} {{output_ptr}}, {{indent}} {{input_ptr}}, {{indent}} {{mask_ptr}}, +{{indent}} {% if need_broadcast %} {{dynamic_dims_call}} {% endif %} {{indent}} num_elems, {{indent}} {{num_nonmasked}}, {{indent}} global_workspace_, @@ -149,6 +283,13 @@ ) +def _get_dims(shape: List[IntVar]) -> List[str]: + return [ + str(dim.value()) if isinstance(dim, IntImm) else dim._attrs["name"] + for dim in shape + ] + + @registry.reg("cuda.masked_select.gen_function") def gen_function(func_attrs) -> str: """ @@ -160,21 +301,44 @@ def gen_function(func_attrs) -> str: The function body string """ backend_spec = CUDASpec() - x = func_attrs["inputs"][0] - y = func_attrs["outputs"][0] + x, mask = func_attrs["inputs"] + output = func_attrs["outputs"][0] + max_shape = func_attrs["max_shape"] input_type = cuda_common.dtype_to_cuda_type(x._attrs["dtype"]) - output_type = cuda_common.dtype_to_cuda_type(y._attrs["dtype"]) + output_type = cuda_common.dtype_to_cuda_type(output._attrs["dtype"]) if input_type != output_type: raise TypeError("input type must equal to output type") + dynamic_dims = get_dynamic_dims(x.shape(), mask.shape()) + return SRC_TEMPLATE.render( input_type=input_type, index_type=backend_spec.index_type, func_name=func_attrs["name"], header_files=header_files, workspace_size=func_attrs["workspace"], + input_dims=_get_dims(x.shape()), + input_rank=len(x.shape()), + input_strides=get_stride_expressions(x.shape()) + ["1"], + need_broadcast_input=x._attrs["shape"] != max_shape, + mask_dims=_get_dims(mask.shape()), + mask_rank=len(mask.shape()), + mask_strides=get_stride_expressions(mask.shape()) + ["1"], + need_broadcast_mask=mask._attrs["shape"] != max_shape, + max_dims=_get_dims(max_shape), + max_rank=len(max_shape), + dynamic_dims_decl=gen_dynamic_dim_str( + index_type=backend_spec.index_type, + dynamic_dims=dynamic_dims, + has_type=True, + ), + dynamic_dims_call=gen_dynamic_dim_str( + index_type=backend_spec.index_type, + dynamic_dims=dynamic_dims, + has_type=False, + ), ) @@ -189,12 +353,19 @@ def gen_function_decl(func_attrs) -> str: The function declaration string """ backend_spec = CUDASpec() - x = func_attrs["inputs"][0] + x, mask = func_attrs["inputs"] input_type = cuda_common.dtype_to_cuda_type(x._attrs["dtype"]) + return FUNC_DECL_TEMPLATE.render( func_name=func_attrs["name"], input_type=input_type, index_type=backend_spec.index_type, + need_broadcast=x._attrs["shape"] != mask._attrs["shape"], + dynamic_dims_decl=gen_dynamic_dim_str( + index_type=backend_spec.index_type, + dynamic_dims=get_dynamic_dims(x.shape(), mask.shape()), + has_type=True, + ), ) @@ -226,16 +397,23 @@ def gen_function_call(func_attrs, indent=" ") -> str: name=mask._attrs["name"], dtype="bool", ) + max_shape = func_attrs["max_shape"] # Number of nonmasked elements, i.e. size of the output num_nonmasked_ptr = "&" + y._attrs["shape"][0]._attrs["name"] - input_dims = "{" + ",".join([dim._attrs["name"] for dim in x._attrs["shape"]]) + "}" + return FUNC_CALL_TEMPLATE.render( indent=indent, func_name=func_attrs["name"], input_name=x._attrs["name"], num_nonmasked=num_nonmasked_ptr, - input_dims=input_dims, - rank=len(x._attrs["shape"]), + max_dims="{" + ",".join([dim._attrs["name"] for dim in max_shape]) + "}", + max_rank=len(max_shape), + need_broadcast=x._attrs["shape"] != mask._attrs["shape"], + dynamic_dims_call=gen_dynamic_dim_str( + index_type=backend_spec.index_type, + dynamic_dims=get_dynamic_dims(x.shape(), mask.shape()), + has_type=False, + ), output_ptr=output_ptr, input_ptr=input_ptr, mask_ptr=mask_ptr, diff --git a/python/aitemplate/compiler/ops/tensor/masked_select.py b/python/aitemplate/compiler/ops/tensor/masked_select.py index bc994f4fb..4eac1764b 100644 --- a/python/aitemplate/compiler/ops/tensor/masked_select.py +++ b/python/aitemplate/compiler/ops/tensor/masked_select.py @@ -24,6 +24,10 @@ from aitemplate.compiler.base import IntVar, Operator, Tensor +from aitemplate.compiler.dtype import get_dtype_size + +from aitemplate.utils import shape_utils + _LOGGER = logging.getLogger(__name__) @@ -35,13 +39,14 @@ class masked_select(Operator): Args: input (Tensor): the source tensor. - mask (Tensor, boolean): has to be of same shape as input. + mask (Tensor, boolean): the shapes of the mask tensor and the input tensor do not need + to match, but they must be broadcastable. Returns: - output: 1D tensor of length equal to the total number of elements in `input`. The result - is contained in the first `num_nonmasked` elements of output. The rest of the output - tensor is not meaningful. - num_nonmasked: number of the non-masked elements in the input, i.e. the length of the + output: 1D tensor of length equal to the total number of elements in broadcast shape + deduced from input and mask. The result is contained in the first `num_nonmasked` + elements of output. The rest of the output tensor is not meaningful. + num_nonmasked: number of the non-masked elements from the input, i.e. the length of the significant part of output. """ @@ -49,20 +54,39 @@ def __init__(self): super().__init__() self._attrs["op"] = "masked_select" self._attrs["workspace"] = 0 + self._attrs["max_shape"] = None def _infer_shape(self, x: Tensor, mask: Tensor) -> List[IntVar]: input_shape = x._attrs["shape"] mask_shape = mask._attrs["shape"] - if input_shape != mask_shape: + broadcastable, max_shape = shape_utils.get_broadcast_max_shape( + input_shape, mask_shape + ) + if not broadcastable: raise RuntimeError( - "Tensor shapes of input and mask are not equal! Shape1: {}, shape2: {}".format( + "Tensor shapes of input and mask are not broadcastable! Shape1: {}, shape2: {}".format( input_shape, mask_shape ) ) - + self._attrs["max_shape"] = max_shape numel = 1 - for dim in input_shape: + for dim in max_shape: numel *= dim.upper_bound() + + # Allocate temporary buffer. This empirical formula for size is deduced by looking at necessary + # memory to expand input/mask and the buffer sizes requested by cub::DeviceSelect::Flagged for + # different input sizes. + input_workspace_size = (input_shape != max_shape) * get_dtype_size( + x._attrs["dtype"] + ) + mask_workspace_size = (mask_shape != max_shape) * 1 # bool + self._attrs["workspace"] = ( + numel * (input_workspace_size + mask_workspace_size) + numel // 128 + 1024 + ) + _LOGGER.debug( + f'Allocating {self._attrs["workspace"]} bytes for temporary buffer of masked_select op' + ) + # Output size can range from 0 (when all mask elements are False) to the total number of # elements in the input (when all mask elements are True). return [IntVar(values=(0, numel))] @@ -81,13 +105,6 @@ def __call__( output = Tensor(output_shape, src_ops={self}, dtype=x._attrs["dtype"]) self._attrs["outputs"] = [output] - # Allocate temporary buffer. This empirical formula for size is deduced by looking at buffer sizes - # requested by cub::DeviceSelect::Flagged for differen input sizes. Required buffer size depends on - # the number of input elements and on the GPU architecture, but not on the input data type. - self._attrs["workspace"] = output_shape[0].upper_bound() // 128 + 1024 - _LOGGER.debug( - f'Allocating {self._attrs["workspace"]} bytes for temporary buffer of masked_select op' - ) return output def gen_function(self) -> str: diff --git a/tests/unittest/ops/test_masked_select.py b/tests/unittest/ops/test_masked_select.py index 22192328a..1e2fe0f71 100644 --- a/tests/unittest/ops/test_masked_select.py +++ b/tests/unittest/ops/test_masked_select.py @@ -30,10 +30,12 @@ detect_target().name() == "rocm", "masked_select is not implemented for ROCm" ) class maskedSelectTestCase(unittest.TestCase): - def _test_masked_select( + def _test_broadcastable_dynamic_masked_select( self, - batch_size=1, - shape=(2, 6), + input_dynamic_shape=(2, 6), + mask_dynamic_shape=(2, 6), + input_shape=(2, 6), + mask_shape=(2, 6), test_name="masked_select", copy_op=False, dtype="float16", @@ -41,13 +43,13 @@ def _test_masked_select( benchmark=False, ): X1 = Tensor( - shape=shape, + shape=input_dynamic_shape, dtype=dtype, name="x", is_input=True, ) X2 = Tensor( - shape=shape, + shape=mask_dynamic_shape, dtype="bool", name="mask", is_input=True, @@ -62,13 +64,18 @@ def _test_masked_select( target = detect_target() module = compile_model([X4], target, "./tmp", test_name) - x = get_random_torch_tensor(shape, dtype=dtype) + + x = get_random_torch_tensor(input_shape, dtype=dtype) if zero_mask: - mask = torch.zeros_like(x) + mask = torch.zeros(mask_shape) else: - mask = get_random_torch_tensor(shape, dtype="float16") > 0 + mask = get_random_torch_tensor(mask_shape, dtype="float16") > 0 y_pt = torch.masked_select(x, mask) - y = torch.empty((x.numel(),), dtype=x.dtype, device=x.device) + y = torch.empty( + (torch.broadcast_shapes(input_shape, mask_shape).numel(),), + dtype=x.dtype, + device=x.device, + ) y_ait = module.run_with_tensors([x, mask], [y])["output_values"] # y_ait contains the correct result. It points to the same memory blob as y, but has the correct shape self.assertTrue(torch.allclose(y_pt, y_ait, atol=1e-10, rtol=0)) @@ -76,7 +83,9 @@ def _test_masked_select( self.assertTrue(torch.allclose(y_pt, y[: y_ait.shape[0]], atol=1e-10, rtol=0)) if benchmark: - print(f"Benchmarking with shape={shape}, dtype={dtype}") + print( + f"Benchmarking with input_shape={input_shape}, mask_shape={mask_shape}, dtype={dtype}" + ) # Warm up. for _ in range(5): module.run_with_tensors([x, mask], [y]) @@ -101,6 +110,47 @@ def _test_masked_select( print(f"Speedup: {torch_time_per_iter_ms / time_per_iter_ms:.2f}x") + def _test_broadcastable_masked_select( + self, + input_shape=(2, 6), + mask_shape=(2, 6), + test_name="masked_select", + copy_op=False, + dtype="float16", + zero_mask=False, + benchmark=False, + ): + self._test_broadcastable_dynamic_masked_select( + input_dynamic_shape=input_shape, + mask_dynamic_shape=mask_shape, + input_shape=input_shape, + mask_shape=mask_shape, + test_name=test_name, + copy_op=copy_op, + dtype=dtype, + zero_mask=zero_mask, + benchmark=benchmark, + ) + + def _test_masked_select( + self, + shape=(2, 6), + test_name="masked_select", + copy_op=False, + dtype="float16", + zero_mask=False, + benchmark=False, + ): + self._test_broadcastable_masked_select( + input_shape=shape, + mask_shape=shape, + test_name=test_name, + copy_op=copy_op, + dtype=dtype, + zero_mask=zero_mask, + benchmark=benchmark, + ) + @parameterized.expand( [ [(2, 6), False], @@ -180,7 +230,6 @@ def test_fp32(self, shape, benchmark): def test_input_dynamic_shape( self, - batch_size=1, shape=(2, 6), test_name="masked_select_dynamic", dtype="float16", @@ -190,35 +239,15 @@ def test_input_dynamic_shape( Check that dynamic input shape is handled correctly. """ dyn_shape = (IntVar(values=(1, 10)), IntVar(values=(1, 10))) - X1 = Tensor( - shape=dyn_shape, + self._test_broadcastable_dynamic_masked_select( + input_dynamic_shape=dyn_shape, + mask_dynamic_shape=dyn_shape, + input_shape=shape, + mask_shape=shape, + test_name=test_name, dtype=dtype, - name="x", - is_input=True, - ) - X2 = Tensor( - shape=dyn_shape, - dtype="bool", - name="mask", - is_input=True, + benchmark=benchmark, ) - X4_op = ops.masked_select() - X4 = X4_op(X1, X2) - X4._attrs["is_output"] = True - X4._attrs["name"] = "output_values" - - target = detect_target() - module = compile_model([X4], target, "./tmp", test_name) - - x = get_random_torch_tensor(shape, dtype=dtype) - mask = get_random_torch_tensor(shape, dtype="float16") > 0 - y_pt = torch.masked_select(x, mask) - y = torch.empty((x.numel(),), dtype=x.dtype, device=x.device) - y_ait = module.run_with_tensors([x, mask], [y])["output_values"] - # y_ait contains the correct result. It points to the same memory blob as y, but has the correct shape - self.assertTrue(torch.allclose(y_pt, y_ait, atol=1e-10, rtol=0)) - # y retained the original shape (x.numel(),), so needs to be cut before comparison - self.assertTrue(torch.allclose(y_pt, y[: y_ait.shape[0]], atol=1e-10, rtol=0)) def test_empty_output(self, shape=(2, 6)): """ @@ -229,6 +258,154 @@ def test_empty_output(self, shape=(2, 6)): test_name="masked_select_zero_mask", ) + @parameterized.expand( + [ + [(32, 16), (1, 16), False], + [(32, 1), (64, 1, 16), False], + [(64, 32, 64), (64,), False], + [(128, 256), (1024, 128, 256), False], + [(10, 1, 1, 256), (10, 10, 10, 10, 128, 256), False], + # Uncomment to benchmark + # [(32, 16), (1, 16), True], + # [(32, 1), (64, 1, 16), True], + # [(64, 32, 64), (64,), True], + # [(64,), (64, 32, 64), True], + # [(128, 256), (1024, 1, 256), True], + # [(1024, 1, 256), (128, 256), True], + # [(128, 256), (1024, 128, 256), True], + # [(1024, 128, 256), (128, 256), True], + # [(1024, 1, 256), (1024, 128, 256), True], + # [(1024, 128, 256), (1024, 1, 256), True], + # [(10, 1, 1, 256), (10, 10, 10, 10, 128, 256), True], + # [(10, 10, 10, 10, 128, 256), (10, 1, 1, 256), True], + # [(10, 10, 1, 1), (10, 10, 10, 1, 1, 10, 256), True], + # [(10, 10, 10, 1, 1, 10, 256), (10, 10, 1, 1), True], + ] + ) + def test_fp16_input_broadcast_shape( + self, + input_shape, + mask_shape, + benchmark=False, + test_name="masked_select_broadcast_fp16", + dtype="float16", + ): + """ + Check the support for broadcastable input and mask. + """ + self._test_broadcastable_masked_select( + input_shape=input_shape, + mask_shape=mask_shape, + test_name=test_name, + dtype=dtype, + benchmark=benchmark, + ) + + @parameterized.expand( + [ + [(32, 16), (1, 16), False], + [(32, 1), (64, 1, 16), False], + [(64, 32, 64), (64,), False], + [(128, 256), (1024, 128, 256), False], + [(10, 1, 1, 256), (10, 10, 10, 10, 128, 256), False], + # Uncomment to benchmark + # [(32, 16), (1, 16), True], + # [(32, 1), (64, 1, 16), True], + # [(64, 32, 64), (64,), True], + # [(64,), (64, 32, 64), True], + # [(128, 256), (1024, 1, 256), True], + # [(1024, 1, 256), (128, 256), True], + # [(128, 256), (1024, 128, 256), True], + # [(1024, 128, 256), (128, 256), True], + # [(1024, 1, 256), (1024, 128, 256), True], + # [(1024, 128, 256), (1024, 1, 256), True], + # [(10, 1, 1, 256), (10, 10, 10, 10, 128, 256), True], + # [(10, 10, 10, 10, 128, 256), (10, 1, 1, 256), True], + # [(10, 10, 1, 1), (10, 10, 10, 1, 1, 10, 256), True], + # [(10, 10, 10, 1, 1, 10, 256), (10, 10, 1, 1), True], + ] + ) + def test_fp32_input_broadcast_shape( + self, + input_shape, + mask_shape, + benchmark=False, + test_name="masked_select_broadcast_fp32", + dtype="float32", + ): + """ + Check the support for broadcastable input and mask. + """ + self._test_broadcastable_masked_select( + input_shape=input_shape, + mask_shape=mask_shape, + test_name=test_name, + dtype=dtype, + benchmark=benchmark, + ) + + @parameterized.expand( + [ + [ + {0, 1}, + {1}, + {0: IntVar(values=(1, 10)), 1: IntVar(values=(1, 10))}, + (6, 8), + (8,), + ], + [{0}, {0}, {0: IntVar(values=(1, 10))}, (7, 6, 8), (7, 1, 8)], + [{0}, set(), {0: IntVar(values=(1, 10))}, (7, 6, 8), (1, 1, 8)], + [ + {0}, + {2}, + {0: IntVar(values=(1, 10)), 2: IntVar(values=(1, 10))}, + (7, 6, 1), + (1, 6, 9), + ], + ] + ) + def test_input_broadcast_dynamic_shape( + self, + input_dynamic_dim_idx, + mask_dynamic_dim_idx, + dynamic_dim_dict, + input_shape, + mask_shape, + test_name="masked_select_broadcast_dynamic_shape", + dtype="float16", + benchmark=False, + ): + """ + Check that broadcast dynamic shape is handled correctly. + """ + + input_rank = len(input_shape) + mask_rank = len(mask_shape) + max_rank = max(input_rank, mask_rank) + input_dynamic_shape = [] + mask_dynamic_shape = [] + for idx in range(max_rank): + if idx >= max_rank - input_rank: + if idx in input_dynamic_dim_idx and idx in dynamic_dim_dict: + input_dynamic_shape.append(dynamic_dim_dict[idx]) + else: + input_dynamic_shape.append(input_shape[idx - max_rank + input_rank]) + if idx >= max_rank - mask_rank: + if idx in mask_dynamic_dim_idx and idx in dynamic_dim_dict: + mask_dynamic_shape.append(dynamic_dim_dict[idx]) + else: + mask_dynamic_shape.append(mask_shape[idx - max_rank + mask_rank]) + + self._test_broadcastable_dynamic_masked_select( + input_dynamic_shape=input_dynamic_shape, + mask_dynamic_shape=mask_dynamic_shape, + input_shape=input_shape, + mask_shape=mask_shape, + test_name=test_name, + dtype=dtype, + benchmark=benchmark, + ) + if __name__ == "__main__": torch.manual_seed(1024)