diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp index 55f37eb162..92eff015f4 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp @@ -41,6 +41,16 @@ inline uint32_t pruned_hash_function(uint32_t h) { return h; } +inline uint64_t pruned_hash_function(uint64_t k) { + // MurmorHash3 64-bit mixing function. + k ^= k >> 33; + k *= (0xff51afd7ed558ccd); + k ^= k >> 33; + k *= (0xc4ceb9fe1a85ec53); + k ^= k >> 33; + return k; +} + } // namespace void pruned_hashmap_insert_{{ wdesc }}_cpu( @@ -404,54 +414,67 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu( TENSOR_ON_CPU(offsets); TENSOR_ON_CPU(hash_table); TENSOR_ON_CPU(hash_table_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table); int32_t T = hash_table_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); + auto dense_indices = empty_like(indices); - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - const auto hash_table_acc = hash_table.accessor(); - const auto hash_table_offsets_acc = hash_table_offsets.accessor(); -for (const auto t : c10::irange(T)) { - int64_t table_start = hash_table_offsets_acc[t]; - int64_t table_end = hash_table_offsets_acc[t + 1]; - int64_t capacity = table_end - table_start; -for (const auto b : c10::irange(B)) { - int32_t indices_start = offsets_acc[t * B + b]; - int32_t indices_end = offsets_acc[t * B + b + 1]; - int32_t L = indices_end - indices_start; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_{{ wdesc }}_cpu", [&] { + using hash_t = + std::conditional_t, uint64_t, uint32_t>; - if (table_start == table_end) { -for (const auto l : c10::irange(L)) { - dense_indices_acc[indices_start + l] = indices_acc[indices_start + l]; - } - } else { -for (const auto l : c10::irange(L)) { - int32_t idx = indices_acc[indices_start + l]; - uint32_t slot = pruned_hash_function(static_cast(idx)) % capacity; - while (true) { - int32_t slot_sparse_idx = hash_table_acc[table_start + static_cast(slot)][0]; - - // empty slot - if (slot_sparse_idx == -1) { - dense_indices_acc[indices_start + l] = -1; - break; - } - // already exists - if (slot_sparse_idx == idx) { - dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast(slot)][1]; - break; + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + + const auto* offsets_acc = offsets.data_ptr(); + const auto hash_table_acc = hash_table.accessor(); + const auto hash_table_offsets_acc = hash_table_offsets.accessor(); + + for (const auto t : c10::irange(T)) { + const auto table_start = hash_table_offsets_acc[t]; + const auto table_end = hash_table_offsets_acc[t + 1]; + const auto capacity = table_end - table_start; + + for (const auto b : c10::irange(B)) { + const auto indices_start = offsets_acc[t * B + b]; + const auto indices_end = offsets_acc[t * B + b + 1]; + const auto L = indices_end - indices_start; + + if (table_start == table_end) { + for (const auto l : c10::irange(L)) { + dense_indices_acc[indices_start + l] = indices_acc[indices_start + l]; + } + + } else { + for (const auto l : c10::irange(L)) { + const auto idx = indices_acc[indices_start + l]; + auto slot = pruned_hash_function(static_cast(idx)) % capacity; + + while (true) { + const auto slot_sparse_idx = hash_table_acc[table_start + static_cast(slot)][0]; + + // empty slot + if (slot_sparse_idx == -1) { + dense_indices_acc[indices_start + l] = -1; + break; + } + // already exists + if (slot_sparse_idx == idx) { + dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast(slot)][1]; + break; + } + // linear probe + slot = (slot + 1) % capacity; } - // linear probe - slot = (slot + 1) % capacity; } } } } - } + }); + return dense_indices; } @@ -466,6 +489,7 @@ Tensor pruned_array_lookup_cpu( TENSOR_ON_CPU(offsets); TENSOR_ON_CPU(index_remappings); TENSOR_ON_CPU(index_remappings_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); int32_t T = index_remappings_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 86165bb39f..846cd47636 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -160,7 +160,7 @@ Tensor pruned_hashmap_lookup_cuda( TORCH_CHECK(hash_table.size(0) < std::numeric_limits::max()); constexpr size_t kForwardMaxThreads = 256; - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup", [&] { + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_cuda", [&] { #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; @@ -218,7 +218,7 @@ Tensor pruned_array_lookup_cuda( TORCH_CHECK(dense_indices.dim() == 1, "Tensor dim: ", dense_indices.dim()); constexpr size_t kForwardMaxThreads = 256; - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup", [&] { + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cuda", [&] { #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = "int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel";