Skip to content

Commit

Permalink
Improve tests for GS
Browse files Browse the repository at this point in the history
  • Loading branch information
pelesh committed Oct 6, 2023
1 parent cec7893 commit c9ae8de
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 38 deletions.
2 changes: 1 addition & 1 deletion examples/r_KLU_rf_FGMRES.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ int main(int argc, char *argv[])
real_type one = 1.0;
real_type minusone = -1.0;

ReSolve::GramSchmidt* GS = new ReSolve::GramSchmidt(vector_handler, ReSolve::cgs2);
ReSolve::GramSchmidt* GS = new ReSolve::GramSchmidt(vector_handler, ReSolve::GramSchmidt::cgs2);
ReSolve::LinSolverDirectKLU* KLU = new ReSolve::LinSolverDirectKLU;
ReSolve::LinSolverDirectCuSolverRf* Rf = new ReSolve::LinSolverDirectCuSolverRf;
ReSolve::LinSolverIterativeFGMRES* FGMRES = new ReSolve::LinSolverIterativeFGMRES(matrix_handler, vector_handler, GS);
Expand Down
2 changes: 1 addition & 1 deletion examples/r_KLU_rf_FGMRES_reuse_factorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ int main(int argc, char *argv[])
real_type one = 1.0;
real_type minusone = -1.0;

ReSolve::GramSchmidt* GS = new ReSolve::GramSchmidt(vector_handler, ReSolve::cgs2);
ReSolve::GramSchmidt* GS = new ReSolve::GramSchmidt(vector_handler, ReSolve::GramSchmidt::cgs2);

ReSolve::LinSolverDirectKLU* KLU = new ReSolve::LinSolverDirectKLU;
ReSolve::LinSolverDirectCuSolverRf* Rf = new ReSolve::LinSolverDirectCuSolverRf;
Expand Down
9 changes: 8 additions & 1 deletion resolve/GramSchmidt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace ReSolve
int idxmap(index_type i, index_type j, index_type col_lenght) {
return i * (col_lenght) + j;
}

GramSchmidt::GramSchmidt()
{
variant_ = mgs; //variant is enum now
Expand Down Expand Up @@ -74,7 +75,7 @@ namespace ReSolve
return 0;
}

GSVariant GramSchmidt::getVariant()
GramSchmidt::GSVariant GramSchmidt::getVariant()
{
return variant_;
}
Expand All @@ -84,6 +85,11 @@ namespace ReSolve
return h_L_;
}

bool GramSchmidt::isSetupComplete()
{
return setup_complete_;
}

