Skip to content

Commit

Permalink
Refactor softmax code for dim=-1 case (#886)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #886

This sets the stage for extending the warp reduction code to the dim=-2 case.

Main changes in this diff:

1. blockReduceMax now uses `fast_max` instead of `max`. This is what the dim=-2
reduction code already uses. From looking at the implementation, it seems that
fast_max is fast because of type specialization and not because it sacrifices
accuracy, so I think this is a safe change.

2. The block reduction code had a `NUM` parameter that was always set to 1.
I've eliminated that to simplify things and remove a bunch of indirection.

3. The shared memory used by the block / warp reduction code had shared memory
set to rows of 33 elements in order to avoid bank conflicts. However, given
that NUM was always 1, I don't think bank conflicts are possible in practice.
I've therefore changed the shared memory to use rows of 32 elements instead.

Reviewed By: muchulee8, aakhundov

Differential Revision: D47862053

fbshipit-source-id: 477f16ae3f33e2e5f2a858bf156268aef5a25ee1
  • Loading branch information
int3 authored and facebook-github-bot committed Aug 8, 2023
1 parent 318111f commit 029ba1c
Showing 1 changed file with 59 additions and 76 deletions.
135 changes: 59 additions & 76 deletions python/aitemplate/backend/cuda/softmax/softmax.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ using bfloat16 = nv_bfloat16;

#define SOFTMAX_LAUNCH_CHECK() SOFTMAX_DEVICE_CHECK(cudaGetLastError())

#define WARP_SIZE 32

// unroll directives copied from CUTLASS
#if defined(__CUDA_ARCH__)
#if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
Expand Down Expand Up @@ -127,82 +129,63 @@ struct float8 {

#define FINAL_MASK 0xffffffff

template <typename T, int NUM>
__inline__ __device__ T warpReduceSum(T* val, int thread_group_width = 32) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
template <typename T>
__inline__ __device__ void warpReduceSum(
T& val,
int thread_group_width = WARP_SIZE) {
#pragma unroll
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
val += __shfl_xor_sync(FINAL_MASK, val, mask, WARP_SIZE);
}
return (T)(0.0f);
}

template <typename T, int NUM>
__inline__ __device__ T blockReduceSum(T* val) {
__shared__ T shared[NUM][33];
template <typename T>
__inline__ __device__ void blockReduceSum(T& val) {
__shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f; // threadIdx.x % warp_size
int wid = threadIdx.x >> 5; // threadIdx.x / warp_size

warpReduceSum<T, NUM>(val);
warpReduceSum<T>(val);

if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[i][wid] = val[i];
}
}
if (lane == 0)
shared[wid] = val;

__syncthreads();

bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
}
val = is_mask ? shared[lane] : (T)(0.0f);
if (wid == 0)
warpReduceSum<T, NUM>(val);
return (T)0.0f;
warpReduceSum<T>(val);
}

template <typename T, int NUM>
__inline__ __device__ T warpReduceMax(T* val, int thread_group_width = 32) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
template <typename T>
__inline__ __device__ void warpReduceMax(
T& val,
int thread_group_width = WARP_SIZE) {
#pragma unroll
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
}
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
val = fast_max(val, __shfl_xor_sync(FINAL_MASK, val, mask, WARP_SIZE));
}
return (T)(0.0f);
}

template <typename T, int NUM>
__inline__ __device__ T blockReduceMax(T* val) {
__shared__ T shared[NUM][33];
template <typename T>
__inline__ __device__ void blockReduceMax(T& val) {
__shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;

warpReduceMax<T, NUM>(val);
warpReduceMax<T>(val);

if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[i][wid] = val[i];
}
}
if (lane == 0)
shared[wid] = val;

__syncthreads();

bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
}
val = is_mask ? shared[lane] : (T)(0.0f);

