Skip to content

Commit

Permalink
Prototype for options setting for SystemSolver class.
Browse files Browse the repository at this point in the history
  • Loading branch information
pelesh committed Dec 4, 2023
1 parent d5fb0df commit 9cc141a
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 78 deletions.
24 changes: 13 additions & 11 deletions resolve/LinSolverIterativeFGMRES.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ namespace ReSolve
memspace_ = memspace;
this->matrix_handler_ = nullptr;
this->vector_handler_ = nullptr;
tol_ = 1e-14; //default
maxit_= 100; //default
restart_ = 10;
conv_cond_ = 0;//default
flexible_ = 1;
// Defaults:
// tol_ = 1e-14;
// maxit_= 100;
// restart_ = 10;
// conv_cond_ = 0;
// flexible_ = true;

d_V_ = nullptr;
d_Z_ = nullptr;
Expand All @@ -35,11 +36,12 @@ namespace ReSolve
this->vector_handler_ = vector_handler;
this->GS_ = gs;

tol_ = 1e-14; //default
maxit_= 100; //default
restart_ = 10;
conv_cond_ = 0;//default
flexible_ = 1;
// Defaults:
// tol_ = 1e-14;
// maxit_= 100;
// restart_ = 10;
// conv_cond_ = 0;
// flexible_ = true;

d_V_ = nullptr;
d_Z_ = nullptr;
Expand All @@ -63,7 +65,7 @@ namespace ReSolve
maxit_= maxit;
restart_ = restart;
conv_cond_ = conv_cond;
flexible_ = 1;
flexible_ = true;

d_V_ = nullptr;
d_Z_ = nullptr;
Expand Down
81 changes: 63 additions & 18 deletions resolve/SystemSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,38 +46,44 @@ namespace ReSolve
}

#ifdef RESOLVE_USE_CUDA
SystemSolver::SystemSolver(LinAlgWorkspaceCUDA* workspace, std::string ir) : workspaceCuda_(workspace), irMethod_(ir)
SystemSolver::SystemSolver(LinAlgWorkspaceCUDA* workspaceCuda,
std::string factor,
std::string refactor,
std::string solve,
std::string ir)
: workspaceCuda_(workspaceCuda),
factorizationMethod_(factor),
refactorizationMethod_(refactor),
solveMethod_(solve),
irMethod_(ir)
{
// Instantiate handlers
matrixHandler_ = new MatrixHandler(workspaceCuda_);
vectorHandler_ = new VectorHandler(workspaceCuda_);

//set defaults:
memspace_ = "cuda";
factorizationMethod_ = "klu";
refactorizationMethod_ = "glu";
solveMethod_ = "glu";
// irMethod_ = "none";
gsMethod_ = "cgs2";

initialize();
}
#endif

#ifdef RESOLVE_USE_HIP
SystemSolver::SystemSolver(LinAlgWorkspaceHIP* workspace, std::string ir) : workspaceHip_(workspace), irMethod_(ir)
SystemSolver::SystemSolver(LinAlgWorkspaceHIP* workspaceHip,
std::string factor,
std::string refactor,
std::string solve,
std::string ir)
: workspaceHip_(workspaceHip),
factorizationMethod_(factor),
refactorizationMethod_(refactor),
solveMethod_(solve),
irMethod_(ir)
{
// Instantiate handlers
matrixHandler_ = new MatrixHandler(workspaceHip_);
vectorHandler_ = new VectorHandler(workspaceHip_);

//set defaults:
memspace_ = "hip";
factorizationMethod_ = "klu";
refactorizationMethod_ = "rocsolverrf";
solveMethod_ = "rocsolverrf";
// irMethod_ = "none";
gsMethod_ = "cgs2";

initialize();
}
Expand Down Expand Up @@ -165,7 +171,9 @@ namespace ReSolve
gs_ = new GramSchmidt(vectorHandler_, GramSchmidt::cgs1);
} else {
out::warning() << "Gram-Schmidt variant " << gsMethod_ << " not recognized.\n";
gs_ = nullptr;
out::warning() << "Using default cgs2 Gram-Schmidt variant.\n";
gs_ = new GramSchmidt(vectorHandler_, GramSchmidt::cgs2);
gsMethod_ = "cgs2";
}