int GramSchmidt::setup(index_type n, index_type restart)
{
if (setup_complete_) {
Expand Down Expand Up @@ -115,6 +121,7 @@ namespace ReSolve

return 0;
}

//this always happen on the GPU
int GramSchmidt::orthogonalize(index_type n, vector::Vector* V, real_type* H, index_type i, std::string memspace)
{
Expand Down
53 changes: 27 additions & 26 deletions resolve/GramSchmidt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,43 @@
#include <cassert>
namespace ReSolve
{
enum GSVariant { mgs = 0,
cgs2 = 1,
mgs_two_synch = 2,
mgs_pm = 3,
cgs1 = 4 };
class GramSchmidt
{
using vector_type = vector::Vector;
using vector_type = vector::Vector;
public:
GramSchmidt();
GramSchmidt(VectorHandler* vh, GSVariant variant);
~GramSchmidt();
int setVariant(GSVariant variant);
GSVariant getVariant();
real_type* getL(); //only for low synch, returns null ptr otherwise
enum GSVariant { mgs = 0,
cgs2 = 1,
mgs_two_synch = 2,
mgs_pm = 3,
cgs1 = 4 };

int setup(index_type n, index_type restart);
int orthogonalize(index_type n, vector_type* V, real_type* H, index_type i, std::string memspace);
GramSchmidt();
GramSchmidt(VectorHandler* vh, GSVariant variant);
~GramSchmidt();
int setVariant(GramSchmidt::GSVariant variant);
GSVariant getVariant();
real_type* getL(); //only for low synch, returns null ptr otherwise

int setup(index_type n, index_type restart);
int orthogonalize(index_type n, vector_type* V, real_type* H, index_type i, std::string memspace);
bool isSetupComplete();

private:

GSVariant variant_;
bool setup_complete_; //to avoid double allocations and stuff
GSVariant variant_;
bool setup_complete_; //to avoid double allocations and stuff

index_type num_vecs_; //the same as restart
vector_type* vec_rv_;
vector_type* vec_Hcolumn_;
// vector_type* d_H_col_;
index_type num_vecs_; //the same as restart
vector_type* vec_rv_{nullptr};
vector_type* vec_Hcolumn_{nullptr};

real_type* h_L_;
real_type* h_rv_;
real_type* h_aux_;
VectorHandler* vector_handler_;
real_type* h_L_{nullptr};
real_type* h_rv_{nullptr};
real_type* h_aux_{nullptr};
VectorHandler* vector_handler_{nullptr};

vector_type* vec_v_; // aux variable
vector_type* vec_w_; // aux variable
vector_type* vec_v_{nullptr}; // aux variable
vector_type* vec_w_{nullptr}; // aux variable
};

}//namespace
2 changes: 1 addition & 1 deletion tests/functionality/testKLU_Rf_FGMRES.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ int main(int argc, char *argv[])
KLU->setupParameters(1, 0.1, false);

ReSolve::LinSolverDirectCuSolverRf* Rf = new ReSolve::LinSolverDirectCuSolverRf;
ReSolve::GramSchmidt* GS = new ReSolve::GramSchmidt(vector_handler, ReSolve::mgs_pm);
ReSolve::GramSchmidt* GS = new ReSolve::GramSchmidt(vector_handler, ReSolve::GramSchmidt::mgs_pm);
ReSolve::LinSolverIterativeFGMRES* FGMRES = new ReSolve::LinSolverIterativeFGMRES(matrix_handler, vector_handler, GS);
// Input to this code is location of `data` directory where matrix files are stored
const std::string data_path = (argc == 2) ? argv[1] : "./";
Expand Down
39 changes: 35 additions & 4 deletions tests/unit/vector/GramSchmidtTests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,46 @@ namespace ReSolve {
TestOutcome GramSchmidtConstructor()
{
TestStatus status;
status.skipTest();
// status.skipTest();

GramSchmidt gs1;
status *= (gs1.getVariant() == GramSchmidt::mgs);
status *= (gs1.getL() == nullptr);
status *= !gs1.isSetupComplete();

VectorHandler vh;
GramSchmidt gs2(&vh, GramSchmidt::mgs_pm);
status *= (gs2.getVariant() == GramSchmidt::mgs_pm);
status *= (gs1.getL() == nullptr);
status *= !gs1.isSetupComplete();

return status.report(__func__);
}

TestOutcome orthogonalize(index_type N, GSVariant var)
TestOutcome orthogonalize(index_type N, GramSchmidt::GSVariant var)
{
TestStatus status;

std::string testname(__func__);
switch(var)
{
case GramSchmidt::mgs:
testname += " (Modified Gram-Schmidt)";
break;
case GramSchmidt::mgs_two_synch:
testname += " (Modified Gram-Schmidt 2-Sync)";
break;
case GramSchmidt::mgs_pm:
testname += " (Post-Modern Modified Gram-Schmidt)";
break;
case GramSchmidt::cgs1:
testname += " (Classical Gram-Schmidt)";
break;
case GramSchmidt::cgs2:
testname += " (Reorthogonalized Classical Gram-Schmidt)";
break;
}

ReSolve::LinAlgWorkspace* workspace = createLinAlgWorkspace(memspace_);
ReSolve::VectorHandler* handler = new ReSolve::VectorHandler(workspace);

Expand All @@ -42,7 +73,7 @@ namespace ReSolve {
real_type* aux_data; // needed for setup

V->allocate(memspace_);
if (memspace_ != "cpu") {
if (memspace_ != "cpu") {
V->allocate("cpu");
}

Expand Down Expand Up @@ -87,7 +118,7 @@ namespace ReSolve {
delete V;
delete GS;

return status.report(__func__);
return status.report(testname.c_str());
}

private:
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/vector/runGramSchmidtTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ int main(int argc, char* argv[])
std::cout << "Running tests with CUDA backend:\n";
ReSolve::tests::GramSchmidtTests test("cuda");

result += test.orthogonalize(5000, ReSolve::mgs);
result += test.orthogonalize(5000, ReSolve::cgs2);
result += test.orthogonalize(5000, ReSolve::mgs_two_synch);
result += test.orthogonalize(5000, ReSolve::mgs_pm);
result += test.GramSchmidtConstructor();
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs2);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_two_synch);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_pm);
std::cout << "\n";
}

Expand Down

0 comments on commit c9ae8de

Please sign in to comment.