Skip to content

Commit

Permalink
Replace AT_DISPATCH with FBGEMM_DISPATCH, pt 3 (pytorch#2370)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/torchrec#1746

Pull Request resolved: pytorch#2370

- Replace AT_DISPATCH with FBGEMM_DISPATCH, pt 3

Reviewed By: sryap

Differential Revision: D54396473

fbshipit-source-id: 036ed75bb733ba27681f9510f24ea0bc9d8698fd
  • Loading branch information
q10 authored and facebook-github-bot committed Mar 4, 2024
1 parent 42753de commit f072386
Show file tree
Hide file tree
Showing 18 changed files with 131 additions and 208 deletions.
8 changes: 2 additions & 6 deletions fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,8 @@ Tensor split_embedding_codegen_forward_cpu(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
output.scalar_type(), "split_embedding_cpu_forward", [&]() {
using output_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::Byte,
weights.scalar_type(),
"split_embedding_cpu_forward",
[&] {
FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
weights.scalar_type(), "split_embedding_cpu_forward", [&] {
using ind_weights_t = std::conditional<
std::is_same<scalar_t, double>::value,
double,
Expand Down
27 changes: 23 additions & 4 deletions fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@
/// These macros cover bundled dispatch cases, similar to AT_DISPATCH_*_CASE
////////////////////////////////////////////////////////////////////////////////

#define FBGEMM_DISPATCH_CASE_FLOATING_TYPES(...) \
#define FBGEMM_DISPATCH_INTEGRAL_TYPES_CASE(...) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define FBGEMM_DISPATCH_FLOATING_TYPES_CASE(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
Expand All @@ -193,11 +197,15 @@
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)

// AT_DISPATCH_CASE_FLOATING_TYPES_AND BFloat16
#define FBGEMM_DISPATCH_FLOAT_AND_BFLOAT16_CASE(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)

#define FBGEMM_DISPATCH_ALL_TYPES_BUT_HALF_CASE(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
FBGEMM_DISPATCH_INTEGRAL_TYPES_CASE(__VA_ARGS__)

////////////////////////////////////////////////////////////////////////////////
/// Dispatch Macros
///
Expand All @@ -222,15 +230,26 @@

#define FBGEMM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, FBGEMM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
TYPE, NAME, FBGEMM_DISPATCH_FLOATING_TYPES_CASE(__VA_ARGS__))

#define FBGEMM_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
FBGEMM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
FBGEMM_DISPATCH_FLOATING_TYPES_CASE(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__))

#define FBGEMM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, FBGEMM_DISPATCH_INTEGRAL_TYPES_CASE(__VA_ARGS__))

#define FBGEMM_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
FBGEMM_DISPATCH_FLOATING_TYPES_CASE(__VA_ARGS__) \
FBGEMM_DISPATCH_INTEGRAL_TYPES_CASE(__VA_ARGS__))

// We can cleanup the following once fbgemm uses PyTorch 2.2 in January 2024.
#ifndef PT2_COMPLIANT_TAG
#ifdef HAS_PT2_COMPLIANT_TAG
Expand Down
9 changes: 0 additions & 9 deletions fbgemm_gpu/src/jagged_tensor_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,6 @@ namespace fbgemm_gpu {

using Tensor = at::Tensor;

#define DISPATCH_JAGGED_TYPES_CASE(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define DISPATCH_JAGGED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_JAGGED_TYPES_CASE(__VA_ARGS__))

namespace {

template <typename T>
Expand Down
51 changes: 23 additions & 28 deletions fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,37 +31,32 @@ Tensor dense_to_jagged_forward(

CUDA_DEVICE_GUARD(dense);

#define DISPATCH_DENSE_TO_JAGGED_OPT_CASE(TYPE) \
AT_DISPATCH_CASE(TYPE, [&] { \
jagged_dense_elementwise_jagged_output_opt_<scalar_t>( \
values, \
offsets, \
dense, \
output, \
[] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { \
return y; \
}); \
})

// clang-format off
AT_DISPATCH_SWITCH(
values.scalar_type(),
"dense_to_jagged_gpu_op_forward",
DISPATCH_DENSE_TO_JAGGED_OPT_CASE(at::ScalarType::Half)
DISPATCH_JAGGED_TYPES_CASE(
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
values,
offsets,
dense,
output,
[] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t {
return y;
}); // device lambda
} // lambda
) // DISPATCH_JAGGED_TYPES_CASE
); // SWITCH
// clang-format on
AT_DISPATCH_CASE(
at::ScalarType::Half,
[&] {
jagged_dense_elementwise_jagged_output_opt_<scalar_t>(
values,
offsets,
dense,
output,
[] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t {
return y;
});
})

FBGEMM_DISPATCH_ALL_TYPES_BUT_HALF_CASE([&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
values,
offsets,
dense,
output,
[] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t {
return y;
});
}));

