Skip to content

Commit

Permalink
use scalar from ws and move some local variables
Browse files Browse the repository at this point in the history
Co-authored-by: Marcel Koch <[email protected]>
  • Loading branch information
yhmtsai and MarcelKoch committed Sep 12, 2023
1 parent ba43ed3 commit edee59a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 24 deletions.
11 changes: 4 additions & 7 deletions core/solver/chebyshev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "core/distributed/helpers.hpp"
#include "core/solver/ir_kernels.hpp"
#include "core/solver/residual_update.hpp"
#include "core/solver/solver_base.hpp"
#include "core/solver/solver_boilerplate.hpp"
#include "core/solver/update_residual.hpp"


namespace gko {
Expand Down Expand Up @@ -186,7 +186,6 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
{
using Vector = matrix::Dense<ValueType>;
using ws = workspace_traits<Chebyshev>;
constexpr uint8 relative_stopping_id{1};

auto exec = this->get_executor();
this->setup_workspace();
Expand Down Expand Up @@ -229,7 +228,6 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
auto beta_ref = ValueType{0.5} * (foci_direction_ * alpha_ref) *
(foci_direction_ * alpha_ref);

bool one_changed{};
auto& stop_status = this->template create_workspace_array<stopping_status>(
ws::stop, dense_b->get_size()[1]);
exec->run(chebyshev::make_initialize(&stop_status));
Expand Down Expand Up @@ -257,10 +255,9 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
solver, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr,
&stop_status, all_stopped);
};
bool all_stopped = residual_update(
this, iter, one_op, neg_one_op, dense_b, dense_x, residual,
residual_ptr, stop_criterion, relative_stopping_id, stop_status,
one_changed, log_func);
bool all_stopped = update_residual(
this, iter, dense_b, dense_x, residual, residual_ptr,
stop_criterion, stop_status, log_func);
if (all_stopped) {
break;
}
Expand Down
11 changes: 4 additions & 7 deletions core/solver/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "core/distributed/helpers.hpp"
#include "core/solver/ir_kernels.hpp"
#include "core/solver/residual_update.hpp"
#include "core/solver/solver_base.hpp"
#include "core/solver/solver_boilerplate.hpp"
#include "core/solver/update_residual.hpp"


namespace gko {
Expand Down Expand Up @@ -193,7 +193,6 @@ void Ir<ValueType>::apply_dense_impl(const VectorType* dense_b,
{
using Vector = matrix::Dense<ValueType>;
using ws = workspace_traits<Ir>;
constexpr uint8 relative_stopping_id{1};

auto exec = this->get_executor();
this->setup_workspace();
Expand All @@ -203,7 +202,6 @@ void Ir<ValueType>::apply_dense_impl(const VectorType* dense_b,

GKO_SOLVER_ONE_MINUS_ONE();

bool one_changed{};
auto& stop_status = this->template create_workspace_array<stopping_status>(
ws::stop, dense_b->get_size()[1]);
exec->run(ir::make_initialize(&stop_status));
Expand Down Expand Up @@ -232,10 +230,9 @@ void Ir<ValueType>::apply_dense_impl(const VectorType* dense_b,
solver, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr,
&stop_status, all_stopped);
};
bool all_stopped = residual_update(
this, iter, one_op, neg_one_op, dense_b, dense_x, residual,
residual_ptr, stop_criterion, relative_stopping_id, stop_status,
one_changed, log_func);
bool all_stopped = update_residual(
this, iter, dense_b, dense_x, residual, residual_ptr,
stop_criterion, stop_status, log_func);
if (all_stopped) {
break;
}
Expand Down
24 changes: 14 additions & 10 deletions core/solver/residual_update.hpp → core/solver/update_residual.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************<GINKGO LICENSE>*******************************/

#ifndef GKO_CORE_SOLVER_RESIDUAL_UPDATE_HPP_
#define GKO_CORE_SOLVER_RESIDUAL_UPDATE_HPP_
#ifndef GKO_CORE_SOLVER_UPDATE_RESIDUAL_HPP_
#define GKO_CORE_SOLVER_UPDATE_RESIDUAL_HPP_


#include <ginkgo/core/base/array.hpp>
Expand All @@ -43,17 +43,21 @@ namespace gko {
namespace solver {


template <typename SolverType, typename VectorType, typename ScalarType,
typename LogFunc>
bool residual_update(SolverType* solver, int iter, const ScalarType* one_op,
const ScalarType* neg_one_op, const VectorType* dense_b,
template <typename SolverType, typename VectorType, typename LogFunc>
bool update_residual(SolverType* solver, int iter, const VectorType* dense_b,
VectorType* dense_x, VectorType* residual,
const VectorType*& residual_ptr,
std::unique_ptr<gko::stop::Criterion>& stop_criterion,
uint8 relative_stopping_id,
array<stopping_status>& stop_status, bool& one_changed,
LogFunc log)
array<stopping_status>& stop_status, LogFunc log)
{
using ws = workspace_traits<std::remove_cv_t<SolverType>>;
constexpr uint8 relative_stopping_id{1};

// It's required to be initialized outside.
auto one_op = solver->get_workspace_op(ws::one);
auto neg_one_op = solver->get_workspace_op(ws::minus_one);

bool one_changed{};
if (iter == 0) {
// In iter 0, the iteration and residual are updated.
bool all_stopped =
Expand Down Expand Up @@ -100,4 +104,4 @@ bool residual_update(SolverType* solver, int iter, const ScalarType* one_op,
} // namespace solver
} // namespace gko

#endif // GKO_CORE_SOLVER_RESIDUAL_UPDATE_HPP_
#endif // GKO_CORE_SOLVER_UPDATE_RESIDUAL_HPP_

0 comments on commit edee59a

Please sign in to comment.