From bf1cc6d3ae7a4dc5c2fc9f11ba57d02588d0de1b Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Wed, 23 Oct 2024 20:43:52 +0200 Subject: [PATCH 01/26] Add interface for Fin --- src/CMakeLists.txt | 1 + src/fin/fin_interface.cpp | 224 +++++++++++ src/include/miopen/batchnorm/solvers.hpp | 4 + src/include/miopen/conv/solvers.hpp | 1 - src/include/miopen/fin/fin_interface.hpp | 241 ++++++++++++ src/include/miopen/find_solution.hpp | 35 -- src/include/miopen/solver.hpp | 7 + src/include/miopen/solver_id.hpp | 2 + src/solver.cpp | 57 +-- src/solver/conv/conv_ocl_dir2D_bwdWrW_2.cpp | 6 +- test/gtest/unit_FinInterface.cpp | 396 ++++++++++++++++++++ 11 files changed, 913 insertions(+), 61 deletions(-) create mode 100644 src/fin/fin_interface.cpp create mode 100644 src/include/miopen/fin/fin_interface.hpp create mode 100644 test/gtest/unit_FinInterface.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 92e4f4264a..74868858c7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -120,6 +120,7 @@ set( MIOpen_Source env.cpp execution_context.cpp expanduser.cpp + fin/fin_interface.cpp find_controls.cpp find_db.cpp fused_api.cpp diff --git a/src/fin/fin_interface.cpp b/src/fin/fin_interface.cpp new file mode 100644 index 0000000000..1ff6081374 --- /dev/null +++ b/src/fin/fin_interface.cpp @@ -0,0 +1,224 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include +#include +#include + +namespace miopen { +namespace fin { + +// ================== Solver ================== +Solver::Solver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id) : sbase(solver_base), id(solver_id) +{ + if(sbase == nullptr) + MIOPEN_THROW(miopenStatusInternalError); +} + +Solver::Solver(const std::string& requested_name) : rname(requested_name) +{ +} + +bool Solver::IsValid() const +{ + return sbase != nullptr; +} + +uint64_t Solver::GetId() const +{ + if(sbase == nullptr) + MIOPEN_THROW(miopenStatusNotInitialized); + + return id; +} + +const std::string& Solver::GetName() const +{ + if(sbase != nullptr) + return sbase->SolverDbId(); + else + return rname; +} + +bool Solver::IsTunable() const +{ + if(sbase == nullptr) + MIOPEN_THROW(miopenStatusNotInitialized); + + return sbase->IsTunable(); +} + +bool Solver::IsDynamic() const +{ + if(sbase == nullptr) + MIOPEN_THROW(miopenStatusNotInitialized); + + return sbase->IsDynamic(); +} + +// ================== SolverMixin ================== +template +bool SolverMixin::IsApplicable(const Context& ctx, const Problem& problem) const +{ + if(sbase == nullptr) + MIOPEN_THROW(miopenStatusNotInitialized); + + return static_cast*>(sbase)->IsApplicable(ctx, problem); +} + +template +size_t SolverMixin::GetWorkspaceSize(const Context& ctx, const Problem& problem) const +{ + if(sbase == nullptr) + MIOPEN_THROW(miopenStatusNotInitialized); + + return static_cast*>(sbase)->GetWorkspaceSize(ctx, problem); +} + +template +miopen::solver::ConvSolution SolverMixin::FindSolution(const Context& ctx, + const Problem& problem, + miopen::PerformanceDb& db, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg) const +{ + if(sbase == nullptr) + MIOPEN_THROW(miopenStatusNotInitialized); + + /// \todo + MIOPEN_THROW(miopenStatusNotImplemented); +} + +template +std::vector SolverMixin::GetAllSolutions(const Context& ctx, const Problem& problem) const +{ + if(sbase == nullptr) + MIOPEN_THROW(miopenStatusNotInitialized); + + /// \todo + MIOPEN_THROW(miopenStatusNotImplemented); +} + +template +std::string SolverMixin::GetPerfCfgParams(const Context& ctx, + const Problem& problem, + const PerformanceDb& db) const +{ + if(sbase == nullptr) + MIOPEN_THROW(miopenStatusNotInitialized); + + /// \todo + MIOPEN_THROW(miopenStatusNotImplemented); +} + +template +bool SolverMixin::TestPerfCfgParams(const Context& ctx, + const Problem& problem, + const std::string& params) const +{ + if(sbase == nullptr) + MIOPEN_THROW(miopenStatusNotInitialized); + + /// \todo + MIOPEN_THROW(miopenStatusNotImplemented); +} + +// ================== ConvSolver ================== +ConvSolver::ConvSolver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id, miopenConvAlgorithm_t algo_) : SolverMixin(solver_base, solver_id), algo(algo_) {} + +std::string ConvSolver::GetAlgo(miopen::conv::Direction dir) const +{ + if(sbase == nullptr) + MIOPEN_THROW(miopenStatusNotInitialized); + + return ConvolutionAlgoToDirectionalString(algo, dir); +} + +// ================== FinInterface ================== +template +const std::vector& FinInterface::GetAllSolvers(miopen::solver::Primitive primitive) +{ + static const auto solvers = [primitive] { + const auto& ids = GetSolversByPrimitive(primitive); + std::vector solvers; + for(const auto& id : ids) + { + if(!id.IsValid()) + MIOPEN_THROW(miopenStatusInternalError); + solvers.emplace_back(Solver{id.GetSolverBase(), id.Value()}); + } + return solvers; + }(); + return solvers; +} + +template +Solver FinInterface::GetSolver(const std::string& name) +{ + const auto id = miopen::solver::Id{name}; + if(!id.IsValid()) + return {name}; + return {id.GetSolverBase(), id.Value()}; +} + +const std::vector& FinInterface::GetAllConvSolvers() +{ + static const auto solvers = [] { + const auto& ids = GetSolversByPrimitive(miopen::solver::Primitive::Convolution); + std::vector solvers; + for(const auto& id : ids) + { + if(!id.IsValid()) + MIOPEN_THROW(miopenStatusInternalError); + solvers.emplace_back(ConvSolver{id.GetSolverBase(), id.Value(), id.GetAlgo()}); + } + return solvers; + }(); + return solvers; +} + +ConvSolver FinInterface::GetConvSolver(const std::string& name) +{ + const auto id = miopen::solver::Id{name}; + if(!id.IsValid()) + return {name}; + return {id.GetSolverBase(), id.Value(), id.GetAlgo()}; +} + +const std::vector& FinInterface::GetAllBatchNormSolvers() +{ + return GetAllSolvers(miopen::solver::Primitive::Batchnorm); +} + +BatchNormSolver FinInterface::GetBatchNormSolver(const std::string& name) +{ + return GetSolver(name); +} + +} // namespace fin +} // namespace miopen diff --git a/src/include/miopen/batchnorm/solvers.hpp b/src/include/miopen/batchnorm/solvers.hpp index 7edba36b49..92cf11703e 100644 --- a/src/include/miopen/batchnorm/solvers.hpp +++ b/src/include/miopen/batchnorm/solvers.hpp @@ -128,6 +128,7 @@ struct BnFwdInference final : BatchnormSolver bool IsApplicable(const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const override; + bool IsDynamic() const override { return true; } ConvSolution GetSolution(const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const override; }; @@ -138,6 +139,7 @@ struct BnCKFwdInference final : BatchnormSolver bool IsApplicable(const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const override; + bool IsDynamic() const override { return true; } ConvSolution GetSolution(const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const override; }; @@ -148,6 +150,7 @@ struct BnCKBwdBackward final : BatchnormSolver bool IsApplicable(const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const override; + bool IsDynamic() const override { return true; } ConvSolution GetSolution(const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const override; }; @@ -158,6 +161,7 @@ struct BnCKFwdTraining final : BatchnormSolver bool IsApplicable(const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const override; + bool IsDynamic() const override { return true; } ConvSolution GetSolution(const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const override; }; diff --git a/src/include/miopen/conv/solvers.hpp b/src/include/miopen/conv/solvers.hpp index 9713d27135..fc5792b71c 100644 --- a/src/include/miopen/conv/solvers.hpp +++ b/src/include/miopen/conv/solvers.hpp @@ -2807,7 +2807,6 @@ struct ConvCkIgemmFwdV6r1DlopsNchw final : ConvTunableSolver +#include +#include + +#include +#include +#include +#include + +namespace miopen { + +struct AnyInvokeParams; +struct ExecutionContext; + +namespace conv { +struct ProblemDescription; +} // namespace conv + +namespace batchnorm { +struct ProblemDescription; +} // namespace conv + +namespace solver { +struct SolverBase; +} // namespace solver + +namespace fin { + +// Base classes for solvers. +class Solver +{ +public: + // GetId(), IsDynamic() and IsTunable() throw miopenStatusNotInitialized if the solver is not valid. + + // Returns false if the solver could not be found by its name. + bool IsValid() const; + + uint64_t GetId() const; + // Returns the name even if the solver is not valid (returns the requested name). + const std::string& GetName() const; + + bool IsTunable() const; + bool IsDynamic() const; + +protected: + Solver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id); + Solver(const std::string& requested_name); + + const miopen::solver::SolverBase* const sbase = nullptr; + const std::string rname; + uint64_t id; + + friend class FinInterface; +}; + +template +class SolverMixin : public Solver +{ +public: + // All the methods throw miopenStatusNotInitialized if the solver is not valid. + + bool IsApplicable(const Context& ctx, const Problem& problem) const; + + size_t GetWorkspaceSize(const Context& ctx, const Problem& problem) const; + + miopen::solver::ConvSolution FindSolution(const Context& ctx, + const Problem& problem, + miopen::PerformanceDb& db, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg = "") const; + + std::vector GetAllSolutions(const Context& ctx, const Problem& problem) const; + + std::string GetPerfCfgParams(const Context& ctx, + const Problem& problem, + const PerformanceDb& db) const; + + bool TestPerfCfgParams(const Context& ctx, + const Problem& problem, + const std::string& params) const; + +protected: + using Solver::Solver; +}; + +// Convolution solver +class ConvSolver: public SolverMixin +{ +public: + std::string GetAlgo(miopen::conv::Direction dir) const; + +protected: + ConvSolver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id, miopenConvAlgorithm_t algo_); + using SolverMixin::SolverMixin; + + miopenConvAlgorithm_t algo; + + friend class FinInterface; +}; + +// Batch normalization solver +class BatchNormSolver: public SolverMixin +{ +protected: + using SolverMixin::SolverMixin; +}; + +// Interface for Fin +class FinInterface +{ +public: + // GetAll*Solvers() - returns all solvers for a particular primitive. All solvers are always valid. + // + // Get*Solver(name) - returns single solver by its name for a particular primitive. May return a dummy if a solver with specified name does not exist. + + // Convolution + static const std::vector& GetAllConvSolvers(); + static ConvSolver GetConvSolver(const std::string& name); + + // Batch normalization + static const std::vector& GetAllBatchNormSolvers(); + static BatchNormSolver GetBatchNormSolver(const std::string& name); + +private: + template + static const std::vector& GetAllSolvers(miopen::solver::Primitive primitive); + + template + static Solver GetSolver(const std::string& name); +}; + +// Examples: +// +// Convolution solvers: +// +// 1a (Old version): +// +// const auto& solver_id_list = miopen::solver::GetSolversByPrimitive(miopen::solver::Primitive::Convolution); +// for(const auto& id : solver_id_list) +// { +// std::unordered_map solver_info; +// solver_info["name"] = id.ToString(); +// if(!id.IsValid()) +// continue; +// solver_info["id"] = std::to_string(id.Value()); +// solver_info["algo"] = id.GetAlgo(miopen::conv::Direction::Forward); +// const auto solver = id.GetSolver(); +// if(solver.IsEmpty()) +// continue; +// solver_info["tunable"] = solver.IsTunable() ? "1" : "0"; +// solver_info["dynamic"] = solver.IsDynamic() ? "1" : "0"; +// } +// +// 1b (Nev version): +// +// const auto solver_list = miopen::fin::FinInterface::GetAllConvSolvers(); +// for(const auto& solver : solver_list) +// { +// std::unordered_map solver_info; +// solver_info["name"] = solver.GetName(); +// if(!solver.IsValid()) +// continue; +// solver_info["id"] = std::to_string(solver.GetId()); +// solver_info["algo"] = solver.GetAlgo(miopen::conv::Direction::Forward); +// solver_info["tunable"] = solver.IsTunable() ? "1" : "0"; +// solver_info["dynamic"] = solver.IsDynamic() ? "1" : "0"; +// } +// +// 2a (Old version): +// +// std::string solver_name = "ConvBiasActivAsm1x1U"; +// const auto id = miopen::solver::Id{solver_name}; +// std::unordered_map solver_info; +// solver_info["name"] = id.ToString(); +// if(id.IsValid()) +// { +// solver_info["id"] = std::to_string(id.Value()); +// solver_info["algo"] = id.GetAlgo(miopen::conv::Direction::Forward); +// const auto solver = id.GetSolver(); +// if(!solver.IsEmpty()) +// { +// solver_info["tunable"] = solver.IsTunable() ? "1" : "0"; +// solver_info["dynamic"] = solver.IsDynamic() ? "1" : "0"; +// } +// } +// +// 2b (Nev version): +// +// std::string solver_name = "ConvBiasActivAsm1x1U"; +// const auto solver = miopen::fin::FinInterface::GetConvSolver(solver_name); +// std::unordered_map solver_info; +// solver_info["name"] = solver.GetName(); +// if(solver.IsValid()) +// { +// solver_info["id"] = std::to_string(solver.GetId()); +// solver_info["algo"] = solver.GetAlgo(miopen::conv::Direction::Forward); +// solver_info["tunable"] = solver.IsTunable() ? "1" : "0"; +// solver_info["dynamic"] = solver.IsDynamic() ? "1" : "0"; +// } +// +// Batch normalization solvers: +// +// ... +// const auto solver_list = miopen::fin::FinInterface::GetAllBatchNormSolvers(); +// ... +// +// ... +// const auto solver = miopen::fin::FinInterface::GetBatchNormSolver(solver_name); +// ... +// + +} // namespace fin +} // namespace miopen diff --git a/src/include/miopen/find_solution.hpp b/src/include/miopen/find_solution.hpp index 30735cd38a..9527e50aae 100644 --- a/src/include/miopen/find_solution.hpp +++ b/src/include/miopen/find_solution.hpp @@ -438,41 +438,6 @@ struct SolverContainer return res; } - // Search for all applicable solutions among many solvers - template - bool IsAnySolverApplicable(const Context& ctx, const Problem& problem) const - { - const auto find_only = GetEnvFindOnlySolver(); - auto found = false; - - miopen::each_args( - [&](auto solver) { - if(found || (find_only && (std::find(find_only->begin(), - find_only->end(), - Id{solver.SolverDbId()}) == find_only->end()))) - return; - - // For better performance, check IsDynamic() first, because - // it is much faster than IsApplicable(). - if(ctx.use_dynamic_solutions_only && !solver.IsDynamic()) - { - MIOPEN_LOG_I2(solver.SolverDbId() << ": Skipped (non-dynamic)"); - return; - } - - if(solver.IsApplicable(ctx, problem)) - { - found = true; - return; - } - - MIOPEN_LOG_I2(solver.SolverDbId() << ": Not applicable"); - }, - Solvers{}...); - - return found; - } - template void ExecutePrimitive(const ExecutionContext& ctx, const Problem& problem, diff --git a/src/include/miopen/solver.hpp b/src/include/miopen/solver.hpp index 32a29ead6e..88d2989e22 100644 --- a/src/include/miopen/solver.hpp +++ b/src/include/miopen/solver.hpp @@ -79,6 +79,9 @@ struct SolverBase return null_id; } + /// Returns true for tunable solvers + virtual bool IsTunable() const = 0; + /// [Informative as of Sep 2020] The minimum requirement for Dynamic Solvers: /// Batch size and input picture size (N, W, H) must NOT be compiled into the /// kernel(s) that consist a Solution. These must go into the kernel as a @@ -159,6 +162,8 @@ struct SolverInterfaceTunable : SolverInterface template struct SolverBaseNonTunable : SolverInterfaceNonTunable { + bool IsTunable() const final { return false; }; + InvokerFactory GetInvokerFactory(const Context& ctx, const Problem& problem) const { const auto solution = this->GetSolution(ctx, problem); @@ -174,6 +179,8 @@ struct TunableSolverTrait template struct SolverBaseTunable : SolverInterfaceTunable, TunableSolverTrait { + bool IsTunable() const final { return true; }; + /// Initializes performance config to the default values. /// The function may involve some heuristic to guess the best solution /// configuration. It is assumed that the function takes constant time diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index f79a5f5a54..30571ac144 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -42,6 +42,7 @@ struct ForceInit namespace solver { struct AnySolver; +struct SolverBase; enum class Primitive { @@ -78,6 +79,7 @@ struct MIOPEN_INTERNALS_EXPORT Id std::string ToString() const; AnySolver GetSolver() const; + const SolverBase* const GetSolverBase() const; std::string GetAlgo(conv::Direction dir) const; miopenConvAlgorithm_t GetAlgo() const; Primitive GetPrimitive() const; diff --git a/src/solver.cpp b/src/solver.cpp index 1f6873d5f7..850c47b054 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -141,6 +141,7 @@ struct IdRegistryEntry Primitive primitive = Primitive::Convolution; miopenConvAlgorithm_t convAlgo = miopenConvolutionAlgoDirect; AnySolver solver; + std::unique_ptr solver_base; }; struct IdRegistryData @@ -198,6 +199,16 @@ AnySolver Id::GetSolver() const return it != IdRegistry().value_to_entry.end() ? it->second.solver : AnySolver{}; } +const SolverBase* const Id::GetSolverBase() const +{ + if(!IsValid()) + return nullptr; + const auto it = IdRegistry().value_to_entry.find(value); + if(it == IdRegistry().value_to_entry.end()) + return nullptr; + return it->second.solver_base.get(); +} + std::string Id::GetAlgo(miopen::conv::Direction dir) const { return ConvolutionAlgoToDirectionalString(GetAlgo(), dir); @@ -283,9 +294,21 @@ template inline void RegisterWithSolver(IdRegistryData& registry, uint64_t value, TSolver, miopenConvAlgorithm_t algo) { - if(!Register(registry, value, TSolver{}.SolverDbId(), algo)) + auto solver_base = std::make_unique(); + if(!Register(registry, value, solver_base->SolverDbId(), algo)) + return; + auto& entry = registry.value_to_entry.at(value); + entry.solver = TSolver{}; + entry.solver_base = std::move(solver_base); +} + +template +void RegisterWithSolver(IdRegistryData& registry, uint64_t value, Primitive primitive) +{ + auto solver_base = std::make_unique(); + if(!Register(registry, value, primitive, solver_base->SolverDbId())) return; - registry.value_to_entry.at(value).solver = TSolver{}; + registry.value_to_entry.at(value).solver_base = std::move(solver_base); } inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) @@ -547,30 +570,20 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) Register(registry, ++id, Primitive::Activation, activ::ActivBwdSolver0{}.SolverDbId()); Register(registry, ++id, Primitive::Activation, activ::ActivBwdSolver1{}.SolverDbId()); - Register( - registry, ++id, Primitive::Batchnorm, batchnorm::BnFwdTrainingSpatialSingle{}.SolverDbId()); + RegisterWithSolver(registry, ++id, Primitive::Batchnorm); RegisterWithSolver( registry, ++id, conv::ConvCkIgemmFwdV6r1DlopsNchw{}, miopenConvolutionAlgoImplicitGEMM); - Register(registry, - ++id, - Primitive::Batchnorm, - batchnorm::BnFwdTrainingSpatialMultiple{}.SolverDbId()); + RegisterWithSolver(registry, ++id, Primitive::Batchnorm); - Register( - registry, ++id, Primitive::Batchnorm, batchnorm::BnFwdTrainingPerActivation{}.SolverDbId()); + RegisterWithSolver(registry, ++id, Primitive::Batchnorm); - Register( - registry, ++id, Primitive::Batchnorm, batchnorm::BnBwdTrainingSpatialSingle{}.SolverDbId()); - Register(registry, - ++id, - Primitive::Batchnorm, - batchnorm::BnBwdTrainingSpatialMultiple{}.SolverDbId()); - Register( - registry, ++id, Primitive::Batchnorm, batchnorm::BnBwdTrainingPerActivation{}.SolverDbId()); + RegisterWithSolver(registry, ++id, Primitive::Batchnorm); + RegisterWithSolver(registry, ++id, Primitive::Batchnorm); + RegisterWithSolver(registry, ++id, Primitive::Batchnorm); - Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnFwdInference{}.SolverDbId()); + RegisterWithSolver(registry, ++id, Primitive::Batchnorm); Register(registry, ++id, Primitive::Pooling, pooling::PoolingForward2d{}.SolverDbId()); Register(registry, ++id, Primitive::Pooling, pooling::PoolingForwardNd{}.SolverDbId()); @@ -626,9 +639,9 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) ++id, conv::ConvHipImplicitGemm3DGroupBwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); - Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdInference{}.SolverDbId()); - Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKBwdBackward{}.SolverDbId()); - Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdTraining{}.SolverDbId()); + RegisterWithSolver(registry, ++id, Primitive::Batchnorm); + RegisterWithSolver(registry, ++id, Primitive::Batchnorm); + RegisterWithSolver(registry, ++id, Primitive::Batchnorm); Register( registry, ++id, Primitive::Normalization, layernorm::Layernorm2DCKForward{}.SolverDbId()); Register( diff --git a/src/solver/conv/conv_ocl_dir2D_bwdWrW_2.cpp b/src/solver/conv/conv_ocl_dir2D_bwdWrW_2.cpp index 744bc5bf09..ed5b8905ae 100644 --- a/src/solver/conv/conv_ocl_dir2D_bwdWrW_2.cpp +++ b/src/solver/conv/conv_ocl_dir2D_bwdWrW_2.cpp @@ -135,7 +135,7 @@ inline static bool Inc_2_to_11_optimized(int& v) // in issue 1289 #define WORKAROUND_ISSUE_1185 1 -static bool IsTunable(const ProblemDescription& problem) +static bool IsTunableBase(const ProblemDescription& problem) { return !(problem.GetGroupCount() == 1 && ((problem.GetWeightsWidth() == 3 && problem.GetWeightsHeight() == 3) || @@ -149,7 +149,7 @@ bool ConvOclBwdWrW2NonTunable::IsApplicable(const ExecutionContext& ctx, // reasons: after tuning ocl kernel for 3x3 and 1x1 filters, assembly kernel still // dominates. Thus, this solver is used for non-group 3x3 and 1x1 filters only. const auto tunable = ConvOclBwdWrW2<1>{}; - return tunable.IsApplicableBase(ctx, problem) && !IsTunable(problem); + return tunable.IsApplicableBase(ctx, problem) && !IsTunableBase(problem); } size_t ConvOclBwdWrW2NonTunable::GetWorkspaceSize(const ExecutionContext& ctx, @@ -520,7 +520,7 @@ template bool ConvOclBwdWrW2::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const { - return IsApplicableBase(ctx, problem) && IsTunable(problem); + return IsApplicableBase(ctx, problem) && IsTunableBase(problem); } template diff --git a/test/gtest/unit_FinInterface.cpp b/test/gtest/unit_FinInterface.cpp new file mode 100644 index 0000000000..cf5013b91e --- /dev/null +++ b/test/gtest/unit_FinInterface.cpp @@ -0,0 +1,396 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include + +#include + +#include + +namespace { + +struct TestParams +{ + friend std::ostream& operator<<(std::ostream& os, const TestParams& tp) + { + os << "none"; + return os; + } +}; + +struct SolverInfo +{ + SolverInfo() = default; + SolverInfo(uint64_t id_, bool dynamic_, bool tunable_) : id(id_), dynamic(dynamic_), tunable(tunable_) {} + + friend std::ostream& operator<<(std::ostream& os, const SolverInfo& info) + { + os << "("; + os << "id:" << info.id; + os << ", dynamic:" << info.dynamic; + os << ", tunable:" << info.tunable; + os << ")"; + return os; + } + + uint64_t id; + bool dynamic; + bool tunable; +}; + +struct ConvSolverInfo : SolverInfo +{ + using SolverInfo::SolverInfo; + ConvSolverInfo(uint64_t id_, bool dynamic_, bool tunable_, std::string algo_) : SolverInfo(id_, dynamic_, tunable_), algo(std::move(algo_)) {} + + friend std::ostream& operator<<(std::ostream& os, const ConvSolverInfo& info) + { + os << "("; + os << static_cast(info); + os << ", algo:" << info.algo; + os << ")"; + return os; + } + + std::string algo; +}; + +using BatchNormSolverInfo = SolverInfo; + +template +struct TestCase +{ + friend std::ostream& operator<<(std::ostream& os, const TestCase& tc) + { + os << "("; + os << "name:" << tc.name; + os << ", info:" << tc.info; + os << ")"; + return os; + } + + std::string name; + Info info; +}; + +using ConvTestCase = TestCase; +using BatchNormTestCase = TestCase; + +const auto& GetTestParams() +{ + static const auto params = TestParams{}; + return params; +} + +const auto& GetConvSolversInfo() +{ + static const std::unordered_map solver_info = { + // clang-format off + {"ConvAsm3x3U", {1, false, true, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvAsm1x1U", {2, false, true, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvAsm1x1UV2", {3, false, true, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvAsm5x10u2v2f1", {5, false, false, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvAsm5x10u2v2b1", {6, false, false, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvAsm7x7c3h224w224k64u2v2p3q3f1", {7, false, false, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvOclDirectFwd11x11", {8, false, false, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvOclDirectFwdGen", {9, false, false, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvOclDirectFwd", {11, false, true, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvOclDirectFwd1x1", {13, false, true, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvBinWinograd3x3U", {14, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvBinWinogradRxS", {15, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvAsmBwdWrW3x3", {16, false, true, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvAsmBwdWrW1x1", {17, false, true, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvOclBwdWrW2<1>", {18, false, true, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvOclBwdWrW2<2>", {19, false, true, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvOclBwdWrW2<4>", {20, false, true, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvOclBwdWrW2<8>", {21, false, true, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvOclBwdWrW2<16>", {22, false, true, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvOclBwdWrW2NonTunable", {23, false, false, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvOclBwdWrW53", {24, false, false, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvOclBwdWrW1x1", {25, false, false, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvHipImplicitGemmV4R1Fwd", {26, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmV4R1WrW", {31, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"fft", {34, false, false, "miopenConvolutionFwdAlgoFFT"}}, + {"ConvWinograd3x3MultipassWrW<3-4>", {35, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvBinWinogradRxSf3x2", {37, true, true, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvWinograd3x3MultipassWrW<3-5>", {38, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvWinograd3x3MultipassWrW<3-6>", {39, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvWinograd3x3MultipassWrW<3-2>", {40, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvWinograd3x3MultipassWrW<3-3>", {41, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvWinograd3x3MultipassWrW<7-2>", {42, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvWinograd3x3MultipassWrW<7-3>", {43, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvWinograd3x3MultipassWrW<7-2-1-1>", {44, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvWinograd3x3MultipassWrW<7-3-1-1>", {45, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvWinograd3x3MultipassWrW<1-1-7-2>", {46, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvWinograd3x3MultipassWrW<1-1-7-3>", {47, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvWinograd3x3MultipassWrW<5-3>", {48, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvWinograd3x3MultipassWrW<5-4>", {49, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvBinWinogradRxSf2x3", {53, true, true, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvHipImplicitGemmV4R4Fwd", {54, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmBwdDataV1R1", {55, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmBwdDataV4R1", {56, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmBwdDataV1R1Xdlops", {57, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmBwdDataV4R1Xdlops", {60, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmV4R4WrW", {61, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvAsmImplicitGemmV4R1DynamicFwd", {62, true, false, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvAsmImplicitGemmV4R1DynamicFwd_1x1", {63, true, false, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmForwardV4R4Xdlops", {64, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvAsmImplicitGemmV4R1DynamicBwd", {65, true, false, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvAsmImplicitGemmV4R1DynamicWrw", {66, true, false, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvMPBidirectWinograd<2-3>", {67, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvMPBidirectWinograd<3-3>", {68, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvMPBidirectWinograd<4-3>", {69, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvMPBidirectWinograd<5-3>", {70, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvMPBidirectWinograd<6-3>", {71, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvAsmImplicitGemmGTCDynamicWrwXdlops", {72, true, false, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmWrwV4R4Xdlops", {73, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvAsmImplicitGemmGTCDynamicFwdXdlops", {74, true, false, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvMPBidirectWinograd_xdlops<2-3>", {75, false, true, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvMPBidirectWinograd_xdlops<3-3>", {76, false, true, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvMPBidirectWinograd_xdlops<4-3>", {77, false, true, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvMPBidirectWinograd_xdlops<5-3>", {78, false, true, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvMPBidirectWinograd_xdlops<6-3>", {79, false, true, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvHipImplicitGemmForwardV4R5Xdlops", {80, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm", {81, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvAsmImplicitGemmGTCDynamicBwdXdlops", {82, true, false, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm", {83, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvBinWinogradRxSf2x3g1", {84, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvDirectNaiveConvFwd", {85, true, false, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvDirectNaiveConvBwd", {86, true, false, "miopenConvolutionFwdAlgoDirect"}}, + {"ConvDirectNaiveConvWrw", {87, true, false, "miopenConvolutionFwdAlgoDirect"}}, + {"GemmFwd1x1_0_1", {88, true, false, "miopenConvolutionFwdAlgoGEMM"}}, + {"GemmFwd1x1_0_1_int8", {89, true, false, "miopenConvolutionFwdAlgoGEMM"}}, + {"GemmFwd1x1_0_2", {90, true, false, "miopenConvolutionFwdAlgoGEMM"}}, + {"GemmFwdRest", {91, true, false, "miopenConvolutionFwdAlgoGEMM"}}, + {"GemmBwd1x1_stride2", {95, true, false, "miopenConvolutionFwdAlgoGEMM"}}, + {"GemmBwd1x1_stride1", {96, true, false, "miopenConvolutionFwdAlgoGEMM"}}, + {"GemmBwdRest", {97, true, false, "miopenConvolutionFwdAlgoGEMM"}}, + {"ConvMlirIgemmFwd", {98, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvMlirIgemmBwd", {99, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvMlirIgemmWrW", {100, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"GemmWrw1x1_stride1", {101, true, false, "miopenConvolutionFwdAlgoGEMM"}}, + {"GemmWrwUniversal", {102, true, false, "miopenConvolutionFwdAlgoGEMM"}}, + {"ConvMlirIgemmFwdXdlops", {103, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvMlirIgemmBwdXdlops", {104, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvMlirIgemmWrWXdlops", {105, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC", {107, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC", {108, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC", {110, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvCkIgemmFwdV6r1DlopsNchw", {114, false, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC", {127, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmFwdXdlops", {128, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmBwdXdlops", {129, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmGroupFwdXdlops", {137, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemm3DGroupFwdXdlops", {138, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvWinoFuryRxS<2-3>", {139, true, false, "miopenConvolutionFwdAlgoWinograd"}}, + {"ConvHipImplicitGemm3DGroupWrwXdlops", {140, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemm3DGroupBwdXdlops", {141, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmF16F8F16FwdXdlops", {149, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmF16F8F16BwdXdlops", {150, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmF16F8F16WrwXdlops", {151, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmGroupBwdXdlops", {155, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + {"ConvHipImplicitGemmGroupWrwXdlops", {156, true, true, "miopenConvolutionFwdAlgoImplicitGEMM"}}, + // clang-format on + }; + + return solver_info; +} + +const auto& GetBatchNormSolversInfo() +{ + static const std::unordered_map solver_info = { + // clang-format off + {"BnFwdTrainingSpatialSingle", {113, false, false}}, + {"BnFwdTrainingSpatialMultiple", {115, false, false}}, + {"BnFwdTrainingPerActivation", {116, false, false}}, + {"BnBwdTrainingSpatialSingle", {117, false, false}}, + {"BnBwdTrainingSpatialMultiple", {118, false, false}}, + {"BnBwdTrainingPerActivation", {119, false, false}}, + {"BnFwdInference", {120, true, false}}, + {"BnCKFwdInference", {142, true, false}}, + {"BnCKBwdBackward", {143, true, false}}, + {"BnCKFwdTraining", {144, true, false}}, + // clang-format on + }; + + return solver_info; +} + +const auto& GetConvTestCases() +{ + static const auto test_cases = [] { + std::vector test_cases; + const auto& sinfo = GetConvSolversInfo(); + for(const auto& s : sinfo) + test_cases.emplace_back(ConvTestCase{s.first, s.second}); + return test_cases; + }(); + return test_cases; +} + +const auto& GetBatchNormTestCases() +{ + static const auto test_cases = [] { + std::vector test_cases; + const auto& sinfo = GetBatchNormSolversInfo(); + for(const auto& s : sinfo) + test_cases.emplace_back(BatchNormTestCase{s.first, s.second}); + return test_cases; + }(); + return test_cases; +} + +template +void CheckSolverInfo(const Solver& solver, const Info& info) +{ + ASSERT_EQ(solver.GetId(), info.id); + ASSERT_EQ(solver.IsDynamic(), info.dynamic); + ASSERT_EQ(solver.IsTunable(), info.tunable); +} + +void CheckConvSolverInfo(const miopen::fin::ConvSolver& solver, const ConvSolverInfo& info) +{ + ASSERT_NO_FATAL_FAILURE(CheckSolverInfo(solver, info)); + ASSERT_EQ(solver.GetAlgo(miopen::conv::Direction::Forward), info.algo); +} + +template +void CheckSolver(const Solver& solver, const TestCase& test_case) +{ + ASSERT_EQ(solver.GetName(), test_case.name); + ASSERT_EQ(solver.IsValid(), true); + ASSERT_NO_FATAL_FAILURE(CheckSolverInfo(solver, test_case.info)); +} + +void CheckConvSolver(const miopen::fin::ConvSolver& solver, const ConvTestCase& test_case) +{ + ASSERT_EQ(solver.GetName(), test_case.name); + ASSERT_EQ(solver.IsValid(), true); + ASSERT_NO_FATAL_FAILURE(CheckConvSolverInfo(solver, test_case.info)); +} + +class TestGetAllConvSolvers : public ::testing::TestWithParam +{ +public: + void RunTest() + { + const auto& solvers = miopen::fin::FinInterface::GetAllConvSolvers(); + const auto& solvers_info = GetConvSolversInfo(); + + ASSERT_EQ(solvers.size(), solvers_info.size()); + for(const auto& solver : solvers) + { + const auto& name = solver.GetName(); + const auto& solver_info = solvers_info.find(name); + if(solver_info == solvers_info.end()) + { + const std::string error = name + " not found"; + GTEST_FAIL() << error; + } + ASSERT_NO_FATAL_FAILURE(CheckConvSolverInfo(solver, solver_info->second)); + } + } +}; + +class TestGetConvSolver : public ::testing::TestWithParam> +{ +public: + void RunTest() + { + ConvTestCase test_case; + std::tie(std::ignore, test_case) = GetParam(); + const auto solver = miopen::fin::FinInterface::GetConvSolver(test_case.name); + CheckConvSolver(solver, test_case); + } +}; + +class TestGetAllBatchNormSolvers : public ::testing::TestWithParam +{ +public: + void RunTest() + { + const auto& solvers = miopen::fin::FinInterface::GetAllBatchNormSolvers(); + const auto& solvers_info = GetBatchNormSolversInfo(); + + ASSERT_EQ(solvers.size(), solvers_info.size()); + for(const auto& solver : solvers) + { + const auto& name = solver.GetName(); + const auto& solver_info = solvers_info.find(name); + if(solver_info == solvers_info.end()) + { + const std::string error = name + " not found"; + GTEST_FAIL() << error; + } + ASSERT_NO_FATAL_FAILURE(CheckSolverInfo(solver, solver_info->second)); + } + } +}; + +class TestGetBatchNormSolver : public ::testing::TestWithParam> +{ +public: + void RunTest() + { + BatchNormTestCase test_case; + std::tie(std::ignore, test_case) = GetParam(); + const auto solver = miopen::fin::FinInterface::GetBatchNormSolver(test_case.name); + CheckSolver(solver, test_case); + } +}; + +} // namespace + +using CPU_FinInterfaceTestGetAllConvSolvers_NONE = TestGetAllConvSolvers; +using CPU_FinInterfaceTestGetConvSolver_NONE = TestGetConvSolver; +using CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE = TestGetAllBatchNormSolvers; +using CPU_FinInterfaceTestGetBatchNormSolver_NONE = TestGetBatchNormSolver; + +TEST_P(CPU_FinInterfaceTestGetAllConvSolvers_NONE, FinInterface) { this->RunTest(); }; +TEST_P(CPU_FinInterfaceTestGetConvSolver_NONE, FinInterface) { this->RunTest(); }; +TEST_P(CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE, FinInterface) { this->RunTest(); }; +TEST_P(CPU_FinInterfaceTestGetBatchNormSolver_NONE, FinInterface) { this->RunTest(); }; + +INSTANTIATE_TEST_SUITE_P(Full, + CPU_FinInterfaceTestGetAllConvSolvers_NONE, + testing::Values(GetTestParams())); + +INSTANTIATE_TEST_SUITE_P(Full, + CPU_FinInterfaceTestGetConvSolver_NONE, + testing::Combine(testing::Values(GetTestParams()), + testing::ValuesIn(GetConvTestCases()))); + +INSTANTIATE_TEST_SUITE_P(Full, + CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE, + testing::Values(GetTestParams())); + +INSTANTIATE_TEST_SUITE_P(Full, + CPU_FinInterfaceTestGetBatchNormSolver_NONE, + testing::Combine(testing::Values(GetTestParams()), + testing::ValuesIn(GetBatchNormTestCases()))); From ec2ef181b83c266f3c9bdf3e748e4be20f3bf963 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Wed, 23 Oct 2024 20:51:34 +0200 Subject: [PATCH 02/26] Fix formatting --- src/fin/fin_interface.cpp | 56 +++++++++++++----------- src/include/miopen/fin/fin_interface.hpp | 44 +++++++++++-------- src/solver.cpp | 10 +++-- test/gtest/unit_FinInterface.cpp | 33 ++++++++------ 4 files changed, 82 insertions(+), 61 deletions(-) diff --git a/src/fin/fin_interface.cpp b/src/fin/fin_interface.cpp index 1ff6081374..9a1dd86a6a 100644 --- a/src/fin/fin_interface.cpp +++ b/src/fin/fin_interface.cpp @@ -34,20 +34,16 @@ namespace miopen { namespace fin { // ================== Solver ================== -Solver::Solver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id) : sbase(solver_base), id(solver_id) +Solver::Solver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id) + : sbase(solver_base), id(solver_id) { if(sbase == nullptr) MIOPEN_THROW(miopenStatusInternalError); } -Solver::Solver(const std::string& requested_name) : rname(requested_name) -{ -} +Solver::Solver(const std::string& requested_name) : rname(requested_name) {} -bool Solver::IsValid() const -{ - return sbase != nullptr; -} +bool Solver::IsValid() const { return sbase != nullptr; } uint64_t Solver::GetId() const { @@ -77,7 +73,7 @@ bool Solver::IsDynamic() const { if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); - + return sbase->IsDynamic(); } @@ -87,35 +83,40 @@ bool SolverMixin::IsApplicable(const Context& ctx, const Probl { if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); - - return static_cast*>(sbase)->IsApplicable(ctx, problem); + + return static_cast*>(sbase)->IsApplicable( + ctx, problem); } template -size_t SolverMixin::GetWorkspaceSize(const Context& ctx, const Problem& problem) const +size_t SolverMixin::GetWorkspaceSize(const Context& ctx, + const Problem& problem) const { if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); - return static_cast*>(sbase)->GetWorkspaceSize(ctx, problem); + return static_cast*>(sbase)->GetWorkspaceSize( + ctx, problem); } template -miopen::solver::ConvSolution SolverMixin::FindSolution(const Context& ctx, - const Problem& problem, - miopen::PerformanceDb& db, - const miopen::AnyInvokeParams& invoke_ctx, - const std::string& perf_cfg) const +miopen::solver::ConvSolution +SolverMixin::FindSolution(const Context& ctx, + const Problem& problem, + miopen::PerformanceDb& db, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg) const { if(sbase == nullptr) - MIOPEN_THROW(miopenStatusNotInitialized); + MIOPEN_THROW(miopenStatusNotInitialized); /// \todo MIOPEN_THROW(miopenStatusNotImplemented); } template -std::vector SolverMixin::GetAllSolutions(const Context& ctx, const Problem& problem) const +std::vector +SolverMixin::GetAllSolutions(const Context& ctx, const Problem& problem) const { if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); @@ -126,8 +127,8 @@ std::vector SolverMixin::GetAllS template std::string SolverMixin::GetPerfCfgParams(const Context& ctx, - const Problem& problem, - const PerformanceDb& db) const + const Problem& problem, + const PerformanceDb& db) const { if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); @@ -138,8 +139,8 @@ std::string SolverMixin::GetPerfCfgParams(const Context& ctx, template bool SolverMixin::TestPerfCfgParams(const Context& ctx, - const Problem& problem, - const std::string& params) const + const Problem& problem, + const std::string& params) const { if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); @@ -149,7 +150,12 @@ bool SolverMixin::TestPerfCfgParams(const Context& ctx, } // ================== ConvSolver ================== -ConvSolver::ConvSolver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id, miopenConvAlgorithm_t algo_) : SolverMixin(solver_base, solver_id), algo(algo_) {} +ConvSolver::ConvSolver(const miopen::solver::SolverBase* solver_base, + uint64_t solver_id, + miopenConvAlgorithm_t algo_) + : SolverMixin(solver_base, solver_id), algo(algo_) +{ +} std::string ConvSolver::GetAlgo(miopen::conv::Direction dir) const { diff --git a/src/include/miopen/fin/fin_interface.hpp b/src/include/miopen/fin/fin_interface.hpp index ca38c430cf..2f5f73de35 100644 --- a/src/include/miopen/fin/fin_interface.hpp +++ b/src/include/miopen/fin/fin_interface.hpp @@ -46,7 +46,7 @@ struct ProblemDescription; namespace batchnorm { struct ProblemDescription; -} // namespace conv +} // namespace batchnorm namespace solver { struct SolverBase; @@ -58,7 +58,8 @@ namespace fin { class Solver { public: - // GetId(), IsDynamic() and IsTunable() throw miopenStatusNotInitialized if the solver is not valid. + // GetId(), IsDynamic() and IsTunable() throw miopenStatusNotInitialized if the solver is not + // valid. // Returns false if the solver could not be found by its name. bool IsValid() const; @@ -92,33 +93,34 @@ class SolverMixin : public Solver size_t GetWorkspaceSize(const Context& ctx, const Problem& problem) const; miopen::solver::ConvSolution FindSolution(const Context& ctx, - const Problem& problem, - miopen::PerformanceDb& db, - const miopen::AnyInvokeParams& invoke_ctx, - const std::string& perf_cfg = "") const; + const Problem& problem, + miopen::PerformanceDb& db, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg = "") const; - std::vector GetAllSolutions(const Context& ctx, const Problem& problem) const; + std::vector GetAllSolutions(const Context& ctx, + const Problem& problem) const; - std::string GetPerfCfgParams(const Context& ctx, - const Problem& problem, - const PerformanceDb& db) const; + std::string + GetPerfCfgParams(const Context& ctx, const Problem& problem, const PerformanceDb& db) const; - bool TestPerfCfgParams(const Context& ctx, - const Problem& problem, - const std::string& params) const; + bool + TestPerfCfgParams(const Context& ctx, const Problem& problem, const std::string& params) const; protected: using Solver::Solver; }; // Convolution solver -class ConvSolver: public SolverMixin +class ConvSolver : public SolverMixin { public: std::string GetAlgo(miopen::conv::Direction dir) const; protected: - ConvSolver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id, miopenConvAlgorithm_t algo_); + ConvSolver(const miopen::solver::SolverBase* solver_base, + uint64_t solver_id, + miopenConvAlgorithm_t algo_); using SolverMixin::SolverMixin; miopenConvAlgorithm_t algo; @@ -127,7 +129,8 @@ class ConvSolver: public SolverMixin +class BatchNormSolver + : public SolverMixin { protected: using SolverMixin::SolverMixin; @@ -137,9 +140,11 @@ class BatchNormSolver: public SolverMixin& GetAllConvSolvers(); @@ -163,7 +168,8 @@ class FinInterface // // 1a (Old version): // -// const auto& solver_id_list = miopen::solver::GetSolversByPrimitive(miopen::solver::Primitive::Convolution); +// const auto& solver_id_list = +// miopen::solver::GetSolversByPrimitive(miopen::solver::Primitive::Convolution); // for(const auto& id : solver_id_list) // { // std::unordered_map solver_info; diff --git a/src/solver.cpp b/src/solver.cpp index 850c47b054..3a5effbd51 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -297,8 +297,8 @@ RegisterWithSolver(IdRegistryData& registry, uint64_t value, TSolver, miopenConv auto solver_base = std::make_unique(); if(!Register(registry, value, solver_base->SolverDbId(), algo)) return; - auto& entry = registry.value_to_entry.at(value); - entry.solver = TSolver{}; + auto& entry = registry.value_to_entry.at(value); + entry.solver = TSolver{}; entry.solver_base = std::move(solver_base); } @@ -575,12 +575,14 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) RegisterWithSolver( registry, ++id, conv::ConvCkIgemmFwdV6r1DlopsNchw{}, miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver(registry, ++id, Primitive::Batchnorm); + RegisterWithSolver( + registry, ++id, Primitive::Batchnorm); RegisterWithSolver(registry, ++id, Primitive::Batchnorm); RegisterWithSolver(registry, ++id, Primitive::Batchnorm); - RegisterWithSolver(registry, ++id, Primitive::Batchnorm); + RegisterWithSolver( + registry, ++id, Primitive::Batchnorm); RegisterWithSolver(registry, ++id, Primitive::Batchnorm); RegisterWithSolver(registry, ++id, Primitive::Batchnorm); diff --git a/test/gtest/unit_FinInterface.cpp b/test/gtest/unit_FinInterface.cpp index cf5013b91e..fb1d60a5c0 100644 --- a/test/gtest/unit_FinInterface.cpp +++ b/test/gtest/unit_FinInterface.cpp @@ -47,7 +47,10 @@ struct TestParams struct SolverInfo { SolverInfo() = default; - SolverInfo(uint64_t id_, bool dynamic_, bool tunable_) : id(id_), dynamic(dynamic_), tunable(tunable_) {} + SolverInfo(uint64_t id_, bool dynamic_, bool tunable_) + : id(id_), dynamic(dynamic_), tunable(tunable_) + { + } friend std::ostream& operator<<(std::ostream& os, const SolverInfo& info) { @@ -67,7 +70,10 @@ struct SolverInfo struct ConvSolverInfo : SolverInfo { using SolverInfo::SolverInfo; - ConvSolverInfo(uint64_t id_, bool dynamic_, bool tunable_, std::string algo_) : SolverInfo(id_, dynamic_, tunable_), algo(std::move(algo_)) {} + ConvSolverInfo(uint64_t id_, bool dynamic_, bool tunable_, std::string algo_) + : SolverInfo(id_, dynamic_, tunable_), algo(std::move(algo_)) + { + } friend std::ostream& operator<<(std::ostream& os, const ConvSolverInfo& info) { @@ -99,7 +105,7 @@ struct TestCase Info info; }; -using ConvTestCase = TestCase; +using ConvTestCase = TestCase; using BatchNormTestCase = TestCase; const auto& GetTestParams() @@ -300,13 +306,13 @@ class TestGetAllConvSolvers : public ::testing::TestWithParam public: void RunTest() { - const auto& solvers = miopen::fin::FinInterface::GetAllConvSolvers(); + const auto& solvers = miopen::fin::FinInterface::GetAllConvSolvers(); const auto& solvers_info = GetConvSolversInfo(); ASSERT_EQ(solvers.size(), solvers_info.size()); for(const auto& solver : solvers) { - const auto& name = solver.GetName(); + const auto& name = solver.GetName(); const auto& solver_info = solvers_info.find(name); if(solver_info == solvers_info.end()) { @@ -325,7 +331,7 @@ class TestGetConvSolver : public ::testing::TestWithParam public: void RunTest() { - const auto& solvers = miopen::fin::FinInterface::GetAllBatchNormSolvers(); + const auto& solvers = miopen::fin::FinInterface::GetAllBatchNormSolvers(); const auto& solvers_info = GetBatchNormSolversInfo(); ASSERT_EQ(solvers.size(), solvers_info.size()); for(const auto& solver : solvers) { - const auto& name = solver.GetName(); + const auto& name = solver.GetName(); const auto& solver_info = solvers_info.find(name); if(solver_info == solvers_info.end()) { @@ -353,7 +359,8 @@ class TestGetAllBatchNormSolvers : public ::testing::TestWithParam } }; -class TestGetBatchNormSolver : public ::testing::TestWithParam> +class TestGetBatchNormSolver + : public ::testing::TestWithParam> { public: void RunTest() @@ -367,10 +374,10 @@ class TestGetBatchNormSolver : public ::testing::TestWithParamRunTest(); }; TEST_P(CPU_FinInterfaceTestGetConvSolver_NONE, FinInterface) { this->RunTest(); }; From be07add9486dbc903e788e72a71b149d6f88e43f Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Fri, 25 Oct 2024 15:32:53 +0200 Subject: [PATCH 03/26] Fix the description --- src/include/miopen/fin/fin_interface.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/include/miopen/fin/fin_interface.hpp b/src/include/miopen/fin/fin_interface.hpp index 2f5f73de35..7901abbe0b 100644 --- a/src/include/miopen/fin/fin_interface.hpp +++ b/src/include/miopen/fin/fin_interface.hpp @@ -187,7 +187,7 @@ class FinInterface // // 1b (Nev version): // -// const auto solver_list = miopen::fin::FinInterface::GetAllConvSolvers(); +// const auto& solver_list = miopen::fin::FinInterface::GetAllConvSolvers(); // for(const auto& solver : solver_list) // { // std::unordered_map solver_info; @@ -235,7 +235,7 @@ class FinInterface // Batch normalization solvers: // // ... -// const auto solver_list = miopen::fin::FinInterface::GetAllBatchNormSolvers(); +// const auto& solver_list = miopen::fin::FinInterface::GetAllBatchNormSolvers(); // ... // // ... From 3652303210091aacffef473342be3f5988ebb1d8 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Fri, 25 Oct 2024 15:56:22 +0200 Subject: [PATCH 04/26] Fix formatting --- src/include/miopen/batchnorm/solvers.hpp | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/include/miopen/batchnorm/solvers.hpp b/src/include/miopen/batchnorm/solvers.hpp index b138322c6e..551164f611 100644 --- a/src/include/miopen/batchnorm/solvers.hpp +++ b/src/include/miopen/batchnorm/solvers.hpp @@ -151,12 +151,10 @@ struct PerformanceConfigBnCKFwdInference : PerfConfigBase static void Visit(Self&& s, F f) @@ -220,12 +218,10 @@ struct PerformanceConfigBnCKBwdBackward : PerfConfigBase static void Visit(Self&& s, F f) @@ -291,12 +287,10 @@ struct PerformanceConfigBnCKFwdTraining : PerfConfigBase static void Visit(Self&& s, F f) From d89c0cedd05e10e72fc4935ff9b1bb47975bbfd5 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:05:46 +0200 Subject: [PATCH 05/26] Fix tidy --- src/include/miopen/solver_id.hpp | 2 +- src/solver.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index 30571ac144..76a13b051c 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -79,7 +79,7 @@ struct MIOPEN_INTERNALS_EXPORT Id std::string ToString() const; AnySolver GetSolver() const; - const SolverBase* const GetSolverBase() const; + const SolverBase* GetSolverBase() const; std::string GetAlgo(conv::Direction dir) const; miopenConvAlgorithm_t GetAlgo() const; Primitive GetPrimitive() const; diff --git a/src/solver.cpp b/src/solver.cpp index 3a5effbd51..c50e840a42 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -199,7 +199,7 @@ AnySolver Id::GetSolver() const return it != IdRegistry().value_to_entry.end() ? it->second.solver : AnySolver{}; } -const SolverBase* const Id::GetSolverBase() const +const SolverBase* Id::GetSolverBase() const { if(!IsValid()) return nullptr; From ee46f04b7a143aee4ae0cba80a6551d9d2eef4fc Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:17:39 +0200 Subject: [PATCH 06/26] Fix windows build --- src/include/miopen/fin/fin_interface.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/include/miopen/fin/fin_interface.hpp b/src/include/miopen/fin/fin_interface.hpp index 7901abbe0b..4ae1728eaf 100644 --- a/src/include/miopen/fin/fin_interface.hpp +++ b/src/include/miopen/fin/fin_interface.hpp @@ -112,7 +112,7 @@ class SolverMixin : public Solver }; // Convolution solver -class ConvSolver : public SolverMixin +class MIOPEN_INTERNALS_EXPORT ConvSolver : public SolverMixin { public: std::string GetAlgo(miopen::conv::Direction dir) const; @@ -129,7 +129,7 @@ class ConvSolver : public SolverMixin { protected: @@ -147,12 +147,12 @@ class FinInterface // dummy if a solver with specified name does not exist. // Convolution - static const std::vector& GetAllConvSolvers(); - static ConvSolver GetConvSolver(const std::string& name); + MIOPEN_INTERNALS_EXPORT static const std::vector& GetAllConvSolvers(); + MIOPEN_INTERNALS_EXPORT static ConvSolver GetConvSolver(const std::string& name); // Batch normalization - static const std::vector& GetAllBatchNormSolvers(); - static BatchNormSolver GetBatchNormSolver(const std::string& name); + MIOPEN_INTERNALS_EXPORT static const std::vector& GetAllBatchNormSolvers(); + MIOPEN_INTERNALS_EXPORT static BatchNormSolver GetBatchNormSolver(const std::string& name); private: template From a56b39137c562d569107beef200ac3f717408886 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:49:15 +0200 Subject: [PATCH 07/26] Refactor --- src/fin/fin_interface.cpp | 56 ++++++++++++++---------- src/include/miopen/fin/fin_interface.hpp | 3 +- 2 files changed, 36 insertions(+), 23 deletions(-) diff --git a/src/fin/fin_interface.cpp b/src/fin/fin_interface.cpp index 9a1dd86a6a..c9ab57c318 100644 --- a/src/fin/fin_interface.cpp +++ b/src/fin/fin_interface.cpp @@ -24,6 +24,7 @@ * *******************************************************************************/ +#include #include #include @@ -167,17 +168,38 @@ std::string ConvSolver::GetAlgo(miopen::conv::Direction dir) const // ================== FinInterface ================== template -const std::vector& FinInterface::GetAllSolvers(miopen::solver::Primitive primitive) +struct SolverToPrimitive; + +template <> +struct SolverToPrimitive +{ + static auto GetPrimitive() { return miopen::solver::Primitive::Convolution; } +}; + +template <> +struct SolverToPrimitive +{ + static auto GetPrimitive() { return miopen::solver::Primitive::Batchnorm; } +}; + +template +const std::vector& FinInterface::GetAllSolvers() { - static const auto solvers = [primitive] { - const auto& ids = GetSolversByPrimitive(primitive); + static const auto solvers = [] { + const auto& ids = GetSolversByPrimitive(SolverToPrimitive::GetPrimitive()); std::vector solvers; + for(const auto& id : ids) { if(!id.IsValid()) MIOPEN_THROW(miopenStatusInternalError); - solvers.emplace_back(Solver{id.GetSolverBase(), id.Value()}); + + if constexpr(std::is_same_v) + solvers.emplace_back(Solver{id.GetSolverBase(), id.Value(), id.GetAlgo()}); + else + solvers.emplace_back(Solver{id.GetSolverBase(), id.Value()}); } + return solvers; }(); return solvers; @@ -189,36 +211,26 @@ Solver FinInterface::GetSolver(const std::string& name) const auto id = miopen::solver::Id{name}; if(!id.IsValid()) return {name}; - return {id.GetSolverBase(), id.Value()}; + + if constexpr(std::is_same_v) + return {id.GetSolverBase(), id.Value(), id.GetAlgo()}; + else + return {id.GetSolverBase(), id.Value()}; } const std::vector& FinInterface::GetAllConvSolvers() { - static const auto solvers = [] { - const auto& ids = GetSolversByPrimitive(miopen::solver::Primitive::Convolution); - std::vector solvers; - for(const auto& id : ids) - { - if(!id.IsValid()) - MIOPEN_THROW(miopenStatusInternalError); - solvers.emplace_back(ConvSolver{id.GetSolverBase(), id.Value(), id.GetAlgo()}); - } - return solvers; - }(); - return solvers; + return GetAllSolvers(); } ConvSolver FinInterface::GetConvSolver(const std::string& name) { - const auto id = miopen::solver::Id{name}; - if(!id.IsValid()) - return {name}; - return {id.GetSolverBase(), id.Value(), id.GetAlgo()}; + return GetSolver(name); } const std::vector& FinInterface::GetAllBatchNormSolvers() { - return GetAllSolvers(miopen::solver::Primitive::Batchnorm); + return GetAllSolvers(); } BatchNormSolver FinInterface::GetBatchNormSolver(const std::string& name) diff --git a/src/include/miopen/fin/fin_interface.hpp b/src/include/miopen/fin/fin_interface.hpp index 4ae1728eaf..c55ddcc9a2 100644 --- a/src/include/miopen/fin/fin_interface.hpp +++ b/src/include/miopen/fin/fin_interface.hpp @@ -121,6 +121,7 @@ class MIOPEN_INTERNALS_EXPORT ConvSolver : public SolverMixin - static const std::vector& GetAllSolvers(miopen::solver::Primitive primitive); + static const std::vector& GetAllSolvers(); template static Solver GetSolver(const std::string& name); From 3a89c666326328f314743ac6ac3a29cd6f6f2e8f Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:53:05 +0200 Subject: [PATCH 08/26] Fix formatting --- src/include/miopen/fin/fin_interface.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/include/miopen/fin/fin_interface.hpp b/src/include/miopen/fin/fin_interface.hpp index c55ddcc9a2..d6d9633926 100644 --- a/src/include/miopen/fin/fin_interface.hpp +++ b/src/include/miopen/fin/fin_interface.hpp @@ -112,7 +112,8 @@ class SolverMixin : public Solver }; // Convolution solver -class MIOPEN_INTERNALS_EXPORT ConvSolver : public SolverMixin +class MIOPEN_INTERNALS_EXPORT ConvSolver + : public SolverMixin { public: std::string GetAlgo(miopen::conv::Direction dir) const; From ee2196f3dd09604840ad8d50743ebedcd69d5822 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Fri, 25 Oct 2024 21:41:52 +0200 Subject: [PATCH 09/26] Fix tidy --- src/fin/fin_interface.cpp | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/fin/fin_interface.cpp b/src/fin/fin_interface.cpp index c9ab57c318..b0756d104f 100644 --- a/src/fin/fin_interface.cpp +++ b/src/fin/fin_interface.cpp @@ -82,10 +82,13 @@ bool Solver::IsDynamic() const template bool SolverMixin::IsApplicable(const Context& ctx, const Problem& problem) const { + std::ignore = ctx; + std::ignore = problem; + if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); - return static_cast*>(sbase)->IsApplicable( + return static_cast*>(sbase)->IsApplicable( ctx, problem); } @@ -93,10 +96,13 @@ template size_t SolverMixin::GetWorkspaceSize(const Context& ctx, const Problem& problem) const { + std::ignore = ctx; + std::ignore = problem; + if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); - return static_cast*>(sbase)->GetWorkspaceSize( + return static_cast*>(sbase)->GetWorkspaceSize( ctx, problem); } @@ -108,6 +114,12 @@ SolverMixin::FindSolution(const Context& ctx, const miopen::AnyInvokeParams& invoke_ctx, const std::string& perf_cfg) const { + std::ignore = ctx; + std::ignore = problem; + std::ignore = db; + std::ignore = invoke_ctx; + std::ignore = perf_cfg; + if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); @@ -119,6 +131,9 @@ template std::vector SolverMixin::GetAllSolutions(const Context& ctx, const Problem& problem) const { + std::ignore = ctx; + std::ignore = problem; + if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); @@ -131,6 +146,10 @@ std::string SolverMixin::GetPerfCfgParams(const Context& ctx, const Problem& problem, const PerformanceDb& db) const { + std::ignore = ctx; + std::ignore = problem; + std::ignore = db; + if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); @@ -143,6 +162,10 @@ bool SolverMixin::TestPerfCfgParams(const Context& ctx, const Problem& problem, const std::string& params) const { + std::ignore = ctx; + std::ignore = problem; + std::ignore = params; + if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); From 2a9b9a70ecde2dfd0805ff489ccc8a69c1fc0f37 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Fri, 25 Oct 2024 21:46:06 +0200 Subject: [PATCH 10/26] Fix windows build --- src/fin/fin_interface.cpp | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/fin/fin_interface.cpp b/src/fin/fin_interface.cpp index b0756d104f..ddad293961 100644 --- a/src/fin/fin_interface.cpp +++ b/src/fin/fin_interface.cpp @@ -82,28 +82,22 @@ bool Solver::IsDynamic() const template bool SolverMixin::IsApplicable(const Context& ctx, const Problem& problem) const { - std::ignore = ctx; - std::ignore = problem; - if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); - return static_cast*>(sbase)->IsApplicable( - ctx, problem); + using SolverInterface = miopen::solver::SolverInterface; + return static_cast(sbase)->IsApplicable(ctx, problem); } template size_t SolverMixin::GetWorkspaceSize(const Context& ctx, const Problem& problem) const { - std::ignore = ctx; - std::ignore = problem; - if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); - return static_cast*>(sbase)->GetWorkspaceSize( - ctx, problem); + using SolverInterface = miopen::solver::SolverInterface; + return static_cast(sbase)->GetWorkspaceSize(ctx, problem); } template From 7b47b41a5bcda79fb432dd47eca71dde84cea3ff Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:44:47 +0100 Subject: [PATCH 11/26] Fix tidy --- src/include/miopen/fin/fin_interface.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/miopen/fin/fin_interface.hpp b/src/include/miopen/fin/fin_interface.hpp index d6d9633926..130e69d434 100644 --- a/src/include/miopen/fin/fin_interface.hpp +++ b/src/include/miopen/fin/fin_interface.hpp @@ -116,13 +116,13 @@ class MIOPEN_INTERNALS_EXPORT ConvSolver : public SolverMixin { public: + ConvSolver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id) = delete; std::string GetAlgo(miopen::conv::Direction dir) const; protected: ConvSolver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id, miopenConvAlgorithm_t algo_); - ConvSolver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id) = delete; using SolverMixin::SolverMixin; miopenConvAlgorithm_t algo; From 80261ed55b11700cdb841cc24742fc685b383c73 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:47:39 +0100 Subject: [PATCH 12/26] Fix windows build --- src/include/miopen/fin/fin_interface.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/miopen/fin/fin_interface.hpp b/src/include/miopen/fin/fin_interface.hpp index 130e69d434..fa3fd8472a 100644 --- a/src/include/miopen/fin/fin_interface.hpp +++ b/src/include/miopen/fin/fin_interface.hpp @@ -55,7 +55,7 @@ struct SolverBase; namespace fin { // Base classes for solvers. -class Solver +class MIOPEN_INTERNALS_EXPORT Solver { public: // GetId(), IsDynamic() and IsTunable() throw miopenStatusNotInitialized if the solver is not From a8a432eaf41c04557664d753a51bd7178adb02ae Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Mon, 28 Oct 2024 20:15:07 +0100 Subject: [PATCH 13/26] Refactor, add GetConvSolvers & GetBatchNormSolvers --- src/fin/fin_interface.cpp | 43 ++++++-- src/include/miopen/fin/fin_interface.hpp | 60 +++++------ test/gtest/unit_FinInterface.cpp | 129 ++++++++++------------- 3 files changed, 119 insertions(+), 113 deletions(-) diff --git a/src/fin/fin_interface.cpp b/src/fin/fin_interface.cpp index ddad293961..f6ddfc7b29 100644 --- a/src/fin/fin_interface.cpp +++ b/src/fin/fin_interface.cpp @@ -32,7 +32,7 @@ #include namespace miopen { -namespace fin { +namespace fin_interface { // ================== Solver ================== Solver::Solver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id) @@ -184,6 +184,8 @@ std::string ConvSolver::GetAlgo(miopen::conv::Direction dir) const } // ================== FinInterface ================== +namespace { + template struct SolverToPrimitive; @@ -199,8 +201,10 @@ struct SolverToPrimitive static auto GetPrimitive() { return miopen::solver::Primitive::Batchnorm; } }; +} // namespace + template -const std::vector& FinInterface::GetAllSolvers() +const std::vector& GetAllSolvers() { static const auto solvers = [] { const auto& ids = GetSolversByPrimitive(SolverToPrimitive::GetPrimitive()); @@ -223,7 +227,7 @@ const std::vector& FinInterface::GetAllSolvers() } template -Solver FinInterface::GetSolver(const std::string& name) +Solver GetSolver(const std::string& name) { const auto id = miopen::solver::Id{name}; if(!id.IsValid()) @@ -235,25 +239,48 @@ Solver FinInterface::GetSolver(const std::string& name) return {id.GetSolverBase(), id.Value()}; } -const std::vector& FinInterface::GetAllConvSolvers() +namespace { + +template +std::vector GetSolvers(const std::vector& names) +{ + std::vector solvers; + for(const auto& name : names) + solvers.emplace_back(GetSolver(name)); + return solvers; +} + +} // namespace + +const std::vector& GetAllConvSolvers() { return GetAllSolvers(); } -ConvSolver FinInterface::GetConvSolver(const std::string& name) +std::vector GetConvSolvers(const std::vector& names) +{ + return GetSolvers(names); +} + +ConvSolver GetConvSolver(const std::string& name) { return GetSolver(name); } -const std::vector& FinInterface::GetAllBatchNormSolvers() +const std::vector& GetAllBatchNormSolvers() { return GetAllSolvers(); } -BatchNormSolver FinInterface::GetBatchNormSolver(const std::string& name) +std::vector GetBatchNormSolvers(const std::vector& names) +{ + return GetSolvers(names); +} + +BatchNormSolver GetBatchNormSolver(const std::string& name) { return GetSolver(name); } -} // namespace fin +} // namespace fin_interface } // namespace miopen diff --git a/src/include/miopen/fin/fin_interface.hpp b/src/include/miopen/fin/fin_interface.hpp index fa3fd8472a..a098605646 100644 --- a/src/include/miopen/fin/fin_interface.hpp +++ b/src/include/miopen/fin/fin_interface.hpp @@ -52,7 +52,9 @@ namespace solver { struct SolverBase; } // namespace solver -namespace fin { +namespace fin_interface { + +// ================== Interface for Fin ================== // Base classes for solvers. class MIOPEN_INTERNALS_EXPORT Solver @@ -79,7 +81,10 @@ class MIOPEN_INTERNALS_EXPORT Solver const std::string rname; uint64_t id; - friend class FinInterface; + template + friend const std::vector& GetAllSolvers(); + template + friend Solver GetSolver(const std::string&); }; template @@ -127,7 +132,10 @@ class MIOPEN_INTERNALS_EXPORT ConvSolver miopenConvAlgorithm_t algo; - friend class FinInterface; + template + friend const std::vector& GetAllSolvers(); + template + friend Solver GetSolver(const std::string&); }; // Batch normalization solver @@ -138,31 +146,21 @@ class MIOPEN_INTERNALS_EXPORT BatchNormSolver using SolverMixin::SolverMixin; }; -// Interface for Fin -class FinInterface -{ -public: - // GetAll*Solvers() - returns all solvers for a particular primitive. All solvers are always - // valid. - // - // Get*Solver(name) - returns single solver by its name for a particular primitive. May return a - // dummy if a solver with specified name does not exist. - - // Convolution - MIOPEN_INTERNALS_EXPORT static const std::vector& GetAllConvSolvers(); - MIOPEN_INTERNALS_EXPORT static ConvSolver GetConvSolver(const std::string& name); +// GetAll*Solvers() - returns all solvers for a particular primitive. All solvers are always valid. +// +// Get*Solvers(names) - returns all solvers with the names specified in the list for a particular primitive. May return dummies ​​for solvers that do not exist. +// +// Get*Solver(name) - returns single solver by its name for a particular primitive. May return a dummy if a solver with specified name does not exist. - // Batch normalization - MIOPEN_INTERNALS_EXPORT static const std::vector& GetAllBatchNormSolvers(); - MIOPEN_INTERNALS_EXPORT static BatchNormSolver GetBatchNormSolver(const std::string& name); +// Convolution +MIOPEN_INTERNALS_EXPORT const std::vector& GetAllConvSolvers(); +MIOPEN_INTERNALS_EXPORT std::vector GetConvSolvers(const std::vector& names); +MIOPEN_INTERNALS_EXPORT ConvSolver GetConvSolver(const std::string& name); -private: - template - static const std::vector& GetAllSolvers(); - - template - static Solver GetSolver(const std::string& name); -}; +// Batch normalization +MIOPEN_INTERNALS_EXPORT const std::vector& GetAllBatchNormSolvers(); +MIOPEN_INTERNALS_EXPORT std::vector GetBatchNormSolvers(const std::vector& names); +MIOPEN_INTERNALS_EXPORT BatchNormSolver GetBatchNormSolver(const std::string& name); // Examples: // @@ -189,7 +187,7 @@ class FinInterface // // 1b (Nev version): // -// const auto& solver_list = miopen::fin::FinInterface::GetAllConvSolvers(); +// const auto& solver_list = miopen::fin_interface::GetAllConvSolvers(); // for(const auto& solver : solver_list) // { // std::unordered_map solver_info; @@ -223,7 +221,7 @@ class FinInterface // 2b (Nev version): // // std::string solver_name = "ConvBiasActivAsm1x1U"; -// const auto solver = miopen::fin::FinInterface::GetConvSolver(solver_name); +// const auto solver = miopen::fin_interface::GetConvSolver(solver_name); // std::unordered_map solver_info; // solver_info["name"] = solver.GetName(); // if(solver.IsValid()) @@ -237,13 +235,13 @@ class FinInterface // Batch normalization solvers: // // ... -// const auto& solver_list = miopen::fin::FinInterface::GetAllBatchNormSolvers(); +// const auto& solver_list = miopen::fin_interface::GetAllBatchNormSolvers(); // ... // // ... -// const auto solver = miopen::fin::FinInterface::GetBatchNormSolver(solver_name); +// const auto solver = miopen::fin_interface::GetBatchNormSolver(solver_name); // ... // -} // namespace fin +} // namespace fin_interface } // namespace miopen diff --git a/test/gtest/unit_FinInterface.cpp b/test/gtest/unit_FinInterface.cpp index 0f242ac32c..ab6acdfeaf 100644 --- a/test/gtest/unit_FinInterface.cpp +++ b/test/gtest/unit_FinInterface.cpp @@ -114,7 +114,11 @@ const auto& GetTestParams() return params; } -const auto& GetConvSolversInfo() +template +const auto& GetSolversInfo(); + +template <> +const auto& GetSolversInfo() { static const std::unordered_map solver_info = { // clang-format off @@ -228,7 +232,8 @@ const auto& GetConvSolversInfo() return solver_info; } -const auto& GetBatchNormSolversInfo() +template <> +const auto& GetSolversInfo() { static const std::unordered_map solver_info = { // clang-format off @@ -248,42 +253,27 @@ const auto& GetBatchNormSolversInfo() return solver_info; } -const auto& GetConvTestCases() +template +const auto& GetTestCases() { static const auto test_cases = [] { - std::vector test_cases; - const auto& sinfo = GetConvSolversInfo(); + std::vector test_cases; + const auto& sinfo = GetSolversInfo(); for(const auto& s : sinfo) - test_cases.emplace_back(ConvTestCase{s.first, s.second}); + test_cases.emplace_back(TestCase{s.first, s.second}); return test_cases; }(); return test_cases; } -const auto& GetBatchNormTestCases() -{ - static const auto test_cases = [] { - std::vector test_cases; - const auto& sinfo = GetBatchNormSolversInfo(); - for(const auto& s : sinfo) - test_cases.emplace_back(BatchNormTestCase{s.first, s.second}); - return test_cases; - }(); - return test_cases; -} - -template -void CheckSolverInfo(const Solver& solver, const Info& info) +template +void CheckSolverInfo(const Solver& solver, const SolverInfo& info) { ASSERT_EQ(solver.GetId(), info.id); ASSERT_EQ(solver.IsDynamic(), info.dynamic); ASSERT_EQ(solver.IsTunable(), info.tunable); -} - -void CheckConvSolverInfo(const miopen::fin::ConvSolver& solver, const ConvSolverInfo& info) -{ - ASSERT_NO_FATAL_FAILURE(CheckSolverInfo(solver, info)); - ASSERT_EQ(solver.GetAlgo(miopen::conv::Direction::Forward), info.algo); + if constexpr(std::is_same_v) + ASSERT_EQ(solver.GetAlgo(miopen::conv::Direction::Forward), info.algo); } template @@ -294,55 +284,44 @@ void CheckSolver(const Solver& solver, const TestCase& test_case) ASSERT_NO_FATAL_FAILURE(CheckSolverInfo(solver, test_case.info)); } -void CheckConvSolver(const miopen::fin::ConvSolver& solver, const ConvTestCase& test_case) +template +const auto& InterfaceGetAllSolvers(); + +template <> +const auto& InterfaceGetAllSolvers() { - ASSERT_EQ(solver.GetName(), test_case.name); - ASSERT_EQ(solver.IsValid(), true); - ASSERT_NO_FATAL_FAILURE(CheckConvSolverInfo(solver, test_case.info)); + return miopen::fin_interface::GetAllConvSolvers(); } -class TestGetAllConvSolvers : public ::testing::TestWithParam +template <> +const auto& InterfaceGetAllSolvers() { -public: - void RunTest() - { - const auto& solvers = miopen::fin::FinInterface::GetAllConvSolvers(); - const auto& solvers_info = GetConvSolversInfo(); + return miopen::fin_interface::GetAllBatchNormSolvers(); +} - ASSERT_EQ(solvers.size(), solvers_info.size()); - for(const auto& solver : solvers) - { - const auto& name = solver.GetName(); - const auto& solver_info = solvers_info.find(name); - if(solver_info == solvers_info.end()) - { - const std::string error = name + " not found"; - GTEST_FAIL() << error; - } - ASSERT_NO_FATAL_FAILURE(CheckConvSolverInfo(solver, solver_info->second)); - } - } -}; +template +auto InterfaceGetSolver(const std::string& name); -class TestGetConvSolver : public ::testing::TestWithParam> +template <> +auto InterfaceGetSolver(const std::string& name) { -public: - void RunTest() - { - ConvTestCase test_case; - std::tie(std::ignore, test_case) = GetParam(); - const auto solver = miopen::fin::FinInterface::GetConvSolver(test_case.name); - CheckConvSolver(solver, test_case); - } -}; + return miopen::fin_interface::GetConvSolver(name); +} + +template <> +auto InterfaceGetSolver(const std::string& name) +{ + return miopen::fin_interface::GetBatchNormSolver(name); +} -class TestGetAllBatchNormSolvers : public ::testing::TestWithParam +template +class TestGetAllSolvers : public ::testing::TestWithParam { public: void RunTest() { - const auto& solvers = miopen::fin::FinInterface::GetAllBatchNormSolvers(); - const auto& solvers_info = GetBatchNormSolversInfo(); + const auto& solvers = InterfaceGetAllSolvers(); + const auto& solvers_info = GetSolversInfo(); ASSERT_EQ(solvers.size(), solvers_info.size()); for(const auto& solver : solvers) @@ -359,25 +338,25 @@ class TestGetAllBatchNormSolvers : public ::testing::TestWithParam } }; -class TestGetBatchNormSolver - : public ::testing::TestWithParam> +template +class TestGetSolver : public ::testing::TestWithParam> { public: void RunTest() { - BatchNormTestCase test_case; - std::tie(std::ignore, test_case) = GetParam(); - const auto solver = miopen::fin::FinInterface::GetBatchNormSolver(test_case.name); + TestCase test_case; + std::tie(std::ignore, test_case) = this->GetParam(); + const auto solver = InterfaceGetSolver(test_case.name); CheckSolver(solver, test_case); } }; } // namespace -using CPU_FinInterfaceTestGetAllConvSolvers_NONE = TestGetAllConvSolvers; -using CPU_FinInterfaceTestGetConvSolver_NONE = TestGetConvSolver; -using CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE = TestGetAllBatchNormSolvers; -using CPU_FinInterfaceTestGetBatchNormSolver_NONE = TestGetBatchNormSolver; +using CPU_FinInterfaceTestGetAllConvSolvers_NONE = TestGetAllSolvers; +using CPU_FinInterfaceTestGetConvSolver_NONE = TestGetSolver; +using CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE = TestGetAllSolvers; +using CPU_FinInterfaceTestGetBatchNormSolver_NONE = TestGetSolver; TEST_P(CPU_FinInterfaceTestGetAllConvSolvers_NONE, FinInterface) { this->RunTest(); }; TEST_P(CPU_FinInterfaceTestGetConvSolver_NONE, FinInterface) { this->RunTest(); }; @@ -391,13 +370,15 @@ INSTANTIATE_TEST_SUITE_P(Full, INSTANTIATE_TEST_SUITE_P(Full, CPU_FinInterfaceTestGetConvSolver_NONE, testing::Combine(testing::Values(GetTestParams()), - testing::ValuesIn(GetConvTestCases()))); + testing::ValuesIn(GetTestCases()))); INSTANTIATE_TEST_SUITE_P(Full, CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE, testing::Values(GetTestParams())); +// clang-format off INSTANTIATE_TEST_SUITE_P(Full, CPU_FinInterfaceTestGetBatchNormSolver_NONE, testing::Combine(testing::Values(GetTestParams()), - testing::ValuesIn(GetBatchNormTestCases()))); + testing::ValuesIn(GetTestCases()))); +// clang-format on From f57bbac187670f6c3ae72942f3da6d2c239f13ef Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Mon, 28 Oct 2024 20:17:01 +0100 Subject: [PATCH 14/26] Fix formatting --- src/fin/fin_interface.cpp | 10 ++-------- src/include/miopen/fin/fin_interface.hpp | 12 ++++++++---- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/fin/fin_interface.cpp b/src/fin/fin_interface.cpp index f6ddfc7b29..93689b27c0 100644 --- a/src/fin/fin_interface.cpp +++ b/src/fin/fin_interface.cpp @@ -252,20 +252,14 @@ std::vector GetSolvers(const std::vector& names) } // namespace -const std::vector& GetAllConvSolvers() -{ - return GetAllSolvers(); -} +const std::vector& GetAllConvSolvers() { return GetAllSolvers(); } std::vector GetConvSolvers(const std::vector& names) { return GetSolvers(names); } -ConvSolver GetConvSolver(const std::string& name) -{ - return GetSolver(name); -} +ConvSolver GetConvSolver(const std::string& name) { return GetSolver(name); } const std::vector& GetAllBatchNormSolvers() { diff --git a/src/include/miopen/fin/fin_interface.hpp b/src/include/miopen/fin/fin_interface.hpp index a098605646..737c2ea11b 100644 --- a/src/include/miopen/fin/fin_interface.hpp +++ b/src/include/miopen/fin/fin_interface.hpp @@ -148,18 +148,22 @@ class MIOPEN_INTERNALS_EXPORT BatchNormSolver // GetAll*Solvers() - returns all solvers for a particular primitive. All solvers are always valid. // -// Get*Solvers(names) - returns all solvers with the names specified in the list for a particular primitive. May return dummies ​​for solvers that do not exist. +// Get*Solvers(names) - returns all solvers with the names specified in the list for a particular +// primitive. May return dummies ​​for solvers that do not exist. // -// Get*Solver(name) - returns single solver by its name for a particular primitive. May return a dummy if a solver with specified name does not exist. +// Get*Solver(name) - returns single solver by its name for a particular primitive. May return a +// dummy if a solver with specified name does not exist. // Convolution MIOPEN_INTERNALS_EXPORT const std::vector& GetAllConvSolvers(); -MIOPEN_INTERNALS_EXPORT std::vector GetConvSolvers(const std::vector& names); +MIOPEN_INTERNALS_EXPORT std::vector +GetConvSolvers(const std::vector& names); MIOPEN_INTERNALS_EXPORT ConvSolver GetConvSolver(const std::string& name); // Batch normalization MIOPEN_INTERNALS_EXPORT const std::vector& GetAllBatchNormSolvers(); -MIOPEN_INTERNALS_EXPORT std::vector GetBatchNormSolvers(const std::vector& names); +MIOPEN_INTERNALS_EXPORT std::vector +GetBatchNormSolvers(const std::vector& names); MIOPEN_INTERNALS_EXPORT BatchNormSolver GetBatchNormSolver(const std::string& name); // Examples: From 81ffac5f73dfa97ad37b1cc3ca0ca49eb3bf81e8 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Tue, 29 Oct 2024 10:18:12 +0100 Subject: [PATCH 15/26] Fix tidy --- src/fin/fin_interface.cpp | 2 ++ test/gtest/unit_FinInterface.cpp | 1 + 2 files changed, 3 insertions(+) diff --git a/src/fin/fin_interface.cpp b/src/fin/fin_interface.cpp index 93689b27c0..6cb739ab50 100644 --- a/src/fin/fin_interface.cpp +++ b/src/fin/fin_interface.cpp @@ -209,6 +209,7 @@ const std::vector& GetAllSolvers() static const auto solvers = [] { const auto& ids = GetSolversByPrimitive(SolverToPrimitive::GetPrimitive()); std::vector solvers; + solvers.reserve(ids.size()); for(const auto& id : ids) { @@ -245,6 +246,7 @@ template std::vector GetSolvers(const std::vector& names) { std::vector solvers; + solvers.reserve(names.size()); for(const auto& name : names) solvers.emplace_back(GetSolver(name)); return solvers; diff --git a/test/gtest/unit_FinInterface.cpp b/test/gtest/unit_FinInterface.cpp index ab6acdfeaf..1c432e8a30 100644 --- a/test/gtest/unit_FinInterface.cpp +++ b/test/gtest/unit_FinInterface.cpp @@ -259,6 +259,7 @@ const auto& GetTestCases() static const auto test_cases = [] { std::vector test_cases; const auto& sinfo = GetSolversInfo(); + test_cases.reserve(sinfo.size()); for(const auto& s : sinfo) test_cases.emplace_back(TestCase{s.first, s.second}); return test_cases; From 230acd57b408c46d12fb0d9c0b2fb1beab447566 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Tue, 29 Oct 2024 11:08:16 +0100 Subject: [PATCH 16/26] Add test for GetSolvers() --- test/gtest/unit_FinInterface.cpp | 73 ++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/test/gtest/unit_FinInterface.cpp b/test/gtest/unit_FinInterface.cpp index 1c432e8a30..9c84b2a81a 100644 --- a/test/gtest/unit_FinInterface.cpp +++ b/test/gtest/unit_FinInterface.cpp @@ -253,6 +253,20 @@ const auto& GetSolversInfo() return solver_info; } +template +const auto& GetSolverNames() +{ + static const auto names = [] { + std::vector names; + const auto& sinfo = GetSolversInfo(); + names.reserve(sinfo.size()); + for(const auto& s : sinfo) + names.push_back(s.first); + return names; + }(); + return names; +} + template const auto& GetTestCases() { @@ -267,6 +281,7 @@ const auto& GetTestCases() return test_cases; } +// Checks template void CheckSolverInfo(const Solver& solver, const SolverInfo& info) { @@ -285,6 +300,7 @@ void CheckSolver(const Solver& solver, const TestCase& test_case) ASSERT_NO_FATAL_FAILURE(CheckSolverInfo(solver, test_case.info)); } +// GetAll*Solvers() template const auto& InterfaceGetAllSolvers(); @@ -300,6 +316,23 @@ const auto& InterfaceGetAllSolvers() return miopen::fin_interface::GetAllBatchNormSolvers(); } +// Get*Solvers(names) +template +const auto InterfaceGetSolvers(const std::vector& names); + +template <> +const auto InterfaceGetSolvers(const std::vector& names) +{ + return miopen::fin_interface::GetConvSolvers(names); +} + +template <> +const auto InterfaceGetSolvers(const std::vector& names) +{ + return miopen::fin_interface::GetBatchNormSolvers(names); +} + +// Get*Solver(name) template auto InterfaceGetSolver(const std::string& name); @@ -315,6 +348,7 @@ auto InterfaceGetSolver(const std::string& name) return miopen::fin_interface::GetBatchNormSolver(name); } +// Tests template class TestGetAllSolvers : public ::testing::TestWithParam { @@ -339,6 +373,31 @@ class TestGetAllSolvers : public ::testing::TestWithParam } }; +template +class TestGetSolvers : public ::testing::TestWithParam +{ +public: + void RunTest() + { + const auto& solvers_info = GetSolversInfo(); + const auto& names = GetSolverNames(); + const auto solvers = InterfaceGetSolvers(names); + + ASSERT_EQ(solvers.size(), names.size()); + for(const auto& solver : solvers) + { + const auto& name = solver.GetName(); + const auto& solver_info = solvers_info.find(name); + if(solver_info == solvers_info.end()) + { + const std::string error = name + " not found"; + GTEST_FAIL() << error; + } + ASSERT_NO_FATAL_FAILURE(CheckSolver(solver, TestCase{name, solver_info->second})); + } + } +}; + template class TestGetSolver : public ::testing::TestWithParam> { @@ -355,19 +414,29 @@ class TestGetSolver : public ::testing::TestWithParam; +using CPU_FinInterfaceTestGetConvSolvers_NONE = TestGetSolvers; using CPU_FinInterfaceTestGetConvSolver_NONE = TestGetSolver; + using CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE = TestGetAllSolvers; +using CPU_FinInterfaceTestGetBatchNormSolvers_NONE = TestGetSolvers; using CPU_FinInterfaceTestGetBatchNormSolver_NONE = TestGetSolver; TEST_P(CPU_FinInterfaceTestGetAllConvSolvers_NONE, FinInterface) { this->RunTest(); }; +TEST_P(CPU_FinInterfaceTestGetConvSolvers_NONE, FinInterface) { this->RunTest(); }; TEST_P(CPU_FinInterfaceTestGetConvSolver_NONE, FinInterface) { this->RunTest(); }; + TEST_P(CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE, FinInterface) { this->RunTest(); }; +TEST_P(CPU_FinInterfaceTestGetBatchNormSolvers_NONE, FinInterface) { this->RunTest(); }; TEST_P(CPU_FinInterfaceTestGetBatchNormSolver_NONE, FinInterface) { this->RunTest(); }; INSTANTIATE_TEST_SUITE_P(Full, CPU_FinInterfaceTestGetAllConvSolvers_NONE, testing::Values(GetTestParams())); +INSTANTIATE_TEST_SUITE_P(Full, + CPU_FinInterfaceTestGetConvSolvers_NONE, + testing::Values(GetTestParams())); + INSTANTIATE_TEST_SUITE_P(Full, CPU_FinInterfaceTestGetConvSolver_NONE, testing::Combine(testing::Values(GetTestParams()), @@ -377,6 +446,10 @@ INSTANTIATE_TEST_SUITE_P(Full, CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE, testing::Values(GetTestParams())); +INSTANTIATE_TEST_SUITE_P(Full, + CPU_FinInterfaceTestGetBatchNormSolvers_NONE, + testing::Values(GetTestParams())); + // clang-format off INSTANTIATE_TEST_SUITE_P(Full, CPU_FinInterfaceTestGetBatchNormSolver_NONE, From 0f9bdf226079367a6ea29d3590080a1f010e02a4 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Tue, 29 Oct 2024 11:11:30 +0100 Subject: [PATCH 17/26] Refactor --- test/gtest/unit_FinInterface.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/test/gtest/unit_FinInterface.cpp b/test/gtest/unit_FinInterface.cpp index 9c84b2a81a..1c5e76a374 100644 --- a/test/gtest/unit_FinInterface.cpp +++ b/test/gtest/unit_FinInterface.cpp @@ -413,22 +413,15 @@ class TestGetSolver : public ::testing::TestWithParam; using CPU_FinInterfaceTestGetConvSolvers_NONE = TestGetSolvers; using CPU_FinInterfaceTestGetConvSolver_NONE = TestGetSolver; -using CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE = TestGetAllSolvers; -using CPU_FinInterfaceTestGetBatchNormSolvers_NONE = TestGetSolvers; -using CPU_FinInterfaceTestGetBatchNormSolver_NONE = TestGetSolver; - TEST_P(CPU_FinInterfaceTestGetAllConvSolvers_NONE, FinInterface) { this->RunTest(); }; TEST_P(CPU_FinInterfaceTestGetConvSolvers_NONE, FinInterface) { this->RunTest(); }; TEST_P(CPU_FinInterfaceTestGetConvSolver_NONE, FinInterface) { this->RunTest(); }; -TEST_P(CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE, FinInterface) { this->RunTest(); }; -TEST_P(CPU_FinInterfaceTestGetBatchNormSolvers_NONE, FinInterface) { this->RunTest(); }; -TEST_P(CPU_FinInterfaceTestGetBatchNormSolver_NONE, FinInterface) { this->RunTest(); }; - INSTANTIATE_TEST_SUITE_P(Full, CPU_FinInterfaceTestGetAllConvSolvers_NONE, testing::Values(GetTestParams())); @@ -442,6 +435,15 @@ INSTANTIATE_TEST_SUITE_P(Full, testing::Combine(testing::Values(GetTestParams()), testing::ValuesIn(GetTestCases()))); +// Batch normalization +using CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE = TestGetAllSolvers; +using CPU_FinInterfaceTestGetBatchNormSolvers_NONE = TestGetSolvers; +using CPU_FinInterfaceTestGetBatchNormSolver_NONE = TestGetSolver; + +TEST_P(CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE, FinInterface) { this->RunTest(); }; +TEST_P(CPU_FinInterfaceTestGetBatchNormSolvers_NONE, FinInterface) { this->RunTest(); }; +TEST_P(CPU_FinInterfaceTestGetBatchNormSolver_NONE, FinInterface) { this->RunTest(); }; + INSTANTIATE_TEST_SUITE_P(Full, CPU_FinInterfaceTestGetAllBatchNormSolvers_NONE, testing::Values(GetTestParams())); From 61bfaea4ae416a3358b3e77801682e266b98ac8a Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Tue, 29 Oct 2024 12:28:04 +0100 Subject: [PATCH 18/26] Fix linker error --- src/fin/fin_interface.cpp | 4 ++++ src/include/miopen/fin/fin_interface.hpp | 9 ++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/fin/fin_interface.cpp b/src/fin/fin_interface.cpp index 6cb739ab50..c961b5c274 100644 --- a/src/fin/fin_interface.cpp +++ b/src/fin/fin_interface.cpp @@ -167,6 +167,10 @@ bool SolverMixin::TestPerfCfgParams(const Context& ctx, MIOPEN_THROW(miopenStatusNotImplemented); } +// Explicit instantiation +template class SolverMixin; +template class SolverMixin; + // ================== ConvSolver ================== ConvSolver::ConvSolver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id, diff --git a/src/include/miopen/fin/fin_interface.hpp b/src/include/miopen/fin/fin_interface.hpp index 737c2ea11b..16dbb3f2dc 100644 --- a/src/include/miopen/fin/fin_interface.hpp +++ b/src/include/miopen/fin/fin_interface.hpp @@ -57,7 +57,7 @@ namespace fin_interface { // ================== Interface for Fin ================== // Base classes for solvers. -class MIOPEN_INTERNALS_EXPORT Solver +class Solver { public: // GetId(), IsDynamic() and IsTunable() throw miopenStatusNotInitialized if the solver is not @@ -116,8 +116,11 @@ class SolverMixin : public Solver using Solver::Solver; }; +extern template class SolverMixin; +extern template class SolverMixin; + // Convolution solver -class MIOPEN_INTERNALS_EXPORT ConvSolver +class MIOPEN_INTERNALS_EXPORT ConvSolver final : public SolverMixin { public: @@ -139,7 +142,7 @@ class MIOPEN_INTERNALS_EXPORT ConvSolver }; // Batch normalization solver -class MIOPEN_INTERNALS_EXPORT BatchNormSolver +class MIOPEN_INTERNALS_EXPORT BatchNormSolver final : public SolverMixin { protected: From 8fa4325f7e8ab4eb427e9b9d68c3023b1c47a888 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Tue, 29 Oct 2024 12:29:16 +0100 Subject: [PATCH 19/26] Fix formatting --- test/gtest/unit_FinInterface.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/gtest/unit_FinInterface.cpp b/test/gtest/unit_FinInterface.cpp index 1c5e76a374..ef3e6fee2a 100644 --- a/test/gtest/unit_FinInterface.cpp +++ b/test/gtest/unit_FinInterface.cpp @@ -380,8 +380,8 @@ class TestGetSolvers : public ::testing::TestWithParam void RunTest() { const auto& solvers_info = GetSolversInfo(); - const auto& names = GetSolverNames(); - const auto solvers = InterfaceGetSolvers(names); + const auto& names = GetSolverNames(); + const auto solvers = InterfaceGetSolvers(names); ASSERT_EQ(solvers.size(), names.size()); for(const auto& solver : solvers) @@ -414,9 +414,9 @@ class TestGetSolver : public ::testing::TestWithParam; -using CPU_FinInterfaceTestGetConvSolvers_NONE = TestGetSolvers; -using CPU_FinInterfaceTestGetConvSolver_NONE = TestGetSolver; +using CPU_FinInterfaceTestGetAllConvSolvers_NONE = TestGetAllSolvers; +using CPU_FinInterfaceTestGetConvSolvers_NONE = TestGetSolvers; +using CPU_FinInterfaceTestGetConvSolver_NONE = TestGetSolver; TEST_P(CPU_FinInterfaceTestGetAllConvSolvers_NONE, FinInterface) { this->RunTest(); }; TEST_P(CPU_FinInterfaceTestGetConvSolvers_NONE, FinInterface) { this->RunTest(); }; From c7525a9fc1f2babaa18daad4848c9dcef6676d8b Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Wed, 30 Oct 2024 19:59:59 +0100 Subject: [PATCH 20/26] Add test for IsApplicable() & GetWorkspaceSize() --- test/gtest/unit_FinInterface.cpp | 162 ++++++++++++++++++++++++++++++- test/gtest/unit_conv_solver.cpp | 41 ++++---- test/gtest/unit_conv_solver.hpp | 2 + 3 files changed, 179 insertions(+), 26 deletions(-) diff --git a/test/gtest/unit_FinInterface.cpp b/test/gtest/unit_FinInterface.cpp index ef3e6fee2a..88a745b1dd 100644 --- a/test/gtest/unit_FinInterface.cpp +++ b/test/gtest/unit_FinInterface.cpp @@ -31,8 +31,13 @@ #include +// batchnorm::ProblemDescription +#include #include +#include "get_handle.hpp" +#include "unit_conv_solver.hpp" + namespace { struct TestParams @@ -89,7 +94,83 @@ struct ConvSolverInfo : SolverInfo using BatchNormSolverInfo = SolverInfo; -template +struct SolverConfig +{ + SolverConfig() : empty(true) + { + } + SolverConfig(bool empty_) : empty(empty_) + { + } + + bool empty; +}; + +struct ConvSolverConfig : SolverConfig, private miopen::unit_tests::ConvTestCase +{ + ConvSolverConfig(miopen::conv::Direction direction_, + miopen::unit_tests::TensorDescriptorParams&& x, + miopen::unit_tests::TensorDescriptorParams&& w, + miopenDataType_t type_y, + miopen::unit_tests::ConvolutionDescriptorParams&& conv) : SolverConfig(false), miopen::unit_tests::ConvTestCase(std::move(x), std::move(w), type_y, std::move(conv)), direction(direction_) + { + } + + auto GetProblemDescription() const + { + return GetProblemDescription(direction); + } + + friend std::ostream& operator<<(std::ostream& os, const ConvSolverConfig& config) + { + os << "("; + if(config.empty) + { + os << "empty"; + } + else + { + os << "direction:" << static_cast(config.direction); + os << ", " << static_cast(config); + } + os << ")"; + return os; + } + +private: + miopen::conv::Direction direction; + + using SolverConfig::SolverConfig; + using miopen::unit_tests::ConvTestCase::GetProblemDescription; +}; + +struct BatchNormSolverConfig : SolverConfig +{ + BatchNormSolverConfig(int dummy) : SolverConfig(false) + { + std::ignore = dummy; + } + + auto GetProblemDescription() const + { + return miopen::batchnorm::ProblemDescription{{}, {}, {}, {}, {}, {}, {}, {}}; + } + + friend std::ostream& operator<<(std::ostream& os, const BatchNormSolverConfig& config) + { + os << "("; + if(config.empty) + os << "empty"; + else + os << "none"; + os << ")"; + return os; + } + + using SolverConfig::SolverConfig; +}; + +template struct TestCase { friend std::ostream& operator<<(std::ostream& os, const TestCase& tc) @@ -97,16 +178,18 @@ struct TestCase os << "("; os << "name:" << tc.name; os << ", info:" << tc.info; + os << ", config:" << tc.config; os << ")"; return os; } std::string name; Info info; + SolverConfig config; }; -using ConvTestCase = TestCase; -using BatchNormTestCase = TestCase; +using ConvTestCase = TestCase; +using BatchNormTestCase = TestCase; const auto& GetTestParams() { @@ -253,6 +336,36 @@ const auto& GetSolversInfo() return solver_info; } +template +const auto& GetSolverConfigs(); + +template <> +const auto& GetSolverConfigs() +{ + static const std::unordered_map configs = { + // clang-format off + {"ConvDirectNaiveConvFwd", {miopen::conv::Direction::Forward, {miopenFloat, {1, 16, 14, 14}}, {miopenFloat, {48, 16, 5, 5}}, miopenFloat, {{2, 2}, {1, 1}, {1, 1}}}}, + {"ConvDirectNaiveConvBwd", {miopen::conv::Direction::BackwardData, {miopenFloat, {1, 16, 14, 14}}, {miopenFloat, {48, 16, 5, 5}}, miopenFloat, {{2, 2}, {1, 1}, {1, 1}}}}, + {"ConvDirectNaiveConvWrw", {miopen::conv::Direction::BackwardWeights, {miopenFloat, {1, 16, 14, 14}}, {miopenFloat, {48, 16, 5, 5}}, miopenFloat, {{2, 2}, {1, 1}, {1, 1}}}}, + // clang-format on + }; + + return configs; +} + +template <> +const auto& GetSolverConfigs() +{ + static const std::unordered_map configs = { + // clang-format off + /// \todo add configs + {"DummySolver", {42}}, + // clang-format on + }; + + return configs; +} + template const auto& GetSolverNames() { @@ -273,14 +386,40 @@ const auto& GetTestCases() static const auto test_cases = [] { std::vector test_cases; const auto& sinfo = GetSolversInfo(); + const auto& configs = GetSolverConfigs(); test_cases.reserve(sinfo.size()); for(const auto& s : sinfo) - test_cases.emplace_back(TestCase{s.first, s.second}); + { + const auto& config = configs.find(s.first); + if(config == configs.end()) + test_cases.emplace_back(TestCase{s.first, s.second, {}}); + else + test_cases.emplace_back(TestCase{s.first, s.second, config->second}); + } return test_cases; }(); return test_cases; } +// Context +template +auto GetContext(miopen::Handle* handle, const Problem& problem); + +template <> +auto GetContext(miopen::Handle* handle, const miopen::conv::ProblemDescription& problem) +{ + auto tmp = miopen::ExecutionContext{handle}; + problem.SetupFloats(tmp); + return tmp; +} + +template <> +auto GetContext(miopen::Handle* handle, const miopen::batchnorm::ProblemDescription&) +{ + auto tmp = miopen::ExecutionContext{handle}; + return tmp; +} + // Checks template void CheckSolverInfo(const Solver& solver, const SolverInfo& info) @@ -292,12 +431,25 @@ void CheckSolverInfo(const Solver& solver, const SolverInfo& info) ASSERT_EQ(solver.GetAlgo(miopen::conv::Direction::Forward), info.algo); } +template +void CheckSolverConfig(const Solver& solver, const SolverConfig& config) +{ + auto&& handle = get_handle(); + const auto problem = config.GetProblemDescription(); + const auto ctx = GetContext(&handle, problem); + + ASSERT_EQ(solver.IsApplicable(ctx, problem), true); + std::ignore = solver.GetWorkspaceSize(ctx, problem); +} + template void CheckSolver(const Solver& solver, const TestCase& test_case) { ASSERT_EQ(solver.GetName(), test_case.name); ASSERT_EQ(solver.IsValid(), true); ASSERT_NO_FATAL_FAILURE(CheckSolverInfo(solver, test_case.info)); + if(!test_case.config.empty) + ASSERT_NO_FATAL_FAILURE(CheckSolverConfig(solver, test_case.config)); } // GetAll*Solvers() @@ -393,7 +545,7 @@ class TestGetSolvers : public ::testing::TestWithParam const std::string error = name + " not found"; GTEST_FAIL() << error; } - ASSERT_NO_FATAL_FAILURE(CheckSolver(solver, TestCase{name, solver_info->second})); + ASSERT_NO_FATAL_FAILURE(CheckSolver(solver, TestCase{name, solver_info->second, {}})); } } }; diff --git a/test/gtest/unit_conv_solver.cpp b/test/gtest/unit_conv_solver.cpp index ffc1f99025..87166632ee 100644 --- a/test/gtest/unit_conv_solver.cpp +++ b/test/gtest/unit_conv_solver.cpp @@ -73,26 +73,6 @@ bool IsDeviceSupported(Gpu supported_devs, Gpu dev) return false; } -miopen::conv::ProblemDescription GetProblemDescription(miopen::conv::Direction direction, - const ConvTestCase& conv_config) -{ - const auto x_desc = conv_config.GetXTensorDescriptor(); - const auto w_desc = conv_config.GetWTensorDescriptor(); - const auto conv_desc = conv_config.GetConv(); - const auto y_desc = - conv_desc.GetForwardOutputTensor(x_desc, w_desc, conv_config.GetYDataType()); - - switch(direction) - { - case miopen::conv::Direction::Forward: - case miopen::conv::Direction::BackwardData: - return miopen::conv::ProblemDescription(x_desc, w_desc, y_desc, conv_desc, direction); - case miopen::conv::Direction::BackwardWeights: - return miopen::conv::ProblemDescription(y_desc, w_desc, x_desc, conv_desc, direction); - default: throw std::runtime_error("unknown direction"); - } -} - } // namespace //************************************************************************************ @@ -171,6 +151,25 @@ miopen::ConvolutionDescriptor ConvTestCase::GetConv() const return conv.GetConvolutionDescriptor(); } +miopen::conv::ProblemDescription ConvTestCase::GetProblemDescription(miopen::conv::Direction direction) const +{ + const auto x_desc = GetXTensorDescriptor(); + const auto w_desc = GetWTensorDescriptor(); + const auto conv_desc = GetConv(); + const auto y_desc = + conv_desc.GetForwardOutputTensor(x_desc, w_desc, GetYDataType()); + + switch(direction) + { + case miopen::conv::Direction::Forward: + case miopen::conv::Direction::BackwardData: + return miopen::conv::ProblemDescription(x_desc, w_desc, y_desc, conv_desc, direction); + case miopen::conv::Direction::BackwardWeights: + return miopen::conv::ProblemDescription(y_desc, w_desc, x_desc, conv_desc, direction); + default: throw std::runtime_error("unknown direction"); + } +} + std::ostream& operator<<(std::ostream& os, const ConvTestCase& tc) { os << "("; @@ -732,7 +731,7 @@ void UnitTestConvSolverDevApplicabilityBase::RunTestImpl( deprecated_solv_enabler.Enable(); } - const auto problem = GetProblemDescription(direction, conv_config); + const auto problem = conv_config.GetProblemDescription(direction); const auto all_known_devs = GetAllKnownDevices(); for(const auto& [dev, dev_descr] : all_known_devs) diff --git a/test/gtest/unit_conv_solver.hpp b/test/gtest/unit_conv_solver.hpp index 8ab181906e..e048692f9b 100644 --- a/test/gtest/unit_conv_solver.hpp +++ b/test/gtest/unit_conv_solver.hpp @@ -26,6 +26,7 @@ #pragma once #include +#include #include "gtest_common.hpp" #include "unit_conv_ConvolutionDescriptor.hpp" @@ -71,6 +72,7 @@ struct ConvTestCase miopenDataType_t GetYDataType() const; miopen::ConvolutionDescriptor GetConv() const; + miopen::conv::ProblemDescription GetProblemDescription(miopen::conv::Direction direction) const; friend std::ostream& operator<<(std::ostream& os, const ConvTestCase& tc); From 6da9a98a90821d5a6cfac66e0fc875723b71fcf7 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Wed, 30 Oct 2024 20:02:53 +0100 Subject: [PATCH 21/26] Fix formatting --- test/gtest/unit_FinInterface.cpp | 31 ++++++++++++------------------- test/gtest/unit_conv_solver.cpp | 6 +++--- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/test/gtest/unit_FinInterface.cpp b/test/gtest/unit_FinInterface.cpp index 88a745b1dd..4854162a82 100644 --- a/test/gtest/unit_FinInterface.cpp +++ b/test/gtest/unit_FinInterface.cpp @@ -96,12 +96,8 @@ using BatchNormSolverInfo = SolverInfo; struct SolverConfig { - SolverConfig() : empty(true) - { - } - SolverConfig(bool empty_) : empty(empty_) - { - } + SolverConfig() : empty(true) {} + SolverConfig(bool empty_) : empty(empty_) {} bool empty; }; @@ -112,14 +108,14 @@ struct ConvSolverConfig : SolverConfig, private miopen::unit_tests::ConvTestCase miopen::unit_tests::TensorDescriptorParams&& x, miopen::unit_tests::TensorDescriptorParams&& w, miopenDataType_t type_y, - miopen::unit_tests::ConvolutionDescriptorParams&& conv) : SolverConfig(false), miopen::unit_tests::ConvTestCase(std::move(x), std::move(w), type_y, std::move(conv)), direction(direction_) + miopen::unit_tests::ConvolutionDescriptorParams&& conv) + : SolverConfig(false), + miopen::unit_tests::ConvTestCase(std::move(x), std::move(w), type_y, std::move(conv)), + direction(direction_) { } - auto GetProblemDescription() const - { - return GetProblemDescription(direction); - } + auto GetProblemDescription() const { return GetProblemDescription(direction); } friend std::ostream& operator<<(std::ostream& os, const ConvSolverConfig& config) { @@ -140,16 +136,13 @@ struct ConvSolverConfig : SolverConfig, private miopen::unit_tests::ConvTestCase private: miopen::conv::Direction direction; - using SolverConfig::SolverConfig; using miopen::unit_tests::ConvTestCase::GetProblemDescription; + using SolverConfig::SolverConfig; }; struct BatchNormSolverConfig : SolverConfig { - BatchNormSolverConfig(int dummy) : SolverConfig(false) - { - std::ignore = dummy; - } + BatchNormSolverConfig(int dummy) : SolverConfig(false) { std::ignore = dummy; } auto GetProblemDescription() const { @@ -385,7 +378,7 @@ const auto& GetTestCases() { static const auto test_cases = [] { std::vector test_cases; - const auto& sinfo = GetSolversInfo(); + const auto& sinfo = GetSolversInfo(); const auto& configs = GetSolverConfigs(); test_cases.reserve(sinfo.size()); for(const auto& s : sinfo) @@ -434,9 +427,9 @@ void CheckSolverInfo(const Solver& solver, const SolverInfo& info) template void CheckSolverConfig(const Solver& solver, const SolverConfig& config) { - auto&& handle = get_handle(); + auto&& handle = get_handle(); const auto problem = config.GetProblemDescription(); - const auto ctx = GetContext(&handle, problem); + const auto ctx = GetContext(&handle, problem); ASSERT_EQ(solver.IsApplicable(ctx, problem), true); std::ignore = solver.GetWorkspaceSize(ctx, problem); diff --git a/test/gtest/unit_conv_solver.cpp b/test/gtest/unit_conv_solver.cpp index 87166632ee..5ba36e6ad2 100644 --- a/test/gtest/unit_conv_solver.cpp +++ b/test/gtest/unit_conv_solver.cpp @@ -151,13 +151,13 @@ miopen::ConvolutionDescriptor ConvTestCase::GetConv() const return conv.GetConvolutionDescriptor(); } -miopen::conv::ProblemDescription ConvTestCase::GetProblemDescription(miopen::conv::Direction direction) const +miopen::conv::ProblemDescription +ConvTestCase::GetProblemDescription(miopen::conv::Direction direction) const { const auto x_desc = GetXTensorDescriptor(); const auto w_desc = GetWTensorDescriptor(); const auto conv_desc = GetConv(); - const auto y_desc = - conv_desc.GetForwardOutputTensor(x_desc, w_desc, GetYDataType()); + const auto y_desc = conv_desc.GetForwardOutputTensor(x_desc, w_desc, GetYDataType()); switch(direction) { From 9a3591e1ec99077566131ab917704ebd1a17617a Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Thu, 31 Oct 2024 16:44:11 +0100 Subject: [PATCH 22/26] Fix windows build --- src/include/miopen/fin/fin_interface.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/miopen/fin/fin_interface.hpp b/src/include/miopen/fin/fin_interface.hpp index 16dbb3f2dc..569a249e0b 100644 --- a/src/include/miopen/fin/fin_interface.hpp +++ b/src/include/miopen/fin/fin_interface.hpp @@ -57,7 +57,7 @@ namespace fin_interface { // ================== Interface for Fin ================== // Base classes for solvers. -class Solver +class MIOPEN_INTERNALS_EXPORT Solver { public: // GetId(), IsDynamic() and IsTunable() throw miopenStatusNotInitialized if the solver is not From 279d4591f939a39d28bf467212b374b1699456af Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Wed, 13 Nov 2024 17:08:03 +0100 Subject: [PATCH 23/26] Implement methods --- src/fin/fin_interface.cpp | 453 +++++++++++++++++++++-- src/include/miopen/fin/fin_interface.hpp | 2 +- src/include/miopen/find_solution.hpp | 6 +- 3 files changed, 431 insertions(+), 30 deletions(-) diff --git a/src/fin/fin_interface.cpp b/src/fin/fin_interface.cpp index c961b5c274..ea981487ef 100644 --- a/src/fin/fin_interface.cpp +++ b/src/fin/fin_interface.cpp @@ -24,16 +24,434 @@ * *******************************************************************************/ +#include #include +#include #include +#include #include +#include +#include #include #include namespace miopen { namespace fin_interface { +// ================== AnySolver ================== +// This class is an ugly trick. The only reason for it is that each tunable solver has its own PerformanceConfig. Needs to be refactored in the future. Although, it might be worth leaving it as is to be able to add things needed for tuning infrastructure. +template +class AnySolver +{ +public: + AnySolver(const miopen::solver::SolverBase* solver_base, uint64_t id); + + /// \todo Move to SolverInterface + miopen::solver::ConvSolution FindSolution(const Context& ctx, + const Problem& problem, + miopen::PerformanceDb& db, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg) const + { + assert(obj != nullptr); + return obj->FindSolution(sbase, ctx, problem, db, invoke_ctx, perf_cfg); + } + + std::vector GetAllSolutions(const Context& ctx, const Problem& problem) const + { + assert(obj != nullptr); + return obj->GetAllSolutions(sbase, ctx, problem); + } + + std::string GetPerfCfgParams(const Context& ctx, + const Problem& problem, + miopen::PerformanceDb& db) const + { + assert(obj != nullptr); + return obj->GetPerfCfgParams(sbase, ctx, problem, db); + } + + bool TestPerfCfgParams(const Context& ctx, + const Problem& problem, + const std::string& params) const + { + assert(obj != nullptr); + return obj->TestPerfCfgParams(sbase, ctx, problem, params); + } + +private: + using NonTunableSolver = miopen::solver::SolverInterfaceNonTunable; + + // Virtual base class + class AnySolver_base + { + public: + virtual ~AnySolver_base() = default; + virtual miopen::solver::ConvSolution FindSolution(const miopen::solver::SolverBase* solver_base, + const Context& ctx, + const Problem& problem, + miopen::PerformanceDb& db, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg) const = 0; + virtual std::vector GetAllSolutions(const miopen::solver::SolverBase* solver_base, + const Context& ctx, + const Problem& problem) const = 0; + virtual std::string GetPerfCfgParams(const miopen::solver::SolverBase* solver_base, + const Context& ctx, + const Problem& problem, + miopen::PerformanceDb& db) const = 0; + virtual bool TestPerfCfgParams(const miopen::solver::SolverBase* solver_base, + const Context& ctx, + const Problem& problem, + const std::string& params) const = 0; + }; + + // Templated derived class + template + class AnySolver_impl final : public AnySolver_base + { + miopen::solver::ConvSolution FindSolution(const miopen::solver::SolverBase* solver_base, + const Context& ctx, + const Problem& problem, + miopen::PerformanceDb& db, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg) const override + { + const auto& solver = GetSolver(solver_base); + return miopen::solver::FindSolution(solver, ctx, problem, db, invoke_ctx, perf_cfg); + } + + std::vector GetAllSolutions(const miopen::solver::SolverBase* solver_base, + const Context& ctx, + const Problem& problem) const override + { + const auto& solver = GetSolver(solver_base); + if constexpr(std::is_same_v) + { + // Non-tunable solver + return {solver.GetSolution(ctx, problem)}; + } + else + { + // Tunable solver + using PerformanceConfig = decltype(solver.GetDefaultPerformanceConfig(ctx, problem)); + if constexpr(std::is_same_v) + { + // Legacy tunable solver + MIOPEN_THROW("No solutions returned for Legacy Solvers."); + } + else + { + return miopen::solver::GetAllSolutions(solver, ctx, problem); + } + } + } + + std::string GetPerfCfgParams(const miopen::solver::SolverBase* solver_base, + const Context& ctx, + const Problem& problem, + miopen::PerformanceDb& db) const override + { + const auto& solver = GetSolver(solver_base); + + if constexpr(std::is_same_v) + { + // Non-tunable solver + MIOPEN_LOG_I2("PerformanceDb: No Config: " << solver.SolverDbId()); + return {}; + } + else + { + // Tunable solver + using PerformanceConfig = decltype(solver.GetDefaultPerformanceConfig(ctx, problem)); + PerformanceConfig config; + + if(db.Load(problem, solver.SolverDbId(), config)) + { + MIOPEN_LOG_I2("PerformanceDb: Record Loaded: " << solver.SolverDbId()); + if(solver.IsValidPerformanceConfig(ctx, problem, config)) + { + return config.ToString(); + } + MIOPEN_LOG_I2("PerformanceDb: Invalid Config: " << solver.SolverDbId()); + } + else if(!solver.AltSolverDbId().empty() && + db.Load(problem, solver.AltSolverDbId(), config)) + { + MIOPEN_LOG_I("PerformanceDb: alternate record loaded: " << solver.AltSolverDbId()); + if(solver.IsValidPerformanceConfig(ctx, problem, config)) + { + return config.ToString(); + } + MIOPEN_LOG_I2("PerformanceDb: Invalid alternate record: " << solver.AltSolverDbId() + << ": " << config); + } + + MIOPEN_LOG_I2("PerformanceDb: Failed Loading, Using Default: " << solver.SolverDbId()); + config = solver.GetDefaultPerformanceConfig(ctx, problem); + return config.ToString(); + } + } + + bool TestPerfCfgParams(const miopen::solver::SolverBase* solver_base, + const Context& ctx, + const Problem& problem, + const std::string& params) const override + { + if constexpr(std::is_same_v) + { + // Non-tunable solver + return false; + } + else + { + // Tunable solver + const auto& solver = GetSolver(solver_base); + using PerformanceConfig = decltype(solver.GetDefaultPerformanceConfig(ctx, problem)); + PerformanceConfig config; + + if(!config.Deserialize(params)) + { + MIOPEN_LOG_WE("Perf params are obsolete or corrupt: " + << params << ". Performance may degrade."); + return false; + } + + return solver.IsValidPerformanceConfig(ctx, problem, config); + } + } + + const auto& GetSolver(const miopen::solver::SolverBase* solver_base) const + { + return *static_cast(solver_base); + } + }; + + template + void SetObject() + { + // Test cast + const T* ptr = dynamic_cast(sbase); + if(ptr == nullptr) + MIOPEN_THROW(miopenStatusInternalError); + + static const AnySolver_impl impl; + obj = &impl; + } + + void SetObjectNonTunable() + { + SetObject(); + } + + const AnySolver_base* obj = nullptr; + const miopen::solver::SolverBase* const sbase; +}; + +template <> +AnySolver::AnySolver(const miopen::solver::SolverBase* solver_base, uint64_t id) : sbase(solver_base) +{ + if(!sbase->IsTunable()) + { + SetObjectNonTunable(); + return; + } + + switch(id) + { + case 1: + SetObject(); + break; + case 2: + SetObject(); + break; + case 3: + SetObject(); + break; + case 11: + SetObject(); + break; + case 13: + SetObject(); + break; + case 16: + SetObject(); + break; + case 17: + SetObject(); + break; + case 18: + SetObject>(); + break; + case 19: + SetObject>(); + break; + case 20: + SetObject>(); + break; + case 21: + SetObject>(); + break; + case 22: + SetObject>(); + break; + case 26: + SetObject(); + break; + case 31: + SetObject(); + break; + case 37: + SetObject>(); + break; + case 53: + SetObject>(); + break; + case 54: + SetObject(); + break; + case 55: + SetObject(); + break; + case 56: + SetObject(); + break; + case 57: + SetObject(); + break; + case 60: + SetObject(); + break; + case 61: + SetObject(); + break; + case 64: + SetObject(); + break; + case 73: + SetObject(); + break; + case 75: + SetObject>(); + break; + case 76: + SetObject>(); + break; + case 77: + SetObject>(); + break; + case 78: + SetObject>(); + break; + case 79: + SetObject>(); + break; + case 80: + SetObject(); + break; + case 81: + SetObject(); + break; + case 83: + SetObject(); + break; + case 98: + SetObject(); + break; + case 99: + SetObject(); + break; + case 100: + SetObject(); + break; + case 103: + SetObject(); + break; + case 104: + SetObject(); + break; + case 105: + SetObject(); + break; + case 107: + SetObject(); + break; + case 108: + SetObject(); + break; + case 110: + SetObject(); + break; + case 114: + SetObject(); + break; + case 127: + SetObject(); + break; + case 128: + SetObject(); + break; + case 129: + SetObject(); + break; + case 137: + SetObject(); + break; + case 138: + SetObject(); + break; + case 140: + SetObject(); + break; + case 141: + SetObject(); + break; + case 149: + SetObject(); + break; + case 150: + SetObject(); + break; + case 151: + SetObject(); + break; + case 155: + SetObject(); + break; + case 156: + SetObject(); + break; + default: + MIOPEN_THROW(miopenStatusInternalError, "Unknown solver ID"); + } +} + +template <> +AnySolver::AnySolver(const miopen::solver::SolverBase* solver_base, uint64_t id) : sbase(solver_base) +{ + if(!sbase->IsTunable()) + { + SetObjectNonTunable(); + return; + } + + switch(id) + { + case 142: + SetObject(); + break; + case 143: + SetObject(); + break; + case 144: + SetObject(); + break; + default: + MIOPEN_THROW(miopenStatusInternalError, "Unknown solver ID"); + } +} + // ================== Solver ================== Solver::Solver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id) : sbase(solver_base), id(solver_id) @@ -108,47 +526,34 @@ SolverMixin::FindSolution(const Context& ctx, const miopen::AnyInvokeParams& invoke_ctx, const std::string& perf_cfg) const { - std::ignore = ctx; - std::ignore = problem; - std::ignore = db; - std::ignore = invoke_ctx; - std::ignore = perf_cfg; - if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); - /// \todo - MIOPEN_THROW(miopenStatusNotImplemented); + const auto solver = AnySolver(sbase, id); + return solver.FindSolution(ctx, problem, db, invoke_ctx, perf_cfg); } template std::vector SolverMixin::GetAllSolutions(const Context& ctx, const Problem& problem) const { - std::ignore = ctx; - std::ignore = problem; - if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); - /// \todo - MIOPEN_THROW(miopenStatusNotImplemented); + const auto solver = AnySolver(sbase, id); + return solver.GetAllSolutions(ctx, problem); } template std::string SolverMixin::GetPerfCfgParams(const Context& ctx, const Problem& problem, - const PerformanceDb& db) const + miopen::PerformanceDb& db) const { - std::ignore = ctx; - std::ignore = problem; - std::ignore = db; - if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); - /// \todo - MIOPEN_THROW(miopenStatusNotImplemented); + const auto solver = AnySolver(sbase, id); + return solver.GetPerfCfgParams(ctx, problem, db); } template @@ -156,15 +561,11 @@ bool SolverMixin::TestPerfCfgParams(const Context& ctx, const Problem& problem, const std::string& params) const { - std::ignore = ctx; - std::ignore = problem; - std::ignore = params; - if(sbase == nullptr) MIOPEN_THROW(miopenStatusNotInitialized); - /// \todo - MIOPEN_THROW(miopenStatusNotImplemented); + const auto solver = AnySolver(sbase, id); + return solver.TestPerfCfgParams(ctx, problem, params); } // Explicit instantiation diff --git a/src/include/miopen/fin/fin_interface.hpp b/src/include/miopen/fin/fin_interface.hpp index 569a249e0b..6e4a93c27e 100644 --- a/src/include/miopen/fin/fin_interface.hpp +++ b/src/include/miopen/fin/fin_interface.hpp @@ -107,7 +107,7 @@ class SolverMixin : public Solver const Problem& problem) const; std::string - GetPerfCfgParams(const Context& ctx, const Problem& problem, const PerformanceDb& db) const; + GetPerfCfgParams(const Context& ctx, const Problem& problem, miopen::PerformanceDb& db) const; bool TestPerfCfgParams(const Context& ctx, const Problem& problem, const std::string& params) const; diff --git a/src/include/miopen/find_solution.hpp b/src/include/miopen/find_solution.hpp index 9527e50aae..fa04b46d20 100644 --- a/src/include/miopen/find_solution.hpp +++ b/src/include/miopen/find_solution.hpp @@ -52,7 +52,7 @@ namespace solver { template auto FindSolutionImpl(rank<1>, - Solver s, + const Solver& s, const Context& context, const Problem& problem, Db&& db, @@ -148,7 +148,7 @@ auto FindSolutionImpl(rank<1>, template auto FindSolutionImpl(rank<0>, - Solver s, + const Solver& s, const Context& context, const Problem& problem, Db&&, @@ -200,7 +200,7 @@ auto GetInvokeFactoryImpl( /// Could take long if an exhaustive search is requested/performed. /// May read/write perfDb. template -ConvSolution FindSolution(Solver s, +ConvSolution FindSolution(const Solver& s, const Context& context, const Problem& problem, Db&& db, From 0ff5960a9faa39ed0e695daafbc12bb4194d494b Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Wed, 13 Nov 2024 17:11:18 +0100 Subject: [PATCH 24/26] Fix formatting --- src/fin/fin_interface.cpp | 313 +++++++++++++------------------------- 1 file changed, 105 insertions(+), 208 deletions(-) diff --git a/src/fin/fin_interface.cpp b/src/fin/fin_interface.cpp index ea981487ef..e665b36ea1 100644 --- a/src/fin/fin_interface.cpp +++ b/src/fin/fin_interface.cpp @@ -40,7 +40,9 @@ namespace miopen { namespace fin_interface { // ================== AnySolver ================== -// This class is an ugly trick. The only reason for it is that each tunable solver has its own PerformanceConfig. Needs to be refactored in the future. Although, it might be worth leaving it as is to be able to add things needed for tuning infrastructure. +// This class is an ugly trick. The only reason for it is that each tunable solver has its own +// PerformanceConfig. Needs to be refactored in the future. Although, it might be worth leaving it +// as is to be able to add things needed for tuning infrastructure. template class AnySolver { @@ -58,23 +60,22 @@ class AnySolver return obj->FindSolution(sbase, ctx, problem, db, invoke_ctx, perf_cfg); } - std::vector GetAllSolutions(const Context& ctx, const Problem& problem) const + std::vector GetAllSolutions(const Context& ctx, + const Problem& problem) const { assert(obj != nullptr); return obj->GetAllSolutions(sbase, ctx, problem); } - std::string GetPerfCfgParams(const Context& ctx, - const Problem& problem, - miopen::PerformanceDb& db) const + std::string + GetPerfCfgParams(const Context& ctx, const Problem& problem, miopen::PerformanceDb& db) const { assert(obj != nullptr); return obj->GetPerfCfgParams(sbase, ctx, problem, db); } - bool TestPerfCfgParams(const Context& ctx, - const Problem& problem, - const std::string& params) const + bool + TestPerfCfgParams(const Context& ctx, const Problem& problem, const std::string& params) const { assert(obj != nullptr); return obj->TestPerfCfgParams(sbase, ctx, problem, params); @@ -88,15 +89,17 @@ class AnySolver { public: virtual ~AnySolver_base() = default; - virtual miopen::solver::ConvSolution FindSolution(const miopen::solver::SolverBase* solver_base, - const Context& ctx, - const Problem& problem, - miopen::PerformanceDb& db, - const miopen::AnyInvokeParams& invoke_ctx, - const std::string& perf_cfg) const = 0; - virtual std::vector GetAllSolutions(const miopen::solver::SolverBase* solver_base, - const Context& ctx, - const Problem& problem) const = 0; + virtual miopen::solver::ConvSolution + FindSolution(const miopen::solver::SolverBase* solver_base, + const Context& ctx, + const Problem& problem, + miopen::PerformanceDb& db, + const miopen::AnyInvokeParams& invoke_ctx, + const std::string& perf_cfg) const = 0; + virtual std::vector + GetAllSolutions(const miopen::solver::SolverBase* solver_base, + const Context& ctx, + const Problem& problem) const = 0; virtual std::string GetPerfCfgParams(const miopen::solver::SolverBase* solver_base, const Context& ctx, const Problem& problem, @@ -104,7 +107,7 @@ class AnySolver virtual bool TestPerfCfgParams(const miopen::solver::SolverBase* solver_base, const Context& ctx, const Problem& problem, - const std::string& params) const = 0; + const std::string& params) const = 0; }; // Templated derived class @@ -122,9 +125,10 @@ class AnySolver return miopen::solver::FindSolution(solver, ctx, problem, db, invoke_ctx, perf_cfg); } - std::vector GetAllSolutions(const miopen::solver::SolverBase* solver_base, - const Context& ctx, - const Problem& problem) const override + std::vector + GetAllSolutions(const miopen::solver::SolverBase* solver_base, + const Context& ctx, + const Problem& problem) const override { const auto& solver = GetSolver(solver_base); if constexpr(std::is_same_v) @@ -135,8 +139,10 @@ class AnySolver else { // Tunable solver - using PerformanceConfig = decltype(solver.GetDefaultPerformanceConfig(ctx, problem)); - if constexpr(std::is_same_v) + using PerformanceConfig = + decltype(solver.GetDefaultPerformanceConfig(ctx, problem)); + if constexpr(std::is_same_v) { // Legacy tunable solver MIOPEN_THROW("No solutions returned for Legacy Solvers."); @@ -164,7 +170,8 @@ class AnySolver else { // Tunable solver - using PerformanceConfig = decltype(solver.GetDefaultPerformanceConfig(ctx, problem)); + using PerformanceConfig = + decltype(solver.GetDefaultPerformanceConfig(ctx, problem)); PerformanceConfig config; if(db.Load(problem, solver.SolverDbId(), config)) @@ -179,16 +186,18 @@ class AnySolver else if(!solver.AltSolverDbId().empty() && db.Load(problem, solver.AltSolverDbId(), config)) { - MIOPEN_LOG_I("PerformanceDb: alternate record loaded: " << solver.AltSolverDbId()); + MIOPEN_LOG_I( + "PerformanceDb: alternate record loaded: " << solver.AltSolverDbId()); if(solver.IsValidPerformanceConfig(ctx, problem, config)) { return config.ToString(); } - MIOPEN_LOG_I2("PerformanceDb: Invalid alternate record: " << solver.AltSolverDbId() - << ": " << config); + MIOPEN_LOG_I2("PerformanceDb: Invalid alternate record: " + << solver.AltSolverDbId() << ": " << config); } - MIOPEN_LOG_I2("PerformanceDb: Failed Loading, Using Default: " << solver.SolverDbId()); + MIOPEN_LOG_I2( + "PerformanceDb: Failed Loading, Using Default: " << solver.SolverDbId()); config = solver.GetDefaultPerformanceConfig(ctx, problem); return config.ToString(); } @@ -208,13 +217,14 @@ class AnySolver { // Tunable solver const auto& solver = GetSolver(solver_base); - using PerformanceConfig = decltype(solver.GetDefaultPerformanceConfig(ctx, problem)); + using PerformanceConfig = + decltype(solver.GetDefaultPerformanceConfig(ctx, problem)); PerformanceConfig config; if(!config.Deserialize(params)) { MIOPEN_LOG_WE("Perf params are obsolete or corrupt: " - << params << ". Performance may degrade."); + << params << ". Performance may degrade."); return false; } @@ -240,17 +250,16 @@ class AnySolver obj = &impl; } - void SetObjectNonTunable() - { - SetObject(); - } + void SetObjectNonTunable() { SetObject(); } const AnySolver_base* obj = nullptr; const miopen::solver::SolverBase* const sbase; }; template <> -AnySolver::AnySolver(const miopen::solver::SolverBase* solver_base, uint64_t id) : sbase(solver_base) +AnySolver::AnySolver( + const miopen::solver::SolverBase* solver_base, uint64_t id) + : sbase(solver_base) { if(!sbase->IsTunable()) { @@ -260,175 +269,70 @@ AnySolver::AnySolver switch(id) { - case 1: - SetObject(); - break; - case 2: - SetObject(); - break; - case 3: - SetObject(); - break; - case 11: - SetObject(); - break; - case 13: - SetObject(); - break; - case 16: - SetObject(); - break; - case 17: - SetObject(); - break; - case 18: - SetObject>(); - break; - case 19: - SetObject>(); - break; - case 20: - SetObject>(); - break; - case 21: - SetObject>(); - break; - case 22: - SetObject>(); - break; - case 26: - SetObject(); - break; - case 31: - SetObject(); - break; - case 37: - SetObject>(); - break; - case 53: - SetObject>(); - break; - case 54: - SetObject(); - break; - case 55: - SetObject(); - break; - case 56: - SetObject(); - break; - case 57: - SetObject(); - break; - case 60: - SetObject(); - break; - case 61: - SetObject(); - break; - case 64: - SetObject(); - break; - case 73: - SetObject(); - break; - case 75: - SetObject>(); - break; - case 76: - SetObject>(); - break; - case 77: - SetObject>(); - break; - case 78: - SetObject>(); - break; - case 79: - SetObject>(); - break; - case 80: - SetObject(); - break; + case 1: SetObject(); break; + case 2: SetObject(); break; + case 3: SetObject(); break; + case 11: SetObject(); break; + case 13: SetObject(); break; + case 16: SetObject(); break; + case 17: SetObject(); break; + case 18: SetObject>(); break; + case 19: SetObject>(); break; + case 20: SetObject>(); break; + case 21: SetObject>(); break; + case 22: SetObject>(); break; + case 26: SetObject(); break; + case 31: SetObject(); break; + case 37: SetObject>(); break; + case 53: SetObject>(); break; + case 54: SetObject(); break; + case 55: SetObject(); break; + case 56: SetObject(); break; + case 57: SetObject(); break; + case 60: SetObject(); break; + case 61: SetObject(); break; + case 64: SetObject(); break; + case 73: SetObject(); break; + case 75: SetObject>(); break; + case 76: SetObject>(); break; + case 77: SetObject>(); break; + case 78: SetObject>(); break; + case 79: SetObject>(); break; + case 80: SetObject(); break; case 81: SetObject(); break; - case 83: - SetObject(); - break; - case 98: - SetObject(); - break; - case 99: - SetObject(); - break; - case 100: - SetObject(); - break; - case 103: - SetObject(); - break; - case 104: - SetObject(); - break; - case 105: - SetObject(); - break; - case 107: - SetObject(); - break; - case 108: - SetObject(); - break; - case 110: - SetObject(); - break; - case 114: - SetObject(); - break; - case 127: - SetObject(); - break; - case 128: - SetObject(); - break; - case 129: - SetObject(); - break; - case 137: - SetObject(); - break; - case 138: - SetObject(); - break; - case 140: - SetObject(); - break; - case 141: - SetObject(); - break; - case 149: - SetObject(); - break; - case 150: - SetObject(); - break; - case 151: - SetObject(); - break; - case 155: - SetObject(); - break; - case 156: - SetObject(); - break; - default: - MIOPEN_THROW(miopenStatusInternalError, "Unknown solver ID"); + case 83: SetObject(); break; + case 98: SetObject(); break; + case 99: SetObject(); break; + case 100: SetObject(); break; + case 103: SetObject(); break; + case 104: SetObject(); break; + case 105: SetObject(); break; + case 107: SetObject(); break; + case 108: SetObject(); break; + case 110: SetObject(); break; + case 114: SetObject(); break; + case 127: SetObject(); break; + case 128: SetObject(); break; + case 129: SetObject(); break; + case 137: SetObject(); break; + case 138: SetObject(); break; + case 140: SetObject(); break; + case 141: SetObject(); break; + case 149: SetObject(); break; + case 150: SetObject(); break; + case 151: SetObject(); break; + case 155: SetObject(); break; + case 156: SetObject(); break; + default: MIOPEN_THROW(miopenStatusInternalError, "Unknown solver ID"); } } template <> -AnySolver::AnySolver(const miopen::solver::SolverBase* solver_base, uint64_t id) : sbase(solver_base) +AnySolver::AnySolver( + const miopen::solver::SolverBase* solver_base, uint64_t id) + : sbase(solver_base) { if(!sbase->IsTunable()) { @@ -438,17 +342,10 @@ AnySolver::AnyS switch(id) { - case 142: - SetObject(); - break; - case 143: - SetObject(); - break; - case 144: - SetObject(); - break; - default: - MIOPEN_THROW(miopenStatusInternalError, "Unknown solver ID"); + case 142: SetObject(); break; + case 143: SetObject(); break; + case 144: SetObject(); break; + default: MIOPEN_THROW(miopenStatusInternalError, "Unknown solver ID"); } } From f67b5b544f491a694ba8445ee86c4493116362c7 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Thu, 14 Nov 2024 13:33:31 +0100 Subject: [PATCH 25/26] Add tests for GetAllSolutions(), GetPerfCfgParams() & TestPerfCfgParams() --- test/gtest/unit_FinInterface.cpp | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/test/gtest/unit_FinInterface.cpp b/test/gtest/unit_FinInterface.cpp index 4854162a82..4e5fe55429 100644 --- a/test/gtest/unit_FinInterface.cpp +++ b/test/gtest/unit_FinInterface.cpp @@ -337,9 +337,13 @@ const auto& GetSolverConfigs() { static const std::unordered_map configs = { // clang-format off + // Non-tunable solvers {"ConvDirectNaiveConvFwd", {miopen::conv::Direction::Forward, {miopenFloat, {1, 16, 14, 14}}, {miopenFloat, {48, 16, 5, 5}}, miopenFloat, {{2, 2}, {1, 1}, {1, 1}}}}, {"ConvDirectNaiveConvBwd", {miopen::conv::Direction::BackwardData, {miopenFloat, {1, 16, 14, 14}}, {miopenFloat, {48, 16, 5, 5}}, miopenFloat, {{2, 2}, {1, 1}, {1, 1}}}}, {"ConvDirectNaiveConvWrw", {miopen::conv::Direction::BackwardWeights, {miopenFloat, {1, 16, 14, 14}}, {miopenFloat, {48, 16, 5, 5}}, miopenFloat, {{2, 2}, {1, 1}, {1, 1}}}}, + // Tunable solvers + {"ConvBinWinogradRxSf3x2", {miopen::conv::Direction::Forward, {miopenFloat, {1, 20, 20, 20}}, {miopenFloat, {20, 20, 3, 3}}, miopenFloat, {{1, 1}, {1, 1}, {1, 1}}}}, + {"ConvBinWinogradRxSf2x3", {miopen::conv::Direction::BackwardWeights, {miopenFloat, {1, 20, 20, 20}}, {miopenFloat, {20, 20, 3, 3}}, miopenFloat, {{1, 1}, {1, 1}, {1, 1}}}}, // clang-format on }; @@ -430,9 +434,20 @@ void CheckSolverConfig(const Solver& solver, const SolverConfig& config) auto&& handle = get_handle(); const auto problem = config.GetProblemDescription(); const auto ctx = GetContext(&handle, problem); + auto db = miopen::GetDb(ctx); - ASSERT_EQ(solver.IsApplicable(ctx, problem), true); + ASSERT_TRUE(solver.IsApplicable(ctx, problem)); std::ignore = solver.GetWorkspaceSize(ctx, problem); + + /// \todo test FindSolution() + + const auto solutions = solver.GetAllSolutions(ctx, problem); + ASSERT_GT(solutions.size(), 0); + + const auto pcfg_params = solver.GetPerfCfgParams(ctx, problem, db); + ASSERT_NE(pcfg_params.empty(), solver.IsTunable()); + + ASSERT_EQ(solver.TestPerfCfgParams(ctx, problem, pcfg_params), solver.IsTunable()); } template From b315df954bd1915483c42b493d9b79a7c32e594b Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Thu, 14 Nov 2024 16:43:09 +0100 Subject: [PATCH 26/26] Registry: remove smart pointer --- src/solver.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/solver.cpp b/src/solver.cpp index c50e840a42..7b80aa99d9 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -141,7 +141,7 @@ struct IdRegistryEntry Primitive primitive = Primitive::Convolution; miopenConvAlgorithm_t convAlgo = miopenConvolutionAlgoDirect; AnySolver solver; - std::unique_ptr solver_base; + const SolverBase* solver_base = nullptr; }; struct IdRegistryData @@ -206,7 +206,7 @@ const SolverBase* Id::GetSolverBase() const const auto it = IdRegistry().value_to_entry.find(value); if(it == IdRegistry().value_to_entry.end()) return nullptr; - return it->second.solver_base.get(); + return it->second.solver_base; } std::string Id::GetAlgo(miopen::conv::Direction dir) const @@ -294,21 +294,21 @@ template inline void RegisterWithSolver(IdRegistryData& registry, uint64_t value, TSolver, miopenConvAlgorithm_t algo) { - auto solver_base = std::make_unique(); - if(!Register(registry, value, solver_base->SolverDbId(), algo)) + static const TSolver solver_base; + if(!Register(registry, value, solver_base.SolverDbId(), algo)) return; auto& entry = registry.value_to_entry.at(value); entry.solver = TSolver{}; - entry.solver_base = std::move(solver_base); + entry.solver_base = &solver_base; } template void RegisterWithSolver(IdRegistryData& registry, uint64_t value, Primitive primitive) { - auto solver_base = std::make_unique(); - if(!Register(registry, value, primitive, solver_base->SolverDbId())) + static const Solver solver_base; + if(!Register(registry, value, primitive, solver_base.SolverDbId())) return; - registry.value_to_entry.at(value).solver_base = std::move(solver_base); + registry.value_to_entry.at(value).solver_base = &solver_base; } inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry)