diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops.cu b/fbgemm_gpu/src/sparse_ops/sparse_ops.cu index 8bc583b7b..302c9afca 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops.cu @@ -2144,8 +2144,9 @@ permute_sparse_features_cuda( permuted_weights = at::empty(permuted_lengths_sum, weights_value.options()); AT_DISPATCH_INDEX_TYPES( input_offsets.scalar_type(), "permute_indices_weights_kernel_1", [&] { - AT_DISPATCH_FLOATING_TYPES_AND( + AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Int, + at::ScalarType::BFloat16, weights_value.scalar_type(), "permute_indices_weights_kernel_2", [&] {