diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp index 2b126c96d..55f37eb16 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp @@ -456,6 +456,7 @@ for (const auto l : c10::irange(L)) { } {% if not weighted %} + Tensor pruned_array_lookup_cpu( Tensor indices, Tensor offsets, @@ -469,33 +470,41 @@ Tensor pruned_array_lookup_cpu( int32_t T = index_remappings_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); + auto dense_indices = empty_like(indices); - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - const auto index_remappings_acc = index_remappings.data_ptr(); - const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr(); - at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) { - for (const auto t : c10::irange(begin, end)) { - int64_t index_remappings_start = index_remappings_offsets_acc[t]; - int64_t index_remappings_end = index_remappings_offsets_acc[t + 1]; - int64_t capacity = index_remappings_end - index_remappings_start; - int32_t indices_start = offsets_acc[t * B]; - int32_t indices_end = offsets_acc[(t + 1) * B]; - if (capacity > 0) { - for (const auto i : c10::irange(indices_start,indices_end)) { - int32_t idx = indices_acc[i]; - dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx]; - } - } else { - std::memcpy( - dense_indices_acc + indices_start, - indices_acc + indices_start, - (indices_end - indices_start) * sizeof(int32_t)); - } - } + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cpu", [&] { + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + const auto* offsets_acc = offsets.data_ptr(); + + const auto index_remappings_acc = index_remappings.data_ptr(); + const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr(); + + at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) { + for (const auto t : c10::irange(begin, end)) { + const auto index_remappings_start = index_remappings_offsets_acc[t]; + const auto index_remappings_end = index_remappings_offsets_acc[t + 1]; + const auto capacity = index_remappings_end - index_remappings_start; + + const auto indices_start = offsets_acc[t * B]; + const auto indices_end = offsets_acc[(t + 1) * B]; + + if (capacity > 0) { + for (const auto i : c10::irange(indices_start, indices_end)) { + auto idx = indices_acc[i]; + dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx]; + } + } else { + std::memcpy( + dense_indices_acc + indices_start, + indices_acc + indices_start, + (indices_end - indices_start) * sizeof(index_t)); + } + } + }); }); + return dense_indices; }