Skip to content

Commit

Permalink
Add ck mha fp16 solver (#3277)
Browse files Browse the repository at this point in the history
* Add CK FA v2 solver

* Apply formatting

* Add define guard around CK implementation

* Add support for masking option

* Fix hip-tidy issues

* Add registering new CK solver

* Apply formatting

* Code review comments
- Update requirements to CK that generates MHA for gfx90a
- Fix compilation issue for No CK
- Use provided tensor strides instead of calculating

* Fix incorrect check for CK define

* Add additional details to the register comment to prevent missed conflicts in the future

* Update RunImpl to select proper solver

* Apply formatting

* Add checking stride conditions since D must be contigious

* Fix IsApplicable check to properly check for device support.

* Apply formatting

* Add bias back after accidentally removing for testing
  • Loading branch information
BrianHarrisonAMD authored Oct 4, 2024
1 parent fc96df7 commit 95186b8
Show file tree
Hide file tree
Showing 8 changed files with 337 additions and 15 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:HIP_COMPILER_FLAGS=${HIP_COMPI
# HIP
if( MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU")
if(MIOPEN_USE_COMPOSABLEKERNEL)
find_package(composable_kernel 1.0.0 COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations)
find_package(composable_kernel 1.0.0 COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations device_mha_operations)
endif()
if( MIOPEN_BACKEND STREQUAL "HIPNOGPU")
set(MIOPEN_MODE_NOGPU 1)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ nlohmann/[email protected] -DJSON_MultipleHeaders=ON -DJSON_BuildTests=Off
ROCm/[email protected]
ROCm/[email protected]
ROCm/frugally-deep@9683d557eb672ee2304f80f6682c51242d748a50
ROCm/composable_kernel@9c0811f39a2262dbe4d71b81898187951c1e11ba -DCMAKE_BUILD_TYPE=Release -DINSTANCES_ONLY=ON
ROCm/composable_kernel@294cb823142a815170cf1faa63e01a431a557a04 -DCMAKE_BUILD_TYPE=Release -DINSTANCES_ONLY=ON
google/[email protected]
3 changes: 2 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ set( MIOpen_Source
solver/layernorm/forward_layernorm2d_ck.cpp
solver/layernorm/forward_layernorm4d_ck.cpp
solver/layernorm/forward_t5layernorm.cpp
solver/mha/mha_ck_fa_v2_solver_forward.cpp
solver/mha/mha_solver_backward.cpp
solver/mha/mha_solver_forward.cpp
solver/pooling/forward2d.cpp
Expand Down Expand Up @@ -834,7 +835,7 @@ target_include_directories(MIOpen PUBLIC
)

if(MIOPEN_USE_COMPOSABLEKERNEL)
set(MIOPEN_CK_LINK_FLAGS composable_kernel::device_other_operations composable_kernel::device_gemm_operations composable_kernel::device_conv_operations composable_kernel::device_reduction_operations hip::host)
set(MIOPEN_CK_LINK_FLAGS composable_kernel::device_other_operations composable_kernel::device_gemm_operations composable_kernel::device_conv_operations composable_kernel::device_reduction_operations composable_kernel::device_mha_operations hip::host)
endif()

if(WIN32)
Expand Down
22 changes: 22 additions & 0 deletions src/include/miopen/mha/solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,28 @@ struct MhaBackward final : MhaSolver
MIOPEN_INTERNALS_EXPORT bool MayNeedWorkspace() const override;
};

struct MhaCKFlashAttentionV2Forward final : MhaSolver
{
const std::string& SolverDbId() const override
{
return GetSolverDbId<MhaCKFlashAttentionV2Forward>();
}

MIOPEN_INTERNALS_EXPORT bool
IsApplicable(const ExecutionContext& context,
const miopen::mha::ProblemDescription& problem) const override;

MIOPEN_INTERNALS_EXPORT ConvSolution
GetSolution(const ExecutionContext& context,
const miopen::mha::ProblemDescription& problem) const override;

MIOPEN_INTERNALS_EXPORT std::size_t
GetWorkspaceSize(const ExecutionContext& context,
const miopen::mha::ProblemDescription& problem) const override;

MIOPEN_INTERNALS_EXPORT bool MayNeedWorkspace() const override;
};

} // namespace mha

} // namespace solver
Expand Down
4 changes: 3 additions & 1 deletion src/problem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,12 @@ Problem::FindSolutionsImpl(Handle& handle,

const auto algo = AlgorithmName{"Mha"};

static solver::mha::MhaCKFlashAttentionV2Forward mhaCKFAForwardSolver;
static solver::mha::MhaForward mhaForwardSolver;
static solver::mha::MhaBackward mhaBackwardSolver;

std::vector<solver::mha::MhaSolver*> solvers = {&mhaForwardSolver, &mhaBackwardSolver};
std::vector<solver::mha::MhaSolver*> solvers = {
&mhaCKFAForwardSolver, &mhaForwardSolver, &mhaBackwardSolver};

for(auto solver : solvers)
{
Expand Down
34 changes: 24 additions & 10 deletions src/solution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,15 +400,32 @@ void Solution::RunImpl(Handle& handle,
return;
}

solver::mha::MhaForward mhaForward;
solver::mha::MhaBackward mhaBackward;
auto getSolution = [&](const ExecutionContext& ctx) {
auto solverId = GetSolver();
solver::mha::MhaForward mhaForward;
solver::mha::MhaBackward mhaBackward;
solver::mha::MhaCKFlashAttentionV2Forward ckMhaForward;

if(solverId == ckMhaForward.SolverDbId())
{
return ckMhaForward.GetSolution(ctx, problem_description);
}
else if(solverId == mhaForward.SolverDbId())
{
return mhaForward.GetSolution(ctx, problem_description);
}
else if(solverId == mhaBackward.SolverDbId())
{
return mhaBackward.GetSolution(ctx, problem_description);
}

MIOPEN_THROW("No MHA solver with matching SolverDbId of " + solverId.ToString());
};

if(!kernels.empty())
{
const auto ctx = ExecutionContext{&handle};
const auto mha_solution = GetSolver() == mhaForward.SolverDbId()
? mhaForward.GetSolution(ctx, problem_description)
: mhaBackward.GetSolution(ctx, problem_description);
const auto mha_solution = getSolution(ctx);
auto kernel_handles = std::vector<Kernel>{std::begin(kernels), std::end(kernels)};

invoker = (*mha_solution.invoker_factory)(kernel_handles);
Expand All @@ -425,11 +442,8 @@ void Solution::RunImpl(Handle& handle,
return;
}

auto ctx = ExecutionContext{&handle};

const auto mha_solution = GetSolver() == mhaForward.SolverDbId()
? mhaForward.GetSolution(ctx, problem_description)
: mhaBackward.GetSolution(ctx, problem_description);
auto ctx = ExecutionContext{&handle};
const auto mha_solution = getSolution(ctx);

invoker =
handle.PrepareInvoker(*mha_solution.invoker_factory, mha_solution.construction_params);
Expand Down
4 changes: 3 additions & 1 deletion src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,9 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry)
Register(registry, ++id, Primitive::Activation, glu::GLUForward{}.SolverDbId());
Register(registry, ++id, Primitive::Activation, glu::GLUBackward{}.SolverDbId());

// IMPORTANT: New solvers should be added to the end of the function!
Register(registry, ++id, Primitive::Mha, mha::MhaCKFlashAttentionV2Forward{}.SolverDbId());
// IMPORTANT: New solvers should be added to the end of the function, and don't leave a white
// space between this comment and the newly registered solver(s)!
}

bool ThisSolverIsDeprecatedStatic::IsDisabled(const ExecutionContext& ctx)
Expand Down
Loading

0 comments on commit 95186b8

Please sign in to comment.