Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SymPy solver when using derivimplicit method with NEURON codegen #1366

Merged
merged 2 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,10 @@ int main(int argc, const char* argv[]) {
ast_to_nmodl(*ast, filepath("localize"));
}

if (sympy_conductance || sympy_analytic || sparse_solver_exists(*ast)) {
const bool sympy_derivimplicit = neuron_code && solver_exists(*ast, "derivimplicit");

if (sympy_conductance || sympy_analytic || solver_exists(*ast, "sparse") ||
sympy_derivimplicit) {
nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance()
.api()
.initialize_interpreter();
Expand All @@ -496,7 +499,7 @@ int main(int argc, const char* argv[]) {
ast_to_nmodl(*ast, filepath("sympy_conductance"));
}

if (sympy_analytic || sparse_solver_exists(*ast)) {
if (sympy_analytic || solver_exists(*ast, "sparse") || sympy_derivimplicit) {
if (!sympy_analytic) {
logger->info(
"Automatically enable sympy_analytic because it exists solver of type "
Expand Down
7 changes: 4 additions & 3 deletions src/visitors/visitor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,16 @@ std::vector<std::shared_ptr<ast::Ast>> collect_nodes(ast::Ast& node,
return visitor.lookup(node, types);
}

bool sparse_solver_exists(const ast::Ast& node) {
bool solver_exists(const ast::Ast& node, const std::string& name) {
const auto solve_blocks = collect_nodes(node, {ast::AstNodeType::SOLVE_BLOCK});
return std::any_of(solve_blocks.begin(), solve_blocks.end(), [](auto const& solve_block) {
return std::any_of(solve_blocks.begin(), solve_blocks.end(), [&name](auto const& solve_block) {
assert(solve_block);
const auto& method = dynamic_cast<ast::SolveBlock const&>(*solve_block).get_method();
return method && method->get_node_name() == "sparse";
return method && method->get_node_name() == name;
});
}


std::string to_nmodl(const ast::Ast& node, const std::set<ast::AstNodeType>& exclude_types) {
std::stringstream stream;
visitor::NmodlPrintVisitor v(stream, exclude_types);
Expand Down
3 changes: 2 additions & 1 deletion src/visitors/visitor_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ std::vector<std::shared_ptr<ast::Ast>> collect_nodes(
ast::Ast& node,
const std::vector<ast::AstNodeType>& types = {});

bool sparse_solver_exists(const ast::Ast& node);
/// Whether or not a solver of type name exists in the AST
bool solver_exists(const ast::Ast& node, const std::string& name);

/// Given AST node, return the NMODL string representation
std::string to_nmodl(const ast::Ast& node, const std::set<ast::AstNodeType>& exclude_types = {});
Expand Down
2 changes: 1 addition & 1 deletion test/usecases/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
set(NMODL_USECASE_DIRS
cnexp
solve
constant
function
procedure
Expand Down
26 changes: 0 additions & 26 deletions test/usecases/cnexp/test_array.py

This file was deleted.

26 changes: 0 additions & 26 deletions test/usecases/cnexp/test_scalar.py

This file was deleted.

30 changes: 30 additions & 0 deletions test/usecases/solve/derivimplicit_array.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
NEURON {
SUFFIX derivimplicit_array
RANGE z
}

ASSIGNED {
z[3]
}

STATE {
x
s[2]
}

INITIAL {
x = 42.0
s[0] = 0.1
s[1] = -1.0
z[0] = 0.7
z[1] = 0.8
z[2] = 0.9
}

BREAKPOINT {
SOLVE dX METHOD derivimplicit
}

DERIVATIVE dX {
x' = (s[0] + s[1])*(z[0]*z[1]*z[2])*x
}
16 changes: 16 additions & 0 deletions test/usecases/solve/derivimplicit_scalar.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
NEURON {
SUFFIX derivimplicit_scalar
}

STATE { x }

INITIAL {
x = 42
}

BREAKPOINT {
SOLVE dX METHOD derivimplicit
}

DERIVATIVE dX { x' = -x }

32 changes: 32 additions & 0 deletions test/usecases/solve/test_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np
from neuron import h, gui
from neuron.units import ms


def simulate(method, rtol, dt=None):
nseg = 1

s = h.Section()
s.insert(f"{method}_array")
s.nseg = nseg

x_hoc = h.Vector().record(getattr(s(0.5), f"_ref_x_{method}_array"))
t_hoc = h.Vector().record(h._ref_t)

h.stdinit()
if dt is not None:
h.dt = dt * ms
h.tstop = 5.0 * ms
h.run()

x = np.array(x_hoc.as_numpy())
t = np.array(t_hoc.as_numpy())

rate = (0.1 - 1.0) * (0.7 * 0.8 * 0.9)
x_exact = 42.0 * np.exp(rate * t)
np.testing.assert_allclose(x, x_exact, rtol=rtol)


if __name__ == "__main__":
simulate("cnexp", rtol=1e-12)
simulate("derivimplicit", rtol=1e-4, dt=1e-4)
31 changes: 31 additions & 0 deletions test/usecases/solve/test_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
from neuron import h, gui
from neuron.units import ms


def simulate(method, rtol, dt=None):
nseg = 1

s = h.Section()
s.insert(f"{method}_scalar")
s.nseg = nseg

x_hoc = h.Vector().record(getattr(s(0.5), f"_ref_x_{method}_scalar"))
t_hoc = h.Vector().record(h._ref_t)

h.stdinit()
if dt is not None:
h.dt = dt * ms
h.tstop = 5.0 * ms
h.run()

x = np.array(x_hoc.as_numpy())
t = np.array(t_hoc.as_numpy())

x_exact = 42.0 * np.exp(-t)
np.testing.assert_allclose(x, x_exact, rtol=rtol)
1uc marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
simulate("cnexp", rtol=1e-12)
simulate("derivimplicit", rtol=1e-3, dt=1e-4)
1uc marked this conversation as resolved.
Show resolved Hide resolved
Loading