From 617404e75e62a1cc29ac13ac2d6ffe44465972d2 Mon Sep 17 00:00:00 2001 From: superwhiskers Date: Tue, 23 Jul 2024 15:53:46 -0400 Subject: [PATCH] tests for coo2coo and a fix for a bug in coo2csr and coo2coo --- resolve/matrix/Utilities.cpp | 4 +- tests/unit/matrix/MatrixConversionTests.hpp | 145 ++++++++++++++++---- 2 files changed, 117 insertions(+), 32 deletions(-) diff --git a/resolve/matrix/Utilities.cpp b/resolve/matrix/Utilities.cpp index aef25435..9e60034a 100644 --- a/resolve/matrix/Utilities.cpp +++ b/resolve/matrix/Utilities.cpp @@ -170,7 +170,7 @@ namespace ReSolve // spaces is equivalent to the amount of nonzeroes in the row, and if not, // shifts every subsequent row back the amount of unused spaces - for (index_type column = 0; column < n_columns - 1; column++) { + for (index_type column = 0; column < n_columns; column++) { index_type column_nnz = partitions[column + 1] - partitions[column]; if (used[column] != column_nnz) { index_type correction = column_nnz - used[column]; @@ -363,7 +363,7 @@ namespace ReSolve // spaces is equivalent to the amount of nonzeroes in the row, and if not, // shifts every subsequent row back the amount of unused spaces - for (index_type row = 0; row < n_rows - 1; row++) { + for (index_type row = 0; row < n_rows; row++) { index_type row_nnz = csr_rows[row + 1] - csr_rows[row]; if (used[row] != row_nnz) { index_type correction = row_nnz - used[row]; diff --git a/tests/unit/matrix/MatrixConversionTests.hpp b/tests/unit/matrix/MatrixConversionTests.hpp index d7f69a69..cc66cf6a 100644 --- a/tests/unit/matrix/MatrixConversionTests.hpp +++ b/tests/unit/matrix/MatrixConversionTests.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include #include #include @@ -44,12 +46,23 @@ namespace ReSolve status *= ReSolve::matrix::coo2csr(&A, &B, memory::HOST) == 0; status *= this->verifyAnswer(&B, - simple_symmetric_expected_n_, - simple_symmetric_expected_m_, - simple_symmetric_expected_nnz_, - simple_symmetric_expected_i_, - simple_symmetric_expected_j_, - simple_symmetric_expected_a_); + simple_symmetric_expected_csr_n_, + simple_symmetric_expected_csr_m_, + simple_symmetric_expected_csr_nnz_, + simple_symmetric_expected_csr_i_, + simple_symmetric_expected_csr_j_, + simple_symmetric_expected_csr_a_); + + ReSolve::matrix::Coo C(A.getNumRows(), A.getNumColumns(), 0); + + status *= ReSolve::matrix::coo2coo(&A, &C, memory::HOST) == 0; + status *= this->verifyAnswer(&C, + simple_symmetric_expected_coo_col_n_, + simple_symmetric_expected_coo_col_m_, + simple_symmetric_expected_coo_col_nnz_, + simple_symmetric_expected_coo_col_i_, + simple_symmetric_expected_coo_col_j_, + simple_symmetric_expected_coo_col_a_); return status.report(__func__); } @@ -72,12 +85,23 @@ namespace ReSolve status *= ReSolve::matrix::coo2csr(&A, &B, memory::HOST) == 0; status *= this->verifyAnswer(&B, - simple_symmetric_expected_n_, - simple_symmetric_expected_m_, - simple_symmetric_expected_nnz_, - simple_symmetric_expected_i_, - simple_symmetric_expected_j_, - simple_symmetric_expected_a_); + simple_symmetric_expected_csr_n_, + simple_symmetric_expected_csr_m_, + simple_symmetric_expected_csr_nnz_, + simple_symmetric_expected_csr_i_, + simple_symmetric_expected_csr_j_, + simple_symmetric_expected_csr_a_); + + ReSolve::matrix::Coo C(A.getNumRows(), A.getNumColumns(), 0); + + status *= ReSolve::matrix::coo2coo(&A, &C, memory::HOST) == 0; + status *= this->verifyAnswer(&C, + simple_symmetric_expected_coo_col_n_, + simple_symmetric_expected_coo_col_m_, + simple_symmetric_expected_coo_col_nnz_, + simple_symmetric_expected_coo_col_i_, + simple_symmetric_expected_coo_col_j_, + simple_symmetric_expected_coo_col_a_); return status.report(__func__); } @@ -107,6 +131,17 @@ namespace ReSolve simple_main_diagonal_only_i_j_, simple_main_diagonal_only_a_); + ReSolve::matrix::Coo C(A.getNumRows(), A.getNumColumns(), 0); + + status *= ReSolve::matrix::coo2coo(&A, &C, memory::HOST) == 0; + status *= this->verifyAnswer(&C, + simple_main_diagonal_only_n_, + simple_main_diagonal_only_m_, + simple_main_diagonal_only_nnz_, + simple_main_diagonal_only_i_j_, + simple_main_diagonal_only_i_j_, + simple_main_diagonal_only_a_); + return status.report(__func__); } @@ -126,23 +161,41 @@ namespace ReSolve status *= ReSolve::matrix::coo2csr(&A, &B, memory::HOST) == 0; status *= this->verifyAnswer(&B, - simple_asymmetric_expected_n_, - simple_asymmetric_expected_m_, - simple_asymmetric_expected_nnz_, - simple_asymmetric_expected_i_, - simple_asymmetric_expected_j_, - simple_asymmetric_expected_a_); + simple_asymmetric_expected_csr_n_, + simple_asymmetric_expected_csr_m_, + simple_asymmetric_expected_csr_nnz_, + simple_asymmetric_expected_csr_i_, + simple_asymmetric_expected_csr_j_, + simple_asymmetric_expected_csr_a_); + + ReSolve::matrix::Coo C(A.getNumRows(), A.getNumColumns(), 0); + + status *= ReSolve::matrix::coo2coo(&A, &C, memory::HOST) == 0; + status *= this->verifyAnswer(&C, + simple_asymmetric_expected_coo_col_n_, + simple_asymmetric_expected_coo_col_m_, + simple_asymmetric_expected_coo_col_nnz_, + simple_asymmetric_expected_coo_col_i_, + simple_asymmetric_expected_coo_col_j_, + simple_asymmetric_expected_coo_col_a_); return status.report(__func__); } private: - const index_type simple_symmetric_expected_n_ = 5; - const index_type simple_symmetric_expected_m_ = 5; - const index_type simple_symmetric_expected_nnz_ = 8; - index_type simple_symmetric_expected_i_[8] = {0, 1, 1, 1, 2, 3, 3, 4}; - index_type simple_symmetric_expected_j_[8] = {0, 1, 2, 3, 1, 1, 4, 3}; - real_type simple_symmetric_expected_a_[8] = {3.0, 7.0, 11.0, 7.0, 11.0, 7.0, 8.0, 8.0}; + const index_type simple_symmetric_expected_csr_n_ = 5; + const index_type simple_symmetric_expected_csr_m_ = 5; + const index_type simple_symmetric_expected_csr_nnz_ = 8; + index_type simple_symmetric_expected_csr_i_[8] = {0, 1, 1, 1, 2, 3, 3, 4}; + index_type simple_symmetric_expected_csr_j_[8] = {0, 1, 2, 3, 1, 1, 4, 3}; + real_type simple_symmetric_expected_csr_a_[8] = {3.0, 7.0, 11.0, 7.0, 11.0, 7.0, 8.0, 8.0}; + + const index_type simple_symmetric_expected_coo_col_n_ = 5; + const index_type simple_symmetric_expected_coo_col_m_ = 5; + const index_type simple_symmetric_expected_coo_col_nnz_ = 8; + index_type simple_symmetric_expected_coo_col_i_[8] = {0, 1, 2, 3, 1, 1, 4, 3}; + index_type simple_symmetric_expected_coo_col_j_[8] = {0, 1, 1, 1, 2, 3, 3, 4}; + real_type simple_symmetric_expected_coo_col_a_[8] = {3.0, 7.0, 11.0, 7.0, 11.0, 7.0, 8.0, 8.0}; const index_type simple_upper_unexpanded_symmetric_n_ = 5; const index_type simple_upper_unexpanded_symmetric_m_ = 5; @@ -171,12 +224,19 @@ namespace ReSolve index_type simple_asymmetric_j_[10] = {0, 1, 3, 1, 1, 4, 4, 3, 2, 2}; real_type simple_asymmetric_a_[10] = {2.0, 4.0, 7.0, 9.0, 6.0, 7.0, 8.0, 8.0, 5.0, 6.0}; - const index_type simple_asymmetric_expected_n_ = 5; - const index_type simple_asymmetric_expected_m_ = 5; - const index_type simple_asymmetric_expected_nnz_ = 8; - index_type simple_asymmetric_expected_i_[8] = {0, 1, 1, 1, 2, 3, 3, 4}; - index_type simple_asymmetric_expected_j_[8] = {0, 1, 2, 3, 1, 1, 4, 3}; - real_type simple_asymmetric_expected_a_[8] = {2.0, 4.0, 11.0, 7.0, 9.0, 6.0, 15.0, 8.0}; + const index_type simple_asymmetric_expected_csr_n_ = 5; + const index_type simple_asymmetric_expected_csr_m_ = 5; + const index_type simple_asymmetric_expected_csr_nnz_ = 8; + index_type simple_asymmetric_expected_csr_i_[8] = {0, 1, 1, 1, 2, 3, 3, 4}; + index_type simple_asymmetric_expected_csr_j_[8] = {0, 1, 2, 3, 1, 1, 4, 3}; + real_type simple_asymmetric_expected_csr_a_[8] = {2.0, 4.0, 11.0, 7.0, 9.0, 6.0, 15.0, 8.0}; + + const index_type simple_asymmetric_expected_coo_col_n_ = 5; + const index_type simple_asymmetric_expected_coo_col_m_ = 5; + const index_type simple_asymmetric_expected_coo_col_nnz_ = 8; + index_type simple_asymmetric_expected_coo_col_i_[8] = {0, 1, 2, 3, 1, 1, 4, 3}; + index_type simple_asymmetric_expected_coo_col_j_[8] = {0, 1, 1, 1, 2, 3, 3, 4}; + real_type simple_asymmetric_expected_coo_col_a_[8] = {2.0, 4.0, 9.0, 6.0, 11.0, 7.0, 8.0, 15.0}; bool verifyAnswer(matrix::Csr* A, const index_type& n, @@ -206,6 +266,31 @@ namespace ReSolve return true; } + + bool verifyAnswer(matrix::Coo* A, + const index_type& n, + const index_type& m, + const index_type& nnz, + index_type* is, + index_type* js, + real_type* as) + { + if (n != A->getNumRows() || m != A->getNumColumns() || nnz != A->getNnz()) { + return false; + } + + index_type* rows = A->getRowData(memory::HOST); + index_type* columns = A->getColData(memory::HOST); + real_type* values = A->getValues(memory::HOST); + + for (index_type i = 0; i < nnz; i++) { + if (rows[i] != is[i] || columns[i] != js[i] || !isEqual(values[i], as[i])) { + return false; + } + } + + return true; + } }; } // namespace tests } // namespace ReSolve