From c9ae8deac248e378e77dc6e4e35ed5245832e159 Mon Sep 17 00:00:00 2001 From: Slaven Peles Date: Thu, 5 Oct 2023 20:13:16 -0400 Subject: [PATCH] Improve tests for GS --- examples/r_KLU_rf_FGMRES.cpp | 2 +- .../r_KLU_rf_FGMRES_reuse_factorization.cpp | 2 +- resolve/GramSchmidt.cpp | 9 +++- resolve/GramSchmidt.hpp | 53 ++++++++++--------- tests/functionality/testKLU_Rf_FGMRES.cpp | 2 +- tests/unit/vector/GramSchmidtTests.hpp | 39 ++++++++++++-- tests/unit/vector/runGramSchmidtTests.cpp | 9 ++-- 7 files changed, 78 insertions(+), 38 deletions(-) diff --git a/examples/r_KLU_rf_FGMRES.cpp b/examples/r_KLU_rf_FGMRES.cpp index 50888b31..afe01e34 100644 --- a/examples/r_KLU_rf_FGMRES.cpp +++ b/examples/r_KLU_rf_FGMRES.cpp @@ -48,7 +48,7 @@ int main(int argc, char *argv[]) real_type one = 1.0; real_type minusone = -1.0; - ReSolve::GramSchmidt* GS = new ReSolve::GramSchmidt(vector_handler, ReSolve::cgs2); + ReSolve::GramSchmidt* GS = new ReSolve::GramSchmidt(vector_handler, ReSolve::GramSchmidt::cgs2); ReSolve::LinSolverDirectKLU* KLU = new ReSolve::LinSolverDirectKLU; ReSolve::LinSolverDirectCuSolverRf* Rf = new ReSolve::LinSolverDirectCuSolverRf; ReSolve::LinSolverIterativeFGMRES* FGMRES = new ReSolve::LinSolverIterativeFGMRES(matrix_handler, vector_handler, GS); diff --git a/examples/r_KLU_rf_FGMRES_reuse_factorization.cpp b/examples/r_KLU_rf_FGMRES_reuse_factorization.cpp index eab01c58..3061eb56 100644 --- a/examples/r_KLU_rf_FGMRES_reuse_factorization.cpp +++ b/examples/r_KLU_rf_FGMRES_reuse_factorization.cpp @@ -49,7 +49,7 @@ int main(int argc, char *argv[]) real_type one = 1.0; real_type minusone = -1.0; - ReSolve::GramSchmidt* GS = new ReSolve::GramSchmidt(vector_handler, ReSolve::cgs2); + ReSolve::GramSchmidt* GS = new ReSolve::GramSchmidt(vector_handler, ReSolve::GramSchmidt::cgs2); ReSolve::LinSolverDirectKLU* KLU = new ReSolve::LinSolverDirectKLU; ReSolve::LinSolverDirectCuSolverRf* Rf = new ReSolve::LinSolverDirectCuSolverRf; diff --git a/resolve/GramSchmidt.cpp b/resolve/GramSchmidt.cpp index c54e41e3..b6a27b04 100644 --- a/resolve/GramSchmidt.cpp +++ b/resolve/GramSchmidt.cpp @@ -13,6 +13,7 @@ namespace ReSolve int idxmap(index_type i, index_type j, index_type col_lenght) { return i * (col_lenght) + j; } + GramSchmidt::GramSchmidt() { variant_ = mgs; //variant is enum now @@ -74,7 +75,7 @@ namespace ReSolve return 0; } - GSVariant GramSchmidt::getVariant() + GramSchmidt::GSVariant GramSchmidt::getVariant() { return variant_; } @@ -84,6 +85,11 @@ namespace ReSolve return h_L_; } + bool GramSchmidt::isSetupComplete() + { + return setup_complete_; + } + int GramSchmidt::setup(index_type n, index_type restart) { if (setup_complete_) { @@ -115,6 +121,7 @@ namespace ReSolve return 0; } + //this always happen on the GPU int GramSchmidt::orthogonalize(index_type n, vector::Vector* V, real_type* H, index_type i, std::string memspace) { diff --git a/resolve/GramSchmidt.hpp b/resolve/GramSchmidt.hpp index e6a41933..7d7b93be 100644 --- a/resolve/GramSchmidt.hpp +++ b/resolve/GramSchmidt.hpp @@ -5,42 +5,43 @@ #include namespace ReSolve { - enum GSVariant { mgs = 0, - cgs2 = 1, - mgs_two_synch = 2, - mgs_pm = 3, - cgs1 = 4 }; class GramSchmidt { - using vector_type = vector::Vector; + using vector_type = vector::Vector; public: - GramSchmidt(); - GramSchmidt(VectorHandler* vh, GSVariant variant); - ~GramSchmidt(); - int setVariant(GSVariant variant); - GSVariant getVariant(); - real_type* getL(); //only for low synch, returns null ptr otherwise + enum GSVariant { mgs = 0, + cgs2 = 1, + mgs_two_synch = 2, + mgs_pm = 3, + cgs1 = 4 }; - int setup(index_type n, index_type restart); - int orthogonalize(index_type n, vector_type* V, real_type* H, index_type i, std::string memspace); + GramSchmidt(); + GramSchmidt(VectorHandler* vh, GSVariant variant); + ~GramSchmidt(); + int setVariant(GramSchmidt::GSVariant variant); + GSVariant getVariant(); + real_type* getL(); //only for low synch, returns null ptr otherwise + + int setup(index_type n, index_type restart); + int orthogonalize(index_type n, vector_type* V, real_type* H, index_type i, std::string memspace); + bool isSetupComplete(); private: - GSVariant variant_; - bool setup_complete_; //to avoid double allocations and stuff + GSVariant variant_; + bool setup_complete_; //to avoid double allocations and stuff - index_type num_vecs_; //the same as restart - vector_type* vec_rv_; - vector_type* vec_Hcolumn_; -// vector_type* d_H_col_; + index_type num_vecs_; //the same as restart + vector_type* vec_rv_{nullptr}; + vector_type* vec_Hcolumn_{nullptr}; - real_type* h_L_; - real_type* h_rv_; - real_type* h_aux_; - VectorHandler* vector_handler_; + real_type* h_L_{nullptr}; + real_type* h_rv_{nullptr}; + real_type* h_aux_{nullptr}; + VectorHandler* vector_handler_{nullptr}; - vector_type* vec_v_; // aux variable - vector_type* vec_w_; // aux variable + vector_type* vec_v_{nullptr}; // aux variable + vector_type* vec_w_{nullptr}; // aux variable }; }//namespace diff --git a/tests/functionality/testKLU_Rf_FGMRES.cpp b/tests/functionality/testKLU_Rf_FGMRES.cpp index 9882d35b..1575f2af 100644 --- a/tests/functionality/testKLU_Rf_FGMRES.cpp +++ b/tests/functionality/testKLU_Rf_FGMRES.cpp @@ -42,7 +42,7 @@ int main(int argc, char *argv[]) KLU->setupParameters(1, 0.1, false); ReSolve::LinSolverDirectCuSolverRf* Rf = new ReSolve::LinSolverDirectCuSolverRf; - ReSolve::GramSchmidt* GS = new ReSolve::GramSchmidt(vector_handler, ReSolve::mgs_pm); + ReSolve::GramSchmidt* GS = new ReSolve::GramSchmidt(vector_handler, ReSolve::GramSchmidt::mgs_pm); ReSolve::LinSolverIterativeFGMRES* FGMRES = new ReSolve::LinSolverIterativeFGMRES(matrix_handler, vector_handler, GS); // Input to this code is location of `data` directory where matrix files are stored const std::string data_path = (argc == 2) ? argv[1] : "./"; diff --git a/tests/unit/vector/GramSchmidtTests.hpp b/tests/unit/vector/GramSchmidtTests.hpp index 8c46d6df..2ed889db 100644 --- a/tests/unit/vector/GramSchmidtTests.hpp +++ b/tests/unit/vector/GramSchmidtTests.hpp @@ -25,15 +25,46 @@ namespace ReSolve { TestOutcome GramSchmidtConstructor() { TestStatus status; - status.skipTest(); + // status.skipTest(); + + GramSchmidt gs1; + status *= (gs1.getVariant() == GramSchmidt::mgs); + status *= (gs1.getL() == nullptr); + status *= !gs1.isSetupComplete(); + + VectorHandler vh; + GramSchmidt gs2(&vh, GramSchmidt::mgs_pm); + status *= (gs2.getVariant() == GramSchmidt::mgs_pm); + status *= (gs1.getL() == nullptr); + status *= !gs1.isSetupComplete(); return status.report(__func__); } - TestOutcome orthogonalize(index_type N, GSVariant var) + TestOutcome orthogonalize(index_type N, GramSchmidt::GSVariant var) { TestStatus status; + std::string testname(__func__); + switch(var) + { + case GramSchmidt::mgs: + testname += " (Modified Gram-Schmidt)"; + break; + case GramSchmidt::mgs_two_synch: + testname += " (Modified Gram-Schmidt 2-Sync)"; + break; + case GramSchmidt::mgs_pm: + testname += " (Post-Modern Modified Gram-Schmidt)"; + break; + case GramSchmidt::cgs1: + testname += " (Classical Gram-Schmidt)"; + break; + case GramSchmidt::cgs2: + testname += " (Reorthogonalized Classical Gram-Schmidt)"; + break; + } + ReSolve::LinAlgWorkspace* workspace = createLinAlgWorkspace(memspace_); ReSolve::VectorHandler* handler = new ReSolve::VectorHandler(workspace); @@ -42,7 +73,7 @@ namespace ReSolve { real_type* aux_data; // needed for setup V->allocate(memspace_); - if (memspace_ != "cpu") { + if (memspace_ != "cpu") { V->allocate("cpu"); } @@ -87,7 +118,7 @@ namespace ReSolve { delete V; delete GS; - return status.report(__func__); + return status.report(testname.c_str()); } private: diff --git a/tests/unit/vector/runGramSchmidtTests.cpp b/tests/unit/vector/runGramSchmidtTests.cpp index 71d93c58..eeec0583 100644 --- a/tests/unit/vector/runGramSchmidtTests.cpp +++ b/tests/unit/vector/runGramSchmidtTests.cpp @@ -11,10 +11,11 @@ int main(int argc, char* argv[]) std::cout << "Running tests with CUDA backend:\n"; ReSolve::tests::GramSchmidtTests test("cuda"); - result += test.orthogonalize(5000, ReSolve::mgs); - result += test.orthogonalize(5000, ReSolve::cgs2); - result += test.orthogonalize(5000, ReSolve::mgs_two_synch); - result += test.orthogonalize(5000, ReSolve::mgs_pm); + result += test.GramSchmidtConstructor(); + result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs); + result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs2); + result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_two_synch); + result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_pm); std::cout << "\n"; }