Skip to content

Commit

Permalink
Merge branch 'master' into jelic/cvode_codegen_only
Browse files Browse the repository at this point in the history
  • Loading branch information
JCGoran committed Oct 22, 2024
2 parents 80dbe90 + 380207b commit 6700597
Show file tree
Hide file tree
Showing 21 changed files with 395 additions and 130 deletions.
2 changes: 1 addition & 1 deletion src/codegen/codegen_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ void CodegenCppVisitor::print_global_var_struct_assertions() const {


void CodegenCppVisitor::print_global_var_struct_decl() {
printer->add_line(global_struct(), ' ', global_struct_instance(), ';');
printer->fmt_line("static {} {};", global_struct(), global_struct_instance());
}


Expand Down
4 changes: 2 additions & 2 deletions src/codegen/codegen_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1547,8 +1547,8 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor {
template <typename T>
void print_function_declaration(const T& node,
const std::string& name,
const std::unordered_set<CppObjectSpecifier>& = {
CppObjectSpecifier::Inline});
const std::unordered_set<CppObjectSpecifier>& =
{CppObjectSpecifier::Static, CppObjectSpecifier::Inline});

void print_rename_state_vars() const;
};
Expand Down
194 changes: 125 additions & 69 deletions src/codegen/codegen_neuron_cpp_visitor.cpp

Large diffs are not rendered by default.

40 changes: 38 additions & 2 deletions src/codegen/codegen_neuron_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,14 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
void print_function_prototypes() override;


/**
* Print function and procedures prototype definitions.
*
* This includes the HOC/Python wrappers.
*/
void print_function_definitions();


/**
* Print all `check_*` function declarations
*/
Expand All @@ -244,9 +252,37 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
void print_function_procedure_helper(const ast::Block& node) override;


void print_hoc_py_wrapper_function_body(const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type);
/** Print the wrapper for calling FUNCION/PROCEDURES from HOC/Py.
*
* Usually the function is made up of the following parts:
* * Print setup code `inst`, etc.
* * Print code to call the function and return.
*/
void print_hoc_py_wrapper(const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type);

/** Print the setup code for HOC/Py wrapper.
*/
void print_hoc_py_wrapper_setup(const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type);


/** Print the code that calls the impl from the HOC/Py wrapper.
*/
void print_hoc_py_wrapper_call_impl(const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type);

/** Return the wrapper signature.
*
* Everything without the `{` or `;`. Roughly, as an example:
* <return_type> <function_name>(<internal_args>, <args>)
*
* were `<internal_args> is the list of arguments required by the
* codegen to be passed along, while <args> are the arguments of
* of the function as they appear in the MOD file.
*/
std::string hoc_py_wrapper_signature(const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type);

void print_hoc_py_wrapper_function_definitions();

Expand Down
11 changes: 11 additions & 0 deletions src/utils/string_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,16 @@ std::string to_string(double value, const std::string& format_spec) {
return fmt::format(format_spec, value);
}

std::string join_arguments(const std::string& lhs, const std::string& rhs) {
if (lhs.empty()) {
return rhs;
} else if (rhs.empty()) {
return lhs;
} else {
return fmt::format("{}", fmt::join({lhs, rhs}, ", "));
}
}


} // namespace stringutils
} // namespace nmodl
6 changes: 6 additions & 0 deletions src/utils/string_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ static inline bool starts_with(const std::string& haystack, const std::string& n
*/
std::string to_string(double value, const std::string& format_spec = "{:.16g}");

/** Joint two (list of) arguments.
*
* The tricks is to not add a ',' when either side is empty.
*/
std::string join_arguments(const std::string& lhs, const std::string& rhs);

/** \} */ // end of utils

} // namespace stringutils
Expand Down
9 changes: 9 additions & 0 deletions src/visitors/solve_block_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "utils/fmt.h"
#include <cassert>
#include <memory>

