Skip to content

Commit

Permalink
Create Csr constructor that takes Coo as the argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
pelesh committed Dec 5, 2023
1 parent 5129c8c commit c199aa3
Show file tree
Hide file tree
Showing 11 changed files with 259 additions and 53 deletions.
8 changes: 5 additions & 3 deletions examples/r_KLU_GLU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,14 @@ int main(int argc, char *argv[])
rhs_file.close();

//Now convert to CSR.
if (i < 1) {
matrix_handler->coo2csr(A_coo, A, "cpu");
if (i < 1) {
A->updateFromCoo(A_coo, ReSolve::memory::HOST);
// matrix_handler->coo2csr(A_coo, A, "cpu");
vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::HOST);
vec_rhs->setDataUpdated(ReSolve::memory::HOST);
} else {
matrix_handler->coo2csr(A_coo, A, "cuda");
A->updateFromCoo(A_coo, ReSolve::memory::DEVICE);
// matrix_handler->coo2csr(A_coo, A, "cuda");
vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
}
std::cout<<"COO to CSR completed. Expanded NNZ: "<< A->getNnzExpanded()<<std::endl;
Expand Down
188 changes: 187 additions & 1 deletion resolve/matrix/Csr.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
#include <cstring> // <-- includes memcpy
#include <algorithm>
#include <cassert>

#include "Csr.hpp"
#include "Coo.hpp"
#include <resolve/utilities/misc/IndexValuePair.hpp>
#include <resolve/utilities/logger/Logger.hpp>

namespace ReSolve
{
using out = io::Logger;

matrix::Csr::Csr()
{
}
Expand All @@ -20,6 +27,16 @@ namespace ReSolve
{
}

matrix::Csr::Csr(matrix::Coo* A_coo, memory::MemorySpace memspace)
: Sparse(A_coo->getNumRows(),
A_coo->getNumColumns(),
A_coo->getNnz(),
A_coo->symmetric(),
A_coo->expanded())
{
coo2csr(A_coo, memspace);
}

matrix::Csr::~Csr()
{
}
Expand Down Expand Up @@ -234,7 +251,176 @@ namespace ReSolve
default:
return -1;
} // switch
}
}

int matrix::Csr::updateFromCoo(matrix::Coo* A_coo, memory::MemorySpace memspaceOut)
{
assert(n_ == A_coo->getNumRows());
assert(m_ == A_coo->getNumColumns());
assert(nnz_ == A_coo->getNnz());
assert(is_symmetric_ == A_coo->symmetric()); // <- Do we need to check for this?

return coo2csr(A_coo, memspaceOut);
}


