Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ck mha fp16 solver #3277

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:HIP_COMPILER_FLAGS=${HIP_COMPI
# HIP
if( MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU")
if(MIOPEN_USE_COMPOSABLEKERNEL)
find_package(composable_kernel 1.0.0 COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations)
find_package(composable_kernel 1.0.0 COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations device_mha_operations)
endif()
if( MIOPEN_BACKEND STREQUAL "HIPNOGPU")
set(MIOPEN_MODE_NOGPU 1)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ nlohmann/[email protected] -DJSON_MultipleHeaders=ON -DJSON_BuildTests=Off
ROCm/[email protected]
ROCm/[email protected]
ROCm/frugally-deep@9683d557eb672ee2304f80f6682c51242d748a50
ROCm/composable_kernel@9c0811f39a2262dbe4d71b81898187951c1e11ba -DCMAKE_BUILD_TYPE=Release -DINSTANCES_ONLY=ON
ROCm/composable_kernel@294cb823142a815170cf1faa63e01a431a557a04 -DCMAKE_BUILD_TYPE=Release -DINSTANCES_ONLY=ON
Copy link
Collaborator Author

@BrianHarrisonAMD BrianHarrisonAMD Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, this is pointing to a dev commit where I added generating the device_mha_operations lib for gfx90a as well.
This will need to be updated to point to an amd-develop commit instead once it's been promoted.

Since this is living in a integration branch it should be fine for now, but we shouldn't merge the integration into develop until this is resolved.

google/[email protected]
3 changes: 2 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ set( MIOpen_Source
solver/layernorm/forward_layernorm2d_ck.cpp
solver/layernorm/forward_layernorm4d_ck.cpp
solver/layernorm/forward_t5layernorm.cpp
solver/mha/mha_ck_fa_v2_solver_forward.cpp
solver/mha/mha_solver_backward.cpp
solver/mha/mha_solver_forward.cpp
solver/pooling/forward2d.cpp
Expand Down Expand Up @@ -834,7 +835,7 @@ target_include_directories(MIOpen PUBLIC
)

if(MIOPEN_USE_COMPOSABLEKERNEL)
set(MIOPEN_CK_LINK_FLAGS composable_kernel::device_other_operations composable_kernel::device_gemm_operations composable_kernel::device_conv_operations composable_kernel::device_reduction_operations hip::host)
set(MIOPEN_CK_LINK_FLAGS composable_kernel::device_other_operations composable_kernel::device_gemm_operations composable_kernel::device_conv_operations composable_kernel::device_reduction_operations composable_kernel::device_mha_operations hip::host)
endif()

if(WIN32)
Expand Down
22 changes: 22 additions & 0 deletions src/include/miopen/mha/solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,28 @@ struct MhaBackward final : MhaSolver
MIOPEN_INTERNALS_EXPORT bool MayNeedWorkspace() const override;
};

struct MhaCKFlashAttentionV2Forward final : MhaSolver
{
const std::string& SolverDbId() const override
{
return GetSolverDbId<MhaCKFlashAttentionV2Forward>();
}

MIOPEN_INTERNALS_EXPORT bool
IsApplicable(const ExecutionContext& context,
const miopen::mha::ProblemDescription& problem) const override;

MIOPEN_INTERNALS_EXPORT ConvSolution
GetSolution(const ExecutionContext& context,
const miopen::mha::ProblemDescription& problem) const override;

MIOPEN_INTERNALS_EXPORT std::size_t
GetWorkspaceSize(const ExecutionContext& context,
const miopen::mha::ProblemDescription& problem) const override;

MIOPEN_INTERNALS_EXPORT bool MayNeedWorkspace() const override;
};

} // namespace mha

} // namespace solver
Expand Down
4 changes: 3 additions & 1 deletion src/problem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,12 @@ Problem::FindSolutionsImpl(Handle& handle,

const auto algo = AlgorithmName{"Mha"};

static solver::mha::MhaCKFlashAttentionV2Forward mhaCKFAForwardSolver;
static solver::mha::MhaForward mhaForwardSolver;
static solver::mha::MhaBackward mhaBackwardSolver;

std::vector<solver::mha::MhaSolver*> solvers = {&mhaForwardSolver, &mhaBackwardSolver};
std::vector<solver::mha::MhaSolver*> solvers = {
&mhaCKFAForwardSolver, &mhaForwardSolver, &mhaBackwardSolver};

for(auto solver : solvers)
{
Expand Down
34 changes: 24 additions & 10 deletions src/solution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,15 +400,32 @@ void Solution::RunImpl(Handle& handle,
return;
}

solver::mha::MhaForward mhaForward;
solver::mha::MhaBackward mhaBackward;
auto getSolution = [&](const ExecutionContext& ctx) {
auto solverId = GetSolver();
solver::mha::MhaForward mhaForward;
solver::mha::MhaBackward mhaBackward;
solver::mha::MhaCKFlashAttentionV2Forward ckMhaForward;

if(solverId == ckMhaForward.SolverDbId())
{
return ckMhaForward.GetSolution(ctx, problem_description);
}
else if(solverId == mhaForward.SolverDbId())
{
return mhaForward.GetSolution(ctx, problem_description);
}
else if(solverId == mhaBackward.SolverDbId())
{
return mhaBackward.GetSolution(ctx, problem_description);
}

MIOPEN_THROW("No MHA solver with matching SolverDbId of " + solverId.ToString());
};

if(!kernels.empty())
{
const auto ctx = ExecutionContext{&handle};
const auto mha_solution = GetSolver() == mhaForward.SolverDbId()
? mhaForward.GetSolution(ctx, problem_description)
: mhaBackward.GetSolution(ctx, problem_description);
const auto mha_solution = getSolution(ctx);
auto kernel_handles = std::vector<Kernel>{std::begin(kernels), std::end(kernels)};

invoker = (*mha_solution.invoker_factory)(kernel_handles);
Expand All @@ -425,11 +442,8 @@ void Solution::RunImpl(Handle& handle,
return;
}

auto ctx = ExecutionContext{&handle};

const auto mha_solution = GetSolver() == mhaForward.SolverDbId()
? mhaForward.GetSolution(ctx, problem_description)
: mhaBackward.GetSolution(ctx, problem_description);
auto ctx = ExecutionContext{&handle};
const auto mha_solution = getSolution(ctx);

invoker =
handle.PrepareInvoker(*mha_solution.invoker_factory, mha_solution.construction_params);
Expand Down
4 changes: 3 additions & 1 deletion src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,9 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry)
Register(registry, ++id, Primitive::Activation, glu::GLUForward{}.SolverDbId());
Register(registry, ++id, Primitive::Activation, glu::GLUBackward{}.SolverDbId());

// IMPORTANT: New solvers should be added to the end of the function!
Register(registry, ++id, Primitive::Mha, mha::MhaCKFlashAttentionV2Forward{}.SolverDbId());
// IMPORTANT: New solvers should be added to the end of the function, and don't leave a white
// space between this comment and the newly registered solver(s)!
}

bool ThisSolverIsDeprecatedStatic::IsDisabled(const ExecutionContext& ctx)
Expand Down
Loading