diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip index 3450d2cad..b076e281e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip @@ -72,8 +72,14 @@ static const std::unordered_map< fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2}, {{128, 7168, 8192}, fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + {{1024, 7168, 8192}, + fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5}, {{2048, 7168, 8192}, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{4096, 7168, 8192}, + fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{8192, 7168, 8192}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, // Support for decode across batch sizes for [8192, 3584] {{16, 8192, 3584}, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2}, @@ -83,6 +89,14 @@ static const std::unordered_map< fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2}, {{128, 8192, 3584}, fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + {{1024, 8192, 3584}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{2048, 8192, 3584}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{4096, 8192, 3584}, + fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{8192, 8192, 3584}, + fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, // Llama 405B Decode Shapes. // Support for decode across batch sizes for [13312, 6656]. {{16, 13312, 6656}, @@ -102,8 +116,14 @@ static const std::unordered_map< fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, {{128, 13312, 16384}, fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, - {{2048, 13312, 16384}, + {{1024, 13312, 16384}, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{2048, 13312, 16384}, + fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{4096, 13312, 16384}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{8192, 13312, 16384}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, // Support for decode across batch sizes for [16384, 6656]. {{16, 16384, 6656}, fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1}, @@ -113,8 +133,14 @@ static const std::unordered_map< fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, {{128, 16384, 6656}, fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + {{1024, 16384, 6656}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, {{2048, 16384, 6656}, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{4096, 16384, 6656}, + fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{8192, 16384, 6656}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, // Support for decode across batch sizes for [16384, 16384]. {{16, 16384, 16384}, fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2}, diff --git a/fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp b/fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp index b2d082906..b9f1324f8 100644 --- a/fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp +++ b/fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp @@ -29,31 +29,65 @@ namespace fbgemm_gpu { void _cat_int_tensors_out( Tensor& combined_tensors, const std::vector& tensor_list, - int64_t total_num) { + int64_t total_num, + bool to_trim_padding = false, + const std::vector& indices_terminating_idx = + std::vector()) { + if (to_trim_padding) { + // We need to define the teminating idx for each indices tensor + TORCH_CHECK(indices_terminating_idx.size() == tensor_list.size()); + } at::native::resize_(combined_tensors, {total_num}); auto* combined_tensors_data_ptr = combined_tensors.mutable_data_ptr(); size_t idx = 0; - for (const auto& tensor : tensor_list) { + // Let's keep the original paddings and later pad them in the end + std::vector paddings; + paddings.reserve(total_num); + + for (size_t i = 0; i < tensor_list.size(); ++i) { + const auto& tensor = tensor_list[i]; AT_DISPATCH_INDEX_TYPES(tensor.scalar_type(), "tbe_cat_inputs_", [&] { // Necessary to use data_ptr. Checked in caller, but let's // be safe in case somebody changes that. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.is_contiguous()); auto* indices_data_ptr = tensor.const_data_ptr(); - const auto numel = tensor.numel(); - for (auto j = 0; j < numel; j++) { + auto numel = tensor.numel(); + if (to_trim_padding) { + const auto terminating_idx = indices_terminating_idx.at(i); + numel = terminating_idx > 0 && terminating_idx < numel ? terminating_idx + : numel; + } + size_t j = 0; + for (; j < numel; j++) { combined_tensors_data_ptr[idx++] = static_cast(indices_data_ptr[j]); } + for (; j < tensor.numel(); j++) { + paddings.push_back(indices_data_ptr[j]); + } }); } + + // Pad the original paddings in the end + int i = 0; + while (idx < total_num) { + if (i < paddings.size()) [[likely]] { + combined_tensors_data_ptr[idx++] = paddings[i++]; + } else { + combined_tensors_data_ptr[idx++] = 0; + } + } } Tensor _cat_int_tensors( const std::vector& tensor_list, int64_t total_num, - bool use_pin_memory) { + bool use_pin_memory, + bool to_trim_padding = false, + const std::vector& indices_terminating_idx = + std::vector()) { // Using int type to maintain original behavior // in https://fburl.com/code/h2lwews2 auto combined_tensors = at::empty( @@ -63,7 +97,12 @@ Tensor _cat_int_tensors( .device(tensor_list[0].device()) .pinned_memory(use_pin_memory)); - _cat_int_tensors_out(combined_tensors, tensor_list, total_num); + _cat_int_tensors_out( + combined_tensors, + tensor_list, + total_num, + to_trim_padding, + indices_terminating_idx); return combined_tensors; } @@ -104,7 +143,14 @@ void _cat_per_sample_weights_list_out( Tensor& out, const std::vector& per_sample_weights, const std::vector& indices_list, - int64_t total_num) { + int64_t total_num, + bool to_trim_padding = false, + const std::vector& indices_terminating_idx = + std::vector()) { + if (to_trim_padding) { + // We need to define the teminating idx for each indices tensor + TORCH_CHECK(indices_terminating_idx.size() == indices_list.size()); + } at::native::resize_(out, {total_num}); out.fill_(1.); @@ -112,13 +158,22 @@ void _cat_per_sample_weights_list_out( for (size_t i = 0; i < per_sample_weights.size(); i++) { auto element_size = per_sample_weights[i].numel(); + auto actual_indices_size = indices_list[i].numel(); + if (to_trim_padding) { + element_size = element_size > indices_terminating_idx.at(i) + ? indices_terminating_idx.at(i) + : element_size; + actual_indices_size = actual_indices_size > indices_terminating_idx.at(i) + ? indices_terminating_idx.at(i) + : actual_indices_size; + } if (element_size != 0) { memcpy( out_weights_ptr, per_sample_weights[i].data_ptr(), element_size * sizeof(float)); } - out_weights_ptr += indices_list[i].numel(); + out_weights_ptr += actual_indices_size; } } @@ -126,7 +181,10 @@ Tensor _cat_per_sample_weights_list( const std::vector& per_sample_weights, const std::vector& indices_list, int64_t total_num, - bool use_pin_memory) { + bool use_pin_memory, + bool to_trim_padding = false, + const std::vector& indices_terminating_idx = + std::vector()) { auto combined_weights = at::empty( {0}, at::TensorOptions() @@ -134,7 +192,12 @@ Tensor _cat_per_sample_weights_list( .device(per_sample_weights[0].device()) .pinned_memory(use_pin_memory)); _cat_per_sample_weights_list_out( - combined_weights, per_sample_weights, indices_list, total_num); + combined_weights, + per_sample_weights, + indices_list, + total_num, + to_trim_padding, + indices_terminating_idx); return combined_weights; } @@ -155,6 +218,24 @@ std::tuple tbe_input_combine_cpu( bool need_weights = false; bool pin_memory = false; + // We only enable this feature when all the elements of `include_last_offsets` + // are True and there is at least one indices tensor has paddings + bool indices_tensor_has_padding = false; + for (size_t i = 0; i < indices_list.size(); i++) { + if (indices_list[i].numel() > offsets_list[i][-1].item().toLong()) { + indices_tensor_has_padding = true; + break; + } + } + auto to_trim_padding = + indices_tensor_has_padding && include_last_offsets.all().item(); + // In case of index tensors have padding, we need to determine the boundary + // i.e. the terminating idx, to properly combine the TBE inputs + // `indices_terminating_idx` is a list of the terminating idx for each index + // tensor + std::vector indices_terminating_idx; + indices_terminating_idx.reserve(indices_list.size()); + for (size_t i = 0; i < indices_list.size(); i++) { TORCH_CHECK( indices_list[i].dtype() == c10::kInt || @@ -167,6 +248,14 @@ std::tuple tbe_input_combine_cpu( TORCH_CHECK(indices_list[i].is_contiguous()); TORCH_CHECK(offsets_list[i].is_contiguous()); total_indices += indices_list[i].numel(); + if (to_trim_padding) { + // When the offsets tensor has last offset, we respect this value + // And the last offset value should be less than (in case there are + // paddings) or equal to the number of elements in the indices tensor + TORCH_CHECK_LE( + offsets_list[i][-1].item().toLong(), indices_list[i].numel()); + indices_terminating_idx.push_back(offsets_list[i][-1].item().toLong()); + } auto num_offset = offsets_list[i].numel() - (include_last_offsets_acc[i] ? 1 : 0); total_offsets += num_offset == 0 ? 1 : num_offset; @@ -179,8 +268,12 @@ std::tuple tbe_input_combine_cpu( } } - auto combined_indices = - _cat_int_tensors(indices_list, total_indices, pin_memory); + auto combined_indices = _cat_int_tensors( + indices_list, + total_indices, + pin_memory, + to_trim_padding, + indices_terminating_idx); auto combined_offsets = at::empty( {total_offsets}, @@ -208,17 +301,26 @@ std::tuple tbe_input_combine_cpu( offset + static_cast(offsets_data_ptr[j]); } - offset += static_cast(indices_list[i].numel()); + if (to_trim_padding) { + offset += static_cast(offsets_list[i][-1].item().toInt()); + } else { + offset += static_cast(indices_list[i].numel()); + } combined_offsets_data_ptr[offsets_acc_idx++] = offset; }); } if (need_weights) { return { - std::move(combined_indices), - std::move(combined_offsets), + combined_indices, + combined_offsets, _cat_per_sample_weights_list( - per_sample_weights, indices_list, total_indices, pin_memory)}; + per_sample_weights, + indices_list, + total_indices, + pin_memory, + to_trim_padding, + indices_terminating_idx)}; } return {combined_indices, combined_offsets, at::empty({0})}; } diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 6fcb9206b..a80eea05e 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -2843,8 +2843,9 @@ Tensor pack_segments_forward_cpu( TORCH_CHECK( t_in.dtype() == at::ScalarType::Float || t_in.dtype() == at::ScalarType::Double || - t_in.dtype() == at::ScalarType::Half, - "t_in must be of type float or double or half"); + t_in.dtype() == at::ScalarType::Half || + t_in.dtype() == at::ScalarType::BFloat16, + "t_in must be of type float, double, half, or bfloat16"); TORCH_CHECK_GT(max_length, 0); const auto t_in_cont = t_in.expect_contiguous(); @@ -2911,8 +2912,9 @@ Tensor pack_segments_backward_cpu( TORCH_CHECK( data.dtype() == at::ScalarType::Float || data.dtype() == at::ScalarType::Double || - data.dtype() == at::ScalarType::Half, - "data must be of type float or double or half"); + data.dtype() == at::ScalarType::Half || + data.dtype() == at::ScalarType::BFloat16, + "data must be of type float, double, half, or bfloat16"); TORCH_CHECK( max_length == data.sizes()[1], "max_length should be equal to the second dimension of the packed segments"); diff --git a/fbgemm_gpu/test/combine/input_combine_test.py b/fbgemm_gpu/test/combine/input_combine_test.py index c465fa191..14569232e 100644 --- a/fbgemm_gpu/test/combine/input_combine_test.py +++ b/fbgemm_gpu/test/combine/input_combine_test.py @@ -8,6 +8,8 @@ # pyre-strict import unittest +from typing import List, Tuple + import torch from fbgemm_gpu import sparse_ops # noqa: F401 from hypothesis import given, settings @@ -53,6 +55,44 @@ def _get_inputs(self, dtypes, device=DEFAULT_DEVICE): include_last_offsets, ) + def _get_prepadded_inputs( + self, + dtypes: List[torch.dtype], + device: torch._C.device = DEFAULT_DEVICE, + include_last: bool = True, + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[torch.Tensor], + List[torch.Tensor], + List[bool], + ]: + indices_list = [ + torch.tensor([1, 2, 3, 123, 123, 123], dtype=dtypes[0], device=device), + torch.tensor([1, 2, 3, 4, 456, 456, 456], dtype=dtypes[1], device=device), + ] + offsets_list = [ + torch.tensor([0, 2, 3], dtype=dtypes[0], device=device), + torch.tensor([0, 1, 4], dtype=dtypes[1], device=device), + ] + # One of the offsets tensor is always with the last offset + include_last_offsets = [True, include_last] + per_sample_weights = [ + torch.tensor([1, 2, 1, 0, 0, 0], dtype=torch.float, device=device), + torch.tensor([1, 2, 1, 3, 0, 0, 0], dtype=torch.float, device=device), + ] + empty_per_sample_weights = [ + torch.tensor([], dtype=torch.float, device=device), + torch.tensor([], dtype=torch.float, device=device), + ] + return ( + indices_list, + offsets_list, + per_sample_weights, + empty_per_sample_weights, + include_last_offsets, + ) + # pyre-fixme[2]: Parameter must be annotated. def _run_test(self, dtypes) -> None: ( @@ -91,6 +131,101 @@ def _run_test(self, dtypes) -> None: self.assertTrue(outputs[1].dtype == torch.int32) self.assertTrue(outputs[-1].size(0) == 0) + def _run_test_with_prepadded_indices_weights(self) -> None: + """ + When indices tensors are having paddings and the offsets tensors are all + with the last offset, we should expect the outputs will have values in + the front with paddings in the end. + """ + dtypes = [torch.int64, torch.int64] + ( + indices_list, + offsets_list, + per_sample_weights, + empty_per_sample_weights, + include_last_offsets, + ) = self._get_prepadded_inputs(dtypes, include_last=True) + + outputs = torch.ops.fbgemm.tbe_input_combine( + indices_list, + offsets_list, + per_sample_weights, + torch.BoolTensor(include_last_offsets), + ) + expected_outputs = [ + torch.tensor( + [1, 2, 3, 1, 2, 3, 4, 123, 123, 123, 456, 456, 456], dtype=torch.int32 + ), + torch.tensor([0, 2, 3, 4, 7], dtype=torch.int32), + torch.tensor( + [1.0, 2.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ), + ] + for i, j in zip(outputs, expected_outputs): + torch.testing.assert_close(i, j) + self.assertTrue(outputs[0].dtype == torch.int32) + self.assertTrue(outputs[1].dtype == torch.int32) + + outputs = torch.ops.fbgemm.tbe_input_combine( + indices_list, + offsets_list, + empty_per_sample_weights, + torch.BoolTensor(include_last_offsets), + ) + expected_outputs = [ + torch.tensor( + [1, 2, 3, 1, 2, 3, 4, 123, 123, 123, 456, 456, 456], dtype=torch.int32 + ), + torch.tensor([0, 2, 3, 4, 7], dtype=torch.int32), + torch.empty(0), + ] + for i, j in zip(outputs, expected_outputs): + torch.testing.assert_close(i, j) + self.assertTrue(outputs[0].dtype == torch.int32) + self.assertTrue(outputs[1].dtype == torch.int32) + self.assertTrue(outputs[2].size(0) == 0) + + def _run_test_with_prepadded_indices_weights_without_last_offsets(self) -> None: + """ + When indices tensors are having paddings and there is at least one offsets + tensor doesn't have the last offset, we should expect the outputs will be as + previously. + """ + dtypes = [torch.int64, torch.int64] + ( + indices_list, + offsets_list, + per_sample_weights, + empty_per_sample_weights, + include_last_offsets, + ) = self._get_prepadded_inputs(dtypes, include_last=False) + ref_mod = TBEInputPrepareReference(include_last_offsets) + + outputs = torch.ops.fbgemm.tbe_input_combine( + indices_list, + offsets_list, + per_sample_weights, + torch.BoolTensor(include_last_offsets), + ) + ref_outputs = ref_mod(indices_list, offsets_list, per_sample_weights) + for i, j in zip(outputs, ref_outputs): + torch.testing.assert_close(i, j) + self.assertTrue(outputs[0].dtype == torch.int32) + self.assertTrue(outputs[1].dtype == torch.int32) + + outputs = torch.ops.fbgemm.tbe_input_combine( + indices_list, + offsets_list, + empty_per_sample_weights, + torch.BoolTensor(include_last_offsets), + ) + ref_outputs = ref_mod(indices_list, offsets_list, per_sample_weights) + for i, j in zip(outputs[:-1], ref_outputs[:-1]): + torch.testing.assert_close(i, j) + self.assertTrue(outputs[0].dtype == torch.int32) + self.assertTrue(outputs[1].dtype == torch.int32) + self.assertTrue(outputs[2].size(0) == 0) + # pyre-fixme[2]: Parameter must be annotated. def _run_padding_fused_test(self, dtypes, batch_size) -> None: ( @@ -234,6 +369,14 @@ def test_input_combine_int32(self) -> None: def test_input_combined_mix(self) -> None: self._run_test((torch.int64, torch.int32)) + def test_tbe_input_combine_cpu_with_padded_indices(self) -> None: + self._run_test_with_prepadded_indices_weights() + + def test_tbe_input_combine_cpu_with_padded_indices_without_last_offsets( + self, + ) -> None: + self._run_test_with_prepadded_indices_weights_without_last_offsets() + # pyre-fixme[56]: Pyre was not able to infer the type of argument # `test_utils.cpu_and_maybe_gpu()` to decorator factory `hypothesis.given`. @given(device=cpu_and_maybe_gpu()) diff --git a/fbgemm_gpu/test/sparse/pack_segments_test.py b/fbgemm_gpu/test/sparse/pack_segments_test.py index d6b40328e..dd5319277 100644 --- a/fbgemm_gpu/test/sparse/pack_segments_test.py +++ b/fbgemm_gpu/test/sparse/pack_segments_test.py @@ -91,6 +91,7 @@ def _pack_segments_ref( [ torch.float, torch.half, + torch.bfloat16, ] ), torch_compile=st.booleans(), @@ -192,6 +193,7 @@ def test_pack_segments( [ torch.float, torch.half, + torch.bfloat16, ] ), torch_compile=st.booleans(), @@ -207,7 +209,8 @@ def test_pack_segments_smaller_max_len( dtype: torch.dtype, torch_compile: bool, ) -> None: - input_data = torch.tensor(np.random.rand(batch_size, n, k), dtype=dtype) + input_raw = np.random.rand(batch_size, n, k) + input_data = torch.tensor(input_raw, dtype=dtype) lengths = torch.tensor( get_n_rand_num_summing_to_k(divisions, batch_size), dtype=torch.int ) @@ -221,10 +224,10 @@ def test_pack_segments_smaller_max_len( packed_ref = self._pack_segments_ref( lengths, - input_data, + input_raw, max_length=max_length, ) - # pyre-fixme[6]: For 2nd param expected `Tensor` but got `ndarray`. + packed_ref = torch.Tensor(packed_ref).to(dtype) self.assertTrue(torch.equal(packed_tensor, packed_ref)) if gpu_available: @@ -248,6 +251,7 @@ def test_pack_segments_smaller_max_len( [ torch.float, torch.half, + torch.bfloat16, ] ), )