Skip to content

Commit

Permalink
Add BF16 support in permute_indices_weights_kernel_2 (pytorch#1852)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#1852

Reviewed By: sryap

Differential Revision: D47027925

fbshipit-source-id: 2cd03046d75ad58bf64491dadce6bb77d6be038b
  • Loading branch information
jianyuh authored and facebook-github-bot committed Jun 26, 2023
1 parent bbcac8b commit 555ad07
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion fbgemm_gpu/src/sparse_ops/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2144,8 +2144,9 @@ permute_sparse_features_cuda(
permuted_weights = at::empty(permuted_lengths_sum, weights_value.options());
AT_DISPATCH_INDEX_TYPES(
input_offsets.scalar_type(), "permute_indices_weights_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Int,
at::ScalarType::BFloat16,
weights_value.scalar_type(),
"permute_indices_weights_kernel_2",
[&] {
Expand Down

0 comments on commit 555ad07

Please sign in to comment.