From 5c636468156260eb8b019ae58793d5d49c2a8b45 Mon Sep 17 00:00:00 2001 From: kswirydo Date: Fri, 15 Dec 2023 17:03:33 -0500 Subject: [PATCH] fixing solve(rhs) vs solve(x, rhs) discrepancy (#119) * fixing solve(rhs) vs solve(x, rhs) discrepancy --------- Co-authored-by: Slaven Peles --- resolve/LinSolverDirectRocSolverRf.cpp | 10 ++++------ tests/functionality/testKLU_RocSolver_FGMRES.cpp | 3 +-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/resolve/LinSolverDirectRocSolverRf.cpp b/resolve/LinSolverDirectRocSolverRf.cpp index 2a7ac3f8..2b28c29c 100644 --- a/resolve/LinSolverDirectRocSolverRf.cpp +++ b/resolve/LinSolverDirectRocSolverRf.cpp @@ -261,7 +261,7 @@ namespace ReSolve L_csr_->getNnz(), &(constants::ONE), descr_L_, - L_csr_->getValues(ReSolve::memory::DEVICE), //vals_, + L_csr_->getValues(ReSolve::memory::DEVICE), L_csr_->getRowData(ReSolve::memory::DEVICE), L_csr_->getColData(ReSolve::memory::DEVICE), info_L_, @@ -271,24 +271,22 @@ namespace ReSolve L_buffer_); error_sum += status_rocsparse_; - //mem_.deviceSynchronize(); rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(), rocsparse_operation_none, A_->getNumRows(), U_csr_->getNnz(), &(constants::ONE), - descr_L_, - U_csr_->getValues(ReSolve::memory::DEVICE), //vals_, + descr_U_, + U_csr_->getValues(ReSolve::memory::DEVICE), U_csr_->getRowData(ReSolve::memory::DEVICE), U_csr_->getColData(ReSolve::memory::DEVICE), info_U_, d_aux2_, //input - d_aux1_,//result + d_aux1_, //result rocsparse_solve_policy_auto, U_buffer_); error_sum += status_rocsparse_; - //mem_.deviceSynchronize(); permuteVectorQ(A_->getNumRows(), d_Q_,d_aux1_,rhs->getData(ReSolve::memory::DEVICE)); mem_.deviceSynchronize(); } diff --git a/tests/functionality/testKLU_RocSolver_FGMRES.cpp b/tests/functionality/testKLU_RocSolver_FGMRES.cpp index 3a6b359e..577c4f84 100644 --- a/tests/functionality/testKLU_RocSolver_FGMRES.cpp +++ b/tests/functionality/testKLU_RocSolver_FGMRES.cpp @@ -197,8 +197,7 @@ int main(int argc, char *argv[]) error_sum += status; vec_x->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE); - // TODO: Investigate why results are different when using Rf->solve(vec_x) !! - status = Rf->solve(vec_rhs, vec_x); + status = Rf->solve(vec_x); error_sum += status; FGMRES->resetMatrix(A);