Skip to content

Commit

Permalink
A working alternative triangular solver (faster) for rocsolverrf (#56)
Browse files Browse the repository at this point in the history
* a WORKING alternative triangular solver (faster) for rocsolverrf

---------

Co-authored-by: kswirydo <[email protected]>
  • Loading branch information
pelesh and kswirydo authored Nov 2, 2023
1 parent 219e645 commit 05a5b2e
Show file tree
Hide file tree
Showing 6 changed files with 287 additions and 7 deletions.
3 changes: 2 additions & 1 deletion examples/r_KLU_rocSolverRf_FGMRES.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ int main(int argc, char *argv[])
std::cout<<"KLU analysis status: "<<status<<std::endl;
status = KLU->factorize();
std::cout<<"KLU factorization status: "<<status<<std::endl;

status = KLU->solve(vec_rhs, vec_x);
std::cout<<"KLU solve status: "<<status<<std::endl;
vec_r->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
Expand All @@ -149,6 +150,7 @@ int main(int argc, char *argv[])
if (L == nullptr) {printf("ERROR");}
index_type* P = KLU->getPOrdering();
index_type* Q = KLU->getQOrdering();
Rf->setSolveMode(1);
Rf->setup(A, L, U, P, Q, vec_rhs);
Rf->refactorize();
std::cout<<"about to set FGMRES" <<std::endl;
Expand All @@ -162,7 +164,6 @@ int main(int argc, char *argv[])
std::cout<<"ROCSOLVER RF refactorization status: "<<status<<std::endl;
status = Rf->solve(vec_rhs, vec_x);
std::cout<<"ROCSOLVER RF solve status: "<<status<<std::endl;

vec_r->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
norm_b = vector_handler->dot(vec_r, vec_r, "hip");
norm_b = sqrt(norm_b);
Expand Down
213 changes: 210 additions & 3 deletions resolve/LinSolverDirectRocSolverRf.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <resolve/vector/Vector.hpp>
#include <resolve/matrix/Csr.hpp>
#include "LinSolverDirectRocSolverRf.hpp"
#include <resolve/hip/hipKernels.h>

namespace ReSolve
{
Expand All @@ -15,6 +16,12 @@ namespace ReSolve
{
mem_.deleteOnDevice(d_P_);
mem_.deleteOnDevice(d_Q_);

mem_.deleteOnDevice(d_aux1_);
mem_.deleteOnDevice(d_aux2_);

delete L_csr_;
delete U_csr_;
}

int LinSolverDirectRocSolverRf::setup(matrix::Sparse* A, matrix::Sparse* L, matrix::Sparse* U, index_type* P, index_type* Q, vector_type* rhs)
Expand Down Expand Up @@ -56,7 +63,109 @@ namespace ReSolve
mem_.deviceSynchronize();
error_sum += status_rocblas_;

// tri solve setup
if (solve_mode_ == 1) { // fast mode
L_csr_ = new ReSolve::matrix::Csr(L->getNumRows(), L->getNumColumns(), L->getNnz());
U_csr_ = new ReSolve::matrix::Csr(U->getNumRows(), U->getNumColumns(), U->getNnz());

L_csr_->allocateMatrixData(ReSolve::memory::DEVICE);
U_csr_->allocateMatrixData(ReSolve::memory::DEVICE);

rocsparse_create_mat_descr(&(descr_L_));
rocsparse_set_mat_fill_mode(descr_L_, rocsparse_fill_mode_lower);
rocsparse_set_mat_index_base(descr_L_, rocsparse_index_base_zero);

rocsparse_create_mat_descr(&(descr_U_));
rocsparse_set_mat_index_base(descr_U_, rocsparse_index_base_zero);
rocsparse_set_mat_fill_mode(descr_U_, rocsparse_fill_mode_upper);

rocsparse_create_mat_info(&info_L_);
rocsparse_create_mat_info(&info_U_);

// local variables
size_t L_buffer_size;
size_t U_buffer_size;

status_rocblas_ = rocsolver_dcsrrf_splitlu(workspace_->getRocblasHandle(),
n,
M_->getNnzExpanded(),
M_->getRowData(ReSolve::memory::DEVICE),
M_->getColData(ReSolve::memory::DEVICE),
M_->getValues(ReSolve::memory::DEVICE), //vals_,
L_csr_->getRowData(ReSolve::memory::DEVICE),
L_csr_->getColData(ReSolve::memory::DEVICE),
L_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
U_csr_->getRowData(ReSolve::memory::DEVICE),
U_csr_->getColData(ReSolve::memory::DEVICE),
U_csr_->getValues(ReSolve::memory::DEVICE));

error_sum += status_rocblas_;

status_rocsparse_ = rocsparse_dcsrsv_buffer_size(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
n,
L_csr_->getNnz(),
descr_L_,
L_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
L_csr_->getRowData(ReSolve::memory::DEVICE),
L_csr_->getColData(ReSolve::memory::DEVICE),
info_L_,
&L_buffer_size);
error_sum += status_rocsparse_;

printf("buffer size for L %d status %d \n", L_buffer_size, status_rocsparse_);
// hipMalloc((void**)&(L_buffer), L_buffer_size);

mem_.allocateBufferOnDevice(&L_buffer_, L_buffer_size);
status_rocsparse_ = rocsparse_dcsrsv_buffer_size(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
n,
U_csr_->getNnz(),
descr_U_,
U_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
U_csr_->getRowData(ReSolve::memory::DEVICE),
U_csr_->getColData(ReSolve::memory::DEVICE),
info_U_,
&U_buffer_size);
error_sum += status_rocsparse_;
// hipMalloc((void**)&(U_buffer), U_buffer_size);
mem_.allocateBufferOnDevice(&U_buffer_, U_buffer_size);
printf("buffer size for U %d status %d \n", U_buffer_size, status_rocsparse_);

status_rocsparse_ = rocsparse_dcsrsv_analysis(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
n,
L_csr_->getNnz(),
descr_L_,
L_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
L_csr_->getRowData(ReSolve::memory::DEVICE),
L_csr_->getColData(ReSolve::memory::DEVICE),
info_L_,
rocsparse_analysis_policy_force,
rocsparse_solve_policy_auto,
L_buffer_);
error_sum += status_rocsparse_;
if (status_rocsparse_!=0)printf("status after analysis 1 %d \n", status_rocsparse_);
status_rocsparse_ = rocsparse_dcsrsv_analysis(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
n,
U_csr_->getNnz(),
descr_U_,
U_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
U_csr_->getRowData(ReSolve::memory::DEVICE),
U_csr_->getColData(ReSolve::memory::DEVICE),
info_U_,
rocsparse_analysis_policy_force,
rocsparse_solve_policy_auto,
U_buffer_);
error_sum += status_rocsparse_;
if (status_rocsparse_!=0)printf("status after analysis 2 %d \n", status_rocsparse_);
//allocate aux data

mem_.allocateArrayOnDevice(&d_aux1_,n);
mem_.allocateArrayOnDevice(&d_aux2_,n);

}
return error_sum;
}

Expand All @@ -78,15 +187,38 @@ namespace ReSolve
d_Q_,
infoM_);