int matrix::Csr::coo2csr(matrix::Coo* A_coo, memory::MemorySpace memspace)
{
//count nnzs first
index_type nnz_unpacked = 0;
index_type nnz = A_coo->getNnz();
index_type n = A_coo->getNumRows();
bool symmetric = A_coo->symmetric();
bool expanded = A_coo->expanded();

index_type* nnz_counts = new index_type[n];
std::fill_n(nnz_counts, n, 0);
index_type* coo_rows = A_coo->getRowData(memory::HOST);
index_type* coo_cols = A_coo->getColData(memory::HOST);
real_type* coo_vals = A_coo->getValues( memory::HOST);

index_type* diag_control = new index_type[n]; //for DEDUPLICATION of the diagonal
std::fill_n(diag_control, n, 0);
index_type nnz_unpacked_no_duplicates = 0;
index_type nnz_no_duplicates = nnz;


//maybe check if they exist?
for (index_type i = 0; i < nnz; ++i)
{
nnz_counts[coo_rows[i]]++;
nnz_unpacked++;
nnz_unpacked_no_duplicates++;
if ((coo_rows[i] != coo_cols[i])&& (symmetric) && (!expanded))
{
nnz_counts[coo_cols[i]]++;
nnz_unpacked++;
nnz_unpacked_no_duplicates++;
}
if (coo_rows[i] == coo_cols[i]){
if (diag_control[coo_rows[i]] > 0) {
//duplicate
nnz_unpacked_no_duplicates--;
nnz_no_duplicates--;
}
diag_control[coo_rows[i]]++;
}
}
this->setExpanded(true);
this->setNnzExpanded(nnz_unpacked_no_duplicates);
index_type* csr_ia = new index_type[n+1];
std::fill_n(csr_ia, n + 1, 0);
index_type* csr_ja = new index_type[nnz_unpacked];
real_type* csr_a = new real_type[nnz_unpacked];
index_type* nnz_shifts = new index_type[n];
std::fill_n(nnz_shifts, n , 0);

IndexValuePair* tmp = new IndexValuePair[nnz_unpacked];

csr_ia[0] = 0;

for (index_type i = 1; i < n + 1; ++i){
csr_ia[i] = csr_ia[i - 1] + nnz_counts[i - 1] - (diag_control[i-1] - 1);
}

int r, start;


for (index_type i = 0; i < nnz; ++i){
//which row
r = coo_rows[i];
start = csr_ia[r];

if ((start + nnz_shifts[r]) > nnz_unpacked) {
out::warning() << "index out of bounds (case 1) start: " << start << "nnz_shifts[" << r << "] = " << nnz_shifts[r] << std::endl;
}
if ((r == coo_cols[i]) && (diag_control[r] > 1)) {//diagonal, and there are duplicates
bool already_there = false;
for (index_type j = start; j < start + nnz_shifts[r]; ++j)
{
index_type c = tmp[j].getIdx();
if (c == r) {
real_type val = tmp[j].getValue();
val += coo_vals[i];
tmp[j].setValue(val);
already_there = true;
out::warning() << " duplicate found, row " << c << " adding in place " << j << " current value: " << val << std::endl;
}
}
if (!already_there){ // first time this duplicates appears

tmp[start + nnz_shifts[r]].setIdx(coo_cols[i]);
tmp[start + nnz_shifts[r]].setValue(coo_vals[i]);

nnz_shifts[r]++;
}
} else {//not diagonal
tmp[start + nnz_shifts[r]].setIdx(coo_cols[i]);
tmp[start + nnz_shifts[r]].setValue(coo_vals[i]);
nnz_shifts[r]++;

if ((coo_rows[i] != coo_cols[i]) && (symmetric == 1))
{
r = coo_cols[i];
start = csr_ia[r];

if ((start + nnz_shifts[r]) > nnz_unpacked)
out::warning() << "index out of bounds (case 2) start: " << start << "nnz_shifts[" << r << "] = " << nnz_shifts[r] << std::endl;
tmp[start + nnz_shifts[r]].setIdx(coo_rows[i]);
tmp[start + nnz_shifts[r]].setValue(coo_vals[i]);
nnz_shifts[r]++;
}
}
}
//now sort whatever is inside rows

for (int i = 0; i < n; ++i)
{

//now sorting (and adding 1)
int colStart = csr_ia[i];
int colEnd = csr_ia[i + 1];
int length = colEnd - colStart;
std::sort(&tmp[colStart],&tmp[colStart] + length);
}

for (index_type i = 0; i < nnz_unpacked; ++i)
{
csr_ja[i] = tmp[i].getIdx();
csr_a[i] = tmp[i].getValue();
}
#if 0
for (int i = 0; i<n; ++i){
printf("Row: %d \n", i);
for (int j = csr_ia[i]; j<csr_ia[i+1]; ++j){
printf("(%d %16.16f) ", csr_ja[j], csr_a[j]);
}
printf("\n");
}
#endif
this->setNnz(nnz_no_duplicates);
this->updateData(csr_ia, csr_ja, csr_a, memory::HOST, memspace);
// if (memspace == "cpu"){
// this->updateData(csr_ia, csr_ja, csr_a, memory::HOST, memory::HOST);
// } else {
// if (memspace == "cuda"){
// this->updateData(csr_ia, csr_ja, csr_a, memory::HOST, memory::DEVICE);
// } else if (memspace == "hip"){
// this->updateData(csr_ia, csr_ja, csr_a, memory::HOST, memory::DEVICE);
// } else {
// //display error
// }
// }
delete [] nnz_counts;
delete [] tmp;
delete [] nnz_shifts;
delete [] csr_ia;
delete [] csr_ja;
delete [] csr_a;
delete [] diag_control;

return 0;
}

} // namespace ReSolve

12 changes: 11 additions & 1 deletion resolve/matrix/Csr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

