From 31a6c488548f2a94d71da5beb9d797424284600b Mon Sep 17 00:00:00 2001 From: Luc Grosheintz Date: Tue, 29 Oct 2024 14:56:10 +0100 Subject: [PATCH] Handle voltage via instance specific copy. (#1532) --- src/codegen/codegen_neuron_cpp_visitor.cpp | 15 +++--- test/usecases/voltage/accessors.mod | 18 +++++++ test/usecases/voltage/ode.mod | 17 +++++++ test/usecases/voltage/state_ode.mod | 31 ++++++++++++ test/usecases/voltage/test_voltage.py | 55 ++++++++++++++++++++++ 5 files changed, 130 insertions(+), 6 deletions(-) create mode 100644 test/usecases/voltage/accessors.mod create mode 100644 test/usecases/voltage/ode.mod create mode 100644 test/usecases/voltage/state_ode.mod create mode 100644 test/usecases/voltage/test_voltage.py diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index 1d546f623..6d8c3c930 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -795,7 +795,10 @@ std::string CodegenNeuronCppVisitor::global_variable_name(const SymbolType& symb std::string CodegenNeuronCppVisitor::get_variable_name(const std::string& name, bool use_instance) const { - const std::string& varname = update_if_ion_variable_name(name); + std::string varname = update_if_ion_variable_name(name); + if (!info.artificial_cell && varname == "v") { + varname = naming::VOLTAGE_UNUSED_VARIABLE; + } auto name_comparator = [&varname](const auto& sym) { return varname == get_name(sym); }; @@ -956,9 +959,6 @@ void CodegenNeuronCppVisitor::print_sdlists_init(bool /* print_initializers */) CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::functor_params() { auto params = internal_method_parameters(); - if (!info.artificial_cell) { - params.push_back({"", "double", "", "v"}); - } return params; } @@ -1822,7 +1822,7 @@ void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) { printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); if (!info.artificial_cell) { printer->add_line("int node_id = node_data.nodeindices[id];"); - printer->add_line("auto v = node_data.node_voltages[node_id];"); + printer->add_line("inst.v_unused[id] = node_data.node_voltages[node_id];"); } print_rename_state_vars(); @@ -2069,7 +2069,9 @@ void CodegenNeuronCppVisitor::print_nrn_state() { printer->push_block("for (int id = 0; id < nodecount; id++)"); printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); - printer->add_line("auto v = node_data.node_voltages[node_id];"); + if (!info.artificial_cell) { + printer->add_line("inst.v_unused[id] = node_data.node_voltages[node_id];"); + } /** * \todo Eigen solver node also emits IonCurVar variable in the functor @@ -2142,6 +2144,7 @@ void CodegenNeuronCppVisitor::print_nrn_current(const BreakpointBlock& node) { printer->fmt_push_block("static inline double nrn_current_{}({})", info.mod_suffix, get_parameter_str(args)); + printer->add_line("inst.v_unused[id] = v;"); printer->add_line("double current = 0.0;"); print_statement_block(*block, false, false); for (auto& current: info.currents) { diff --git a/test/usecases/voltage/accessors.mod b/test/usecases/voltage/accessors.mod new file mode 100644 index 000000000..2a3a72616 --- /dev/null +++ b/test/usecases/voltage/accessors.mod @@ -0,0 +1,18 @@ +NEURON { + SUFFIX accessors + NONSPECIFIC_CURRENT il +} + +ASSIGNED { + v + il +} + +BREAKPOINT { + il = 0.003 +} + + +FUNCTION get_voltage() { + get_voltage = v +} diff --git a/test/usecases/voltage/ode.mod b/test/usecases/voltage/ode.mod new file mode 100644 index 000000000..d0e636851 --- /dev/null +++ b/test/usecases/voltage/ode.mod @@ -0,0 +1,17 @@ +NEURON { + SUFFIX ode + NONSPECIFIC_CURRENT il +} + +ASSIGNED { + il + v +} + +FUNCTION voltage() { + voltage = 0.001 * v +} + +BREAKPOINT { + il = voltage() +} diff --git a/test/usecases/voltage/state_ode.mod b/test/usecases/voltage/state_ode.mod new file mode 100644 index 000000000..909645f32 --- /dev/null +++ b/test/usecases/voltage/state_ode.mod @@ -0,0 +1,31 @@ +NEURON { + SUFFIX state_ode + NONSPECIFIC_CURRENT il +} + +STATE { + X +} + +ASSIGNED { + il + v +} + +INITIAL { + X = v +} + +BREAKPOINT { + SOLVE eqn + il = 0.001 * X +} + +NONLINEAR eqn { LOCAL c + c = rate() + ~ X = c +} + +FUNCTION rate() { + rate = v +} diff --git a/test/usecases/voltage/test_voltage.py b/test/usecases/voltage/test_voltage.py new file mode 100644 index 000000000..59828de6b --- /dev/null +++ b/test/usecases/voltage/test_voltage.py @@ -0,0 +1,55 @@ +from neuron import h, gui + +import numpy as np + + +def test_voltage_access(): + s = h.Section() + s.insert("accessors") + + h.finitialize() + v = s(0.5).v + vinst = s(0.5).accessors.get_voltage() + # The voltage will be consistent right after + # finitialize. + assert vinst == v + + for _ in range(4): + v = s(0.5).v + h.fadvance() + vinst = s(0.5).accessors.get_voltage() + + # During timestepping the internal copy + # of the voltage lags behind the current + # voltage by some timestep. + assert vinst == v, f"{vinst = }, {v = }, delta = {vinst - v}" + + +def check_ode(mech_name, step): + s = h.Section() + s.insert(mech_name) + + h.finitialize() + + c = -0.001 / 1e-3 + + for _ in range(4): + v_expected = step(s(0.5).v, c) + h.fadvance() + np.testing.assert_approx_equal(s(0.5).v, v_expected, significant=10) + + +def test_breakpoint(): + # Results in backward Euler. + check_ode("ode", lambda v, c: (1.0 - c * h.dt) ** (-1.0) * v) + + +def test_state(): + # Effectively, the timing when states are computed results in backward Euler. + check_ode("state_ode", lambda v, c: (1.0 + c * h.dt) * v) + + +if __name__ == "__main__": + test_voltage_access() + test_breakpoint() + test_state()