mem_.deviceSynchronize();
error_sum += status_rocblas_;

if (solve_mode_ == 1) {
//split M, fill L and U with correct values
printf("solve mode 1, splitting the factors again \n");
status_rocblas_ = rocsolver_dcsrrf_splitlu(workspace_->getRocblasHandle(),
A_->getNumRows(),
M_->getNnzExpanded(),
M_->getRowData(ReSolve::memory::DEVICE),
M_->getColData(ReSolve::memory::DEVICE),
M_->getValues(ReSolve::memory::DEVICE), //vals_,
L_csr_->getRowData(ReSolve::memory::DEVICE),
L_csr_->getColData(ReSolve::memory::DEVICE),
L_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
U_csr_->getRowData(ReSolve::memory::DEVICE),
U_csr_->getColData(ReSolve::memory::DEVICE),
U_csr_->getValues(ReSolve::memory::DEVICE));

mem_.deviceSynchronize();
error_sum += status_rocblas_;

}

return error_sum;
}

// solution is returned in RHS
int LinSolverDirectRocSolverRf::solve(vector_type* rhs)
{
int error_sum = 0;
if (solve_mode_ == 0) {
mem_.deviceSynchronize();
status_rocblas_ = rocsolver_dcsrrf_solve(workspace_->getRocblasHandle(),
Expand All @@ -104,15 +236,51 @@ namespace ReSolve
mem_.deviceSynchronize();
} else {
// not implemented yet
permuteVectorP(A_->getNumRows(), d_P_, rhs->getData(ReSolve::memory::DEVICE), d_aux1_);
mem_.deviceSynchronize();
rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
A_->getNumRows(),
L_csr_->getNnz(),
&(constants::ONE),
descr_L_,
L_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
L_csr_->getRowData(ReSolve::memory::DEVICE),
L_csr_->getColData(ReSolve::memory::DEVICE),
info_L_,
d_aux1_,
d_aux2_, //result
rocsparse_solve_policy_auto,
L_buffer_);
error_sum += status_rocsparse_;

rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
A_->getNumRows(),
U_csr_->getNnz(),
&(constants::ONE),
descr_L_,
U_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
U_csr_->getRowData(ReSolve::memory::DEVICE),
U_csr_->getColData(ReSolve::memory::DEVICE),
info_U_,
d_aux2_, //input
d_aux1_,//result
rocsparse_solve_policy_auto,
U_buffer_);
error_sum += status_rocsparse_;

permuteVectorQ(A_->getNumRows(), d_Q_,d_aux1_,rhs->getData(ReSolve::memory::DEVICE));
mem_.deviceSynchronize();
}
return status_rocblas_;
return error_sum;
}

