diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 0bbd8a40f..73fa70dd4 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -595,7 +595,7 @@ if(NOT FBGEMM_CPU_ONLY) src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp src/quantize_ops/quantize_ops_gpu.cpp src/sparse_ops/sparse_ops_gpu.cpp - src/split_embeddings_utils.cpp + src/split_embeddings_utils/split_embeddings_utils.cpp src/split_embeddings_cache/split_embeddings_cache_ops.cu src/metric_ops/metric_ops_host.cpp src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp @@ -692,7 +692,10 @@ if(NOT FBGEMM_CPU_ONLY) src/split_embeddings_cache/lxu_cache.cu src/split_embeddings_cache/linearize_cache_indices.cu src/split_embeddings_cache/reset_weight_momentum.cu - src/split_embeddings_utils.cu) + src/split_embeddings_utils/generate_vbe_metadata.cu + src/split_embeddings_utils/get_infos_metadata.cu + src/split_embeddings_utils/radix_sort_pairs.cu + src/split_embeddings_utils/transpose_embedding_input.cu) set_source_files_properties(${fbgemm_gpu_sources_static_gpu} PROPERTIES COMPILE_OPTIONS diff --git a/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu b/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu new file mode 100644 index 000000000..fe65a8d95 --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu @@ -0,0 +1,156 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "fbgemm_gpu/ops_utils.h" +#include "fbgemm_gpu/sparse_ops_utils.h" +#include "fbgemm_gpu/split_embeddings_utils.cuh" + +using Tensor = at::Tensor; +using namespace fbgemm_gpu; + +__global__ +__launch_bounds__(kMaxThreads) void generate_vbe_metadata_foreach_sample_kernel( + at::PackedTensorAccessor32 + row_output_offsets, + at::PackedTensorAccessor32 b_t_map, + const at::PackedTensorAccessor32 + B_offsets, + const at::PackedTensorAccessor32 + B_offsets_rank_per_feature, + const at::PackedTensorAccessor32 + output_offsets_feature_rank, + const at::PackedTensorAccessor32 + D_offsets, + const int32_t D, + const bool nobag, + FixedDivisor fd_max_B, + FixedDivisor fd_max_B_T, + const int32_t info_B_num_bits) { + const auto r_b_t = blockIdx.x * blockDim.x + threadIdx.x; + const auto T = B_offsets.size(0) - 1; // Num tables + const auto R = B_offsets_rank_per_feature.size(1) - 1; // Num ranks + + int32_t b_t; + int32_t r; // Rank ID + int32_t t; // Table ID + int32_t b; // Relative sample ID in the rank-table matrix + + fd_max_B_T.DivMod(r_b_t, &r, &b_t); + if (r >= R) { + return; + } + + fd_max_B.DivMod(b_t, &t, &b); + if (t >= T) { + return; + } + + const auto B_start_r_t = B_offsets_rank_per_feature[t][r]; + const auto B_r_t = B_offsets_rank_per_feature[t][r + 1] - B_start_r_t; + if (b >= B_r_t) { + return; + } + + const auto B_start_t = B_offsets[t]; + // Update b_t + b_t = B_start_t + B_start_r_t + b; + const auto D_ = nobag ? D : D_offsets[t + 1] - D_offsets[t]; + row_output_offsets[b_t] = output_offsets_feature_rank[r * T + t] + b * D_; + + // Relative sample ID in the table + const auto b_ = B_start_r_t + b; + // b_t is always positive. + *reinterpret_cast(&b_t_map[b_t]) = + (reinterpret_cast(&t)[0] << info_B_num_bits) | + reinterpret_cast(&b_)[0]; +} + +/// Generate VBE metadata namely output_offsets and b_t_map +/// +/// row_output_offsets A 1D tensor that contains the output offset of each b +/// (sample) and t (feature/table) pair. The output +/// serializes O_r_t where O_r_t is the local output of rank +/// r and feature/table t (t is the fastest moving index). +/// b_t_map A 1D tensor that contains the b and t information of the +/// linearized b and t (b is the fastest moving index). +/// +/// @param B_offsets Batch size offsets for all features. +/// @param B_offsets_rank_per_feature Batch size offsets for all ranks (GPUs) +/// for each feature. +/// @param output_offsets_feature_rank Output offsets for all features and ranks +/// and features. +/// @param D_offsets Embedding dimension offsets. Required if +/// nobag is false. +/// @param D The embedding dimension. Required if +/// nobag is true. +/// @param nobag A boolean to indicate if TBE is pooled +/// (false) or sequence (true). +/// @param info_B_num_bits The number of bits used to encode a +/// sample ID. (Used for populating b_t_map). +/// @param total_B The total number of samples (i.e., the +/// total number of b and t pairs). +DLL_PUBLIC std::tuple +generate_vbe_metadata( + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& output_offsets_feature_rank, + const Tensor& D_offsets, + const int64_t D, + const bool nobag, + const int64_t max_B_feature_rank, + const int64_t info_B_num_bits, + const int64_t total_B) { + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( + B_offsets, B_offsets_rank_per_feature, output_offsets_feature_rank); + + TENSOR_NDIM_EQUALS(B_offsets, 1); + TENSOR_NDIM_EQUALS(B_offsets_rank_per_feature, 2); + TENSOR_NDIM_EQUALS(output_offsets_feature_rank, 1); + + const int32_t T = B_offsets.numel() - 1; + if (!nobag) { + TENSOR_ON_CUDA_GPU(D_offsets); + TENSORS_ON_SAME_DEVICE(B_offsets, D_offsets); + TORCH_CHECK(D_offsets.numel() == T + 1) + } + + const auto num_ranks = B_offsets_rank_per_feature.size(1) - 1; + TORCH_CHECK(B_offsets_rank_per_feature.size(0) == T); + TORCH_CHECK(output_offsets_feature_rank.numel() == num_ranks * T + 1); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(B_offsets.get_device()); + + Tensor row_output_offsets = + at::empty({total_B}, output_offsets_feature_rank.options()); + Tensor b_t_map = at::empty({total_B}, B_offsets.options()); + + // Over allocate total number of threads to avoid using binary search + generate_vbe_metadata_foreach_sample_kernel<<< + div_round_up(max_B_feature_rank * T * num_ranks, kMaxThreads), + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + row_output_offsets.packed_accessor32(), + b_t_map.packed_accessor32(), + B_offsets.packed_accessor32(), + B_offsets_rank_per_feature + .packed_accessor32(), + output_offsets_feature_rank + .packed_accessor32(), + D_offsets.packed_accessor32(), + D, + nobag, + FixedDivisor(max_B_feature_rank), + FixedDivisor(max_B_feature_rank * T), + info_B_num_bits); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return {row_output_offsets, b_t_map}; +} diff --git a/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu b/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu new file mode 100644 index 000000000..e8fdcc141 --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/embedding_backward_template_helpers.cuh" +#include "fbgemm_gpu/ops_utils.h" +#include "fbgemm_gpu/split_embeddings_utils.cuh" + +using Tensor = at::Tensor; +using namespace fbgemm_gpu; + +DLL_PUBLIC std::tuple adjust_info_B_num_bits( + int32_t B, + int32_t T) { + int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS; + uint32_t info_B_mask = DEFAULT_INFO_B_MASK; + uint32_t max_T = MAX_T; + uint32_t max_B = MAX_B; + bool invalid_T = T > max_T; + bool invalid_B = B > max_B; + + TORCH_CHECK( + !(invalid_T && invalid_B), + "Not enough infos bits to accommodate T and B. Default num bits = ", + DEFAULT_INFO_NUM_BITS); + + if (invalid_T) { + // Reduce info_B_num_bits + while (invalid_T && !invalid_B && info_B_num_bits > 0) { + info_B_num_bits--; + max_T = ((max_T + 1) << 1) - 1; + max_B = ((max_B + 1) >> 1) - 1; + invalid_T = T > max_T; + invalid_B = B > max_B; + } + } else if (invalid_B) { + // Increase info_B_num_bits + while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) { + info_B_num_bits++; + max_T = ((max_T + 1) >> 1) - 1; + max_B = ((max_B + 1) << 1) - 1; + invalid_T = T > max_T; + invalid_B = B > max_B; + } + } + + TORCH_CHECK( + !invalid_T && !invalid_B, + "Not enough infos bits to accommodate T and B. Default num bits = ", + DEFAULT_INFO_NUM_BITS); + + // Recompute info_B_mask using new info_B_num_bits + info_B_mask = (1u << info_B_num_bits) - 1; + + return {info_B_num_bits, info_B_mask}; +} + +std::tuple DLL_PUBLIC +get_infos_metadata(Tensor unused, int64_t B, int64_t T) { + return adjust_info_B_num_bits(B, T); +} diff --git a/fbgemm_gpu/src/split_embeddings_utils/radix_sort_pairs.cu b/fbgemm_gpu/src/split_embeddings_utils/radix_sort_pairs.cu new file mode 100644 index 000000000..121f66625 --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_utils/radix_sort_pairs.cu @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/split_embeddings_utils.cuh" + +#include +#include +#include "fbgemm_gpu/embedding_backward_template_helpers.cuh" +#include "fbgemm_gpu/ops_utils.h" + +// clang-format off +#include "fbgemm_gpu/cub_namespace_prefix.cuh" +#include +#include +#include +#include "fbgemm_gpu/cub_namespace_postfix.cuh" +// clang-format on + +using Tensor = at::Tensor; +using namespace fbgemm_gpu; + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 +#define DEF_RADIX_SORT_PAIRS_FN(KeyT, ValueT) \ + DLL_PUBLIC cudaError_t radix_sort_pairs( \ + void* d_temp_storage, \ + size_t& temp_storage_bytes, \ + const KeyT* d_keys_in, \ + KeyT* d_keys_out, \ + const ValueT* d_values_in, \ + ValueT* d_values_out, \ + const int num_items, \ + const int begin_bit, \ + const int end_bit, \ + cudaStream_t stream) { \ + return FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \ + d_temp_storage, \ + temp_storage_bytes, \ + d_keys_in, \ + d_keys_out, \ + d_values_in, \ + d_values_out, \ + num_items, \ + begin_bit, \ + end_bit, \ + stream); \ + } +#else +#define DEF_RADIX_SORT_PAIRS_FN(KeyT, ValueT) \ + DLL_PUBLIC cudaError_t radix_sort_pairs( \ + void* d_temp_storage, \ + size_t& temp_storage_bytes, \ + const KeyT* d_keys_in, \ + KeyT* d_keys_out, \ + const ValueT* d_values_in, \ + ValueT* d_values_out, \ + const int num_items, \ + const int begin_bit, \ + const int end_bit, \ + cudaStream_t stream) { \ + return FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \ + d_temp_storage, \ + temp_storage_bytes, \ + d_keys_in, \ + d_keys_out, \ + d_values_in, \ + d_values_out, \ + num_items, \ + begin_bit, \ + end_bit, \ + stream, \ + false); \ + } +#endif + +DEF_RADIX_SORT_PAIRS_FN(int64_t, float); +DEF_RADIX_SORT_PAIRS_FN(int64_t, double); +DEF_RADIX_SORT_PAIRS_FN(int64_t, int64_t); +DEF_RADIX_SORT_PAIRS_FN(int64_t, int32_t); diff --git a/fbgemm_gpu/src/split_embeddings_utils.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp similarity index 100% rename from fbgemm_gpu/src/split_embeddings_utils.cpp rename to fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp diff --git a/fbgemm_gpu/src/split_embeddings_utils.cu b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu similarity index 60% rename from fbgemm_gpu/src/split_embeddings_utils.cu rename to fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu index dd5c0ec70..d1eb5e00a 100644 --- a/fbgemm_gpu/src/split_embeddings_utils.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu @@ -6,12 +6,9 @@ * LICENSE file in the root directory of this source tree. */ -#include "fbgemm_gpu/split_embeddings_utils.cuh" - -#include -#include #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/ops_utils.h" +#include "fbgemm_gpu/split_embeddings_utils.cuh" // clang-format off #include "fbgemm_gpu/cub_namespace_prefix.cuh" @@ -21,9 +18,8 @@ #include "fbgemm_gpu/cub_namespace_postfix.cuh" // clang-format on -#ifdef USE_ROCM -#include -#endif +using Tensor = at::Tensor; +using namespace fbgemm_gpu; inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) { at::cuda::OptionalCUDAGuard device_guard; @@ -62,10 +58,6 @@ inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) { return t_out; } -using Tensor = at::Tensor; - -using namespace fbgemm_gpu; - template __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel( const at::PackedTensorAccessor32 @@ -394,253 +386,3 @@ transpose_embedding_input( sorted_linear_indices_num_runs, sorted_linear_indices_cumulative_run_lengths}; } - -std::tuple -get_infos_metadata(Tensor unused, int64_t B, int64_t T) { - return adjust_info_B_num_bits(B, T); -} - -DLL_PUBLIC std::tuple adjust_info_B_num_bits( - int32_t B, - int32_t T) { - int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS; - uint32_t info_B_mask = DEFAULT_INFO_B_MASK; - uint32_t max_T = MAX_T; - uint32_t max_B = MAX_B; - bool invalid_T = T > max_T; - bool invalid_B = B > max_B; - - TORCH_CHECK( - !(invalid_T && invalid_B), - "Not enough infos bits to accommodate T and B. Default num bits = ", - DEFAULT_INFO_NUM_BITS); - - if (invalid_T) { - // Reduce info_B_num_bits - while (invalid_T && !invalid_B && info_B_num_bits > 0) { - info_B_num_bits--; - max_T = ((max_T + 1) << 1) - 1; - max_B = ((max_B + 1) >> 1) - 1; - invalid_T = T > max_T; - invalid_B = B > max_B; - } - } else if (invalid_B) { - // Increase info_B_num_bits - while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) { - info_B_num_bits++; - max_T = ((max_T + 1) >> 1) - 1; - max_B = ((max_B + 1) << 1) - 1; - invalid_T = T > max_T; - invalid_B = B > max_B; - } - } - - TORCH_CHECK( - !invalid_T && !invalid_B, - "Not enough infos bits to accommodate T and B. Default num bits = ", - DEFAULT_INFO_NUM_BITS); - - // Recompute info_B_mask using new info_B_num_bits - info_B_mask = (1u << info_B_num_bits) - 1; - - return {info_B_num_bits, info_B_mask}; -} - -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 -#define DEF_RADIX_SORT_PAIRS_FN(KeyT, ValueT) \ - DLL_PUBLIC cudaError_t radix_sort_pairs( \ - void* d_temp_storage, \ - size_t& temp_storage_bytes, \ - const KeyT* d_keys_in, \ - KeyT* d_keys_out, \ - const ValueT* d_values_in, \ - ValueT* d_values_out, \ - const int num_items, \ - const int begin_bit, \ - const int end_bit, \ - cudaStream_t stream) { \ - return FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \ - d_temp_storage, \ - temp_storage_bytes, \ - d_keys_in, \ - d_keys_out, \ - d_values_in, \ - d_values_out, \ - num_items, \ - begin_bit, \ - end_bit, \ - stream); \ - } -#else -#define DEF_RADIX_SORT_PAIRS_FN(KeyT, ValueT) \ - DLL_PUBLIC cudaError_t radix_sort_pairs( \ - void* d_temp_storage, \ - size_t& temp_storage_bytes, \ - const KeyT* d_keys_in, \ - KeyT* d_keys_out, \ - const ValueT* d_values_in, \ - ValueT* d_values_out, \ - const int num_items, \ - const int begin_bit, \ - const int end_bit, \ - cudaStream_t stream) { \ - return FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \ - d_temp_storage, \ - temp_storage_bytes, \ - d_keys_in, \ - d_keys_out, \ - d_values_in, \ - d_values_out, \ - num_items, \ - begin_bit, \ - end_bit, \ - stream, \ - false); \ - } -#endif - -DEF_RADIX_SORT_PAIRS_FN(int64_t, float); -DEF_RADIX_SORT_PAIRS_FN(int64_t, double); -DEF_RADIX_SORT_PAIRS_FN(int64_t, int64_t); -DEF_RADIX_SORT_PAIRS_FN(int64_t, int32_t); - -__global__ -__launch_bounds__(kMaxThreads) void generate_vbe_metadata_foreach_sample_kernel( - at::PackedTensorAccessor32 - row_output_offsets, - at::PackedTensorAccessor32 b_t_map, - const at::PackedTensorAccessor32 - B_offsets, - const at::PackedTensorAccessor32 - B_offsets_rank_per_feature, - const at::PackedTensorAccessor32 - output_offsets_feature_rank, - const at::PackedTensorAccessor32 - D_offsets, - const int32_t D, - const bool nobag, - FixedDivisor fd_max_B, - FixedDivisor fd_max_B_T, - const int32_t info_B_num_bits) { - const auto r_b_t = blockIdx.x * blockDim.x + threadIdx.x; - const auto T = B_offsets.size(0) - 1; // Num tables - const auto R = B_offsets_rank_per_feature.size(1) - 1; // Num ranks - - int32_t b_t; - int32_t r; // Rank ID - int32_t t; // Table ID - int32_t b; // Relative sample ID in the rank-table matrix - - fd_max_B_T.DivMod(r_b_t, &r, &b_t); - if (r >= R) { - return; - } - - fd_max_B.DivMod(b_t, &t, &b); - if (t >= T) { - return; - } - - const auto B_start_r_t = B_offsets_rank_per_feature[t][r]; - const auto B_r_t = B_offsets_rank_per_feature[t][r + 1] - B_start_r_t; - if (b >= B_r_t) { - return; - } - - const auto B_start_t = B_offsets[t]; - // Update b_t - b_t = B_start_t + B_start_r_t + b; - const auto D_ = nobag ? D : D_offsets[t + 1] - D_offsets[t]; - row_output_offsets[b_t] = output_offsets_feature_rank[r * T + t] + b * D_; - - // Relative sample ID in the table - const auto b_ = B_start_r_t + b; - // b_t is always positive. - *reinterpret_cast(&b_t_map[b_t]) = - (reinterpret_cast(&t)[0] << info_B_num_bits) | - reinterpret_cast(&b_)[0]; -} - -/// Generate VBE metadata namely output_offsets and b_t_map -/// -/// row_output_offsets A 1D tensor that contains the output offset of each b -/// (sample) and t (feature/table) pair. The output -/// serializes O_r_t where O_r_t is the local output of rank -/// r and feature/table t (t is the fastest moving index). -/// b_t_map A 1D tensor that contains the b and t information of the -/// linearized b and t (b is the fastest moving index). -/// -/// @param B_offsets Batch size offsets for all features. -/// @param B_offsets_rank_per_feature Batch size offsets for all ranks (GPUs) -/// for each feature. -/// @param output_offsets_feature_rank Output offsets for all features and ranks -/// and features. -/// @param D_offsets Embedding dimension offsets. Required if -/// nobag is false. -/// @param D The embedding dimension. Required if -/// nobag is true. -/// @param nobag A boolean to indicate if TBE is pooled -/// (false) or sequence (true). -/// @param info_B_num_bits The number of bits used to encode a -/// sample ID. (Used for populating b_t_map). -/// @param total_B The total number of samples (i.e., the -/// total number of b and t pairs). -DLL_PUBLIC std::tuple -generate_vbe_metadata( - const Tensor& B_offsets, - const Tensor& B_offsets_rank_per_feature, - const Tensor& output_offsets_feature_rank, - const Tensor& D_offsets, - const int64_t D, - const bool nobag, - const int64_t max_B_feature_rank, - const int64_t info_B_num_bits, - const int64_t total_B) { - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( - B_offsets, B_offsets_rank_per_feature, output_offsets_feature_rank); - - TENSOR_NDIM_EQUALS(B_offsets, 1); - TENSOR_NDIM_EQUALS(B_offsets_rank_per_feature, 2); - TENSOR_NDIM_EQUALS(output_offsets_feature_rank, 1); - - const int32_t T = B_offsets.numel() - 1; - if (!nobag) { - TENSOR_ON_CUDA_GPU(D_offsets); - TENSORS_ON_SAME_DEVICE(B_offsets, D_offsets); - TORCH_CHECK(D_offsets.numel() == T + 1) - } - - const auto num_ranks = B_offsets_rank_per_feature.size(1) - 1; - TORCH_CHECK(B_offsets_rank_per_feature.size(0) == T); - TORCH_CHECK(output_offsets_feature_rank.numel() == num_ranks * T + 1); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(B_offsets.get_device()); - - Tensor row_output_offsets = - at::empty({total_B}, output_offsets_feature_rank.options()); - Tensor b_t_map = at::empty({total_B}, B_offsets.options()); - - // Over allocate total number of threads to avoid using binary search - generate_vbe_metadata_foreach_sample_kernel<<< - div_round_up(max_B_feature_rank * T * num_ranks, kMaxThreads), - kMaxThreads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - row_output_offsets.packed_accessor32(), - b_t_map.packed_accessor32(), - B_offsets.packed_accessor32(), - B_offsets_rank_per_feature - .packed_accessor32(), - output_offsets_feature_rank - .packed_accessor32(), - D_offsets.packed_accessor32(), - D, - nobag, - FixedDivisor(max_B_feature_rank), - FixedDivisor(max_B_feature_rank * T), - info_B_num_bits); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return {row_output_offsets, b_t_map}; -} diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu similarity index 100% rename from fbgemm_gpu/src/ssd_split_embeddings_cache_cuda.cu rename to fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu diff --git a/fbgemm_gpu/src/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp similarity index 100% rename from fbgemm_gpu/src/ssd_split_table_batched_embeddings.cpp rename to fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp diff --git a/fbgemm_gpu/src/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h similarity index 100% rename from fbgemm_gpu/src/ssd_table_batched_embeddings.h rename to fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h diff --git a/fbgemm_gpu/test/split_embeddings_utils_test.py b/fbgemm_gpu/test/split_embeddings_utils_test.py index 19bb14fff..bbf0960e5 100644 --- a/fbgemm_gpu/test/split_embeddings_utils_test.py +++ b/fbgemm_gpu/test/split_embeddings_utils_test.py @@ -13,15 +13,12 @@ import hypothesis.strategies as st import torch - +from fbgemm_gpu import sparse_ops # noqa: F401 from hypothesis import given, HealthCheck, settings try: - # pyre-ignore[21] - from fbgemm_gpu import open_source # noqa: F401 from test_utils import gpu_unavailable # pyre-ignore[21] except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") from fbgemm_gpu.test.test_utils import gpu_unavailable