From acf1e7484336817f8580ecb516fc3bd699b76746 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Fri, 27 Sep 2024 17:45:44 -0700 Subject: [PATCH] Add support for int64_t indices and offsets in TBE inference [7B/N] Summary: - Fix `index_remapping` in `pruned_array_lookup` to be `int64_t` always, since it is set up before we are provided `indices` and `offsets` Differential Revision: D63567553 --- ...bedding_forward_quantized_cpu_template.cpp | 106 ++++++++++-------- ...mbedding_forward_quantized_split_lookup.cu | 14 +-- ..._table_batched_embeddings_ops_inference.py | 15 +-- .../inference/nbit_forward_autovec_test.py | 2 +- .../test/tbe/inference/nbit_forward_test.py | 38 ++++++- 5 files changed, 107 insertions(+), 68 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp index 92eff015f4..17fcc9560e 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp @@ -65,54 +65,64 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu( TENSOR_ON_CPU(hash_table); TENSOR_ON_CPU(hash_table_offsets); - int32_t T = hash_table_offsets.size(0) - 1; - int32_t B = (offsets.size(0) - 1) / T; + const int32_t T = hash_table_offsets.size(0) - 1; + const int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); - const auto* indices_acc = indices.data_ptr(); - const auto* dense_indices_acc = dense_indices.data_ptr(); - - const auto* offsets_acc = offsets.data_ptr(); - auto hash_table_acc = hash_table.accessor(); - const auto hash_table_offsets_acc = hash_table_offsets.accessor(); -for (const auto t : c10::irange(T)) { - int64_t table_start = hash_table_offsets_acc[t]; - int64_t table_end = hash_table_offsets_acc[t + 1]; - if (table_start == table_end) { - continue; - } - int64_t capacity = table_end - table_start; -for (const auto b : c10::irange(B)) { - int32_t indices_start = offsets_acc[t * B + b]; - int32_t indices_end = offsets_acc[t * B + b + 1]; - int32_t L = indices_end - indices_start; -for (const auto l : c10::irange(L)) { - int32_t idx = indices_acc[indices_start + l]; - int32_t dense_idx = dense_indices_acc[indices_start + l]; - if (dense_idx == -1) { - // -1 means this row has been pruned, do not insert it. - continue; - } - uint32_t slot = pruned_hash_function(static_cast(idx)) % capacity; - while (true) { - int32_t slot_sparse_idx = hash_table_acc[table_start + static_cast(slot)][0]; - // empty slot - if (slot_sparse_idx == -1) { - hash_table_acc[table_start + static_cast(slot)][0] = idx; - hash_table_acc[table_start + static_cast(slot)][1] = dense_idx; - break; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] { + using hash_t = + std::conditional_t, uint64_t, uint32_t>; + + const auto* indices_acc = indices.data_ptr(); + const auto* dense_indices_acc = dense_indices.data_ptr(); + const auto* offsets_acc = offsets.data_ptr(); + + auto hash_table_acc = hash_table.accessor(); + const auto hash_table_offsets_acc = hash_table_offsets.accessor(); + + for (const auto t : c10::irange(T)) { + const auto table_start = hash_table_offsets_acc[t]; + const auto table_end = hash_table_offsets_acc[t + 1]; + if (table_start == table_end) { + continue; + } + const auto capacity = table_end - table_start; + + for (const auto b : c10::irange(B)) { + const auto indices_start = offsets_acc[t * B + b]; + const auto indices_end = offsets_acc[t * B + b + 1]; + const auto L = indices_end - indices_start; + + for (const auto l : c10::irange(L)) { + const auto idx = indices_acc[indices_start + l]; + const auto dense_idx = dense_indices_acc[indices_start + l]; + if (dense_idx == -1) { + // -1 means this row has been pruned, do not insert it. + continue; } - // already exists (shouldn't happen in practice) - if (slot_sparse_idx == idx) { - hash_table_acc[table_start + static_cast(slot)][1] = dense_idx; - break; + + auto slot = pruned_hash_function(static_cast(idx)) % capacity; + while (true) { + const auto slot_sparse_idx = hash_table_acc[table_start + static_cast(slot)][0]; + // empty slot + if (slot_sparse_idx == -1) { + hash_table_acc[table_start + static_cast(slot)][0] = idx; + hash_table_acc[table_start + static_cast(slot)][1] = dense_idx; + break; + } + // already exists (shouldn't happen in practice) + if (slot_sparse_idx == idx) { + hash_table_acc[table_start + static_cast(slot)][1] = dense_idx; + break; + } + // linear probe + slot = (slot + 1) % capacity; } - // linear probe - slot = (slot + 1) % capacity; } } } - } + }); + return; } @@ -414,7 +424,7 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu( TENSOR_ON_CPU(offsets); TENSOR_ON_CPU(hash_table); TENSOR_ON_CPU(hash_table_offsets); - TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets); int32_t T = hash_table_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; @@ -428,9 +438,9 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu( const auto* indices_acc = indices.data_ptr(); auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - const auto hash_table_acc = hash_table.accessor(); + + const auto hash_table_acc = hash_table.accessor(); const auto hash_table_offsets_acc = hash_table_offsets.accessor(); for (const auto t : c10::irange(T)) { @@ -463,7 +473,7 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu( } // already exists if (slot_sparse_idx == idx) { - dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast(slot)][1]; + dense_indices_acc[indices_start + l] = static_cast(hash_table_acc[table_start + static_cast(slot)][1]); break; } // linear probe @@ -489,7 +499,7 @@ Tensor pruned_array_lookup_cpu( TENSOR_ON_CPU(offsets); TENSOR_ON_CPU(index_remappings); TENSOR_ON_CPU(index_remappings_offsets); - TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets); int32_t T = index_remappings_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; @@ -502,7 +512,7 @@ Tensor pruned_array_lookup_cpu( auto* dense_indices_acc = dense_indices.data_ptr(); const auto* offsets_acc = offsets.data_ptr(); - const auto index_remappings_acc = index_remappings.data_ptr(); + const auto index_remappings_acc = index_remappings.data_ptr(); const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr(); at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) { @@ -517,7 +527,7 @@ Tensor pruned_array_lookup_cpu( if (capacity > 0) { for (const auto i : c10::irange(indices_start, indices_end)) { auto idx = indices_acc[i]; - dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx]; + dense_indices_acc[i] = static_cast(index_remappings_acc[index_remappings_start + idx]); } } else { std::memcpy( diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 846cd47636..54cfcb3e69 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -21,7 +21,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru indices, const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor64 + const pta::PackedTensorAccessor64 hash_table, const pta::PackedTensorAccessor32 hash_table_offsets, @@ -103,7 +103,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru indices, const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 index_remappings, const pta::PackedTensorAccessor32 index_remappings_offsets, @@ -129,7 +129,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru for (index_t l = threadIdx.x; l < L; l += blockDim.x) { index_t idx = indices[indices_start + l]; dense_indices[indices_start + l] = - index_remappings[index_remappings_start + idx]; + static_cast(index_remappings[index_remappings_start + idx]); } } else { for (index_t l = threadIdx.x; l < L; l += blockDim.x) { @@ -149,7 +149,7 @@ Tensor pruned_hashmap_lookup_cuda( Tensor hash_table_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, hash_table, hash_table_offsets); - TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets); CUDA_DEVICE_GUARD(indices); @@ -173,7 +173,7 @@ Tensor pruned_hashmap_lookup_cuda( at::cuda::getCurrentCUDAStream()>>>( MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, hash_table, index_t, 2, 64), + MAKE_PTA_WITH_NAME(func_name, hash_table, int64_t, 2, 64), MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32), B, T, @@ -191,7 +191,7 @@ Tensor pruned_array_lookup_cuda( Tensor index_remappings_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, index_remappings, index_remappings_offsets); - TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets); CUDA_DEVICE_GUARD(indices); @@ -231,7 +231,7 @@ Tensor pruned_array_lookup_cuda( at::cuda::getCurrentCUDAStream()>>>( MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), B, T, diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index f1671d29d1..ccb13987ae 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -403,7 +403,7 @@ def max_ty_D(ty: SparseType) -> int: ) self.register_buffer( "index_remappings_array", - torch.empty(0, device=self.current_device, dtype=torch.int32), + torch.empty(0, device=self.current_device, dtype=torch.int64), ) self.register_buffer( "index_remapping_hash_table_offsets", @@ -411,7 +411,7 @@ def max_ty_D(ty: SparseType) -> int: ) self.register_buffer( "index_remapping_hash_table", - torch.empty(0, device=self.current_device, dtype=torch.int32), + torch.empty(0, device=self.current_device, dtype=torch.int64), ) self.register_buffer( "original_rows_per_table", @@ -946,8 +946,9 @@ def reset_embedding_spec_location( @torch.jit.export def recompute_module_buffers(self) -> None: """ - Compute module buffers that're on meta device and are not materialized in reset_weights_placements_and_offsets(). - Currently those buffers are `weights_tys`, `rows_per_table`, `D_offsets` and `bounds_check_warning`. + Compute module buffers that're on meta device and are not materialized + in reset_weights_placements_and_offsets(). Currently those buffers are + `weights_tys`, `rows_per_table`, `D_offsets` and `bounds_check_warning`. Pruning related or uvm related buffers are not computed right now. """ if ( @@ -1527,11 +1528,11 @@ def set_index_remappings_array( index_remappings_filter_nones.append(mapping) if len(index_remappings_filter_nones) == 0: self.index_remappings_array = torch.empty( - 0, dtype=torch.int32, device=self.current_device + 0, dtype=torch.int64, device=self.current_device ) else: self.index_remappings_array = torch.cat(index_remappings_filter_nones).to( - self.current_device + dtype=torch.int64, device=self.current_device ) def set_index_remappings( @@ -1554,7 +1555,7 @@ def set_index_remappings( ] hash_table = torch.empty( (sum(capacities), 2), - dtype=torch.int32, + dtype=torch.int64, ) hash_table[:, :] = -1 hash_table_offsets = torch.tensor([0] + list(accumulate(capacities))).long() diff --git a/fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py b/fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py index 1aee221a2e..920a86cbde 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py @@ -87,7 +87,7 @@ def get_nbit_weights_ty(draw) -> Optional[SparseType]: # @optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators) -class NBitFowardTest(unittest.TestCase): +class NBitFowardAutovecTest(unittest.TestCase): def execute_nbit_forward_( # noqa C901 self, T: int, diff --git a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py index f2872bb4ec..c9d75b53a2 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py @@ -122,6 +122,11 @@ def get_nbit_weights_ty(draw) -> Optional[SparseType]: "Operator outputs int4 tensors which do not support opcheck tests" ), }, + "test_pt2_compliant_tag_fbgemm_int_nbit_split_embedding_codegen_lookup_function": [ + unittest.skip( + "Operator outputs int4 tensors which do not support opcheck tests" + ), + ], } @@ -327,6 +332,7 @@ def execute_nbit_forward_( # noqa C901 use_array_for_index_remapping: bool, do_pruning: bool, mixed_weights_ty: bool, + indices_dtype: torch.dtype, output_dtype: SparseType, ) -> None: # NOTE: weighted operation can be done only for SUM. @@ -533,19 +539,22 @@ def execute_nbit_forward_( # noqa C901 fp8_config=fp8_config if has_fp8_weight else None, ) + indices = indices.to(dtype=indices_dtype) + offsets = offsets.to(dtype=indices_dtype) + if not use_cpu: fc2 = ( - cc(indices.int(), offsets.int()) + cc(indices, offsets) if not weighted - else cc(indices.int(), offsets.int(), xw.contiguous().view(-1)) + else cc(indices, offsets, xw.contiguous().view(-1)) ) else: cc = cc.cpu() indices, offsets = indices.cpu(), offsets.cpu() fc2 = ( - cc(indices.int(), offsets.int()) + cc(indices, offsets) if not weighted - else cc(indices.int(), offsets.int(), xw.contiguous().view(-1).cpu()) + else cc(indices, offsets, xw.contiguous().view(-1).cpu()) ) if do_pooling and B == 0: @@ -589,6 +598,9 @@ def execute_nbit_forward_( # noqa C901 ) else: fc2_float = fc2.float() + + print(fc2_float.cpu()) + print(f.float().cpu()) torch.testing.assert_close( fc2_float.cpu(), f.float().cpu(), @@ -603,6 +615,7 @@ def execute_nbit_forward_( # noqa C901 pooling_mode=st.sampled_from( [PoolingMode.SUM, PoolingMode.NONE, PoolingMode.MEAN] ), + indices_dtype=st.sampled_from([torch.int32, torch.int64]), output_dtype=st.sampled_from( [SparseType.FP32, SparseType.FP16, SparseType.BF16] ), @@ -618,6 +631,7 @@ def test_nbit_forward_cpu( use_array_for_index_remapping: bool, do_pruning: bool, pooling_mode: PoolingMode, + indices_dtype: torch.dtype, output_dtype: SparseType, ) -> None: use_cpu = True @@ -661,11 +675,18 @@ def test_nbit_forward_cpu( use_array_for_index_remapping, do_pruning, mixed_weights_ty, + indices_dtype, output_dtype, ) + @given( + indices_dtype=st.sampled_from([torch.int32, torch.int64]), + ) @unittest.skipIf(*gpu_unavailable) - def test_nbit_forward_gpu_no_cache_fp8_2048(self) -> None: + @settings(deadline=None) + def test_nbit_forward_gpu_no_cache_fp8_2048( + self, indices_dtype: torch.dtype + ) -> None: # Test the case of FB8 table with 128B*8 < D <= 128B*16 self.execute_nbit_forward_( T=1, @@ -683,6 +704,7 @@ def test_nbit_forward_gpu_no_cache_fp8_2048(self) -> None: use_array_for_index_remapping=True, do_pruning=False, mixed_weights_ty=False, + indices_dtype=indices_dtype, output_dtype=SparseType.FP16, ) @@ -691,6 +713,7 @@ def test_nbit_forward_gpu_no_cache_fp8_2048(self) -> None: nbit_weights_ty=get_nbit_weights_ty(), use_array_for_index_remapping=st.booleans(), do_pruning=st.booleans(), + indices_dtype=st.sampled_from([torch.int32, torch.int64]), ) @settings( verbosity=VERBOSITY, @@ -702,6 +725,7 @@ def test_nbit_forward_gpu_no_cache( nbit_weights_ty: Optional[SparseType], use_array_for_index_remapping: bool, do_pruning: bool, + indices_dtype: torch.dtype, ) -> None: use_cpu = False T = random.randint(1, 50) @@ -756,6 +780,7 @@ def test_nbit_forward_gpu_no_cache( use_array_for_index_remapping, do_pruning, mixed_weights_ty, + indices_dtype, output_dtype, ) @@ -978,6 +1003,7 @@ def test_nbit_forward_cpu_seq_int8( T=st.integers(min_value=10, max_value=20), L=st.integers(min_value=10, max_value=100), MAXH=st.integers(min_value=50, max_value=100), + indices_dtype=st.sampled_from([torch.int32, torch.int64]), ) @settings( verbosity=VERBOSITY, @@ -991,6 +1017,7 @@ def test_nbit_forward_cpu_seq_int4( T: int, L: int, MAXH: int, + indices_dtype: torch.dtype, ) -> None: """ we init a quant table split embedding bag with int4 weights and scale of 1 and 0 bias @@ -1012,6 +1039,7 @@ def test_nbit_forward_cpu_seq_int4( use_array_for_index_remapping=True, do_pruning=False, mixed_weights_ty=False, + indices_dtype=indices_dtype, output_dtype=SparseType.INT4, )