Skip to content

Commit

Permalink
address review comments and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
superwhiskers committed Jul 15, 2024
1 parent 79cae17 commit 553592f
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 92 deletions.
177 changes: 90 additions & 87 deletions resolve/matrix/Utilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,40 @@ namespace ReSolve
/**
* @brief Creates a CSR from a COO matrix.
*
* @param[in] A
* @param[out] B
* @param[in] A_coo
* @param[out] A_csr
* @return int - Error code, 0 if successful.
*
* @pre `A` is a valid sparse matrix in unordered COO format. Duplicates are allowed.
* Up-to-date values and indices must be on the host.
* @pre `A_coo` is a valid sparse matrix in unordered COO format. Duplicates are
* allowed. Up-to-date values and indices must be on the host.
*
* @post `B` represents the same matrix as `A` but is in the CSR format. `B` is
* allocated and stored on the host.
* @post `A_csr` represents the same matrix as `A_coo` but is in the CSR format.
* `A_csr` is allocated and stored on the host.
*
* @invariant `A` is not changed.
* @invariant `A_coo` is not changed.
*/
int coo2csr(matrix::Coo* A, matrix::Csr* B, memory::MemorySpace memspace)
int coo2csr(matrix::Coo* A_coo, matrix::Csr* A_csr, memory::MemorySpace memspace)
{
index_type* rows = A->getRowData(memory::HOST);
index_type* columns = A->getColData(memory::HOST);
real_type* values = A->getValues(memory::HOST);
index_type* coo_rows = A_coo->getRowData(memory::HOST);
index_type* coo_columns = A_coo->getColData(memory::HOST);
real_type* coo_values = A_coo->getValues(memory::HOST);

if (rows == nullptr || columns == nullptr || values == nullptr) {
if (coo_rows == nullptr || coo_columns == nullptr || coo_values == nullptr) {
return 0;
}

index_type nnz = A->getNnz();
index_type n_rows = A->getNumRows();
index_type* new_rows = new index_type[n_rows + 1];
std::fill_n(new_rows, n_rows + 1, 0);
index_type nnz_with_duplicates = A_coo->getNnz();
index_type n_rows = A_coo->getNumRows();
index_type* csr_rows = new index_type[n_rows + 1];
std::fill_n(csr_rows, n_rows + 1, 0);

// NOTE: this is the only auxiliary storage buffer used by this conversion
// function. it is first used to track the number of values on the
// diagonal (if the matrix is symmetric and unexpanded), then it is
// used to track the amount of spaces used in each row's value and
// column data
// used to track the amount of elements present within each row while
// the rows are being filled. this is later used during the backshifting
// step, in which the excess space is compacted so that there is no
// left over space between each row
std::unique_ptr<index_type[]> used(new index_type[n_rows]);
std::fill_n(used.get(), n_rows, 0);

Expand All @@ -62,88 +64,88 @@ namespace ReSolve
// and validates the triangularity. the branch is done to avoid the extra work if
// it's not necessary

if (!A->symmetric() || A->expanded()) {
for (index_type i = 0; i < nnz; i++) {
new_rows[rows[i] + 1]++;
if (!A_coo->symmetric() || A_coo->expanded()) {
for (index_type i = 0; i < nnz_with_duplicates; i++) {
csr_rows[coo_rows[i] + 1]++;
}
} else {
bool upper_triangular = false;
for (index_type i = 0; i < nnz; i++) {
new_rows[rows[i] + 1]++;
if (rows[i] != columns[i]) {
used[columns[i]]++;
bool is_upper_triangular = false;
for (index_type i = 0; i < nnz_with_duplicates; i++) {
csr_rows[coo_rows[i] + 1]++;
if (coo_rows[i] != coo_columns[i]) {
used[coo_columns[i]]++;

if (rows[i] > columns[i] && upper_triangular) {
if (coo_rows[i] > coo_columns[i] && is_upper_triangular) {
assert(false && "a matrix indicated to be symmetric triangular was not actually symmetric triangular");
return -1;
}
upper_triangular = rows[i] < columns[i];
is_upper_triangular = coo_rows[i] < coo_columns[i];
}
}
}

for (index_type row = 0; row < n_rows; row++) {
new_rows[row + 1] += new_rows[row] + used[row];
csr_rows[row + 1] += csr_rows[row] + used[row];
used[row] = 0;
}

index_type* new_columns = new index_type[new_rows[n_rows]];
std::fill_n(new_columns, new_rows[n_rows], -1);
real_type* new_values = new real_type[new_rows[n_rows]];
index_type nnz_expanded_with_duplicates = csr_rows[n_rows];
index_type* csr_columns = new index_type[nnz_expanded_with_duplicates];
std::fill_n(csr_columns, nnz_expanded_with_duplicates, -1);
real_type* csr_values = new real_type[nnz_expanded_with_duplicates];

// fill stage, approximately O(nnz * m) in the worst case
//
// all this does is iterate over the nonzeroes in the coo matrix,
// check to see if a value at that column already exists using binary search,
// and if it does, then insert the new value at that position (deduplicating
// and if it does, then add to the value at that position ("deduplicating"
// the matrix), otherwise, it allocates a new spot in the row (where you see
// used[rows[i]]++) and shifts everything over, performing what is effectively
// insertion sort. the lower half is conditioned on the matrix being symmetric
// and stored as either upper-triangular or lower-triangular, and just
// performs the same as what is described above, but with the indices swapped.

for (index_type i = 0; i < nnz; i++) {
index_type insertion_pos =
static_cast<index_type>(
std::lower_bound(&new_columns[new_rows[rows[i]]],
&new_columns[new_rows[rows[i]] + used[rows[i]]],
columns[i])
- new_columns);

if (new_columns[insertion_pos] == columns[i]) {
new_values[insertion_pos] = values[i];
// used[coo_rows[i]]++) and shifts everything over, performing what is
// effectively insertion sort. the lower half is conditioned on the matrix
// being symmetric and stored as either upper-triangular or lower-triangular,
// and just performs the same as what is described above, but with the
// indices swapped.

for (index_type i = 0; i < nnz_with_duplicates; i++) {
index_type* closest_position =
std::lower_bound(&csr_columns[csr_rows[coo_rows[i]]],
&csr_columns[csr_rows[coo_rows[i]] + used[coo_rows[i]]],
coo_columns[i]);
index_type insertion_offset = static_cast<index_type>(closest_position - csr_columns);

if (csr_columns[insertion_offset] == coo_columns[i]) {
csr_values[insertion_offset] += coo_values[i];
} else {
for (index_type offset = new_rows[rows[i]] + used[rows[i]]++;
offset > insertion_pos;
for (index_type offset = csr_rows[coo_rows[i]] + used[coo_rows[i]]++;
offset > insertion_offset;
offset--) {
std::swap(new_columns[offset], new_columns[offset - 1]);
std::swap(new_values[offset], new_values[offset - 1]);
std::swap(csr_columns[offset], csr_columns[offset - 1]);
std::swap(csr_values[offset], csr_values[offset - 1]);
}

new_columns[insertion_pos] = columns[i];
new_values[insertion_pos] = values[i];
csr_columns[insertion_offset] = coo_columns[i];
csr_values[insertion_offset] = coo_values[i];
}

if ((A->symmetric() && !A->expanded()) && (columns[i] != rows[i])) {
index_type mirrored_insertion_pos =
static_cast<index_type>(
std::lower_bound(&new_columns[new_rows[columns[i]]],
&new_columns[new_rows[columns[i]] + used[columns[i]]],
rows[i])
- new_columns);
if ((A_coo->symmetric() && !A_coo->expanded()) && (coo_columns[i] != coo_rows[i])) {
index_type* mirrored_closest_position =
std::lower_bound(&csr_columns[csr_rows[coo_columns[i]]],
&csr_columns[csr_rows[coo_columns[i]] + used[coo_columns[i]]],
coo_rows[i]);
index_type mirrored_insertion_offset = static_cast<index_type>(mirrored_closest_position - csr_columns);

if (new_columns[mirrored_insertion_pos] == rows[i]) {
new_values[mirrored_insertion_pos] = values[i];
if (csr_columns[mirrored_insertion_offset] == coo_rows[i]) {
csr_values[mirrored_insertion_offset] += coo_values[i];
} else {
for (index_type offset = new_rows[columns[i]] + used[columns[i]]++;
offset > mirrored_insertion_pos;
for (index_type offset = csr_rows[coo_columns[i]] + used[coo_columns[i]]++;
offset > mirrored_insertion_offset;
offset--) {
std::swap(new_columns[offset], new_columns[offset - 1]);
std::swap(new_values[offset], new_values[offset - 1]);
std::swap(csr_columns[offset], csr_columns[offset - 1]);
std::swap(csr_values[offset], csr_values[offset - 1]);
}

new_columns[mirrored_insertion_pos] = rows[i];
new_values[mirrored_insertion_pos] = values[i];
csr_columns[mirrored_insertion_offset] = coo_rows[i];
csr_values[mirrored_insertion_offset] = coo_values[i];
}
}
}
Expand All @@ -156,46 +158,47 @@ namespace ReSolve
// shifts every subsequent row back the amount of unused spaces

for (index_type row = 0; row < n_rows - 1; row++) {
index_type row_nnz = new_rows[row + 1] - new_rows[row];
index_type row_nnz = csr_rows[row + 1] - csr_rows[row];
if (used[row] != row_nnz) {
index_type correction = row_nnz - used[row];

for (index_type corrected_row = row + 1;
corrected_row < n_rows;
corrected_row++) {
for (index_type offset = new_rows[corrected_row];
offset < new_rows[corrected_row + 1];
for (index_type offset = csr_rows[corrected_row];
offset < csr_rows[corrected_row + 1];
offset++) {
new_columns[offset - correction] = new_columns[offset];
new_values[offset - correction] = new_values[offset];
csr_columns[offset - correction] = csr_columns[offset];
csr_values[offset - correction] = csr_values[offset];
}

new_rows[corrected_row] -= correction;
csr_rows[corrected_row] -= correction;
}

new_rows[n_rows] -= correction;
csr_rows[n_rows] -= correction;
}
}

B->setSymmetric(A->symmetric());
B->setNnz(new_rows[n_rows]);
index_type nnz_expanded_without_duplicates = csr_rows[n_rows];
A_csr->setSymmetric(A_coo->symmetric());
A_csr->setNnz(nnz_expanded_without_duplicates);
// NOTE: this is necessary because updateData always reads the current nnz from
// this field. see #176
B->setNnzExpanded(new_rows[n_rows]);
B->setExpanded(true);
A_csr->setNnzExpanded(nnz_expanded_without_duplicates);
A_csr->setExpanded(true);

if (B->updateData(new_rows, new_columns, new_values, memory::HOST, memspace) != 0) {
delete[] new_rows;
delete[] new_columns;
delete[] new_values;
if (A_csr->updateData(csr_rows, csr_columns, csr_values, memory::HOST, memspace) != 0) {
delete[] csr_rows;
delete[] csr_columns;
delete[] csr_values;

assert(false && "invalid state after coo -> csr conversion");
return -1;
}

delete[] new_rows;
delete[] new_columns;
delete[] new_values;
delete[] csr_rows;
delete[] csr_columns;
delete[] csr_values;

return 0;
}
Expand Down
8 changes: 3 additions & 5 deletions tests/unit/matrix/MatrixConversionTests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ namespace ReSolve
const index_type simple_symmetric_expected_nnz_ = 8;
index_type simple_symmetric_expected_i_[8] = {0, 1, 1, 1, 2, 3, 3, 4};
index_type simple_symmetric_expected_j_[8] = {0, 1, 2, 3, 1, 1, 4, 3};
real_type simple_symmetric_expected_a_[8] = {2.0, 4.0, 6.0, 7.0, 6.0, 7.0, 8.0, 8.0};
real_type simple_symmetric_expected_a_[8] = {3.0, 7.0, 11.0, 7.0, 11.0, 7.0, 8.0, 8.0};

const index_type simple_upper_unexpanded_symmetric_n_ = 5;
const index_type simple_upper_unexpanded_symmetric_m_ = 5;
Expand Down Expand Up @@ -176,7 +176,7 @@ namespace ReSolve
const index_type simple_asymmetric_expected_nnz_ = 8;
index_type simple_asymmetric_expected_i_[8] = {0, 1, 1, 1, 2, 3, 3, 4};
index_type simple_asymmetric_expected_j_[8] = {0, 1, 2, 3, 1, 1, 4, 3};
real_type simple_asymmetric_expected_a_[8] = {2.0, 4.0, 6.0, 7.0, 9.0, 6.0, 8.0, 8.0};
real_type simple_asymmetric_expected_a_[8] = {2.0, 4.0, 11.0, 7.0, 9.0, 6.0, 15.0, 8.0};

bool verifyAnswer(matrix::Csr* A,
const index_type& n,
Expand All @@ -197,9 +197,7 @@ namespace ReSolve
index_type answer_offset = 0;
for (index_type i = 0; i < A->getNumRows(); i++) {
for (index_type offset = rows[i]; offset < rows[i + 1]; offset++) {
if (i != is[answer_offset] ||
columns[offset] != js[answer_offset] ||
!isEqual(values[offset], as[answer_offset])) {
if (i != is[answer_offset] || columns[offset] != js[answer_offset] || !isEqual(values[offset], as[answer_offset])) {
return false;
}
answer_offset++;
Expand Down

0 comments on commit 553592f

Please sign in to comment.