Skip to content

Commit

Permalink
Remove unsupported type dispatch from FBGEMM ops, pt. 1 (pytorch#1989)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1989

- Remove unsupported type dispatch from FBGEMM ops, pt. 1

Reviewed By: sryap

Differential Revision: D48895311

fbshipit-source-id: 05a0ee7e3db10bbe720b01457a08b29de6f8afdc
  • Loading branch information
q10 authored and facebook-github-bot committed Sep 7, 2023
1 parent ce3943c commit f664fd9
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "fbgemm/FbgemmEmbedding.h"
#include "fbgemm_gpu/cpu_utils.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/dispatch_macros.h"

using Tensor = at::Tensor;
using namespace fbgemm_gpu;
Expand Down Expand Up @@ -193,10 +194,10 @@ for (const auto t : c10::irange(t_begin,t_end)) {

{% endif %}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "split_embedding_backward_cpu", [&] {
using grad_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
host_weights.scalar_type(),
"split_embedding_backward_cpu_inner",
[&] {
Expand Down
7 changes: 4 additions & 3 deletions fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "fbgemm/FbgemmEmbedding.h"
#include "fbgemm/Types.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/cpu_utils.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

Expand Down Expand Up @@ -344,11 +345,11 @@ for (const auto d : c10::irange(D)) {
grad_output = grad_output.contiguous();


AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(),
"split_embedding_backward_exact_cpu_outer", [&]() {
using grad_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
split_embedding_backward_exact_cpu_kernel<scalar_t, grad_t>(
grad_output,
Expand Down Expand Up @@ -379,7 +380,7 @@ for (const auto d : c10::irange(D)) {

// When input is dense enough, avoid sorting and just treat as dense.
auto grad = zeros_like(host_weights, grad_output.dtype());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "split_embedding_backward_exact_cpu", [&] {

split_embedding_backward_exact_cpu_dense_kernel<scalar_t>(
Expand Down
7 changes: 4 additions & 3 deletions fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "fbgemm/Types.h"
#include "fbgemm/Utils.h"
#include "fbgemm_gpu/cpu_utils.h"
#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
#ifdef FBCODE_CAFFE2
Expand Down Expand Up @@ -201,7 +202,7 @@ Tensor split_embedding_codegen_forward_cpu(
// It is assumed that the indice_weights will always be float
TORCH_CHECK(
!indice_weights.defined() || indice_weights.scalar_type() != at::kHalf);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
output.scalar_type(), "split_embedding_cpu_forward", [&]() {
using output_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES_AND2(
Expand Down Expand Up @@ -298,12 +299,12 @@ Tensor split_embedding_codegen_grad_indice_weights_cpu(
indices,
indices.options().dtype(
at::toAccumulateType(grad_output.scalar_type(), true)));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(),
"split_embedding_grad_indice_weights_cpu_outer",
[&] {
using grad_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
weights.scalar_type(),
"split_embedding_grad_indice_weights_cpu",
[&] {
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/src/jagged_tensor_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "fbgemm_gpu/cub_namespace_postfix.cuh"
// clang-format on

#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/fbgemm_tensor_accessor.h"
#include "fbgemm_gpu/ops_utils.h"
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <torch/library.h>
#include "ATen/Parallel.h"

#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ class KeyedJaggedIndexSelectDim1GPUOp
"keyed_jagged_index_select_dim1_warpper_3",
[&] {
if (weights.has_value()) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
weights.value().scalar_type(),
"keyed_jagged_index_select_dim1_warpper_4",
[&] {
Expand Down
7 changes: 4 additions & 3 deletions fbgemm_gpu/src/layout_transform_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// clang-format off
#include "fbgemm_gpu/cub_namespace_prefix.cuh"
#include <cub/device/device_scan.cuh>
#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/cub_namespace_postfix.cuh"
// clang-format on

Expand Down Expand Up @@ -49,7 +50,7 @@ Tensor recat_embedding_grad_output_cuda(

Tensor sharded_grad_output =
at::empty({grad_output.numel()}, grad_output.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "recat_embedding_gradients", [&] {
const auto go = grad_output.accessor<scalar_t, 3>();
auto sgo = sharded_grad_output.accessor<scalar_t, 1>();
Expand Down Expand Up @@ -93,7 +94,7 @@ Tensor recat_embedding_grad_output_mixed_D_cuda(
Tensor sharded_grad_output =
at::empty({grad_output.numel()}, grad_output.options());

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "recat_embedding_gradients", [&] {
const auto go = grad_output.accessor<scalar_t, 2>();
auto sgo = sharded_grad_output.accessor<scalar_t, 1>();
Expand Down Expand Up @@ -145,7 +146,7 @@ Tensor recat_embedding_grad_output_mixed_D_batch_cuda(
const dim3 blocks(fbgemm_gpu::div_round_up(
(B_local * dim_num), fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize));

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "recat_embedding_gradients", [&] {
recat_copy_async_kernel<scalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/src/layout_transform_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>
#include "ATen/Parallel.h"
#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

using Tensor = at::Tensor;
Expand Down Expand Up @@ -37,7 +38,7 @@ Tensor recat_embedding_grad_output_mixed_D_cpu(
const auto global_dim_sum = accum_dim_sum[n];
TORCH_CHECK(B_local * global_dim_sum == grad_output.numel());

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "recat_embedding_gradients", [&] {
const auto go = grad_output.accessor<scalar_t, 2>();
auto sgo = sharded_grad_output.accessor<scalar_t, 1>();
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/src/metric_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ATen/cuda/Atomic.cuh>
#include <algorithm>

#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "metric_ops.h"

Expand Down Expand Up @@ -276,7 +277,7 @@ at::Tensor batch_auc(
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half, labels.scalar_type(), "auc_wrapper_2", [&] {
using label_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
weights.scalar_type(), "auc_wrapper_3", [&] {
using acc_t = at::acc_type<scalar_t, true>;
if (padded_section_size == 1) {
Expand Down

0 comments on commit f664fd9

Please sign in to comment.