diff --git a/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu b/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu index 9037b7c09..c899bbf9b 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu @@ -62,18 +62,21 @@ DLL_PUBLIC Tensor pack_segments_backward_cuda( CUDA_DEVICE_GUARD(data); + const auto data_contig = data.expect_contiguous(); + Tensor unpacked_tensor; // The output tensor AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "unpack_segments_cuda", [&] { const auto* const lengths_data = lengths.data_ptr(); // Create output tensor of appropriate dimensions - auto shape = data.sizes().vec(); + auto shape = data_contig->sizes().vec(); shape.erase(shape.begin()); shape[0] = total_length; - unpacked_tensor = at::empty(shape, data.options()); + unpacked_tensor = at::empty(shape, data_contig->options()); - if (!(data.size(0) && data.size(1))) { // TODO: What does this mean? + if (!(data_contig->size(0) && + data_contig->size(1))) { // TODO: What does this mean? return; } @@ -82,10 +85,11 @@ DLL_PUBLIC Tensor pack_segments_backward_cuda( auto lps_data = lengths_prefix_sum.data_ptr(); FBGEMM_DISPATCH_ALL_TYPES( - data.scalar_type(), "unpack_segments_cuda-unpacking", [&] { + data_contig->scalar_type(), "unpack_segments_cuda-unpacking", [&] { const auto num_seq = lengths.size(0); - const auto cell_size = data.numel() / (data.size(0) * data.size(1)); - const auto* const data_ptr = data.data_ptr(); + const auto cell_size = data_contig->numel() / + (data_contig->size(0) * data_contig->size(1)); + const auto* const data_ptr = data_contig->data_ptr(); auto* const out_data = unpacked_tensor.data_ptr(); unpack_segments_cuda_kernel diff --git a/fbgemm_gpu/test/sparse/failures_dict.json b/fbgemm_gpu/test/sparse/failures_dict.json index 40cfacc06..fb2fcf85e 100644 --- a/fbgemm_gpu/test/sparse/failures_dict.json +++ b/fbgemm_gpu/test/sparse/failures_dict.json @@ -2,6 +2,16 @@ "_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit", "_version": 1, "data": { + "fb::pack_segments": { + "PackedSegmentsTest.test_aot_dispatch_dynamic__test_pack_segments_noncontig": { + "comment": "", + "status": "xfail" + }, + "PackedSegmentsTest.test_faketensor__test_pack_segments_noncontig": { + "comment": "", + "status": "xfail" + } + }, "fbgemm::asynchronous_complete_cumsum": {}, "fbgemm::asynchronous_exclusive_cumsum": {}, "fbgemm::asynchronous_inclusive_cumsum": {}, diff --git a/fbgemm_gpu/test/sparse/pack_segments_test.py b/fbgemm_gpu/test/sparse/pack_segments_test.py index c6383017a..26ba30407 100644 --- a/fbgemm_gpu/test/sparse/pack_segments_test.py +++ b/fbgemm_gpu/test/sparse/pack_segments_test.py @@ -22,9 +22,9 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available + from test_utils import gpu_available, gpu_unavailable else: - from fbgemm_gpu.test.test_utils import gpu_available + from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable def get_n_rand_num_summing_to_k(n: int, k: int) -> np.ndarray: @@ -46,6 +46,15 @@ def get_n_rand_num_summing_to_k(n: int, k: int) -> np.ndarray: # pyre-fixme[2] # pyre-fixme[24] def torch_compiled(model: Callable, **kwargs) -> Callable: + """A helper function to apply torch.compile if python < 3.12. + + Args: + model: The model to be compiled. + kwargs: The arguments to be passed to torch.compile. + + Returns: + The model. + """ if sys.version_info < (3, 12, 0): return torch.compile(model, **kwargs) else: @@ -59,6 +68,17 @@ def _pack_segments_ref( tensor: torch.Tensor, max_length: Optional[int] = None, ) -> np.ndarray: + """ + This function is a feference implementation of pack_segments. + + Args: + lengths (Tensor): The lengths of tensor. + tensor (Tensor): The tensor to be packed. + max_length (Optional[int]): The maximum length of the packed tensor. + + Returns: + The packed tensor. + """ lengths = lengths.numpy() sections = np.split(tensor, np.cumsum(lengths)) max_length = np.max(lengths, initial=0) if max_length is None else max_length @@ -104,6 +124,22 @@ def test_pack_segments( dtype: torch.dtype, torch_compile: bool, ) -> None: + """ + This function tests pack_segments ops compared to the reference implementation. + Both CPU and GPU (if available) are tested. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + dtype - The data type + torch_compile - Whether to use torch.compile + + Returns: + None + """ + input_raw = np.random.rand(batch_size, n, k) input_data = torch.tensor(input_raw, dtype=dtype, requires_grad=True) lengths = torch.tensor( @@ -206,6 +242,23 @@ def test_pack_segments_smaller_max_len( dtype: torch.dtype, torch_compile: bool, ) -> None: + """ + This function tests pack_segments ops with set max_length + Both CPU and GPU (if available) are tested. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + max_length - The maximum length of the packed tensor + dtype - The data type + torch_compile - Whether to use torch.compile + + Returns: + None + """ + input_data = torch.tensor(np.random.rand(batch_size, n, k), dtype=dtype) lengths = torch.tensor( get_n_rand_num_summing_to_k(divisions, batch_size), dtype=torch.int @@ -259,6 +312,20 @@ def test_pack_segments_meta_backend( divisions: int, dtype: torch.dtype, ) -> None: + """ + This function tests pack_segments ops with meta backend. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + dtype - The data type + + Returns: + None + """ + input_raw = np.random.rand(batch_size, n, k) input_data = torch.tensor( input_raw, dtype=torch.float32, requires_grad=True @@ -276,6 +343,109 @@ def test_pack_segments_meta_backend( # verify forward assert packed_tensor.size() == torch.Tensor(packed_ref).size() + @unittest.skipIf(*gpu_unavailable) + @given( + n=st.integers(2, 10), + k=st.integers(2, 10), + batch_size=st.integers(1, 30), + divisions=st.integers(1, 10), + dtype=st.sampled_from( + [ + torch.float, + torch.half, + ] + ), + torch_compile=st.booleans(), + use_cpu=st.booleans(), + ) + @settings(deadline=None) + def test_pack_segments_noncontig( + self, + n: int, + k: int, + batch_size: int, + divisions: int, + dtype: torch.dtype, + torch_compile: bool, + use_cpu: bool, + ) -> None: + """ + This function tests pack_segments ops when input gradients to backward are non-contiguous. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + dtype - The data type + torch_compile - Whether to use torch.compile + use_cpu - Whether to use CPU or GPU + + Returns: + None + """ + + input_raw = np.random.rand(batch_size, n, k) + # create input + input_data_ref = torch.tensor(input_raw, dtype=dtype, requires_grad=True) + input_data = torch.tensor(input_raw, dtype=dtype, requires_grad=True).cuda() + # retain grad to compare gradients of the inputs later + input_data.retain_grad() + input_data_ref.retain_grad() + + # set lengths + lengths = torch.tensor( + get_n_rand_num_summing_to_k(divisions, batch_size), + dtype=torch.int, + ) + max_length = lengths.max().item() + + packed_ref = torch.ops.fbgemm.pack_segments( + t_in=input_data_ref, lengths=lengths, max_length=max_length + ) + packed_ref.retain_grad() + + # pack segments using fbgemm and fb + packed_tensor = torch.ops.fbgemm.pack_segments( + t_in=input_data, lengths=lengths.cuda(), max_length=max_length + ) + packed_tensor.retain_grad() + + # verify forward + self.assertTrue(torch.equal(packed_tensor.cpu(), packed_ref)) + + # create non-contiguous grad + shape = tuple(x * 2 for x in packed_ref.shape) + grads = torch.tensor( + np.random.uniform(low=0.01, high=0.5, size=shape).astype(np.float32) + ).to(dtype) + grad_noncontig_cpu = grads.as_strided(packed_ref.shape, grads.stride()) + grad_noncontig_cuda = grads.cuda().as_strided(packed_ref.shape, grads.stride()) + + self.assertTrue( + not ( + grad_noncontig_cpu.is_contiguous() + and grad_noncontig_cuda.is_contiguous() + ), + msg="Expected grads to be non-contiguous but they are contiguous", + ) + + # verify backward + packed_ref.backward(grad_noncontig_cpu) + packed_tensor.backward(grad_noncontig_cuda) + self.assertTrue( + torch.equal(packed_tensor.cpu(), packed_ref), + msg="Expected packed tensors to be equal but they are not", + ) + + # verify backward input gradients + self.assertTrue( + # pyre-fixme[16]: Optional type has no attribute `cpu`. + # pyre-fixme[6]: For 2nd param expected `Tensor` but got `Optional[Tensor]`. + torch.equal(input_data.grad.cpu(), input_data_ref.grad.cpu()), + msg="Expected input gradients to be equal but they are not", + ) + extend_test_class(PackedSegmentsTest)