Skip to content

Commit

Permalink
working cgs1 ortho
Browse files Browse the repository at this point in the history
  • Loading branch information
kswirydo committed Feb 21, 2024
1 parent fbd006c commit cb707be
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 9 deletions.
39 changes: 36 additions & 3 deletions resolve/GramSchmidt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ namespace ReSolve
delete vec_Hcolumn_;;
}

if(variant_ == cgs2) {
if (variant_ == cgs2) {
delete h_aux_;
delete vec_Hcolumn_;
}
if(variant_ == mgs_pm) {
if (variant_ == cgs1) {
delete vec_Hcolumn_;
}
if (variant_ == mgs_pm) {
delete h_aux_;
}

Expand Down Expand Up @@ -102,7 +105,10 @@ namespace ReSolve
vec_Hcolumn_ = new vector_type(num_vecs_ + 1);
vec_Hcolumn_->allocate(memspace_);
}

if(variant_ == cgs1) {
vec_Hcolumn_ = new vector_type(num_vecs_ + 1);
vec_Hcolumn_->allocate(memspace_);
}
if(variant_ == mgs_pm) {
h_aux_ = new real_type[num_vecs_ + 1];
}
Expand Down Expand Up @@ -312,6 +318,33 @@ namespace ReSolve
return 0;
break;

case cgs1:
vec_v_->setData(V->getVectorData(i + 1, memspace), memspace);
//Hcol = V(:,1:i)^T*V(:,i+1);
vector_handler_->gemv('T', n, i + 1, &ONE, &ZERO, V, vec_v_, vec_Hcolumn_, memspace);
// V(:,i+1) = V(:, i+1) - V(:,1:i)*Hcol
vector_handler_->gemv('N', n, i + 1, &ONE, &MINUSONE, V, vec_Hcolumn_, vec_v_, memspace );
mem_.deviceSynchronize();

// copy H_col to H
vec_Hcolumn_->setDataUpdated(memspace);
vec_Hcolumn_->setCurrentSize(i + 1);
vec_Hcolumn_->deepCopyVectorData(&H[ idxmap(i, 0, num_vecs_ + 1)], 0, memory::HOST);
mem_.deviceSynchronize();

t = vector_handler_->dot(vec_v_, vec_v_, memspace);
//set the last entry in Hessenberg matrix
t = sqrt(t);
H[ idxmap(i, i + 1, num_vecs_ + 1) ] = t;
if(fabs(t) > EPSILON) {
t = 1.0/t;
vector_handler_->scal(&t, vec_v_, memspace);
} else {
assert(0 && "Gram-Schmidt failed, vector with ZERO norm\n");
return -1;
}
return 0;
break;
default:
assert(0 && "Iterative refinement failed, wrong orthogonalization.\n");
return -1;
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/TestBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,14 @@ class TestBase
{
return (std::abs(a - b)/(1.0 + std::abs(b)) < eps);
}

bool isEqual(const real_type a, const real_type b, const real_type tol)
{
return (std::abs(a - b)/(1.0 + std::abs(b)) < tol);
}

protected:
std::string mem_space_;
};

}} // namespace ReSolve::tests
}} // namespace ReSolve::tests
51 changes: 46 additions & 5 deletions tests/unit/vector/GramSchmidtTests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ namespace ReSolve {

GS->orthogonalize(N, V, H, 0);
GS->orthogonalize(N, V, H, 1);

status *= verifyAnswer(V, 3, handler, memspace_);

if (var == GramSchmidt::cgs1) {
status *= verifyAnswer(V, 3, handler, memspace_, 1e-13);
} else {
status *= verifyAnswer(V, 3, handler, memspace_);
}
delete handler;
delete [] H;
delete V;
Expand Down Expand Up @@ -168,7 +170,7 @@ namespace ReSolve {
if ( (i != j) && !isEqual(ip, 0.0)) {
status = false;
std::cout << "Vectors " << i << " and " << j << " are not orthogonal!"
<< " Inner product computed: " << ip << ", expected: " << 0.0 << "\n";
<< " Inner product computed: " << ip << ", expected: " << 0.0 << "\n";
break;
}
if ( (i == j) && !isEqual(sqrt(ip), 1.0)) {
Expand All @@ -183,6 +185,45 @@ namespace ReSolve {
delete b;
return status;
}
}; // class

bool verifyAnswer(vector::Vector* x, index_type K, ReSolve::VectorHandler* handler, std::string memspace, real_type const tol)
{
ReSolve::memory::MemorySpace ms;
if (memspace == "cpu")
ms = memory::HOST;
else
ms = memory::DEVICE;

vector::Vector* a = new vector::Vector(x->getSize());
vector::Vector* b = new vector::Vector(x->getSize());

real_type ip;
bool status = true;

for (index_type i = 0; i < K; ++i) {
for (index_type j = 0; j < K; ++j) {
a->update(x->getVectorData(i, ms), ms, memory::HOST);
b->update(x->getVectorData(j, ms), ms, memory::HOST);
ip = handler->dot(a, b, memory::HOST);
if ( (i != j) && !isEqual(ip, 0.0, tol)) {
status = false;
std::cout << "Vectors " << i << " and " << j << " are not orthogonal!"
<< " Inner product computed: " << ip << ", expected: " << 0.0 << "\n";
break;
}
if ( (i == j) && !isEqual(sqrt(ip), 1.0, tol)) {
status = false;
std::cout << std::setprecision(16);
std::cout << "Vector " << i << " has norm: " << sqrt(ip) << " expected: "<< 1.0 <<"\n";
break;
}
}
}
delete a;
delete b;
return status;
}

}; // class
}
}
3 changes: 3 additions & 0 deletions tests/unit/vector/runGramSchmidtTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ int main(int, char**)
result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs2);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_two_synch);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_pm);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs1);
std::cout << "\n";
}
#endif
Expand All @@ -31,6 +32,7 @@ int main(int, char**)
result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs2);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_two_synch);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_pm);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs1);
std::cout << "\n";
}
#endif
Expand All @@ -44,6 +46,7 @@ int main(int, char**)
result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs2);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_two_synch);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_pm);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs1);
std::cout << "\n";
}
return result.summary();
Expand Down

0 comments on commit cb707be

Please sign in to comment.