Skip to content

Commit

Permalink
Add support for where operator (#791)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #791

As titled, this diffs adds supports for where operations. The expected behavior is to match equivalent operators in Torch, i.e. https://pytorch.org/docs/stable/generated/torch.where.html

Reviewed By: aakhundov

Differential Revision: D46957405

fbshipit-source-id: db4bdf4f2d91d154fb0c9ee092bf6429679b63db
  • Loading branch information
AlbertDachiChen authored and facebook-github-bot committed Jun 27, 2023
1 parent 5e30494 commit 0cf1c2e
Show file tree
Hide file tree
Showing 5 changed files with 554 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/aitemplate/backend/cuda/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
slice_scatter,
split,
topk,
where,
)

__all__ = [
Expand All @@ -65,4 +66,5 @@
"slice_scatter",
"split",
"topk",
"where",
]
228 changes: 228 additions & 0 deletions python/aitemplate/backend/cuda/tensor/where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, Dict

import jinja2

from aitemplate.backend import registry
from aitemplate.backend.backend_spec import CUDASpec
from aitemplate.backend.common.elementwise_common import gen_int_var_product_str
from aitemplate.utils import shape_utils


CUDA_HEADER_FILES = """
#include <cuda_fp16.h>
#include <cuda_runtime.h>
"""


CONSTANT_TEMPLATE = jinja2.Template(
"""
#define N_THREADS_PER_BLOCK 256
#define N_READS_PER_THREAD sizeof({{condition_read_t}}) / sizeof(bool)
"""
)


FUNC_DECL = jinja2.Template(
"""
void invoke_{{func_name}}(
void*, /* output */
const void*, /* condition */
{% if not input_tensor_is_a_const_num %}
const void*, /* input tensor */
{% endif %}
{% if not other_tensor_is_a_const_num %}
const void*, /* other tensor */
{% endif %}
{{index_type}}, /* number of elements */
{{prefix}}Stream_t /* stream */
);
"""
)


FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{
{{indent}}{{index_type}} n_elements = {{calculate_n}};
{{indent}} invoke_{{func_name}}(
{{indent}} {{output}},
{{indent}} {{condition}},
{% if not input_tensor_is_a_const_num %}
{{indent}} {{input_tensor}},
{% endif %}
{% if not other_tensor_is_a_const_num %}
{{indent}} {{other_tensor}},
{% endif %}
{{indent}} n_elements,
{{indent}} stream
{{indent}});
{{indent}}}
"""
)

FUNC_TEMPLATE = jinja2.Template(
"""
{{header_files}}
namespace {
{{constant}}
__global__ void where(
{{read_t}}* output,
const {{condition_read_t}}* condition,
{% if not input_tensor_is_a_const_num %}
const {{read_t}}* input_tesnor,
{% endif %}
{% if not other_tensor_is_a_const_num %}
const {{read_t}}* other_tensor,
{% endif %}
{{index_type}} num_elements) {
const {{index_type}} idx = (blockIdx.x * blockDim.x + threadIdx.x);
if (idx * N_READS_PER_THREAD >= num_elements) {
return;
}
{{read_t}} tmp_output;
{{data_t}}* tmp_output_ptr = reinterpret_cast<{{data_t}}*>(&tmp_output);
{{condition_read_t}} tmp_condition = condition[idx];
bool* tmp_condition_ptr = reinterpret_cast<bool*>(&tmp_condition);
{% if not input_tensor_is_a_const_num %}
{{read_t}} tmp_input_tensor = input_tesnor[idx];
{{data_t}}* tmp_input_tensor_ptr = reinterpret_cast<{{data_t}}*>(&tmp_input_tensor);
{% endif %}
{% if not other_tensor_is_a_const_num %}
{{read_t}} tmp_other_tensor = other_tensor[idx];
{{data_t}}* tmp_other_tensor_ptr = reinterpret_cast<{{data_t}}*>(&tmp_other_tensor);
{% endif %}
#pragma unroll
for (int i=0; i < N_READS_PER_THREAD; i++) {
tmp_output_ptr[i] = ({{data_t}})(tmp_condition_ptr[i]) * ({{data_t}})({{ input_tensor_val if input_tensor_is_a_const_num else "tmp_input_tensor_ptr[i]" }}) + ({{data_t}})(1 - tmp_condition_ptr[i]) * ({{data_t}})({{ other_tensor_val if other_tensor_is_a_const_num else "tmp_other_tensor_ptr[i]" }});
}
output[idx] = tmp_output;
}
} // namespace
void invoke_{{func_name}}(
void* output,
const void* condition,
{% if not input_tensor_is_a_const_num %}
const void* input_tesnor,
{% endif %}
{% if not other_tensor_is_a_const_num %}
const void* other_tensor,
{% endif %}
{{index_type}} num_elements,
{{prefix}}Stream_t stream) {
int grid_size = static_cast<int>(
std::ceil(static_cast<double>(num_elements) / N_THREADS_PER_BLOCK / N_READS_PER_THREAD));
where<<<grid_size, N_THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<{{read_t}}*>(output),
reinterpret_cast<const {{condition_read_t}}*>(condition),
{% if not input_tensor_is_a_const_num %}
reinterpret_cast<const {{read_t}}*>(input_tesnor),
{% endif %}
{% if not other_tensor_is_a_const_num %}
reinterpret_cast<const {{read_t}}*>(other_tensor),
{% endif %}
num_elements);
}
"""
)


@registry.reg("cuda.where.gen_function")
def gen_function(func_attrs: Dict[str, Any]) -> str:
condition, input_tensor, other_tensor = func_attrs["args"]
output = func_attrs["outputs"][0]
dtype = output.dtype()
backend_spec = CUDASpec()
read_t = backend_spec.get_elementwise_read_backend_type(
shape_utils.get_num_rightmost_static_elements(output.shape()), dtype
)
data_t = backend_spec.dtype_to_backend_type(dtype)
read_vector_length = (
backend_spec.sizeof_types[read_t] / backend_spec.sizeof_types[data_t]
)
# condition data type is bool, which is 1 byte
condition_read_t = {
1: "bool",
2: "half",
4: "float",
8: "int2",
16: "int4",
}[read_vector_length]

return FUNC_TEMPLATE.render(
header_files=backend_spec.header_src_template.render(
extra_header=CUDA_HEADER_FILES
),
constant=CONSTANT_TEMPLATE.render(condition_read_t=condition_read_t),
func_name=func_attrs["name"],
data_t=data_t,
read_t=read_t,
condition_read_t=condition_read_t,
index_type=backend_spec.index_type,
prefix=backend_spec.prefix,
input_tensor_is_a_const_num=input_tensor.is_a_const_num(),
other_tensor_is_a_const_num=other_tensor.is_a_const_num(),
input_tensor_val=str(input_tensor._attrs["value"]),
other_tensor_val=str(other_tensor._attrs["value"]),
)


@registry.reg("cuda.where.func_decl")
def gen_function_decl(func_attrs: Dict[str, Any]) -> str:
_, input_tensor, other_tensor = func_attrs["args"]
backend_spec = CUDASpec()
return FUNC_DECL.render(
func_name=func_attrs["name"],
prefix=backend_spec.prefix,
index_type=backend_spec.index_type,
input_tensor_is_a_const_num=input_tensor.is_a_const_num(),
other_tensor_is_a_const_num=other_tensor.is_a_const_num(),
)


@registry.reg("cuda.where.func_call")
def gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str:
condition, input_tensor, other_tensor = func_attrs["args"]
output = func_attrs["outputs"][0]
backend_spec = CUDASpec()
return FUNC_CALL_TEMPLATE.render(
func_name=func_attrs["name"],
output=output._attrs["name"],
condition=condition._attrs["name"],
input_tensor=input_tensor._attrs["name"],
other_tensor=other_tensor._attrs["name"],
calculate_n=gen_int_var_product_str(condition.shape()),
indent=indent,
index_type=backend_spec.index_type,
input_tensor_is_a_const_num=input_tensor.is_a_const_num(),
other_tensor_is_a_const_num=other_tensor.is_a_const_num(),
)
1 change: 1 addition & 0 deletions python/aitemplate/compiler/ops/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@
from aitemplate.compiler.ops.tensor.split import split
from aitemplate.compiler.ops.tensor.topk import topk
from aitemplate.compiler.ops.tensor.transpose import transpose
from aitemplate.compiler.ops.tensor.where import where
102 changes: 102 additions & 0 deletions python/aitemplate/compiler/ops/tensor/where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import Operator, Tensor
from aitemplate.compiler.dtype import normalize_dtype


class where(Operator):
"""
Return a tensor of elements selected from either input or other, depending on condition.
Parameters:
condition (A bool Tensor): When True (nonzero), yield input, otherwise yield other
input_tensor (Tensor or Scalar): value (if input is a scalar) or values selected at indices where condition is True
other_tensor (Tensor or Scalar): value (if other is a scalar) or values selected at indices where condition is False
dtype: output dtype if both input_tensor and output_tensor is scalar
Returns:
Tensor: A tensor of shape equal to the shape of condition
"""

def __init__(self) -> None:
super().__init__()
self._attrs["op"] = "where"

def __call__(
self,
condition: Tensor,
input_tensor: Tensor,
other_tensor: Tensor,
dtype: str = "",
) -> Tensor:
assert isinstance(
condition, Tensor
), f"condition needs to be a tensor, but got {type(condition)}"
assert (
condition.dtype() == "bool"
), f"condition needs to be a bool tensor, but got {condition.dtype()}"

output_shape = condition.shape()
args = []
inputs = []
common_dtype = None
for tensor in [input_tensor, other_tensor]:
if isinstance(tensor, int) or isinstance(tensor, float):
tensor = Tensor(shape=[], value=tensor, dtype=common_dtype)
else:
assert isinstance(
tensor, Tensor
), f"Unsupported data type: {type(tensor)}"
assert (
tensor.shape() == output_shape
), f"Tensor shape should be the same, {tensor.shape()} != {output_shape}"
if common_dtype is None:
common_dtype = normalize_dtype(tensor.dtype())
else:
assert common_dtype == normalize_dtype(
tensor.dtype()
), f"Expect tensor of the same dtype, got {common_dtype} and {normalize_dtype(tensor.dtype())}"
inputs.append(tensor)

args.append(tensor)

# In case where both inputs are scalars,
if len(inputs) == 0:
assert dtype != "", "dtype needs to be provided for scalars"
common_dtype = normalize_dtype(dtype)
for arg in args:
arg._attrs["dtype"] = common_dtype
self._attrs["args"] = [condition, *args]
self._attrs["inputs"] = [condition, *inputs]
self._set_depth()
output = Tensor(
shape=output_shape,
src_ops={self},
dtype=common_dtype,
)
self._attrs["outputs"] = [output]
return output

def gen_function(self) -> str:
target = backend.target.Target.current()
func_key = f"{target.name()}.{self._attrs['op']}.gen_function"
func = registry.get(func_key)
return func(self._attrs)
Loading

0 comments on commit 0cf1c2e

Please sign in to comment.