Skip to content

Commit

Permalink
Have softmax use warp reduction in dim=-2 case (#887)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #887

We extend the warp reduction code to handle a 2D block, which the dim=-2 case uses.

Reviewed By: aakhundov

Differential Revision: D47862056

fbshipit-source-id: e1bbf1e8830bd7346c2bf33721bebf72e18f751d
  • Loading branch information
int3 authored and facebook-github-bot committed Aug 9, 2023
1 parent 38aebd9 commit b5e2959
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 68 deletions.
117 changes: 60 additions & 57 deletions python/aitemplate/backend/cuda/softmax/softmax.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,24 @@ __inline__ __device__ void warpReduceSum(
}
}

template <typename T>
// Note that it's not a complete block-wide reduction.
// Only threads that share threadIdx.y reduce values.
template <typename T, size_t ROWS = 1>
__inline__ __device__ void blockReduceSum(T& val) {
__shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f; // threadIdx.x % warp_size
int wid = threadIdx.x >> 5; // threadIdx.x / warp_size
// NOTE: if ROWS > 1, we must have blockDim.x % WARP_SIZE == 0
__shared__ T shared[ROWS][WARP_SIZE];
int lane = threadIdx.x & (WARP_SIZE - 1); // threadIdx.x % WARP_SIZE
int wid = threadIdx.x >> 5; // threadIdx.x / WARP_SIZE

warpReduceSum<T>(val);

if (lane == 0)
shared[wid] = val;
shared[threadIdx.y][wid] = val;

__syncthreads();

bool is_mask = threadIdx.x < (blockDim.x / 32.f);
val = is_mask ? shared[lane] : (T)(0.0f);
val = is_mask ? shared[threadIdx.y][lane] : (T)(0.0f);
if (wid == 0)
warpReduceSum<T>(val);
}
Expand All @@ -168,21 +171,21 @@ __inline__ __device__ void warpReduceMax(
}
}

template <typename T>
template <typename T, size_t ROWS = 1>
__inline__ __device__ void blockReduceMax(T& val) {
__shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
__shared__ T shared[ROWS][WARP_SIZE];
int lane = threadIdx.x & (WARP_SIZE - 1);
int wid = threadIdx.x >> 5;

warpReduceMax<T>(val);

if (lane == 0)
shared[wid] = val;
shared[threadIdx.y][wid] = val;

__syncthreads();

bool is_mask = threadIdx.x < (blockDim.x / 32.f);
val = is_mask ? shared[lane] : (T)(0.0f);
val = is_mask ? shared[threadIdx.y][lane] : (T)(0.0f);

if (wid == 0)
warpReduceMax<T>(val);
Expand Down Expand Up @@ -844,36 +847,14 @@ void LaunchSoftmaxK1Middle(
SOFTMAX_LAUNCH_CHECK();
}

// Note that it's not a complete block-wide reduction.
// Only threads that share threadIdx.y reduce values.
template <typename T, template <typename> class ReduceOp>
__forceinline__ __device__ T softmax_general_block_reduce_x(T* shared, T val) {
ReduceOp<T> r;
shared += threadIdx.y * blockDim.x;

__syncthreads();

shared[threadIdx.x] = val;

// NOTE: loop starts with __syncthreads()
int offset = blockDim.x / 2;
while (offset > 0) {
__syncthreads();
if (threadIdx.x < offset)
shared[threadIdx.x] =
r(shared[threadIdx.x], shared[threadIdx.x + offset]);
offset /= 2;
}

__syncthreads();

return shared[0];
}

template <typename T, size_t DimSize, size_t InnerSize, size_t DimThreads>
template <
typename T,
size_t DimSize,
size_t InnerSize,
size_t DimThreads /* blockDim.x */,
size_t InnerThreads /* blockDim.y */>
__global__ void softmax_general(const T* input, T* output, size_t outer_size) {
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<T*>(smem);
__shared__ T reduced_values[InnerThreads];
const uint32_t outer_stride = InnerSize * DimSize;
const uint32_t dim_stride = InnerSize;

Expand All @@ -884,26 +865,49 @@ __global__ void softmax_general(const T* input, T* output, size_t outer_size) {
inner_index < InnerSize;
inner_index += blockDim.y * gridDim.y) {
const uint32_t data_offset = outer_offset + inner_index;
T max_input = std::numeric_limits<T>::lowest();
// DimThreads == blockDim.x, but using DimThreads here is actually a perf
// regression
T local_max = std::numeric_limits<T>::lowest();
// First we reduce locally on a per-thread basis. We reduce #InnerThreads
// consecutive rows of the tensor at once, so we read the #input values in
// contiguous chunks of size #InnerThreads. For small values of InnerSize,
// we have InnerThreads == InnerSize, and so we will read in one big
// contiguous range.
for (uint32_t d = threadIdx.x; d < DimSize; d += blockDim.x) {
const T value = input[data_offset + d * dim_stride];
max_input = fast_max(max_input, value);
local_max = fast_max(local_max, value);
}
// If reduction uses more than one thread, get the max of the thread-local
// values for each row and broadcast it.
if constexpr (DimThreads > 1) {
if constexpr (DimThreads > WARP_SIZE)
blockReduceMax<T, InnerThreads>(local_max);
else
warpReduceMax<T>(local_max);
if (threadIdx.x == 0)
reduced_values[threadIdx.y] = local_max;
__syncthreads();
local_max = reduced_values[threadIdx.y];
}
if constexpr (DimThreads > 1)
max_input =
softmax_general_block_reduce_x<T, FastMax>(sdata, max_input);

T sum = 0;
T local_sum = 0;
// NOTE: DimThreads == blockDim.x, but using DimThreads here is actually a
// perf regression.
for (uint32_t d = threadIdx.x; d < DimSize; d += blockDim.x)
sum += fast_exp(input[data_offset + d * dim_stride] - max_input);
if constexpr (DimThreads > 1)
sum = softmax_general_block_reduce_x<T, std::plus>(sdata, sum);
local_sum += fast_exp(input[data_offset + d * dim_stride] - local_max);
if constexpr (DimThreads > 1) {
if constexpr (DimThreads > WARP_SIZE)
blockReduceSum<T, InnerThreads>(local_sum);
else
warpReduceSum<T>(local_sum);
if (threadIdx.x == 0)
reduced_values[threadIdx.y] = local_sum;
__syncthreads();
local_sum = reduced_values[threadIdx.y];
}

for (uint32_t d = threadIdx.x; d < DimSize; d += blockDim.x)
output[data_offset + d * dim_stride] =
fast_exp(input[data_offset + d * dim_stride] - max_input) / sum;
fast_exp(input[data_offset + d * dim_stride] - local_max) /
local_sum;
}
}
}
Expand Down Expand Up @@ -940,19 +944,18 @@ void LaunchSoftmaxGeneral(
int multiprocessorCount,
cudaStream_t stream) {
int block_size = DimThreads * InnerThreads;
size_t smem_size = DimThreads == 1 ? 0 : block_size * sizeof(T);
int max_active_blocks;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
softmax_general<T, DimSize, InnerSize, DimThreads>,
softmax_general<T, DimSize, InnerSize, DimThreads, InnerThreads>,
block_size,
smem_size);
/*smem_size=*/0);
max_active_blocks *= multiprocessorCount;
dim3 grid = softmax_general_get_grid_size<InnerThreads, InnerSize>(
max_active_blocks, outer_size);
dim3 block(DimThreads, InnerThreads);
softmax_general<T, DimSize, InnerSize, DimThreads>
<<<grid, block, smem_size, stream>>>(input, output, outer_size);
softmax_general<T, DimSize, InnerSize, DimThreads, InnerThreads>
<<<grid, block, 0, stream>>>(input, output, outer_size);
SOFTMAX_LAUNCH_CHECK();
}

