Skip to content

Commit

Permalink
Remove duplicated code and consolidate coo2csr functionality in a sin…
Browse files Browse the repository at this point in the history
…gle function. (#173)

* Make coo2csr a standalone function.

* Remove duplicated coo2csr code.

* Fix missing header for sort function.
  • Loading branch information
pelesh authored Jul 2, 2024
1 parent b77ca06 commit f4b7f88
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 332 deletions.
2 changes: 2 additions & 0 deletions resolve/matrix/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ set(Matrix_SRC
Coo.cpp
MatrixHandler.cpp
MatrixHandlerCpu.cpp
Utilities.cpp
)

# C++ code that depends on CUDA SDK libraries
Expand All @@ -35,6 +36,7 @@ set(Matrix_HEADER_INSTALL
Csr.hpp
Csc.hpp
MatrixHandler.hpp
Utilities.hpp
)

# Build shared library ReSolve::matrix
Expand Down
145 changes: 3 additions & 142 deletions resolve/matrix/Csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

#include "Csr.hpp"
#include "Coo.hpp"
#include <resolve/utilities/misc/IndexValuePair.hpp>
#include <resolve/utilities/logger/Logger.hpp>
#include <resolve/matrix/Utilities.hpp>

namespace ReSolve
{
Expand Down Expand Up @@ -35,7 +35,7 @@ namespace ReSolve
A_coo->symmetric(),
A_coo->expanded())
{
coo2csr(A_coo, memspace);
matrix::coo2csr(A_coo, this, memspace);
}

/**
Expand Down Expand Up @@ -374,149 +374,10 @@ namespace ReSolve
assert(nnz_ == A_coo->getNnz());
assert(is_symmetric_ == A_coo->symmetric()); // <- Do we need to check for this?

return coo2csr(A_coo, memspaceOut);
return matrix::coo2csr(A_coo, this, memspaceOut);
}


int matrix::Csr::coo2csr(matrix::Coo* A_coo, memory::MemorySpace memspace)
{
//count nnzs first
index_type nnz_unpacked = 0;
index_type nnz = A_coo->getNnz();
index_type n = A_coo->getNumRows();
bool symmetric = A_coo->symmetric();
bool expanded = A_coo->expanded();

index_type* nnz_counts = new index_type[n];
std::fill_n(nnz_counts, n, 0);
index_type* coo_rows = A_coo->getRowData(memory::HOST);
index_type* coo_cols = A_coo->getColData(memory::HOST);
real_type* coo_vals = A_coo->getValues( memory::HOST);

index_type* diag_control = new index_type[n]; //for DEDUPLICATION of the diagonal
std::fill_n(diag_control, n, 0);
index_type nnz_unpacked_no_duplicates = 0;
index_type nnz_no_duplicates = nnz;


//maybe check if they exist?
for (index_type i = 0; i < nnz; ++i)
{
nnz_counts[coo_rows[i]]++;
nnz_unpacked++;
nnz_unpacked_no_duplicates++;
if ((coo_rows[i] != coo_cols[i])&& (symmetric) && (!expanded))
{
nnz_counts[coo_cols[i]]++;
nnz_unpacked++;
nnz_unpacked_no_duplicates++;
}
if (coo_rows[i] == coo_cols[i]){
if (diag_control[coo_rows[i]] > 0) {
//duplicate
nnz_unpacked_no_duplicates--;
nnz_no_duplicates--;
}
diag_control[coo_rows[i]]++;
}
}
this->setExpanded(true);
this->setNnzExpanded(nnz_unpacked_no_duplicates);
index_type* csr_ia = new index_type[n+1];
std::fill_n(csr_ia, n + 1, 0);
index_type* csr_ja = new index_type[nnz_unpacked];
real_type* csr_a = new real_type[nnz_unpacked];
index_type* nnz_shifts = new index_type[n];
std::fill_n(nnz_shifts, n , 0);

IndexValuePair* tmp = new IndexValuePair[nnz_unpacked];

csr_ia[0] = 0;

for (index_type i = 1; i < n + 1; ++i){
csr_ia[i] = csr_ia[i - 1] + nnz_counts[i - 1] - (diag_control[i-1] - 1);
}

int r, start;


for (index_type i = 0; i < nnz; ++i){
//which row
r = coo_rows[i];
start = csr_ia[r];

if ((start + nnz_shifts[r]) > nnz_unpacked) {
out::warning() << "index out of bounds (case 1) start: " << start << "nnz_shifts[" << r << "] = " << nnz_shifts[r] << std::endl;
}
if ((r == coo_cols[i]) && (diag_control[r] > 1)) {//diagonal, and there are duplicates
bool already_there = false;
for (index_type j = start; j < start + nnz_shifts[r]; ++j)
{
index_type c = tmp[j].getIdx();
if (c == r) {
real_type val = tmp[j].getValue();
val += coo_vals[i];
tmp[j].setValue(val);
already_there = true;
out::warning() << " duplicate found, row " << c << " adding in place " << j << " current value: " << val << std::endl;
}
}
if (!already_there){ // first time this duplicates appears

tmp[start + nnz_shifts[r]].setIdx(coo_cols[i]);
tmp[start + nnz_shifts[r]].setValue(coo_vals[i]);

nnz_shifts[r]++;
}
} else {//not diagonal
tmp[start + nnz_shifts[r]].setIdx(coo_cols[i]);
tmp[start + nnz_shifts[r]].setValue(coo_vals[i]);
nnz_shifts[r]++;

if ((coo_rows[i] != coo_cols[i]) && (symmetric == 1))
{
r = coo_cols[i];
start = csr_ia[r];

if ((start + nnz_shifts[r]) > nnz_unpacked)
out::warning() << "index out of bounds (case 2) start: " << start << "nnz_shifts[" << r << "] = " << nnz_shifts[r] << std::endl;
tmp[start + nnz_shifts[r]].setIdx(coo_rows[i]);
tmp[start + nnz_shifts[r]].setValue(coo_vals[i]);
nnz_shifts[r]++;
}
}
}
//now sort whatever is inside rows

for (int i = 0; i < n; ++i)
{
//now sorting (and adding 1)
int colStart = csr_ia[i];
int colEnd = csr_ia[i + 1];
int length = colEnd - colStart;
std::sort(&tmp[colStart],&tmp[colStart] + length);
}

for (index_type i = 0; i < nnz_unpacked; ++i)
{
csr_ja[i] = tmp[i].getIdx();
csr_a[i] = tmp[i].getValue();
}

this->setNnz(nnz_no_duplicates);
this->updateData(csr_ia, csr_ja, csr_a, memory::HOST, memspace);

delete [] nnz_counts;
delete [] tmp;
delete [] nnz_shifts;
delete [] csr_ia;
delete [] csr_ja;
delete [] csr_a;
delete [] diag_control;

return 0;
}

/**
* @brief Prints matrix data.
*
Expand Down
2 changes: 0 additions & 2 deletions resolve/matrix/Csr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ namespace ReSolve { namespace matrix {

int updateFromCoo(matrix::Coo* mat, memory::MemorySpace memspaceOut);

private:
int coo2csr(matrix::Coo* mat, memory::MemorySpace memspace);
};

}} // namespace ReSolve::matrix
147 changes: 2 additions & 145 deletions resolve/matrix/MatrixHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <resolve/matrix/Csc.hpp>
#include <resolve/matrix/Csr.hpp>
#include <resolve/workspace/LinAlgWorkspace.hpp>
#include <resolve/utilities/misc/IndexValuePair.hpp>
#include <resolve/matrix/Utilities.hpp>
#include "MatrixHandler.hpp"
#include "MatrixHandlerCpu.hpp"

Expand Down Expand Up @@ -116,150 +116,7 @@ namespace ReSolve {
*/
int MatrixHandler::coo2csr(matrix::Coo* A_coo, matrix::Csr* A_csr, memory::MemorySpace memspace)
{
//count nnzs first
index_type nnz_unpacked = 0;
index_type nnz = A_coo->getNnz();
index_type n = A_coo->getNumRows();
bool symmetric = A_coo->symmetric();
bool expanded = A_coo->expanded();

index_type* nnz_counts = new index_type[n];
std::fill_n(nnz_counts, n, 0);
index_type* coo_rows = A_coo->getRowData(memory::HOST);
index_type* coo_cols = A_coo->getColData(memory::HOST);
real_type* coo_vals = A_coo->getValues( memory::HOST);

index_type* diag_control = new index_type[n]; //for DEDUPLICATION of the diagonal
std::fill_n(diag_control, n, 0);
index_type nnz_unpacked_no_duplicates = 0;
index_type nnz_no_duplicates = nnz;


//maybe check if they exist?
for (index_type i = 0; i < nnz; ++i)
{
nnz_counts[coo_rows[i]]++;
nnz_unpacked++;
nnz_unpacked_no_duplicates++;
if ((coo_rows[i] != coo_cols[i])&& (symmetric) && (!expanded))
{
nnz_counts[coo_cols[i]]++;
nnz_unpacked++;
nnz_unpacked_no_duplicates++;
}
if (coo_rows[i] == coo_cols[i]){
if (diag_control[coo_rows[i]] > 0) {
//duplicate
nnz_unpacked_no_duplicates--;
nnz_no_duplicates--;
}
diag_control[coo_rows[i]]++;
}
}
A_csr->setExpanded(true);
A_csr->setNnzExpanded(nnz_unpacked_no_duplicates);
index_type* csr_ia = new index_type[n+1];
std::fill_n(csr_ia, n + 1, 0);
index_type* csr_ja = new index_type[nnz_unpacked];
real_type* csr_a = new real_type[nnz_unpacked];
index_type* nnz_shifts = new index_type[n];
std::fill_n(nnz_shifts, n , 0);

IndexValuePair* tmp = new IndexValuePair[nnz_unpacked];

csr_ia[0] = 0;

for (index_type i = 1; i < n + 1; ++i){
csr_ia[i] = csr_ia[i - 1] + nnz_counts[i - 1] - (diag_control[i-1] - 1);
}

int r, start;


for (index_type i = 0; i < nnz; ++i){
//which row
r = coo_rows[i];
start = csr_ia[r];

if ((start + nnz_shifts[r]) > nnz_unpacked) {
out::warning() << "index out of bounds (case 1) start: " << start << "nnz_shifts[" << r << "] = " << nnz_shifts[r] << std::endl;
}
if ((r == coo_cols[i]) && (diag_control[r] > 1)) {//diagonal, and there are duplicates
bool already_there = false;
for (index_type j = start; j < start + nnz_shifts[r]; ++j)
{
index_type c = tmp[j].getIdx();
if (c == r) {
real_type val = tmp[j].getValue();
val += coo_vals[i];
tmp[j].setValue(val);
already_there = true;
out::warning() << " duplicate found, row " << c << " adding in place " << j << " current value: " << val << std::endl;
}
}
if (!already_there){ // first time this duplicates appears

tmp[start + nnz_shifts[r]].setIdx(coo_cols[i]);
tmp[start + nnz_shifts[r]].setValue(coo_vals[i]);

nnz_shifts[r]++;
}
} else {//not diagonal
tmp[start + nnz_shifts[r]].setIdx(coo_cols[i]);
tmp[start + nnz_shifts[r]].setValue(coo_vals[i]);
nnz_shifts[r]++;

if ((coo_rows[i] != coo_cols[i]) && (symmetric == 1))
{
r = coo_cols[i];
start = csr_ia[r];

if ((start + nnz_shifts[r]) > nnz_unpacked)
out::warning() << "index out of bounds (case 2) start: " << start << "nnz_shifts[" << r << "] = " << nnz_shifts[r] << std::endl;
tmp[start + nnz_shifts[r]].setIdx(coo_rows[i]);
tmp[start + nnz_shifts[r]].setValue(coo_vals[i]);
nnz_shifts[r]++;
}
}
}
//now sort whatever is inside rows

for (int i = 0; i < n; ++i)
{

//now sorting (and adding 1)
int colStart = csr_ia[i];
int colEnd = csr_ia[i + 1];
int length = colEnd - colStart;
std::sort(&tmp[colStart],&tmp[colStart] + length);
}

for (index_type i = 0; i < nnz_unpacked; ++i)
{
csr_ja[i] = tmp[i].getIdx();
csr_a[i] = tmp[i].getValue();
}
#if 0
for (int i = 0; i<n; ++i){
printf("Row: %d \n", i);
for (int j = csr_ia[i]; j<csr_ia[i+1]; ++j){
printf("(%d %16.16f) ", csr_ja[j], csr_a[j]);
}
printf("\n");
}
#endif
A_csr->setNnz(nnz_no_duplicates);
A_csr->updateData(csr_ia, csr_ja, csr_a, memory::HOST, memspace);

delete [] nnz_counts;
delete [] tmp;
delete [] nnz_shifts;
delete [] csr_ia;
delete [] csr_ja;
delete [] csr_a;
delete [] diag_control;

return 0;
return matrix::coo2csr(A_coo, A_csr, memspace);
}

/**
Expand Down
Loading

0 comments on commit f4b7f88

Please sign in to comment.