From 0aecd178a0560e24a10936783a5eea98d5a9d351 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Tue, 30 Apr 2024 22:52:03 -0700 Subject: [PATCH] Dedup GQA splitk kernel Summary: We want to keep use_tensor_cores = False option for gqa_attn_splitk function for backward compatibility (GPUs before Hopper, AMD). Reviewed By: sryap Differential Revision: D56687037 fbshipit-source-id: 0c98fe6327fd063b62d59aaaacd238cacbfb20c5 --- .../gen_ai/src/attention/attention.cpp | 11 +- .../gen_ai/src/attention/gqa_attn_splitk.cu | 1010 ++++++++++++++++- 2 files changed, 983 insertions(+), 38 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/attention.cpp b/fbgemm_gpu/experimental/gen_ai/src/attention/attention.cpp index ecf9f88f9..c02c1b4b8 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/attention.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/attention.cpp @@ -12,15 +12,15 @@ namespace fbgemm_gpu::gen_ai::attention { -std::tuple gqa_attn_splitk_cuda( +std::tuple gqa_attn_splitk( const at::Tensor& XQ, const at::Tensor& cache_K, const at::Tensor& cache_V, const at::Tensor& seq_positions, const double qk_scale, const int64_t num_split_ks, - const int64_t num_groups); - + const int64_t num_int4_kv_groups, + const bool use_tensor_cores); } // namespace fbgemm_gpu::gen_ai::attention TORCH_LIBRARY_FRAGMENT(fbgemm, m) { @@ -32,7 +32,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor seq_positions, " " float qk_scale, " " int num_split_ks, " - " int num_int4_kv_groups=1" + " int num_int4_kv_groups=1, " + " bool use_tensor_cores=True" ") -> (Tensor, Tensor, Tensor)"); } @@ -41,5 +42,5 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { "gqa_attn_splitk", torch::dispatch( c10::DispatchKey::CUDA, - TORCH_FN(fbgemm_gpu::gen_ai::attention::gqa_attn_splitk_cuda))); + TORCH_FN(fbgemm_gpu::gen_ai::attention::gqa_attn_splitk))); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/gqa_attn_splitk.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/gqa_attn_splitk.cu index 2c4b6f5c8..0ec084182 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/gqa_attn_splitk.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/gqa_attn_splitk.cu @@ -141,11 +141,32 @@ __floats2bfloat162_rn(float x, float y) { #endif +struct __align__(16) fx4 { + float x; + float y; + float z; + float w; + __host__ __device__ fx4() { + x = 0; + y = 0; + z = 0; + w = 0; + } +}; + +struct __align__(8) bfx4 { + __nv_bfloat162 vals[2]; +}; + // TODO: Include the following code from fbgemm_gpu header struct __align__(16) bfx8 { __nv_bfloat162 vals[4]; }; +struct __align__(8) halfx4 { + __half2 vals[2]; +}; + struct __align__(16) halfx8 { __half2 vals[4]; }; @@ -234,6 +255,202 @@ dequantize_permuted_int4(uint32_t packedVals, __half2 shift_scale) { return result; } +// struct __align__(16) bfx8 { +// __nv_bfloat162 vals[4]; +// }; + +// DEVICE_INLINE bfx4 dequantize_packed_int4(uint16_t vs, __half2 +// shift_scale_0); DEVICE_INLINE bfx8 dequantize_packed_int4( +// uint32_t v, +// __half2 shift_scale_0, +// __half2 shift_scale_1); +// DEVICE_INLINE bfx8 +// dequantize_permuted_int4(uint32_t packedVals, __half2 shift_scale); + +DEVICE_INLINE bfx4 dequantize_packed_int4(uint16_t vs, __half2 shift_scale_0) { + uint32_t v = vs; + // move 2nd byte to 3rd byte, so our bits are in 0x00FF00FF positions. + v = (v & 0xFF) | ((v & 0xFF00) << 8); + + halfx4 res; + res.vals[0] = hmul_short2(v & 0x000F000F, __float2half(32768)); + res.vals[1] = hmul_short2(v & 0x00F000F0, __float2half(32768)); + + // ~5% perf gain is observed with the explicit type conversions using + // __float2half on Nvidia A100 GPUs (https://fburl.com/diff/ss8372zw) using + // NVCC 11.0. Additionally, HIP compiler requires these explicit type + // conversions. + half shift_scale_0_x = __low2half(shift_scale_0); + half shift_scale_0_y = __high2half(shift_scale_0); + + // now, dequantize + auto shifts = __half2(shift_scale_0_y, shift_scale_0_y); + auto scales_lower = __half2( + __hmul(shift_scale_0_x, __float2half(512)), + __hmul(shift_scale_0_x, __float2half(512))); + auto scales_upper = __half2( + __hmul(shift_scale_0_x, __float2half(32)), + __hmul(shift_scale_0_x, __float2half(32))); + + auto r0 = __half22float2(__hfma2(res.vals[0], scales_lower, shifts)); + auto r1 = __half22float2(__hfma2(res.vals[1], scales_upper, shifts)); + + bfx4 result; + result.vals[0] = __floats2bfloat162_rn(r0.x, r1.x); + result.vals[1] = __floats2bfloat162_rn(r0.y, r1.y); + return result; +} + +DEVICE_INLINE bfx8 dequantize_packed_int4( + uint32_t v, + __half2 shift_scale_0, + __half2 shift_scale_1) { + halfx8 res; + res.vals[0] = hmul_short2(v & 0x000F000F, __float2half(32768)); + res.vals[1] = hmul_short2(v & 0x00F000F0, __float2half(32768)); + v >>= 8; + res.vals[2] = hmul_short2(v & 0x000F000F, __float2half(32768)); + res.vals[3] = hmul_short2(v & 0x00F000F0, __float2half(32768)); + + half shift_scale_0_x = __low2half(shift_scale_0); + half shift_scale_0_y = __high2half(shift_scale_0); + half shift_scale_1_x = __low2half(shift_scale_1); + half shift_scale_1_y = __high2half(shift_scale_1); + + // now, dequantize + auto shifts = __half2(shift_scale_0_y, shift_scale_1_y); + auto scales_lower = __half2( + __hmul(shift_scale_0_x, __float2half(512)), + __hmul(shift_scale_1_x, __float2half(512))); + auto scales_upper = __half2( + __hmul(shift_scale_0_x, __float2half(32)), + __hmul(shift_scale_1_x, __float2half(32))); + + auto r0 = __half22float2(__hfma2(res.vals[0], scales_lower, shifts)); + auto r1 = __half22float2(__hfma2(res.vals[1], scales_upper, shifts)); + auto r2 = __half22float2(__hfma2(res.vals[2], scales_lower, shifts)); + auto r3 = __half22float2(__hfma2(res.vals[3], scales_upper, shifts)); + + bfx8 result; + result.vals[0] = __floats2bfloat162_rn(r0.x, r1.x); + result.vals[1] = __floats2bfloat162_rn(r2.x, r3.x); + result.vals[2] = __floats2bfloat162_rn(r0.y, r1.y); + result.vals[3] = __floats2bfloat162_rn(r2.y, r3.y); + return result; +} + +DEVICE_INLINE float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = __low2float(val); + f_val.y = __high2float(val); + return f_val; +#elif defined(USE_ROCM) + float2 f_val; + f_val.x = __bfloat162float(val.x); + f_val.y = __bfloat162float(val.y); + return f_val; +#else + return __bfloat1622float2(val); +#endif +} + +#define CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK(NAME, NUM_GROUPS, ...) \ + switch (NUM_GROUPS) { \ + case 1: \ + NAME(1, __VA_ARGS__); \ + break; \ + case 2: \ + NAME(2, __VA_ARGS__); \ + break; \ + case 4: \ + NAME(4, __VA_ARGS__); \ + break; \ + case 8: \ + NAME(8, __VA_ARGS__); \ + break; \ + case 16: \ + TORCH_CHECK( \ + false, \ + "With head dim = 128 we're almost even with int8 at this point. Are you sure about this? Num groups:", \ + NUM_GROUPS); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported number of groups: ", NUM_GROUPS); \ + } + +DEVICE_INLINE float bfx4_dot(bfx4 a, bfx4 b) { + // float2 acc = {0, 0}; + // __nv_bfloat162 acc; + // acc.x = static_cast(0); + // acc.y = static_cast(0); + // TODO: need to be performed in float32? + auto a0 = bf1622float2(a.vals[0]); + auto a1 = bf1622float2(a.vals[1]); + auto b0 = bf1622float2(b.vals[0]); + auto b1 = bf1622float2(b.vals[1]); + return a0.x * b0.x + a0.y * b0.y + a1.x * b1.x + a1.y * b1.y; + + // acc = __hfma2(a.vals[0], b.vals[0], acc); + // acc = __hfma2(a.vals[1], b.vals[1], acc); + // auto r = bf1622float2(acc); + // return r.x + r.y; +} + +DEVICE_INLINE fx4 bfx4_scale_acc(fx4 acc, bfx4 a, float b) { + auto axy = bf1622float2(a.vals[0]); + auto azw = bf1622float2(a.vals[1]); + acc.x += axy.x * b; + acc.y += axy.y * b; + acc.z += azw.x * b; + acc.w += azw.y * b; + return acc; +} + +DEVICE_INLINE fx4 fx4_acc(fx4 a, fx4 b) { + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; + return a; +} + +DEVICE_INLINE bfx4 fx4_to_bfx4(fx4 a) { + bfx4 r; + r.vals[0] = __floats2bfloat162_rn(a.x, a.y); + r.vals[1] = __floats2bfloat162_rn(a.z, a.w); + return r; +} + +template +DEVICE_INLINE T shfl_xor( + unsigned shfl_sync_mask, + const T val, + int laneMask, + int width = kThreadsPerWarp) { +#if defined(__HIP_PLATFORM_AMD__) || CUDA_VERSION < 9000 + return __shfl_xor(val, laneMask, width); +#else + return __shfl_xor_sync(shfl_sync_mask, val, laneMask, width); +#endif +} + +template +DEVICE_INLINE T warpReduceSum(T val, uint32_t warp_mask = FINAL_MASK) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += shfl_xor(warp_mask, val, mask, 32); + return val; +} + +template +DEVICE_INLINE T warpReduceMax(T val, uint32_t warp_mask = FINAL_MASK) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = max(val, shfl_xor(warp_mask, val, mask, 32)); + return val; +} + template < typename kv_t, int KVQuantNumGroups = 1, @@ -962,40 +1179,534 @@ __global__ void gqa_attn_splitk_reduce_wmma_kernel( O[b][0][h][d] = acc / l_sum; } + +__global__ void gqa_attn_splitk_qk_kernel( + const at::PackedTensorAccessor32 XQ, + const at::PackedTensorAccessor64 + cache_K, + const at::PackedTensorAccessor32 + seq_positions, + at::PackedTensorAccessor32 QK_out) { + static_assert(kWarpsPerBlock <= kThreadsPerWarp, ""); + + // Each block handles a single batch and head + int32_t b = blockIdx.x; + int32_t h = blockIdx.y; + int32_t split_k = gridDim.z; + int32_t z = blockIdx.z; + + // Note: this is decoding case where we attent to current and all previous + // tokens. + int32_t max_t = seq_positions[b] + 1; + + int32_t warp_idx = threadIdx.y; + // need kWarpsPerBlock == blockDim.y; + // Need D_H == 128 + auto* q_ = &(XQ[b][0][h][0]); + + // assume cache_K/cache_V is contiguous + auto* cache_K_base = &cache_K[b][0][0][0]; + + // Load Q into registers in all warps. + // Each thread handles 4 D dimensions + bfx4 q_thread; + *reinterpret_cast(&q_thread) = + *(reinterpret_cast(q_) + threadIdx.x); + + // Compute S[MAX_T] = for i in range(T): S[t] = sum(Q[d] * K[t, d]) + // Split T across warps in a block, unroll loads to expose more + // parallelism. + + constexpr int32_t kTimeUnroll = 4; + bfx4 k_loads[kTimeUnroll]; + float qk_accs[kTimeUnroll]; + + const int32_t t_total = round_up(max_t, split_k); + const int32_t t_per_block = t_total / split_k; + const int32_t t_per_block_start = t_per_block * z; + const int32_t t_per_block_end = min(t_per_block * (z + 1), max_t); + + int32_t t_per_block_unroll = t_per_block_start + + ((t_per_block_end - t_per_block_start) / (kWarpsPerBlock * kTimeUnroll)) * + (kWarpsPerBlock * kTimeUnroll); + for (auto tt = t_per_block_start + warp_idx * kTimeUnroll; + tt < t_per_block_unroll; + tt += kWarpsPerBlock * kTimeUnroll) { +#pragma unroll kTimeUnroll + for (auto ttt = 0; ttt < kTimeUnroll; ++ttt) { + int32_t t = tt + ttt; + auto* k_ = cache_K_base + t * D_H; // &(cache_K[b][t][0][0]); + // bfx4 k_thread; + *reinterpret_cast(&k_loads[ttt]) = + *(reinterpret_cast(k_) + threadIdx.x); + } +#pragma unroll kTimeUnroll + for (auto ttt = 0; ttt < kTimeUnroll; ++ttt) { + float qk_acc = 0; + // int32_t t = tt + ttt; + qk_acc += bfx4_dot(q_thread, k_loads[ttt]); + qk_acc = warpReduceSum(qk_acc); + qk_accs[ttt] = qk_acc; + } + + if (threadIdx.x < kTimeUnroll) { + int32_t t = tt + threadIdx.x; + QK_out[b][h][t] = qk_accs[threadIdx.x]; + } + } + + constexpr int32_t kTimeUnroll1 = 1; + for (auto tt = t_per_block_unroll + warp_idx; tt < t_per_block_end; + tt += kWarpsPerBlock * kTimeUnroll1) { +#pragma unroll kTimeUnroll1 + for (auto ttt = 0; ttt < kTimeUnroll1; ++ttt) { + int32_t t = tt + ttt; + auto* k_ = cache_K_base + t * D_H; // &(cache_K[b][t][0][0]); + // bfx4 k_thread; + *reinterpret_cast(&k_loads[ttt]) = + *(reinterpret_cast(k_) + threadIdx.x); + } +#pragma unroll kTimeUnroll1 + for (auto ttt = 0; ttt < kTimeUnroll1; ++ttt) { + float qk_acc = 0; + int32_t t = tt + ttt; + qk_acc += bfx4_dot(q_thread, k_loads[ttt]); + + qk_acc = warpReduceSum(qk_acc); + QK_out[b][h][t] = qk_acc; + // // write accumulated sums to smem. + // if (threadIdx.x == 0) { + // smem[t] = qk_acc; + // } + } + } +} + +template +__global__ void gqa_attn_splitk_qk_int4_kernel( + const at::PackedTensorAccessor32 XQ, + const at::PackedTensorAccessor64 cache_K, + const at::PackedTensorAccessor32 + seq_positions, + at::PackedTensorAccessor32 QK_out) { + static_assert(kWarpsPerBlock <= kThreadsPerWarp, ""); + + // Each block handles a single batch and head + int32_t b = blockIdx.x; + int32_t h = blockIdx.y; + int32_t split_k = gridDim.z; + int32_t z = blockIdx.z; + + // Note: this is decoding case where we attent to current and all previous + // tokens. + int32_t max_t = seq_positions[b] + 1; + + int32_t warp_idx = threadIdx.y; + // need kWarpsPerBlock == blockDim.y; + // Need D_H == 128 + auto* q_ = &(XQ[b][0][h][0]); + + // assume cache_K/cache_V is contiguous + auto* cache_K_base = &cache_K[b][0][0][0]; + + int32_t int4_qparam_offset = 4; + int32_t qparam_offset = 0; + if (KVQuantNumGroups > 1) { + int4_qparam_offset = 4 * KVQuantNumGroups; + int32_t group_size = D_H / KVQuantNumGroups; + int32_t group_idx = threadIdx.x * 2 / group_size; + qparam_offset = 4 * group_idx; + } + int32_t D_H_bytes = D_H / 2 + int4_qparam_offset; + // Load Q into registers in all warps. + // Each thread handles 4 D dimensions + bfx4 q_thread; + *reinterpret_cast(&q_thread) = + *(reinterpret_cast(q_) + threadIdx.x); + + // Compute S[MAX_T] = for i in range(T): S[t] = sum(Q[d] * K[t, d]) + // Split T across warps in a block, unroll loads to expose more + // parallelism. + + constexpr int32_t kTimeUnroll = 4; + uint16_t k_qvals[kTimeUnroll]; + __half2 k_scales[kTimeUnroll]; + float qk_accs[kTimeUnroll]; + + const int32_t t_total = round_up(max_t, split_k); + const int32_t t_per_block = t_total / split_k; + const int32_t t_per_block_start = t_per_block * z; + const int32_t t_per_block_end = min(t_per_block * (z + 1), max_t); + + int32_t t_per_block_unroll = t_per_block_start + + ((t_per_block_end - t_per_block_start) / (kWarpsPerBlock * kTimeUnroll)) * + (kWarpsPerBlock * kTimeUnroll); + for (auto tt = t_per_block_start + warp_idx * kTimeUnroll; + tt < t_per_block_unroll; + tt += kWarpsPerBlock * kTimeUnroll) { +#pragma unroll kTimeUnroll + for (auto ttt = 0; ttt < kTimeUnroll; ++ttt) { + int32_t t = tt + ttt; + auto* k_ = cache_K_base + t * D_H; // &(cache_K[b][t][0][0]); + // bfx4 k_thread; + *reinterpret_cast(&k_qvals[ttt]) = + *(reinterpret_cast( + &k_[threadIdx.x * 2 + int4_qparam_offset])); + *reinterpret_cast(&k_scales[ttt]) = + *(reinterpret_cast(&k_[qparam_offset])); + } +#pragma unroll kTimeUnroll + for (auto ttt = 0; ttt < kTimeUnroll; ++ttt) { + float qk_acc = 0; + // int32_t t = tt + ttt; + qk_acc += bfx4_dot( + q_thread, dequantize_packed_int4(k_qvals[ttt], k_scales[ttt])); + qk_acc = warpReduceSum(qk_acc); + qk_accs[ttt] = qk_acc; + } + + if (threadIdx.x < kTimeUnroll) { + int32_t t = tt + threadIdx.x; + QK_out[b][h][t] = qk_accs[threadIdx.x]; + } + } + + constexpr int32_t kTimeUnroll1 = 1; + for (auto tt = t_per_block_unroll + warp_idx; tt < t_per_block_end; + tt += kWarpsPerBlock * kTimeUnroll1) { +#pragma unroll kTimeUnroll1 + for (auto ttt = 0; ttt < kTimeUnroll1; ++ttt) { + int32_t t = tt + ttt; + auto* k_ = cache_K_base + t * D_H_bytes; // &(cache_K[b][t][0][0]); + // bfx4 k_thread; + *reinterpret_cast(&k_qvals[ttt]) = + *(reinterpret_cast( + &k_[threadIdx.x * 2 + int4_qparam_offset])); + *reinterpret_cast(&k_scales[ttt]) = + *(reinterpret_cast(&k_[qparam_offset])); + } +#pragma unroll kTimeUnroll1 + for (auto ttt = 0; ttt < kTimeUnroll1; ++ttt) { + float qk_acc = 0; + int32_t t = tt + ttt; + qk_acc += bfx4_dot( + q_thread, dequantize_packed_int4(k_qvals[ttt], k_scales[ttt])); + + qk_acc = warpReduceSum(qk_acc); + QK_out[b][h][t] = qk_acc; + // // write accumulated sums to smem. + // if (threadIdx.x == 0) { + // smem[t] = qk_acc; + // } + } + } +} + +// TODO: can also fuse RoPe into this kernel. Doesn't seem worth it. +__global__ void gqa_attn_splitk_attn_kernel( + at::PackedTensorAccessor32 XQ_out, + at::PackedTensorAccessor32 attn_out, + at::PackedTensorAccessor32 seq_positions, + float qk_scale) { + static_assert(kWarpsPerBlock <= kThreadsPerWarp, ""); + + extern __shared__ __align__(16) float smem[]; + + // Each block handles a single batch and head + int32_t b = blockIdx.x; + int32_t h = blockIdx.y; + int32_t split_k = XQ_out.size(0); + + // Note: this is decoding case where we attent to current and all previous + // tokens. + int32_t max_t = seq_positions[b] + 1; + + int32_t warp_idx = threadIdx.y; + + // Each block handles single batch and head + // Accumulate over split-k inputs and write into smem + float max_qk_acc = std::numeric_limits::lowest(); + // each thread handles one T timestep. + // now, compute the normalization across all threads. + for (int32_t t = threadIdx.x + warp_idx * kThreadsPerWarp; t < max_t; + t += kWarpsPerBlock * kThreadsPerWarp) { + float qk_acc = XQ_out[b][h][t]; + qk_acc *= qk_scale; + max_qk_acc = max(max_qk_acc, qk_acc); + smem[t] = qk_acc; + } + + // each warp computes XQ^T and writes to gmem + + // Use shared reduction to compute max and compute softmax on shared memory. + // write max acc + max_qk_acc = warpReduceMax(max_qk_acc); + if (threadIdx.x == 0) { + smem[MAX_T + warp_idx] = max_qk_acc; + } + __syncthreads(); + if (threadIdx.x < kWarpsPerBlock) { + max_qk_acc = max(max_qk_acc, smem[MAX_T + threadIdx.x]); + } + + // shared across all threads in block + max_qk_acc = warpReduceMax(max_qk_acc); + // each warp computes partial sum of exp. + float softmax_denominator = 0.0f; + for (int32_t t = threadIdx.x + warp_idx * kThreadsPerWarp; t < max_t; + t += kWarpsPerBlock * kThreadsPerWarp) { + softmax_denominator += __expf(smem[t] - max_qk_acc); + } + softmax_denominator = warpReduceSum(softmax_denominator); + + __syncthreads(); + if (threadIdx.x == 0) { + smem[MAX_T + warp_idx] = softmax_denominator; + } + __syncthreads(); + // now, compute sum of exp(x - max(x)) over all intermediate results. + softmax_denominator = 0.0; + if (threadIdx.x < kWarpsPerBlock) { + softmax_denominator = smem[MAX_T + threadIdx.x]; + } + softmax_denominator = warpReduceSum(softmax_denominator); + + // now, compute the normalization across all threads. + for (int32_t t = threadIdx.x + warp_idx * kThreadsPerWarp; t < max_t; + t += kWarpsPerBlock * kThreadsPerWarp) { + attn_out[b][h][t] = __expf(smem[t] - max_qk_acc) / softmax_denominator; + } +} + +// TODO: can also fuse RoPe into this kernel. Doesn't seem worth it. +__global__ void gqa_attn_splitk_v_kernel( + at::PackedTensorAccessor32 attn_out, + at::PackedTensorAccessor64 cache_V, + at::PackedTensorAccessor32 O, + at::PackedTensorAccessor32 + seq_positions) { + static_assert(kWarpsPerBlock <= kThreadsPerWarp, ""); + + // Each block handles a single batch and head + int32_t b = blockIdx.x; + int32_t h = blockIdx.y; + int32_t split_k = gridDim.z; + int32_t z = blockIdx.z; + + // Note: this is decoding case where we attent to current and all previous + // tokens. + int32_t max_t = seq_positions[b] + 1; + + int32_t warp_idx = threadIdx.y; + + // need kWarpsPerBlock == blockDim.y; + // Need D_H == 128 + // auto* q_ = &(XQ[b][0][h][0]); + + // assume cache_K/cache_V is contiguous + // auto* cache_K_base = &cache_K[b][0][0][0]; + auto* cache_V_base = &cache_V[b][0][0][0]; + + constexpr int32_t kTimeUnroll = 4; + + // Split T across warps in a block + // each warp compute sum(t_subset) P[t] * V[t_subset, d] + // outputs are of size float[D] + + float ps[kTimeUnroll]; + bfx4 k_loads[kTimeUnroll]; + + const int32_t t_total = round_up(max_t, split_k); + const int32_t t_per_block = t_total / split_k; + const int32_t t_per_block_start = t_per_block * z; + const int32_t t_per_block_end = min(t_per_block * (z + 1), max_t); + + fx4 o_acc; + int32_t t_per_block_unroll = t_per_block_start + + ((t_per_block_end - t_per_block_start) / (kWarpsPerBlock * kTimeUnroll)) * + (kWarpsPerBlock * kTimeUnroll); + for (auto tt = t_per_block_start + warp_idx * kTimeUnroll; + tt < t_per_block_unroll; + tt += kWarpsPerBlock * kTimeUnroll) { +#pragma unroll kTimeUnroll + for (auto ttt = 0; ttt < kTimeUnroll; ++ttt) { + int32_t t = tt + ttt; + auto* v_ = cache_V_base + t * D_H; // &(cache_V[b][t][0][0]); + // bfx4 v_thread; + *reinterpret_cast(&k_loads[ttt]) = + *(reinterpret_cast(v_) + threadIdx.x); + ps[ttt] = attn_out[b][h][t]; + } + +#pragma unroll kTimeUnroll + for (auto ttt = 0; ttt < kTimeUnroll; ++ttt) { + o_acc = bfx4_scale_acc(o_acc, k_loads[ttt], ps[ttt]); + } + } + + constexpr int32_t kTimeUnroll1 = 1; + for (auto tt = t_per_block_unroll + warp_idx; tt < t_per_block_end; + tt += kWarpsPerBlock * kTimeUnroll1) { +#pragma unroll kTimeUnroll1 + for (auto ttt = 0; ttt < kTimeUnroll1; ++ttt) { + int32_t t = tt + ttt; + auto* v_ = cache_V_base + t * D_H; // &(cache_V[b][t][0][0]); + // bfx4 v_thread; + *reinterpret_cast(&k_loads[ttt]) = + *(reinterpret_cast(v_) + threadIdx.x); + ps[ttt] = attn_out[b][h][t]; + } + +#pragma unroll kTimeUnroll1 + for (auto ttt = 0; ttt < kTimeUnroll1; ++ttt) { + o_acc = bfx4_scale_acc(o_acc, k_loads[ttt], ps[ttt]); + } + } + extern __shared__ __align__(16) float smem[]; + // accumulate in shared memory + *(reinterpret_cast(&smem[0]) + warp_idx * kThreadsPerWarp + + threadIdx.x) = o_acc; + __syncthreads(); + // accumulate partial sums + // note: seemed marginally faster than smem reduction in benchmarks. + if (warp_idx == 0) { + fx4 r; + for (int32_t w = 0; w < kWarpsPerBlock; ++w) { + auto partial_r = *( + reinterpret_cast(&smem[0]) + w * kThreadsPerWarp + threadIdx.x); + r = fx4_acc(r, partial_r); + } + // write output D row + *(reinterpret_cast(&O[z][b][0][h][0]) + threadIdx.x) = + *reinterpret_cast(&r); + } +} + +// TODO: can also fuse RoPe into this kernel. Doesn't seem worth it. +template +__global__ void gqa_attn_splitk_v_int4_kernel( + at::PackedTensorAccessor32 attn_out, + at::PackedTensorAccessor64 cache_V, + at::PackedTensorAccessor32 O, + at::PackedTensorAccessor32 + seq_positions) { + static_assert(kWarpsPerBlock <= kThreadsPerWarp, ""); + + // Each block handles a single batch and head + int32_t b = blockIdx.x; + int32_t h = blockIdx.y; + int32_t split_k = gridDim.z; + int32_t z = blockIdx.z; + + // Note: this is decoding case where we attent to current and all previous + // tokens. + int32_t max_t = seq_positions[b] + 1; + + int32_t warp_idx = threadIdx.y; + + // need kWarpsPerBlock == blockDim.y; + // Need D_H == 128 + // auto* q_ = &(XQ[b][0][h][0]); + + // assume cache_K/cache_V is contiguous + // auto* cache_K_base = &cache_K[b][0][0][0]; + auto* cache_V_base = &cache_V[b][0][0][0]; + int32_t int4_qparam_offset = 4; + int32_t qparam_idx = 0; + if (KVQuantNumGroups > 1) { + int4_qparam_offset = 4 * KVQuantNumGroups; + int32_t group_size = D_H / KVQuantNumGroups; + int32_t group_idx = threadIdx.x * 2 / group_size; + qparam_idx = 4 * group_idx; + } + int32_t D_H_bytes = D_H / 2 + int4_qparam_offset; + constexpr int32_t kTimeUnroll = 4; + + // Split T across warps in a block + // each warp compute sum(t_subset) P[t] * V[t_subset, d] + // outputs are of size float[D] + + float ps[kTimeUnroll]; + uint16_t k_qvals[kTimeUnroll]; + __half2 k_scales[kTimeUnroll]; + + const int32_t t_total = round_up(max_t, split_k); + const int32_t t_per_block = t_total / split_k; + const int32_t t_per_block_start = t_per_block * z; + const int32_t t_per_block_end = min(t_per_block * (z + 1), max_t); + + fx4 o_acc; + int32_t t_per_block_unroll = t_per_block_start + + ((t_per_block_end - t_per_block_start) / (kWarpsPerBlock * kTimeUnroll)) * + (kWarpsPerBlock * kTimeUnroll); + for (auto tt = t_per_block_start + warp_idx * kTimeUnroll; + tt < t_per_block_unroll; + tt += kWarpsPerBlock * kTimeUnroll) { +#pragma unroll kTimeUnroll + for (auto ttt = 0; ttt < kTimeUnroll; ++ttt) { + int32_t t = tt + ttt; + auto* v_ = cache_V_base + t * D_H_bytes; // &(cache_V[b][t][0][0]); + // bfx4 v_thread; + *reinterpret_cast(&k_qvals[ttt]) = + *(reinterpret_cast( + &v_[threadIdx.x * 2 + int4_qparam_offset])); + *reinterpret_cast(&k_scales[ttt]) = + *(reinterpret_cast(&v_[qparam_idx])); + ps[ttt] = attn_out[b][h][t]; + } + +#pragma unroll kTimeUnroll + for (auto ttt = 0; ttt < kTimeUnroll; ++ttt) { + o_acc = bfx4_scale_acc( + o_acc, dequantize_packed_int4(k_qvals[ttt], k_scales[ttt]), ps[ttt]); + } + } + + constexpr int32_t kTimeUnroll1 = 1; + for (auto tt = t_per_block_unroll + warp_idx; tt < t_per_block_end; + tt += kWarpsPerBlock * kTimeUnroll1) { +#pragma unroll kTimeUnroll1 + for (auto ttt = 0; ttt < kTimeUnroll1; ++ttt) { + int32_t t = tt + ttt; + auto* v_ = cache_V_base + t * D_H_bytes; // &(cache_V[b][t][0][0]); + // bfx4 v_thread; + *reinterpret_cast(&k_qvals[ttt]) = + *(reinterpret_cast( + &v_[threadIdx.x * 2 + int4_qparam_offset])); + *reinterpret_cast(&k_scales[ttt]) = + *(reinterpret_cast(&v_[qparam_idx])); + ps[ttt] = attn_out[b][h][t]; + } + +#pragma unroll kTimeUnroll1 + for (auto ttt = 0; ttt < kTimeUnroll1; ++ttt) { + o_acc = bfx4_scale_acc( + o_acc, dequantize_packed_int4(k_qvals[ttt], k_scales[ttt]), ps[ttt]); + } + } + extern __shared__ __align__(16) float smem[]; + // accumulate in shared memory + *(reinterpret_cast(&smem[0]) + warp_idx * kThreadsPerWarp + + threadIdx.x) = o_acc; + __syncthreads(); + // accumulate partial sums + // note: seemed marginally faster than smem reduction in benchmarks. + if (warp_idx == 0) { + fx4 r; + for (int32_t w = 0; w < kWarpsPerBlock; ++w) { + auto partial_r = *( + reinterpret_cast(&smem[0]) + w * kThreadsPerWarp + threadIdx.x); + r = fx4_acc(r, partial_r); + } + // write output D row + *(reinterpret_cast(&O[z][b][0][h][0]) + threadIdx.x) = + *reinterpret_cast(&r); + } +} + } // namespace -/// @ingroup experimental-gen-ai-attention -/// -/// @brief Decoding Grouped Query Attention Split-K w/ BF16/INT4 KV -/// -/// The CUDA implementation of decoding Grouped Query Attention (GQA) -/// that supports BF16 and INT4 KV cache and BF16 input query. It -/// currently only supports the max context length of 16384, the fixed -/// head dimension of 128, and only one KV cache head. It supports an -/// arbitrary number of query heads. -/// -/// @param XQ Input query; shape = (B, 1, H_Q, D), where B = batch -/// size, H_Q = num query heads, D = head dimension (fixed -/// to 128) -/// @param cache_K K cache; shape = (B, MAX_T, H_KV, D), where MAX_T = -/// max context length (fixed to 16384), and H_KV = num -/// KV cache heads (fixed to 1) -/// @param cache_V V cache; shape = (B, MAX_T, H_KV, D) -/// @param seq_positions Sequence position (contains the actual -/// length of each token); shape = (B) -/// @param qk_scale The scale that is applied after QK^T -/// @param num_split_ks The number of split Ks (controlling the -/// amount of parallelism in the context length -/// dimension (MAX_T)) -/// @param num_int4_kv_groups The number of groups for group-wise INT4 -/// quantization for each KV token (each -/// group uses the same scale and bias for -/// quantization) -/// -/// @return A tuple of the combined split-K output, the -/// non-combined split-K output, and the split-K metadata -/// (containing max QK^T, and softmax(QK^T) head sum) -std::tuple gqa_attn_splitk_cuda( +std::tuple gqa_attn_splitk_wmma_impl( const at::Tensor& XQ, const at::Tensor& cache_K, const at::Tensor& cache_V, @@ -1004,10 +1715,16 @@ std::tuple gqa_attn_splitk_cuda( const int64_t num_split_ks, const int64_t num_int4_kv_groups) { auto dprops = at::cuda::getCurrentDeviceProperties(); +#ifdef USE_ROCM + TORCH_CHECK( + false, + "gqa_attn_splitk with use_tensor_cores=True is not supported on ROCm"); +#else TORCH_CHECK( dprops->major >= 8, - "Too old compute capability major version to run gqa_attn_splitk_wmma ", + "Too old compute capability major version to run gqa_attn_splitk_wmma (use_tensor_cores=True)", dprops->major); +#endif at::OptionalDeviceGuard guard(XQ.device()); TORCH_CHECK(XQ.is_cuda()); @@ -1110,4 +1827,231 @@ std::tuple gqa_attn_splitk_cuda( return {O, out_splitK, metadata}; } +std::tuple gqa_attn_splitk_impl( + const at::Tensor& XQ, // [B, 1, H, D] + const at::Tensor& cache_K, // [B, MAX_T, 1, D] + const at::Tensor& cache_V, // [B, MAX_T, 1, D] + const at::Tensor& seq_positions, // [B] + const double qk_scale, + const int64_t split_k, + const c10::optional& num_groups) { + at::OptionalDeviceGuard guard(XQ.device()); + TORCH_CHECK(XQ.is_cuda()); + TORCH_CHECK(cache_K.is_cuda()); + TORCH_CHECK(cache_V.is_cuda()); + TORCH_CHECK(cache_K.is_contiguous()); + TORCH_CHECK(cache_V.is_contiguous()); + + TORCH_CHECK(seq_positions.is_cuda()); + + TORCH_CHECK(cache_K.size(1) <= MAX_T); + TORCH_CHECK( + cache_K.size(2) == 1, + "gqa_attn_splitk only supports for number of K heads 1"); + TORCH_CHECK( + cache_V.size(2) == 1, + "gqa_attn_splitk only supports for number of V heads 1"); + if (cache_K.dtype() == at::kBFloat16) { + TORCH_CHECK(cache_K.size(3) == D_H); + } else { + auto num_groups_ = num_groups ? num_groups.value() : 1; + auto qparam_offset = 4 * num_groups_; + TORCH_CHECK(cache_K.size(3) == D_H / 2 + qparam_offset); + } + + auto B = XQ.size(0); + auto H = XQ.size(2); + auto QK_out = + at::empty({B, H, cache_K.size(1)}, XQ.options().dtype(at::kFloat)); + + if (B == 0) { + return {at::empty_like(XQ), at::empty_like(QK_out), QK_out}; + } + + { + dim3 blocks(B, H, split_k); + dim3 threads(kThreadsPerWarp, kWarpsPerBlock); + + if (cache_K.dtype() == at::kBFloat16) { + gqa_attn_splitk_qk_kernel<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + XQ.packed_accessor32(), + cache_K.packed_accessor64(), + seq_positions.packed_accessor32(), + QK_out.packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { +#define CALL_MQA_ATTN_SPLITK_QK_INT4_GROUPWISE_KERNEL(NUM_GROUPS, ...) \ + gqa_attn_splitk_qk_int4_kernel \ + <<>>( \ + XQ.packed_accessor32(), \ + cache_K.packed_accessor64(), \ + seq_positions \ + .packed_accessor32(), \ + QK_out.packed_accessor32()); + + auto num_groups_ = num_groups ? num_groups.value() : 1; + CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK( + CALL_MQA_ATTN_SPLITK_QK_INT4_GROUPWISE_KERNEL, num_groups_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + +#undef CALL_MQA_ATTN_SPLITK_QK_INT4_GROUPWISE_KERNEL + } + } + + const auto device = XQ.get_device(); + auto attn_out = at::empty_like(QK_out); + { + dim3 blocks(B, H); + dim3 threads(kThreadsPerWarp, kWarpsPerBlock); + + int32_t smem_softmax = + MAX_T * sizeof(float) + kWarpsPerBlock * sizeof(float); + int32_t smem = smem_softmax; + + if (smem > SMEM_ADJUST_THRESHOLD) { + set_gpu_max_dynamic_shared_memory( + gqa_attn_splitk_attn_kernel, smem, device); + } + + gqa_attn_splitk_attn_kernel<<< + blocks, + threads, + smem, + at::cuda::getCurrentCUDAStream()>>>( + QK_out.packed_accessor32(), + attn_out.packed_accessor32(), + seq_positions.packed_accessor32(), + qk_scale); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + auto O = at::empty({split_k, B, 1, H, D_H}, XQ.options().dtype(at::kFloat)); + { + dim3 blocks(B, H, split_k); + dim3 threads(kThreadsPerWarp, kWarpsPerBlock); + + int32_t smem_output = D_H * sizeof(float) * kWarpsPerBlock; + int32_t smem = smem_output; + const bool set_max_dynamic_smem = smem > SMEM_ADJUST_THRESHOLD; + + if (cache_K.dtype() == at::kBFloat16) { + if (set_max_dynamic_smem) { + set_gpu_max_dynamic_shared_memory( + gqa_attn_splitk_v_kernel, smem, device); + } + gqa_attn_splitk_v_kernel<<< + blocks, + threads, + smem, + at::cuda::getCurrentCUDAStream()>>>( + attn_out.packed_accessor32(), + cache_V.packed_accessor64(), + O.packed_accessor32(), + seq_positions.packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { +#define CALL_MQA_ATTN_SPLITKV_INT4_GROUPWISE_KERNEL(NUM_GROUPS, ...) \ + if (set_max_dynamic_smem) { \ + set_gpu_max_dynamic_shared_memory( \ + gqa_attn_splitk_v_int4_kernel, smem, device); \ + } \ + gqa_attn_splitk_v_int4_kernel \ + <<>>( \ + attn_out.packed_accessor32(), \ + cache_V.packed_accessor64(), \ + O.packed_accessor32(), \ + seq_positions \ + .packed_accessor32()); + + auto num_groups_ = num_groups ? num_groups.value() : 1; + CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK( + CALL_MQA_ATTN_SPLITKV_INT4_GROUPWISE_KERNEL, num_groups_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + +#undef CALL_MQA_ATTN_SPLITKV_INT4_GROUPWISE_KERNEL + } + } + + return {O.sum(0, false, at::kBFloat16), attn_out, QK_out}; +} + +/// @ingroup experimental-gen-ai-attention +/// +/// @brief Decoding Grouped Query Attention Split-K w/ BF16/INT4 KV +/// +/// The CUDA implementation of decoding Grouped Query Attention (GQA) +/// that supports BF16 and INT4 KV cache and BF16 input query. It +/// currently only supports the max context length of 16384, the fixed +/// head dimension of 128, and only one KV cache head. It supports an +/// arbitrary number of query heads. +/// +/// @param XQ Input query; shape = (B, 1, H_Q, D), where B = batch +/// size, H_Q = num query heads, D = head dimension (fixed +/// to 128) +/// @param cache_K K cache; shape = (B, MAX_T, H_KV, D), where MAX_T = +/// max context length (fixed to 16384), and H_KV = num +/// KV cache heads (fixed to 1) +/// @param cache_V V cache; shape = (B, MAX_T, H_KV, D) +/// @param seq_positions Sequence position (contains the actual +/// length of each token); shape = (B) +/// @param qk_scale The scale that is applied after QK^T +/// @param num_split_ks The number of split Ks (controlling the +/// amount of parallelism in the context length +/// dimension (MAX_T)) +/// @param num_int4_kv_groups The number of groups for group-wise INT4 +/// quantization for each KV token (each +/// group uses the same scale and bias for +/// quantization) +/// +/// @param use_tensor_cores Whether to use tensor core wmma instructions +/// for fast implementations +/// +/// @return A tuple of the combined split-K output, the +/// non-combined split-K output, and the split-K metadata +/// (containing max QK^T, and softmax(QK^T) head sum) +std::tuple gqa_attn_splitk( + const at::Tensor& XQ, + const at::Tensor& cache_K, + const at::Tensor& cache_V, + const at::Tensor& seq_positions, + const double qk_scale, + const int64_t num_split_ks, + const int64_t num_int4_kv_groups, + const bool use_tensor_cores) { + if (use_tensor_cores) { + const auto dprops = at::cuda::getCurrentDeviceProperties(); +#ifdef USE_ROCM + TORCH_CHECK( + false, + "gqa_attn_splitk with use_tensor_cores=True is not supported on ROCm"); +#else + TORCH_CHECK( + dprops->major >= 8, + "Too old compute capability major version to run gqa_attn_splitk with ", + "use_tensor_cores=True (", + dprops->major, + ")"); +#endif + return gqa_attn_splitk_wmma_impl( + XQ, + cache_K, + cache_V, + seq_positions, + qk_scale, + num_split_ks, + num_int4_kv_groups); + } + return gqa_attn_splitk_impl( + XQ, + cache_K, + cache_V, + seq_positions, + qk_scale, + num_split_ks, + num_int4_kv_groups); +} + } // namespace fbgemm_gpu::gen_ai::attention