Skip to content

Commit

Permalink
add config for chebyshev
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Jun 28, 2024
1 parent ac66085 commit 24fc2df
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 5 deletions.
8 changes: 4 additions & 4 deletions core/config/config_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ namespace gko {
namespace config {


#define GKO_INVALID_CONFIG_VALUE(_entry, _value) \
GKO_INVALID_STATE(std::string("The value >" + _value + \
"< is invalid for the entry >" + _entry + \
"<"))
#define GKO_INVALID_CONFIG_VALUE(_entry, _value) \
GKO_INVALID_STATE(std::string("The value >") + _value + \
"< is invalid for the entry >" + _entry + "<")


#define GKO_MISSING_CONFIG_ENTRY(_entry) \
Expand All @@ -52,6 +51,7 @@ enum class LinOpFactoryType : int {
Direct,
LowerTrs,
UpperTrs,
Chebyshev,
Factorization_Ic,
Factorization_Ilu,
Cholesky,
Expand Down
1 change: 1 addition & 0 deletions core/config/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ configuration_map generate_config_map()
{"solver::Direct", parse<LinOpFactoryType::Direct>},
{"solver::LowerTrs", parse<LinOpFactoryType::LowerTrs>},
{"solver::UpperTrs", parse<LinOpFactoryType::UpperTrs>},
{"solver::Chebyshev", parse<LinOpFactoryType::Chebyshev>},
{"factorization::Ic", parse<LinOpFactoryType::Factorization_Ic>},
{"factorization::Ilu", parse<LinOpFactoryType::Factorization_Ilu>},
{"factorization::Cholesky", parse<LinOpFactoryType::Cholesky>},
Expand Down
2 changes: 2 additions & 0 deletions core/config/solver_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <ginkgo/core/solver/cb_gmres.hpp>
#include <ginkgo/core/solver/cg.hpp>
#include <ginkgo/core/solver/cgs.hpp>
#include <ginkgo/core/solver/chebyshev.hpp>
#include <ginkgo/core/solver/direct.hpp>
#include <ginkgo/core/solver/fcg.hpp>
#include <ginkgo/core/solver/gcr.hpp>
Expand Down Expand Up @@ -43,6 +44,7 @@ GKO_PARSE_VALUE_TYPE(CbGmres, gko::solver::CbGmres);
GKO_PARSE_VALUE_AND_INDEX_TYPE(Direct, gko::experimental::solver::Direct);
GKO_PARSE_VALUE_AND_INDEX_TYPE(LowerTrs, gko::solver::LowerTrs);
GKO_PARSE_VALUE_AND_INDEX_TYPE(UpperTrs, gko::solver::UpperTrs);
GKO_PARSE_VALUE_TYPE(Chebyshev, gko::solver::Chebyshev);


template <>
Expand Down
23 changes: 23 additions & 0 deletions core/solver/chebyshev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ginkgo/core/matrix/dense.hpp>
#include <ginkgo/core/solver/solver_base.hpp>

#include "core/config/solver_config.hpp"
#include "core/distributed/helpers.hpp"
#include "core/solver/ir_kernels.hpp"
#include "core/solver/solver_base.hpp"
Expand All @@ -27,6 +28,28 @@ GKO_REGISTER_OPERATION(initialize, ir::initialize);
} // anonymous namespace
} // namespace chebyshev

template <typename ValueType>
typename Chebyshev<ValueType>::parameters_type Chebyshev<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto params = solver::Chebyshev<ValueType>::build();
common_solver_parse(params, config, context, td_for_child);
if (auto& obj = config.get("foci")) {
auto arr = obj.get_array();
if (arr.size() != 2) {
GKO_INVALID_CONFIG_VALUE("foci", "must contain two elements");
}
params.with_foci(gko::config::get_value<ValueType>(arr.at(0)),
gko::config::get_value<ValueType>(arr.at(1)));
}
if (auto& obj = config.get("default_initial_guess")) {
params.with_default_initial_guess(
gko::config::get_value<solver::initial_guess_mode>(obj));
}
return params;
}


template <typename ValueType>
Chebyshev<ValueType>::Chebyshev(const Factory* factory,
Expand Down
36 changes: 36 additions & 0 deletions core/test/config/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ginkgo/core/solver/cb_gmres.hpp>
#include <ginkgo/core/solver/cg.hpp>
#include <ginkgo/core/solver/cgs.hpp>
#include <ginkgo/core/solver/chebyshev.hpp>
#include <ginkgo/core/solver/direct.hpp>
#include <ginkgo/core/solver/fcg.hpp>
#include <ginkgo/core/solver/gcr.hpp>
Expand Down Expand Up @@ -436,6 +437,41 @@ struct UpperTrs : TrsHelper<gko::solver::UpperTrs> {
};


struct Chebyshev : SolverConfigTest<gko::solver::Chebyshev<float>,
gko::solver::Chebyshev<double>> {
static pnode::map_type setup_base()
{
return {{"type", pnode{"solver::Chebyshev"}}};
}

template <bool from_reg, typename ParamType>
static void set(pnode::map_type& config_map, ParamType& param, registry reg,
std::shared_ptr<const gko::Executor> exec)
{
solver_config_test::template set<from_reg>(config_map, param, reg,
exec);
using fvt = typename decltype(param.foci)::first_type;
config_map["foci"] =
pnode::array_type{pnode{fvt{0.5}}, pnode{fvt{1.5}}};
param.with_foci(fvt{0.5}, fvt{1.5});
config_map["default_initial_guess"] = pnode{"zero"};
param.with_default_initial_guess(gko::solver::initial_guess_mode::zero);
}

template <bool from_reg, typename AnswerType>
static void validate(gko::LinOpFactory* result, AnswerType* answer)
{
auto res_param = gko::as<AnswerType>(result)->get_parameters();
auto ans_param = answer->get_parameters();

solver_config_test::template validate<from_reg>(result, answer);
ASSERT_EQ(res_param.foci, ans_param.foci);
ASSERT_EQ(res_param.default_initial_guess,
ans_param.default_initial_guess);
}
};


template <typename T>
class Solver : public ::testing::Test {
protected:
Expand Down
22 changes: 21 additions & 1 deletion include/ginkgo/core/solver/chebyshev.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/lin_op.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/matrix/dense.hpp>
#include <ginkgo/core/matrix/identity.hpp>
#include <ginkgo/core/solver/solver_base.hpp>
Expand Down Expand Up @@ -49,7 +51,7 @@ namespace solver {
* @ingroup LinOp
*/
template <typename ValueType = default_precision>
class Chebyshev
class Chebyshev final
: public EnableLinOp<Chebyshev<ValueType>>,
public EnablePreconditionedIterativeSolver<ValueType,
Chebyshev<ValueType>>,
Expand Down Expand Up @@ -133,6 +135,24 @@ class Chebyshev
GKO_ENABLE_LIN_OP_FACTORY(Chebyshev, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);

/**
* Create the parameters from the property_tree.
* Because this is directly tied to the specific type, the value/index type
* settings within config are ignored and type_descriptor is only used
* for children configs.
*
* @param config the property tree for setting
* @param context the registry
* @param td_for_child the type descriptor for children configs. The
* default uses the value type of this class.
*
* @return parameters
*/
static parameters_type parse(const config::pnode& config,
const config::registry& context,
const config::type_descriptor& td_for_child =
config::make_type_descriptor<ValueType>());

protected:
void apply_impl(const LinOp* b, LinOp* x) const override;

Expand Down

0 comments on commit 24fc2df

Please sign in to comment.