diff --git a/CMakeLists.txt b/CMakeLists.txt index 6029ed49b9..0fdad458f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -327,7 +327,7 @@ add_compile_definitions($<$:HIP_COMPILER_FLAGS=${HIP_COMPI # HIP if( MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU") if(MIOPEN_USE_COMPOSABLEKERNEL) - find_package(composable_kernel 1.0.0 COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations) + find_package(composable_kernel 1.0.0 COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations device_mha_operations) endif() if( MIOPEN_BACKEND STREQUAL "HIPNOGPU") set(MIOPEN_MODE_NOGPU 1) diff --git a/requirements.txt b/requirements.txt index ffe4c3acb5..7a2806ab57 100755 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,5 @@ nlohmann/json@v3.11.2 -DJSON_MultipleHeaders=ON -DJSON_BuildTests=Off ROCm/FunctionalPlus@v0.2.18-p0 ROCm/eigen@3.4.0 ROCm/frugally-deep@9683d557eb672ee2304f80f6682c51242d748a50 -ROCm/composable_kernel@9c0811f39a2262dbe4d71b81898187951c1e11ba -DCMAKE_BUILD_TYPE=Release -DINSTANCES_ONLY=ON +ROCm/composable_kernel@294cb823142a815170cf1faa63e01a431a557a04 -DCMAKE_BUILD_TYPE=Release -DINSTANCES_ONLY=ON google/googletest@v1.14.0 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 870afab0f5..c5b739fa5c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -307,6 +307,7 @@ set( MIOpen_Source solver/layernorm/forward_layernorm2d_ck.cpp solver/layernorm/forward_layernorm4d_ck.cpp solver/layernorm/forward_t5layernorm.cpp + solver/mha/mha_ck_fa_v2_solver_forward.cpp solver/mha/mha_solver_backward.cpp solver/mha/mha_solver_forward.cpp solver/pooling/forward2d.cpp @@ -834,7 +835,7 @@ target_include_directories(MIOpen PUBLIC ) if(MIOPEN_USE_COMPOSABLEKERNEL) -set(MIOPEN_CK_LINK_FLAGS composable_kernel::device_other_operations composable_kernel::device_gemm_operations composable_kernel::device_conv_operations composable_kernel::device_reduction_operations hip::host) +set(MIOPEN_CK_LINK_FLAGS composable_kernel::device_other_operations composable_kernel::device_gemm_operations composable_kernel::device_conv_operations composable_kernel::device_reduction_operations composable_kernel::device_mha_operations hip::host) endif() if(WIN32) diff --git a/src/include/miopen/mha/solvers.hpp b/src/include/miopen/mha/solvers.hpp index 6bac473a71..55423f63c7 100644 --- a/src/include/miopen/mha/solvers.hpp +++ b/src/include/miopen/mha/solvers.hpp @@ -77,6 +77,28 @@ struct MhaBackward final : MhaSolver MIOPEN_INTERNALS_EXPORT bool MayNeedWorkspace() const override; }; +struct MhaCKFlashAttentionV2Forward final : MhaSolver +{ + const std::string& SolverDbId() const override + { + return GetSolverDbId(); + } + + MIOPEN_INTERNALS_EXPORT bool + IsApplicable(const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const override; + + MIOPEN_INTERNALS_EXPORT ConvSolution + GetSolution(const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const override; + + MIOPEN_INTERNALS_EXPORT std::size_t + GetWorkspaceSize(const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const override; + + MIOPEN_INTERNALS_EXPORT bool MayNeedWorkspace() const override; +}; + } // namespace mha } // namespace solver diff --git a/src/problem.cpp b/src/problem.cpp index 8b406516ef..ba84856850 100644 --- a/src/problem.cpp +++ b/src/problem.cpp @@ -610,10 +610,12 @@ Problem::FindSolutionsImpl(Handle& handle, const auto algo = AlgorithmName{"Mha"}; + static solver::mha::MhaCKFlashAttentionV2Forward mhaCKFAForwardSolver; static solver::mha::MhaForward mhaForwardSolver; static solver::mha::MhaBackward mhaBackwardSolver; - std::vector solvers = {&mhaForwardSolver, &mhaBackwardSolver}; + std::vector solvers = { + &mhaCKFAForwardSolver, &mhaForwardSolver, &mhaBackwardSolver}; for(auto solver : solvers) { diff --git a/src/solution.cpp b/src/solution.cpp index 4fe447423a..3df767c90b 100644 --- a/src/solution.cpp +++ b/src/solution.cpp @@ -400,15 +400,32 @@ void Solution::RunImpl(Handle& handle, return; } - solver::mha::MhaForward mhaForward; - solver::mha::MhaBackward mhaBackward; + auto getSolution = [&](const ExecutionContext& ctx) { + auto solverId = GetSolver(); + solver::mha::MhaForward mhaForward; + solver::mha::MhaBackward mhaBackward; + solver::mha::MhaCKFlashAttentionV2Forward ckMhaForward; + + if(solverId == ckMhaForward.SolverDbId()) + { + return ckMhaForward.GetSolution(ctx, problem_description); + } + else if(solverId == mhaForward.SolverDbId()) + { + return mhaForward.GetSolution(ctx, problem_description); + } + else if(solverId == mhaBackward.SolverDbId()) + { + return mhaBackward.GetSolution(ctx, problem_description); + } + + MIOPEN_THROW("No MHA solver with matching SolverDbId of " + solverId.ToString()); + }; if(!kernels.empty()) { const auto ctx = ExecutionContext{&handle}; - const auto mha_solution = GetSolver() == mhaForward.SolverDbId() - ? mhaForward.GetSolution(ctx, problem_description) - : mhaBackward.GetSolution(ctx, problem_description); + const auto mha_solution = getSolution(ctx); auto kernel_handles = std::vector{std::begin(kernels), std::end(kernels)}; invoker = (*mha_solution.invoker_factory)(kernel_handles); @@ -425,11 +442,8 @@ void Solution::RunImpl(Handle& handle, return; } - auto ctx = ExecutionContext{&handle}; - - const auto mha_solution = GetSolver() == mhaForward.SolverDbId() - ? mhaForward.GetSolution(ctx, problem_description) - : mhaBackward.GetSolution(ctx, problem_description); + auto ctx = ExecutionContext{&handle}; + const auto mha_solution = getSolution(ctx); invoker = handle.PrepareInvoker(*mha_solution.invoker_factory, mha_solution.construction_params); diff --git a/src/solver.cpp b/src/solver.cpp index 07723008fc..e4722bed1b 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -685,7 +685,9 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) Register(registry, ++id, Primitive::Activation, glu::GLUForward{}.SolverDbId()); Register(registry, ++id, Primitive::Activation, glu::GLUBackward{}.SolverDbId()); - // IMPORTANT: New solvers should be added to the end of the function! + Register(registry, ++id, Primitive::Mha, mha::MhaCKFlashAttentionV2Forward{}.SolverDbId()); + // IMPORTANT: New solvers should be added to the end of the function, and don't leave a white + // space between this comment and the newly registered solver(s)! } bool ThisSolverIsDeprecatedStatic::IsDisabled(const ExecutionContext& ctx) diff --git a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp new file mode 100644 index 0000000000..cc7b077b56 --- /dev/null +++ b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp @@ -0,0 +1,281 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include + +#ifdef MIOPEN_USE_COMPOSABLEKERNEL +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/fmha_fwd.hpp" +#include "ck/stream_config.hpp" +#endif + +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_FA_CK_V2_FWD) + +namespace miopen { + +namespace solver { + +namespace mha { + +#if MIOPEN_USE_COMPOSABLEKERNEL +static std::string Convert(miopenDataType_t dataType) +{ + switch(dataType) + { + case miopenHalf: { + return "fp16"; + } + case miopenBFloat16: { + return "bfp16"; + } + default: { + MIOPEN_THROW("Unsupported datatype provided"); + } + } +} +#endif + +bool MhaCKFlashAttentionV2Forward::IsApplicable( + [[maybe_unused]] const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const +{ +#if MIOPEN_USE_COMPOSABLEKERNEL + if(!problem.IsForward()) + { + return false; + } + + auto deviceName = context.GetStream().GetDeviceName(); + if(!StartsWith(deviceName, "gfx94") && deviceName != "gfx90a") + { + return false; + } + + auto& descsFwd = problem.GetDescsForward(); + auto [N_k, H_k, S_k, D_k] = miopen::tien<4>(descsFwd.kDesc.GetLengths()); + auto [N_stride_k, H_stride_k, S_stride_k, D_stride_k] = + miopen::tien<4>(descsFwd.kDesc.GetStrides()); + + auto [N_q, H_q, S_q, D_q] = miopen::tien<4>(descsFwd.qDesc.GetLengths()); + auto [N_stride_q, H_stride_q, S_stride_q, D_stride_q] = + miopen::tien<4>(descsFwd.qDesc.GetStrides()); + + auto [N_stride_v, H_stride_v, S_stride_v, D_stride_v] = + miopen::tien<4>(descsFwd.vDesc.GetStrides()); + + auto [N_stride_o, H_stride_o, S_stride_o, D_stride_o] = + miopen::tien<4>(descsFwd.oDesc.GetStrides()); + + return !env::disabled(MIOPEN_DEBUG_FA_CK_V2_FWD) // + && H_q == H_k // Replace with H_q % H_k == 0 once we add support for MQA & GQA. + && H_q <= 256 // + && H_q % 8 == 0 // No padding support yet which means it needs to be multiple of 8. + && descsFwd.kDesc.IsPacked() // + && descsFwd.qDesc.IsPacked() // + && descsFwd.vDesc.IsPacked() // + && descsFwd.oDesc.IsPacked() // + && descsFwd.biasDesc.IsPacked() // + && descsFwd.biasDesc.GetType() == miopenHalf // + && descsFwd.kDesc.GetType() == miopenHalf // + && descsFwd.qDesc.GetType() == miopenHalf // + && descsFwd.vDesc.GetType() == miopenHalf // + && descsFwd.oDesc.GetType() == miopenHalf // + && D_stride_k == 1 // CK requires D stride as 1. + && D_stride_q == 1 && D_stride_v == 1 && D_stride_o == 1; +#else + return false; +#endif +} + +std::size_t MhaCKFlashAttentionV2Forward::GetWorkspaceSize( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::mha::ProblemDescription& problem) const +{ + return 0; +} + +ConvSolution +MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContext& context, + const miopen::mha::ProblemDescription& problem) const +{ +#if MIOPEN_USE_COMPOSABLEKERNEL + auto result = ConvSolution{miopenStatusSuccess}; + result.workspace_sz = 0; + + const miopen::mha::MhaInputDescsForward& descsFwd = problem.GetDescsForward(); + auto [N_k, H_k, S_k, D_k] = miopen::tien<4>(descsFwd.kDesc.GetLengths()); + auto [N_stride_k, H_stride_k, S_stride_k, D_stride_k] = + miopen::tien<4>(descsFwd.kDesc.GetStrides()); + + auto [N_q, H_q, S_q, D_q] = miopen::tien<4>(descsFwd.qDesc.GetLengths()); + auto [N_stride_q, H_stride_q, S_stride_q, D_stride_q] = + miopen::tien<4>(descsFwd.qDesc.GetStrides()); + + auto [N_v, H_v, S_v, D_v] = miopen::tien<4>(descsFwd.vDesc.GetLengths()); + auto [N_stride_v, H_stride_v, S_stride_v, D_stride_v] = + miopen::tien<4>(descsFwd.vDesc.GetStrides()); + + auto [N_stride_bias, H_stride_bias, S_stride_bias, D_stride_bias] = + miopen::tien<4>(descsFwd.biasDesc.GetStrides()); + + float scale_s = descsFwd.scale; + float scale_p = 1.0; + float scale_o = 1.0; + + fmha_fwd_traits fmha_traits; + fmha_traits.hdim_q = H_q; + fmha_traits.hdim_v = H_v; + fmha_traits.data_type = Convert(descsFwd.qDesc.GetType()); + fmha_traits.is_group_mode = false; + // is_v_rowmajor relates to the layout of the V tensor. Row major means NHSD, and Col major + // means NHDS. + fmha_traits.is_v_rowmajor = true; + fmha_traits.mask_type = mask_enum::no_mask; + fmha_traits.has_lse = false; + fmha_traits.do_fp8_static_quant = false; + fmha_traits.has_dropout = false; + + fmha_fwd_args fmha_args; + fmha_args.batch = N_q; + fmha_args.hdim_q = D_q; + fmha_args.hdim_v = D_v; + fmha_args.nhead_q = H_q; + fmha_args.nhead_k = H_k; + fmha_args.stride_q = S_stride_q; + fmha_args.stride_k = S_stride_k; + fmha_args.stride_v = S_stride_v; + fmha_args.nhead_stride_q = H_stride_q; + fmha_args.nhead_stride_k = H_stride_k; + fmha_args.nhead_stride_v = H_stride_v; + fmha_args.batch_stride_q = N_stride_q; + fmha_args.batch_stride_k = N_stride_k; + fmha_args.batch_stride_v = N_stride_v; + fmha_args.seqlen_k = S_k; + fmha_args.seqlen_q = S_q; + fmha_args.max_seqlen_q = S_q; + + // These are used for group mode, and we are in batch right now. + fmha_args.seqstart_q_ptr = nullptr; + fmha_args.seqstart_k_ptr = nullptr; + + // Batch does not support padding, and we aren't using kvcache yet. + fmha_args.seqlen_k_ptr = nullptr; + + fmha_args.scale_s = scale_s; + fmha_args.scale_p = scale_p; + fmha_args.scale_o = scale_o; + fmha_args.stride_bias = S_stride_bias; + fmha_args.stride_o = S_stride_v; + fmha_args.nhead_stride_bias = 0; + fmha_args.nhead_stride_lse = S_q; + fmha_args.nhead_stride_o = S_q * D_v; + fmha_args.window_size_left = 0; + fmha_args.window_size_right = 0; + fmha_args.mask_type = static_cast(fmha_traits.mask_type); + + fmha_args.s_randval = false; + // Since we aren't storing the random values these will be unused for now. + fmha_args.stride_randval = S_q; + fmha_args.nhead_stride_randval = S_q * S_k; + fmha_args.batch_stride_randval = H_q * S_q * S_k; + + result.invoker_factory = [=](const std::vector&) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) params = raw_params.CastTo(); + const auto& dataFwd = params.GetDataForward(); + + fmha_fwd_traits fmha_runtime_traits = fmha_traits; + fmha_fwd_args fmha_runtime_args = fmha_args; + + fmha_runtime_args.q_ptr = dataFwd.qData; + fmha_runtime_args.k_ptr = dataFwd.kData; + fmha_runtime_args.v_ptr = dataFwd.vData; + fmha_runtime_args.rand_val_ptr = nullptr; + fmha_runtime_args.o_ptr = dataFwd.oData; + + fmha_runtime_traits.bias_type = + dataFwd.biasData != nullptr ? bias_enum::elementwise_bias : bias_enum::no_bias; + fmha_runtime_args.bias_ptr = dataFwd.biasData; + + // Top-left causal mask + if(dataFwd.mask == miopenMhaMask_t::miopenMhaMaskCausal) + { + fmha_runtime_traits.mask_type = mask_enum::mask_top_left; + fmha_runtime_args.mask_type = + static_cast(mask_enum::mask_top_left); + fmha_runtime_args.window_size_left = -1; + fmha_runtime_args.window_size_right = 0; + } + + fmha_runtime_traits.has_dropout = dataFwd.dropoutProbabilityData != nullptr; + float probability = 0; + uint64_t seed = 0; + uint64_t offset = 0; + if(fmha_runtime_traits.has_dropout) + { + hipMemcpy(&probability, + dataFwd.dropoutProbabilityData, + sizeof(float), + hipMemcpyKind::hipMemcpyDeviceToHost); + hipMemcpy(&seed, + dataFwd.dropoutSeedData, + sizeof(uint64_t), + hipMemcpyKind::hipMemcpyDeviceToHost); + hipMemcpy(&offset, + dataFwd.dropoutOffsetData, + sizeof(uint64_t), + hipMemcpyKind::hipMemcpyDeviceToHost); + } + fmha_runtime_args.p_drop = probability; + fmha_runtime_args.drop_seed_offset = {seed, offset}; + + // Create stream_config, and set it to not time kernel. + ck_tile::stream_config stream_config; + stream_config.stream_id_ = handle_.GetStream(); + + { + HipEventProfiler profiler(handle_); + fmha_fwd(fmha_runtime_traits, fmha_runtime_args, stream_config); + } + }; + }; + + return result; +#else + return ConvSolution{miopenStatusNotImplemented}; +#endif +} + +bool MhaCKFlashAttentionV2Forward::MayNeedWorkspace() const { return false; } + +} // namespace mha + +} // namespace solver + +} // namespace miopen