From 2fdf63d6416743861e03911f7990047388988c59 Mon Sep 17 00:00:00 2001 From: Vasilii Filippov Date: Wed, 7 Aug 2024 12:50:58 +0200 Subject: [PATCH 1/3] Made fusion perfdb creation lazy --- src/fusion.cpp | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/fusion.cpp b/src/fusion.cpp index d60f025c77..8fc234e246 100644 --- a/src/fusion.cpp +++ b/src/fusion.cpp @@ -775,18 +775,31 @@ static auto GetAllFusionSolvers() GetFusedWinogradSolvers(); } +namespace { +auto MakeDbGetter(const FusionContext& ctx) +{ + return [&, db_container = std::optional()]() mutable -> PerformanceDb& { + if(!db_container) + db_container.emplace(GetDb(ctx)); + + return *db_container; + }; +} +} // namespace + solver::ConvSolution MakeFusedSolution(const FusionContext& ctx, solver::Id id, const std::optional& perf_cfg_override, const FusionDescription& problem, const AnyInvokeParams& invoke_params) { - decltype(auto) db = GetDb(ctx); + auto db_getter = MakeDbGetter(ctx); + solver::ConvSolution solution{miopenStatusInternalError}; GetAllFusionSolvers().FindById(id, [&](auto solver) { solution = miopen::solver::FindSolution( - solver, ctx, problem, db, invoke_params, perf_cfg_override.value_or("")); + solver, ctx, problem, db_getter, invoke_params, perf_cfg_override.value_or("")); }); return solution; @@ -1034,9 +1047,9 @@ miopenStatus_t FusionPlanDescriptor::Compile(Handle& handle) GetAllFusionSolvers().FindById(id, [&](auto solver) { const auto ctx = FusionContext{handle}; - auto db = GetDb(ctx); + auto db_getter = MakeDbGetter(ctx); const auto solution = solver::FindSolution( - solver, ctx, fusion_problem, db, {}); // auto tune is not expected here + solver, ctx, fusion_problem, db_getter, {}); // auto tune is not expected here auto invoker = handle.PrepareInvoker(*solution.invoker_factory, solution.construction_params); // We register the invoker below From ecfed448a330c300c6334de05c3a34523a0bf3cb Mon Sep 17 00:00:00 2001 From: Vasilii Filippov Date: Wed, 7 Aug 2024 13:45:11 +0200 Subject: [PATCH 2/3] Changed conv perf-dbs to be loaded lazily --- src/fusion.cpp | 18 +++------------- src/include/miopen/any_solver.hpp | 31 ++++++++++++++++++++++++++-- src/include/miopen/mlo_internal.hpp | 3 +++ src/mlo_dir_conv.cpp | 32 +++++++++++++++++++++-------- src/ocl/convolutionocl.cpp | 2 +- src/problem.cpp | 2 +- src/solution.cpp | 2 +- 7 files changed, 62 insertions(+), 28 deletions(-) diff --git a/src/fusion.cpp b/src/fusion.cpp index 8fc234e246..4f14286fcf 100644 --- a/src/fusion.cpp +++ b/src/fusion.cpp @@ -775,25 +775,13 @@ static auto GetAllFusionSolvers() GetFusedWinogradSolvers(); } -namespace { -auto MakeDbGetter(const FusionContext& ctx) -{ - return [&, db_container = std::optional()]() mutable -> PerformanceDb& { - if(!db_container) - db_container.emplace(GetDb(ctx)); - - return *db_container; - }; -} -} // namespace - solver::ConvSolution MakeFusedSolution(const FusionContext& ctx, solver::Id id, const std::optional& perf_cfg_override, const FusionDescription& problem, const AnyInvokeParams& invoke_params) { - auto db_getter = MakeDbGetter(ctx); + auto db_getter = MakeConvDbGetter(ctx); solver::ConvSolution solution{miopenStatusInternalError}; @@ -837,7 +825,7 @@ class FusionSolverFinder : public SolversFinderMixin(ctx), problem, - miopen::GetDb(ctx), + MakeConvDbGetter(ctx), invoke_ctx, std::numeric_limits::max(), options); @@ -1047,7 +1035,7 @@ miopenStatus_t FusionPlanDescriptor::Compile(Handle& handle) GetAllFusionSolvers().FindById(id, [&](auto solver) { const auto ctx = FusionContext{handle}; - auto db_getter = MakeDbGetter(ctx); + auto db_getter = MakeConvDbGetter(ctx); const auto solution = solver::FindSolution( solver, ctx, fusion_problem, db_getter, {}); // auto tune is not expected here auto invoker = diff --git a/src/include/miopen/any_solver.hpp b/src/include/miopen/any_solver.hpp index 0908c2823f..ee5f1dd148 100644 --- a/src/include/miopen/any_solver.hpp +++ b/src/include/miopen/any_solver.hpp @@ -86,6 +86,7 @@ struct AnySolver return ptr_value->Type(); }; bool IsEmpty() const { return ptr_value == nullptr; }; + ConvSolution FindSolution(const ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, PerformanceDb& db, @@ -94,7 +95,18 @@ struct AnySolver { assert(ptr_value != nullptr); return ptr_value->FindSolution(ctx, problem, db, invoke_ctx, perf_cfg); - }; + } + + ConvSolution FindSolution(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem, + std::function db_getter, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg = "") const + { + assert(ptr_value != nullptr); + return ptr_value->FindSolution(ctx, problem, db_getter, invoke_ctx, perf_cfg); + } + InvokerFactory GetInvokeFactory(const ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const std::string& perf_cfg = "") const @@ -153,6 +165,11 @@ struct AnySolver PerformanceDb& db, const miopen::AnyInvokeParams& invoke_ctx, const std::string& perf_cfg) const = 0; + virtual ConvSolution FindSolution(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem, + std::function db_getter, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg) const = 0; virtual InvokerFactory GetInvokeFactory(const ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const std::string& perf_cfg) const = 0; @@ -307,7 +324,17 @@ struct AnySolver const std::string& perf_cfg) const override { return miopen::solver::FindSolution(value, ctx, problem, db, invoke_ctx, perf_cfg); - }; + } + + ConvSolution FindSolution(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem, + std::function db_getter, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg) const override + { + return miopen::solver::FindSolution( + value, ctx, problem, db_getter, invoke_ctx, perf_cfg); + } InvokerFactory GetInvokeFactory(const ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, diff --git a/src/include/miopen/mlo_internal.hpp b/src/include/miopen/mlo_internal.hpp index 6252df3087..5c2073a408 100644 --- a/src/include/miopen/mlo_internal.hpp +++ b/src/include/miopen/mlo_internal.hpp @@ -167,6 +167,9 @@ size_t setBotDfDescFromMLDesc(int spatial_dims, TTo& to, const TensorDescriptor& return SetDescFromMLDesc(spatial_dims, to, tensor, &TTo::setBotDfDescr); } +MIOPEN_INTERNALS_EXPORT auto MakeConvDbGetter(const ExecutionContext& ctx) + -> std::function; + namespace solver { struct ConvSolution; diff --git a/src/mlo_dir_conv.cpp b/src/mlo_dir_conv.cpp index 1b80d0d729..ba1df1afa9 100644 --- a/src/mlo_dir_conv.cpp +++ b/src/mlo_dir_conv.cpp @@ -204,12 +204,22 @@ static auto GetBwdWrW2DSolvers() static auto GetFFTSolvers() { return miopen::solver::SolverContainer{}; } +auto miopen::MakeConvDbGetter(const ExecutionContext& ctx) -> std::function +{ + return [&, db_container = std::optional()]() mutable -> PerformanceDb& { + if(!db_container) + db_container.emplace(GetDb(ctx)); + + return *db_container; + }; +} + std::vector FindAllGemmSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetGemmSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetGemmSolvers().SearchForAllSolutions(ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector> @@ -224,7 +234,8 @@ FindAllDirectSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetDirectSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetDirectSolvers().SearchForAllSolutions( + ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector> @@ -260,7 +271,8 @@ FindAllImplicitGemmSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetImplicitGemmSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetImplicitGemmSolvers().SearchForAllSolutions( + ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector @@ -268,7 +280,8 @@ FindAllWinogradSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetWindogradSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetWindogradSolvers().SearchForAllSolutions( + ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector @@ -276,7 +289,8 @@ FindWinogradWrWAllSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetWindogradWrWSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetWindogradWrWSolvers().SearchForAllSolutions( + ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector> @@ -298,7 +312,8 @@ FindImplicitGemmWrWAllSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetImplicitGemmWrWSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetImplicitGemmWrWSolvers().SearchForAllSolutions( + ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector @@ -306,7 +321,8 @@ FindAllBwdWrW2DSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetBwdWrW2DSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetBwdWrW2DSolvers().SearchForAllSolutions( + ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector @@ -314,7 +330,7 @@ FindAllFFTSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetFFTSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetFFTSolvers().SearchForAllSolutions(ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector> diff --git a/src/ocl/convolutionocl.cpp b/src/ocl/convolutionocl.cpp index aeb38a9575..13abdc246b 100644 --- a/src/ocl/convolutionocl.cpp +++ b/src/ocl/convolutionocl.cpp @@ -90,7 +90,7 @@ static Invoker PrepareInvoker(ExecutionContext ctx, ctx.disable_search_enforce = true; const auto solver = solver_id.GetSolver(); - auto db = GetDb(ctx); + auto db = MakeConvDbGetter(ctx); auto solution = solver.FindSolution(ctx, problem, db, {}); // auto tune is not expected here auto& handle = ctx.GetStream(); auto invoker = handle.PrepareInvoker(*solution.invoker_factory, solution.construction_params); diff --git a/src/problem.cpp b/src/problem.cpp index 75ae8d1790..26a8e7a124 100644 --- a/src/problem.cpp +++ b/src/problem.cpp @@ -514,6 +514,7 @@ std::vector Problem::FindSolutionsImpl(Handle& handle, auto results = FindConvolution(ctx, conv_problem, invoke_ctx, max_solutions, options.attach_binaries); + auto db = MakeConvDbGetter(ctx); for(auto& result : results) { @@ -525,7 +526,6 @@ std::vector Problem::FindSolutionsImpl(Handle& handle, // This would make binaries not serialized and invoker not cached. // So we prepare them here. - auto db = GetDb(ctx); const auto conv_solution = result.GetSolver().GetSolver().FindSolution(ctx, conv_problem, db, invoke_ctx); diff --git a/src/solution.cpp b/src/solution.cpp index 438dae7685..f599365fe9 100644 --- a/src/solution.cpp +++ b/src/solution.cpp @@ -246,7 +246,7 @@ void Solution::RunImpl(Handle& handle, auto conv_ctx = ExecutionContext{&handle}; conv_problem.SetupFloats(conv_ctx); - decltype(auto) db = GetDb(conv_ctx); + decltype(auto) db = MakeConvDbGetter(conv_ctx); const auto conv_solution = GetSolver().GetSolver().FindSolution( conv_ctx, conv_problem, db, invoke_ctx, perf_cfg.value_or("")); From 4bbd6ea9296528476a6e770d1aa877ce9fedebe1 Mon Sep 17 00:00:00 2001 From: Vasilii Filippov Date: Wed, 7 Aug 2024 13:45:42 +0200 Subject: [PATCH 3/3] format --- src/include/miopen/any_solver.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/include/miopen/any_solver.hpp b/src/include/miopen/any_solver.hpp index ee5f1dd148..cd5cdd528f 100644 --- a/src/include/miopen/any_solver.hpp +++ b/src/include/miopen/any_solver.hpp @@ -99,7 +99,7 @@ struct AnySolver ConvSolution FindSolution(const ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, - std::function db_getter, + std::function db_getter, const miopen::AnyInvokeParams& invoke_ctx, const std::string& perf_cfg = "") const { @@ -328,7 +328,7 @@ struct AnySolver ConvSolution FindSolution(const ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, - std::function db_getter, + std::function db_getter, const miopen::AnyInvokeParams& invoke_ctx, const std::string& perf_cfg) const override {