diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py b/fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py index 002418a4a0..58591d3c9d 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py @@ -11,7 +11,7 @@ lib = torch.library.Library("fbgemm", "FRAGMENT") lib.define( """ - get_unique_indices( + get_unique_indices_v2( Tensor linear_indices, int max_indices, bool compute_count=False, @@ -21,28 +21,39 @@ ) -@torch.library.impl(lib, "get_unique_indices", "CUDA") -def get_unique_indices( +@torch.library.impl(lib, "get_unique_indices_v2", "CUDA") +def get_unique_indices_v2( linear_indices: torch.Tensor, max_indices: int, compute_count: bool = False, compute_inverse_indices: bool = False, ) -> Union[ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], + Tuple[ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + Tuple[torch.Tensor, torch.Tensor], + ], ]: """ A wrapper for get_unique_indices for overloading the return type based on inputs """ - ret = torch.ops.fbgemm.get_unique_indices_internal( + ret = torch.ops.fbgemm.get_unique_indices_with_inverse( linear_indices, max_indices, compute_count, compute_inverse_indices, ) - if not compute_inverse_indices: - # Return only 3 tensors + if compute_count and compute_inverse_indices: + # Return all tensors + return ret + if compute_count: + # Return (unique_indices, length, count) return ret[:-1] - # Return all tensors - return ret + if compute_inverse_indices: + # Return (unique_indices, length, inverse_indices) + return ret[0], ret[1], ret[3] + # Return (unique_indices, length) + return ret[:-2] diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh index 537b33725d..d3d3d404aa 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh @@ -26,6 +26,14 @@ enum uvm_cache_stats_index { } // namespace fbgemm_gpu +///@ingroup table-batched-embed-cuda +/// Deduplicate indices. +std::tuple> +get_unique_indices_cuda( + const at::Tensor& linear_indices, + const int64_t max_indices, + const bool compute_count); + ///@ingroup table-batched-embed-cuda /// Deduplicate indices. std::tuple< @@ -33,10 +41,10 @@ std::tuple< at::Tensor, c10::optional, c10::optional> -get_unique_indices_cuda( - at::Tensor linear_indices, - int64_t max_indices, - bool compute_count, +get_unique_indices_with_inverse_cuda( + const at::Tensor& linear_indices, + const int64_t max_indices, + const bool compute_count, const bool compute_inverse_indices); ///@ingroup table-batched-embed-cuda diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu index cb1675bc32..585702da18 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu @@ -271,16 +271,11 @@ DLL_PUBLIC void lfu_cache_populate_cuda( } // get unqiue indices - auto - [unique_indices, - unique_indices_length, - unique_indices_count, - linear_cache_indices_positions_sorted] = - get_unique_indices_cuda( - linear_cache_indices, - total_cache_hash_size, - /*compute_count=*/true, - /*compute_inverse_indices=*/false); + auto [unique_indices, unique_indices_length, unique_indices_count] = + get_unique_indices_cuda( + linear_cache_indices, + total_cache_hash_size, + /*compute_count=*/true); // update lfu counts lfu_update_counts_cuda( diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu index 251dd1bf8a..1fa91c519e 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu @@ -240,16 +240,11 @@ DLL_PUBLIC void lfu_cache_populate_byte_cuda( } // get unqiue indices - auto - [unique_indices, - unique_indices_length, - unique_indices_count, - linear_indices_postions_sorted] = - get_unique_indices_cuda( - linear_cache_indices, - total_cache_hash_size, - /*compute_count=*/true, - /*compute_inverse_indices=*/false); + auto [unique_indices, unique_indices_length, unique_indices_count] = + get_unique_indices_cuda( + linear_cache_indices, + total_cache_hash_size, + /*compute_count=*/true); // update lfu counts lfu_update_counts_cuda( diff --git a/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu b/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu index e09af1462e..b391d44726 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu @@ -202,10 +202,10 @@ DLL_PUBLIC Tensor linearize_cache_indices_from_row_idx_cuda( DLL_PUBLIC std::tuple, c10::optional> -get_unique_indices_cuda( - Tensor linear_indices, - int64_t max_indices, - bool compute_count, +get_unique_indices_cuda_impl( + const Tensor& linear_indices, + const int64_t max_indices, + const bool compute_count, const bool compute_inverse_indices) { TENSOR_ON_CUDA_GPU(linear_indices); @@ -327,3 +327,28 @@ get_unique_indices_cuda( #undef INVOKE_CUB_ENCODE #undef INVOKE_CUB_UNIQUE } + +DLL_PUBLIC +std::tuple> get_unique_indices_cuda( + const Tensor& linear_indices, + const int64_t max_indices, + const bool compute_count) { + const auto ret = get_unique_indices_cuda_impl( + linear_indices, + max_indices, + compute_count, + /*compute_inverse_indices=*/false); + + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; +} + +DLL_PUBLIC +std::tuple, c10::optional> +get_unique_indices_with_inverse_cuda( + const Tensor& linear_indices, + const int64_t max_indices, + const bool compute_count, + const bool compute_inverse_indices) { + return get_unique_indices_cuda_impl( + linear_indices, max_indices, compute_count, compute_inverse_indices); +} diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu index 72f586f9d1..0cf8586f28 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu @@ -325,16 +325,11 @@ DLL_PUBLIC void lru_cache_populate_cuda( } // Get unqiue indices - auto - [unique_indices, - unique_indices_length, - unique_indices_count, - linear_cache_indices_positions_sorted] = - get_unique_indices_cuda( - linear_cache_indices, - total_cache_hash_size, - /*compute_count=*/false, - /*compute_inverse_indices=*/false); + auto [unique_indices, unique_indices_length, unique_indices_count] = + get_unique_indices_cuda( + linear_cache_indices, + total_cache_hash_size, + /*compute_count=*/false); auto [sorted_cache_sets, diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu index cc90385961..e884aa4892 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu @@ -549,16 +549,11 @@ DLL_PUBLIC void lru_cache_populate_byte_cuda( } // Get unqiue indices - auto - [unique_indices, - unique_indices_length, - unique_indices_count, - linear_cache_indices_positions_sorted] = - get_unique_indices_cuda( - linear_cache_indices, - total_cache_hash_size, - /*compute_count=*/false, - /*compute_inverse_indices=*/false); + auto [unique_indices, unique_indices_length, unique_indices_count] = + get_unique_indices_cuda( + linear_cache_indices, + total_cache_hash_size, + /*compute_count=*/false); // Find uncached indices Tensor lxu_cache_locking_counter = diff --git a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp index e049c2efd2..96e1380251 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp @@ -39,7 +39,13 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "lxu_cache_locations_update(Tensor(a!) lxu_cache_locations, Tensor lxu_cache_locations_new, Tensor? num_uniq_cache_indices=None) -> ()"); m.def( - "get_unique_indices_internal(" + "get_unique_indices(" + " Tensor linear_indices, " + " int max_indices, " + " bool compute_count" + ") -> (Tensor, Tensor, Tensor?)"); + m.def( + "get_unique_indices_with_inverse(" " Tensor linear_indices, " " int max_indices, " " bool compute_count, " diff --git a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cu b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cu index b6f3504bff..71bdafd39f 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cu @@ -33,7 +33,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { lxu_cache_locking_counter_decrement_cuda); DISPATCH_TO_CUDA( "lxu_cache_locations_update", lxu_cache_locations_update_cuda); - DISPATCH_TO_CUDA("get_unique_indices_internal", get_unique_indices_cuda); + DISPATCH_TO_CUDA("get_unique_indices", get_unique_indices_cuda); + DISPATCH_TO_CUDA( + "get_unique_indices_with_inverse", get_unique_indices_with_inverse_cuda); } } // namespace diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index 0ab8f6c0a2..3817aa7e16 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -273,7 +273,7 @@ ssd_cache_populate_actions_cuda( unique_indices_length, unique_indices_count, linear_index_inverse_indices] = - get_unique_indices_cuda( + get_unique_indices_with_inverse_cuda( linear_indices, total_hash_size, /*compute_count=*/true, diff --git a/fbgemm_gpu/test/tbe/cache/failures_dict_fast.json b/fbgemm_gpu/test/tbe/cache/failures_dict_fast.json index a18ec6b9e9..6e59307187 100644 --- a/fbgemm_gpu/test/tbe/cache/failures_dict_fast.json +++ b/fbgemm_gpu/test/tbe/cache/failures_dict_fast.json @@ -26,10 +26,6 @@ "LXUCacheTest.test_faketensor__test_unique_lxu_cache_lookup": { "comment": "", "status": "xfail" - }, - "LXUCacheTest.test_schema__test_unique_lxu_cache_lookup": { - "comment": "", - "status": "xfail" } }, "fbgemm::int_nbit_split_embedding_codegen_lookup_function": {