Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal stopping criteria config #1613

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions core/config/config_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ parse_or_get_factory<const stop::CriterionFactory>(const pnode& config,
if (config.get_tag() == pnode::tag_t::string) {
return detail::registry_accessor::get_data<stop::CriterionFactory>(
context, config.get_string());
} else if (config.get_tag() == pnode::tag_t::map) {
}

if (config.get_tag() == pnode::tag_t::map) {
static std::map<std::string,
std::function<deferred_factory_parameter<
gko::stop::CriterionFactory>(
Expand All @@ -56,9 +58,89 @@ parse_or_get_factory<const stop::CriterionFactory>(const pnode& config,
{"ImplicitResidualNorm", configure_implicit_residual}}};
return criterion_map.at(config.get("type").get_string())(config,
context, td);
} else {
GKO_INVALID_STATE("The data of config is not valid.");
}

GKO_INVALID_STATE(
"Criteria must either be defined as a string or an array.");
}


std::vector<deferred_factory_parameter<const stop::CriterionFactory>>
parse_minimal_criteria(const pnode& config, const registry& context,
const type_descriptor& td)
{
auto map_time = [](const pnode& config, const registry& context,
const type_descriptor& td) {
pnode time_config{{{"time_limit", config.get("time")}}};
return configure_time(time_config, context, td);
};
auto map_iteration = [](const pnode& config, const registry& context,
const type_descriptor& td) {
pnode iter_config{{{"max_iters", config.get("iteration")}}};
return configure_iter(iter_config, context, td);
};
auto create_residual_mapping = [](const std::string& key,
const std::string& baseline,
auto configure_fn) {
return std::make_pair(
key, [=](const pnode& config, const registry& context,
const type_descriptor& td) {
pnode res_config{{{"baseline", pnode{baseline}},
{"reduction_factor", config.get(key)}}};
return configure_fn(res_config, context, td);
});
};
std::map<
std::string,
std::function<deferred_factory_parameter<gko::stop::CriterionFactory>(
const pnode&, const registry&, type_descriptor)>>
criterion_map{
{{"time", map_time},
{"iteration", map_iteration},
create_residual_mapping("relative_residual_norm", "rhs_norm",
configure_residual),
create_residual_mapping("initial_residual_norm", "initial_resnorm",
configure_residual),
create_residual_mapping("absolute_residual_norm", "absolute",
configure_residual),
create_residual_mapping("relative_implicit_residual_norm",
"rhs_norm", configure_implicit_residual),
create_residual_mapping("initial_implicit_residual_norm",
"initial_resnorm",
configure_implicit_residual),
create_residual_mapping("absolute_implicit_residual_norm",
"absolute", configure_implicit_residual)}};

std::vector<deferred_factory_parameter<const stop::CriterionFactory>> res;
for (const auto& it : config.get_map()) {
res.emplace_back(criterion_map.at(it.first)(config, context, td));
}
return res;
}


std::vector<deferred_factory_parameter<const stop::CriterionFactory>>
parse_or_get_criteria(const pnode& config, const registry& context,
const type_descriptor& td)
{
if (config.get_tag() == pnode::tag_t::array ||
(config.get_tag() == pnode::tag_t::map && config.get("type"))) {
return parse_or_get_factory_vector<const stop::CriterionFactory>(
config, context, td);
}

if (config.get_tag() == pnode::tag_t::map) {
return parse_minimal_criteria(config, context, td);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return parse_minimal_criteria(config, context, td);
auto updated = config::update_type(td);
return parse_minimal_criteria(config, context, updated);

It is to support no valuetype available outside

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean here. I can specify the value type of the residual nom criterion, as you can see in the tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's based on the type_descriptor from the outer loop.
I mean something like

"stop": {
  "value_type": "float64",
  "residual_norm": ...
}

in case no precision information from outside or want to specify certain precision for stop.

}

if (config.get_tag() == pnode::tag_t::string) {
return {detail::registry_accessor::get_data<stop::CriterionFactory>(
context, config.get_string())};
}

GKO_INVALID_STATE(
"Criteria must either be defined as a string, an array,"
"or an map.");
}

} // namespace config
Expand Down
9 changes: 9 additions & 0 deletions core/config/config_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ parse_or_get_factory<const stop::CriterionFactory>(const pnode& config,
const registry& context,
const type_descriptor& td);

/**
* parse or get an std::vector of criteria.
* A stored single criterion will be converted to an std::vector.
*/
std::vector<deferred_factory_parameter<const stop::CriterionFactory>>
parse_or_get_criteria(const pnode& config, const registry& context,
const type_descriptor& td);


/**
* give a vector of factory by calling parse_or_get_factory.
*/
Expand Down
55 changes: 55 additions & 0 deletions core/test/config/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <ginkgo/core/stop/combined.hpp>
#include <ginkgo/core/stop/iteration.hpp>
#include <ginkgo/core/stop/residual_norm.hpp>
#include <ginkgo/core/stop/time.hpp>


#include "core/config/config_helper.hpp"
Expand Down Expand Up @@ -126,6 +127,60 @@ TEST_F(Config, GenerateObjectWithCustomBuild)
}


TEST_F(Config, GenerateCriteriaFromMinimalConfig)
{
auto reg = registry();
reg.emplace("precond", this->mtx);
pnode minimal_stop{{
{"iteration", pnode{10}},
{"relative_implicit_residual_norm", pnode{0.01}},
{"relative_residual_norm", pnode{0.01}},
{"time", pnode{100}},
}};

pnode p{{{"criteria", minimal_stop}}};
auto obj = std::dynamic_pointer_cast<gko::solver::Cg<float>::Factory>(
parse<LinOpFactoryType::Cg>(p, reg, type_descriptor{"float32", "void"})
.on(this->exec));

ASSERT_NE(obj, nullptr);
auto criteria = obj->get_parameters().criteria;
ASSERT_EQ(criteria.size(), minimal_stop.get_map().size());
{
SCOPED_TRACE("Iteration Criterion");
auto it =
std::dynamic_pointer_cast<const gko::stop::Iteration::Factory>(
criteria[0]);
ASSERT_NE(it, nullptr);
EXPECT_EQ(it->get_parameters().max_iters, 10);
}
{
SCOPED_TRACE("Implicit Residual Criterion");
auto res = std::dynamic_pointer_cast<
const gko::stop::ImplicitResidualNorm<float>::Factory>(criteria[1]);
ASSERT_NE(res, nullptr);
EXPECT_EQ(res->get_parameters().baseline, gko::stop::mode::rhs_norm);
EXPECT_EQ(res->get_parameters().reduction_factor, 0.01f);
}
{
SCOPED_TRACE("Residual Criterion");
auto res = std::dynamic_pointer_cast<
const gko::stop::ResidualNorm<float>::Factory>(criteria[2]);
ASSERT_NE(res, nullptr);
EXPECT_EQ(res->get_parameters().baseline, gko::stop::mode::rhs_norm);
EXPECT_EQ(res->get_parameters().reduction_factor, 0.01f);
}
{
SCOPED_TRACE("Time Criterion");
using namespace std::chrono_literals;
auto time = std::dynamic_pointer_cast<const gko::stop::Time::Factory>(
criteria[3]);
ASSERT_NE(time, nullptr);
EXPECT_EQ(time->get_parameters().time_limit, 100ns);
}
}


TEST(GetValue, IndexType)
{
long long int value = 123;
Expand Down
Loading