Skip to content

Commit

Permalink
Better naming
Browse files Browse the repository at this point in the history
  • Loading branch information
JCGoran committed Oct 22, 2024
1 parent ea07cbc commit 80dbe90
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 55 deletions.
18 changes: 18 additions & 0 deletions src/codegen/codegen_naming.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,24 @@ static constexpr char THREAD_ARGS_PROTO[] = "_threadargsproto_";
/// prefix for ion variable
static constexpr char ION_VARNAME_PREFIX[] = "ion_";

/// name of CVODE method for counting # of ODEs
static constexpr char CVODE_COUNT_NAME[] = "ode_count";

/// name of CVODE method for updating non-stiff systems
static constexpr char CVODE_UPDATE_NON_STIFF_NAME[] = "ode_update_nonstiff";

/// name of CVODE method for updating stiff systems
static constexpr char CVODE_UPDATE_STIFF_NAME[] = "ode_update_stiff";

/// name of CVODE method for setting up non-stiff systems
static constexpr char CVODE_SETUP_NON_STIFF_NAME[] = "ode_setup_nonstiff";

/// name of CVODE method for setting up stiff systems
static constexpr char CVODE_SETUP_STIFF_NAME[] = "ode_setup_stiff";

/// name of CVODE method for setting up tolerances
static constexpr char CVODE_SETUP_TOLERANCES_NAME[] = "ode_setup_tolerances";

