From 029ba1c613d5b7a6ef8796789fa8cb6c68ecce0e Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Tue, 8 Aug 2023 13:59:29 -0700 Subject: [PATCH] Refactor softmax code for dim=-1 case (#886) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/886 This sets the stage for extending the warp reduction code to the dim=-2 case. Main changes in this diff: 1. blockReduceMax now uses `fast_max` instead of `max`. This is what the dim=-2 reduction code already uses. From looking at the implementation, it seems that fast_max is fast because of type specialization and not because it sacrifices accuracy, so I think this is a safe change. 2. The block reduction code had a `NUM` parameter that was always set to 1. I've eliminated that to simplify things and remove a bunch of indirection. 3. The shared memory used by the block / warp reduction code had shared memory set to rows of 33 elements in order to avoid bank conflicts. However, given that NUM was always 1, I don't think bank conflicts are possible in practice. I've therefore changed the shared memory to use rows of 32 elements instead. Reviewed By: muchulee8, aakhundov Differential Revision: D47862053 fbshipit-source-id: 477f16ae3f33e2e5f2a858bf156268aef5a25ee1 --- .../backend/cuda/softmax/softmax.cuh | 135 ++++++++---------- 1 file changed, 59 insertions(+), 76 deletions(-) diff --git a/python/aitemplate/backend/cuda/softmax/softmax.cuh b/python/aitemplate/backend/cuda/softmax/softmax.cuh index 341071b2e..acad5eb67 100644 --- a/python/aitemplate/backend/cuda/softmax/softmax.cuh +++ b/python/aitemplate/backend/cuda/softmax/softmax.cuh @@ -36,6 +36,8 @@ using bfloat16 = nv_bfloat16; #define SOFTMAX_LAUNCH_CHECK() SOFTMAX_DEVICE_CHECK(cudaGetLastError()) +#define WARP_SIZE 32 + // unroll directives copied from CUTLASS #if defined(__CUDA_ARCH__) #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) @@ -127,82 +129,63 @@ struct float8 { #define FINAL_MASK 0xffffffff -template -__inline__ __device__ T warpReduceSum(T* val, int thread_group_width = 32) { -#pragma unroll - for (int i = 0; i < NUM; i++) { +template +__inline__ __device__ void warpReduceSum( + T& val, + int thread_group_width = WARP_SIZE) { #pragma unroll - for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) { - val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); - } + for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) { + val += __shfl_xor_sync(FINAL_MASK, val, mask, WARP_SIZE); } - return (T)(0.0f); } -template -__inline__ __device__ T blockReduceSum(T* val) { - __shared__ T shared[NUM][33]; +template +__inline__ __device__ void blockReduceSum(T& val) { + __shared__ T shared[WARP_SIZE]; int lane = threadIdx.x & 0x1f; // threadIdx.x % warp_size int wid = threadIdx.x >> 5; // threadIdx.x / warp_size - warpReduceSum(val); + warpReduceSum(val); - if (lane == 0) { -#pragma unroll - for (int i = 0; i < NUM; i++) { - shared[i][wid] = val[i]; - } - } + if (lane == 0) + shared[wid] = val; __syncthreads(); bool is_mask = threadIdx.x < (blockDim.x / 32.f); -#pragma unroll - for (int i = 0; i < NUM; i++) { - val[i] = is_mask ? shared[i][lane] : (T)(0.0f); - } + val = is_mask ? shared[lane] : (T)(0.0f); if (wid == 0) - warpReduceSum(val); - return (T)0.0f; + warpReduceSum(val); } -template -__inline__ __device__ T warpReduceMax(T* val, int thread_group_width = 32) { -#pragma unroll - for (int i = 0; i < NUM; i++) { +template +__inline__ __device__ void warpReduceMax( + T& val, + int thread_group_width = WARP_SIZE) { #pragma unroll - for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) { - val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); - } + for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) { + val = fast_max(val, __shfl_xor_sync(FINAL_MASK, val, mask, WARP_SIZE)); } - return (T)(0.0f); } -template -__inline__ __device__ T blockReduceMax(T* val) { - __shared__ T shared[NUM][33]; +template +__inline__ __device__ void blockReduceMax(T& val) { + __shared__ T shared[WARP_SIZE]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; - warpReduceMax(val); + warpReduceMax(val); - if (lane == 0) { -#pragma unroll - for (int i = 0; i < NUM; i++) { - shared[i][wid] = val[i]; - } - } + if (lane == 0) + shared[wid] = val; __syncthreads(); bool is_mask = threadIdx.x < (blockDim.x / 32.f); -#pragma unroll - for (int i = 0; i < NUM; i++) { - val[i] = is_mask ? shared[i][lane] : (T)(0.0f); - } + val = is_mask ? shared[lane] : (T)(0.0f); + if (wid == 0) - warpReduceMax(val); - return (T)0.0f; + warpReduceMax(val); } } // namespace @@ -348,32 +331,32 @@ __global__ void softmaxBlockNocache( input += offset; output += offset; - float local_max[1] = {-Inf()}; + float local_max = -Inf(); for (int i = tid; i < n; i += blockDim.x) { float local_val = static_cast(input[i]); - local_max[0] = max(local_val, local_max[0]); + local_max = max(local_val, local_max); } - if (blockDim.x <= 32) { - warpReduceMax(local_max); + if (blockDim.x <= WARP_SIZE) { + warpReduceMax(local_max); } else { - blockReduceMax(local_max); + blockReduceMax(local_max); } if (threadIdx.x == 0) { - s_max = local_max[0]; + s_max = local_max; } __syncthreads(); - float local_sum[1] = {0.0f}; + float local_sum = 0.0f; for (int i = tid; i < n; i += blockDim.x) { - local_sum[0] += exp(static_cast(input[i]) - s_max); + local_sum += exp(static_cast(input[i]) - s_max); } - if (blockDim.x <= 32) { - warpReduceSum(local_sum); + if (blockDim.x <= WARP_SIZE) { + warpReduceSum(local_sum); } else { - blockReduceSum(local_sum); + blockReduceSum(local_sum); } if (threadIdx.x == 0) { - s_sum = local_sum[0]; + s_sum = local_sum; } __syncthreads(); for (int i = tid; i < n; i += blockDim.x) { @@ -417,7 +400,7 @@ __global__ void softmax_stored_locally_multi_dim( const int64_t row_offset = row * int((n + pack_size - 1) / pack_size); const T* row_x = input + row_offset; T* row_y = output + row_offset; - float local_max[1] = {-Inf()}; + float local_max = -Inf(); #pragma unroll for (int i = 0; i < num_packs; ++i) { const int col = i * blockDim.x + tid; @@ -427,7 +410,7 @@ __global__ void softmax_stored_locally_multi_dim( #pragma unroll for (int j = 0; j < pack_size; j++) { buf[i * pack_size + j] = static_cast(pack_x[j]); - local_max[0] = max(local_max[0], buf[i * pack_size + j]); + local_max = max(local_max, buf[i * pack_size + j]); } } else { #pragma unroll @@ -436,15 +419,15 @@ __global__ void softmax_stored_locally_multi_dim( } } } - warpReduceMax(local_max, blockDim.x); + warpReduceMax(local_max, blockDim.x); - float local_sum[1] = {0.0f}; + float local_sum = 0.0f; #pragma unroll for (int i = 0; i < cols_per_thread; ++i) { - buf[i] = exp(buf[i] - local_max[0]); - local_sum[0] += buf[i]; + buf[i] = exp(buf[i] - local_max); + local_sum += buf[i]; } - warpReduceSum(local_sum, blockDim.x); + warpReduceSum(local_sum, blockDim.x); T tmp_o; ACT_T* pack_y = reinterpret_cast(&tmp_o); @@ -453,7 +436,7 @@ __global__ void softmax_stored_locally_multi_dim( const int col = i * blockDim.x + tid; if (col < n / pack_size) { for (int j = 0; j < pack_size; j++) { - pack_y[j] = ACT_T(buf[i * pack_size + j] / local_sum[0]); + pack_y[j] = ACT_T(buf[i * pack_size + j] / local_sum); } row_y[col] = tmp_o; } @@ -481,7 +464,7 @@ __global__ void softmax_block_smem( const int64_t row_offset = row * int((n + pack_size - 1) / pack_size); const T* row_x = input + row_offset; T* row_y = output + row_offset; - float local_max[1] = {-Inf()}; + float local_max = -Inf(); for (int pack_id = tid; pack_id < num_packs; pack_id += blockDim.x) { T tmp_in = row_x[pack_id]; @@ -490,28 +473,28 @@ __global__ void softmax_block_smem( for (int j = 0; j < pack_size; j++) { float pack = pack_x[j]; buf[j * num_packs + pack_id] = pack; - local_max[0] = max(local_max[0], pack); + local_max = max(local_max, pack); } } - blockReduceMax(local_max); // reduce on a block of #blockDim.x + blockReduceMax(local_max); // reduce on a block of #blockDim.x __shared__ float s_max; if (threadIdx.x == 0) { - s_max = local_max[0]; + s_max = local_max; } __syncthreads(); - float local_sum[1] = {0.0f}; + float local_sum = 0.0f; for (int i = tid; i < n; i += blockDim.x) { float local_val = exp(buf[i] - s_max); buf[i] = local_val; - local_sum[0] += local_val; + local_sum += local_val; } - blockReduceSum(local_sum); + blockReduceSum(local_sum); __shared__ float s_sum; if (threadIdx.x == 0) { - s_sum = local_sum[0]; + s_sum = local_sum; } __syncthreads();