diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 7db295f64..05d4112ba 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -251,3 +251,27 @@ def bounds_check_indices( max_B: int = -1, ) -> None: pass + + +@impl_abstract("fbgemm::permute_sparse_features") +def permute_sparse_features_abstract( + permute: Tensor, lengths: Tensor, indices: Tensor, weights: Optional[Tensor] = None +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + torch._check(lengths.dtype == indices.dtype) + torch._check(permute.device == lengths.device) + torch._check(permute.device == indices.device) + if weights is not None: + torch._check(permute.device == weights.device) + num_output_features = permute.numel() + B = lengths.size(1) + permuted_lengths = lengths.new_empty(num_output_features, B) + output_size = torch.library.get_ctx().new_dynamic_size() + # pyre-fixme[6]: In call `torch._C.TensorBase.new_empty`, for 1st positional argument, + # expected `Sequence[Union[int, types.SymInt]]` but got `Union[int, torch.SymInt]` + permuted_indices = indices.new_empty(output_size) + permuted_weights = None + if weights is not None: + # pyre-fixme[6]: In call `torch._C.TensorBase.new_empty`, for 1st positional argument, + # expected `Sequence[Union[int, types.SymInt]]` but got `Union[int, torch.SymInt]` + permuted_weights = weights.new_empty(output_size) + return (permuted_lengths, permuted_indices, permuted_weights) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 889082884..efdac1385 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -2738,7 +2738,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "lengths_range_out(Tensor output, Tensor t_in, SymInt[]? shape=None) -> Tensor"); m.def( - "permute_sparse_features(Tensor permute, Tensor lengths, Tensor indices, Tensor? weights=None) -> (Tensor, Tensor, Tensor?)"); + "permute_sparse_features(Tensor permute, Tensor lengths, Tensor indices, Tensor? weights=None) -> (Tensor, Tensor, Tensor?)", + {PT2_COMPLIANT_TAG}); m.def("Bfloat16QuantizedToFloat(Tensor input) -> Tensor"); m.def("FloatToBfloat16Quantized(Tensor input) -> Tensor"); m.def( diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index 5e0fed050..d1f810bc6 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -15,7 +15,7 @@ import random import unittest from itertools import accumulate -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union import fbgemm_gpu @@ -2402,6 +2402,167 @@ def validate( "grad", ) + def permute_sparse_features_ref_( + self, + lengths: torch.Tensor, + indices: torch.Tensor, + weights: Optional[torch.Tensor], + permute: torch.LongTensor, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + T = lengths.size(0) + B = lengths.size(1) + permuted_lengths = torch.index_select(lengths.view(T, B), 0, permute) + + original_segment_lengths = lengths.view(T, B).sum(dim=1, dtype=torch.int32) + original_segment_start = torch.ops.fbgemm.asynchronous_exclusive_cumsum( + original_segment_lengths.view(-1) + ) + + permuted_indices = [] + permuted_weights = [] + for i in range(permute.size(0)): + start = original_segment_start[permute[i]] + end = start + original_segment_lengths[permute[i]] + permuted_indices.append(indices[start:end]) + if weights is not None: + permuted_weights.append(weights[start:end]) + + permuted_indices = torch.cat(permuted_indices, dim=0).flatten() + + if weights is None: + permuted_weights = None + else: + permuted_weights = torch.cat(permuted_weights, dim=0).flatten() + + return permuted_lengths, permuted_indices, permuted_weights + + @given( + B=st.integers(min_value=1, max_value=20), + T=st.integers(min_value=1, max_value=20), + L=st.integers(min_value=2, max_value=20), + long_index=st.booleans(), + has_weight=st.booleans(), + ) + @settings(max_examples=20, deadline=None) + def test_permute_sparse_features( + self, B: int, T: int, L: int, long_index: bool, has_weight: bool + ) -> None: + index_dtype = torch.int64 if long_index else torch.int32 + lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype) + weights = torch.rand(int(lengths.sum().item())).float() if has_weight else None + indices = torch.randint( + low=1, + high=int(1e5), + size=cast(Tuple[int, ...], (lengths.sum().item(),)), + ).type(index_dtype) + permute_list = list(range(T)) + random.shuffle(permute_list) + permute = torch.IntTensor(permute_list) + + ( + permuted_lengths_cpu, + permuted_indices_cpu, + permuted_weights_cpu, + ) = torch.ops.fbgemm.permute_sparse_features(permute, lengths, indices, weights) + ( + permuted_lengths_ref, + permuted_indices_ref, + permuted_weights_ref, + # pyre-fixme[6]: For 4th param expected `LongTensor` but got `Tensor`. + ) = self.permute_indices_ref_(lengths, indices, weights, permute.long()) + torch.testing.assert_close(permuted_indices_cpu, permuted_indices_ref) + torch.testing.assert_close(permuted_lengths_cpu, permuted_lengths_ref) + if has_weight: + torch.testing.assert_close(permuted_weights_cpu, permuted_weights_ref) + else: + assert permuted_weights_cpu is None and permuted_weights_ref is None + + if gpu_available: + ( + permuted_lengths_gpu, + permuted_indices_gpu, + permuted_weights_gpu, + ) = torch.ops.fbgemm.permute_sparse_features( + permute.cuda(), + lengths.cuda(), + indices.cuda(), + weights.cuda() if has_weight and weights is not None else None, + ) + torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu) + torch.testing.assert_close(permuted_lengths_gpu.cpu(), permuted_lengths_cpu) + if has_weight: + torch.testing.assert_close( + permuted_weights_gpu.cpu(), permuted_weights_cpu + ) + else: + assert permuted_weights_gpu is None + + @given( + B=st.integers(min_value=1, max_value=20), + T=st.integers(min_value=1, max_value=20), + L=st.integers(min_value=2, max_value=20), + long_index=st.booleans(), + has_weight=st.booleans(), + ) + @settings(max_examples=20, deadline=None) + def test_permute_sparse_features_with_repeats( + self, B: int, T: int, L: int, long_index: bool, has_weight: bool + ) -> None: + index_dtype = torch.int64 if long_index else torch.int32 + lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype) + weights = torch.rand(int(lengths.sum().item())).float() if has_weight else None + indices = torch.randint( + low=1, + high=int(1e5), + size=cast(Tuple[int, ...], (lengths.sum().item(),)), + ).type(index_dtype) + permute_list = list(range(T)) + + num_repeats = random.randint(0, T) + for _ in range(num_repeats): + permute_list.append(random.randint(0, T - 1)) + + random.shuffle(permute_list) + permute = torch.IntTensor(permute_list) + + ( + permuted_lengths_cpu, + permuted_indices_cpu, + permuted_weights_cpu, + ) = torch.ops.fbgemm.permute_sparse_features(permute, lengths, indices, weights) + ( + permuted_lengths_ref, + permuted_indices_ref, + permuted_weights_ref, + # pyre-fixme[6]: For 4th param expected `LongTensor` but got `Tensor`. + ) = self.permute_indices_ref_(lengths, indices, weights, permute.long()) + torch.testing.assert_close(permuted_indices_cpu, permuted_indices_ref) + torch.testing.assert_close(permuted_lengths_cpu, permuted_lengths_ref) + if has_weight: + torch.testing.assert_close(permuted_weights_cpu, permuted_weights_ref) + else: + assert permuted_weights_cpu is None and permuted_weights_ref is None + + if gpu_available: + ( + permuted_lengths_gpu, + permuted_indices_gpu, + permuted_weights_gpu, + ) = torch.ops.fbgemm.permute_sparse_features( + permute.cuda(), + lengths.cuda(), + indices.cuda(), + weights.cuda() if has_weight and weights is not None else None, + ) + torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu) + torch.testing.assert_close(permuted_lengths_gpu.cpu(), permuted_lengths_cpu) + if has_weight: + torch.testing.assert_close( + permuted_weights_gpu.cpu(), permuted_weights_cpu + ) + else: + assert permuted_weights_cpu is None + failures_dict_path: str = get_file_path_2( "", os.path.dirname(__file__), "failures_dict.json"