diff --git a/python/aitemplate/backend/cuda/softmax/softmax.cuh b/python/aitemplate/backend/cuda/softmax/softmax.cuh index acad5eb67..8b10e520b 100644 --- a/python/aitemplate/backend/cuda/softmax/softmax.cuh +++ b/python/aitemplate/backend/cuda/softmax/softmax.cuh @@ -139,21 +139,24 @@ __inline__ __device__ void warpReduceSum( } } -template +// Note that it's not a complete block-wide reduction. +// Only threads that share threadIdx.y reduce values. +template __inline__ __device__ void blockReduceSum(T& val) { - __shared__ T shared[WARP_SIZE]; - int lane = threadIdx.x & 0x1f; // threadIdx.x % warp_size - int wid = threadIdx.x >> 5; // threadIdx.x / warp_size + // NOTE: if ROWS > 1, we must have blockDim.x % WARP_SIZE == 0 + __shared__ T shared[ROWS][WARP_SIZE]; + int lane = threadIdx.x & (WARP_SIZE - 1); // threadIdx.x % WARP_SIZE + int wid = threadIdx.x >> 5; // threadIdx.x / WARP_SIZE warpReduceSum(val); if (lane == 0) - shared[wid] = val; + shared[threadIdx.y][wid] = val; __syncthreads(); bool is_mask = threadIdx.x < (blockDim.x / 32.f); - val = is_mask ? shared[lane] : (T)(0.0f); + val = is_mask ? shared[threadIdx.y][lane] : (T)(0.0f); if (wid == 0) warpReduceSum(val); } @@ -168,21 +171,21 @@ __inline__ __device__ void warpReduceMax( } } -template +template __inline__ __device__ void blockReduceMax(T& val) { - __shared__ T shared[WARP_SIZE]; - int lane = threadIdx.x & 0x1f; + __shared__ T shared[ROWS][WARP_SIZE]; + int lane = threadIdx.x & (WARP_SIZE - 1); int wid = threadIdx.x >> 5; warpReduceMax(val); if (lane == 0) - shared[wid] = val; + shared[threadIdx.y][wid] = val; __syncthreads(); bool is_mask = threadIdx.x < (blockDim.x / 32.f); - val = is_mask ? shared[lane] : (T)(0.0f); + val = is_mask ? shared[threadIdx.y][lane] : (T)(0.0f); if (wid == 0) warpReduceMax(val); @@ -844,36 +847,14 @@ void LaunchSoftmaxK1Middle( SOFTMAX_LAUNCH_CHECK(); } -// Note that it's not a complete block-wide reduction. -// Only threads that share threadIdx.y reduce values. -template class ReduceOp> -__forceinline__ __device__ T softmax_general_block_reduce_x(T* shared, T val) { - ReduceOp r; - shared += threadIdx.y * blockDim.x; - - __syncthreads(); - - shared[threadIdx.x] = val; - - // NOTE: loop starts with __syncthreads() - int offset = blockDim.x / 2; - while (offset > 0) { - __syncthreads(); - if (threadIdx.x < offset) - shared[threadIdx.x] = - r(shared[threadIdx.x], shared[threadIdx.x + offset]); - offset /= 2; - } - - __syncthreads(); - - return shared[0]; -} - -template +template < + typename T, + size_t DimSize, + size_t InnerSize, + size_t DimThreads /* blockDim.x */, + size_t InnerThreads /* blockDim.y */> __global__ void softmax_general(const T* input, T* output, size_t outer_size) { - extern __shared__ unsigned char smem[]; - auto sdata = reinterpret_cast(smem); + __shared__ T reduced_values[InnerThreads]; const uint32_t outer_stride = InnerSize * DimSize; const uint32_t dim_stride = InnerSize; @@ -884,26 +865,49 @@ __global__ void softmax_general(const T* input, T* output, size_t outer_size) { inner_index < InnerSize; inner_index += blockDim.y * gridDim.y) { const uint32_t data_offset = outer_offset + inner_index; - T max_input = std::numeric_limits::lowest(); - // DimThreads == blockDim.x, but using DimThreads here is actually a perf - // regression + T local_max = std::numeric_limits::lowest(); + // First we reduce locally on a per-thread basis. We reduce #InnerThreads + // consecutive rows of the tensor at once, so we read the #input values in + // contiguous chunks of size #InnerThreads. For small values of InnerSize, + // we have InnerThreads == InnerSize, and so we will read in one big + // contiguous range. for (uint32_t d = threadIdx.x; d < DimSize; d += blockDim.x) { const T value = input[data_offset + d * dim_stride]; - max_input = fast_max(max_input, value); + local_max = fast_max(local_max, value); + } + // If reduction uses more than one thread, get the max of the thread-local + // values for each row and broadcast it. + if constexpr (DimThreads > 1) { + if constexpr (DimThreads > WARP_SIZE) + blockReduceMax(local_max); + else + warpReduceMax(local_max); + if (threadIdx.x == 0) + reduced_values[threadIdx.y] = local_max; + __syncthreads(); + local_max = reduced_values[threadIdx.y]; } - if constexpr (DimThreads > 1) - max_input = - softmax_general_block_reduce_x(sdata, max_input); - T sum = 0; + T local_sum = 0; + // NOTE: DimThreads == blockDim.x, but using DimThreads here is actually a + // perf regression. for (uint32_t d = threadIdx.x; d < DimSize; d += blockDim.x) - sum += fast_exp(input[data_offset + d * dim_stride] - max_input); - if constexpr (DimThreads > 1) - sum = softmax_general_block_reduce_x(sdata, sum); + local_sum += fast_exp(input[data_offset + d * dim_stride] - local_max); + if constexpr (DimThreads > 1) { + if constexpr (DimThreads > WARP_SIZE) + blockReduceSum(local_sum); + else + warpReduceSum(local_sum); + if (threadIdx.x == 0) + reduced_values[threadIdx.y] = local_sum; + __syncthreads(); + local_sum = reduced_values[threadIdx.y]; + } for (uint32_t d = threadIdx.x; d < DimSize; d += blockDim.x) output[data_offset + d * dim_stride] = - fast_exp(input[data_offset + d * dim_stride] - max_input) / sum; + fast_exp(input[data_offset + d * dim_stride] - local_max) / + local_sum; } } } @@ -940,19 +944,18 @@ void LaunchSoftmaxGeneral( int multiprocessorCount, cudaStream_t stream) { int block_size = DimThreads * InnerThreads; - size_t smem_size = DimThreads == 1 ? 0 : block_size * sizeof(T); int max_active_blocks; cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks, - softmax_general, + softmax_general, block_size, - smem_size); + /*smem_size=*/0); max_active_blocks *= multiprocessorCount; dim3 grid = softmax_general_get_grid_size( max_active_blocks, outer_size); dim3 block(DimThreads, InnerThreads); - softmax_general - <<>>(input, output, outer_size); + softmax_general + <<>>(input, output, outer_size); SOFTMAX_LAUNCH_CHECK(); } diff --git a/python/aitemplate/backend/cuda/softmax/softmax.py b/python/aitemplate/backend/cuda/softmax/softmax.py index 6605de1ae..c08939752 100644 --- a/python/aitemplate/backend/cuda/softmax/softmax.py +++ b/python/aitemplate/backend/cuda/softmax/softmax.py @@ -199,15 +199,22 @@ def find_tile_size(k: int) -> int: def _softmax_general_block_size(dim_size: int, inner_size: int) -> tuple[int, int]: MAX_THREADS_PER_BLOCK = 1024 + WARP_SIZE = 32 + assert inner_size != 0 inner_threads = min(inner_size, MAX_THREADS_PER_BLOCK) dim_threads = 1 - if inner_threads <= 64 and dim_size >= 64: - while ( - inner_threads * dim_threads <= MAX_THREADS_PER_BLOCK - and dim_threads <= dim_size - ): - dim_threads *= 2 - dim_threads //= 2 + if inner_threads <= 32 and dim_size >= WARP_SIZE: + dim_threads = ( + min( + MAX_THREADS_PER_BLOCK // inner_threads // WARP_SIZE, + dim_size // WARP_SIZE, + ) + * WARP_SIZE + ) + # When dim_threads > 1, warp reduction is done, and for our warp reduction + # impl to work, dim_threads needs to be a multiple of the warp size. + assert dim_threads == 1 or dim_threads % WARP_SIZE == 0 + assert dim_threads != 0 return dim_threads, inner_threads diff --git a/tests/unittest/ops/test_softmax.py b/tests/unittest/ops/test_softmax.py index b81e6b7a3..e477e5e12 100644 --- a/tests/unittest/ops/test_softmax.py +++ b/tests/unittest/ops/test_softmax.py @@ -124,7 +124,7 @@ def _test_softmax( ("k8_large_fp32", "float32", (1, 2), (3848,)), ("no_smem_fp32", "float32", (1, 2), (12500,)), ( - "general_no_smem_fp32", + "general_dim_threads_1_fp32", "float32", (1, 2), (6, 8, 3, 3), @@ -148,21 +148,28 @@ def _test_softmax( ("k8_large_bf16", "bfloat16", (1, 2), (3848,)), ("no_smem_bf16", "bfloat16", (1, 2), (12500,)), ( - "general_no_smem_bf16", + "general_dim_threads_1_bf16", "bfloat16", (1, 2), (6, 8, 3, 3), 2, ), ( - "general_smem_bf16", + "general_dim_threads_32_bf16", + "bfloat16", + (1, 2), + (6, 32, 3, 3), + 2, + ), + ( + "general_dim_threads_64_bf16", "bfloat16", (1, 2), (6, 64, 3, 3), 2, ), ( - "general_reduce_dim_1_bf16", + "general_reduce_dim_size_1_bf16", "bfloat16", (1, 2), (1, 3),