Skip to content

Commit

Permalink
Fix BC issues for get_unique_indices (#2602)
Browse files Browse the repository at this point in the history
Summary:

`get_unique_indices` can return a different number of tensors based on
the input arguments.  D55926421 attempted to overload the return type
of `get_unique_indices` by moving its definition and implementation
into Python.  However, that did not guarantee backward compatibility
(BC).  Thus, this diff fixes the BC issue by keeping the original
`get_unique_indices` definition and implementation in the backend
(C++) and define and implement `get_unique_indices_v2` in the frontend
(Python) for future usage.

Reviewed By: q10

Differential Revision: D57501865
  • Loading branch information
sryap authored and facebook-github-bot committed May 17, 2024
1 parent 1c0344f commit 48954a2
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 64 deletions.
29 changes: 20 additions & 9 deletions fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
lib = torch.library.Library("fbgemm", "FRAGMENT")
lib.define(
"""
get_unique_indices(
get_unique_indices_v2(
Tensor linear_indices,
int max_indices,
bool compute_count=False,
Expand All @@ -21,28 +21,39 @@
)


@torch.library.impl(lib, "get_unique_indices", "CUDA")
def get_unique_indices(
@torch.library.impl(lib, "get_unique_indices_v2", "CUDA")
def get_unique_indices_v2(
linear_indices: torch.Tensor,
max_indices: int,
compute_count: bool = False,
compute_inverse_indices: bool = False,
) -> Union[
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]],
Tuple[
torch.Tensor,
torch.Tensor,
Optional[torch.Tensor],
Tuple[torch.Tensor, torch.Tensor],
],
]:
"""
A wrapper for get_unique_indices for overloading the return type
based on inputs
"""
ret = torch.ops.fbgemm.get_unique_indices_internal(
ret = torch.ops.fbgemm.get_unique_indices_with_inverse(
linear_indices,
max_indices,
compute_count,
compute_inverse_indices,
)
if not compute_inverse_indices:
# Return only 3 tensors
if compute_count and compute_inverse_indices:
# Return all tensors
return ret
if compute_count:
# Return (unique_indices, length, count)
return ret[:-1]
# Return all tensors
return ret
if compute_inverse_indices:
# Return (unique_indices, length, inverse_indices)
return ret[0], ret[1], ret[3]
# Return (unique_indices, length)
return ret[:-2]
16 changes: 12 additions & 4 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,25 @@ enum uvm_cache_stats_index {

} // namespace fbgemm_gpu

///@ingroup table-batched-embed-cuda
/// Deduplicate indices.
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
get_unique_indices_cuda(
const at::Tensor& linear_indices,
const int64_t max_indices,
const bool compute_count);

///@ingroup table-batched-embed-cuda
/// Deduplicate indices.
std::tuple<
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>>
get_unique_indices_cuda(
at::Tensor linear_indices,
int64_t max_indices,
bool compute_count,
get_unique_indices_with_inverse_cuda(
const at::Tensor& linear_indices,
const int64_t max_indices,
const bool compute_count,
const bool compute_inverse_indices);

///@ingroup table-batched-embed-cuda
Expand Down
15 changes: 5 additions & 10 deletions fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,16 +271,11 @@ DLL_PUBLIC void lfu_cache_populate_cuda(
}
// get unqiue indices
auto
[unique_indices,
unique_indices_length,
unique_indices_count,
linear_cache_indices_positions_sorted] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/true,
/*compute_inverse_indices=*/false);
auto [unique_indices, unique_indices_length, unique_indices_count] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/true);
// update lfu counts
lfu_update_counts_cuda(
Expand Down
15 changes: 5 additions & 10 deletions fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,16 +240,11 @@ DLL_PUBLIC void lfu_cache_populate_byte_cuda(
}
// get unqiue indices
auto
[unique_indices,
unique_indices_length,
unique_indices_count,
linear_indices_postions_sorted] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/true,
/*compute_inverse_indices=*/false);
auto [unique_indices, unique_indices_length, unique_indices_count] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/true);
// update lfu counts
lfu_update_counts_cuda(
Expand Down
33 changes: 29 additions & 4 deletions fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,10 @@ DLL_PUBLIC Tensor linearize_cache_indices_from_row_idx_cuda(
DLL_PUBLIC
std::tuple<Tensor, Tensor, c10::optional<Tensor>, c10::optional<Tensor>>
get_unique_indices_cuda(
Tensor linear_indices,
int64_t max_indices,
bool compute_count,
get_unique_indices_cuda_impl(
const Tensor& linear_indices,
const int64_t max_indices,
const bool compute_count,
const bool compute_inverse_indices) {
TENSOR_ON_CUDA_GPU(linear_indices);
Expand Down Expand Up @@ -327,3 +327,28 @@ get_unique_indices_cuda(
#undef INVOKE_CUB_ENCODE
#undef INVOKE_CUB_UNIQUE
}
DLL_PUBLIC
std::tuple<Tensor, Tensor, c10::optional<Tensor>> get_unique_indices_cuda(
const Tensor& linear_indices,
const int64_t max_indices,
const bool compute_count) {
const auto ret = get_unique_indices_cuda_impl(
linear_indices,
max_indices,
compute_count,
/*compute_inverse_indices=*/false);
return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)};
}
DLL_PUBLIC
std::tuple<Tensor, Tensor, c10::optional<Tensor>, c10::optional<Tensor>>
get_unique_indices_with_inverse_cuda(
const Tensor& linear_indices,
const int64_t max_indices,
const bool compute_count,
const bool compute_inverse_indices) {
return get_unique_indices_cuda_impl(
linear_indices, max_indices, compute_count, compute_inverse_indices);
}
15 changes: 5 additions & 10 deletions fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -325,16 +325,11 @@ DLL_PUBLIC void lru_cache_populate_cuda(
}
// Get unqiue indices
auto
[unique_indices,
unique_indices_length,
unique_indices_count,
linear_cache_indices_positions_sorted] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/false,
/*compute_inverse_indices=*/false);
auto [unique_indices, unique_indices_length, unique_indices_count] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/false);
auto
[sorted_cache_sets,
Expand Down
15 changes: 5 additions & 10 deletions fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu
Original file line number Diff line number Diff line change
Expand Up @@ -549,16 +549,11 @@ DLL_PUBLIC void lru_cache_populate_byte_cuda(
}
// Get unqiue indices
auto
[unique_indices,
unique_indices_length,
unique_indices_count,
linear_cache_indices_positions_sorted] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/false,
/*compute_inverse_indices=*/false);
auto [unique_indices, unique_indices_length, unique_indices_count] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/false);
// Find uncached indices
Tensor lxu_cache_locking_counter =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"lxu_cache_locations_update(Tensor(a!) lxu_cache_locations, Tensor lxu_cache_locations_new, Tensor? num_uniq_cache_indices=None) -> ()");
m.def(
"get_unique_indices_internal("
"get_unique_indices("
" Tensor linear_indices, "
" int max_indices, "
" bool compute_count"
") -> (Tensor, Tensor, Tensor?)");
m.def(
"get_unique_indices_with_inverse("
" Tensor linear_indices, "
" int max_indices, "
" bool compute_count, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
lxu_cache_locking_counter_decrement_cuda);
DISPATCH_TO_CUDA(
"lxu_cache_locations_update", lxu_cache_locations_update_cuda);
DISPATCH_TO_CUDA("get_unique_indices_internal", get_unique_indices_cuda);
DISPATCH_TO_CUDA("get_unique_indices", get_unique_indices_cuda);
DISPATCH_TO_CUDA(
"get_unique_indices_with_inverse", get_unique_indices_with_inverse_cuda);
}

} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ ssd_cache_populate_actions_cuda(
unique_indices_length,
unique_indices_count,
linear_index_inverse_indices] =
get_unique_indices_cuda(
get_unique_indices_with_inverse_cuda(
linear_indices,
total_hash_size,
/*compute_count=*/true,
Expand Down
4 changes: 0 additions & 4 deletions fbgemm_gpu/test/tbe/cache/failures_dict_fast.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
"LXUCacheTest.test_faketensor__test_unique_lxu_cache_lookup": {
"comment": "",
"status": "xfail"
},
"LXUCacheTest.test_schema__test_unique_lxu_cache_lookup": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::int_nbit_split_embedding_codegen_lookup_function": {
Expand Down

0 comments on commit 48954a2

Please sign in to comment.