From 79ef0642bf27b6a6f0d0143e47f9b64083141e51 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 18 Sep 2024 11:35:05 +0000 Subject: [PATCH] 2024-09-18 nightly release (e27c5e166f69ac138e839f7f62b189a731923da1) --- .../embedding_forward_quantized_host_cpu.cpp | 3 + ...ward_quantized_split_nbit_host_template.cu | 2 +- .../utils/embedding_bounds_check_host_cpu.cpp | 3 + .../experimental/gen_ai/src/comm/car.cu | 161 ++++++++++++------ .../fbgemm_gpu/utils/tensor_accessor.h | 74 +++++++- .../transpose_embedding_input.cu | 4 +- 6 files changed, 195 insertions(+), 52 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp index e799120a6..41fd137dd 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp @@ -103,6 +103,9 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu( std::optional max_float8_D, std::optional fp8_exponent_bits, std::optional fp8_exponent_bias) { + if (offsets.scalar_type() != indices.scalar_type()) { + offsets = offsets.toType(indices.scalar_type()); + } if (static_cast(pooling_mode) == PoolingMode::NONE) { std::vector max_D_list{ max_int2_D, diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index e7b908cdd..bc4e7ba74 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -62,7 +62,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no {%- set func_name = "nbit::" + emb_weight_type + "_split_embedding" + ("_nobag" if nobag else "") + "_codegen_forward_" + wdesc + "_kernel_small_L" %} #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name_{{ emb_weight_type }} = func_name_{{ emb_weight_type }}; + const auto func_name_{{ emb_weight_type }} = "{{ func_name }}_{{ emb_weight_type }}"; #endif #ifdef X diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp index 85d23cc94..1098378d0 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp @@ -49,6 +49,9 @@ void bounds_check_indices_cpu( const std::optional& weights, const std::optional& B_offsets, const int64_t max_B) { + if (offsets.scalar_type() != indices.scalar_type()) { + offsets = offsets.toType(indices.scalar_type()); + } const auto vbe = B_offsets.has_value(); if (vbe) { TENSOR_NDIM_EQUALS(B_offsets.value(), 1); diff --git a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu index f0c689570..4712f620f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu @@ -67,17 +67,26 @@ DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) { return c; } -template -__global__ void one_shot_all_reduce( - int32_t rank, - int32_t world_size, - int32_t flag, - std::array barriers, - std::array inputs, +template +#if defined(USE_ROCM) +__launch_bounds__(512) +#endif + __global__ void one_shot_all_reduce( + int32_t rank, + int32_t world_size, + int32_t flag, + std::array barriers, + std::array inputs, +#if defined(USE_ROCM) + at::BFloat16* __restrict__ ar_input, + at::BFloat16* __restrict__ acc, + at::BFloat16* __restrict__ output, +#else at::BFloat16* ar_input, at::BFloat16* acc, at::BFloat16* output, - int32_t N) { +#endif + int32_t N) { // It is expensive to launch hipMemcpyAsync on ROCm // Move data copy here. Each block copies part of input data at::BFloat16* input = inputs[rank]; @@ -143,11 +152,11 @@ __global__ void one_shot_all_reduce( // Sum the values from the different ranks. bf16x8 sums; - if (acc) { + if constexpr (has_acc) { *reinterpret_cast(&sums) = *reinterpret_cast(&acc[i]); } else { - memset(reinterpret_cast(&sums), 0, sizeof(sums)); + *reinterpret_cast(&sums) = uint4{0}; } #pragma unroll kWorldSize @@ -336,15 +345,24 @@ static DEVICE_INLINE void ld_flag_acquire(int32_t& flag, int32_t* flag_addr) { #endif } -template +template +#if defined(USE_ROCM) +__launch_bounds__(512) __global__ void two_shot_all_reduce( +#else __launch_bounds__(1024) __global__ void two_shot_all_reduce( +#endif int32_t rank, int32_t world_size, int32_t flag, std::array barriers, std::array inputs, +#if defined(USE_ROCM) + at::BFloat16* __restrict__ acc, + at::BFloat16* __restrict__ output, +#else at::BFloat16* acc, at::BFloat16* output, +#endif int32_t N) { int32_t N_per_rank = N / kWorldSize; int32_t N_start = N_per_rank * rank; @@ -374,13 +392,11 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce( __syncthreads(); at::BFloat16* src_d[kWorldSize]; - int dst_rank[kWorldSize]; #pragma unroll kWorldSize for (int ii = 0; ii < kWorldSize; ++ii) { int d_rank = (rank + ii) % kWorldSize; src_d[ii] = inputs[d_rank]; - dst_rank[ii] = d_rank; } // Each block accumulates the values from the different GPUs on the same @@ -395,11 +411,12 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce( } bf16x8 sums; - if (acc) { + + if constexpr (has_acc) { *reinterpret_cast(&sums) = *reinterpret_cast(&acc[i + N_start]); } else { - memset(reinterpret_cast(&sums), 0, sizeof(sums)); + *reinterpret_cast(&sums) = uint4{0}; } #pragma unroll kWorldSize @@ -433,11 +450,19 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce( // Gather all needed elts from other intra-node ranks for (size_t i = threadIdx.x * 8 + blockIdx.x * blockDim.x * 8; i < N_per_rank; i += gridDim.x * blockDim.x * 8) { + uint4 temp[kWorldSize]; +#pragma unroll kWorldSize + for (int ii = 0; ii < kWorldSize; ++ii) { + int d_rank = (rank + ii) % kWorldSize; + int i_r = N_start + i + (d_rank - rank) * N_per_rank; + temp[ii] = reinterpret_cast(&src_d[ii][i_r])[0]; + } + #pragma unroll kWorldSize for (int ii = 0; ii < kWorldSize; ++ii) { - int i_r = N_start + i + (dst_rank[ii] - rank) * N_per_rank; - *reinterpret_cast(&output[i_r]) = - reinterpret_cast(&src_d[ii][i_r])[0]; + int d_rank = (rank + ii) % kWorldSize; + int i_r = N_start + i + (d_rank - rank) * N_per_rank; + *reinterpret_cast(&output[i_r]) = temp[ii]; } } } @@ -474,7 +499,11 @@ void one_shot_car_allreduce( constexpr int32_t N_per_thread = 8; constexpr int32_t N_per_warp = N_per_thread * kThreadsPerWarp; TORCH_CHECK(N % N_per_warp == 0); +#if defined(USE_ROCM) + constexpr int32_t kThreadsPerBlock = 512; +#else constexpr int32_t kThreadsPerBlock = 1024; +#endif constexpr int32_t kMaxBlocks = 24; dim3 threads(0, 1, 1); @@ -494,21 +523,37 @@ void one_shot_car_allreduce( threads.x = threads_per_block; } -#define X(kWorldSize) \ - if (state->world_size_ == kWorldSize) { \ - one_shot_all_reduce \ - <<>>( \ - state->rank_, \ - state->world_size_, \ - state->flag_ * state->world_size_, \ - barriers, \ - inputs, \ - y.data_ptr(), \ - z ? z->data_ptr() : nullptr, \ - y_allreduce.data_ptr(), \ - N); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - return; \ +#define X(kWorldSize) \ + if (state->world_size_ == kWorldSize) { \ + if (z) { \ + one_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + y.data_ptr(), \ + z->data_ptr(), \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } else { \ + one_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + y.data_ptr(), \ + nullptr, \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } \ } TORCH_CHECK( @@ -520,7 +565,7 @@ void one_shot_car_allreduce( #undef X return; -} +} // namespace fbgemm_gpu void two_shot_car_allreduce( at::Tensor y_allreduce, @@ -565,26 +610,46 @@ void two_shot_car_allreduce( TORCH_CHECK(N_per_rank % N_per_thread == 0); auto threads_per_rank = N_per_rank / N_per_thread; +#if defined(USE_ROCM) + constexpr int32_t kThreadsPerBlock = 512; +#else constexpr int32_t kThreadsPerBlock = 1024; +#endif + constexpr int32_t kMaxBlocks = 24; auto blocks = std::min( cuda_calc_block_count(threads_per_rank, kThreadsPerBlock), kMaxBlocks); -#define X(kWorldSize) \ - if (state->world_size_ == kWorldSize) { \ - two_shot_all_reduce \ - <<>>( \ - state->rank_, \ - state->world_size_, \ - state->flag_ * state->world_size_, \ - barriers, \ - inputs, \ - z ? z->data_ptr() : nullptr, \ - y_allreduce.data_ptr(), \ - N); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - return; \ +#define X(kWorldSize) \ + if (state->world_size_ == kWorldSize) { \ + if (z) { \ + two_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + z->data_ptr(), \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } else { \ + two_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + nullptr, \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } \ } TORCH_CHECK( diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h index 3f5ed08f2..24ad14125 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -472,6 +473,53 @@ template < using PackedTensorAccessor64 = GenericPackedTensorAccessor; +template +inline at::ScalarType scalar_type_for() { +#define TYPE_CASE(U, name) \ + if constexpr (std::is_same_v) { \ + return at::ScalarType::name; \ + } + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_CASE) + +#undef TYPE_CASE + + return at::ScalarType::Undefined; +} + +template +inline void check_scalar_type( + const at::TensorBase& tensor +#ifdef FBGEMM_GPU_MEMCHECK + , + const char* const func_name, + const char* const tensor_name +#endif +) { + const auto expected_type = scalar_type_for(); + + TORCH_CHECK( + tensor.scalar_type() == expected_type || + (isQIntType(tensor.scalar_type()) && + toUnderlying(tensor.scalar_type()) == expected_type), +#ifdef FBGEMM_GPU_MEMCHECK + "[ ", + func_name, + " ]: ", +#endif + "Expected tensor ", +#ifdef FBGEMM_GPU_MEMCHECK + "'", + tensor_name, + "' ", +#endif + "to have scalar type ", + expected_type, + ", but found ", + tensor.scalar_type(), + " instead!"); +} + } // namespace fbgemm_gpu #ifdef FBGEMM_GPU_MEMCHECK @@ -521,10 +569,21 @@ pta::PackedTensorAccessor32 make_packed_tensor_accessor32( #else const at::Tensor& tensor) { #endif + TORCH_CHECK( tensor.numel() <= static_cast(std::numeric_limits::max()), "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64"); + + fbgemm_gpu::check_scalar_type( + tensor +#ifdef FBGEMM_GPU_MEMCHECK + , + func_name, + ptr_name +#endif + ); + #ifdef FBGEMM_GPU_MEMCHECK return make_generic_packed_tensor_accessor( tensor, ptr_name, func_name); @@ -542,10 +601,23 @@ pta::PackedTensorAccessor64 make_packed_tensor_accessor64( const at::Tensor& tensor, const char* const ptr_name, const char* const func_name) { +#else + const at::Tensor& tensor) { +#endif + + fbgemm_gpu::check_scalar_type( + tensor +#ifdef FBGEMM_GPU_MEMCHECK + , + func_name, + ptr_name +#endif + ); + +#ifdef FBGEMM_GPU_MEMCHECK return make_generic_packed_tensor_accessor( tensor, ptr_name, func_name); #else - const at::Tensor& tensor) { return tensor.packed_accessor64(); #endif } diff --git a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu index 16d46e2e3..83a06d78a 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu @@ -274,10 +274,10 @@ transpose_embedding_input( } AT_DISPATCH_INDEX_TYPES( - infos.scalar_type(), "transpose_embedding_input1", [&] { + infos.scalar_type(), "transpose_embedding_input_1", [&] { using info_t = index_t; AT_DISPATCH_INDEX_TYPES( - indices.scalar_type(), "transpose_embedding_input2", [&] { + indices.scalar_type(), "transpose_embedding_input_2", [&] { if (!is_index_select) { if (!nobag) { INVOKE_LINEARIZE_INDEX_KERNEL(int32_t, false);