Skip to content

Commit

Permalink
Make "sparse" solver check if equations are linear.
Browse files Browse the repository at this point in the history
If the system is linear, then newtons method always converges
in exactly one iteration. When using the sparse solver on
linear systems omit the newtons iteration and solve directly.

This should make the resulting code run marginally faster by
skipping the check for convergence. Currently the check for
convergence is implemented as "error = sqrt(|F|^2)".
  • Loading branch information
ctrl-z-9000-times committed May 5, 2022
1 parent cde5dbf commit 00a0296
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 42 deletions.
16 changes: 15 additions & 1 deletion nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from importlib import import_module

import sympy as sp
import itertools

# import known_functions through low-level mechanism because the ccode
# module is overwritten in sympy and contents of that submodule cannot be
Expand Down Expand Up @@ -272,6 +273,8 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):

eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants)

linear = _is_linear(eqs, state_vars, sympy_vars)

custom_fcts = _get_custom_functions(function_calls)

jacobian = sp.Matrix(eqs).jacobian(state_vars)
Expand All @@ -291,7 +294,18 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):
# interweave
code = _interweave_eqs(vecFcode, vecJcode)

return code
return code, linear


def _is_linear(eqs, state_vars, sympy_vars):
for expr in eqs:
for (x, y) in itertools.combinations_with_replacement(state_vars, 2):
try:
if not sp.Eq(sp.diff(expr, x, y), 0):
return False
except TypeError:
return False
return True


def integrate2c(diff_string, dt_var, vars, use_pade_approx=False):
Expand Down
2 changes: 2 additions & 0 deletions src/pybind/pyembed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ struct SolveNonLinearSystemExecutor: public PythonExecutor {
// output
// returns a vector of solutions, i.e. new statements to add to block:
std::vector<std::string> solutions;
// returns if the system is linear or not.
bool linear;
// may also return a python exception message:
std::string exception_message;

Expand Down
4 changes: 3 additions & 1 deletion src/pybind/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,22 @@ void SolveNonLinearSystemExecutor::operator()() {
from nmodl.ode import solve_non_lin_system
exception_message = ""
try:
solutions = solve_non_lin_system(equation_strings,
solutions, linear = solve_non_lin_system(equation_strings,
state_vars,
vars,
function_calls)
except Exception as e:
# if we fail, fail silently and return empty string
solutions = [""]
linear = False
new_local_vars = [""]
exception_message = str(e)
)",
py::globals(),
locals);
// returns a vector of solutions, i.e. new statements to add to block:
solutions = locals["solutions"].cast<std::vector<std::string>>();
linear = locals["linear"].cast<bool>();
// may also return a python exception message:
exception_message = locals["exception_message"].cast<std::string>();
}
Expand Down
10 changes: 8 additions & 2 deletions src/visitors/sympy_solver_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ void SympySolverVisitor::solve_non_linear_system(
(*solver)();
// returns a vector of solutions, i.e. new statements to add to block:
auto solutions = solver->solutions;
bool linear = solver->linear;
// may also return a python exception message:
auto exception_message = solver->exception_message;
pywrap::EmbeddedPythonLoader::get_instance().api()->destroy_nsls_executor(solver);
Expand All @@ -364,8 +365,13 @@ void SympySolverVisitor::solve_non_linear_system(
exception_message);
return;
}
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
construct_eigen_solver_block(pre_solve_statements, solutions, false);
if (!linear) {
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
}
else {
logger->debug("SympySolverVisitor :: Constructing eigen solve block");
}
construct_eigen_solver_block(pre_solve_statements, solutions, linear);
}

void SympySolverVisitor::visit_var_name(ast::VarName& node) {
Expand Down
Loading

0 comments on commit 00a0296

Please sign in to comment.