namespace ReSolve { namespace matrix {

// Forward declaration of Coo
class Coo;

class Csr : public Sparse
{
public:
Expand All @@ -15,6 +18,8 @@ namespace ReSolve { namespace matrix {
index_type nnz,
bool symmetric,
bool expanded);

Csr(matrix::Coo* mat, memory::MemorySpace memspace);

~Csr();

Expand All @@ -23,13 +28,18 @@ namespace ReSolve { namespace matrix {
virtual real_type* getValues( memory::MemorySpace memspace);

virtual int updateData(index_type* row_data, index_type* col_data, real_type* val_data, memory::MemorySpace memspaceIn, memory::MemorySpace memspaceOut);
virtual int updateData(index_type* row_data, index_type* col_data, real_type* val_data, index_type new_nnz, memory::MemorySpace memspaceIn, memory::MemorySpace memspaceOut);
virtual int updateData(index_type* row_data, index_type* col_data, real_type* val_data, index_type new_nnz, memory::MemorySpace memspaceIn, memory::MemorySpace memspaceOut);

virtual int allocateMatrixData(memory::MemorySpace memspace);

virtual void print() {return;}

virtual int copyData(memory::MemorySpace memspaceOut);

int updateFromCoo(matrix::Coo* mat, memory::MemorySpace memspaceOut);

private:
int coo2csr(matrix::Coo* mat, memory::MemorySpace memspace);
};

}} // namespace ReSolve::matrix
13 changes: 7 additions & 6 deletions tests/functionality/testKLU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ int main(int argc, char *argv[])
return -1;
}
ReSolve::matrix::Coo* A_coo = ReSolve::io::readMatrixFromFile(mat1);
ReSolve::matrix::Csr* A = new ReSolve::matrix::Csr(A_coo->getNumRows(),
A_coo->getNumColumns(),
A_coo->getNnz(),
A_coo->symmetric(),
A_coo->expanded());
ReSolve::matrix::Csr* A = new ReSolve::matrix::Csr(A_coo, ReSolve::memory::HOST);
// ReSolve::matrix::Csr* A = new ReSolve::matrix::Csr(A_coo->getNumRows(),
// A_coo->getNumColumns(),
// A_coo->getNnz(),
// A_coo->symmetric(),
// A_coo->expanded());
mat1.close();

// Read first rhs vector
Expand All @@ -72,7 +73,7 @@ int main(int argc, char *argv[])
rhs1_file.close();

// Convert first matrix to CSR format
matrix_handler->coo2csr(A_coo, A, "cpu");
// matrix_handler->coo2csr(A_coo, A, "cpu");
vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::HOST);
vec_rhs->setDataUpdated(ReSolve::memory::HOST);

Expand Down
13 changes: 7 additions & 6 deletions tests/functionality/testKLU_GLU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ int main(int argc, char *argv[])
return -1;
}
ReSolve::matrix::Coo* A_coo = ReSolve::io::readMatrixFromFile(mat1);
ReSolve::matrix::Csr* A = new ReSolve::matrix::Csr(A_coo->getNumRows(),
A_coo->getNumColumns(),
A_coo->getNnz(),
A_coo->symmetric(),
A_coo->expanded());
ReSolve::matrix::Csr* A = new ReSolve::matrix::Csr(A_coo, ReSolve::memory::HOST);
// ReSolve::matrix::Csr* A = new ReSolve::matrix::Csr(A_coo->getNumRows(),
// A_coo->getNumColumns(),
// A_coo->getNnz(),
// A_coo->symmetric(),
// A_coo->expanded());
mat1.close();

// Read first rhs vector
Expand All @@ -80,7 +81,7 @@ int main(int argc, char *argv[])
rhs1_file.close();

// Convert first matrix to CSR format
matrix_handler->coo2csr(A_coo, A, "cpu");
// matrix_handler->coo2csr(A_coo, A, "cpu");
vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::HOST);
vec_rhs->setDataUpdated(ReSolve::memory::HOST);

Expand Down
13 changes: 7 additions & 6 deletions tests/functionality/testKLU_Rf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ int main(int argc, char *argv[])
return -1;
}
ReSolve::matrix::Coo* A_coo = ReSolve::io::readMatrixFromFile(mat1);
ReSolve::matrix::Csr* A = new ReSolve::matrix::Csr(A_coo->getNumRows(),
A_coo->getNumColumns(),
A_coo->getNnz(),
A_coo->symmetric(),
A_coo->expanded());
ReSolve::matrix::Csr* A = new ReSolve::matrix::Csr(A_coo, ReSolve::memory::HOST);
// ReSolve::matrix::Csr* A = new ReSolve::matrix::Csr(A_coo->getNumRows(),
// A_coo->getNumColumns(),
// A_coo->getNnz(),
// A_coo->symmetric(),
// A_coo->expanded());
mat1.close();

// Read first rhs vector
Expand All @@ -78,7 +79,7 @@ int main(int argc, char *argv[])
rhs1_file.close();

// Convert first matrix to CSR format
matrix_handler->coo2csr(A_coo, A, "cpu");
// matrix_handler->coo2csr(A_coo, A, "cpu");
vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::HOST);
vec_rhs->setDataUpdated(ReSolve::memory::HOST);

Expand Down
13 changes: 7 additions & 6 deletions tests/functionality/testKLU_Rf_FGMRES.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ int main(int argc, char *argv[])
return -1;
}
ReSolve::matrix::Coo* A_coo = ReSolve::io::readMatrixFromFile(mat1);
ReSolve::matrix::Csr* A = new ReSolve::matrix::Csr(A_coo->getNumRows(),
A_coo->getNumColumns(),
A_coo->getNnz(),
A_coo->symmetric(),
A_coo->expanded());
ReSolve::matrix::Csr* A = new ReSolve::matrix::Csr(A_coo, ReSolve::memory::HOST);
// ReSolve::matrix::Csr* A = new ReSolve::matrix::Csr(A_coo->getNumRows(),
// A_coo->getNumColumns(),
// A_coo->getNnz(),
// A_coo->symmetric(),
// A_coo->expanded());
mat1.close();

// Read first rhs vector
Expand All @@ -83,7 +84,7 @@ int main(int argc, char *argv[])
rhs1_file.close();

// Convert first matrix to CSR format
matrix_handler->coo2csr(A_coo, A, "cpu");
// matrix_handler->coo2csr(A_coo, A, "cpu");
vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::HOST);
vec_rhs->setDataUpdated(ReSolve::memory::HOST);

Expand Down
Loading

0 comments on commit c199aa3

Please sign in to comment.