Skip to content

Commit

Permalink
Simple ODE for state.
Browse files Browse the repository at this point in the history
  • Loading branch information
1uc committed Dec 8, 2023
1 parent 8c8b25e commit fda6976
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 38 deletions.
205 changes: 168 additions & 37 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,22 @@ void CodegenNeuronCppVisitor::print_namespace_stop() {
/// TODO: Edit for NEURON
std::string CodegenNeuronCppVisitor::float_variable_name(const SymbolType& symbol,
bool use_instance) const {
return symbol->get_name();
auto name = symbol->get_name();
auto dimension = symbol->get_length();
// auto position = position_of_float_var(name);
// clang-format off
if (symbol->is_array()) {
if (use_instance) {
return fmt::format("(inst.{}+id*{})", name, dimension);
}
throw std::runtime_error("Not implemented.");
// return fmt::format("(data + {}*pnodecount + id*{})", position, dimension);
}
if (use_instance) {
return fmt::format("inst.{}[id]", name);
}
throw std::runtime_error("Not implemented.");
// return fmt::format("data[{}*pnodecount + id]", position);
}


Expand All @@ -221,7 +236,70 @@ std::string CodegenNeuronCppVisitor::global_variable_name(const SymbolType& symb
/// TODO: Edit for NEURON
std::string CodegenNeuronCppVisitor::get_variable_name(const std::string& name,
bool use_instance) const {
return name;
// const std::string& varname = update_if_ion_variable_name(name);
const std::string& varname = name;

// clang-format off
auto symbol_comparator = [&varname](const SymbolType& sym) {
return varname == sym->get_name();
};

auto index_comparator = [&varname](const IndexVariableInfo& var) {
return varname == var.symbol->get_name();
};
// clang-format on

// float variable
auto f = std::find_if(codegen_float_variables.begin(),
codegen_float_variables.end(),
symbol_comparator);
if (f != codegen_float_variables.end()) {
return float_variable_name(*f, use_instance);
}

// integer variable
auto i =
std::find_if(codegen_int_variables.begin(), codegen_int_variables.end(), index_comparator);
if (i != codegen_int_variables.end()) {
return int_variable_name(*i, varname, use_instance);
}

// global variable
auto g = std::find_if(codegen_global_variables.begin(),
codegen_global_variables.end(),
symbol_comparator);
if (g != codegen_global_variables.end()) {
return global_variable_name(*g, use_instance);
}

if (varname == naming::NTHREAD_DT_VARIABLE) {
return std::string("_nt->_") + naming::NTHREAD_DT_VARIABLE;
}

// t in net_receive method is an argument to function and hence it should
// ne used instead of nt->_t which is current time of thread
if (varname == naming::NTHREAD_T_VARIABLE && !printing_net_receive) {
return std::string("_nt->_") + naming::NTHREAD_T_VARIABLE;
}

auto const iter =
std::find_if(info.neuron_global_variables.begin(),
info.neuron_global_variables.end(),
[&varname](auto const& entry) { return entry.first->get_name() == varname; });
if (iter != info.neuron_global_variables.end()) {
std::string ret;
if (use_instance) {
ret = "*(inst->";
}
ret.append(varname);
if (use_instance) {
ret.append(")");
}
return ret;
}

// otherwise return original name
return varname;
}


Expand Down Expand Up @@ -384,6 +462,23 @@ void CodegenNeuronCppVisitor::print_global_variables_for_hoc() {
printer->add_line("};");
}

void CodegenNeuronCppVisitor::print_make_instance() const {
printer->add_newline(2);
printer->fmt_push_block("static {} make_{}_instance(_nrn_mechanism_cache_range& _ml)",
instance_struct(),
info.mod_suffix);
printer->fmt_push_block("return {}", instance_struct());

const auto codegen_float_variables_size = codegen_float_variables.size();
for (int i = 0; i < codegen_float_variables_size; ++i) {
const auto& float_var = codegen_float_variables[i];

Check warning on line 474 in src/codegen/codegen_neuron_cpp_visitor.cpp

View workflow job for this annotation

GitHub Actions / { "flag_warnings": "ON", "os": "ubuntu-22.04" }

unused variable ‘float_var’ [-Wunused-variable]

Check warning on line 474 in src/codegen/codegen_neuron_cpp_visitor.cpp

View workflow job for this annotation

GitHub Actions / { "flag_warnings": "ON", "os": "ubuntu-22.04", "sanitizer": "undefined" }

unused variable 'float_var' [-Wunused-variable]
printer->fmt_line("&_ml.template fpfield<{}>(0){}",
i,
i < codegen_float_variables_size - 1 ? "," : "");
}
printer->pop_block(";");
printer->pop_block();
}

void CodegenNeuronCppVisitor::print_mechanism_register() {
/// TODO: Write this according to NEURON
Expand Down Expand Up @@ -443,49 +538,78 @@ void CodegenNeuronCppVisitor::print_mechanism_register() {
}


void CodegenNeuronCppVisitor::print_mechanism_range_var_structure(
[[maybe_unused]] bool print_initializers) {
void CodegenNeuronCppVisitor::print_mechanism_range_var_structure(bool print_initializers) {
auto const value_initialize = print_initializers ? "{}" : "";
auto int_type = default_int_data_type();
printer->add_newline(2);
printer->add_line("/* NEURON RANGE variables macro definitions */");
for (auto i = 0; i < codegen_float_variables.size(); ++i) {
const auto float_var = codegen_float_variables[i];
if (float_var->is_array()) {
printer->add_line("#define ",
float_var->get_name(),
"(id) _ml->template data_array<",
std::to_string(i),
", ",
std::to_string(float_var->get_length()),
">(id)");
printer->add_line("/** all mechanism instance variables and global variables */");
printer->fmt_push_block("struct {} ", instance_struct());

for (auto const& [var, type]: info.neuron_global_variables) {
auto const name = var->get_name();
printer->fmt_line("{}* {}{};",
type,
name,
print_initializers ? fmt::format("{{&coreneuron::{}}}", name)
: std::string{});
}
for (auto& var: codegen_float_variables) {
const auto& name = var->get_name();
// auto type = get_range_var_float_type(var);
// auto qualifier = is_constant_variable(name) ? "const " : "";
printer->fmt_line("double* {}{};", name, value_initialize);
}
for (auto& var: codegen_int_variables) {
const auto& name = var.symbol->get_name();
if (var.is_index || var.is_integer) {
auto qualifier = var.is_constant ? "const " : "";
printer->fmt_line("{}{}* {}{};", qualifier, int_type, name, value_initialize);
} else {
printer->add_line("#define ",
float_var->get_name(),
"(id) _ml->template fpfield<",
std::to_string(i),
">(id)");
auto qualifier = var.is_constant ? "const " : "";
auto type = var.is_vdata ? "void*" : default_float_data_type();
printer->fmt_line("{}{}* {}{};", qualifier, type, name, value_initialize);
}
}

// printer->fmt_line("{}* {}{};",
// global_struct(),
// naming::INST_GLOBAL_MEMBER,
// print_initializers ? fmt::format("{{&{}}}", global_struct_instance())
// : std::string{});
printer->pop_block(";");
}


/// TODO: Edit for NEURON
void CodegenNeuronCppVisitor::print_global_function_common_code(BlockType type,
const std::string& function_name) {
return;
std::string method = function_name.empty() ? compute_method_name(type) : function_name;
std::string args =
"_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* _ml_arg, int "
"_type";
printer->fmt_push_block("void {}({})", method, args);

printer->add_line("_nrn_mechanism_cache_range _lmr{_sorted_token, *_nt, *_ml_arg, _type};");
printer->add_line("auto inst = make_leonhard_instance(_lmr);");
printer->add_line("auto nodecount = _ml_arg->nodecount;");
}


void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) {
codegen = true;
printer->add_newline(2);
printer->add_line("/** initialize channel */");
print_global_function_common_code(BlockType::Initial);

printer->fmt_line(
"static void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* "
"_ml_arg, int _type) {{}}",
method_name(naming::NRN_INIT_METHOD));
printer->push_block("for (int id = 0; id < nodecount; id++)");
print_initial_block(info.initial_node);
printer->pop_block();

codegen = false;
printer->pop_block();
}

void CodegenNeuronCppVisitor::print_initial_block(const InitialBlock* node) {
// initial block
if (node != nullptr) {
const auto& block = node->get_statement_block();
print_statement_block(*block, false, false);
}
}


Expand Down Expand Up @@ -532,17 +656,23 @@ void CodegenNeuronCppVisitor::print_nrn_state() {
if (!nrn_state_required()) {
return;
}
codegen = true;
printer->add_newline(2);

printer->fmt_line(
"void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* "
"_ml_arg, int _type) {{}}",
method_name(naming::NRN_STATE_METHOD));
print_global_function_common_code(BlockType::State);

/// TODO: Fill in
printer->push_block("for (int id = 0; id < nodecount; id++)");

codegen = false;
if (info.nrn_state_block) {
info.nrn_state_block->visit_children(*this);
}

if (info.currents.empty() && info.breakpoint_node != nullptr) {
auto block = info.breakpoint_node->get_statement_block();
print_statement_block(*block, false, false);
}

printer->pop_block();
printer->pop_block();
}


Expand Down Expand Up @@ -673,6 +803,7 @@ void CodegenNeuronCppVisitor::print_namespace_end() {
void CodegenNeuronCppVisitor::print_data_structures(bool print_initializers) {
print_mechanism_global_var_structure(print_initializers);
print_mechanism_range_var_structure(print_initializers);
print_make_instance();
}


Expand Down
13 changes: 12 additions & 1 deletion src/codegen/codegen_neuron_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
#include <string_view>
#include <utility>

#include "codegen/codegen_cpp_visitor.hpp"
#include "codegen/codegen_info.hpp"
#include "codegen/codegen_naming.hpp"
#include "printer/code_printer.hpp"
#include "symtab/symbol_table.hpp"
#include "utils/logger.hpp"
#include "visitors/ast_visitor.hpp"
#include <codegen/codegen_cpp_visitor.hpp>


/// encapsulates code generation backend implementations
Expand Down Expand Up @@ -83,6 +83,12 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
*/
virtual std::string backend_name() const override;

/**
* Name of structure that wraps range variables
*/
std::string instance_struct() const {
return fmt::format("{}_Instance", info.mod_suffix);
}

/****************************************************************************************/
/* Common helper routines accross codegen functions */
Expand Down Expand Up @@ -381,6 +387,8 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
*/
void print_nrn_init(bool skip_init_check = true);

/** Print the initial block. */
void print_initial_block(const ast::InitialBlock* node);

/**
* Print nrn_constructor function definition
Expand Down Expand Up @@ -527,6 +535,9 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
*/
void print_data_structures(bool print_initializers) override;

/** Print `make_*_instance`.
*/
void print_make_instance() const;

/**
* Set v_unused (voltage) for NRN_PRCELLSTATE feature
Expand Down

0 comments on commit fda6976

Please sign in to comment.