diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp index 2b126c96d..92eff015f 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp @@ -41,6 +41,16 @@ inline uint32_t pruned_hash_function(uint32_t h) { return h; } +inline uint64_t pruned_hash_function(uint64_t k) { + // MurmorHash3 64-bit mixing function. + k ^= k >> 33; + k *= (0xff51afd7ed558ccd); + k ^= k >> 33; + k *= (0xc4ceb9fe1a85ec53); + k ^= k >> 33; + return k; +} + } // namespace void pruned_hashmap_insert_{{ wdesc }}_cpu( @@ -404,58 +414,72 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu( TENSOR_ON_CPU(offsets); TENSOR_ON_CPU(hash_table); TENSOR_ON_CPU(hash_table_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table); int32_t T = hash_table_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); + auto dense_indices = empty_like(indices); - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - const auto hash_table_acc = hash_table.accessor(); - const auto hash_table_offsets_acc = hash_table_offsets.accessor(); -for (const auto t : c10::irange(T)) { - int64_t table_start = hash_table_offsets_acc[t]; - int64_t table_end = hash_table_offsets_acc[t + 1]; - int64_t capacity = table_end - table_start; -for (const auto b : c10::irange(B)) { - int32_t indices_start = offsets_acc[t * B + b]; - int32_t indices_end = offsets_acc[t * B + b + 1]; - int32_t L = indices_end - indices_start; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_{{ wdesc }}_cpu", [&] { + using hash_t = + std::conditional_t, uint64_t, uint32_t>; - if (table_start == table_end) { -for (const auto l : c10::irange(L)) { - dense_indices_acc[indices_start + l] = indices_acc[indices_start + l]; - } - } else { -for (const auto l : c10::irange(L)) { - int32_t idx = indices_acc[indices_start + l]; - uint32_t slot = pruned_hash_function(static_cast(idx)) % capacity; - while (true) { - int32_t slot_sparse_idx = hash_table_acc[table_start + static_cast(slot)][0]; - - // empty slot - if (slot_sparse_idx == -1) { - dense_indices_acc[indices_start + l] = -1; - break; - } - // already exists - if (slot_sparse_idx == idx) { - dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast(slot)][1]; - break; + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + + const auto* offsets_acc = offsets.data_ptr(); + const auto hash_table_acc = hash_table.accessor(); + const auto hash_table_offsets_acc = hash_table_offsets.accessor(); + + for (const auto t : c10::irange(T)) { + const auto table_start = hash_table_offsets_acc[t]; + const auto table_end = hash_table_offsets_acc[t + 1]; + const auto capacity = table_end - table_start; + + for (const auto b : c10::irange(B)) { + const auto indices_start = offsets_acc[t * B + b]; + const auto indices_end = offsets_acc[t * B + b + 1]; + const auto L = indices_end - indices_start; + + if (table_start == table_end) { + for (const auto l : c10::irange(L)) { + dense_indices_acc[indices_start + l] = indices_acc[indices_start + l]; + } + + } else { + for (const auto l : c10::irange(L)) { + const auto idx = indices_acc[indices_start + l]; + auto slot = pruned_hash_function(static_cast(idx)) % capacity; + + while (true) { + const auto slot_sparse_idx = hash_table_acc[table_start + static_cast(slot)][0]; + + // empty slot + if (slot_sparse_idx == -1) { + dense_indices_acc[indices_start + l] = -1; + break; + } + // already exists + if (slot_sparse_idx == idx) { + dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast(slot)][1]; + break; + } + // linear probe + slot = (slot + 1) % capacity; } - // linear probe - slot = (slot + 1) % capacity; } } } } - } + }); + return dense_indices; } {% if not weighted %} + Tensor pruned_array_lookup_cpu( Tensor indices, Tensor offsets, @@ -465,37 +489,46 @@ Tensor pruned_array_lookup_cpu( TENSOR_ON_CPU(offsets); TENSOR_ON_CPU(index_remappings); TENSOR_ON_CPU(index_remappings_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); int32_t T = index_remappings_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); + auto dense_indices = empty_like(indices); - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - const auto index_remappings_acc = index_remappings.data_ptr(); - const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr(); - at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) { - for (const auto t : c10::irange(begin, end)) { - int64_t index_remappings_start = index_remappings_offsets_acc[t]; - int64_t index_remappings_end = index_remappings_offsets_acc[t + 1]; - int64_t capacity = index_remappings_end - index_remappings_start; - int32_t indices_start = offsets_acc[t * B]; - int32_t indices_end = offsets_acc[(t + 1) * B]; - if (capacity > 0) { - for (const auto i : c10::irange(indices_start,indices_end)) { - int32_t idx = indices_acc[i]; - dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx]; - } - } else { - std::memcpy( - dense_indices_acc + indices_start, - indices_acc + indices_start, - (indices_end - indices_start) * sizeof(int32_t)); - } - } + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cpu", [&] { + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + const auto* offsets_acc = offsets.data_ptr(); + + const auto index_remappings_acc = index_remappings.data_ptr(); + const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr(); + + at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) { + for (const auto t : c10::irange(begin, end)) { + const auto index_remappings_start = index_remappings_offsets_acc[t]; + const auto index_remappings_end = index_remappings_offsets_acc[t + 1]; + const auto capacity = index_remappings_end - index_remappings_start; + + const auto indices_start = offsets_acc[t * B]; + const auto indices_end = offsets_acc[(t + 1) * B]; + + if (capacity > 0) { + for (const auto i : c10::irange(indices_start, indices_end)) { + auto idx = indices_acc[i]; + dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx]; + } + } else { + std::memcpy( + dense_indices_acc + indices_start, + indices_acc + indices_start, + (indices_end - indices_start) * sizeof(index_t)); + } + } + }); }); + return dense_indices; } diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp index 41fd137dd..b6f55b961 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp @@ -21,6 +21,7 @@ #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/ops_utils.h" +#include "fbgemm_gpu/utils/tensor_utils.h" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -374,29 +375,37 @@ class PrunedMapCPU : public torch::jit::CustomClassHolder { } Tensor lookup(Tensor indices, Tensor offsets) const { + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets); + int32_t T = maps_.size(); TORCH_CHECK(T > 0); int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); TORCH_CHECK(maps_.size() == T); + auto dense_indices = empty_like(indices); - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - for (const auto t : c10::irange(T)) { - auto& map = maps_[t]; - for (const auto b : c10::irange(B)) { - int32_t indices_start = offsets_acc[t * B + b]; - int32_t indices_end = offsets_acc[t * B + b + 1]; - int32_t L = indices_end - indices_start; - for (const auto l : c10::irange(L)) { - int32_t slot_sparse_index = indices_acc[indices_start + l]; - auto it = map.find(slot_sparse_index); - dense_indices_acc[indices_start + l] = - it != map.end() ? it->second : -1; + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "PrunedMapCPU::lookup", [&] { + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + const auto* offsets_acc = offsets.data_ptr(); + + for (const auto t : c10::irange(T)) { + auto& map = maps_[t]; + for (const auto b : c10::irange(B)) { + const auto indices_start = offsets_acc[t * B + b]; + const auto indices_end = offsets_acc[t * B + b + 1]; + const auto L = indices_end - indices_start; + for (const auto l : c10::irange(L)) { + const auto slot_sparse_index = indices_acc[indices_start + l]; + const auto it = map.find(slot_sparse_index); + dense_indices_acc[indices_start + l] = + it != map.end() ? it->second : -1; + } } } - } + }); + return dense_indices; } diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 7d4eebcce..846cd4763 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -14,19 +14,20 @@ using Tensor = at::Tensor; namespace nbit { +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor64 + const pta::PackedTensorAccessor64 hash_table, const pta::PackedTensorAccessor32 hash_table_offsets, const int32_t B, const int32_t T, - pta::PackedTensorAccessor32 + pta::PackedTensorAccessor32 dense_indices) { // uint32_t capacity = hash_table.size(0); const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; @@ -35,9 +36,9 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru if (b_t >= B * T) { return; } - const int32_t indices_start = offsets[t * B + b]; - const int32_t indices_end = offsets[t * B + b + 1]; - const int32_t L = indices_end - indices_start; + const index_t indices_start = offsets[t * B + b]; + const index_t indices_end = offsets[t * B + b + 1]; + const index_t L = indices_end - indices_start; const int64_t table_start = hash_table_offsets[t]; const int64_t table_end = hash_table_offsets[t + 1]; @@ -51,6 +52,9 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru return; } + using hash_t = + std::conditional_t, uint64_t, uint32_t>; + const uint32_t subwarp_id = threadIdx.x / 4; const uint32_t subwarp_tid = threadIdx.x % 4; #ifdef USE_ROCM @@ -58,13 +62,15 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru #else const uint32_t subwarp_mask = static_cast(0xF) << (4 * subwarp_id); #endif + for (int32_t l_start = 0; l_start + subwarp_id < L; l_start += kWarpSize / 4) { - const int32_t idx = indices[indices_start + l_start + subwarp_id]; - uint32_t slot_start = - pruned_hash_function(static_cast(idx)) % capacity; + const index_t idx = indices[indices_start + l_start + subwarp_id]; + hash_t slot_start = + pruned_hash_function(static_cast(idx)) % capacity; + while (true) { - const uint32_t slot = (slot_start + subwarp_tid) % capacity; + const hash_t slot = (slot_start + subwarp_tid) % capacity; const int2 val = *reinterpret_cast( &hash_table[table_start + static_cast(slot)][0]); const int32_t slot_sparse_idx = val.x; @@ -78,6 +84,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru found = true; dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx; } + if (__any_sync(subwarp_mask, found)) { break; } else if (__any_sync(subwarp_mask, empty)) { @@ -89,19 +96,20 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru } } +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 index_remappings, const pta::PackedTensorAccessor32 index_remappings_offsets, const int32_t B, const int32_t T, - pta::PackedTensorAccessor32 + pta::PackedTensorAccessor32 dense_indices) { const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; const int32_t t = b_t / B; @@ -109,22 +117,22 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru if (b_t >= B * T) { return; } - const int32_t indices_start = offsets[t * B + b]; - const int32_t indices_end = offsets[t * B + b + 1]; - const int32_t L = indices_end - indices_start; + const index_t indices_start = offsets[t * B + b]; + const index_t indices_end = offsets[t * B + b + 1]; + const index_t L = indices_end - indices_start; const int64_t index_remappings_start = index_remappings_offsets[t]; const int64_t index_remappings_end = index_remappings_offsets[t + 1]; const int64_t capacity = index_remappings_end - index_remappings_start; if (capacity > 0) { - for (int32_t l = threadIdx.x; l < L; l += blockDim.x) { - int32_t idx = indices[indices_start + l]; + for (index_t l = threadIdx.x; l < L; l += blockDim.x) { + index_t idx = indices[indices_start + l]; dense_indices[indices_start + l] = index_remappings[index_remappings_start + idx]; } } else { - for (int32_t l = threadIdx.x; l < L; l += blockDim.x) { + for (index_t l = threadIdx.x; l < L; l += blockDim.x) { dense_indices[indices_start + l] = indices[indices_start + l]; } } @@ -132,6 +140,8 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru } // namespace nbit +using namespace nbit; + Tensor pruned_hashmap_lookup_cuda( Tensor indices, Tensor offsets, @@ -139,6 +149,7 @@ Tensor pruned_hashmap_lookup_cuda( Tensor hash_table_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, hash_table, hash_table_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table); CUDA_DEVICE_GUARD(indices); @@ -149,23 +160,25 @@ Tensor pruned_hashmap_lookup_cuda( TORCH_CHECK(hash_table.size(0) < std::numeric_limits::max()); constexpr size_t kForwardMaxThreads = 256; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_cuda", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + const auto func_name = + "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; #endif - nbit::int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<< - nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, hash_table, int32_t, 2, 64), - MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32), - B, - T, - MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32)); + int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<< + nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, hash_table, index_t, 2, 64), + MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32), + B, + T, + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); return dense_indices; @@ -178,6 +191,7 @@ Tensor pruned_array_lookup_cuda( Tensor index_remappings_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, index_remappings, index_remappings_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); CUDA_DEVICE_GUARD(indices); @@ -204,23 +218,26 @@ Tensor pruned_array_lookup_cuda( TORCH_CHECK(dense_indices.dim() == 1, "Tensor dim: ", dense_indices.dim()); constexpr size_t kForwardMaxThreads = 256; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cuda", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + const auto func_name = + "int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel"; #endif - nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< - nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), - B, - T, - MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32)); + int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< + nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), + B, + T, + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); + }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return dense_indices; } diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index bc4e7ba74..5dd5c30b1 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -7,7 +7,7 @@ */ // clang-format off -{% set wdesc = "weighted" if weighted else "unweighted" %} +{%- set wdesc = "weighted" if weighted else "unweighted" %} #include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor.h" @@ -22,7 +22,7 @@ namespace nbit { `Tensor int_nbit_split_embedding*_codegen_forward_*_cuda(...)` later in the same generated source file. */ -{% for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} +{%- for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} template __launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( @@ -31,30 +31,30 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no const pta::PackedTensorAccessor32 weights_placements, const pta::PackedTensorAccessor32 weights_offsets, const pta::PackedTensorAccessor32 weights_tys, - {% if not nobag %} + {%- if not nobag %} const pta::PackedTensorAccessor32 D_offsets, - {% else %} + {%- else %} const int64_t D, - {% endif %} + {%- endif %} FixedDivisor fd_B, // FixedDivisor(div_round_up(B, OutputRowsPerThread)) const pta::PackedTensorAccessor32 indices, const pta::PackedTensorAccessor32 offsets, - {% if not nobag %} + {%- if not nobag %} const int64_t pooling_mode, - {% endif %} + {%- endif %} const int64_t row_alignment, - {% if weighted %} + {%- if weighted %} pta::PackedTensorAccessor32 indice_weights, - {% endif %} - {% if type_map[emb_weight_type].enum_name == "FP8" %} + {%- endif %} + {%- if type_map[emb_weight_type].enum_name == "FP8" %} const int fp8_exponent_bits, const int fp8_exponent_bias, - {% endif %} + {%- endif %} pta::PackedTensorAccessor32 output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations ); -{% endfor %} // for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] +{%- endfor %} // for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] } @@ -107,58 +107,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no C10_CUDA_KERNEL_LAUNCH_CHECK(); \ {%- endmacro %} - -Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda( - Tensor dev_weights, - Tensor uvm_weights, - Tensor weights_placements, - Tensor weights_offsets, - Tensor weights_tys, - {% if not nobag %} - Tensor D_offsets, - const int64_t total_D, - {% else %} - const int64_t D, - {% endif %} - const int64_t max_int2_D, - const int64_t max_int4_D, - const int64_t max_int8_D, - const int64_t max_float16_D, - const int64_t max_float32_D, - Tensor indices, - Tensor offsets, - {% if not nobag %} - const int64_t pooling_mode, - {% endif %} - const int64_t row_alignment, - {% if weighted %} - Tensor indice_weights, - {% endif %} - const int64_t output_dtype, - Tensor lxu_cache_weights, - Tensor lxu_cache_locations, - const int64_t max_float8_D, - const int64_t fp8_exponent_bits, - const int64_t fp8_exponent_bias -) { - TENSOR_ON_CUDA_GPU(dev_weights); - TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); - {% if not nobag %} - TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); - {% endif %} - TENSORS_ON_SAME_DEVICE(indices, dev_weights); - TENSORS_ON_SAME_DEVICE(offsets, dev_weights); - {% if weighted %} - TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); - {% endif %} - TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); - TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); - - CUDA_DEVICE_GUARD(dev_weights); - +{%- macro construct_and_return_output_tensor() %} // kernels assume indices are contiguous. indices = indices.contiguous(); @@ -180,8 +129,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ TORCH_CHECK(D > 0); {%- endif %} + // Construct output tensor Tensor output; const int kINT8QparamsBytes = 8; + SparseType o_dtype = static_cast(output_dtype); TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); @@ -216,11 +167,63 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ if (B == 0 || indices.numel() == 0) { return output; } +{%- endmacro %} - using index_t = int32_t; +template +Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda_impl( + Tensor dev_weights, + Tensor uvm_weights, + Tensor weights_placements, + Tensor weights_offsets, + Tensor weights_tys, + {%- if not nobag %} + Tensor D_offsets, + const int64_t total_D, + {%- else %} + const int64_t D, + {%- endif %} + const int64_t max_int2_D, + const int64_t max_int4_D, + const int64_t max_int8_D, + const int64_t max_float16_D, + const int64_t max_float32_D, + Tensor indices, + Tensor offsets, + {%- if not nobag %} + const int64_t pooling_mode, + {%- endif %} + const int64_t row_alignment, + {%- if weighted %} + Tensor indice_weights, + {%- endif %} + const int64_t output_dtype, + Tensor lxu_cache_weights, + Tensor lxu_cache_locations, + const int64_t max_float8_D, + const int64_t fp8_exponent_bits, + const int64_t fp8_exponent_bias +) { + TENSOR_ON_CUDA_GPU(dev_weights); + TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); + {%- if not nobag %} + TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); + {%- endif %} + TENSORS_ON_SAME_DEVICE(indices, dev_weights); + TENSORS_ON_SAME_DEVICE(offsets, dev_weights); + {%- if weighted %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); + {%- endif %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); - constexpr int32_t kWarpsPerBlock = 4; + CUDA_DEVICE_GUARD(dev_weights); + + {{- construct_and_return_output_tensor() }} + constexpr int32_t kWarpsPerBlock = 4; const auto device_only = lxu_cache_weights.numel() == 0 && uvm_weights.numel() == 0; #define Y(...) \ if (device_only) { \ @@ -397,6 +400,104 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ })); #undef X + return output; +} + +Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda( + Tensor dev_weights, + Tensor uvm_weights, + Tensor weights_placements, + Tensor weights_offsets, + Tensor weights_tys, + {%- if not nobag %} + Tensor D_offsets, + const int64_t total_D, + {%- else %} + const int64_t D, + {%- endif %} + const int64_t max_int2_D, + const int64_t max_int4_D, + const int64_t max_int8_D, + const int64_t max_float16_D, + const int64_t max_float32_D, + Tensor indices, + Tensor offsets, + {%- if not nobag %} + const int64_t pooling_mode, + {%- endif %} + const int64_t row_alignment, + {%- if weighted %} + Tensor indice_weights, + {%- endif %} + const int64_t output_dtype, + Tensor lxu_cache_weights, + Tensor lxu_cache_locations, + const int64_t max_float8_D, + const int64_t fp8_exponent_bits, + const int64_t fp8_exponent_bias +) { + // All argument tensors need to be on the same CUDA device + TENSOR_ON_CUDA_GPU(dev_weights); + TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); + {%- if not nobag %} + TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); + {%- endif %} + TENSORS_ON_SAME_DEVICE(indices, dev_weights); + TENSORS_ON_SAME_DEVICE(offsets, dev_weights); + {%- if weighted %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); + {%- endif %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); + + // indices and offsets need to have the same scalar type + TENSORS_HAVE_SAME_TYPE(indices, offsets); + // Only int32_t and int64_t indices are supported at the moment + TENSOR_SCALAR_TYPE_IS_ONE_OF(indices, at::ScalarType::Long, at::ScalarType::Int); + + CUDA_DEVICE_GUARD(dev_weights); + + // Create output tensor ref + Tensor output; + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ 'int_nbit_split_embedding' + ('_nobag' if nobag else '') + '_codegen_forward_' + wdesc + '_cuda' }}", [&] { + output = int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda_impl( + dev_weights, + uvm_weights, + weights_placements, + weights_offsets, + weights_tys, + {%- if not nobag %} + D_offsets, + total_D, + {%- else %} + D, + {%- endif %} + max_int2_D, + max_int4_D, + max_int8_D, + max_float16_D, + max_float32_D, + indices, + offsets, + {%- if not nobag %} + pooling_mode, + {%- endif %} + row_alignment, + {%- if weighted %} + indice_weights, + {%- endif %} + output_dtype, + lxu_cache_weights, + lxu_cache_locations, + max_float8_D, + fp8_exponent_bits, + fp8_exponent_bias); + }); + return output; } diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu index 08e22baa9..8d8ee6ab5 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu @@ -233,22 +233,24 @@ void bounds_check_indices_cuda( constexpr size_t kNumThreads = 256; const auto max_B_ = vbe ? max_B : B; - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] { - const auto bounds_check_kernel = - (vbe ? bounds_check_indices_kernel - : bounds_check_indices_kernel); - TORCH_DSA_KERNEL_LAUNCH( - bounds_check_kernel, - div_round_up(max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSize), - dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), - 0, - at::cuda::getCurrentCUDAStream(), - rows_per_table.packed_accessor32(), - indices.packed_accessor32(), - offsets.packed_accessor32(), - vbe ? B_offsets.value().data_ptr() : nullptr, - bounds_check_mode_, - warning.packed_accessor32(), - FixedDivisor(max_B_)); - }); + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), "bounds_check_indices_cuda", [&] { + const auto bounds_check_kernel = + (vbe ? bounds_check_indices_kernel + : bounds_check_indices_kernel); + TORCH_DSA_KERNEL_LAUNCH( + bounds_check_kernel, + div_round_up(max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSize), + dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), + 0, + at::cuda::getCurrentCUDAStream(), + rows_per_table + .packed_accessor32(), + indices.packed_accessor32(), + offsets.packed_accessor32(), + vbe ? B_offsets.value().data_ptr() : nullptr, + bounds_check_mode_, + warning.packed_accessor32(), + FixedDivisor(max_B_)); + }); } diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp index 1098378d0..1d0cd1348 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp @@ -70,7 +70,7 @@ void bounds_check_indices_cpu( const auto rows_per_table_acc = rows_per_table.accessor(); auto warning_acc = warning.data_ptr(); - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] { + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices_cpu", [&] { auto offsets_acc = offsets.accessor(); auto indices_acc = indices.accessor(); auto num_indices = indices.numel(); diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh index 97353e03c..2164afd3e 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh @@ -88,6 +88,7 @@ __device__ inline int32_t padded_D( __device__ inline uint32_t pruned_hash_function(uint32_t h) { // MurmorHash3 32-bit mixing function. + // https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp h ^= h >> 16; h *= 0x85ebca6b; h ^= h >> 13; @@ -96,6 +97,17 @@ __device__ inline uint32_t pruned_hash_function(uint32_t h) { return h; } +__device__ inline uint64_t pruned_hash_function(uint64_t k) { + // MurmorHash3 64-bit mixing function. + // https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp + k ^= k >> 33; + k *= (0xff51afd7ed558ccd); + k ^= k >> 33; + k *= (0xc4ceb9fe1a85ec53); + k ^= k >> 33; + return k; +} + // ---------------------- START cp.async helpers, copied from CUTLASS /// CUTLASS helper to get SMEM pointer diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h index b1ab0306c..60cca19ef 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h @@ -299,3 +299,77 @@ inline at::Tensor aligned_grad_output_tensor_for_cuda_backwards( } return aligned_grad_output; } + +template +std::string tensor_scalar_type_is_one_of( + const at::Tensor& ten, + const ScalarTypes&... ttypes) { + auto has_match = false; + + ( + [&](const auto& ttype) { + if (ten.scalar_type() == ttype) { + has_match = true; + } + }(ttypes), + ...); + + if (has_match) { + return ""; + } + + std::string msg = "Tensor's scalar type ("; + msg.append(toString(ten.scalar_type())); + msg.append(") did not match any one of the following types: ["); + ( + [&](const auto& ttype) { + msg.append(toString(ttype)); + msg.append(", "); + }(ttypes), + ...); + + msg.append("]"); + return msg; +} + +#define TENSOR_SCALAR_TYPE_IS_ONE_OF(...) \ + do { \ + const auto has_match = tensor_scalar_type_is_one_of(__VA_ARGS__); \ + TORCH_CHECK(has_match.empty(), has_match); \ + } while (false) + +template +std::string tensors_have_same_scalar_type(const Tensors&... tensors) { + std::optional dtype; + bool have_same_type = true; + + ( + [&](const auto& tensor) { + if (!dtype) { + dtype = tensor.scalar_type(); + } else if (*dtype != tensor.scalar_type()) { + have_same_type = false; + } + }(tensors), + ...); + + if (have_same_type) { + return ""; + } + + std::string msg = "Tensors' scalar types ("; + ( + [&](const auto& tensor) { + msg.append(toString(tensor.scalar_type())); + msg.append(", "); + }(tensors), + ...); + msg.append(") are not one and the same!"); + return msg; +} + +#define TENSORS_HAVE_SAME_SCALAR_TYPE(...) \ + do { \ + const auto have_same_type = tensors_have_same_scalar_type(__VA_ARGS__); \ + TORCH_CHECK(have_same_type.empty(), have_same_type); \ + } while (false)