diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp index 41fd137dda..b6f55b961e 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp @@ -21,6 +21,7 @@ #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/ops_utils.h" +#include "fbgemm_gpu/utils/tensor_utils.h" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -374,29 +375,37 @@ class PrunedMapCPU : public torch::jit::CustomClassHolder { } Tensor lookup(Tensor indices, Tensor offsets) const { + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets); + int32_t T = maps_.size(); TORCH_CHECK(T > 0); int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); TORCH_CHECK(maps_.size() == T); + auto dense_indices = empty_like(indices); - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - for (const auto t : c10::irange(T)) { - auto& map = maps_[t]; - for (const auto b : c10::irange(B)) { - int32_t indices_start = offsets_acc[t * B + b]; - int32_t indices_end = offsets_acc[t * B + b + 1]; - int32_t L = indices_end - indices_start; - for (const auto l : c10::irange(L)) { - int32_t slot_sparse_index = indices_acc[indices_start + l]; - auto it = map.find(slot_sparse_index); - dense_indices_acc[indices_start + l] = - it != map.end() ? it->second : -1; + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "PrunedMapCPU::lookup", [&] { + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + const auto* offsets_acc = offsets.data_ptr(); + + for (const auto t : c10::irange(T)) { + auto& map = maps_[t]; + for (const auto b : c10::irange(B)) { + const auto indices_start = offsets_acc[t * B + b]; + const auto indices_end = offsets_acc[t * B + b + 1]; + const auto L = indices_end - indices_start; + for (const auto l : c10::irange(L)) { + const auto slot_sparse_index = indices_acc[indices_start + l]; + const auto it = map.find(slot_sparse_index); + dense_indices_acc[indices_start + l] = + it != map.end() ? it->second : -1; + } } } - } + }); + return dense_indices; } diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu index 08e22baa9e..8d8ee6ab53 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu @@ -233,22 +233,24 @@ void bounds_check_indices_cuda( constexpr size_t kNumThreads = 256; const auto max_B_ = vbe ? max_B : B; - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] { - const auto bounds_check_kernel = - (vbe ? bounds_check_indices_kernel - : bounds_check_indices_kernel); - TORCH_DSA_KERNEL_LAUNCH( - bounds_check_kernel, - div_round_up(max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSize), - dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), - 0, - at::cuda::getCurrentCUDAStream(), - rows_per_table.packed_accessor32(), - indices.packed_accessor32(), - offsets.packed_accessor32(), - vbe ? B_offsets.value().data_ptr() : nullptr, - bounds_check_mode_, - warning.packed_accessor32(), - FixedDivisor(max_B_)); - }); + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), "bounds_check_indices_cuda", [&] { + const auto bounds_check_kernel = + (vbe ? bounds_check_indices_kernel + : bounds_check_indices_kernel); + TORCH_DSA_KERNEL_LAUNCH( + bounds_check_kernel, + div_round_up(max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSize), + dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), + 0, + at::cuda::getCurrentCUDAStream(), + rows_per_table + .packed_accessor32(), + indices.packed_accessor32(), + offsets.packed_accessor32(), + vbe ? B_offsets.value().data_ptr() : nullptr, + bounds_check_mode_, + warning.packed_accessor32(), + FixedDivisor(max_B_)); + }); } diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp index 1098378d04..1d0cd1348b 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp @@ -70,7 +70,7 @@ void bounds_check_indices_cpu( const auto rows_per_table_acc = rows_per_table.accessor(); auto warning_acc = warning.data_ptr(); - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] { + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices_cpu", [&] { auto offsets_acc = offsets.accessor(); auto indices_acc = indices.accessor(); auto num_indices = indices.numel();