if (wid == 0)
warpReduceMax<T, NUM>(val);
return (T)0.0f;
warpReduceMax<T>(val);
}

} // namespace
Expand Down Expand Up @@ -348,32 +331,32 @@ __global__ void softmaxBlockNocache(
input += offset;
output += offset;

float local_max[1] = {-Inf<float>()};
float local_max = -Inf<float>();
for (int i = tid; i < n; i += blockDim.x) {
float local_val = static_cast<float>(input[i]);
local_max[0] = max(local_val, local_max[0]);
local_max = max(local_val, local_max);
}

if (blockDim.x <= 32) {
warpReduceMax<float, 1>(local_max);
if (blockDim.x <= WARP_SIZE) {
warpReduceMax<float>(local_max);
} else {
blockReduceMax<float, 1>(local_max);
blockReduceMax<float>(local_max);
}
if (threadIdx.x == 0) {
s_max = local_max[0];
s_max = local_max;
}
__syncthreads();
float local_sum[1] = {0.0f};
float local_sum = 0.0f;
for (int i = tid; i < n; i += blockDim.x) {
local_sum[0] += exp(static_cast<float>(input[i]) - s_max);
local_sum += exp(static_cast<float>(input[i]) - s_max);
}
if (blockDim.x <= 32) {
warpReduceSum<float, 1>(local_sum);
if (blockDim.x <= WARP_SIZE) {
warpReduceSum<float>(local_sum);
} else {
blockReduceSum<float, 1>(local_sum);
blockReduceSum<float>(local_sum);
}
if (threadIdx.x == 0) {
s_sum = local_sum[0];
s_sum = local_sum;
}
__syncthreads();
for (int i = tid; i < n; i += blockDim.x) {
Expand Down Expand Up @@ -417,7 +400,7 @@ __global__ void softmax_stored_locally_multi_dim(
const int64_t row_offset = row * int((n + pack_size - 1) / pack_size);
const T* row_x = input + row_offset;
T* row_y = output + row_offset;
float local_max[1] = {-Inf<float>()};
float local_max = -Inf<float>();
#pragma unroll
for (int i = 0; i < num_packs; ++i) {
const int col = i * blockDim.x + tid;
Expand All @@ -427,7 +410,7 @@ __global__ void softmax_stored_locally_multi_dim(
#pragma unroll
for (int j = 0; j < pack_size; j++) {
buf[i * pack_size + j] = static_cast<float>(pack_x[j]);
local_max[0] = max(local_max[0], buf[i * pack_size + j]);
local_max = max(local_max, buf[i * pack_size + j]);
}
} else {
#pragma unroll
Expand All @@ -436,15 +419,15 @@ __global__ void softmax_stored_locally_multi_dim(
}
}
}
warpReduceMax<float, 1>(local_max, blockDim.x);
warpReduceMax<float>(local_max, blockDim.x);

float local_sum[1] = {0.0f};
float local_sum = 0.0f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = exp(buf[i] - local_max[0]);
local_sum[0] += buf[i];
buf[i] = exp(buf[i] - local_max);
local_sum += buf[i];
}
warpReduceSum<float, 1>(local_sum, blockDim.x);
warpReduceSum<float>(local_sum, blockDim.x);

T tmp_o;
ACT_T* pack_y = reinterpret_cast<ACT_T*>(&tmp_o);
Expand All @@ -453,7 +436,7 @@ __global__ void softmax_stored_locally_multi_dim(
const int col = i * blockDim.x + tid;
if (col < n / pack_size) {
for (int j = 0; j < pack_size; j++) {
pack_y[j] = ACT_T(buf[i * pack_size + j] / local_sum[0]);
pack_y[j] = ACT_T(buf[i * pack_size + j] / local_sum);
}
row_y[col] = tmp_o;
}
Expand Down Expand Up @@ -481,7 +464,7 @@ __global__ void softmax_block_smem(
const int64_t row_offset = row * int((n + pack_size - 1) / pack_size);
const T* row_x = input + row_offset;
T* row_y = output + row_offset;
float local_max[1] = {-Inf<float>()};
float local_max = -Inf<float>();

for (int pack_id = tid; pack_id < num_packs; pack_id += blockDim.x) {
T tmp_in = row_x[pack_id];
Expand All @@ -490,28 +473,28 @@ __global__ void softmax_block_smem(
for (int j = 0; j < pack_size; j++) {
float pack = pack_x[j];
buf[j * num_packs + pack_id] = pack;
local_max[0] = max(local_max[0], pack);
local_max = max(local_max, pack);
}
}
blockReduceMax<float, 1>(local_max); // reduce on a block of #blockDim.x
blockReduceMax<float>(local_max); // reduce on a block of #blockDim.x

__shared__ float s_max;
if (threadIdx.x == 0) {
s_max = local_max[0];
s_max = local_max;
}
__syncthreads();

float local_sum[1] = {0.0f};
float local_sum = 0.0f;
for (int i = tid; i < n; i += blockDim.x) {
float local_val = exp(buf[i] - s_max);
buf[i] = local_val;
local_sum[0] += local_val;
local_sum += local_val;
}
blockReduceSum<float, 1>(local_sum);
blockReduceSum<float>(local_sum);

__shared__ float s_sum;
if (threadIdx.x == 0) {
s_sum = local_sum[0];
s_sum = local_sum;
}
__syncthreads();

Expand Down

0 comments on commit 029ba1c

Please sign in to comment.