#include "ast/all.hpp"
#include "codegen/codegen_naming.hpp"
Expand Down Expand Up @@ -67,6 +68,14 @@ ast::SolutionExpression* SolveBlockVisitor::create_solution_expression(
return new ast::SolutionExpression(solve_block.clone(), callback_expr);
}

if (node_to_solve->get_node_type() == ast::AstNodeType::PROCEDURE_BLOCK) {
auto procedure_call = new ast::FunctionCall(solve_block.get_block_name()->clone(), {});
auto statement = std::make_shared<ast::ExpressionStatement>(procedure_call);
auto statement_block = new ast::StatementBlock({statement});

return new ast::SolutionExpression(solve_block.clone(), statement_block);
}

auto block_to_solve = node_to_solve->get_statement_block();
return new ast::SolutionExpression(solve_block.clone(), block_to_solve->clone());
}
Expand Down
18 changes: 18 additions & 0 deletions test/unit/utils/string_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,21 @@ TEST_CASE("starts_with") {
REQUIRE(!stringutils::starts_with("abcde", "abcde++"));
}
}

TEST_CASE("join_arguments") {
SECTION("both empty") {
REQUIRE(stringutils::join_arguments("", "") == "");
}

SECTION("lhs emtpy") {
REQUIRE(stringutils::join_arguments("", "foo, bar") == "foo, bar");
}

SECTION("rhs empty") {
REQUIRE(stringutils::join_arguments("foo", "") == "foo");
}

SECTION("neither empty") {
REQUIRE(stringutils::join_arguments("foo", "bar") == "foo, bar");
}
}
2 changes: 2 additions & 0 deletions test/usecases/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.pkl
*.png
6 changes: 6 additions & 0 deletions test/usecases/function_table/art_function_table.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
NEURON {
ARTIFICIAL_CELL art_function_table
}

INCLUDE "function_table.inc"

8 changes: 8 additions & 0 deletions test/usecases/function_table/function_table.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
FUNCTION_TABLE cnst1(v)
FUNCTION_TABLE cnst2(v, x)
FUNCTION_TABLE tau1(v)
FUNCTION_TABLE tau2(v, x)

FUNCTION use_tau2(v, x) {
use_tau2 = tau2(v, x)
}
8 changes: 1 addition & 7 deletions test/usecases/function_table/function_table.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,5 @@ NEURON {
SUFFIX function_table
}

FUNCTION_TABLE cnst1(v)
FUNCTION_TABLE cnst2(v, x)
FUNCTION_TABLE tau1(v)
FUNCTION_TABLE tau2(v, x)
INCLUDE "function_table.inc"

FUNCTION use_tau2(v, x) {
use_tau2 = tau2(v, x)
}
6 changes: 6 additions & 0 deletions test/usecases/function_table/point_function_table.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
NEURON {
POINT_PROCESS point_function_table
}

INCLUDE "function_table.inc"

94 changes: 64 additions & 30 deletions test/usecases/function_table/test_function_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,80 @@
import scipy


def test_constant_1d():
def make_callable(inst, name, mech_name):
if inst is None:
return getattr(h, f"{name}_{mech_name}")
else:
return getattr(inst, f"{name}")


def make_callbacks(inst, name, mech_name):
set_table = make_callable(inst, f"table_{name}", mech_name)
eval_table = make_callable(inst, name, mech_name)

return set_table, eval_table


def check_constant_1d(make_inst, mech_name):
s = h.Section()
s.insert("function_table")

inst = make_inst(s)
set_table, eval_table = make_callbacks(inst, "cnst1", mech_name)

c = 42.0
h.table_cnst1_function_table(c)
set_table(c)

for vv in np.linspace(-10.0, 10.0, 14):
np.testing.assert_equal(h.cnst1_function_table(vv), c)
np.testing.assert_equal(eval_table(vv), c)


def test_constant_2d():
def check_constant_2d(make_inst, mech_name):
s = h.Section()
s.insert("function_table")

inst = make_inst(s)
set_table, eval_table = make_callbacks(inst, "cnst2", mech_name)

c = 42.0
h.table_cnst2_function_table(c)
set_table(c)

