Skip to content

Commit

Permalink
Switch hipcub::DeviceRadixSort::SortPairs call to rocprim::device_rad…
Browse files Browse the repository at this point in the history
…ix_sort_pairs (#3059)

Summary:
X-link: facebookresearch/FBGEMM#182

Pull Request resolved: #3059

Reviewed By: sryap

Differential Revision: D62249387

Pulled By: jianyuh

fbshipit-source-id: 806cd47a577be4b35f7e355f96106367e73a32fe
  • Loading branch information
liligwu authored and facebook-github-bot committed Sep 6, 2024
1 parent 7c75838 commit 2841b03
Showing 1 changed file with 38 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#include "fbgemm_gpu/split_embeddings_utils.cuh" // @manual
#include "fbgemm_gpu/utils/ops_utils.h" // @manual
#include "fbgemm_gpu/utils/tensor_accessor.h" // @manual

#ifdef USE_ROCM
#include <rocprim/device/device_radix_sort.hpp>
#endif
// clang-format off
#include "fbgemm_gpu/utils/cub_namespace_prefix.cuh" // @manual
#include <cub/device/device_radix_sort.cuh>
Expand Down Expand Up @@ -307,6 +309,7 @@ transpose_embedding_input(
}
{
size_t temp_storage_bytes = 0;
#ifndef USE_ROCM
AT_CUDA_CHECK(
FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs(
nullptr,
Expand Down Expand Up @@ -336,6 +339,40 @@ transpose_embedding_input(
total_hash_size_bits,
at::cuda::getCurrentCUDAStream(),
false));
#else
using config = rocprim::radix_sort_config<
rocprim::default_config,
rocprim::default_config,
rocprim::default_config,
400000>;
rocprim::radix_sort_pairs<config>(
nullptr,
temp_storage_bytes,
linear_indices.data_ptr<index_t>(),
linear_indices_sorted.data_ptr<index_t>(),
infos.data_ptr<info_t>(),
infos_sorted.data_ptr<info_t>(),
linear_indices.numel(),
0,
total_hash_size_bits,
at::cuda::getCurrentCUDAStream(),
false);
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
indices.options().dtype(at::kByte));
rocprim::radix_sort_pairs<config>(
temp_storage.data_ptr(),
temp_storage_bytes,
linear_indices.data_ptr<index_t>(),
linear_indices_sorted.data_ptr<index_t>(),
infos.data_ptr<info_t>(),
infos_sorted.data_ptr<info_t>(),
linear_indices.numel(),
0,
total_hash_size_bits,
at::cuda::getCurrentCUDAStream(),
false);
#endif
}
if (total_unique_indices != -1) {
TORCH_CHECK(total_unique_indices >= 0);
Expand Down

0 comments on commit 2841b03

Please sign in to comment.