diff --git a/src/codegen/codegen_cpp_visitor.cpp b/src/codegen/codegen_cpp_visitor.cpp index e661f5685b..2f8dcb838c 100644 --- a/src/codegen/codegen_cpp_visitor.cpp +++ b/src/codegen/codegen_cpp_visitor.cpp @@ -3330,7 +3330,13 @@ void CodegenCppVisitor::print_initial_block(const InitialBlock* node) { if (!info.is_ionic_conc(name)) { auto lhs = get_variable_name(name); auto rhs = get_variable_name(name + "0"); - printer->fmt_line("{} = {};", lhs, rhs); + if (var->is_array()) { + for (int i = 0; i < var->get_length(); ++i) { + printer->fmt_line("{}[{}] = {};", lhs, i, rhs); + } + } else { + printer->fmt_line("{} = {};", lhs, rhs); + } } } diff --git a/test/unit/codegen/codegen_cpp_visitor.cpp b/test/unit/codegen/codegen_cpp_visitor.cpp index 967c5d3267..be30ce24d6 100644 --- a/test/unit/codegen/codegen_cpp_visitor.cpp +++ b/test/unit/codegen/codegen_cpp_visitor.cpp @@ -1181,3 +1181,57 @@ SCENARIO("Check codegen for MUTEX and PROTECT", "[codegen][mutex_protect]") { } } } + + +SCENARIO("Array STATE variable", "[codegen][array_state]") { + GIVEN("A mod file containing an array STATE variable") { + std::string const nmodl_text = R"( + DEFINE NANN 4 + + NEURON { + SUFFIX ca_test + } + STATE { + ca[NANN] + k + } + )"; + + THEN("nrn_init is printed with proper initialization of the whole array") { + auto const generated = get_cpp_code(nmodl_text); + std::string expected_code_init = + R"(/** initialize channel */ + void nrn_init_ca_test(NrnThread* nt, Memb_list* ml, int type) { + int nodecount = ml->nodecount; + int pnodecount = ml->_nodecount_padded; + const int* node_index = ml->nodeindices; + double* data = ml->data; + const double* voltage = nt->_actual_v; + Datum* indexes = ml->pdata; + ThreadDatum* thread = ml->_thread; + + setup_instance(nt, ml); + auto* const inst = static_cast(ml->instance); + + if (_nrn_skip_initmodel == 0) { + #pragma omp simd + #pragma ivdep + for (int id = 0; id < nodecount; id++) { + int node_id = node_index[id]; + double v = voltage[node_id]; + #if NRN_PRCELLSTATE + inst->v_unused[id] = v; + #endif + (inst->ca+id*4)[0] = inst->global->ca0; + (inst->ca+id*4)[1] = inst->global->ca0; + (inst->ca+id*4)[2] = inst->global->ca0; + (inst->ca+id*4)[3] = inst->global->ca0; + inst->k[id] = inst->global->k0; + } + } + })"; + + REQUIRE_THAT(generated, ContainsSubstring(expected_code_init)); + } + } +}