Skip to content

Commit

Permalink
Complete PIMPL implementation for MatrixHandler.
Browse files Browse the repository at this point in the history
  • Loading branch information
pelesh committed Oct 18, 2023
1 parent edc7e4c commit 678a324
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 202 deletions.
6 changes: 3 additions & 3 deletions examples/r_KLU_KLU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ int main(int argc, char *argv[])

ReSolve::matrix::Coo* A_coo;
ReSolve::matrix::Csr* A;
ReSolve::LinAlgWorkspace* workspace = ReSolve::createLinAlgWorkspace("cpu");
ReSolve::MatrixHandler* matrix_handler = new ReSolve::MatrixHandler(workspace);
ReSolve::VectorHandler* vector_handler = new ReSolve::VectorHandler(workspace);
ReSolve::LinAlgWorkspace* workspace = new ReSolve::LinAlgWorkspace();
ReSolve::MatrixHandler* matrix_handler = new ReSolve::MatrixHandler(workspace);
ReSolve::VectorHandler* vector_handler = new ReSolve::VectorHandler(workspace);
real_type* rhs;
real_type* x;

Expand Down
6 changes: 3 additions & 3 deletions examples/r_KLU_KLU_standalone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ int main(int argc, char *argv[])

ReSolve::matrix::Coo* A_coo;
ReSolve::matrix::Csr* A;
ReSolve::LinAlgWorkspace* workspace = ReSolve::createLinAlgWorkspace("cpu");
ReSolve::MatrixHandler* matrix_handler = new ReSolve::MatrixHandler(workspace);
ReSolve::VectorHandler* vector_handler = new ReSolve::VectorHandler(workspace);
ReSolve::LinAlgWorkspace* workspace = new ReSolve::LinAlgWorkspace();
ReSolve::MatrixHandler* matrix_handler = new ReSolve::MatrixHandler(workspace);
ReSolve::VectorHandler* vector_handler = new ReSolve::VectorHandler(workspace);
real_type* rhs;
real_type* x;

Expand Down
4 changes: 2 additions & 2 deletions resolve/LinSolverIterativeFGMRES.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ namespace ReSolve
{
if (d_V_ != nullptr) {
// cudaFree(d_V_);
delete [] d_V_;
delete d_V_;
}

if (d_Z_ != nullptr) {
// cudaFree(d_Z_);
delete [] d_Z_;
delete d_Z_;
}

}
Expand Down
39 changes: 38 additions & 1 deletion resolve/matrix/MatrixHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,59 @@ namespace ReSolve {
// Create a shortcut name for Logger static class
using out = io::Logger;

/**
* @brief Default constructor
*
* @post Instantiates CPU and CUDA matrix handlers, but does not
* create a workspace.
*
* @todo There is little utility for the default constructor. Rethink its purpose.
* Consider making it private method.
*/
MatrixHandler::MatrixHandler()
{
this->new_matrix_ = true;
cpuImpl_ = new MatrixHandlerCpu();
cudaImpl_ = new MatrixHandlerCuda();
}

/**
* @brief Destructor
*
*/
MatrixHandler::~MatrixHandler()
{
if (isCpuEnabled_) delete cpuImpl_;
if (isCudaEnabled_) delete cudaImpl_;
}

/**
* @brief Constructor taking pointer to the workspace as its parameter.
*
* @note The CPU implementation currently does not require a workspace.
* The workspace pointer parameter is provided for forward compatibility.
*/
MatrixHandler::MatrixHandler(LinAlgWorkspace* new_workspace)
{
workspace_ = new_workspace;
cpuImpl_ = new MatrixHandlerCpu(new_workspace);
isCpuEnabled_ = true;
isCudaEnabled_ = false;
}

/**
* @brief Constructor taking pointer to the CUDA workspace as its parameter.
*
* @post A CPU implementation instance is created because it is cheap and
* it does not require a workspace.
*
* @post A CUDA implementation instance is created with supplied workspace.
*/
MatrixHandler::MatrixHandler(LinAlgWorkspaceCUDA* new_workspace)
{
cpuImpl_ = new MatrixHandlerCpu();
cudaImpl_ = new MatrixHandlerCuda(new_workspace);
isCpuEnabled_ = true;
isCudaEnabled_ = true;
}

void MatrixHandler::setValuesChanged(bool isValuesChanged, std::string memspace)
Expand Down
23 changes: 17 additions & 6 deletions resolve/matrix/MatrixHandler.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
// this class encapsulates various matrix manipulation operations, commonly required by linear solvers:
// this includes
// (1) Matrix format conversion: coo2csr, csr2csc
// (2) Matrix vector product (SpMV)
// (3) Matrix 1-norm
#pragma once
#include <resolve/Common.hpp>
#include <resolve/MemoryUtils.hpp>
Expand All @@ -22,19 +17,33 @@ namespace ReSolve
class Csr;
}
class LinAlgWorkspace;
class LinAlgWorkspaceCUDA;
class MatrixHandlerImpl;
}


