diff --git a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu index f9ef076bc..16d46e2e3 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu @@ -10,7 +10,9 @@ #include "fbgemm_gpu/split_embeddings_utils.cuh" // @manual #include "fbgemm_gpu/utils/ops_utils.h" // @manual #include "fbgemm_gpu/utils/tensor_accessor.h" // @manual - +#ifdef USE_ROCM +#include +#endif // clang-format off #include "fbgemm_gpu/utils/cub_namespace_prefix.cuh" // @manual #include @@ -307,6 +309,7 @@ transpose_embedding_input( } { size_t temp_storage_bytes = 0; +#ifndef USE_ROCM AT_CUDA_CHECK( FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( nullptr, @@ -336,6 +339,40 @@ transpose_embedding_input( total_hash_size_bits, at::cuda::getCurrentCUDAStream(), false)); +#else + using config = rocprim::radix_sort_config< + rocprim::default_config, + rocprim::default_config, + rocprim::default_config, + 400000>; + rocprim::radix_sort_pairs( + nullptr, + temp_storage_bytes, + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), + infos.data_ptr(), + infos_sorted.data_ptr(), + linear_indices.numel(), + 0, + total_hash_size_bits, + at::cuda::getCurrentCUDAStream(), + false); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + indices.options().dtype(at::kByte)); + rocprim::radix_sort_pairs( + temp_storage.data_ptr(), + temp_storage_bytes, + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), + infos.data_ptr(), + infos_sorted.data_ptr(), + linear_indices.numel(), + 0, + total_hash_size_bits, + at::cuda::getCurrentCUDAStream(), + false); +#endif } if (total_unique_indices != -1) { TORCH_CHECK(total_unique_indices >= 0);