diff --git a/src/main.cpp b/src/main.cpp index be322c0e9..3e8d596a0 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -485,7 +485,10 @@ int main(int argc, const char* argv[]) { ast_to_nmodl(*ast, filepath("localize")); } - if (sympy_conductance || sympy_analytic || sparse_solver_exists(*ast)) { + const bool sympy_derivimplicit = neuron_code && solver_exists(*ast, "derivimplicit"); + + if (sympy_conductance || sympy_analytic || solver_exists(*ast, "sparse") || + sympy_derivimplicit) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() .api() .initialize_interpreter(); @@ -496,7 +499,7 @@ int main(int argc, const char* argv[]) { ast_to_nmodl(*ast, filepath("sympy_conductance")); } - if (sympy_analytic || sparse_solver_exists(*ast)) { + if (sympy_analytic || solver_exists(*ast, "sparse") || sympy_derivimplicit) { if (!sympy_analytic) { logger->info( "Automatically enable sympy_analytic because it exists solver of type " diff --git a/src/visitors/visitor_utils.cpp b/src/visitors/visitor_utils.cpp index 8bc49703c..2a13a38af 100644 --- a/src/visitors/visitor_utils.cpp +++ b/src/visitors/visitor_utils.cpp @@ -215,15 +215,16 @@ std::vector> collect_nodes(ast::Ast& node, return visitor.lookup(node, types); } -bool sparse_solver_exists(const ast::Ast& node) { +bool solver_exists(const ast::Ast& node, const std::string& name) { const auto solve_blocks = collect_nodes(node, {ast::AstNodeType::SOLVE_BLOCK}); - return std::any_of(solve_blocks.begin(), solve_blocks.end(), [](auto const& solve_block) { + return std::any_of(solve_blocks.begin(), solve_blocks.end(), [&name](auto const& solve_block) { assert(solve_block); const auto& method = dynamic_cast(*solve_block).get_method(); - return method && method->get_node_name() == "sparse"; + return method && method->get_node_name() == name; }); } + std::string to_nmodl(const ast::Ast& node, const std::set& exclude_types) { std::stringstream stream; visitor::NmodlPrintVisitor v(stream, exclude_types); diff --git a/src/visitors/visitor_utils.hpp b/src/visitors/visitor_utils.hpp index 9e7163fdb..7105fdfd5 100644 --- a/src/visitors/visitor_utils.hpp +++ b/src/visitors/visitor_utils.hpp @@ -106,7 +106,8 @@ std::vector> collect_nodes( ast::Ast& node, const std::vector& types = {}); -bool sparse_solver_exists(const ast::Ast& node); +/// Whether or not a solver of type name exists in the AST +bool solver_exists(const ast::Ast& node, const std::string& name); /// Given AST node, return the NMODL string representation std::string to_nmodl(const ast::Ast& node, const std::set& exclude_types = {}); diff --git a/test/usecases/CMakeLists.txt b/test/usecases/CMakeLists.txt index 3923727b4..e72f7c949 100644 --- a/test/usecases/CMakeLists.txt +++ b/test/usecases/CMakeLists.txt @@ -1,5 +1,5 @@ set(NMODL_USECASE_DIRS - cnexp + solve constant function procedure diff --git a/test/usecases/cnexp/test_array.py b/test/usecases/cnexp/test_array.py deleted file mode 100644 index 5c6ae46cb..000000000 --- a/test/usecases/cnexp/test_array.py +++ /dev/null @@ -1,26 +0,0 @@ -import numpy as np - -from neuron import h, gui -from neuron.units import ms - -nseg = 1 - -s = h.Section() -s.insert("cnexp_array") -s.nseg = nseg - -x_hoc = h.Vector().record(s(0.5)._ref_x_cnexp_array) -t_hoc = h.Vector().record(h._ref_t) - -h.stdinit() -h.tstop = 5.0 * ms -h.run() - -x = np.array(x_hoc.as_numpy()) -t = np.array(t_hoc.as_numpy()) - -rate = (0.1 - 1.0) * (0.7 * 0.8 * 0.9) -x_exact = 42.0 * np.exp(rate * t) -rel_err = np.abs(x - x_exact) / x_exact - -assert np.all(rel_err < 1e-12) diff --git a/test/usecases/cnexp/test_scalar.py b/test/usecases/cnexp/test_scalar.py deleted file mode 100644 index 9c451e9c2..000000000 --- a/test/usecases/cnexp/test_scalar.py +++ /dev/null @@ -1,26 +0,0 @@ -import numpy as np - -from neuron import h, gui -from neuron.units import ms - -nseg = 1 - -s = h.Section() -s.insert("cnexp_scalar") -s.nseg = nseg - -x_hoc = h.Vector().record(s(0.5)._ref_x_cnexp_scalar) -t_hoc = h.Vector().record(h._ref_t) - -h.stdinit() -h.tstop = 5.0 * ms -h.run() - -x = np.array(x_hoc.as_numpy()) -t = np.array(t_hoc.as_numpy()) - -x0 = 42.0 -x_exact = 42.0 * np.exp(-t) -rel_err = np.abs(x - x_exact) / x_exact - -assert np.all(rel_err < 1e-12) diff --git a/test/usecases/cnexp/cnexp_array.mod b/test/usecases/solve/cnexp_array.mod similarity index 100% rename from test/usecases/cnexp/cnexp_array.mod rename to test/usecases/solve/cnexp_array.mod diff --git a/test/usecases/cnexp/cnexp_scalar.mod b/test/usecases/solve/cnexp_scalar.mod similarity index 100% rename from test/usecases/cnexp/cnexp_scalar.mod rename to test/usecases/solve/cnexp_scalar.mod diff --git a/test/usecases/solve/derivimplicit_array.mod b/test/usecases/solve/derivimplicit_array.mod new file mode 100644 index 000000000..e43e34313 --- /dev/null +++ b/test/usecases/solve/derivimplicit_array.mod @@ -0,0 +1,30 @@ +NEURON { + SUFFIX derivimplicit_array + RANGE z +} + +ASSIGNED { + z[3] +} + +STATE { + x + s[2] +} + +INITIAL { + x = 42.0 + s[0] = 0.1 + s[1] = -1.0 + z[0] = 0.7 + z[1] = 0.8 + z[2] = 0.9 +} + +BREAKPOINT { + SOLVE dX METHOD derivimplicit +} + +DERIVATIVE dX { + x' = (s[0] + s[1])*(z[0]*z[1]*z[2])*x +} diff --git a/test/usecases/solve/derivimplicit_scalar.mod b/test/usecases/solve/derivimplicit_scalar.mod new file mode 100644 index 000000000..05dfeab97 --- /dev/null +++ b/test/usecases/solve/derivimplicit_scalar.mod @@ -0,0 +1,16 @@ +NEURON { + SUFFIX derivimplicit_scalar +} + +STATE { x } + +INITIAL { + x = 42 +} + +BREAKPOINT { + SOLVE dX METHOD derivimplicit +} + +DERIVATIVE dX { x' = -x } + diff --git a/test/usecases/solve/test_array.py b/test/usecases/solve/test_array.py new file mode 100644 index 000000000..856145dfd --- /dev/null +++ b/test/usecases/solve/test_array.py @@ -0,0 +1,32 @@ +import numpy as np +from neuron import h, gui +from neuron.units import ms + + +def simulate(method, rtol, dt=None): + nseg = 1 + + s = h.Section() + s.insert(f"{method}_array") + s.nseg = nseg + + x_hoc = h.Vector().record(getattr(s(0.5), f"_ref_x_{method}_array")) + t_hoc = h.Vector().record(h._ref_t) + + h.stdinit() + if dt is not None: + h.dt = dt * ms + h.tstop = 5.0 * ms + h.run() + + x = np.array(x_hoc.as_numpy()) + t = np.array(t_hoc.as_numpy()) + + rate = (0.1 - 1.0) * (0.7 * 0.8 * 0.9) + x_exact = 42.0 * np.exp(rate * t) + np.testing.assert_allclose(x, x_exact, rtol=rtol) + + +if __name__ == "__main__": + simulate("cnexp", rtol=1e-12) + simulate("derivimplicit", rtol=1e-4, dt=1e-4) diff --git a/test/usecases/solve/test_scalar.py b/test/usecases/solve/test_scalar.py new file mode 100644 index 000000000..2abd3fd77 --- /dev/null +++ b/test/usecases/solve/test_scalar.py @@ -0,0 +1,31 @@ +import numpy as np +from neuron import h, gui +from neuron.units import ms + + +def simulate(method, rtol, dt=None): + nseg = 1 + + s = h.Section() + s.insert(f"{method}_scalar") + s.nseg = nseg + + x_hoc = h.Vector().record(getattr(s(0.5), f"_ref_x_{method}_scalar")) + t_hoc = h.Vector().record(h._ref_t) + + h.stdinit() + if dt is not None: + h.dt = dt * ms + h.tstop = 5.0 * ms + h.run() + + x = np.array(x_hoc.as_numpy()) + t = np.array(t_hoc.as_numpy()) + + x_exact = 42.0 * np.exp(-t) + np.testing.assert_allclose(x, x_exact, rtol=rtol) + + +if __name__ == "__main__": + simulate("cnexp", rtol=1e-12) + simulate("derivimplicit", rtol=1e-3, dt=1e-4)