namespace ReSolve {

/**
* @brief this class encapsulates various matrix manipulation operations,
* commonly required by linear solvers.
*
* This includes:
* - Matrix format conversion: coo2csr, csr2csc
* - Matrix vector product (SpMV)
* - Matrix 1-norm
*
* @author Kasia Swirydowicz <[email protected]>
* @author Slaven Peles <[email protected]>
*/
class MatrixHandler
{
using vector_type = vector::Vector;

public:
MatrixHandler();
MatrixHandler(LinAlgWorkspace* workspace);
MatrixHandler(LinAlgWorkspaceCUDA* workspace);
~MatrixHandler();

int csc2csr(matrix::Csc* A_csc, matrix::Csr* A_csr, std::string memspace); //memspace decides on what is returned (cpu or cuda pointer)
Expand All @@ -52,12 +61,14 @@ namespace ReSolve {
void setValuesChanged(bool toWhat, std::string memspace);

private:
LinAlgWorkspace* workspace_{nullptr};
bool new_matrix_{true}; ///< if the structure changed, you need a new handler.

MemoryHandler mem_; ///< Device memory manager object
MatrixHandlerImpl* cpuImpl_{nullptr};
MatrixHandlerImpl* cudaImpl_{nullptr};

bool isCpuEnabled_{false};
bool isCudaEnabled_{false};
};

} // namespace ReSolve
Expand Down
186 changes: 12 additions & 174 deletions resolve/matrix/MatrixHandlerCuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,206 +14,44 @@ namespace ReSolve {

MatrixHandlerCuda::MatrixHandlerCuda()
{
// new_matrix_ = true;
values_changed_ = true;
}

MatrixHandlerCuda::~MatrixHandlerCuda()
{
}

MatrixHandlerCuda::MatrixHandlerCuda(LinAlgWorkspace* new_workspace)
MatrixHandlerCuda::MatrixHandlerCuda(LinAlgWorkspaceCUDA* new_workspace)
{
workspace_ = new_workspace;
}

// bool MatrixHandlerCuda::getValuesChanged()
// {
// return this->values_changed_;
// }
MatrixHandlerCuda::MatrixHandlerCuda(LinAlgWorkspace* new_workspace)
{
workspace_ = (LinAlgWorkspaceCUDA*) new_workspace;
}

void MatrixHandlerCuda::setValuesChanged(bool values_changed)
{
values_changed_ = values_changed;
}

// int MatrixHandlerCuda::coo2csr(matrix::Coo* A_coo, matrix::Csr* A_csr, std::string memspace)
// {
// //this happens on the CPU not on the GPU
// //but will return whatever memspace requested.

// //count nnzs first

// index_type nnz_unpacked = 0;
// index_type nnz = A_coo->getNnz();
// index_type n = A_coo->getNumRows();
// bool symmetric = A_coo->symmetric();
// bool expanded = A_coo->expanded();

// index_type* nnz_counts = new index_type[n];
// std::fill_n(nnz_counts, n, 0);
// index_type* coo_rows = A_coo->getRowData("cpu");
// index_type* coo_cols = A_coo->getColData("cpu");
// real_type* coo_vals = A_coo->getValues("cpu");

// index_type* diag_control = new index_type[n]; //for DEDUPLICATION of the diagonal
// std::fill_n(diag_control, n, 0);
// index_type nnz_unpacked_no_duplicates = 0;
// index_type nnz_no_duplicates = nnz;


// //maybe check if they exist?
// for (index_type i = 0; i < nnz; ++i)
// {
// nnz_counts[coo_rows[i]]++;
// nnz_unpacked++;
// nnz_unpacked_no_duplicates++;
// if ((coo_rows[i] != coo_cols[i])&& (symmetric) && (!expanded))
// {
// nnz_counts[coo_cols[i]]++;
// nnz_unpacked++;
// nnz_unpacked_no_duplicates++;
// }
// if (coo_rows[i] == coo_cols[i]){
// if (diag_control[coo_rows[i]] > 0) {
// //duplicate
// nnz_unpacked_no_duplicates--;
// nnz_no_duplicates--;
// }
// diag_control[coo_rows[i]]++;
// }
// }
// A_csr->setExpanded(true);
// A_csr->setNnzExpanded(nnz_unpacked_no_duplicates);
// index_type* csr_ia = new index_type[n+1];
// std::fill_n(csr_ia, n + 1, 0);
// index_type* csr_ja = new index_type[nnz_unpacked];
// real_type* csr_a = new real_type[nnz_unpacked];
// index_type* nnz_shifts = new index_type[n];
// std::fill_n(nnz_shifts, n , 0);

// IndexValuePair* tmp = new IndexValuePair[nnz_unpacked];

// csr_ia[0] = 0;

// for (index_type i = 1; i < n + 1; ++i){
// csr_ia[i] = csr_ia[i - 1] + nnz_counts[i - 1] - (diag_control[i-1] - 1);
// }

// int r, start;


// for (index_type i = 0; i < nnz; ++i){
// //which row
// r = coo_rows[i];
// start = csr_ia[r];

// if ((start + nnz_shifts[r]) > nnz_unpacked) {
// out::warning() << "index out of bounds (case 1) start: " << start << "nnz_shifts[" << r << "] = " << nnz_shifts[r] << std::endl;
// }
// if ((r == coo_cols[i]) && (diag_control[r] > 1)) {//diagonal, and there are duplicates
// bool already_there = false;
// for (index_type j = start; j < start + nnz_shifts[r]; ++j)
// {
// index_type c = tmp[j].getIdx();
// if (c == r) {
// real_type val = tmp[j].getValue();
// val += coo_vals[i];
// tmp[j].setValue(val);
// already_there = true;
// out::warning() << " duplicate found, row " << c << " adding in place " << j << " current value: " << val << std::endl;
// }
// }
// if (!already_there){ // first time this duplicates appears

// tmp[start + nnz_shifts[r]].setIdx(coo_cols[i]);
// tmp[start + nnz_shifts[r]].setValue(coo_vals[i]);

// nnz_shifts[r]++;
// }
// } else {//not diagonal
// tmp[start + nnz_shifts[r]].setIdx(coo_cols[i]);
// tmp[start + nnz_shifts[r]].setValue(coo_vals[i]);
// nnz_shifts[r]++;

// if ((coo_rows[i] != coo_cols[i]) && (symmetric == 1))
// {
// r = coo_cols[i];
// start = csr_ia[r];

// if ((start + nnz_shifts[r]) > nnz_unpacked)
// out::warning() << "index out of bounds (case 2) start: " << start << "nnz_shifts[" << r << "] = " << nnz_shifts[r] << std::endl;
// tmp[start + nnz_shifts[r]].setIdx(coo_rows[i]);
// tmp[start + nnz_shifts[r]].setValue(coo_vals[i]);
// nnz_shifts[r]++;
// }
// }
// }
// //now sort whatever is inside rows

// for (int i = 0; i < n; ++i)
// {

// //now sorting (and adding 1)
// int colStart = csr_ia[i];
// int colEnd = csr_ia[i + 1];
// int length = colEnd - colStart;
// std::sort(&tmp[colStart],&tmp[colStart] + length);
// }

// for (index_type i = 0; i < nnz_unpacked; ++i)
// {
// csr_ja[i] = tmp[i].getIdx();
// csr_a[i] = tmp[i].getValue();
// }
// #if 0
// for (int i = 0; i<n; ++i){
// printf("Row: %d \n", i);
// for (int j = csr_ia[i]; j<csr_ia[i+1]; ++j){
// printf("(%d %16.16f) ", csr_ja[j], csr_a[j]);
// }
// printf("\n");
// }
// #endif
// A_csr->setNnz(nnz_no_duplicates);
// if (memspace == "cpu"){
// A_csr->updateData(csr_ia, csr_ja, csr_a, "cpu", "cpu");
// } else {
// if (memspace == "cuda"){
// A_csr->updateData(csr_ia, csr_ja, csr_a, "cpu", "cuda");
// } else {
// //display error
// }
// }
// delete [] nnz_counts;
// delete [] tmp;
// delete [] nnz_shifts;
// delete [] csr_ia;
// delete [] csr_ja;
// delete [] csr_a;
// delete [] diag_control;

// return 0;
// }

int MatrixHandlerCuda::matvec(matrix::Sparse* Ageneric,
vector_type* vec_x,
vector_type* vec_result,
const real_type* alpha,
const real_type* beta,
std::string matrixFormat)
vector_type* vec_x,
vector_type* vec_result,
const real_type* alpha,
const real_type* beta,
std::string matrixFormat)
{
using namespace constants;
int error_sum = 0;
if (matrixFormat == "csr") {
matrix::Csr* A = dynamic_cast<matrix::Csr*>(Ageneric);
//result = alpha *A*x + beta * result
cusparseStatus_t status;
// std::cout << "Matvec on NVIDIA GPU ...\n";
LinAlgWorkspaceCUDA* workspaceCUDA = (LinAlgWorkspaceCUDA*) workspace_;
LinAlgWorkspaceCUDA* workspaceCUDA = workspace_;
cusparseDnVecDescr_t vecx = workspaceCUDA->getVecX();
//printf("is vec_x NULL? %d\n", vec_x->getData("cuda") == nullptr);
//printf("is vec_result NULL? %d\n", vec_result->getData("cuda") == nullptr);
cusparseCreateDnVec(&vecx, A->getNumRows(), vec_x->getData("cuda"), CUDA_R_64F);


Expand All @@ -224,7 +62,7 @@ namespace ReSolve {

void* buffer_spmv = workspaceCUDA->getSpmvBuffer();
cusparseHandle_t handle_cusparse = workspaceCUDA->getCusparseHandle();
if (values_changed_) {
if (values_changed_) {
status = cusparseCreateCsr(&matA,
A->getNumRows(),
A->getNumColumns(),
Expand Down
Loading

0 comments on commit 678a324

Please sign in to comment.