for vv in np.linspace(-10.0, 10.0, 7):
for xx in np.linspace(-20.0, 10.0, 9):
np.testing.assert_equal(h.cnst2_function_table(vv, xx), c)
np.testing.assert_equal(eval_table(vv, xx), c)


def check_1d(make_inst, mech_name):
s = h.Section()
s.insert("function_table")

inst = make_inst(s)
set_table, eval_table = make_callbacks(inst, "tau1", mech_name)

def test_1d():
v = np.array([0.0, 1.0])
tau1 = np.array([1.0, 2.0])

h.table_tau1_function_table(h.Vector(tau1), h.Vector(v))
set_table(h.Vector(tau1), h.Vector(v))

for vv in np.linspace(v[0], v[-1], 20):
expected = np.interp(vv, v, tau1)
actual = h.tau1_function_table(vv)
actual = eval_table(vv)

np.testing.assert_approx_equal(actual, expected, significant=11)


def test_2d():
def check_2d(make_inst, mech_name):
s = h.Section()
s.insert("function_table")

inst = make_inst(s)
set_table, eval_table = make_callbacks(inst, "tau2", mech_name)
eval_use_table = make_callable(inst, "use_tau2", mech_name)

if inst is None:
setdata = getattr(h, f"setdata_{mech_name}")
setdata(s(0.5))

v = np.array([0.0, 1.0])
x = np.array([1.0, 2.0, 3.0])

Expand All @@ -50,36 +87,33 @@ def test_2d():
hoc_tau2 = h.Matrix(*tau2.shape)
hoc_tau2.from_vector(h.Vector(tau2.transpose().reshape(-1)))

h.table_tau2_function_table(
hoc_tau2._ref_x[0][0], v.size, v[0], v[-1], x.size, x[0], x[-1]
)
set_table(hoc_tau2._ref_x[0][0], v.size, v[0], v[-1], x.size, x[0], x[-1])

for vv in np.linspace(v[0], v[-1], 20):
for xx in np.linspace(x[0], x[-1], 20):
expected = scipy.interpolate.interpn((v, x), tau2, (vv, xx))
actual = h.tau2_function_table(vv, xx)
actual = eval_table(vv, xx)
actual_indirect = eval_use_table(vv, xx)

np.testing.assert_approx_equal(actual, expected, significant=11)
np.testing.assert_approx_equal(actual_indirect, expected, significant=11)


def test_use_table():
s = h.Section()
s.insert("function_table")

h.setdata_function_table(s(0.5))
def test_function_table():
variations = [
(lambda s: None, "function_table"),
(lambda s: s(0.5).function_table, "function_table"),
(lambda s: h.point_function_table(s(0.5)), "point_function_table"),
(lambda s: h.art_function_table(s(0.5)), "art_function_table"),
]

vv, xx = 0.33, 2.24
for make_instance, mech_name in variations:
check_constant_1d(make_instance, mech_name)
check_constant_2d(make_instance, mech_name)

expected = h.tau2_function_table(vv, xx)
actual = h.use_tau2_function_table(vv, xx)
np.testing.assert_approx_equal(actual, expected, significant=11)
check_1d(make_instance, mech_name)
check_2d(make_instance, mech_name)


if __name__ == "__main__":
test_constant_1d()
test_constant_2d()

test_1d()
test_2d()

test_use_table()
test_function_table()
16 changes: 16 additions & 0 deletions test/usecases/global/global.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
NEURON {
SUFFIX global
GLOBAL gbl
}

ASSIGNED {
gbl
}

FUNCTION get_gbl() {
get_gbl = gbl
}

PROCEDURE set_gbl(value) {
gbl = value
}
11 changes: 11 additions & 0 deletions test/usecases/global/parameter.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
NEURON {
SUFFIX parameter
}

PARAMETER {
gbl = 42.0
}

FUNCTION get_gbl() {
get_gbl = gbl
}
Loading

0 comments on commit 6700597

Please sign in to comment.