Skip to content

Commit

Permalink
SystemSolver now lets its modules manage their own parameter settings.
Browse files Browse the repository at this point in the history
  • Loading branch information
pelesh committed Dec 4, 2023
1 parent 2f483d0 commit d5fb0df
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 52 deletions.
2 changes: 1 addition & 1 deletion examples/r_SysSolverHipRefine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ int main(int argc, char *argv[])
<< rnrm
<< " final nrm: "
<< solver->getResidualNorm(vec_rhs, vec_x)
<< " iter: " << solver->getNumIter()
<< " iter: " << solver->getIterativeSolver().getNumIter()
<< "\n";
}
}
Expand Down
37 changes: 1 addition & 36 deletions resolve/SystemSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,7 @@ namespace ReSolve
gs_ = nullptr;
}

iterativeSolver_ = new LinSolverIterativeFGMRES(irRestart_,
irTol_,
irMaxit_,
irConvCond_,
matrixHandler_,
iterativeSolver_ = new LinSolverIterativeFGMRES(matrixHandler_,
vectorHandler_,
gs_,
memspace_);
Expand Down Expand Up @@ -401,40 +397,9 @@ namespace ReSolve
return resnorm/norm_b;
}

real_type SystemSolver::getInitResidualNorm()
{
#if defined(RESOLVE_USE_HIP) || defined(RESOLVE_USE_CUDA)
return iterativeSolver_->getInitResidualNorm();
#endif
}

real_type SystemSolver::getFinalResidualNorm()
{
#if defined(RESOLVE_USE_HIP) || defined(RESOLVE_USE_CUDA)
return iterativeSolver_->getFinalResidualNorm();
#endif
}

int SystemSolver::getNumIter()
{
#if defined(RESOLVE_USE_HIP) || defined(RESOLVE_USE_CUDA)
return iterativeSolver_->getNumIter();
#endif
}

const std::string SystemSolver::getFactorizationMethod() const
{
return factorizationMethod_;
}

void SystemSolver::setMaxIterations(int maxIter)
{
iterativeSolver_->setMaxit(maxIter);
}

void SystemSolver::setIterationsRestart(int restart)
{
iterativeSolver_->setRestart(restart);
}

} // namespace ReSolve
12 changes: 0 additions & 12 deletions resolve/SystemSolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ namespace ReSolve

real_type getResidualNorm(vector_type* rhs, vector_type* x);

real_type getInitResidualNorm();
real_type getFinalResidualNorm();
int getNumIter();

// Get solver parameters
const std::string getFactorizationMethod() const;
const std::string getRefactorizationMethod() const;
Expand All @@ -69,9 +65,6 @@ namespace ReSolve
void setSolveMethod(std::string method);
void setRefinementMethod(std::string method);

void setMaxIterations(int maxIter);
void setIterationsRestart(int restart);

private:
LinSolverDirect* factorizationSolver_{nullptr};
LinSolverDirect* refactorizationSolver_{nullptr};
Expand Down Expand Up @@ -104,10 +97,5 @@ namespace ReSolve
std::string gsMethod_;

std::string memspace_;

real_type irTol_{1e-14};
int irMaxit_{100};
int irRestart_{10};
int irConvCond_{0};
};
} // namespace ReSolve
6 changes: 3 additions & 3 deletions tests/functionality/testSysHipRefine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ int main(int argc, char *argv[])
std::cout<<"\t ||x-x_true||_2 : "<<normDiffMatrix2<<" (solution error)"<<std::endl;
std::cout<<"\t ||x-x_true||_2/||x_true||_2 : "<<normDiffMatrix2/normXtrue<<" (scaled solution error)"<<std::endl;
std::cout<<"\t ||b-A*x_exact||_2 : "<<exactSol_normRmatrix2<<" (control; residual norm with exact solution)"<<std::endl;
std::cout<<"\t IR iterations : "<<solver->getNumIter()<<" (max 200, restart 100)"<<std::endl;
std::cout<<"\t IR starting res. norm : "<<solver->getInitResidualNorm()<<" "<<std::endl;
std::cout<<"\t IR final res. norm : "<<solver->getFinalResidualNorm()<<" (tol 1e-14)"<<std::endl<<std::endl;
std::cout<<"\t IR iterations : "<<solver->getIterativeSolver().getNumIter()<<" (max 200, restart 100)"<<std::endl;
std::cout<<"\t IR starting res. norm : "<<solver->getIterativeSolver().getInitResidualNorm() <<" "<<std::endl;
std::cout<<"\t IR final res. norm : "<<solver->getIterativeSolver().getFinalResidualNorm() <<" (tol 1e-14)"<<std::endl<<std::endl;

if ((normRmatrix1/normB1 > 1e-12 ) || (normRmatrix2/normB2 > 1e-9)) {
std::cout << "Result inaccurate!\n";
Expand Down

0 comments on commit d5fb0df

Please sign in to comment.