diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index 8ff5270f0..68f2b1fa3 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -725,8 +725,7 @@ endif() # Silence warnings in asmjit target_compile_options(fbgemm_gpu_py PRIVATE - -Wno-deprecated-anon-enum-enum-conversion) -target_compile_options(fbgemm_gpu_py PRIVATE + -Wno-deprecated-anon-enum-enum-conversion -Wno-deprecated-declarations) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu index f20b1b97b..032ef7e86 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu @@ -140,14 +140,14 @@ void split_embedding_backward_count_unique_indices_kernel {% for vbe in [True, False] %} {% set vdesc = "_vbe" if vbe else "" %} -template +template __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel( pta::PackedTensorAccessor64 grad_output_mean, const pta::PackedTensorAccessor64 grad_output, const pta::PackedTensorAccessor32 D_offsets, - const pta::PackedTensorAccessor32 offsets, + const pta::PackedTensorAccessor32 offsets, {% if vbe %} const pta::PackedTensorAccessor32 row_grad_offsets, const pta::PackedTensorAccessor32 b_t_map, @@ -212,15 +212,16 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel( //////////////////////////////////////////////////////////////////////////////// {% for grad_type in ['at::Half', 'float', 'at::BFloat16'] %} +{% for offset_type in ['int32_t', 'int64_t'] %} template __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel -<{{ grad_type }}> ( +<{{ grad_type }}, {{ offset_type }}> ( pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits> grad_output_mean, const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits> grad_output, const pta::PackedTensorAccessor32 D_offsets, - const pta::PackedTensorAccessor32 offsets, + const pta::PackedTensorAccessor32<{{ offset_type }}, 1, at::RestrictPtrTraits> offsets, {% if vbe %} const pta::PackedTensorAccessor32 row_grad_offsets, const pta::PackedTensorAccessor32 b_t_map, @@ -230,6 +231,7 @@ void grad_mean{{ vdesc }}_kernel FixedDivisor fd_B {% endif %} ); +{% endfor %} // for offset_type in ['int32_t', 'int64_t'] {% endfor %} // for grad_type in ['at::Half', 'float'] {% endfor %} // for vbe in [True, False] diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu index 3fb49ed5e..1cfeb66c9 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu @@ -77,6 +77,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, {%- for ph_name in args.placeholder_tensor_names %} typename {{ ph_name + "_ph_t" }}, {%- endfor %} @@ -105,7 +106,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, const pta::PackedTensorAccessor32 long_run_ids, const pta::PackedTensorAccessor32 num_long_run_ids, @@ -430,6 +431,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( emb_type, grad_type, cache_type, + index_type, ph_type_combo, kFixedMaxVecsPerThread, kThreadGroupSize, @@ -446,6 +448,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row < {{ emb_type }}, {{ grad_type }}, {{ cache_type }}, + {{ index_type }}, {%- for ph_name in args.placeholder_tensor_names %} {{ ph_type_combo[ph_name].primitive_type }}, {%- endfor %} @@ -470,7 +473,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, const pta::PackedTensorAccessor32 long_run_ids, const pta::PackedTensorAccessor32 num_long_run_ids, @@ -538,11 +541,13 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} {%- for emb_type in ['float', 'at::Half'] %} {%- for cache_type in ['float', 'at::Half'] %} + {%- for index_type in ['int32_t', 'int64_t'] %} {%- for ph_type_combo in args.placeholder_type_combos %} {{ template_instantiation( emb_type, grad_type, cache_type, + index_type, ph_type_combo, kFixedMaxVecsPerThread, kThreadGroupSize, @@ -552,6 +557,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row {%- endfor %} {%- endfor %} {%- endfor %} + {%- endfor %} {%- endmacro %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index fdd9c0f79..c731308e4 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -45,6 +45,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, {%- for ph_name in args.placeholder_tensor_names %} typename {{ ph_name + "_ph_t" }}, {%- endfor %} @@ -73,7 +74,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, const pta::PackedTensorAccessor32 long_run_ids, const pta::PackedTensorAccessor32 num_long_run_ids, @@ -232,13 +233,13 @@ split_embedding_backward_codegen_find_long_segments( const bool use_deterministic_algorithms); -template +template __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel( pta::PackedTensorAccessor64 grad_output_mean, const pta::PackedTensorAccessor64 grad_output, const pta::PackedTensorAccessor32 D_offsets, - const pta::PackedTensorAccessor32 offsets, + const pta::PackedTensorAccessor32 offsets, {%- if vbe %} const pta::PackedTensorAccessor32 grad_offsets, const pta::PackedTensorAccessor32 b_t_map, @@ -742,6 +743,7 @@ Tensor {{ embedding_cuda_op }}( else { {{ locs_or_addrs_tensor }}_sorted = at::empty_like({{ locs_or_addrs_tensor }}); size_t temp_storage_bytes = 0; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_1", [&] { AT_CUDA_CHECK(radix_sort_pairs( nullptr, temp_storage_bytes, @@ -753,9 +755,11 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, @@ -767,6 +771,7 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + }); } } @@ -775,6 +780,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] { DISPATCH_EMB_GRAD_CACHE_TYPES( dev_weights.scalar_type(), aligned_grad_output.scalar_type(), @@ -800,9 +806,11 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, @@ -853,7 +861,7 @@ Tensor {{ embedding_cuda_op }}( MAKE_PTA_WITH_NAME(func_name1, grad_output_mean, grad_t, 2, 64), MAKE_PTA_WITH_NAME(func_name1, grad_output_reshaped, grad_t, 2, 64), MAKE_PTA_WITH_NAME(func_name1, D_offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name1, offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name1, offsets, index_t, 1, 32), {%- if vbe %} MAKE_PTA_WITH_NAME(func_name1, vbe_row_output_offsets, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name1, vbe_b_t_map, int32_t, 1, 32), @@ -955,6 +963,7 @@ Tensor {{ embedding_cuda_op }}( __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 hash_size_cumsum, const pta::PackedTensorAccessor32 indices, @@ -79,7 +79,7 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel( // Use a raw pointer to avoid creating dummy PackedTensorAccessor const uint32_t* const __restrict__ vbe_b_t_map, FixedDivisor fd) { - const int32_t T = hash_size_cumsum.size(0) - 1; + const auto T = hash_size_cumsum.size(0) - 1; auto b_t = blockIdx.x * blockDim.x + threadIdx.x; int32_t b; int32_t t; @@ -97,17 +97,16 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel( } const index_t hash_offset = valid ? hash_size_cumsum[t] : -1; - const index_t indices_start = valid ? offsets[b_t] : -1; - const int32_t L = valid ? offsets[b_t + 1] - indices_start : 0; + const auto indices_start = valid ? offsets[b_t] : -1; + const auto L = valid ? offsets[b_t + 1] - indices_start : 0; const int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; // Compile-time conditional if (nobag) { for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { - const index_t indices_start_warp = - fbgemm_gpu::shfl_sync(indices_start, j); - const int32_t t_warp = fbgemm_gpu::shfl_sync(t, j); - const int32_t L_warp = fbgemm_gpu::shfl_sync(L, j); + const auto indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j); + const auto t_warp = fbgemm_gpu::shfl_sync(t, j); + const auto L_warp = fbgemm_gpu::shfl_sync(L, j); const index_t hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j); for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { const index_t idx = __ldg(&indices[indices_start_warp + i]); @@ -124,10 +123,9 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel( reinterpret_cast(&b)[0]; } for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { - const index_t indices_start_warp = - fbgemm_gpu::shfl_sync(indices_start, j); - const uint32_t info_warp = fbgemm_gpu::shfl_sync(info, j); - const int32_t L_warp = fbgemm_gpu::shfl_sync(L, j); + const auto indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j); + const auto info_warp = fbgemm_gpu::shfl_sync(info, j); + const auto L_warp = fbgemm_gpu::shfl_sync(L, j); const index_t hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j); for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { const index_t idx = __ldg(&indices[indices_start_warp + i]); @@ -142,7 +140,7 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel( template __global__ __launch_bounds__(kMaxThreads) void linearize_index_index_select_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 hash_size_cumsum, const pta::PackedTensorAccessor32 indices, @@ -153,7 +151,7 @@ __launch_bounds__(kMaxThreads) void linearize_index_index_select_kernel( linear_indices, FixedDivisor fd, int32_t fixed_L_per_warp) { - const int32_t T = hash_size_cumsum.size(0) - 1; + const auto T = hash_size_cumsum.size(0) - 1; auto b_t = blockIdx.x * blockDim.x + threadIdx.x; int32_t b; int32_t t; @@ -258,7 +256,7 @@ transpose_embedding_input( kMaxThreads, \ 0, \ at::cuda::getCurrentCUDAStream()>>>( \ - MAKE_PTA_WITH_NAME(func_name, hash_size_cumsum, index_t, 1, 32), \ + MAKE_PTA_WITH_NAME(func_name, hash_size_cumsum, int64_t, 1, 32), \ MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), \ MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), \ MAKE_PTA_WITH_NAME(func_name, infos, INFO_ACC_T, 1, 32), \ @@ -296,7 +294,7 @@ transpose_embedding_input( 0, at::cuda::getCurrentCUDAStream()>>>( MAKE_PTA_WITH_NAME( - func_name, hash_size_cumsum, index_t, 1, 32), + func_name, hash_size_cumsum, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), MAKE_PTA_WITH_NAME( func_name, total_L_offsets.value(), index_t, 1, 32),