Skip to content

Commit

Permalink
Support broadcast for masked_select (#892)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #892

Similar to pytorch version, add broadcast support for masked_select.

- For no broadcast case, it will follow the previous logic.
- Check whether we need to broadcast input or mask, only generate the necessary code and allocate needed memory.
- Tried different solutions, I feel it's more memory efficient and faster to expand input or mask in device memory. (Better than the 3 kernels version we discussed before)
- Refactored a bit the test.

Reviewed By: chenyang78

Differential Revision: D48054898

fbshipit-source-id: e3e76e56389e1fe38f052408dcefda0dfe49e056
  • Loading branch information
silverguo authored and facebook-github-bot committed Aug 14, 2023
1 parent 664b25d commit 19f07d3
Show file tree
Hide file tree
Showing 3 changed files with 459 additions and 87 deletions.
244 changes: 211 additions & 33 deletions python/aitemplate/backend/cuda/tensor/masked_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@
"""
Define masked_select codegen and CUDA kernel
"""
import jinja2
from typing import List

import jinja2
from aitemplate.backend import registry

from aitemplate.backend.backend_spec import CUDASpec
from aitemplate.backend.cuda import cuda_common

from aitemplate.backend.common.elementwise_common import (
gen_dynamic_dim_str,
get_dynamic_dims,
get_stride_expressions,
)
from aitemplate.backend.cuda import cuda_common
from aitemplate.compiler.base import IntImm, IntVar

header_files = """
#include <cuda_fp16.h>
Expand All @@ -36,6 +43,7 @@
{{input_type}}* /*output*/,
const {{input_type}}* /*input*/,
const bool* /*mask*/,
{% if need_broadcast %} {{dynamic_dims_decl}} {% endif %}
{{index_type}} /*num_elems*/,
{{index_type}}* /*output size*/,
void* workspace /*workspace*/,
Expand All @@ -60,10 +68,65 @@
} while (0)
#endif // CUDA_CHECK_MASKED_SELECT
{% if need_broadcast_input or need_broadcast_mask %}
__global__ void expand_input_mask_kernel(
{% if need_broadcast_input %}
{{input_type}}* expanded_input,
const {{input_type}}* input,
{% endif %}
{% if need_broadcast_mask %}
bool* expanded_mask,
const bool* mask,
{% endif %}
{{dynamic_dims_decl}}
const {{index_type}} num_elems
) {
for(auto idx = blockIdx.x*blockDim.x + threadIdx.x; idx <= num_elems; idx+=gridDim.x*blockDim.x) {
if (idx < num_elems) {
{% if need_broadcast_input %}
{{index_type}} input_idx = 0;
{% endif %}
{% if need_broadcast_mask %}
{{index_type}} mask_idx = 0;
{% endif %}
{{index_type}} cur;
auto tmp = idx;
{% for i in range(max_rank) %}
cur = tmp % {{max_dims[max_rank-i-1]}};
tmp = tmp / {{max_dims[max_rank-i-1]}};
{% if need_broadcast_input and (i < input_rank) %}
if ({{input_dims[input_rank-i-1]}} > 1) {
input_idx += cur * {{input_strides[input_rank-i-1]}};
}
{% endif %}
{% if need_broadcast_mask and (i < mask_rank) %}
if ({{mask_dims[mask_rank-i-1]}} > 1) {
mask_idx += cur * {{mask_strides[mask_rank-i-1]}};
}
{% endif %}
{% endfor %}
{% if need_broadcast_input %}
expanded_input[idx] = input[input_idx];
{% endif %}
{% if need_broadcast_mask %}
expanded_mask[idx] = mask[mask_idx];
{% endif %}
}
}
}
{% endif %}
void {{func_name}}(
{{input_type}}* output,
const {{input_type}}* input,
const bool* mask,
{% if need_broadcast_input or need_broadcast_mask %}
{{dynamic_dims_decl}}
{% endif %}
{{index_type}} num_elems,
{{index_type}}* num_nonmasked,
void* workspace,
Expand All @@ -84,42 +147,113 @@
throw std::runtime_error("workspace is NULL!");
}
size_t allocated_storage = {{workspace_size}};
constexpr size_t INPUT_TYPE_SIZE = sizeof({{input_type}});
constexpr size_t BOOL_SIZE = sizeof(bool);
constexpr size_t INDEX_TYPE_SIZE = sizeof({{index_type}});
// Keep the number of nonmasked elements at the beginning of the workspace
const size_t NUM_NONMASKED_SIZE = sizeof({{index_type}});
{{index_type}}* num_nonmasked_device = static_cast<{{index_type}}*>(workspace);
{{index_type}} workspace_offset = 0;
{{index_type}}* num_nonmasked_device = static_cast<{{index_type}}*>(workspace+workspace_offset);
workspace_offset += INDEX_TYPE_SIZE;
{% if need_broadcast_input %}
{{input_type}}* expanded_input = static_cast<{{input_type}}*>(workspace+workspace_offset);
workspace_offset += INPUT_TYPE_SIZE * num_elems;
{% endif %}
{% if need_broadcast_mask %}
bool* expanded_mask = static_cast<bool*>(workspace+workspace_offset);
workspace_offset += BOOL_SIZE * num_elems;
{% endif %}
// Get needed temporary storage size and reallocate if necessary
void* d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
CUDA_CHECK_MASKED_SELECT(cub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, input, mask, output, num_nonmasked_device, num_elems, stream),
"Error when checking the required buffer size!");
CUDA_CHECK_MASKED_SELECT(cudaStreamSynchronize(stream), "Error when synchronizing the stream!");
CUDA_CHECK_MASKED_SELECT(
cub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes,
{% if need_broadcast_input %}
expanded_input,
{% else %}
input,
{% endif %}
{% if need_broadcast_mask %}
expanded_mask,
{% else %}
mask,
{% endif %}
output, num_nonmasked_device, num_elems, stream),
"Error when checking the required buffer size!"
);
CUDA_CHECK_MASKED_SELECT(
cudaStreamSynchronize(stream),
"Error when synchronizing the stream!"
);
if (allocated_storage < temp_storage_bytes + NUM_NONMASKED_SIZE) {
auto msg = "Got pre-allocated buffer of size " + std::to_string(allocated_storage) + ", but need " + std::to_string(temp_storage_bytes)
+ ". Allocating a new buffer, expect performance degradation.";
if (allocated_storage < temp_storage_bytes + workspace_offset) {
auto msg = "Got pre-allocated buffer of size " + std::to_string(allocated_storage)
+ ", but need " + std::to_string(temp_storage_bytes+workspace_offset)
+ ". Allocating a new buffer, expect performance degradation.";
std::cerr << msg << std::endl;
// Allocate temporary storage
temp_storage_bytes += NUM_NONMASKED_SIZE;
CUDA_CHECK_MASKED_SELECT(cudaMallocAsync(&d_temp_storage, temp_storage_bytes, stream), "Error when trying to allocate a new buffer!");
CUDA_CHECK_MASKED_SELECT(cudaStreamSynchronize(stream), "Error when synchronizing the stream!");
temp_storage_bytes += workspace_offset;
CUDA_CHECK_MASKED_SELECT(
cudaMallocAsync(&d_temp_storage, temp_storage_bytes, stream),
"Error when trying to allocate a new buffer!"
);
CUDA_CHECK_MASKED_SELECT(
cudaStreamSynchronize(stream),
"Error when synchronizing the stream!"
);
workspace = d_temp_storage;
allocated_storage = temp_storage_bytes;
}
allocated_storage -= NUM_NONMASKED_SIZE; // First NUM_NONMASKED_SIZE bytes are reserved
allocated_storage -= workspace_offset;
// Select nonmasked elements. First NUM_NONMASKED_SIZE bytes of workspace are reserved for num_nonmasked_device
CUDA_CHECK_MASKED_SELECT(cub::DeviceSelect::Flagged(workspace + NUM_NONMASKED_SIZE, allocated_storage, input, mask, output,
num_nonmasked_device, num_elems, stream), "Error when selecting nonmasked elements!");
{% if need_broadcast_input or need_broadcast_mask %}
const {{index_type}} THREADS_PER_BLOCK = 256;
const {{index_type}} ELEMS_PER_THREAD = 128;
auto blocks = (num_elems + THREADS_PER_BLOCK * ELEMS_PER_THREAD) / (THREADS_PER_BLOCK * ELEMS_PER_THREAD);
expand_input_mask_kernel<<<blocks, THREADS_PER_BLOCK, 0, stream>>>(
{% if need_broadcast_input %}
expanded_input,
input,
{% endif %}
{% if need_broadcast_mask %}
expanded_mask,
mask,
{% endif %}
{{dynamic_dims_call}} num_elems);
{% endif %}
// Select nonmasked elements
CUDA_CHECK_MASKED_SELECT(
cub::DeviceSelect::Flagged(workspace+workspace_offset, allocated_storage,
{% if need_broadcast_input %}
expanded_input,
{% else %}
input,
{% endif %}
{% if need_broadcast_mask %}
expanded_mask,
{% else %}
mask,
{% endif %}
output, num_nonmasked_device, num_elems, stream),
"Error when selecting nonmasked elements!"
);
// Extract number of nonmasked elements (size of the output)
CUDA_CHECK_MASKED_SELECT(cudaMemcpyAsync(num_nonmasked, num_nonmasked_device, NUM_NONMASKED_SIZE, cudaMemcpyDeviceToHost, stream),
"Error when copying the number of nonmasked elements from device to host!");
CUDA_CHECK_MASKED_SELECT(cudaStreamSynchronize(stream), "Error when synchronizing the stream!");
CUDA_CHECK_MASKED_SELECT(
cudaMemcpyAsync(num_nonmasked, num_nonmasked_device, INDEX_TYPE_SIZE, cudaMemcpyDeviceToHost, stream),
"Error when copying the number of nonmasked elements from device to host!"
);
CUDA_CHECK_MASKED_SELECT(
cudaStreamSynchronize(stream),
"Error when synchronizing the stream!"
);
if (d_temp_storage != nullptr) {
CUDA_CHECK_MASKED_SELECT(cudaFreeAsync(d_temp_storage, stream), "Error when freeing GPU memory allocated by masked_select!");
CUDA_CHECK_MASKED_SELECT(
cudaFreeAsync(d_temp_storage, stream),
"Error when freeing GPU memory allocated by masked_select!"
);
}
}
"""
Expand All @@ -129,16 +263,16 @@
FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{
{{indent}}
{{indent}} const {{index_type}} input_dims[] = {{input_dims}};
{{indent}} const {{index_type}} max_dims[] = {{max_dims}};
{{indent}} int64_t num_elems = 1;
{{indent}} for ({{index_type}} i = 0; i < {{rank}}; i++) {
{{indent}} num_elems *= input_dims[i];
{{indent}} for ({{index_type}} i = 0; i < {{max_rank}}; i++) {
{{indent}} num_elems *= max_dims[i];
{{indent}} }
{{indent}} {{func_name}}(
{{indent}} {{output_ptr}},
{{indent}} {{input_ptr}},
{{indent}} {{mask_ptr}},
{{indent}} {% if need_broadcast %} {{dynamic_dims_call}} {% endif %}
{{indent}} num_elems,
{{indent}} {{num_nonmasked}},
{{indent}} global_workspace_,
Expand All @@ -149,6 +283,13 @@
)


def _get_dims(shape: List[IntVar]) -> List[str]:
return [
str(dim.value()) if isinstance(dim, IntImm) else dim._attrs["name"]
for dim in shape
]


@registry.reg("cuda.masked_select.gen_function")
def gen_function(func_attrs) -> str:
"""
Expand All @@ -160,21 +301,44 @@ def gen_function(func_attrs) -> str:
The function body string
"""
backend_spec = CUDASpec()
x = func_attrs["inputs"][0]
y = func_attrs["outputs"][0]
x, mask = func_attrs["inputs"]
output = func_attrs["outputs"][0]
max_shape = func_attrs["max_shape"]

input_type = cuda_common.dtype_to_cuda_type(x._attrs["dtype"])
output_type = cuda_common.dtype_to_cuda_type(y._attrs["dtype"])
output_type = cuda_common.dtype_to_cuda_type(output._attrs["dtype"])

if input_type != output_type:
raise TypeError("input type must equal to output type")

dynamic_dims = get_dynamic_dims(x.shape(), mask.shape())

return SRC_TEMPLATE.render(
input_type=input_type,
index_type=backend_spec.index_type,
func_name=func_attrs["name"],
header_files=header_files,
workspace_size=func_attrs["workspace"],
input_dims=_get_dims(x.shape()),
input_rank=len(x.shape()),
input_strides=get_stride_expressions(x.shape()) + ["1"],
need_broadcast_input=x._attrs["shape"] != max_shape,
mask_dims=_get_dims(mask.shape()),
mask_rank=len(mask.shape()),
mask_strides=get_stride_expressions(mask.shape()) + ["1"],
need_broadcast_mask=mask._attrs["shape"] != max_shape,
max_dims=_get_dims(max_shape),
max_rank=len(max_shape),
dynamic_dims_decl=gen_dynamic_dim_str(
index_type=backend_spec.index_type,
dynamic_dims=dynamic_dims,
has_type=True,
),
dynamic_dims_call=gen_dynamic_dim_str(
index_type=backend_spec.index_type,
dynamic_dims=dynamic_dims,
has_type=False,
),
)


Expand All @@ -189,12 +353,19 @@ def gen_function_decl(func_attrs) -> str:
The function declaration string
"""
backend_spec = CUDASpec()
x = func_attrs["inputs"][0]
x, mask = func_attrs["inputs"]
input_type = cuda_common.dtype_to_cuda_type(x._attrs["dtype"])

return FUNC_DECL_TEMPLATE.render(
func_name=func_attrs["name"],
input_type=input_type,
index_type=backend_spec.index_type,
need_broadcast=x._attrs["shape"] != mask._attrs["shape"],
dynamic_dims_decl=gen_dynamic_dim_str(
index_type=backend_spec.index_type,
dynamic_dims=get_dynamic_dims(x.shape(), mask.shape()),
has_type=True,
),
)


Expand Down Expand Up @@ -226,16 +397,23 @@ def gen_function_call(func_attrs, indent=" ") -> str:
name=mask._attrs["name"],
dtype="bool",
)
max_shape = func_attrs["max_shape"]
# Number of nonmasked elements, i.e. size of the output
num_nonmasked_ptr = "&" + y._attrs["shape"][0]._attrs["name"]
input_dims = "{" + ",".join([dim._attrs["name"] for dim in x._attrs["shape"]]) + "}"

return FUNC_CALL_TEMPLATE.render(
indent=indent,
func_name=func_attrs["name"],
input_name=x._attrs["name"],
num_nonmasked=num_nonmasked_ptr,
input_dims=input_dims,
rank=len(x._attrs["shape"]),
max_dims="{" + ",".join([dim._attrs["name"] for dim in max_shape]) + "}",
max_rank=len(max_shape),
need_broadcast=x._attrs["shape"] != mask._attrs["shape"],
dynamic_dims_call=gen_dynamic_dim_str(
index_type=backend_spec.index_type,
dynamic_dims=get_dynamic_dims(x.shape(), mask.shape()),
has_type=False,
),
output_ptr=output_ptr,
input_ptr=input_ptr,
mask_ptr=mask_ptr,
Expand Down
Loading

0 comments on commit 19f07d3

Please sign in to comment.