diff --git a/resolve/LinSolverDirectLUSOL.cpp b/resolve/LinSolverDirectLUSOL.cpp index c734cb83..5cd545c5 100644 --- a/resolve/LinSolverDirectLUSOL.cpp +++ b/resolve/LinSolverDirectLUSOL.cpp @@ -192,10 +192,25 @@ namespace ReSolve int LinSolverDirectLUSOL::solve(vector_type* rhs, vector_type* x) { - if (rhs->getSize() != m_ || x->getSize() != n_ || !is_factorized_) { + if (rhs->getSize() != m_ || x->getSize() != n_) { return -1; } + if (!is_factorized_) { + out::warning() << "LinSolverDirect::solve(vector_type*, vector_type*) " + << "called on LinSolverDirectLUSOL without factorizing " + << "first!\n"; + + if (m_ == 0) { + return -1; + } + + index_type inform = factorize(); + if (inform < 0) { + return inform; + } + } + index_type mode = 5; index_type inform = 0; @@ -480,7 +495,7 @@ namespace ReSolve int LinSolverDirectLUSOL::allocateSolverData() { // LUSOL does not do symbolic analysis to determine workspace size to store - // L and U factors, so we have to guess something. See documentation for + // L and U factors, so we have to guess something. See documentation for // lena_ in resolve/lusol/lusol.f90 file. lena_ = std::max({20 * nelem_, 10 * m_, 10 * n_, 10000}); diff --git a/tests/unit/matrix/LUSOLTests.hpp b/tests/unit/matrix/LUSOLTests.hpp index c400a3d1..33b1c4fa 100644 --- a/tests/unit/matrix/LUSOLTests.hpp +++ b/tests/unit/matrix/LUSOLTests.hpp @@ -43,6 +43,36 @@ namespace ReSolve return status.report(__func__); } + TestOutcome automaticFactorization() + { + TestStatus status; + + LinSolverDirectLUSOL solver; + matrix::Coo* A = createMatrix(); + + vector::Vector rhs(A->getNumRows()); + rhs.setToConst(constants::ONE, memory::HOST); + + vector::Vector x(A->getNumColumns()); + x.allocate(memory::HOST); + + if (solver.setup(A) < 0) { + status *= false; + } + if (solver.analyze() < 0) { + status *= false; + } + if (solver.solve(&rhs, &x) < 0) { + status *= false; + } + + status *= verifyAnswer(x, solX_); + + delete A; + + return status.report(__func__); + } + TestOutcome simpleSolve() { TestStatus status; diff --git a/tests/unit/matrix/runLUSOLTests.cpp b/tests/unit/matrix/runLUSOLTests.cpp index b2fed2e9..ca49d768 100644 --- a/tests/unit/matrix/runLUSOLTests.cpp +++ b/tests/unit/matrix/runLUSOLTests.cpp @@ -11,6 +11,7 @@ int main(int, char**) ReSolve::tests::LUSOLTests test; result += test.lusolConstructor(); + result += test.automaticFactorization(); result += test.simpleSolve(); std::cout << "\n";