From 3e552fd6ba90b5d2df5a46fdf11803c50a94cefb Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 26 Jul 2024 11:34:00 +0200 Subject: [PATCH 1/2] Generalize tests for solvers - rename `sparse_solver_exists` to `solver_exists` - add `name` argument to `solver_exists` to keep function generic - rename `cnexp` dir to `solve` - add tests for `derivimplicit` method --- src/main.cpp | 7 ++-- src/visitors/visitor_utils.cpp | 7 ++-- src/visitors/visitor_utils.hpp | 3 +- test/usecases/CMakeLists.txt | 2 +- test/usecases/cnexp/test_array.py | 26 -------------- test/usecases/cnexp/test_scalar.py | 26 -------------- .../usecases/{cnexp => solve}/cnexp_array.mod | 0 .../{cnexp => solve}/cnexp_scalar.mod | 0 test/usecases/solve/derivimplicit_array.mod | 30 ++++++++++++++++ test/usecases/solve/derivimplicit_scalar.mod | 16 +++++++++ test/usecases/solve/test_array.py | 34 +++++++++++++++++++ test/usecases/solve/test_scalar.py | 33 ++++++++++++++++++ 12 files changed, 125 insertions(+), 59 deletions(-) delete mode 100644 test/usecases/cnexp/test_array.py delete mode 100644 test/usecases/cnexp/test_scalar.py rename test/usecases/{cnexp => solve}/cnexp_array.mod (100%) rename test/usecases/{cnexp => solve}/cnexp_scalar.mod (100%) create mode 100644 test/usecases/solve/derivimplicit_array.mod create mode 100644 test/usecases/solve/derivimplicit_scalar.mod create mode 100644 test/usecases/solve/test_array.py create mode 100644 test/usecases/solve/test_scalar.py diff --git a/src/main.cpp b/src/main.cpp index be322c0e99..3e8d596a03 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -485,7 +485,10 @@ int main(int argc, const char* argv[]) { ast_to_nmodl(*ast, filepath("localize")); } - if (sympy_conductance || sympy_analytic || sparse_solver_exists(*ast)) { + const bool sympy_derivimplicit = neuron_code && solver_exists(*ast, "derivimplicit"); + + if (sympy_conductance || sympy_analytic || solver_exists(*ast, "sparse") || + sympy_derivimplicit) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() .api() .initialize_interpreter(); @@ -496,7 +499,7 @@ int main(int argc, const char* argv[]) { ast_to_nmodl(*ast, filepath("sympy_conductance")); } - if (sympy_analytic || sparse_solver_exists(*ast)) { + if (sympy_analytic || solver_exists(*ast, "sparse") || sympy_derivimplicit) { if (!sympy_analytic) { logger->info( "Automatically enable sympy_analytic because it exists solver of type " diff --git a/src/visitors/visitor_utils.cpp b/src/visitors/visitor_utils.cpp index 8bc49703c9..2a13a38af2 100644 --- a/src/visitors/visitor_utils.cpp +++ b/src/visitors/visitor_utils.cpp @@ -215,15 +215,16 @@ std::vector> collect_nodes(ast::Ast& node, return visitor.lookup(node, types); } -bool sparse_solver_exists(const ast::Ast& node) { +bool solver_exists(const ast::Ast& node, const std::string& name) { const auto solve_blocks = collect_nodes(node, {ast::AstNodeType::SOLVE_BLOCK}); - return std::any_of(solve_blocks.begin(), solve_blocks.end(), [](auto const& solve_block) { + return std::any_of(solve_blocks.begin(), solve_blocks.end(), [&name](auto const& solve_block) { assert(solve_block); const auto& method = dynamic_cast(*solve_block).get_method(); - return method && method->get_node_name() == "sparse"; + return method && method->get_node_name() == name; }); } + std::string to_nmodl(const ast::Ast& node, const std::set& exclude_types) { std::stringstream stream; visitor::NmodlPrintVisitor v(stream, exclude_types); diff --git a/src/visitors/visitor_utils.hpp b/src/visitors/visitor_utils.hpp index 9e7163fdb0..7105fdfd54 100644 --- a/src/visitors/visitor_utils.hpp +++ b/src/visitors/visitor_utils.hpp @@ -106,7 +106,8 @@ std::vector> collect_nodes( ast::Ast& node, const std::vector& types = {}); -bool sparse_solver_exists(const ast::Ast& node); +/// Whether or not a solver of type name exists in the AST +bool solver_exists(const ast::Ast& node, const std::string& name); /// Given AST node, return the NMODL string representation std::string to_nmodl(const ast::Ast& node, const std::set& exclude_types = {}); diff --git a/test/usecases/CMakeLists.txt b/test/usecases/CMakeLists.txt index 3923727b48..e72f7c9492 100644 --- a/test/usecases/CMakeLists.txt +++ b/test/usecases/CMakeLists.txt @@ -1,5 +1,5 @@ set(NMODL_USECASE_DIRS - cnexp + solve constant function procedure diff --git a/test/usecases/cnexp/test_array.py b/test/usecases/cnexp/test_array.py deleted file mode 100644 index 5c6ae46cb2..0000000000 --- a/test/usecases/cnexp/test_array.py +++ /dev/null @@ -1,26 +0,0 @@ -import numpy as np - -from neuron import h, gui -from neuron.units import ms - -nseg = 1 - -s = h.Section() -s.insert("cnexp_array") -s.nseg = nseg - -x_hoc = h.Vector().record(s(0.5)._ref_x_cnexp_array) -t_hoc = h.Vector().record(h._ref_t) - -h.stdinit() -h.tstop = 5.0 * ms -h.run() - -x = np.array(x_hoc.as_numpy()) -t = np.array(t_hoc.as_numpy()) - -rate = (0.1 - 1.0) * (0.7 * 0.8 * 0.9) -x_exact = 42.0 * np.exp(rate * t) -rel_err = np.abs(x - x_exact) / x_exact - -assert np.all(rel_err < 1e-12) diff --git a/test/usecases/cnexp/test_scalar.py b/test/usecases/cnexp/test_scalar.py deleted file mode 100644 index 9c451e9c2c..0000000000 --- a/test/usecases/cnexp/test_scalar.py +++ /dev/null @@ -1,26 +0,0 @@ -import numpy as np - -from neuron import h, gui -from neuron.units import ms - -nseg = 1 - -s = h.Section() -s.insert("cnexp_scalar") -s.nseg = nseg - -x_hoc = h.Vector().record(s(0.5)._ref_x_cnexp_scalar) -t_hoc = h.Vector().record(h._ref_t) - -h.stdinit() -h.tstop = 5.0 * ms -h.run() - -x = np.array(x_hoc.as_numpy()) -t = np.array(t_hoc.as_numpy()) - -x0 = 42.0 -x_exact = 42.0 * np.exp(-t) -rel_err = np.abs(x - x_exact) / x_exact - -assert np.all(rel_err < 1e-12) diff --git a/test/usecases/cnexp/cnexp_array.mod b/test/usecases/solve/cnexp_array.mod similarity index 100% rename from test/usecases/cnexp/cnexp_array.mod rename to test/usecases/solve/cnexp_array.mod diff --git a/test/usecases/cnexp/cnexp_scalar.mod b/test/usecases/solve/cnexp_scalar.mod similarity index 100% rename from test/usecases/cnexp/cnexp_scalar.mod rename to test/usecases/solve/cnexp_scalar.mod diff --git a/test/usecases/solve/derivimplicit_array.mod b/test/usecases/solve/derivimplicit_array.mod new file mode 100644 index 0000000000..e43e343135 --- /dev/null +++ b/test/usecases/solve/derivimplicit_array.mod @@ -0,0 +1,30 @@ +NEURON { + SUFFIX derivimplicit_array + RANGE z +} + +ASSIGNED { + z[3] +} + +STATE { + x + s[2] +} + +INITIAL { + x = 42.0 + s[0] = 0.1 + s[1] = -1.0 + z[0] = 0.7 + z[1] = 0.8 + z[2] = 0.9 +} + +BREAKPOINT { + SOLVE dX METHOD derivimplicit +} + +DERIVATIVE dX { + x' = (s[0] + s[1])*(z[0]*z[1]*z[2])*x +} diff --git a/test/usecases/solve/derivimplicit_scalar.mod b/test/usecases/solve/derivimplicit_scalar.mod new file mode 100644 index 0000000000..05dfeab97b --- /dev/null +++ b/test/usecases/solve/derivimplicit_scalar.mod @@ -0,0 +1,16 @@ +NEURON { + SUFFIX derivimplicit_scalar +} + +STATE { x } + +INITIAL { + x = 42 +} + +BREAKPOINT { + SOLVE dX METHOD derivimplicit +} + +DERIVATIVE dX { x' = -x } + diff --git a/test/usecases/solve/test_array.py b/test/usecases/solve/test_array.py new file mode 100644 index 0000000000..34dc4e830f --- /dev/null +++ b/test/usecases/solve/test_array.py @@ -0,0 +1,34 @@ +from typing import Optional + +import numpy as np +from neuron import h, gui +from neuron.units import ms + + +def simulate(method: str, rtol: float, dt: Optional[float] = None): + nseg = 1 + + s = h.Section() + s.insert(f"{method}_array") + s.nseg = nseg + + x_hoc = h.Vector().record(getattr(s(0.5), f"_ref_x_{method}_array")) + t_hoc = h.Vector().record(h._ref_t) + + h.stdinit() + if dt is not None: + h.dt = dt * ms + h.tstop = 5.0 * ms + h.run() + + x = np.array(x_hoc.as_numpy()) + t = np.array(t_hoc.as_numpy()) + + rate = (0.1 - 1.0) * (0.7 * 0.8 * 0.9) + x_exact = 42.0 * np.exp(rate * t) + np.testing.assert_allclose(x, x_exact, rtol=rtol) + + +if __name__ == "__main__": + simulate("cnexp", rtol=1e-12) + simulate("derivimplicit", rtol=1e-4, dt=1e-4) diff --git a/test/usecases/solve/test_scalar.py b/test/usecases/solve/test_scalar.py new file mode 100644 index 0000000000..b9c861aa08 --- /dev/null +++ b/test/usecases/solve/test_scalar.py @@ -0,0 +1,33 @@ +from typing import Optional + +import numpy as np +from neuron import h, gui +from neuron.units import ms + + +def simulate(method: str, rtol: float, dt: Optional[float] = None): + nseg = 1 + + s = h.Section() + s.insert(f"{method}_scalar") + s.nseg = nseg + + x_hoc = h.Vector().record(getattr(s(0.5), f"_ref_x_{method}_scalar")) + t_hoc = h.Vector().record(h._ref_t) + + h.stdinit() + if dt is not None: + h.dt = dt * ms + h.tstop = 5.0 * ms + h.run() + + x = np.array(x_hoc.as_numpy()) + t = np.array(t_hoc.as_numpy()) + + x_exact = 42.0 * np.exp(-t) + np.testing.assert_allclose(x, x_exact, rtol=rtol) + + +if __name__ == "__main__": + simulate("cnexp", rtol=1e-12) + simulate("derivimplicit", rtol=1e-3, dt=1e-4) From 46547909409d8f9e657b5963d102a73610a28ee6 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 26 Jul 2024 14:05:52 +0200 Subject: [PATCH 2/2] Remove type hints for now --- test/usecases/solve/test_array.py | 4 +--- test/usecases/solve/test_scalar.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/test/usecases/solve/test_array.py b/test/usecases/solve/test_array.py index 34dc4e830f..856145dfd7 100644 --- a/test/usecases/solve/test_array.py +++ b/test/usecases/solve/test_array.py @@ -1,11 +1,9 @@ -from typing import Optional - import numpy as np from neuron import h, gui from neuron.units import ms -def simulate(method: str, rtol: float, dt: Optional[float] = None): +def simulate(method, rtol, dt=None): nseg = 1 s = h.Section() diff --git a/test/usecases/solve/test_scalar.py b/test/usecases/solve/test_scalar.py index b9c861aa08..2abd3fd777 100644 --- a/test/usecases/solve/test_scalar.py +++ b/test/usecases/solve/test_scalar.py @@ -1,11 +1,9 @@ -from typing import Optional - import numpy as np from neuron import h, gui from neuron.units import ms -def simulate(method: str, rtol: float, dt: Optional[float] = None): +def simulate(method, rtol, dt=None): nseg = 1 s = h.Section()