/// commonly used variables in verbatim block and how they
/// should be mapped to new code generation backends
// clang-format off
Expand Down
106 changes: 51 additions & 55 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1443,13 +1443,15 @@ void CodegenNeuronCppVisitor::print_mechanism_register() {
if (info.emit_cvode) {
printer->fmt_line("hoc_register_dparam_semantics(mech_type, {}, \"cvodeieq\");",
codegen_int_variables_size);
printer->fmt_line(
"hoc_register_cvode(mech_type, ode_count_{}, ode_setup_tolerance_{}, "
"ode_setup_nonstiff_{}, ode_setup_stiff_{});",
info.mod_suffix,
info.mod_suffix,
info.mod_suffix,
info.mod_suffix);
printer->fmt_line("hoc_register_cvode(mech_type, {}_{}, {}_{}, {}_{}, {}_{});",
naming::CVODE_COUNT_NAME,
info.mod_suffix,
naming::CVODE_SETUP_TOLERANCES_NAME,
info.mod_suffix,
naming::CVODE_SETUP_NON_STIFF_NAME,
info.mod_suffix,
naming::CVODE_SETUP_STIFF_NAME,
info.mod_suffix);
printer->fmt_line("hoc_register_tolerance(mech_type, _hoc_state_tol, &_atollist);");
}

Expand Down Expand Up @@ -2542,39 +2544,33 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() {
printer->add_line("/* Functions related to CVODE codegen */");

/* return # of ODEs to solve */
printer->push_block(
fmt::format("static constexpr int ode_count_{}(int _type)", info.mod_suffix));
printer->fmt_push_block("static constexpr int {}_{}(int _type)",
naming::CVODE_COUNT_NAME,
info.mod_suffix);
printer->fmt_line("return {};", info.cvode_block->get_n_odes()->get_value());
printer->pop_block();

printer->add_newline(2);

auto update_nonstiff_name = fmt::format("ode_update_nonstiff_{}", info.mod_suffix);

/* The update function for non-stiff systems */
printer->fmt_push_block("static int {}({})",
update_nonstiff_name,
get_parameter_str(cvode_update_parameters())); // begin function
// definition
printer->fmt_push_block("static int {}_{}({})",
naming::CVODE_UPDATE_NON_STIFF_NAME,
info.mod_suffix,
get_parameter_str(cvode_update_parameters())); // begin fn
printer->add_line("int node_id = node_data.nodeindices[id];");
printer->add_line("auto v = node_data.node_voltages[node_id];");
if (info.cvode_block) {
auto block = info.cvode_block->get_non_stiff_block();
print_statement_block(*block, false, false);
}
print_statement_block(*info.cvode_block->get_non_stiff_block(), false, false);

printer->add_line("return 0;");
printer->pop_block(); // end function definition
printer->pop_block(); // end fn

printer->add_newline(2);

auto setup_nonstiff_name = fmt::format("ode_setup_nonstiff_{}", info.mod_suffix);

/* The setup function for non-stiff systems */
printer->push_block(
fmt::format("static void {}({})",
setup_nonstiff_name,
get_parameter_str(cvode_setup_parameters()))); // begin function definition
printer->fmt_push_block("static void {}_{}({})",
naming::CVODE_SETUP_NON_STIFF_NAME,
info.mod_suffix,
get_parameter_str(cvode_setup_parameters())); // begin fn
printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};");
printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix);
printer->add_line("auto nodecount = _ml_arg->nodecount;");
Expand All @@ -2589,55 +2585,52 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() {
printer->add_line("int node_id = node_data.nodeindices[id];");
printer->add_line("auto* _ppvar = _ml_arg->pdata[id];");
printer->add_line("auto v = node_data.node_voltages[node_id];");
printer->fmt_line("{}({});", update_nonstiff_name, get_arg_str(cvode_update_parameters()));
printer->fmt_line("{}_{}({});",
naming::CVODE_UPDATE_NON_STIFF_NAME,
info.mod_suffix,
get_arg_str(cvode_update_parameters()));

printer->pop_block(); // end for loop
printer->pop_block(); // end function definition
printer->pop_block(); // end fn

printer->add_newline(2);

/* The function for setup of tolerance */
printer->push_block(
fmt::format("static void ode_setup_tolerance_{}(Prop* _prop, int equation_index, "
"neuron::container::data_handle<double>* _pv, "
"neuron::container::data_handle<double>* _pvdot, double* _atol, int _type)",
info.mod_suffix)); // begin function definition
printer->fmt_push_block(
"static void {}_{}(Prop* _prop, int equation_index, "
"neuron::container::data_handle<double>* _pv, "
"neuron::container::data_handle<double>* _pvdot, double* _atol, int _type)",
naming::CVODE_SETUP_TOLERANCES_NAME,
info.mod_suffix); // begin fn
printer->add_line("auto* _ppvar = _nrn_mechanism_access_dparam(_prop);");
printer->fmt_line("_ppvar[{}].literal_value<int>() = equation_index;", int_variables_size());
printer->push_block(fmt::format("for (int i = 0; i < ode_count_{}(0); i++)",
info.mod_suffix)); // begin for loop
printer->fmt_push_block("for (int i = 0; i < ode_count_{}(0); i++)",
info.mod_suffix); // begin for loop
printer->add_line("_pv[i] = _nrn_mechanism_get_param_handle(_prop, _slist1[i]);");
printer->add_line("_pvdot[i] = _nrn_mechanism_get_param_handle(_prop, _dlist1[i]);");
printer->add_line("_cvode_abstol(_atollist, _atol, i);");
printer->pop_block(); // end for loop
printer->pop_block(); // end function definition
printer->pop_block(); // end fn

printer->add_newline(2);

auto update_stiff_name = fmt::format("ode_update_stiff_{}", info.mod_suffix);

/* The update function for stiff systems */
printer->push_block(
fmt::format("static void {}({})",
update_stiff_name,
get_parameter_str(cvode_update_parameters()))); // begin function definition
printer->fmt_push_block("static void {}_{}({})",
naming::CVODE_UPDATE_STIFF_NAME,
info.mod_suffix,
get_parameter_str(cvode_update_parameters())); // begin fn

if (info.cvode_block) {
auto block = info.cvode_block->get_stiff_block();
print_statement_block(*block, false, false);
}
print_statement_block(*info.cvode_block->get_stiff_block(), false, false);

printer->pop_block(); // end function definition
printer->pop_block(); // end fn

printer->add_newline(2);

auto setup_stiff_name = fmt::format("ode_setup_stiff_{}", info.mod_suffix);

/* The setup function for stiff systems */
printer->push_block(
fmt::format("static void {}({})",
setup_stiff_name,
get_parameter_str(cvode_setup_parameters()))); // begin function definition
printer->fmt_push_block("static void {}_{}({})",
naming::CVODE_SETUP_STIFF_NAME,
info.mod_suffix,
get_parameter_str(cvode_setup_parameters())); // begin fn
printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml_arg, _type};");
printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix);
printer->add_line("auto nodecount = _ml_arg->nodecount;");
Expand All @@ -2654,10 +2647,13 @@ void CodegenNeuronCppVisitor::print_cvode_definitions() {
printer->add_line("int node_id = node_data.nodeindices[id];");
printer->add_line("auto* _ppvar = _ml_arg->pdata[id];");
printer->add_line("auto v = node_data.node_voltages[node_id];");
printer->fmt_line("{}({});", update_stiff_name, get_arg_str(cvode_update_parameters()));
printer->fmt_line("{}_{}({});",
naming::CVODE_UPDATE_STIFF_NAME,
info.mod_suffix,
get_arg_str(cvode_update_parameters()));

printer->pop_block(); // end for loop
printer->pop_block(); // end function definition
printer->pop_block(); // end fn
}

void CodegenNeuronCppVisitor::print_net_receive() {
Expand Down

0 comments on commit 80dbe90

Please sign in to comment.