Skip to content

Commit

Permalink
Merge up to v2.5.5
Browse files Browse the repository at this point in the history
  • Loading branch information
nshepperd committed Feb 26, 2024
1 parent 9140339 commit 40c9981
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 40 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ FlashAttention-2 currently supports:
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
GPUs for now.
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.

## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
Expand Down
2 changes: 1 addition & 1 deletion csrc/cutlass
Submodule cutlass updated 61 files
+7 −2 CHANGELOG.md
+24 −3 CMakeLists.txt
+7 −0 PUBLICATIONS.md
+9 −4 README.md
+0 −38 cmake/version.h.in
+34 −0 cmake/version_extended.h.in
+1 −0 examples/02_dump_reg_shmem/CMakeLists.txt
+2 −2 examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu
+7 −7 examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu
+10 −8 examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt
+96 −49 examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu
+10 −0 examples/57_hopper_grouped_gemm/CMakeLists.txt
+1 −1 include/cute/arch/copy_sm90_desc.hpp
+2 −0 include/cute/atom/mma_atom.hpp
+2 −2 include/cute/util/print.hpp
+3 −0 include/cute/util/type_traits.hpp
+4 −0 include/cutlass/arch/mma_sm90.h
+1 −0 include/cutlass/bfloat16.h
+35 −1 include/cutlass/detail/layout.hpp
+12 −7 include/cutlass/epilogue/collective/builders/sm90_builder.inl
+1 −0 include/cutlass/epilogue/collective/default_epilogue.hpp
+32 −18 include/cutlass/epilogue/collective/default_epilogue_array.hpp
+76 −38 include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp
+1 −2 include/cutlass/epilogue/dispatch_policy.hpp
+28 −0 include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp
+1 −0 include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp
+57 −12 include/cutlass/epilogue/thread/linear_combination.h
+0 −183 include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_row_broadcast.h
+0 −519 include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h
+4 −8 include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl
+45 −29 include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp
+0 −514 include/cutlass/gemm/device/gemm_sparse_row_broadcast.h
+4 −7 include/cutlass/gemm/dispatch_policy.hpp
+12 −0 include/cutlass/gemm/group_array_problem_shape.hpp
+0 −191 include/cutlass/gemm/kernel/default_gemm_sparse_row_broadcast.h
+30 −35 include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp
+5 −7 include/cutlass/gemm/kernel/sm90_gemm_tma.hpp
+5 −7 include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp
+5 −7 include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp
+5 −7 include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp
+5 −7 include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp
+5 −7 include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp
+5 −7 include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp
+140 −86 include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp
+0 −400 include/cutlass/gemm/kernel/sparse_gemm_row_broadcast.h
+14 −6 include/cutlass/gemm/kernel/tile_scheduler_params.h
+80 −0 include/cutlass/version.h
+2 −2 pyproject.toml
+3 −3 python/cutlass/__init__.py
+6 −2 python/cutlass/backend/c_types.py
+23 −1 python/cutlass/backend/epilogue.py
+2 −2 python/cutlass/backend/evt/frontend/frontend_base.py
+0 −16 python/cutlass/backend/evt/passes/graph_drawer.py
+28 −18 python/cutlass/backend/gemm_operation.py
+1 −1 python/setup_library.py
+1 −1 python/setup_pycute.py
+1 −0 test/unit/gemm/device/CMakeLists.txt
+0 −19 test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu
+685 −0 test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu
+7 −20 test/unit/gemm/device/testbed_sparse.h
+1 −1 tools/util/include/cutlass/util/packed_stride.hpp
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
if (Is_dropout) {
if constexpr (Is_dropout) {
int warp_id = tidx / 32;
int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
Expand Down
30 changes: 18 additions & 12 deletions csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
Expand Down Expand Up @@ -101,7 +101,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)

template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
#endif
}

template<typename T>
Expand All @@ -115,7 +117,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
Expand All @@ -140,7 +142,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
Expand Down Expand Up @@ -185,7 +187,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 116 * 1024) {
if constexpr(!Is_dropout) { // 92KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
Expand All @@ -211,7 +213,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
Expand Down Expand Up @@ -244,7 +246,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 116 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
} else {
Expand All @@ -264,7 +266,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 136 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
} else {
Expand All @@ -276,7 +278,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 224;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
});
}
Expand All @@ -292,11 +294,15 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 176 * 1024) { // H100
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
} else { // A100, we don't do double buffering to save smem
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
} else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
if constexpr (!Is_dropout) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false>(params, stream);
}
}
});
}
31 changes: 16 additions & 15 deletions csrc/flash_attn/src/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
Expand Down Expand Up @@ -84,11 +84,11 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
Expand All @@ -114,7 +114,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 4) {
Expand Down Expand Up @@ -148,7 +148,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream)
template<typename T>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
Expand All @@ -158,7 +158,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
Expand All @@ -180,13 +180,14 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;

// auto dprops = at::cuda::getCurrentDeviceProperties();
int device, major, minor;
C10_CUDA_CHECK(cudaGetDevice(&device));
C10_CUDA_CHECK(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
C10_CUDA_CHECK(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
bool is_sm8x = major == 8 && minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if (is_sm8x) {
Expand Down Expand Up @@ -217,7 +218,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
bool is_sm8x = major == 8 && minor > 0;
// auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
Expand Down Expand Up @@ -259,7 +260,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
bool is_sm8x = major == 8 && minor > 0;
// auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
Expand Down Expand Up @@ -287,7 +288,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
Expand Down Expand Up @@ -315,7 +316,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
Expand Down Expand Up @@ -346,7 +347,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
Expand Down
6 changes: 4 additions & 2 deletions csrc/flash_attn/src/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,11 @@ struct Flash_bwd_kernel_traits : public Base {
// TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
// static constexpr int kPBlockN = kBlockN;
static_assert(kBlockN >= 64);
// Temporarily disabling this for hdim 256 on sm86 and sm89
// static_assert(kBlockN >= 64);
static_assert(kBlockN >= 32);
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
static constexpr int kPBlockN = 64;
static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
static constexpr int kSwizzlePdS = 3;
Expand Down
15 changes: 8 additions & 7 deletions csrc/flash_attn/src/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tenso
reduce_<zero_init>(tensor, max, max_op);
}

template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
reduce_(tensor, sum, sum_op);
thread_reduce_<zero_init>(tensor, sum, sum_op);
}

// Apply the exp to all the elements.
Expand Down Expand Up @@ -133,7 +133,7 @@ struct Softmax {
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum(scores, row_sum);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
Expand All @@ -152,15 +152,16 @@ struct Softmax {
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
Tensor scores_sum_cur = make_fragment_like(row_sum);
flash::reduce_sum(scores, scores_sum_cur);
#pragma unroll
for (int mi = 0; mi < size(row_sum); ++mi) { row_sum(mi) += scores_sum_cur(mi); }
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};

template<bool Is_dropout=false, bool Split=false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
Expand Down
Loading

0 comments on commit 40c9981

Please sign in to comment.