From 04ab62eeb1dc91f90166ac6e0ce62b3c4e61fa34 Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Fri, 27 Sep 2024 22:37:34 +0000 Subject: [PATCH 01/16] Add CK FA v2 solver --- CMakeLists.txt | 2 +- src/CMakeLists.txt | 3 +- src/include/miopen/mha/solvers.hpp | 16 ++ src/solution.cpp | 45 +--- .../mha/mha_ck_fa_v2_solver_forward.cpp | 241 ++++++++++++++++++ 5 files changed, 265 insertions(+), 42 deletions(-) create mode 100644 src/solver/mha/mha_ck_fa_v2_solver_forward.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6029ed49b9..0fdad458f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -327,7 +327,7 @@ add_compile_definitions($<$:HIP_COMPILER_FLAGS=${HIP_COMPI # HIP if( MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU") if(MIOPEN_USE_COMPOSABLEKERNEL) - find_package(composable_kernel 1.0.0 COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations) + find_package(composable_kernel 1.0.0 COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations device_mha_operations) endif() if( MIOPEN_BACKEND STREQUAL "HIPNOGPU") set(MIOPEN_MODE_NOGPU 1) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 78a3a0b9cc..57e286dfed 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -300,6 +300,7 @@ set( MIOpen_Source solver/layernorm/forward_layernorm2d_ck.cpp solver/layernorm/forward_layernorm4d_ck.cpp solver/layernorm/forward_t5layernorm.cpp + solver/mha/mha_ck_fa_v2_solver_forward.cpp solver/mha/mha_solver_backward.cpp solver/mha/mha_solver_forward.cpp solver/pooling/forward2d.cpp @@ -822,7 +823,7 @@ target_include_directories(MIOpen PUBLIC ) if(MIOPEN_USE_COMPOSABLEKERNEL) -set(MIOPEN_CK_LINK_FLAGS composable_kernel::device_other_operations composable_kernel::device_gemm_operations composable_kernel::device_conv_operations composable_kernel::device_reduction_operations hip::host) +set(MIOPEN_CK_LINK_FLAGS composable_kernel::device_other_operations composable_kernel::device_gemm_operations composable_kernel::device_conv_operations composable_kernel::device_reduction_operations composable_kernel::device_mha_operations hip::host) endif() if(WIN32) diff --git a/src/include/miopen/mha/solvers.hpp b/src/include/miopen/mha/solvers.hpp index 6bac473a71..e03009e194 100644 --- a/src/include/miopen/mha/solvers.hpp +++ b/src/include/miopen/mha/solvers.hpp @@ -77,6 +77,22 @@ struct MhaBackward final : MhaSolver MIOPEN_INTERNALS_EXPORT bool MayNeedWorkspace() const override; }; +struct MhaCKFlashAttentionV2Forward final : MhaSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + MIOPEN_INTERNALS_EXPORT bool IsApplicable(const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const override; + + MIOPEN_INTERNALS_EXPORT ConvSolution GetSolution(const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const override; + + MIOPEN_INTERNALS_EXPORT std::size_t GetWorkspaceSize(const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const override; + + MIOPEN_INTERNALS_EXPORT bool MayNeedWorkspace() const override; +}; + } // namespace mha } // namespace solver diff --git a/src/solution.cpp b/src/solution.cpp index 4fe447423a..06842961f6 100644 --- a/src/solution.cpp +++ b/src/solution.cpp @@ -394,47 +394,12 @@ void Solution::RunImpl(Handle& handle, } }(); - if(invoker) - { - (*invoker)(handle, invoke_ctx); - return; - } - - solver::mha::MhaForward mhaForward; - solver::mha::MhaBackward mhaBackward; - - if(!kernels.empty()) - { - const auto ctx = ExecutionContext{&handle}; - const auto mha_solution = GetSolver() == mhaForward.SolverDbId() - ? mhaForward.GetSolution(ctx, problem_description) - : mhaBackward.GetSolution(ctx, problem_description); - auto kernel_handles = std::vector{std::begin(kernels), std::end(kernels)}; - - invoker = (*mha_solution.invoker_factory)(kernel_handles); - (*invoker)(handle, invoke_ctx); - return; - } - - const auto net_cfg = problem_description.MakeNetworkConfig(); - invoker = handle.GetInvoker(net_cfg, GetSolver()); - - if(invoker) - { - (*invoker)(handle, invoke_ctx); - return; - } - - auto ctx = ExecutionContext{&handle}; - - const auto mha_solution = GetSolver() == mhaForward.SolverDbId() - ? mhaForward.GetSolution(ctx, problem_description) - : mhaBackward.GetSolution(ctx, problem_description); + const auto algo = AlgorithmName{"MHA"}; + const auto solvers = solver::SolverContainer{}; - invoker = - handle.PrepareInvoker(*mha_solution.invoker_factory, mha_solution.construction_params); - handle.RegisterInvoker(*invoker, net_cfg, GetSolver().ToString()); - (*invoker)(handle, invoke_ctx); + solvers.ExecutePrimitive(handle, problem_description, algo, invoke_ctx); } void Solution::RunImpl(Handle& handle, diff --git a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp new file mode 100644 index 0000000000..4e1764cf3b --- /dev/null +++ b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp @@ -0,0 +1,241 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include + +#ifdef MIOPEN_USE_COMPOSABLEKERNEL +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/fmha_fwd.hpp" +#include "ck/stream_config.hpp" +#endif + +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_FA_CK_V2_FWD) + +namespace miopen { + +namespace solver { + +namespace mha { + +static std::string Convert(miopenDataType_t dataType) +{ + switch(dataType) + { + case miopenHalf: + { + return "fp16"; + } + case miopenBFloat16: + { + return "bfp16"; + } + default: + { + MIOPEN_THROW("Unsupported datatype provided"); + } + } +} + +bool MhaCKFlashAttentionV2Forward::IsApplicable([[maybe_unused]] const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const +{ + #if MIOPEN_USE_COMPOSABLEKERNEL + if(!problem.IsForward()) + { + return false; + } + + if(!StartsWith(context.GetStream().GetDeviceName(), "gfx94")) + { + return false; + } + + auto& descsFwd = problem.GetDescsForward(); + auto [N_k, H_k, S_k, D_k] = miopen::tien<4>(descsFwd.kDesc.GetLengths()); + auto [N_q, H_q, S_q, D_q] = miopen::tien<4>(descsFwd.qDesc.GetLengths()); + + return !env::disabled(MIOPEN_DEBUG_FA_CK_V2_FWD) // + && H_q == H_k // Replace with H_q % H_k == 0 once we add support for MQA & GQA. + && H_q <= 256 // + && H_q % 8 == 0 // No padding support yet which means it needs to be multiple of 8. + && descsFwd.kDesc.IsPacked() // + && descsFwd.qDesc.IsPacked() // + && descsFwd.vDesc.IsPacked() // + && descsFwd.oDesc.IsPacked() // + && descsFwd.biasDesc.IsPacked() // + && descsFwd.biasDesc.GetType() == miopenHalf // + && descsFwd.kDesc.GetType() == miopenHalf // + && descsFwd.qDesc.GetType() == miopenHalf // + && descsFwd.vDesc.GetType() == miopenHalf // + && descsFwd.oDesc.GetType() == miopenHalf; // +#else + return false; +#endif +} + +std::size_t MhaCKFlashAttentionV2Forward::GetWorkspaceSize([[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::mha::ProblemDescription& problem) const +{ + return 0; +} + +ConvSolution MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const +{ + auto result = ConvSolution{miopenStatusSuccess}; + result.workspace_sz = 0; + + const miopen::mha::MhaInputDescsForward& descsFwd = problem.GetDescsForward(); + auto [N_k, H_k, S_k, D_k] = miopen::tien<4>(descsFwd.kDesc.GetLengths()); + auto [N_q, H_q, S_q, D_q] = miopen::tien<4>(descsFwd.qDesc.GetLengths()); + auto [N_v, H_v, S_v, D_v] = miopen::tien<4>(descsFwd.vDesc.GetLengths()); + + ck_tile::index_t batch = N_q; + ck_tile::index_t seqlen_q = S_q; + ck_tile::index_t seqlen_k = S_k; + ck_tile::index_t hdim_q = D_q; + ck_tile::index_t hdim_v = D_v; + ck_tile::index_t nhead = H_q; + ck_tile::index_t nhead_k = H_k; + ck_tile::index_t nhead_q = H_q; + + bool is_group_mode = false; + bool o_perm = true, i_perm = true; // if true, will be batch * nhead * seqlen * hdim + + float scale_s = descsFwd.scale; + float scale_p = 1.0; + float scale_o = 1.0; + + const ck_tile::index_t shape_seqlen_q = seqlen_q; + const ck_tile::index_t shape_seqlen_k = seqlen_k; + const ck_tile::index_t max_seqlen_q = seqlen_q; + const ck_tile::index_t max_seqlen_k = seqlen_k; + + fmha_fwd_traits fmha_traits; + fmha_traits.hdim_q = hdim_q; + fmha_traits.hdim_v = hdim_v; + fmha_traits.data_type = Convert(descsFwd.qDesc.GetType()); + fmha_traits.is_group_mode = is_group_mode; + fmha_traits.is_v_rowmajor = false; + fmha_traits.mask_type = mask_enum::no_mask; + fmha_traits.has_lse = false; + fmha_traits.is_v_rowmajor = false; + fmha_traits.do_fp8_static_quant = false; + + fmha_fwd_args fmha_args; + fmha_args.batch = batch; + fmha_args.hdim_q = hdim_q; + fmha_args.hdim_v = hdim_v; + fmha_args.nhead_q = nhead_q; + fmha_args.nhead_k = nhead_k; + fmha_args.stride_q = (i_perm ? hdim_q : nhead * hdim_q); + fmha_args.stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + fmha_args.stride_v = (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + fmha_args.batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + fmha_args.batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + fmha_args.batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); + fmha_args.seqlen_k = shape_seqlen_k; + fmha_args.max_seqlen_q = max_seqlen_q; + + // These are used for group mode, and we are in batch right now. + fmha_args.seqstart_q_ptr = nullptr; + fmha_args.seqstart_k_ptr = nullptr; + + // Batch does not support padding, and we aren't using kvcache yet. + fmha_args.seqlen_k_ptr = nullptr; + + fmha_args.scale_s = scale_s; + fmha_args.scale_p = scale_p; + fmha_args.scale_o = scale_o; + fmha_args.stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + fmha_args.stride_o = (o_perm ? hdim_v : nhead * hdim_v); + fmha_args.nhead_stride_bias = i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k; + fmha_args.nhead_stride_lse = shape_seqlen_q; + fmha_args.nhead_stride_o = o_perm ? shape_seqlen_q * hdim_v : hdim_v; + fmha_args.window_size_left = 0; + fmha_args.window_size_right = 0; + fmha_args.mask_type = static_cast(fmha_traits.mask_type); + + fmha_args.s_randval = false; + // Since we aren't storing the random values these will be unused for now. + fmha_args.stride_randval = max_seqlen_k; + fmha_args.nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + fmha_args.batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + + result.invoker_factory = [=](const std::vector&) { + return [=] (const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) params = raw_params.CastTo(); + const auto& dataFwd = params.GetDataForward(); + + fmha_fwd_traits fmha_runtime_traits = fmha_traits; + fmha_fwd_args fmha_runtime_args = fmha_args; + + fmha_runtime_traits.bias_type = dataFwd.biasData != nullptr ? bias_enum::elementwise_bias : bias_enum::no_bias; + fmha_runtime_traits.has_dropout = dataFwd.dropoutProbabilityData != nullptr; + + float probability = 0; + uint64_t seed = 0; + uint64_t offset = 0; + if(fmha_runtime_traits.has_dropout) + { + hipMemcpy(&probability, dataFwd.dropoutProbabilityData, sizeof(float), hipMemcpyKind::hipMemcpyDeviceToHost); + hipMemcpy(&seed, dataFwd.dropoutSeedData, sizeof(uint64_t), hipMemcpyKind::hipMemcpyDeviceToHost); + hipMemcpy(&offset, dataFwd.dropoutOffsetData, sizeof(uint64_t), hipMemcpyKind::hipMemcpyDeviceToHost); + } + fmha_runtime_args.p_drop = probability; + fmha_runtime_args.drop_seed_offset = {seed, offset}; + + fmha_runtime_args.bias_ptr = dataFwd.biasData; + fmha_runtime_args.q_ptr = dataFwd.qData; + fmha_runtime_args.k_ptr = dataFwd.kData; + fmha_runtime_args.v_ptr = dataFwd.vData; + fmha_runtime_args.rand_val_ptr = nullptr; + fmha_runtime_args.o_ptr = dataFwd.oData; + + // Create stream_config, and set it to not time kernel. + ck_tile::stream_config stream_config; + stream_config.stream_id_ = handle_.GetStream(); + + { + HipEventProfiler profiler(handle_); + fmha_fwd(fmha_runtime_traits, fmha_runtime_args, stream_config); + } + }; + }; + + return result; +} + +bool MhaCKFlashAttentionV2Forward::MayNeedWorkspace() const { return false; } + +} // namespace mha + +} // namespace solver + +} // namespace miopen From 5b4a58bd4111cbda3382ec9d54ee5b63e7135705 Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Fri, 27 Sep 2024 22:46:53 +0000 Subject: [PATCH 02/16] Apply formatting --- src/include/miopen/mha/solvers.hpp | 20 ++- .../mha/mha_ck_fa_v2_solver_forward.cpp | 168 ++++++++++-------- 2 files changed, 102 insertions(+), 86 deletions(-) diff --git a/src/include/miopen/mha/solvers.hpp b/src/include/miopen/mha/solvers.hpp index e03009e194..55423f63c7 100644 --- a/src/include/miopen/mha/solvers.hpp +++ b/src/include/miopen/mha/solvers.hpp @@ -79,16 +79,22 @@ struct MhaBackward final : MhaSolver struct MhaCKFlashAttentionV2Forward final : MhaSolver { - const std::string& SolverDbId() const override { return GetSolverDbId(); } + const std::string& SolverDbId() const override + { + return GetSolverDbId(); + } - MIOPEN_INTERNALS_EXPORT bool IsApplicable(const ExecutionContext& context, - const miopen::mha::ProblemDescription& problem) const override; + MIOPEN_INTERNALS_EXPORT bool + IsApplicable(const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const override; - MIOPEN_INTERNALS_EXPORT ConvSolution GetSolution(const ExecutionContext& context, - const miopen::mha::ProblemDescription& problem) const override; + MIOPEN_INTERNALS_EXPORT ConvSolution + GetSolution(const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const override; - MIOPEN_INTERNALS_EXPORT std::size_t GetWorkspaceSize(const ExecutionContext& context, - const miopen::mha::ProblemDescription& problem) const override; + MIOPEN_INTERNALS_EXPORT std::size_t + GetWorkspaceSize(const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const override; MIOPEN_INTERNALS_EXPORT bool MayNeedWorkspace() const override; }; diff --git a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp index 4e1764cf3b..4f8a2c2339 100644 --- a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp +++ b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp @@ -47,25 +47,23 @@ static std::string Convert(miopenDataType_t dataType) { switch(dataType) { - case miopenHalf: - { - return "fp16"; - } - case miopenBFloat16: - { - return "bfp16"; - } - default: - { - MIOPEN_THROW("Unsupported datatype provided"); - } + case miopenHalf: { + return "fp16"; + } + case miopenBFloat16: { + return "bfp16"; + } + default: { + MIOPEN_THROW("Unsupported datatype provided"); + } } } -bool MhaCKFlashAttentionV2Forward::IsApplicable([[maybe_unused]] const ExecutionContext& context, - const miopen::mha::ProblemDescription& problem) const +bool MhaCKFlashAttentionV2Forward::IsApplicable( + [[maybe_unused]] const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const { - #if MIOPEN_USE_COMPOSABLEKERNEL +#if MIOPEN_USE_COMPOSABLEKERNEL if(!problem.IsForward()) { return false; @@ -76,37 +74,39 @@ bool MhaCKFlashAttentionV2Forward::IsApplicable([[maybe_unused]] const Execution return false; } - auto& descsFwd = problem.GetDescsForward(); + auto& descsFwd = problem.GetDescsForward(); auto [N_k, H_k, S_k, D_k] = miopen::tien<4>(descsFwd.kDesc.GetLengths()); auto [N_q, H_q, S_q, D_q] = miopen::tien<4>(descsFwd.qDesc.GetLengths()); - return !env::disabled(MIOPEN_DEBUG_FA_CK_V2_FWD) // - && H_q == H_k // Replace with H_q % H_k == 0 once we add support for MQA & GQA. - && H_q <= 256 // - && H_q % 8 == 0 // No padding support yet which means it needs to be multiple of 8. - && descsFwd.kDesc.IsPacked() // - && descsFwd.qDesc.IsPacked() // - && descsFwd.vDesc.IsPacked() // - && descsFwd.oDesc.IsPacked() // - && descsFwd.biasDesc.IsPacked() // - && descsFwd.biasDesc.GetType() == miopenHalf // - && descsFwd.kDesc.GetType() == miopenHalf // - && descsFwd.qDesc.GetType() == miopenHalf // - && descsFwd.vDesc.GetType() == miopenHalf // - && descsFwd.oDesc.GetType() == miopenHalf; // + return !env::disabled(MIOPEN_DEBUG_FA_CK_V2_FWD) // + && H_q == H_k // Replace with H_q % H_k == 0 once we add support for MQA & GQA. + && H_q <= 256 // + && H_q % 8 == 0 // No padding support yet which means it needs to be multiple of 8. + && descsFwd.kDesc.IsPacked() // + && descsFwd.qDesc.IsPacked() // + && descsFwd.vDesc.IsPacked() // + && descsFwd.oDesc.IsPacked() // + && descsFwd.biasDesc.IsPacked() // + && descsFwd.biasDesc.GetType() == miopenHalf // + && descsFwd.kDesc.GetType() == miopenHalf // + && descsFwd.qDesc.GetType() == miopenHalf // + && descsFwd.vDesc.GetType() == miopenHalf // + && descsFwd.oDesc.GetType() == miopenHalf; // #else return false; #endif } -std::size_t MhaCKFlashAttentionV2Forward::GetWorkspaceSize([[maybe_unused]] const ExecutionContext& context, - [[maybe_unused]] const miopen::mha::ProblemDescription& problem) const +std::size_t MhaCKFlashAttentionV2Forward::GetWorkspaceSize( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::mha::ProblemDescription& problem) const { return 0; } -ConvSolution MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContext& context, - const miopen::mha::ProblemDescription& problem) const +ConvSolution +MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const { auto result = ConvSolution{miopenStatusSuccess}; result.workspace_sz = 0; @@ -138,30 +138,30 @@ ConvSolution MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const Ex const ck_tile::index_t max_seqlen_k = seqlen_k; fmha_fwd_traits fmha_traits; - fmha_traits.hdim_q = hdim_q; - fmha_traits.hdim_v = hdim_v; - fmha_traits.data_type = Convert(descsFwd.qDesc.GetType()); - fmha_traits.is_group_mode = is_group_mode; - fmha_traits.is_v_rowmajor = false; - fmha_traits.mask_type = mask_enum::no_mask; - fmha_traits.has_lse = false; - fmha_traits.is_v_rowmajor = false; + fmha_traits.hdim_q = hdim_q; + fmha_traits.hdim_v = hdim_v; + fmha_traits.data_type = Convert(descsFwd.qDesc.GetType()); + fmha_traits.is_group_mode = is_group_mode; + fmha_traits.is_v_rowmajor = false; + fmha_traits.mask_type = mask_enum::no_mask; + fmha_traits.has_lse = false; + fmha_traits.is_v_rowmajor = false; fmha_traits.do_fp8_static_quant = false; fmha_fwd_args fmha_args; - fmha_args.batch = batch; - fmha_args.hdim_q = hdim_q; - fmha_args.hdim_v = hdim_v; - fmha_args.nhead_q = nhead_q; - fmha_args.nhead_k = nhead_k; - fmha_args.stride_q = (i_perm ? hdim_q : nhead * hdim_q); - fmha_args.stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - fmha_args.stride_v = (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + fmha_args.batch = batch; + fmha_args.hdim_q = hdim_q; + fmha_args.hdim_v = hdim_v; + fmha_args.nhead_q = nhead_q; + fmha_args.nhead_k = nhead_k; + fmha_args.stride_q = (i_perm ? hdim_q : nhead * hdim_q); + fmha_args.stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + fmha_args.stride_v = (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); fmha_args.batch_stride_q = (nhead * shape_seqlen_q * hdim_q); fmha_args.batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); fmha_args.batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); - fmha_args.seqlen_k = shape_seqlen_k; - fmha_args.max_seqlen_q = max_seqlen_q; + fmha_args.seqlen_k = shape_seqlen_k; + fmha_args.max_seqlen_q = max_seqlen_q; // These are used for group mode, and we are in batch right now. fmha_args.seqstart_q_ptr = nullptr; @@ -170,53 +170,63 @@ ConvSolution MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const Ex // Batch does not support padding, and we aren't using kvcache yet. fmha_args.seqlen_k_ptr = nullptr; - fmha_args.scale_s = scale_s; - fmha_args.scale_p = scale_p; - fmha_args.scale_o = scale_o; - fmha_args.stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); - fmha_args.stride_o = (o_perm ? hdim_v : nhead * hdim_v); + fmha_args.scale_s = scale_s; + fmha_args.scale_p = scale_p; + fmha_args.scale_o = scale_o; + fmha_args.stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + fmha_args.stride_o = (o_perm ? hdim_v : nhead * hdim_v); fmha_args.nhead_stride_bias = i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k; - fmha_args.nhead_stride_lse = shape_seqlen_q; - fmha_args.nhead_stride_o = o_perm ? shape_seqlen_q * hdim_v : hdim_v; - fmha_args.window_size_left = 0; + fmha_args.nhead_stride_lse = shape_seqlen_q; + fmha_args.nhead_stride_o = o_perm ? shape_seqlen_q * hdim_v : hdim_v; + fmha_args.window_size_left = 0; fmha_args.window_size_right = 0; - fmha_args.mask_type = static_cast(fmha_traits.mask_type); + fmha_args.mask_type = static_cast(fmha_traits.mask_type); fmha_args.s_randval = false; // Since we aren't storing the random values these will be unused for now. - fmha_args.stride_randval = max_seqlen_k; + fmha_args.stride_randval = max_seqlen_k; fmha_args.nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - fmha_args.batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + fmha_args.batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - result.invoker_factory = [=](const std::vector&) { - return [=] (const Handle& handle_, const AnyInvokeParams& raw_params) { + result.invoker_factory = [=](const std::vector&) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { decltype(auto) params = raw_params.CastTo(); const auto& dataFwd = params.GetDataForward(); fmha_fwd_traits fmha_runtime_traits = fmha_traits; - fmha_fwd_args fmha_runtime_args = fmha_args; + fmha_fwd_args fmha_runtime_args = fmha_args; - fmha_runtime_traits.bias_type = dataFwd.biasData != nullptr ? bias_enum::elementwise_bias : bias_enum::no_bias; + fmha_runtime_traits.bias_type = + dataFwd.biasData != nullptr ? bias_enum::elementwise_bias : bias_enum::no_bias; fmha_runtime_traits.has_dropout = dataFwd.dropoutProbabilityData != nullptr; float probability = 0; - uint64_t seed = 0; - uint64_t offset = 0; + uint64_t seed = 0; + uint64_t offset = 0; if(fmha_runtime_traits.has_dropout) { - hipMemcpy(&probability, dataFwd.dropoutProbabilityData, sizeof(float), hipMemcpyKind::hipMemcpyDeviceToHost); - hipMemcpy(&seed, dataFwd.dropoutSeedData, sizeof(uint64_t), hipMemcpyKind::hipMemcpyDeviceToHost); - hipMemcpy(&offset, dataFwd.dropoutOffsetData, sizeof(uint64_t), hipMemcpyKind::hipMemcpyDeviceToHost); + hipMemcpy(&probability, + dataFwd.dropoutProbabilityData, + sizeof(float), + hipMemcpyKind::hipMemcpyDeviceToHost); + hipMemcpy(&seed, + dataFwd.dropoutSeedData, + sizeof(uint64_t), + hipMemcpyKind::hipMemcpyDeviceToHost); + hipMemcpy(&offset, + dataFwd.dropoutOffsetData, + sizeof(uint64_t), + hipMemcpyKind::hipMemcpyDeviceToHost); } - fmha_runtime_args.p_drop = probability; + fmha_runtime_args.p_drop = probability; fmha_runtime_args.drop_seed_offset = {seed, offset}; - fmha_runtime_args.bias_ptr = dataFwd.biasData; - fmha_runtime_args.q_ptr = dataFwd.qData; - fmha_runtime_args.k_ptr = dataFwd.kData; - fmha_runtime_args.v_ptr = dataFwd.vData; + fmha_runtime_args.bias_ptr = dataFwd.biasData; + fmha_runtime_args.q_ptr = dataFwd.qData; + fmha_runtime_args.k_ptr = dataFwd.kData; + fmha_runtime_args.v_ptr = dataFwd.vData; fmha_runtime_args.rand_val_ptr = nullptr; - fmha_runtime_args.o_ptr = dataFwd.oData; + fmha_runtime_args.o_ptr = dataFwd.oData; // Create stream_config, and set it to not time kernel. ck_tile::stream_config stream_config; @@ -228,7 +238,7 @@ ConvSolution MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const Ex } }; }; - + return result; } From 367cf33ab99e7ebbfc63ed55efcfad40e79a183b Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Fri, 27 Sep 2024 22:53:54 +0000 Subject: [PATCH 03/16] Add define guard around CK implementation --- src/solver/mha/mha_ck_fa_v2_solver_forward.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp index 4f8a2c2339..c753726c1d 100644 --- a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp +++ b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp @@ -108,6 +108,7 @@ ConvSolution MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContext& context, const miopen::mha::ProblemDescription& problem) const { +#if MIOPEN_USE_COMPOSABLEKERNEL auto result = ConvSolution{miopenStatusSuccess}; result.workspace_sz = 0; @@ -240,6 +241,9 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex }; return result; +#else + return ConvSolution{miopenStatusNotImplemented}; +#endif } bool MhaCKFlashAttentionV2Forward::MayNeedWorkspace() const { return false; } From b213750a9c46602300500f778e4ab4b655513080 Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Fri, 27 Sep 2024 23:38:06 +0000 Subject: [PATCH 04/16] Add support for masking option --- .../mha/mha_ck_fa_v2_solver_forward.cpp | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp index c753726c1d..3cd8357708 100644 --- a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp +++ b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp @@ -148,6 +148,7 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex fmha_traits.has_lse = false; fmha_traits.is_v_rowmajor = false; fmha_traits.do_fp8_static_quant = false; + fmha_traits.has_dropout = false; fmha_fwd_args fmha_args; fmha_args.batch = batch; @@ -197,13 +198,30 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex fmha_fwd_traits fmha_runtime_traits = fmha_traits; fmha_fwd_args fmha_runtime_args = fmha_args; + fmha_runtime_args.q_ptr = dataFwd.qData; + fmha_runtime_args.k_ptr = dataFwd.kData; + fmha_runtime_args.v_ptr = dataFwd.vData; + fmha_runtime_args.rand_val_ptr = nullptr; + fmha_runtime_args.o_ptr = dataFwd.oData; + fmha_runtime_traits.bias_type = dataFwd.biasData != nullptr ? bias_enum::elementwise_bias : bias_enum::no_bias; - fmha_runtime_traits.has_dropout = dataFwd.dropoutProbabilityData != nullptr; + fmha_runtime_args.bias_ptr = dataFwd.biasData; - float probability = 0; - uint64_t seed = 0; - uint64_t offset = 0; + // Top-left causal mask + if(dataFwd.mask == miopenMhaMask_t::miopenMhaMaskCausal) + { + fmha_runtime_traits.mask_type = mask_enum::mask_top_left; + fmha_runtime_args.mask_type = + static_cast(mask_enum::mask_top_left); + fmha_runtime_args.window_size_left = -1; + fmha_runtime_args.window_size_right = 0; + } + + fmha_runtime_traits.has_dropout = dataFwd.dropoutProbabilityData != nullptr; + float probability = 0; + uint64_t seed = 0; + uint64_t offset = 0; if(fmha_runtime_traits.has_dropout) { hipMemcpy(&probability, @@ -222,13 +240,6 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex fmha_runtime_args.p_drop = probability; fmha_runtime_args.drop_seed_offset = {seed, offset}; - fmha_runtime_args.bias_ptr = dataFwd.biasData; - fmha_runtime_args.q_ptr = dataFwd.qData; - fmha_runtime_args.k_ptr = dataFwd.kData; - fmha_runtime_args.v_ptr = dataFwd.vData; - fmha_runtime_args.rand_val_ptr = nullptr; - fmha_runtime_args.o_ptr = dataFwd.oData; - // Create stream_config, and set it to not time kernel. ck_tile::stream_config stream_config; stream_config.stream_id_ = handle_.GetStream(); From 9d29e2c7f83eb98bb51c1bbc0a436bd0a8e8e4ee Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Mon, 30 Sep 2024 14:43:55 +0000 Subject: [PATCH 05/16] Fix hip-tidy issues --- src/solver/mha/mha_ck_fa_v2_solver_forward.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp index 3cd8357708..d759b3c902 100644 --- a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp +++ b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp @@ -146,7 +146,6 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex fmha_traits.is_v_rowmajor = false; fmha_traits.mask_type = mask_enum::no_mask; fmha_traits.has_lse = false; - fmha_traits.is_v_rowmajor = false; fmha_traits.do_fp8_static_quant = false; fmha_traits.has_dropout = false; @@ -175,9 +174,9 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex fmha_args.scale_s = scale_s; fmha_args.scale_p = scale_p; fmha_args.scale_o = scale_o; - fmha_args.stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + fmha_args.stride_bias = shape_seqlen_k; fmha_args.stride_o = (o_perm ? hdim_v : nhead * hdim_v); - fmha_args.nhead_stride_bias = i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k; + fmha_args.nhead_stride_bias = 0; fmha_args.nhead_stride_lse = shape_seqlen_q; fmha_args.nhead_stride_o = o_perm ? shape_seqlen_q * hdim_v : hdim_v; fmha_args.window_size_left = 0; From 75caf69fc40c71d5f6bfa86dc43547a1f6c88f3f Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Mon, 30 Sep 2024 21:21:15 +0000 Subject: [PATCH 06/16] Add registering new CK solver --- src/problem.cpp | 3 ++- src/solver.cpp | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/problem.cpp b/src/problem.cpp index 16e199c002..494df3c289 100644 --- a/src/problem.cpp +++ b/src/problem.cpp @@ -610,10 +610,11 @@ Problem::FindSolutionsImpl(Handle& handle, const auto algo = AlgorithmName{"Mha"}; + static solver::mha::MhaCKFlashAttentionV2Forward mhaCKFAForwardSolver; static solver::mha::MhaForward mhaForwardSolver; static solver::mha::MhaBackward mhaBackwardSolver; - std::vector solvers = {&mhaForwardSolver, &mhaBackwardSolver}; + std::vector solvers = {&mhaCKFAForwardSolver, &mhaForwardSolver, &mhaBackwardSolver}; for(auto solver : solvers) { diff --git a/src/solver.cpp b/src/solver.cpp index d96548647b..5d10e30c68 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -679,6 +679,8 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) Register(registry, ++id, Primitive::ReLU, prelu::MultiWeightsBackward{}.SolverDbId()); Register(registry, ++id, Primitive::ReLU, prelu::SingleWeightBackward{}.SolverDbId()); + Register(registry, ++id, Primitive::Mha, mha::MhaCKFlashAttentionV2Forward{}.SolverDbId()); + // IMPORTANT: New solvers should be added to the end of the function! } From dcf863614156bc5465537ddad5e1d4c8668fd189 Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Mon, 30 Sep 2024 21:22:02 +0000 Subject: [PATCH 07/16] Apply formatting --- src/problem.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/problem.cpp b/src/problem.cpp index 494df3c289..fed48dfe88 100644 --- a/src/problem.cpp +++ b/src/problem.cpp @@ -614,7 +614,8 @@ Problem::FindSolutionsImpl(Handle& handle, static solver::mha::MhaForward mhaForwardSolver; static solver::mha::MhaBackward mhaBackwardSolver; - std::vector solvers = {&mhaCKFAForwardSolver, &mhaForwardSolver, &mhaBackwardSolver}; + std::vector solvers = { + &mhaCKFAForwardSolver, &mhaForwardSolver, &mhaBackwardSolver}; for(auto solver : solvers) { From bd132b4e9b58fc6cfcf8085bbfca78597b88003d Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Wed, 2 Oct 2024 23:01:26 +0000 Subject: [PATCH 08/16] Code review comments - Update requirements to CK that generates MHA for gfx90a - Fix compilation issue for No CK - Use provided tensor strides instead of calculating --- requirements.txt | 2 +- .../mha/mha_ck_fa_v2_solver_forward.cpp | 87 ++++++++++--------- 2 files changed, 46 insertions(+), 43 deletions(-) diff --git a/requirements.txt b/requirements.txt index ffe4c3acb5..7a2806ab57 100755 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,5 @@ nlohmann/json@v3.11.2 -DJSON_MultipleHeaders=ON -DJSON_BuildTests=Off ROCm/FunctionalPlus@v0.2.18-p0 ROCm/eigen@3.4.0 ROCm/frugally-deep@9683d557eb672ee2304f80f6682c51242d748a50 -ROCm/composable_kernel@9c0811f39a2262dbe4d71b81898187951c1e11ba -DCMAKE_BUILD_TYPE=Release -DINSTANCES_ONLY=ON +ROCm/composable_kernel@294cb823142a815170cf1faa63e01a431a557a04 -DCMAKE_BUILD_TYPE=Release -DINSTANCES_ONLY=ON google/googletest@v1.14.0 diff --git a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp index d759b3c902..0a5e41787b 100644 --- a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp +++ b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp @@ -43,6 +43,7 @@ namespace solver { namespace mha { +#ifdef MIOPEN_USE_COMPOSABLEKERNEL static std::string Convert(miopenDataType_t dataType) { switch(dataType) @@ -58,6 +59,7 @@ static std::string Convert(miopenDataType_t dataType) } } } +#endif bool MhaCKFlashAttentionV2Forward::IsApplicable( [[maybe_unused]] const ExecutionContext& context, @@ -69,7 +71,8 @@ bool MhaCKFlashAttentionV2Forward::IsApplicable( return false; } - if(!StartsWith(context.GetStream().GetDeviceName(), "gfx94")) + auto deviceName = context.GetStream().GetDeviceName(); + if(!StartsWith(deviceName, "gfx94") || deviceName != "gfx90a") { return false; } @@ -114,55 +117,55 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex const miopen::mha::MhaInputDescsForward& descsFwd = problem.GetDescsForward(); auto [N_k, H_k, S_k, D_k] = miopen::tien<4>(descsFwd.kDesc.GetLengths()); + auto [N_stride_k, H_stride_k, S_stride_k, D_stride_k] = + miopen::tien<4>(descsFwd.kDesc.GetStrides()); + auto [N_q, H_q, S_q, D_q] = miopen::tien<4>(descsFwd.qDesc.GetLengths()); - auto [N_v, H_v, S_v, D_v] = miopen::tien<4>(descsFwd.vDesc.GetLengths()); + auto [N_stride_q, H_stride_q, S_stride_q, D_stride_q] = + miopen::tien<4>(descsFwd.qDesc.GetStrides()); - ck_tile::index_t batch = N_q; - ck_tile::index_t seqlen_q = S_q; - ck_tile::index_t seqlen_k = S_k; - ck_tile::index_t hdim_q = D_q; - ck_tile::index_t hdim_v = D_v; - ck_tile::index_t nhead = H_q; - ck_tile::index_t nhead_k = H_k; - ck_tile::index_t nhead_q = H_q; + auto [N_v, H_v, S_v, D_v] = miopen::tien<4>(descsFwd.vDesc.GetLengths()); + auto [N_stride_v, H_stride_v, S_stride_v, D_stride_v] = + miopen::tien<4>(descsFwd.vDesc.GetStrides()); - bool is_group_mode = false; - bool o_perm = true, i_perm = true; // if true, will be batch * nhead * seqlen * hdim + auto [N_stride_bias, H_stride_bias, S_stride_bias, D_stride_bias] = + miopen::tien<4>(descsFwd.biasDesc.GetStrides()); float scale_s = descsFwd.scale; float scale_p = 1.0; float scale_o = 1.0; - const ck_tile::index_t shape_seqlen_q = seqlen_q; - const ck_tile::index_t shape_seqlen_k = seqlen_k; - const ck_tile::index_t max_seqlen_q = seqlen_q; - const ck_tile::index_t max_seqlen_k = seqlen_k; - fmha_fwd_traits fmha_traits; - fmha_traits.hdim_q = hdim_q; - fmha_traits.hdim_v = hdim_v; - fmha_traits.data_type = Convert(descsFwd.qDesc.GetType()); - fmha_traits.is_group_mode = is_group_mode; - fmha_traits.is_v_rowmajor = false; + fmha_traits.hdim_q = H_q; + fmha_traits.hdim_v = H_v; + fmha_traits.data_type = Convert(descsFwd.qDesc.GetType()); + fmha_traits.is_group_mode = false; + // is_v_rowmajor relates to the layout of the V tensor. Row major means NHSD, and Col major + // means NHDS. + fmha_traits.is_v_rowmajor = true; fmha_traits.mask_type = mask_enum::no_mask; fmha_traits.has_lse = false; fmha_traits.do_fp8_static_quant = false; fmha_traits.has_dropout = false; fmha_fwd_args fmha_args; - fmha_args.batch = batch; - fmha_args.hdim_q = hdim_q; - fmha_args.hdim_v = hdim_v; - fmha_args.nhead_q = nhead_q; - fmha_args.nhead_k = nhead_k; - fmha_args.stride_q = (i_perm ? hdim_q : nhead * hdim_q); - fmha_args.stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - fmha_args.stride_v = (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); - fmha_args.batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - fmha_args.batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); - fmha_args.batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); - fmha_args.seqlen_k = shape_seqlen_k; - fmha_args.max_seqlen_q = max_seqlen_q; + fmha_args.batch = N_q; + fmha_args.hdim_q = D_q; + fmha_args.hdim_v = D_v; + fmha_args.nhead_q = H_q; + fmha_args.nhead_k = H_k; + fmha_args.stride_q = S_stride_q; + fmha_args.stride_k = S_stride_k; + fmha_args.stride_v = S_stride_v; + fmha_args.nhead_stride_q = H_stride_q; + fmha_args.nhead_stride_k = H_stride_k; + fmha_args.nhead_stride_v = H_stride_v; + fmha_args.batch_stride_q = N_stride_q; + fmha_args.batch_stride_k = N_stride_k; + fmha_args.batch_stride_v = N_stride_v; + fmha_args.seqlen_k = S_k; + fmha_args.seqlen_q = S_q; + fmha_args.max_seqlen_q = S_q; // These are used for group mode, and we are in batch right now. fmha_args.seqstart_q_ptr = nullptr; @@ -174,20 +177,20 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex fmha_args.scale_s = scale_s; fmha_args.scale_p = scale_p; fmha_args.scale_o = scale_o; - fmha_args.stride_bias = shape_seqlen_k; - fmha_args.stride_o = (o_perm ? hdim_v : nhead * hdim_v); + fmha_args.stride_bias = S_stride_bias; + fmha_args.stride_o = S_stride_v; fmha_args.nhead_stride_bias = 0; - fmha_args.nhead_stride_lse = shape_seqlen_q; - fmha_args.nhead_stride_o = o_perm ? shape_seqlen_q * hdim_v : hdim_v; + fmha_args.nhead_stride_lse = S_q; + fmha_args.nhead_stride_o = S_q * D_v; fmha_args.window_size_left = 0; fmha_args.window_size_right = 0; fmha_args.mask_type = static_cast(fmha_traits.mask_type); fmha_args.s_randval = false; // Since we aren't storing the random values these will be unused for now. - fmha_args.stride_randval = max_seqlen_k; - fmha_args.nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - fmha_args.batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + fmha_args.stride_randval = S_q; + fmha_args.nhead_stride_randval = S_q * S_k; + fmha_args.batch_stride_randval = H_q * S_q * S_k; result.invoker_factory = [=](const std::vector&) { return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { From bb3b268629c8e02e5463861bb4014751d59414c3 Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Thu, 3 Oct 2024 14:07:36 +0000 Subject: [PATCH 09/16] Fix incorrect check for CK define --- src/solver/mha/mha_ck_fa_v2_solver_forward.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp index 0a5e41787b..28e1fd2ed0 100644 --- a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp +++ b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp @@ -43,7 +43,7 @@ namespace solver { namespace mha { -#ifdef MIOPEN_USE_COMPOSABLEKERNEL +#if MIOPEN_USE_COMPOSABLEKERNEL static std::string Convert(miopenDataType_t dataType) { switch(dataType) From f6076c3fbfa678f22e648deb86740f44b9ae25c3 Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Thu, 3 Oct 2024 14:26:45 +0000 Subject: [PATCH 10/16] Add additional details to the register comment to prevent missed conflicts in the future --- src/solver.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/solver.cpp b/src/solver.cpp index 60cd0fa26c..e4722bed1b 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -686,8 +686,8 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) Register(registry, ++id, Primitive::Activation, glu::GLUBackward{}.SolverDbId()); Register(registry, ++id, Primitive::Mha, mha::MhaCKFlashAttentionV2Forward{}.SolverDbId()); - - // IMPORTANT: New solvers should be added to the end of the function! + // IMPORTANT: New solvers should be added to the end of the function, and don't leave a white + // space between this comment and the newly registered solver(s)! } bool ThisSolverIsDeprecatedStatic::IsDisabled(const ExecutionContext& ctx) From 0545995133ae3fd453b505bfdf970596f82d8051 Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Thu, 3 Oct 2024 19:23:33 +0000 Subject: [PATCH 11/16] Update RunImpl to select proper solver --- src/solution.cpp | 60 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/src/solution.cpp b/src/solution.cpp index 06842961f6..6d8c10c13f 100644 --- a/src/solution.cpp +++ b/src/solution.cpp @@ -394,12 +394,62 @@ void Solution::RunImpl(Handle& handle, } }(); - const auto algo = AlgorithmName{"MHA"}; - const auto solvers = solver::SolverContainer{}; + if(invoker) + { + (*invoker)(handle, invoke_ctx); + return; + } + + auto getSolution = [&](const ExecutionContext& ctx) + { + auto solverId = GetSolver(); + solver::mha::MhaForward mhaForward; + solver::mha::MhaBackward mhaBackward; + solver::mha::MhaCKFlashAttentionV2Forward ckMhaForward; + + if(solverId == ckMhaForward.SolverDbId()) + { + return ckMhaForward.GetSolution(ctx, problem_description); + } + else if(solverId == mhaForward.SolverDbId()) + { + return mhaForward.GetSolution(ctx, problem_description); + } + else if(solverId == mhaBackward.SolverDbId()) + { + return mhaBackward.GetSolution(ctx, problem_description); + } + + MIOPEN_THROW("No MHA solver with matching SolverDbId of " + solverId.ToString()); + }; + + if(!kernels.empty()) + { + const auto ctx = ExecutionContext{&handle}; + const auto mha_solution = getSolution(ctx); + auto kernel_handles = std::vector{std::begin(kernels), std::end(kernels)}; - solvers.ExecutePrimitive(handle, problem_description, algo, invoke_ctx); + invoker = (*mha_solution.invoker_factory)(kernel_handles); + (*invoker)(handle, invoke_ctx); + return; + } + + const auto net_cfg = problem_description.MakeNetworkConfig(); + invoker = handle.GetInvoker(net_cfg, GetSolver()); + + if(invoker) + { + (*invoker)(handle, invoke_ctx); + return; + } + + auto ctx = ExecutionContext{&handle}; + const auto mha_solution = getSolution(ctx); + + invoker = + handle.PrepareInvoker(*mha_solution.invoker_factory, mha_solution.construction_params); + handle.RegisterInvoker(*invoker, net_cfg, GetSolver().ToString()); + (*invoker)(handle, invoke_ctx); } void Solution::RunImpl(Handle& handle, From 00e4df6198cb3f814e7f6c492a4d421e1dc1c583 Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Thu, 3 Oct 2024 19:26:30 +0000 Subject: [PATCH 12/16] Apply formatting --- src/solution.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/solution.cpp b/src/solution.cpp index 6d8c10c13f..3df767c90b 100644 --- a/src/solution.cpp +++ b/src/solution.cpp @@ -394,14 +394,13 @@ void Solution::RunImpl(Handle& handle, } }(); - if(invoker) + if(invoker) { (*invoker)(handle, invoke_ctx); return; } - auto getSolution = [&](const ExecutionContext& ctx) - { + auto getSolution = [&](const ExecutionContext& ctx) { auto solverId = GetSolver(); solver::mha::MhaForward mhaForward; solver::mha::MhaBackward mhaBackward; @@ -443,7 +442,7 @@ void Solution::RunImpl(Handle& handle, return; } - auto ctx = ExecutionContext{&handle}; + auto ctx = ExecutionContext{&handle}; const auto mha_solution = getSolution(ctx); invoker = From abcbef4adbb019c4a7b7cd41ea102e07f2d9af9e Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Thu, 3 Oct 2024 19:55:11 +0000 Subject: [PATCH 13/16] Add checking stride conditions since D must be contigious --- src/solver/mha/mha_ck_fa_v2_solver_forward.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp index 28e1fd2ed0..242b7e753e 100644 --- a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp +++ b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp @@ -79,7 +79,18 @@ bool MhaCKFlashAttentionV2Forward::IsApplicable( auto& descsFwd = problem.GetDescsForward(); auto [N_k, H_k, S_k, D_k] = miopen::tien<4>(descsFwd.kDesc.GetLengths()); + auto [N_stride_k, H_stride_k, S_stride_k, D_stride_k] = + miopen::tien<4>(descsFwd.kDesc.GetStrides()); + auto [N_q, H_q, S_q, D_q] = miopen::tien<4>(descsFwd.qDesc.GetLengths()); + auto [N_stride_q, H_stride_q, S_stride_q, D_stride_q] = + miopen::tien<4>(descsFwd.qDesc.GetStrides()); + + auto [N_stride_v, H_stride_v, S_stride_v, D_stride_v] = + miopen::tien<4>(descsFwd.vDesc.GetStrides()); + + auto [N_stride_o, H_stride_o, S_stride_o, D_stride_o] = + miopen::tien<4>(descsFwd.oDesc.GetStrides()); return !env::disabled(MIOPEN_DEBUG_FA_CK_V2_FWD) // && H_q == H_k // Replace with H_q % H_k == 0 once we add support for MQA & GQA. @@ -94,7 +105,9 @@ bool MhaCKFlashAttentionV2Forward::IsApplicable( && descsFwd.kDesc.GetType() == miopenHalf // && descsFwd.qDesc.GetType() == miopenHalf // && descsFwd.vDesc.GetType() == miopenHalf // - && descsFwd.oDesc.GetType() == miopenHalf; // + && descsFwd.oDesc.GetType() == miopenHalf // + && D_stride_k == 1 // CK requires D stride as 1. + && D_stride_q == 1 && D_stride_v == 1 && D_stride_o == 1; #else return false; #endif From d3526b625f97fc42aa0ec3e320ab7bf07257d5d1 Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Thu, 3 Oct 2024 20:26:14 +0000 Subject: [PATCH 14/16] Fix IsApplicable check to properly check for device support. --- src/solver/mha/mha_ck_fa_v2_solver_forward.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp index 242b7e753e..0b285dc062 100644 --- a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp +++ b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp @@ -72,7 +72,7 @@ bool MhaCKFlashAttentionV2Forward::IsApplicable( } auto deviceName = context.GetStream().GetDeviceName(); - if(!StartsWith(deviceName, "gfx94") || deviceName != "gfx90a") + if(!StartsWith(deviceName, "gfx94") && deviceName != "gfx90a") { return false; } @@ -101,7 +101,7 @@ bool MhaCKFlashAttentionV2Forward::IsApplicable( && descsFwd.vDesc.IsPacked() // && descsFwd.oDesc.IsPacked() // && descsFwd.biasDesc.IsPacked() // - && descsFwd.biasDesc.GetType() == miopenHalf // + //&& descsFwd.biasDesc.GetType() == miopenHalf // && descsFwd.kDesc.GetType() == miopenHalf // && descsFwd.qDesc.GetType() == miopenHalf // && descsFwd.vDesc.GetType() == miopenHalf // @@ -219,9 +219,8 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex fmha_runtime_args.rand_val_ptr = nullptr; fmha_runtime_args.o_ptr = dataFwd.oData; - fmha_runtime_traits.bias_type = - dataFwd.biasData != nullptr ? bias_enum::elementwise_bias : bias_enum::no_bias; - fmha_runtime_args.bias_ptr = dataFwd.biasData; + fmha_runtime_traits.bias_type = bias_enum::no_bias; + fmha_runtime_args.bias_ptr = nullptr; // Top-left causal mask if(dataFwd.mask == miopenMhaMask_t::miopenMhaMaskCausal) From 82198ad50954af494814279c26185ca442b27170 Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Thu, 3 Oct 2024 21:42:57 +0000 Subject: [PATCH 15/16] Apply formatting --- .../mha/mha_ck_fa_v2_solver_forward.cpp | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp index 0b285dc062..b80a6d3715 100644 --- a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp +++ b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp @@ -96,17 +96,17 @@ bool MhaCKFlashAttentionV2Forward::IsApplicable( && H_q == H_k // Replace with H_q % H_k == 0 once we add support for MQA & GQA. && H_q <= 256 // && H_q % 8 == 0 // No padding support yet which means it needs to be multiple of 8. - && descsFwd.kDesc.IsPacked() // - && descsFwd.qDesc.IsPacked() // - && descsFwd.vDesc.IsPacked() // - && descsFwd.oDesc.IsPacked() // - && descsFwd.biasDesc.IsPacked() // + && descsFwd.kDesc.IsPacked() // + && descsFwd.qDesc.IsPacked() // + && descsFwd.vDesc.IsPacked() // + && descsFwd.oDesc.IsPacked() // + && descsFwd.biasDesc.IsPacked() // //&& descsFwd.biasDesc.GetType() == miopenHalf // - && descsFwd.kDesc.GetType() == miopenHalf // - && descsFwd.qDesc.GetType() == miopenHalf // - && descsFwd.vDesc.GetType() == miopenHalf // - && descsFwd.oDesc.GetType() == miopenHalf // - && D_stride_k == 1 // CK requires D stride as 1. + && descsFwd.kDesc.GetType() == miopenHalf // + && descsFwd.qDesc.GetType() == miopenHalf // + && descsFwd.vDesc.GetType() == miopenHalf // + && descsFwd.oDesc.GetType() == miopenHalf // + && D_stride_k == 1 // CK requires D stride as 1. && D_stride_q == 1 && D_stride_v == 1 && D_stride_o == 1; #else return false; @@ -220,7 +220,7 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex fmha_runtime_args.o_ptr = dataFwd.oData; fmha_runtime_traits.bias_type = bias_enum::no_bias; - fmha_runtime_args.bias_ptr = nullptr; + fmha_runtime_args.bias_ptr = nullptr; // Top-left causal mask if(dataFwd.mask == miopenMhaMask_t::miopenMhaMaskCausal) From ff0004fa7a7f4ea9ade9a68ac61b7eabcbd29172 Mon Sep 17 00:00:00 2001 From: Brian Harrison Date: Thu, 3 Oct 2024 21:59:10 +0000 Subject: [PATCH 16/16] Add bias back after accidentally removing for testing --- .../mha/mha_ck_fa_v2_solver_forward.cpp | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp index b80a6d3715..cc7b077b56 100644 --- a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp +++ b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp @@ -96,17 +96,17 @@ bool MhaCKFlashAttentionV2Forward::IsApplicable( && H_q == H_k // Replace with H_q % H_k == 0 once we add support for MQA & GQA. && H_q <= 256 // && H_q % 8 == 0 // No padding support yet which means it needs to be multiple of 8. - && descsFwd.kDesc.IsPacked() // - && descsFwd.qDesc.IsPacked() // - && descsFwd.vDesc.IsPacked() // - && descsFwd.oDesc.IsPacked() // - && descsFwd.biasDesc.IsPacked() // - //&& descsFwd.biasDesc.GetType() == miopenHalf // - && descsFwd.kDesc.GetType() == miopenHalf // - && descsFwd.qDesc.GetType() == miopenHalf // - && descsFwd.vDesc.GetType() == miopenHalf // - && descsFwd.oDesc.GetType() == miopenHalf // - && D_stride_k == 1 // CK requires D stride as 1. + && descsFwd.kDesc.IsPacked() // + && descsFwd.qDesc.IsPacked() // + && descsFwd.vDesc.IsPacked() // + && descsFwd.oDesc.IsPacked() // + && descsFwd.biasDesc.IsPacked() // + && descsFwd.biasDesc.GetType() == miopenHalf // + && descsFwd.kDesc.GetType() == miopenHalf // + && descsFwd.qDesc.GetType() == miopenHalf // + && descsFwd.vDesc.GetType() == miopenHalf // + && descsFwd.oDesc.GetType() == miopenHalf // + && D_stride_k == 1 // CK requires D stride as 1. && D_stride_q == 1 && D_stride_v == 1 && D_stride_o == 1; #else return false; @@ -219,8 +219,9 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex fmha_runtime_args.rand_val_ptr = nullptr; fmha_runtime_args.o_ptr = dataFwd.oData; - fmha_runtime_traits.bias_type = bias_enum::no_bias; - fmha_runtime_args.bias_ptr = nullptr; + fmha_runtime_traits.bias_type = + dataFwd.biasData != nullptr ? bias_enum::elementwise_bias : bias_enum::no_bias; + fmha_runtime_args.bias_ptr = dataFwd.biasData; // Top-left causal mask if(dataFwd.mask == miopenMhaMask_t::miopenMhaMaskCausal)