Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed all perf-db creations to be lazy #3212

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,12 +781,13 @@ solver::ConvSolution MakeFusedSolution(const FusionContext& ctx,
const FusionDescription& problem,
const AnyInvokeParams& invoke_params)
{
decltype(auto) db = GetDb(ctx);
auto db_getter = MakeConvDbGetter(ctx);

solver::ConvSolution solution{miopenStatusInternalError};

GetAllFusionSolvers().FindById(id, [&](auto solver) {
solution = miopen::solver::FindSolution(
solver, ctx, problem, db, invoke_params, perf_cfg_override.value_or(""));
solver, ctx, problem, db_getter, invoke_params, perf_cfg_override.value_or(""));
});

return solution;
Expand Down Expand Up @@ -825,7 +826,7 @@ class FusionSolverFinder : public SolversFinderMixin<FusionDescription, FusionFi
const auto fusion_ctx = FusionContext(ctx);
return solvers.SearchForAllSolutions(fusion_ctx,
problem,
miopen::GetDb(ctx),
MakeConvDbGetter(ctx),
invoke_ctx,
std::numeric_limits<std::size_t>::max(),
options);
Expand Down Expand Up @@ -1034,9 +1035,9 @@ miopenStatus_t FusionPlanDescriptor::Compile(Handle& handle)

GetAllFusionSolvers().FindById(id, [&](auto solver) {
const auto ctx = FusionContext{handle};
auto db = GetDb(ctx);
auto db_getter = MakeConvDbGetter(ctx);
const auto solution = solver::FindSolution(
solver, ctx, fusion_problem, db, {}); // auto tune is not expected here
solver, ctx, fusion_problem, db_getter, {}); // auto tune is not expected here
auto invoker =
handle.PrepareInvoker(*solution.invoker_factory, solution.construction_params);
// We register the invoker below
Expand Down
31 changes: 29 additions & 2 deletions src/include/miopen/any_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ struct AnySolver
return ptr_value->Type();
};
bool IsEmpty() const { return ptr_value == nullptr; };

ConvSolution FindSolution(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
PerformanceDb& db,
Expand All @@ -94,7 +95,18 @@ struct AnySolver
{
assert(ptr_value != nullptr);
return ptr_value->FindSolution(ctx, problem, db, invoke_ctx, perf_cfg);
};
}

ConvSolution FindSolution(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
std::function<PerformanceDb&()> db_getter,
const miopen::AnyInvokeParams& invoke_ctx,
const std::string& perf_cfg = "") const
{
assert(ptr_value != nullptr);
return ptr_value->FindSolution(ctx, problem, db_getter, invoke_ctx, perf_cfg);
}

InvokerFactory GetInvokeFactory(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const std::string& perf_cfg = "") const
Expand Down Expand Up @@ -153,6 +165,11 @@ struct AnySolver
PerformanceDb& db,
const miopen::AnyInvokeParams& invoke_ctx,
const std::string& perf_cfg) const = 0;
virtual ConvSolution FindSolution(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
std::function<PerformanceDb&()> db_getter,
const miopen::AnyInvokeParams& invoke_ctx,
const std::string& perf_cfg) const = 0;
virtual InvokerFactory GetInvokeFactory(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const std::string& perf_cfg) const = 0;
Expand Down Expand Up @@ -307,7 +324,17 @@ struct AnySolver
const std::string& perf_cfg) const override
{
return miopen::solver::FindSolution(value, ctx, problem, db, invoke_ctx, perf_cfg);
};
}

ConvSolution FindSolution(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
std::function<PerformanceDb&()> db_getter,
const miopen::AnyInvokeParams& invoke_ctx,
const std::string& perf_cfg) const override
{
return miopen::solver::FindSolution(
value, ctx, problem, db_getter, invoke_ctx, perf_cfg);
}

