From cb707be4b3f9c04765762a2cb579f154407c3bb0 Mon Sep 17 00:00:00 2001 From: kswirydo Date: Tue, 20 Feb 2024 19:40:21 -0800 Subject: [PATCH] working cgs1 ortho --- resolve/GramSchmidt.cpp | 39 +++++++++++++++-- tests/unit/TestBase.hpp | 7 +++- tests/unit/vector/GramSchmidtTests.hpp | 51 ++++++++++++++++++++--- tests/unit/vector/runGramSchmidtTests.cpp | 3 ++ 4 files changed, 91 insertions(+), 9 deletions(-) diff --git a/resolve/GramSchmidt.cpp b/resolve/GramSchmidt.cpp index c99f4a8d..8903e1b2 100644 --- a/resolve/GramSchmidt.cpp +++ b/resolve/GramSchmidt.cpp @@ -41,11 +41,14 @@ namespace ReSolve delete vec_Hcolumn_;; } - if(variant_ == cgs2) { + if (variant_ == cgs2) { delete h_aux_; delete vec_Hcolumn_; } - if(variant_ == mgs_pm) { + if (variant_ == cgs1) { + delete vec_Hcolumn_; + } + if (variant_ == mgs_pm) { delete h_aux_; } @@ -102,7 +105,10 @@ namespace ReSolve vec_Hcolumn_ = new vector_type(num_vecs_ + 1); vec_Hcolumn_->allocate(memspace_); } - + if(variant_ == cgs1) { + vec_Hcolumn_ = new vector_type(num_vecs_ + 1); + vec_Hcolumn_->allocate(memspace_); + } if(variant_ == mgs_pm) { h_aux_ = new real_type[num_vecs_ + 1]; } @@ -312,6 +318,33 @@ namespace ReSolve return 0; break; + case cgs1: + vec_v_->setData(V->getVectorData(i + 1, memspace), memspace); + //Hcol = V(:,1:i)^T*V(:,i+1); + vector_handler_->gemv('T', n, i + 1, &ONE, &ZERO, V, vec_v_, vec_Hcolumn_, memspace); + // V(:,i+1) = V(:, i+1) - V(:,1:i)*Hcol + vector_handler_->gemv('N', n, i + 1, &ONE, &MINUSONE, V, vec_Hcolumn_, vec_v_, memspace ); + mem_.deviceSynchronize(); + + // copy H_col to H + vec_Hcolumn_->setDataUpdated(memspace); + vec_Hcolumn_->setCurrentSize(i + 1); + vec_Hcolumn_->deepCopyVectorData(&H[ idxmap(i, 0, num_vecs_ + 1)], 0, memory::HOST); + mem_.deviceSynchronize(); + + t = vector_handler_->dot(vec_v_, vec_v_, memspace); + //set the last entry in Hessenberg matrix + t = sqrt(t); + H[ idxmap(i, i + 1, num_vecs_ + 1) ] = t; + if(fabs(t) > EPSILON) { + t = 1.0/t; + vector_handler_->scal(&t, vec_v_, memspace); + } else { + assert(0 && "Gram-Schmidt failed, vector with ZERO norm\n"); + return -1; + } + return 0; + break; default: assert(0 && "Iterative refinement failed, wrong orthogonalization.\n"); return -1; diff --git a/tests/unit/TestBase.hpp b/tests/unit/TestBase.hpp index 996b3799..6a3648cc 100644 --- a/tests/unit/TestBase.hpp +++ b/tests/unit/TestBase.hpp @@ -235,9 +235,14 @@ class TestBase { return (std::abs(a - b)/(1.0 + std::abs(b)) < eps); } + + bool isEqual(const real_type a, const real_type b, const real_type tol) + { + return (std::abs(a - b)/(1.0 + std::abs(b)) < tol); + } protected: std::string mem_space_; }; -}} // namespace ReSolve::tests \ No newline at end of file +}} // namespace ReSolve::tests diff --git a/tests/unit/vector/GramSchmidtTests.hpp b/tests/unit/vector/GramSchmidtTests.hpp index ddf8a406..9378ed6c 100644 --- a/tests/unit/vector/GramSchmidtTests.hpp +++ b/tests/unit/vector/GramSchmidtTests.hpp @@ -108,9 +108,11 @@ namespace ReSolve { GS->orthogonalize(N, V, H, 0); GS->orthogonalize(N, V, H, 1); - - status *= verifyAnswer(V, 3, handler, memspace_); - + if (var == GramSchmidt::cgs1) { + status *= verifyAnswer(V, 3, handler, memspace_, 1e-13); + } else { + status *= verifyAnswer(V, 3, handler, memspace_); + } delete handler; delete [] H; delete V; @@ -168,7 +170,7 @@ namespace ReSolve { if ( (i != j) && !isEqual(ip, 0.0)) { status = false; std::cout << "Vectors " << i << " and " << j << " are not orthogonal!" - << " Inner product computed: " << ip << ", expected: " << 0.0 << "\n"; + << " Inner product computed: " << ip << ", expected: " << 0.0 << "\n"; break; } if ( (i == j) && !isEqual(sqrt(ip), 1.0)) { @@ -183,6 +185,45 @@ namespace ReSolve { delete b; return status; } - }; // class + + bool verifyAnswer(vector::Vector* x, index_type K, ReSolve::VectorHandler* handler, std::string memspace, real_type const tol) + { + ReSolve::memory::MemorySpace ms; + if (memspace == "cpu") + ms = memory::HOST; + else + ms = memory::DEVICE; + + vector::Vector* a = new vector::Vector(x->getSize()); + vector::Vector* b = new vector::Vector(x->getSize()); + + real_type ip; + bool status = true; + + for (index_type i = 0; i < K; ++i) { + for (index_type j = 0; j < K; ++j) { + a->update(x->getVectorData(i, ms), ms, memory::HOST); + b->update(x->getVectorData(j, ms), ms, memory::HOST); + ip = handler->dot(a, b, memory::HOST); + if ( (i != j) && !isEqual(ip, 0.0, tol)) { + status = false; + std::cout << "Vectors " << i << " and " << j << " are not orthogonal!" + << " Inner product computed: " << ip << ", expected: " << 0.0 << "\n"; + break; + } + if ( (i == j) && !isEqual(sqrt(ip), 1.0, tol)) { + status = false; + std::cout << std::setprecision(16); + std::cout << "Vector " << i << " has norm: " << sqrt(ip) << " expected: "<< 1.0 <<"\n"; + break; + } + } + } + delete a; + delete b; + return status; + } + + }; // class } } diff --git a/tests/unit/vector/runGramSchmidtTests.cpp b/tests/unit/vector/runGramSchmidtTests.cpp index 56404518..889be202 100644 --- a/tests/unit/vector/runGramSchmidtTests.cpp +++ b/tests/unit/vector/runGramSchmidtTests.cpp @@ -17,6 +17,7 @@ int main(int, char**) result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs2); result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_two_synch); result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_pm); + result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs1); std::cout << "\n"; } #endif @@ -31,6 +32,7 @@ int main(int, char**) result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs2); result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_two_synch); result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_pm); + result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs1); std::cout << "\n"; } #endif @@ -44,6 +46,7 @@ int main(int, char**) result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs2); result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_two_synch); result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_pm); + result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs1); std::cout << "\n"; } return result.summary();