From 7b5b6ecd8570889526bd2a490a595b18256dc5f9 Mon Sep 17 00:00:00 2001 From: kswirydo Date: Sat, 28 Oct 2023 02:06:40 -0400 Subject: [PATCH] deadlock instead of segfault --- resolve/LinSolverDirectRocSolverRf.cpp | 7 +++++++ tests/functionality/testKLU_RocSolver.cpp | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/resolve/LinSolverDirectRocSolverRf.cpp b/resolve/LinSolverDirectRocSolverRf.cpp index c1ba3adcd..916672060 100644 --- a/resolve/LinSolverDirectRocSolverRf.cpp +++ b/resolve/LinSolverDirectRocSolverRf.cpp @@ -38,6 +38,7 @@ namespace ReSolve mem_.copyArrayHostToDevice(d_Q_, Q, n); + mem_.deviceSynchronize(); status_rocblas_ = rocsolver_dcsrrf_analysis(workspace_->getRocblasHandle(), n, 1, @@ -68,6 +69,7 @@ namespace ReSolve int LinSolverDirectRocSolverRf::refactorize() { int error_sum = 0; + mem_.deviceSynchronize(); status_rocblas_ = rocsolver_dcsrrf_refactlu(workspace_->getRocblasHandle(), A_->getNumRows(), A_->getNnzExpanded(), @@ -83,6 +85,7 @@ namespace ReSolve infoM_); + mem_.deviceSynchronize(); error_sum += status_rocblas_; return error_sum; @@ -92,6 +95,7 @@ namespace ReSolve int LinSolverDirectRocSolverRf::solve(vector_type* rhs) { if (solve_mode_ == 0) { + mem_.deviceSynchronize(); status_rocblas_ = rocsolver_dcsrrf_solve(workspace_->getRocblasHandle(), A_->getNumRows(), 1, @@ -104,6 +108,7 @@ namespace ReSolve rhs->getData("hip"), A_->getNumRows(), infoM_); + mem_.deviceSynchronize(); } else { // not implemented yet } @@ -116,6 +121,7 @@ namespace ReSolve x->setDataUpdated("hip"); if (solve_mode_ == 0) { + mem_.deviceSynchronize(); status_rocblas_ = rocsolver_dcsrrf_solve(workspace_->getRocblasHandle(), A_->getNumRows(), 1, @@ -128,6 +134,7 @@ namespace ReSolve x->getData("hip"), A_->getNumRows(), infoM_); + mem_.deviceSynchronize(); } else { // not implemented yet } diff --git a/tests/functionality/testKLU_RocSolver.cpp b/tests/functionality/testKLU_RocSolver.cpp index f8f2efb0e..5bac53bc6 100644 --- a/tests/functionality/testKLU_RocSolver.cpp +++ b/tests/functionality/testKLU_RocSolver.cpp @@ -192,7 +192,7 @@ printf("ERROR in sol %16.16f \n", normDiffMatrix1); status = Rf->refactorize(); error_sum += status; -#if 0 +#if 1 std::cout<<"rocSolverRf refactorization status: "<solve(vec_rhs, vec_x); error_sum += status;