int LinSolverDirectRocSolverRf::solve(vector_type* rhs, vector_type* x)
{
x->update(rhs->getData(ReSolve::memory::DEVICE), ReSolve::memory::DEVICE, ReSolve::memory::DEVICE);
x->setDataUpdated(ReSolve::memory::DEVICE);

int error_sum = 0;
if (solve_mode_ == 0) {
mem_.deviceSynchronize();
status_rocblas_ = rocsolver_dcsrrf_solve(workspace_->getRocblasHandle(),
Expand All @@ -127,11 +295,50 @@ namespace ReSolve
x->getData(ReSolve::memory::DEVICE),
A_->getNumRows(),
infoM_);
error_sum += status_rocblas_;
mem_.deviceSynchronize();
} else {
// not implemented yet

permuteVectorP(A_->getNumRows(), d_P_, rhs->getData(ReSolve::memory::DEVICE), d_aux1_);
mem_.deviceSynchronize();

rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
A_->getNumRows(),
L_csr_->getNnz(),
&(constants::ONE),
descr_L_,
L_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
L_csr_->getRowData(ReSolve::memory::DEVICE),
L_csr_->getColData(ReSolve::memory::DEVICE),
info_L_,
d_aux1_,
d_aux2_, //result
rocsparse_solve_policy_auto,
L_buffer_);
error_sum += status_rocsparse_;

rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
A_->getNumRows(),
U_csr_->getNnz(),
&(constants::ONE),
descr_U_,
U_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
U_csr_->getRowData(ReSolve::memory::DEVICE),
U_csr_->getColData(ReSolve::memory::DEVICE),
info_U_,
d_aux2_, //input
d_aux1_,//result
rocsparse_solve_policy_auto,
U_buffer_);
error_sum += status_rocsparse_;

permuteVectorQ(A_->getNumRows(), d_Q_,d_aux1_,x->getData(ReSolve::memory::DEVICE));
mem_.deviceSynchronize();
}
return status_rocblas_;
return error_sum;
}

int LinSolverDirectRocSolverRf::setSolveMode(int mode)
Expand Down
22 changes: 19 additions & 3 deletions resolve/LinSolverDirectRocSolverRf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ namespace ReSolve
int getSolveMode(); //should be enum too

private:
rocblas_status status_rocblas_;

rocblas_status status_rocblas_;
rocsparse_status status_rocsparse_;
index_type* d_P_;
index_type* d_Q_;

Expand All @@ -54,6 +54,22 @@ namespace ReSolve
void addFactors(matrix::Sparse* L, matrix::Sparse* U); //create L+U from sepeate L, U factors
rocsolver_rfinfo infoM_;
matrix::Sparse* M_;//the matrix that contains added factors
int solve_mode_;
int solve_mode_; // 0 is default and 1 is fast

// not used by default - for fast solve
rocsparse_mat_descr descr_L_{nullptr};
rocsparse_mat_descr descr_U_{nullptr};

rocsparse_mat_info info_L_{nullptr};
rocsparse_mat_info info_U_{nullptr};

void* L_buffer_{nullptr};
void* U_buffer_{nullptr};

ReSolve::matrix::Csr* L_csr_;
ReSolve::matrix::Csr* U_csr_;

real_type* d_aux1_{nullptr};
real_type* d_aux2_{nullptr};
};
}
11 changes: 11 additions & 0 deletions resolve/hip/hipKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,14 @@ void matrix_row_sums(int n,
int* a_ia,
double* a_val,
double* result);

// needed for triangular solve

void permuteVectorP(int n,
int* perm_vector,
double* vec_in,
double* vec_out);
void permuteVectorQ(int n,
int* perm_vector,
double* vec_in,
double* vec_out);
Loading

0 comments on commit 05a5b2e

Please sign in to comment.