Skip to content

Commit

Permalink
Fix FUNCTION_TABLE with POINT_PROCESS/ARTIFICIAL_CELL. (#1513)
Browse files Browse the repository at this point in the history
  • Loading branch information
1uc authored Oct 17, 2024
1 parent a071982 commit 6bcacea
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 59 deletions.
57 changes: 35 additions & 22 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,10 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body(
return;
}
const auto block_name = function_or_procedure_block->get_node_name();
if (info.point_process) {
printer->fmt_push_block("static double _hoc_{}(void* _vptr)", block_name);
} else if (wrapper_type == InterpreterWrapper::HOC) {
printer->fmt_push_block("static void _hoc_{}(void)", block_name);
if (wrapper_type == InterpreterWrapper::HOC) {
printer->fmt_push_block("{}", hoc_function_signature(block_name));
} else {
printer->fmt_push_block("static double _npy_{}(Prop* _prop)", block_name);
printer->fmt_push_block("{}", py_function_signature(block_name));
}
printer->add_multi_line(R"CODE(
double _r{};
Expand Down Expand Up @@ -416,20 +414,24 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_definitions() {
}

// HOC
printer->fmt_push_block("static void {}()", hoc_function_name(name));
printer->fmt_line("hoc_retpushx({}({}));", method_name(name), fmt::join(args, ", "));
std::string return_statement = info.point_process ? "return _ret;" : "hoc_retpushx(_ret);";

printer->fmt_push_block("{}", hoc_function_signature(name));
printer->fmt_line("double _ret = {}({});", method_name(name), fmt::join(args, ", "));
printer->add_line(return_statement);
printer->pop_block();

printer->fmt_push_block("static void {}()", hoc_function_name(table_name));
printer->fmt_line("hoc_retpushx({}());", method_name(table_name));
printer->fmt_push_block("{}", hoc_function_signature(table_name));
printer->fmt_line("double _ret = {}();", method_name(table_name));
printer->add_line(return_statement);
printer->pop_block();

// Python
printer->fmt_push_block("static double {}(Prop* _prop)", py_function_name(name));
printer->fmt_push_block("{}", py_function_signature(name));
printer->fmt_line("return {}({});", method_name(name), fmt::join(args, ", "));
printer->pop_block();

printer->fmt_push_block("static double {}(Prop* _prop)", py_function_name(table_name));
printer->fmt_push_block("{}", py_function_signature(table_name));
printer->fmt_line("return {}();", method_name(table_name));
printer->pop_block();
}
Expand Down Expand Up @@ -604,10 +606,10 @@ std::string CodegenNeuronCppVisitor::hoc_function_name(

std::string CodegenNeuronCppVisitor::hoc_function_signature(
const std::string& function_or_procedure_name) const {
return fmt::format("static {} {}(void{})",
return fmt::format("static {} {}({})",
info.point_process ? "double" : "void",
hoc_function_name(function_or_procedure_name),
info.point_process ? "*" : "");
info.point_process ? "void * _vptr" : "");
}


Expand All @@ -619,7 +621,8 @@ std::string CodegenNeuronCppVisitor::py_function_name(

std::string CodegenNeuronCppVisitor::py_function_signature(
const std::string& function_or_procedure_name) const {
return fmt::format("static double {}(Prop*)", py_function_name(function_or_procedure_name));
return fmt::format("static double {}(Prop* _prop)",
py_function_name(function_or_procedure_name));
}


Expand Down Expand Up @@ -1218,16 +1221,26 @@ void CodegenNeuronCppVisitor::print_global_variables_for_hoc() {
printer->add_line("{nullptr, nullptr}");
printer->decrease_indent();
printer->add_line("};");


auto print_py_callable_reg = [this](const auto& callables, auto get_name) {
for (const auto& callable: callables) {
const auto name = get_name(callable);
printer->fmt_line("{{\"{}\", {}}},", name, py_function_name(name));
}
};

if (!info.point_process) {
printer->push_block("static NPyDirectMechFunc npy_direct_func_proc[] =");
for (const auto& procedure: info.procedures) {
const auto proc_name = procedure->get_node_name();
printer->fmt_line("{{\"{}\", {}}},", proc_name, py_function_name(proc_name));
}
for (const auto& function: info.functions) {
const auto func_name = function->get_node_name();
printer->fmt_line("{{\"{}\", {}}},", func_name, py_function_name(func_name));
}
print_py_callable_reg(info.procedures,
[](const auto& callable) { return callable->get_node_name(); });
print_py_callable_reg(info.functions,
[](const auto& callable) { return callable->get_node_name(); });
print_py_callable_reg(info.function_tables,
[](const auto& callable) { return callable->get_node_name(); });
print_py_callable_reg(info.function_tables, [](const auto& callable) {
return "table_" + callable->get_node_name();
});
printer->add_line("{nullptr, nullptr}");
printer->pop_block(";");
}
Expand Down
6 changes: 6 additions & 0 deletions test/usecases/function_table/art_function_table.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
NEURON {
ARTIFICIAL_CELL art_function_table
}

INCLUDE "function_table.inc"

8 changes: 8 additions & 0 deletions test/usecases/function_table/function_table.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
FUNCTION_TABLE cnst1(v)
FUNCTION_TABLE cnst2(v, x)
FUNCTION_TABLE tau1(v)
FUNCTION_TABLE tau2(v, x)

FUNCTION use_tau2(v, x) {
use_tau2 = tau2(v, x)
}
8 changes: 1 addition & 7 deletions test/usecases/function_table/function_table.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,5 @@ NEURON {
SUFFIX function_table
}

FUNCTION_TABLE cnst1(v)
FUNCTION_TABLE cnst2(v, x)
FUNCTION_TABLE tau1(v)
FUNCTION_TABLE tau2(v, x)
INCLUDE "function_table.inc"

FUNCTION use_tau2(v, x) {
use_tau2 = tau2(v, x)
}
6 changes: 6 additions & 0 deletions test/usecases/function_table/point_function_table.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
NEURON {
POINT_PROCESS point_function_table
}

INCLUDE "function_table.inc"

94 changes: 64 additions & 30 deletions test/usecases/function_table/test_function_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,80 @@
import scipy


def test_constant_1d():
def make_callable(inst, name, mech_name):
if inst is None:
return getattr(h, f"{name}_{mech_name}")
else:
return getattr(inst, f"{name}")


def make_callbacks(inst, name, mech_name):
set_table = make_callable(inst, f"table_{name}", mech_name)
eval_table = make_callable(inst, name, mech_name)

return set_table, eval_table


def check_constant_1d(make_inst, mech_name):
s = h.Section()
s.insert("function_table")

inst = make_inst(s)
set_table, eval_table = make_callbacks(inst, "cnst1", mech_name)

c = 42.0
h.table_cnst1_function_table(c)
set_table(c)

for vv in np.linspace(-10.0, 10.0, 14):
np.testing.assert_equal(h.cnst1_function_table(vv), c)
np.testing.assert_equal(eval_table(vv), c)


def test_constant_2d():
def check_constant_2d(make_inst, mech_name):
s = h.Section()
s.insert("function_table")

inst = make_inst(s)
set_table, eval_table = make_callbacks(inst, "cnst2", mech_name)

c = 42.0
h.table_cnst2_function_table(c)
set_table(c)

for vv in np.linspace(-10.0, 10.0, 7):
for xx in np.linspace(-20.0, 10.0, 9):
np.testing.assert_equal(h.cnst2_function_table(vv, xx), c)
np.testing.assert_equal(eval_table(vv, xx), c)


def check_1d(make_inst, mech_name):
s = h.Section()
s.insert("function_table")

inst = make_inst(s)
set_table, eval_table = make_callbacks(inst, "tau1", mech_name)

def test_1d():
v = np.array([0.0, 1.0])
tau1 = np.array([1.0, 2.0])

h.table_tau1_function_table(h.Vector(tau1), h.Vector(v))
set_table(h.Vector(tau1), h.Vector(v))

for vv in np.linspace(v[0], v[-1], 20):
expected = np.interp(vv, v, tau1)
actual = h.tau1_function_table(vv)
actual = eval_table(vv)

np.testing.assert_approx_equal(actual, expected, significant=11)


def test_2d():
def check_2d(make_inst, mech_name):
s = h.Section()
s.insert("function_table")

inst = make_inst(s)
set_table, eval_table = make_callbacks(inst, "tau2", mech_name)
eval_use_table = make_callable(inst, "use_tau2", mech_name)

if inst is None:
setdata = getattr(h, f"setdata_{mech_name}")
setdata(s(0.5))

v = np.array([0.0, 1.0])
x = np.array([1.0, 2.0, 3.0])

Expand All @@ -50,36 +87,33 @@ def test_2d():
hoc_tau2 = h.Matrix(*tau2.shape)
hoc_tau2.from_vector(h.Vector(tau2.transpose().reshape(-1)))

h.table_tau2_function_table(
hoc_tau2._ref_x[0][0], v.size, v[0], v[-1], x.size, x[0], x[-1]
)
set_table(hoc_tau2._ref_x[0][0], v.size, v[0], v[-1], x.size, x[0], x[-1])

for vv in np.linspace(v[0], v[-1], 20):
for xx in np.linspace(x[0], x[-1], 20):
expected = scipy.interpolate.interpn((v, x), tau2, (vv, xx))
actual = h.tau2_function_table(vv, xx)
actual = eval_table(vv, xx)
actual_indirect = eval_use_table(vv, xx)

np.testing.assert_approx_equal(actual, expected, significant=11)
np.testing.assert_approx_equal(actual_indirect, expected, significant=11)


def test_use_table():
s = h.Section()
s.insert("function_table")

h.setdata_function_table(s(0.5))
def test_function_table():
variations = [
(lambda s: None, "function_table"),
(lambda s: s(0.5).function_table, "function_table"),
(lambda s: h.point_function_table(s(0.5)), "point_function_table"),
(lambda s: h.art_function_table(s(0.5)), "art_function_table"),
]

vv, xx = 0.33, 2.24
for make_instance, mech_name in variations:
check_constant_1d(make_instance, mech_name)
check_constant_2d(make_instance, mech_name)

expected = h.tau2_function_table(vv, xx)
actual = h.use_tau2_function_table(vv, xx)
np.testing.assert_approx_equal(actual, expected, significant=11)
check_1d(make_instance, mech_name)
check_2d(make_instance, mech_name)


if __name__ == "__main__":
test_constant_1d()
test_constant_2d()

test_1d()
test_2d()

test_use_table()
test_function_table()

0 comments on commit 6bcacea

Please sign in to comment.