Skip to content

Commit

Permalink
update documentation, rename, complex number and rm stop::, index_type
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Grützmacher <[email protected]>
Co-authored-by: Tobias Ribizel <[email protected]>
Co-authored-by: Marcel Koch <[email protected]>
  • Loading branch information
4 people committed May 27, 2024
1 parent 5412de5 commit 2629167
Show file tree
Hide file tree
Showing 13 changed files with 279 additions and 184 deletions.
42 changes: 33 additions & 9 deletions core/config/config_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,54 @@

#include "core/config/config_helper.hpp"
#include "core/config/registry_accessor.hpp"

#include "core/config/stop_config.hpp"

namespace gko {
namespace config {


template <>
deferred_factory_parameter<const LinOpFactory> get_factory<const LinOpFactory>(
const pnode& config, const registry& context, const type_descriptor& td)
deferred_factory_parameter<const LinOpFactory>
parse_or_get_factory<const LinOpFactory>(const pnode& config,
const registry& context,
const type_descriptor& td)
{
deferred_factory_parameter<const LinOpFactory> ptr;
if (config.get_tag() == pnode::tag_t::string) {
ptr = detail::registry_accessor::get_data<LinOpFactory>(
return detail::registry_accessor::get_data<LinOpFactory>(
context, config.get_string());
} else if (config.get_tag() == pnode::tag_t::map) {
ptr = parse(config, context, td);
return parse(config, context, td);
} else {
GKO_INVALID_STATE("The data of config is not valid.");
}
GKO_THROW_IF_INVALID(!ptr.is_empty(), "Parse get nullptr in the end");

return ptr;
}


template <>
deferred_factory_parameter<const stop::CriterionFactory>
parse_or_get_factory<const stop::CriterionFactory>(const pnode& config,
const registry& context,
const type_descriptor& td)
{
if (config.get_tag() == pnode::tag_t::string) {
return detail::registry_accessor::get_data<stop::CriterionFactory>(
context, config.get_string());
} else if (config.get_tag() == pnode::tag_t::map) {
static std::map<std::string,
std::function<deferred_factory_parameter<
gko::stop::CriterionFactory>(
const pnode&, const registry&, type_descriptor)>>
criterion_map{
{{"Time", configure_time},
{"Iteration", configure_iter},
{"ResidualNorm", configure_residual},
{"ImplicitResidualNorm", configure_implicit_residual}}};
return criterion_map.at(config.get("type").get_string())(config,
context, td);
} else {
GKO_INVALID_STATE("The data of config is not valid.");
}
}

} // namespace config
} // namespace gko
53 changes: 31 additions & 22 deletions core/config/config_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,45 +60,47 @@ inline std::shared_ptr<T> get_stored_obj(const pnode& config,


/**
* get_factory builds the factory from config (map) or searches the pointers in
* Build the factory from config (map) or search the pointers in
* the registry by string.
*/
template <typename T>
deferred_factory_parameter<T> get_factory(const pnode& config,
const registry& context,
const type_descriptor& td);
deferred_factory_parameter<T> parse_or_get_factory(const pnode& config,
const registry& context,
const type_descriptor& td);

/**
* specialize for const LinOpFactory
*/
template <>
deferred_factory_parameter<const LinOpFactory> get_factory<const LinOpFactory>(
const pnode& config, const registry& context, const type_descriptor& td);
deferred_factory_parameter<const LinOpFactory>
parse_or_get_factory<const LinOpFactory>(const pnode& config,
const registry& context,
const type_descriptor& td);

/**
* specialize for const stop::CriterionFactory
*/
template <>
deferred_factory_parameter<const stop::CriterionFactory>
get_factory<const stop::CriterionFactory>(const pnode& config,
const registry& context,
const type_descriptor& td);
parse_or_get_factory<const stop::CriterionFactory>(const pnode& config,
const registry& context,
const type_descriptor& td);

/**
* get_factory_vector will gives a vector of factory by calling get_factory.
* give a vector of factory by calling parse_or_get_factory.
*/
template <typename T>
inline std::vector<deferred_factory_parameter<T>> get_factory_vector(
inline std::vector<deferred_factory_parameter<T>> parse_or_get_factory_vector(
const pnode& config, const registry& context, const type_descriptor& td)
{
std::vector<deferred_factory_parameter<T>> res;
if (config.get_tag() == pnode::tag_t::array) {
for (const auto& it : config.get_array()) {
res.push_back(get_factory<T>(it, context, td));
res.push_back(parse_or_get_factory<T>(it, context, td));
}
} else {
// only one config can be passed without array
res.push_back(get_factory<T>(config, context, td));
res.push_back(parse_or_get_factory<T>(config, context, td));
}

return res;
Expand All @@ -111,9 +113,8 @@ inline std::vector<deferred_factory_parameter<T>> get_factory_vector(
* This is specialization for integral type
*/
template <typename IndexType>
inline
typename std::enable_if<std::is_integral<IndexType>::value, IndexType>::type
get_value(const pnode& config)
inline std::enable_if_t<std::is_integral<IndexType>::value, IndexType>
get_value(const pnode& config)
{
auto val = config.get_integer();
GKO_THROW_IF_INVALID(
Expand All @@ -130,8 +131,7 @@ inline
* This is specialization for floating point type
*/
template <typename ValueType>
inline typename std::enable_if<std::is_floating_point<ValueType>::value,
ValueType>::type
inline std::enable_if_t<std::is_floating_point<ValueType>::value, ValueType>
get_value(const pnode& config)
{
auto val = config.get_real();
Expand All @@ -149,16 +149,25 @@ get_value(const pnode& config)
* This is specialization for complex type
*/
template <typename ValueType>
inline typename std::enable_if<gko::is_complex_s<ValueType>::value,
ValueType>::type
inline std::enable_if_t<gko::is_complex_s<ValueType>::value, ValueType>
get_value(const pnode& config)
{
using real_type = gko::remove_complex<ValueType>;
if (config.get_tag() == pnode::tag_t::real) {
return static_cast<ValueType>(get_value<real_type>(config));
} else if (config.get_tag() == pnode::tag_t::array) {
return ValueType{get_value<real_type>(config.get(0)),
get_value<real_type>(config.get(1))};
real_type real(0);
real_type imag(0);
if (config.get_array().size() >= 1) {
real = get_value<real_type>(config.get(0));
}
if (config.get_array().size() >= 2) {
imag = get_value<real_type>(config.get(1));
}
GKO_THROW_IF_INVALID(
config.get_array().size() <= 2,
"complex value array expression only accept up to two elements");
return ValueType{real, imag};
}
GKO_INVALID_STATE("Can not get complex value");
}
Expand Down
43 changes: 8 additions & 35 deletions core/config/stop_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/config/type_descriptor.hpp>
#include <ginkgo/core/solver/solver_base.hpp>
#include <ginkgo/core/stop/criterion.hpp>
#include <ginkgo/core/stop/iteration.hpp>
Expand All @@ -15,14 +16,15 @@
#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"
#include "core/config/registry_accessor.hpp"
#include "core/config/stop_config.hpp"
#include "core/config/type_descriptor_helper.hpp"


namespace gko {
namespace config {


inline deferred_factory_parameter<stop::CriterionFactory> configure_time(
deferred_factory_parameter<stop::CriterionFactory> configure_time(
const pnode& config, const registry& context, const type_descriptor& td)
{
auto factory = stop::Time::build();
Expand All @@ -33,7 +35,7 @@ inline deferred_factory_parameter<stop::CriterionFactory> configure_time(
}


inline deferred_factory_parameter<stop::CriterionFactory> configure_iter(
deferred_factory_parameter<stop::CriterionFactory> configure_iter(
const pnode& config, const registry& context, const type_descriptor& td)
{
auto factory = stop::Iteration::build();
Expand All @@ -44,7 +46,7 @@ inline deferred_factory_parameter<stop::CriterionFactory> configure_iter(
}


stop::mode get_mode(const std::string& str)
inline stop::mode get_mode(const std::string& str)
{
if (str == "absolute") {
return stop::mode::absolute;
Expand Down Expand Up @@ -79,7 +81,7 @@ class ResidualNormConfigurer {
};


inline deferred_factory_parameter<stop::CriterionFactory> configure_residual(
deferred_factory_parameter<stop::CriterionFactory> configure_residual(
const pnode& config, const registry& context, const type_descriptor& td)
{
auto updated = update_type(config, td);
Expand Down Expand Up @@ -111,9 +113,8 @@ class ImplicitResidualNormConfigurer {
};


inline deferred_factory_parameter<stop::CriterionFactory>
configure_implicit_residual(const pnode& config, const registry& context,
const type_descriptor& td)
deferred_factory_parameter<stop::CriterionFactory> configure_implicit_residual(
const pnode& config, const registry& context, const type_descriptor& td)
{
auto updated = update_type(config, td);
return dispatch<stop::CriterionFactory, ImplicitResidualNormConfigurer>(
Expand All @@ -122,33 +123,5 @@ configure_implicit_residual(const pnode& config, const registry& context,
}


template <>
deferred_factory_parameter<const stop::CriterionFactory>
get_factory<const stop::CriterionFactory>(const pnode& config,
const registry& context,
const type_descriptor& td)
{
deferred_factory_parameter<const stop::CriterionFactory> ptr;
if (config.get_tag() == pnode::tag_t::string) {
return detail::registry_accessor::get_data<stop::CriterionFactory>(
context, config.get_string());
} else if (config.get_tag() == pnode::tag_t::map) {
static std::map<std::string,
std::function<deferred_factory_parameter<
gko::stop::CriterionFactory>(
const pnode&, const registry&, type_descriptor)>>
criterion_map{
{{"stop::Time", configure_time},
{"stop::Iteration", configure_iter},
{"stop::ResidualNorm", configure_residual},
{"stop::ImplicitResidualNorm", configure_implicit_residual}}};
return criterion_map.at(config.get("type").get_string())(config,
context, td);
}
GKO_THROW_IF_INVALID(!ptr.is_empty(), "Parse get nullptr in the end");
return ptr;
}


} // namespace config
} // namespace gko
35 changes: 35 additions & 0 deletions core/config/stop_config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_CONFIG_STOP_CONFIG_HPP_
#define GKO_CORE_CONFIG_STOP_CONFIG_HPP_


#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/config/type_descriptor.hpp>
#include <ginkgo/core/stop/criterion.hpp>


namespace gko {
namespace config {


deferred_factory_parameter<stop::CriterionFactory> configure_time(
const pnode& config, const registry& context, const type_descriptor& td);

deferred_factory_parameter<stop::CriterionFactory> configure_iter(
const pnode& config, const registry& context, const type_descriptor& td);

deferred_factory_parameter<stop::CriterionFactory> configure_residual(
const pnode& config, const registry& context, const type_descriptor& td);

deferred_factory_parameter<stop::CriterionFactory> configure_implicit_residual(
const pnode& config, const registry& context, const type_descriptor& td);

} // namespace config
} // namespace gko


#endif // GKO_CORE_CONFIG_STOP_CONFIG_HPP_
7 changes: 6 additions & 1 deletion core/config/type_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include <ginkgo/core/config/type_descriptor.hpp>


#include <ginkgo/core/base/exception_helpers.hpp>


#include "core/config/type_descriptor_helper.hpp"


Expand All @@ -21,7 +24,9 @@ type_descriptor update_type(const pnode& config, const type_descriptor& td)
value_typestr = obj.get_string();
}
if (auto& obj = config.get("index_type")) {
index_typestr = obj.get_string();
GKO_INVALID_STATE(
"Setting index_type in the config is not allowed. Please set the "
"proper index_type through type_descriptor of parse");
}
return type_descriptor{value_typestr, index_typestr};
}
Expand Down
13 changes: 5 additions & 8 deletions core/config/type_descriptor_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,11 @@


#include <string>
#include <type_traits>


#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/lin_op.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/solver/solver_base.hpp>
#include <ginkgo/core/stop/criterion.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/config/property_tree.hpp>
#include <ginkgo/core/config/type_descriptor.hpp>


namespace gko {
Expand Down Expand Up @@ -49,6 +44,8 @@ TYPE_STRING_OVERLOAD(std::complex<float>, "complex<float32>");
TYPE_STRING_OVERLOAD(int32, "int32");
TYPE_STRING_OVERLOAD(int64, "int64");

#undef TYPE_STRING_OVERLOAD


} // namespace config
} // namespace gko
Expand Down
9 changes: 5 additions & 4 deletions core/solver/cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ typename Cg<ValueType>::parameters_type Cg<ValueType>::parse(
}
if (auto& obj = config.get("criteria")) {
params.with_criteria(
gko::config::get_factory_vector<const stop::CriterionFactory>(
obj, context, td_for_child));
gko::config::parse_or_get_factory_vector<
const stop::CriterionFactory>(obj, context, td_for_child));
}
if (auto& obj = config.get("preconditioner")) {
params.with_preconditioner(gko::config::get_factory<const LinOpFactory>(
obj, context, td_for_child));
params.with_preconditioner(
gko::config::parse_or_get_factory<const LinOpFactory>(
obj, context, td_for_child));
}
return params;
}
Expand Down
Loading

0 comments on commit 2629167

Please sign in to comment.