iterativeSolver_ = new LinSolverIterativeFGMRES(matrixHandler_,
Expand Down Expand Up @@ -354,10 +362,47 @@ namespace ReSolve
// initialize();
}

void SystemSolver::setRefinementMethod(std::string method)
void SystemSolver::setRefinementMethod(std::string method, std::string gsMethod)
{
irMethod_ = method;
// initialize();
if (iterativeSolver_ != nullptr)
delete iterativeSolver_;

if(gs_ != nullptr)
delete gs_;

if(method == "none")
return;

gsMethod_ = gsMethod;

#if defined(RESOLVE_USE_HIP) || defined(RESOLVE_USE_CUDA)
if (method == "fgmres") {
if (gsMethod == "cgs2") {
gs_ = new GramSchmidt(vectorHandler_, GramSchmidt::cgs2);
} else if (gsMethod == "mgs") {
gs_ = new GramSchmidt(vectorHandler_, GramSchmidt::mgs);
} else if (gsMethod == "mgs_two_synch") {
gs_ = new GramSchmidt(vectorHandler_, GramSchmidt::mgs_two_synch);
} else if (gsMethod == "mgs_pm") {
gs_ = new GramSchmidt(vectorHandler_, GramSchmidt::mgs_pm);
} else if (gsMethod == "cgs1") {
gs_ = new GramSchmidt(vectorHandler_, GramSchmidt::cgs1);
} else {
out::warning() << "Gram-Schmidt variant " << gsMethod_ << " not recognized.\n";
out::warning() << "Using default cgs2 Gram-Schmidt variant.\n";
gs_ = new GramSchmidt(vectorHandler_, GramSchmidt::cgs2);
gsMethod_ = "cgs2";
}

iterativeSolver_ = new LinSolverIterativeFGMRES(matrixHandler_,
vectorHandler_,
gs_,
memspace_);
irMethod_ = method;
} else {
out::error() << "Iterative refinement method " << method << " not recognized.\n";
}
#endif
}

real_type SystemSolver::getResidualNorm(vector_type* rhs, vector_type* x)
Expand Down
16 changes: 12 additions & 4 deletions resolve/SystemSolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,16 @@ namespace ReSolve
using matrix_type = matrix::Sparse;

SystemSolver();
SystemSolver(LinAlgWorkspaceCUDA* workspaceCuda, std::string ir = "none");
SystemSolver(LinAlgWorkspaceHIP* workspaceHip, std::string ir = "none");
SystemSolver(std::string factorizationMethod, std::string refactorizationMethod, std::string solveMethod, std::string IRMethod);
SystemSolver(LinAlgWorkspaceCUDA* workspaceCuda,
std::string factor = "klu",
std::string refactor = "glu",
std::string solve = "glu",
std::string ir = "none");
SystemSolver(LinAlgWorkspaceHIP* workspaceHip,
std::string factor = "klu",
std::string refactor = "rocsolverrf",
std::string solve = "rocsolverrf",
std::string ir = "none");

~SystemSolver();

Expand Down Expand Up @@ -58,12 +65,13 @@ namespace ReSolve
const std::string getRefactorizationMethod() const;
const std::string getSolveMethod() const;
const std::string getRefinementMethod() const;
const std::string getOrthogonalizationMethod() const;

// Set solver parameters
void setFactorizationMethod(std::string method);
void setRefactorizationMethod(std::string method);
void setSolveMethod(std::string method);
void setRefinementMethod(std::string method);
void setRefinementMethod(std::string method, std::string gs = "cgs2");

private:
LinSolverDirect* factorizationSolver_{nullptr};
Expand Down
Loading

0 comments on commit 9cc141a

Please sign in to comment.