diff --git a/src/fusion.cpp b/src/fusion.cpp index a9ef5e27a..ebc563d73 100644 --- a/src/fusion.cpp +++ b/src/fusion.cpp @@ -781,12 +781,13 @@ solver::ConvSolution MakeFusedSolution(const FusionContext& ctx, const FusionDescription& problem, const AnyInvokeParams& invoke_params) { - decltype(auto) db = GetDb(ctx); + auto db_getter = MakeConvDbGetter(ctx); + solver::ConvSolution solution{miopenStatusInternalError}; GetAllFusionSolvers().FindById(id, [&](auto solver) { solution = miopen::solver::FindSolution( - solver, ctx, problem, db, invoke_params, perf_cfg_override.value_or("")); + solver, ctx, problem, db_getter, invoke_params, perf_cfg_override.value_or("")); }); return solution; @@ -825,7 +826,7 @@ class FusionSolverFinder : public SolversFinderMixin::max(), options); @@ -1034,9 +1035,9 @@ miopenStatus_t FusionPlanDescriptor::Compile(Handle& handle) GetAllFusionSolvers().FindById(id, [&](auto solver) { const auto ctx = FusionContext{handle}; - auto db = GetDb(ctx); + auto db_getter = MakeConvDbGetter(ctx); const auto solution = solver::FindSolution( - solver, ctx, fusion_problem, db, {}); // auto tune is not expected here + solver, ctx, fusion_problem, db_getter, {}); // auto tune is not expected here auto invoker = handle.PrepareInvoker(*solution.invoker_factory, solution.construction_params); // We register the invoker below diff --git a/src/include/miopen/any_solver.hpp b/src/include/miopen/any_solver.hpp index 0908c2823..cd5cdd528 100644 --- a/src/include/miopen/any_solver.hpp +++ b/src/include/miopen/any_solver.hpp @@ -86,6 +86,7 @@ struct AnySolver return ptr_value->Type(); }; bool IsEmpty() const { return ptr_value == nullptr; }; + ConvSolution FindSolution(const ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, PerformanceDb& db, @@ -94,7 +95,18 @@ struct AnySolver { assert(ptr_value != nullptr); return ptr_value->FindSolution(ctx, problem, db, invoke_ctx, perf_cfg); - }; + } + + ConvSolution FindSolution(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem, + std::function db_getter, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg = "") const + { + assert(ptr_value != nullptr); + return ptr_value->FindSolution(ctx, problem, db_getter, invoke_ctx, perf_cfg); + } + InvokerFactory GetInvokeFactory(const ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const std::string& perf_cfg = "") const @@ -153,6 +165,11 @@ struct AnySolver PerformanceDb& db, const miopen::AnyInvokeParams& invoke_ctx, const std::string& perf_cfg) const = 0; + virtual ConvSolution FindSolution(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem, + std::function db_getter, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg) const = 0; virtual InvokerFactory GetInvokeFactory(const ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const std::string& perf_cfg) const = 0; @@ -307,7 +324,17 @@ struct AnySolver const std::string& perf_cfg) const override { return miopen::solver::FindSolution(value, ctx, problem, db, invoke_ctx, perf_cfg); - }; + } + + ConvSolution FindSolution(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem, + std::function db_getter, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg) const override + { + return miopen::solver::FindSolution( + value, ctx, problem, db_getter, invoke_ctx, perf_cfg); + } InvokerFactory GetInvokeFactory(const ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, diff --git a/src/include/miopen/mlo_internal.hpp b/src/include/miopen/mlo_internal.hpp index 6252df308..5c2073a40 100644 --- a/src/include/miopen/mlo_internal.hpp +++ b/src/include/miopen/mlo_internal.hpp @@ -167,6 +167,9 @@ size_t setBotDfDescFromMLDesc(int spatial_dims, TTo& to, const TensorDescriptor& return SetDescFromMLDesc(spatial_dims, to, tensor, &TTo::setBotDfDescr); } +MIOPEN_INTERNALS_EXPORT auto MakeConvDbGetter(const ExecutionContext& ctx) + -> std::function; + namespace solver { struct ConvSolution; diff --git a/src/mlo_dir_conv.cpp b/src/mlo_dir_conv.cpp index 1b80d0d72..ba1df1afa 100644 --- a/src/mlo_dir_conv.cpp +++ b/src/mlo_dir_conv.cpp @@ -204,12 +204,22 @@ static auto GetBwdWrW2DSolvers() static auto GetFFTSolvers() { return miopen::solver::SolverContainer{}; } +auto miopen::MakeConvDbGetter(const ExecutionContext& ctx) -> std::function +{ + return [&, db_container = std::optional()]() mutable -> PerformanceDb& { + if(!db_container) + db_container.emplace(GetDb(ctx)); + + return *db_container; + }; +} + std::vector FindAllGemmSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetGemmSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetGemmSolvers().SearchForAllSolutions(ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector> @@ -224,7 +234,8 @@ FindAllDirectSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetDirectSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetDirectSolvers().SearchForAllSolutions( + ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector> @@ -260,7 +271,8 @@ FindAllImplicitGemmSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetImplicitGemmSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetImplicitGemmSolvers().SearchForAllSolutions( + ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector @@ -268,7 +280,8 @@ FindAllWinogradSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetWindogradSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetWindogradSolvers().SearchForAllSolutions( + ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector @@ -276,7 +289,8 @@ FindWinogradWrWAllSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetWindogradWrWSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetWindogradWrWSolvers().SearchForAllSolutions( + ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector> @@ -298,7 +312,8 @@ FindImplicitGemmWrWAllSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetImplicitGemmWrWSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetImplicitGemmWrWSolvers().SearchForAllSolutions( + ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector @@ -306,7 +321,8 @@ FindAllBwdWrW2DSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetBwdWrW2DSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetBwdWrW2DSolvers().SearchForAllSolutions( + ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector @@ -314,7 +330,7 @@ FindAllFFTSolutions(const miopen::ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { - return GetFFTSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); + return GetFFTSolvers().SearchForAllSolutions(ctx, problem, MakeConvDbGetter(ctx), invoke_ctx); } std::vector> diff --git a/src/ocl/convolutionocl.cpp b/src/ocl/convolutionocl.cpp index 379c1029c..8025a5d32 100644 --- a/src/ocl/convolutionocl.cpp +++ b/src/ocl/convolutionocl.cpp @@ -90,7 +90,7 @@ static Invoker PrepareInvoker(ExecutionContext ctx, ctx.disable_search_enforce = true; const auto solver = solver_id.GetSolver(); - auto db = GetDb(ctx); + auto db = MakeConvDbGetter(ctx); auto solution = solver.FindSolution(ctx, problem, db, {}); // auto tune is not expected here auto& handle = ctx.GetStream(); auto invoker = handle.PrepareInvoker(*solution.invoker_factory, solution.construction_params); diff --git a/src/problem.cpp b/src/problem.cpp index 16e199c00..69c853845 100644 --- a/src/problem.cpp +++ b/src/problem.cpp @@ -515,6 +515,7 @@ std::vector Problem::FindSolutionsImpl(Handle& handle, auto results = FindConvolution(ctx, conv_problem, invoke_ctx, max_solutions, options.attach_binaries); + auto db = MakeConvDbGetter(ctx); for(auto& result : results) { @@ -526,7 +527,6 @@ std::vector Problem::FindSolutionsImpl(Handle& handle, // This would make binaries not serialized and invoker not cached. // So we prepare them here. - auto db = GetDb(ctx); const auto conv_solution = result.GetSolver().GetSolver().FindSolution(ctx, conv_problem, db, invoke_ctx); diff --git a/src/solution.cpp b/src/solution.cpp index aaf2ffde2..c588a1d36 100644 --- a/src/solution.cpp +++ b/src/solution.cpp @@ -246,7 +246,7 @@ void Solution::RunImpl(Handle& handle, auto conv_ctx = ExecutionContext{&handle}; conv_problem.SetupFloats(conv_ctx); - decltype(auto) db = GetDb(conv_ctx); + decltype(auto) db = MakeConvDbGetter(conv_ctx); const auto conv_solution = GetSolver().GetSolver().FindSolution( conv_ctx, conv_problem, db, invoke_ctx, perf_cfg.value_or(""));