Skip to content

Commit

Permalink
Implement trivial cases of reading/writing ions. (#1152)
Browse files Browse the repository at this point in the history
This implements a trivial case of reading an ion variable `ena` and
writing `42.0` to `ina`; along with the functionality required to be
able to record `ina`.
  • Loading branch information
1uc authored and Omar Awile committed May 21, 2024
1 parent a511f17 commit e5c328f
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 29 deletions.
45 changes: 45 additions & 0 deletions src/codegen/codegen_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
* \brief Various types to store code generation specific information
*/

#include <fmt/format.h>
#include <string>
#include <unordered_map>
#include <unordered_set>
Expand All @@ -20,6 +21,7 @@
#include "ast/ast.hpp"
#include "symtab/symbol_table.hpp"


namespace nmodl {
namespace codegen {

Expand Down Expand Up @@ -109,12 +111,55 @@ struct Ion {
return is_intra_cell_conc(text) || is_extra_cell_conc(text);
}

/// Is the variable name `text` related to this ion?
///
/// Example: For sodium this is true for any of `"ena"`, `"ina"`, `"nai"`
/// and `"nao"`; but not `ion_ina`, etc.
bool is_ionic_variable(const std::string& text) const {
return is_ionic_conc(text) || is_ionic_current(text) || is_rev_potential(text);
}

bool is_current_derivative(const std::string& text) const {
return text == ("di" + name + "dv");
}

/// for a given ion, return different variable names/properties
/// like internal/external concentration, reversial potential,
/// ionic current etc.
static std::vector<std::string> get_possible_variables(const std::string& ion_name) {
return {"i" + ion_name, ion_name + "i", ion_name + "o", "e" + ion_name};
}

/// Variable index in the ion mechanism.
///
/// For sodium (na), the `var_name` must be one of `ina`, `ena`, `nai`,
/// `nao` or `dinadv`. Replace `na` with the analogous for other ions.
///
/// In NRN the order is:
/// 0: ena
/// 1: nai
/// 2: nao
/// 3: ina
/// 4: dinadv
int variable_index(const std::string& var_name) const {
if (is_rev_potential(var_name)) {
return 0;
}
if (is_intra_cell_conc(var_name)) {
return 1;
}
if (is_extra_cell_conc(var_name)) {
return 2;
}
if (is_ionic_current(var_name)) {
return 3;
}
if (is_current_derivative(var_name)) {
return 4;
}

throw std::runtime_error(fmt::format("Invalid `var_name == {}`.", var_name));
}
};


Expand Down
190 changes: 165 additions & 25 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ void CodegenNeuronCppVisitor::print_namespace_stop() {
std::string CodegenNeuronCppVisitor::conc_write_statement(const std::string& ion_name,
const std::string& concentration,
int index) {
throw std::runtime_error("Not implemented.");
// throw std::runtime_error("Not implemented.");
return "";
}

/****************************************************************************************/
Expand Down Expand Up @@ -234,7 +235,30 @@ std::string CodegenNeuronCppVisitor::float_variable_name(const SymbolType& symbo
std::string CodegenNeuronCppVisitor::int_variable_name(const IndexVariableInfo& symbol,
const std::string& name,
bool use_instance) const {
return name;
auto position = position_of_int_var(name);
if (symbol.is_index) {
if (use_instance) {
throw std::runtime_error("Not implemented. [wiejo]");
// return fmt::format("inst->{}[{}]", name, position);
}
throw std::runtime_error("Not implemented. [ncuwi]");
// return fmt::format("indexes[{}]", position);
}
if (symbol.is_integer) {
if (use_instance) {
throw std::runtime_error("Not implemented. [cnuoe]");
// return fmt::format("inst->{}[{}*pnodecount+id]", name, position);
}
throw std::runtime_error("Not implemented. [u32ow]");
// return fmt::format("indexes[{}*pnodecount+id]", position);
}
if (use_instance) {
return fmt::format("(*inst.{}[id])", name);
}

throw std::runtime_error("Not implemented. [nvueir]");
// auto data = symbol.is_vdata ? "_vdata" : "_data";
// return fmt::format("nt->{}[indexes[{}*pnodecount + id]]", data, position);
}


Expand All @@ -250,8 +274,7 @@ std::string CodegenNeuronCppVisitor::global_variable_name(const SymbolType& symb

std::string CodegenNeuronCppVisitor::get_variable_name(const std::string& name,
bool use_instance) const {
// const std::string& varname = update_if_ion_variable_name(name);
const std::string& varname = name;
const std::string& varname = update_if_ion_variable_name(name);

auto symbol_comparator = [&varname](const SymbolType& sym) {
return varname == sym->get_name();
Expand Down Expand Up @@ -550,20 +573,39 @@ void CodegenNeuronCppVisitor::print_make_instance() const {
info.mod_suffix);
printer->fmt_push_block("return {}", instance_struct());

std::vector<std::string> make_instance_args;

const auto codegen_float_variables_size = codegen_float_variables.size();
for (int i = 0; i < codegen_float_variables_size; ++i) {
const auto& float_var = codegen_float_variables[i];
if (float_var->is_array()) {
printer->fmt_line("_ml.template data_array<{}, {}>(0){}",
i,
float_var->get_length(),
i < codegen_float_variables_size - 1 ? "," : "");
make_instance_args.push_back(
fmt::format("_ml.template data_array_ptr<{}, {}>()", i, float_var->get_length()));
} else {
printer->fmt_line("&_ml.template fpfield<{}>(0){}",
i,
i < codegen_float_variables_size - 1 ? "," : "");
make_instance_args.push_back(fmt::format("_ml.template fpfield_ptr<{}>()", i));
}
}

const auto codegen_int_variables_size = codegen_int_variables.size();
for (size_t i = 0; i < codegen_int_variables_size; ++i) {
const auto& var = codegen_int_variables[i];
auto name = var.symbol->get_name();
auto const variable = [&var, i]() -> std::string {
if (var.is_index || var.is_integer) {
return "";
} else if (var.is_vdata) {
return "";
} else {
return fmt::format("_ml.template dptr_field_ptr<{}>()", i);
}
}();
if (variable != "") {
make_instance_args.push_back(variable);
}
}

printer->add_multi_line(fmt::format("{}", fmt::join(make_instance_args, ",\n")));

printer->pop_block(";");
printer->pop_block();
}
Expand Down Expand Up @@ -600,28 +642,63 @@ void CodegenNeuronCppVisitor::print_mechanism_register() {
printer->add_line("_nrn_mechanism_register_data_fields(mech_type,");
printer->increase_indent();
const auto codegen_float_variables_size = codegen_float_variables.size();

std::vector<std::string> mech_register_args;
for (int i = 0; i < codegen_float_variables_size; ++i) {
const auto& float_var = codegen_float_variables[i];
const auto print_comma = i < codegen_float_variables_size - 1 || info.emit_cvode;
if (float_var->is_array()) {
printer->fmt_line("_nrn_mechanism_field<double>{{\"{}\", {}}} /* {} */{}",
float_var->get_name(),
float_var->get_length(),
i,
print_comma ? "," : "");
mech_register_args.push_back(
fmt::format("_nrn_mechanism_field<double>{{\"{}\", {}}} /* {} */",
float_var->get_name(),
float_var->get_length(),
i));
} else {
printer->fmt_line("_nrn_mechanism_field<double>{{\"{}\"}} /* {} */{}",
float_var->get_name(),
i,
print_comma ? "," : "");
mech_register_args.push_back(fmt::format(
"_nrn_mechanism_field<double>{{\"{}\"}} /* {} */", float_var->get_name(), i));
}
}

const auto codegen_int_variables_size = codegen_int_variables.size();
for (int i = 0; i < codegen_int_variables_size; ++i) {
const auto& int_var = codegen_int_variables[i];
const auto& name = int_var.symbol->get_name();
if (i != info.semantics[i].index) {
throw std::runtime_error("Broken logic.");
}

mech_register_args.push_back(
fmt::format("_nrn_mechanism_field<double*>{{\"{}\", \"{}\"}} /* {} */",
name,
info.semantics[i].name,
i));
}
if (info.emit_cvode) {
printer->add_line("_nrn_mechanism_field<int>{\"_cvode_ieq\", \"cvodeieq\"} /* 0 */");
mech_register_args.push_back(
"_nrn_mechanism_field<int>{\"_cvode_ieq\", \"cvodeieq\"} /* 0 */");
}

printer->add_multi_line(fmt::format("{}", fmt::join(mech_register_args, ",\n")));

printer->decrease_indent();
printer->add_line(");");
printer->add_newline();

printer->fmt_line("hoc_register_prop_size(mech_type, {}, {});",
codegen_float_variables_size,
codegen_int_variables_size);

for (int i = 0; i < codegen_int_variables_size; ++i) {
const auto& int_var = codegen_int_variables[i];
const auto& name = int_var.symbol->get_name();
if (i != info.semantics[i].index) {
throw std::runtime_error("Broken logic.");
}

printer->fmt_line("hoc_register_dparam_semantics(mech_type, {}, \"{}\");",
i,
info.semantics[i].name);
}

printer->pop_block();
}

Expand Down Expand Up @@ -649,11 +726,11 @@ void CodegenNeuronCppVisitor::print_mechanism_range_var_structure(bool print_ini
const auto& name = var.symbol->get_name();
if (var.is_index || var.is_integer) {
auto qualifier = var.is_constant ? "const " : "";
printer->fmt_line("{}{}* {}{};", qualifier, int_type, name, value_initialize);
printer->fmt_line("{}{}* const* {}{};", qualifier, int_type, name, value_initialize);
} else {
auto qualifier = var.is_constant ? "const " : "";
auto type = var.is_vdata ? "void*" : default_float_data_type();
printer->fmt_line("{}{}* {}{};", qualifier, type, name, value_initialize);
printer->fmt_line("{}{}* const* {}{};", qualifier, type, name, value_initialize);
}
}

Expand All @@ -667,11 +744,24 @@ void CodegenNeuronCppVisitor::print_mechanism_range_var_structure(bool print_ini


void CodegenNeuronCppVisitor::print_initial_block(const InitialBlock* node) {
// read ion statements
auto read_statements = ion_read_statements(BlockType::Initial);
for (auto& statement: read_statements) {
printer->add_line(statement);
}

// initial block
if (node != nullptr) {
const auto& block = node->get_statement_block();
print_statement_block(*block, false, false);
}

// write ion statements
auto write_statements = ion_write_statements(BlockType::Initial);
for (auto& statement: write_statements) {
auto text = process_shadow_update_statement(statement, BlockType::Initial);
printer->add_line(text);
}
}


Expand Down Expand Up @@ -729,7 +819,38 @@ void CodegenNeuronCppVisitor::print_nrn_alloc() {
printer->add_newline(2);
auto method = method_name(naming::NRN_ALLOC_METHOD);
printer->fmt_push_block("static void {}(Prop* _prop)", method);
printer->add_line("// do nothing");

const auto codegen_int_variables_size = codegen_int_variables.size();

// TODO number of datum is the number of integer vars.
printer->fmt_line("Datum *_ppvar = nrn_prop_datum_alloc(mech_type, {}, _prop);",
codegen_int_variables_size);
printer->fmt_line("_nrn_mechanism_access_dparam(_prop) = _ppvar;");

for (const auto& ion: info.ions) {
printer->fmt_line("Symbol * {}_sym = hoc_lookup(\"{}_ion\");", ion.name, ion.name);
printer->fmt_line("Prop * {}_prop = need_memb({}_sym);", ion.name, ion.name);

for (size_t i = 0; i < codegen_int_variables_size; ++i) {
const auto& var = codegen_int_variables[i];

// if(var.symbol->has_any_property(NmodlType::useion)) {
const std::string& var_name = var.symbol->get_name();
if (var_name.rfind("ion_", 0) != 0) {
continue;
}

std::string ion_var_name = std::string(var_name.begin() + 4, var_name.end());
if (ion.is_ionic_variable(ion_var_name)) {
printer->fmt_line("_ppvar[{}] = _nrn_mechanism_get_param_handle({}_prop, {});",
i,
ion.name,
ion.variable_index(ion_var_name));
}
// }
}
}

printer->pop_block();
}

Expand All @@ -750,6 +871,19 @@ void CodegenNeuronCppVisitor::print_nrn_state() {

printer->push_block("for (int id = 0; id < nodecount; id++)");

/**
* \todo Eigen solver node also emits IonCurVar variable in the functor
* but that shouldn't update ions in derivative block
*/
if (ion_variable_struct_required()) {
throw std::runtime_error("Not implemented.");
}

auto read_statements = ion_read_statements(BlockType::State);
for (auto& statement: read_statements) {
printer->add_line(statement);
}

if (info.nrn_state_block) {
info.nrn_state_block->visit_children(*this);
}
Expand All @@ -759,6 +893,12 @@ void CodegenNeuronCppVisitor::print_nrn_state() {
print_statement_block(*block, false, false);
}

const auto& write_statements = ion_write_statements(BlockType::State);
for (auto& statement: write_statements) {
const auto& text = process_shadow_update_statement(statement, BlockType::State);
printer->add_line(text);
}

printer->pop_block();
printer->pop_block();
}
Expand Down
15 changes: 11 additions & 4 deletions test/unit/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,9 @@ void _nrn_mechanism_register_data_fields(Args&&... args) {
double* Ds{};
double* v_unused{};
double* g_unused{};
const double* ion_ena{};
double* ion_ina{};
double* ion_dinadv{};
const double* const* ion_ena{};
double* const* ion_ina{};
double* const* ion_dinadv{};
pas_test_Store* global{&pas_test_global};
};)";

Expand Down Expand Up @@ -259,9 +259,16 @@ void _nrn_mechanism_register_data_fields(Args&&... args) {
_nrn_mechanism_field<double>{"ina"} /* 6 */,
_nrn_mechanism_field<double>{"Ds"} /* 7 */,
_nrn_mechanism_field<double>{"v_unused"} /* 8 */,
_nrn_mechanism_field<double>{"g_unused"} /* 9 */
_nrn_mechanism_field<double>{"g_unused"} /* 9 */,
_nrn_mechanism_field<double*>{"ion_ena", "na_ion"} /* 0 */,
_nrn_mechanism_field<double*>{"ion_ina", "na_ion"} /* 1 */,
_nrn_mechanism_field<double*>{"ion_dinadv", "na_ion"} /* 2 */
);
hoc_register_prop_size(mech_type, 10, 3);
hoc_register_dparam_semantics(mech_type, 0, "na_ion");
hoc_register_dparam_semantics(mech_type, 1, "na_ion");
hoc_register_dparam_semantics(mech_type, 2, "na_ion");
})CODE";

REQUIRE_THAT(generated,
Expand Down
Loading

0 comments on commit e5c328f

Please sign in to comment.