From 555ad077f35465f7fdc5ca2be1a6759cf4ddeafd Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Mon, 26 Jun 2023 15:33:30 -0700 Subject: [PATCH] Add BF16 support in permute_indices_weights_kernel_2 (#1852) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1852 Reviewed By: sryap Differential Revision: D47027925 fbshipit-source-id: 2cd03046d75ad58bf64491dadce6bb77d6be038b --- fbgemm_gpu/src/sparse_ops/sparse_ops.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops.cu b/fbgemm_gpu/src/sparse_ops/sparse_ops.cu index 8bc583b7b..302c9afca 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops.cu @@ -2144,8 +2144,9 @@ permute_sparse_features_cuda( permuted_weights = at::empty(permuted_lengths_sum, weights_value.options()); AT_DISPATCH_INDEX_TYPES( input_offsets.scalar_type(), "permute_indices_weights_kernel_1", [&] { - AT_DISPATCH_FLOATING_TYPES_AND( + AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Int, + at::ScalarType::BFloat16, weights_value.scalar_type(), "permute_indices_weights_kernel_2", [&] {