diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 6fcb9206bd..a80eea05e4 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -2843,8 +2843,9 @@ Tensor pack_segments_forward_cpu( TORCH_CHECK( t_in.dtype() == at::ScalarType::Float || t_in.dtype() == at::ScalarType::Double || - t_in.dtype() == at::ScalarType::Half, - "t_in must be of type float or double or half"); + t_in.dtype() == at::ScalarType::Half || + t_in.dtype() == at::ScalarType::BFloat16, + "t_in must be of type float, double, half, or bfloat16"); TORCH_CHECK_GT(max_length, 0); const auto t_in_cont = t_in.expect_contiguous(); @@ -2911,8 +2912,9 @@ Tensor pack_segments_backward_cpu( TORCH_CHECK( data.dtype() == at::ScalarType::Float || data.dtype() == at::ScalarType::Double || - data.dtype() == at::ScalarType::Half, - "data must be of type float or double or half"); + data.dtype() == at::ScalarType::Half || + data.dtype() == at::ScalarType::BFloat16, + "data must be of type float, double, half, or bfloat16"); TORCH_CHECK( max_length == data.sizes()[1], "max_length should be equal to the second dimension of the packed segments"); diff --git a/fbgemm_gpu/test/sparse/pack_segments_test.py b/fbgemm_gpu/test/sparse/pack_segments_test.py index d6b40328e8..dd53192779 100644 --- a/fbgemm_gpu/test/sparse/pack_segments_test.py +++ b/fbgemm_gpu/test/sparse/pack_segments_test.py @@ -91,6 +91,7 @@ def _pack_segments_ref( [ torch.float, torch.half, + torch.bfloat16, ] ), torch_compile=st.booleans(), @@ -192,6 +193,7 @@ def test_pack_segments( [ torch.float, torch.half, + torch.bfloat16, ] ), torch_compile=st.booleans(), @@ -207,7 +209,8 @@ def test_pack_segments_smaller_max_len( dtype: torch.dtype, torch_compile: bool, ) -> None: - input_data = torch.tensor(np.random.rand(batch_size, n, k), dtype=dtype) + input_raw = np.random.rand(batch_size, n, k) + input_data = torch.tensor(input_raw, dtype=dtype) lengths = torch.tensor( get_n_rand_num_summing_to_k(divisions, batch_size), dtype=torch.int ) @@ -221,10 +224,10 @@ def test_pack_segments_smaller_max_len( packed_ref = self._pack_segments_ref( lengths, - input_data, + input_raw, max_length=max_length, ) - # pyre-fixme[6]: For 2nd param expected `Tensor` but got `ndarray`. + packed_ref = torch.Tensor(packed_ref).to(dtype) self.assertTrue(torch.equal(packed_tensor, packed_ref)) if gpu_available: @@ -248,6 +251,7 @@ def test_pack_segments_smaller_max_len( [ torch.float, torch.half, + torch.bfloat16, ] ), )