Expand Down
21 changes: 14 additions & 7 deletions python/aitemplate/backend/cuda/softmax/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,22 @@ def find_tile_size(k: int) -> int:

def _softmax_general_block_size(dim_size: int, inner_size: int) -> tuple[int, int]:
MAX_THREADS_PER_BLOCK = 1024
WARP_SIZE = 32
assert inner_size != 0
inner_threads = min(inner_size, MAX_THREADS_PER_BLOCK)
dim_threads = 1
if inner_threads <= 64 and dim_size >= 64:
while (
inner_threads * dim_threads <= MAX_THREADS_PER_BLOCK
and dim_threads <= dim_size
):
dim_threads *= 2
dim_threads //= 2
if inner_threads <= 32 and dim_size >= WARP_SIZE:
dim_threads = (
min(
MAX_THREADS_PER_BLOCK // inner_threads // WARP_SIZE,
dim_size // WARP_SIZE,
)
* WARP_SIZE
)
# When dim_threads > 1, warp reduction is done, and for our warp reduction
# impl to work, dim_threads needs to be a multiple of the warp size.
assert dim_threads == 1 or dim_threads % WARP_SIZE == 0
assert dim_threads != 0
return dim_threads, inner_threads


Expand Down
15 changes: 11 additions & 4 deletions tests/unittest/ops/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _test_softmax(
("k8_large_fp32", "float32", (1, 2), (3848,)),
("no_smem_fp32", "float32", (1, 2), (12500,)),
(
"general_no_smem_fp32",
"general_dim_threads_1_fp32",
"float32",
(1, 2),
(6, 8, 3, 3),
Expand All @@ -148,21 +148,28 @@ def _test_softmax(
("k8_large_bf16", "bfloat16", (1, 2), (3848,)),
("no_smem_bf16", "bfloat16", (1, 2), (12500,)),
(
"general_no_smem_bf16",
"general_dim_threads_1_bf16",
"bfloat16",
(1, 2),
(6, 8, 3, 3),
2,
),
(
"general_smem_bf16",
"general_dim_threads_32_bf16",
"bfloat16",
(1, 2),
(6, 32, 3, 3),
2,
),
(
"general_dim_threads_64_bf16",
"bfloat16",
(1, 2),
(6, 64, 3, 3),
2,
),
(
"general_reduce_dim_1_bf16",
"general_reduce_dim_size_1_bf16",
"bfloat16",
(1, 2),
(1, 3),
Expand Down

0 comments on commit b5e2959

Please sign in to comment.