InvokerFactory GetInvokeFactory(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
Expand Down
3 changes: 3 additions & 0 deletions src/include/miopen/mlo_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ size_t setBotDfDescFromMLDesc(int spatial_dims, TTo& to, const TensorDescriptor&
return SetDescFromMLDesc(spatial_dims, to, tensor, &TTo::setBotDfDescr);
}

MIOPEN_INTERNALS_EXPORT auto MakeConvDbGetter(const ExecutionContext& ctx)
-> std::function<PerformanceDb&()>;

namespace solver {
struct ConvSolution;

Expand Down
32 changes: 24 additions & 8 deletions src/mlo_dir_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,22 @@ static auto GetBwdWrW2DSolvers()

static auto GetFFTSolvers() { return miopen::solver::SolverContainer<miopen::solver::conv::fft>{}; }

auto miopen::MakeConvDbGetter(const ExecutionContext& ctx) -> std::function<PerformanceDb&()>
{
return [&, db_container = std::optional<PerformanceDb>()]() mutable -> PerformanceDb& {
if(!db_container)
db_container.emplace(GetDb(ctx));

return *db_container;
};
}

std::vector<miopen::solver::ConvSolution>
FindAllGemmSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetGemmSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetGemmSolvers().SearchForAllSolutions(ctx, problem, MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<std::pair<std::string, size_t>>
Expand All @@ -224,7 +234,8 @@ FindAllDirectSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetDirectSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetDirectSolvers().SearchForAllSolutions(
ctx, problem, MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<std::pair<std::string, size_t>>
Expand Down Expand Up @@ -260,23 +271,26 @@ FindAllImplicitGemmSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetImplicitGemmSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetImplicitGemmSolvers().SearchForAllSolutions(
ctx, problem, MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<miopen::solver::ConvSolution>
FindAllWinogradSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetWindogradSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetWindogradSolvers().SearchForAllSolutions(
ctx, problem, MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<miopen::solver::ConvSolution>
FindWinogradWrWAllSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetWindogradWrWSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetWindogradWrWSolvers().SearchForAllSolutions(
ctx, problem, MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<std::pair<std::string, size_t>>
Expand All @@ -298,23 +312,25 @@ FindImplicitGemmWrWAllSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetImplicitGemmWrWSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetImplicitGemmWrWSolvers().SearchForAllSolutions(
ctx, problem, MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<miopen::solver::ConvSolution>
FindAllBwdWrW2DSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetBwdWrW2DSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetBwdWrW2DSolvers().SearchForAllSolutions(
ctx, problem, MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<miopen::solver::ConvSolution>
FindAllFFTSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetFFTSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetFFTSolvers().SearchForAllSolutions(ctx, problem, MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<std::pair<std::string, size_t>>
Expand Down
2 changes: 1 addition & 1 deletion src/ocl/convolutionocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static Invoker PrepareInvoker(ExecutionContext ctx,
ctx.disable_search_enforce = true;

const auto solver = solver_id.GetSolver();
auto db = GetDb(ctx);
auto db = MakeConvDbGetter(ctx);
auto solution = solver.FindSolution(ctx, problem, db, {}); // auto tune is not expected here
auto& handle = ctx.GetStream();
auto invoker = handle.PrepareInvoker(*solution.invoker_factory, solution.construction_params);
Expand Down
2 changes: 1 addition & 1 deletion src/problem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ std::vector<Solution> Problem::FindSolutionsImpl(Handle& handle,

auto results =
FindConvolution(ctx, conv_problem, invoke_ctx, max_solutions, options.attach_binaries);
auto db = MakeConvDbGetter(ctx);

for(auto& result : results)
{
Expand All @@ -526,7 +527,6 @@ std::vector<Solution> Problem::FindSolutionsImpl(Handle& handle,
// This would make binaries not serialized and invoker not cached.
// So we prepare them here.

auto db = GetDb(ctx);
const auto conv_solution =
result.GetSolver().GetSolver().FindSolution(ctx, conv_problem, db, invoke_ctx);

Expand Down
2 changes: 1 addition & 1 deletion src/solution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ void Solution::RunImpl(Handle& handle,
auto conv_ctx = ExecutionContext{&handle};
conv_problem.SetupFloats(conv_ctx);

decltype(auto) db = GetDb(conv_ctx);
decltype(auto) db = MakeConvDbGetter(conv_ctx);
const auto conv_solution = GetSolver().GetSolver().FindSolution(
conv_ctx, conv_problem, db, invoke_ctx, perf_cfg.value_or(""));

Expand Down