return output;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,20 +256,17 @@ Tensor jagged_dense_dense_elementwise_add_jagged_output_forward(
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
[&] {
jagged_dense_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense_0,
dense_1,
output,
[] __device__(scalar_t x, scalar_t y_0, scalar_t y_1)
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH
FBGEMM_DISPATCH_FLOAT_AND_BFLOAT16_CASE([&] {
jagged_dense_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense_0,
dense_1,
output,
[] __device__(scalar_t x, scalar_t y_0, scalar_t y_1)
-> scalar_t { return x + y_0 + y_1; });
} // lambda
)); // SWITCH
}
return output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,8 @@ Tensor jagged_index_add_2d_forward_cuda(
// input_offsets has to be contiguous since it is passed to
// binary_search_range which accepts raw pointers
const auto input_offsets_contig = input_offsets.expect_contiguous();
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_index_add_2d_kernel_wrapper_1",
[&] {
FBGEMM_DISPATCH_ALL_TYPES(
values.scalar_type(), "jagged_index_add_2d_kernel_wrapper_1", [&] {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"jagged_index_add_2d_kernel_wrapper_2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,8 @@ Tensor jagged_index_select_2d_forward_cuda(
// output_offsets has to be contiguous since it is passed to
// binary_search_range which accepts raw pointers
const auto output_offsets_contig = output_offsets.expect_contiguous();
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_index_select_2d_kernel_wrapper_1",
[&] {
FBGEMM_DISPATCH_ALL_TYPES(
values.scalar_type(), "jagged_index_select_2d_kernel_wrapper_1", [&] {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"jagged_index_select_2d_kernel_wrapper_2",
Expand Down
73 changes: 21 additions & 52 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,12 +404,8 @@ at::Tensor jagged_to_padded_dense_forward(
Tensor padded_values_view =
D_folded ? padded_values.unsqueeze(-1) : padded_values;

AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_to_padded_dense",
[&] {
FBGEMM_DISPATCH_ALL_TYPES(
values.scalar_type(), "jagged_to_padded_dense", [&] {
jagged_dense_elementwise_dense_output_<scalar_t>(
values_canonicalized,
offsets,
Expand Down Expand Up @@ -440,9 +436,7 @@ at::Tensor jagged_to_padded_dense_backward(
auto grad_values =
at::zeros_symint({total_L, D}, grad_padded_values.options());

AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
FBGEMM_DISPATCH_ALL_TYPES(
grad_padded_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel",
[&] {
Expand Down Expand Up @@ -474,19 +468,14 @@ Tensor dense_to_jagged_forward(
auto values = at::empty_symint({total_L_computed, D}, dense.options());
auto output = at::zeros_symint({total_L_computed, D}, dense.options());

AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_scalars",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
values,
offsets,
dense,
output,
[](scalar_t /*unused*/, scalar_t y) -> scalar_t { return y; });
});
FBGEMM_DISPATCH_ALL_TYPES(values.scalar_type(), "jagged_scalars", [&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
values,
offsets,
dense,
output,
[](scalar_t /*unused*/, scalar_t y) -> scalar_t { return y; });
});

return output;
}
Expand Down Expand Up @@ -884,12 +873,8 @@ Tensor jagged_1d_to_truncated_values_cpu(
Tensor truncated_values;
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "jagged_1d_to_truncated_values_cpu_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"copy_values_and_truncate_cpu_kernel",
[&] {
FBGEMM_DISPATCH_ALL_TYPES(
values.scalar_type(), "copy_values_and_truncate_cpu_kernel", [&] {
const index_t max_length_int =
static_cast<index_t>(max_truncated_length);
const auto lengths_accessor = lengths.accessor<index_t, 1>();
Expand Down Expand Up @@ -936,12 +921,8 @@ std::tuple<Tensor, Tensor> masked_select_jagged_1d(

AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "mask_select_jagged_1d_kernel1", [&] {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"mask_select_jagged_1d_kernel2",
[&] {
FBGEMM_DISPATCH_ALL_TYPES(
values.scalar_type(), "mask_select_jagged_1d_kernel2", [&] {
const int32_t num_outputs = mask.sum().item<int32_t>();
masked_values = at::empty({num_outputs}, values.options());

Expand Down Expand Up @@ -1121,12 +1102,8 @@ Tensor jagged_index_select_2d_forward_cpu(
at::empty({num_dense_output_rows, num_cols}, values.options());

if (num_dense_output_rows > 0) {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_index_select_2d_kernel_wrapper_1",
[&] {
FBGEMM_DISPATCH_ALL_TYPES(
values.scalar_type(), "jagged_index_select_2d_kernel_wrapper_1", [&] {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"jagged_index_select_2d_kernel_wrapper_2",
Expand Down Expand Up @@ -1286,12 +1263,8 @@ Tensor jagged_index_add_2d_forward_cpu(
"jagged_index_add_2d_forward_cpu supports only 2D inputs");
auto num_cols = values.size(1);
Tensor output = at::zeros({num_output_rows, num_cols}, values.options());
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_index_add_2d_kernel_wrapper_1",
[&] {
FBGEMM_DISPATCH_ALL_TYPES(
values.scalar_type(), "jagged_index_add_2d_kernel_wrapper_1", [&] {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "jagged_index_add_2d_kernel_wrapper_2", [&] {
jagged_index_add_2d_kernel(
Expand Down Expand Up @@ -1605,12 +1578,8 @@ Tensor jagged_slice_forward_cpu(
auto output_offsets = asynchronous_exclusive_cumsum_cpu(output_lengths);
auto input_offsets = asynchronous_exclusive_cumsum_cpu(x_lengths);

AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
x_values.scalar_type(),
"jagged_slice_wrapper_1",
[&] {
FBGEMM_DISPATCH_ALL_TYPES(
x_values.scalar_type(), "jagged_slice_wrapper_1", [&] {
jagged_slice_forward_cpu_kernel<scalar_t>(
output_values.accessor<scalar_t, 1>(),
output_lengths.accessor<int64_t, 1>(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,8 @@ at::Tensor jagged_to_padded_dense_forward(
Tensor padded_values_view =
D_folded ? padded_values.unsqueeze(-1) : padded_values;

AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_to_padded_dense",
[&] {
FBGEMM_DISPATCH_ALL_TYPES(
values.scalar_type(), "jagged_to_padded_dense", [&] {
jagged_dense_elementwise_dense_output_<scalar_t>(
values_canonicalized,
offsets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,7 @@ class KeyedJaggedIndexSelectDim1GPUOp
num_batches, \
batch_size); \
}
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
FBGEMM_DISPATCH_ALL_TYPES(
values.scalar_type(),
"keyed_jagged_index_select_dim1_warpper_1",
[&] {
Expand Down Expand Up @@ -390,12 +388,8 @@ class KeyedJaggedIndexSelectDim1GPUOp
const auto grad_offsets_contig = grad_offsets.expect_contiguous();
if (grid_size != 0) {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad.scalar_type(),
"keyed_jagged_index_add_dim1_wrapper_1",
[&] {
FBGEMM_DISPATCH_ALL_TYPES(
grad.scalar_type(), "keyed_jagged_index_add_dim1_wrapper_1", [&] {
AT_DISPATCH_INDEX_TYPES(
grad_offsets.scalar_type(),
"keyed_jagged_index_add_dim1_wrapper_2",
Expand Down
Loading

0 comments on commit f072386

Please sign in to comment.