diff --git a/src/codegen/codegen_acc_visitor.cpp b/src/codegen/codegen_acc_visitor.cpp index 2d327ddbd4..ada0f4f2e1 100644 --- a/src/codegen/codegen_acc_visitor.cpp +++ b/src/codegen/codegen_acc_visitor.cpp @@ -89,17 +89,19 @@ void CodegenAccVisitor::print_memory_allocation_routine() const { } printer->add_newline(2); auto args = "size_t num, size_t size, size_t alignment = 16"; - printer->fmt_start_block("static inline void* mem_alloc({})", args); - printer->add_line("void* ptr;"); - printer->add_line("cudaMallocManaged(&ptr, num*size);"); - printer->add_line("cudaMemset(ptr, 0, num*size);"); - printer->add_line("return ptr;"); - printer->end_block(1); + printer->fmt_push_block("static inline void* mem_alloc({})", args); + printer->add_multi_line(R"CODE( + void* ptr; + cudaMallocManaged(&ptr, num*size); + cudaMemset(ptr, 0, num*size); + return ptr; + )CODE"); + printer->pop_block(1); printer->add_newline(2); - printer->start_block("static inline void mem_free(void* ptr)"); + printer->push_block("static inline void mem_free(void* ptr)"); printer->add_line("cudaFree(ptr);"); - printer->end_block(1); + printer->pop_block(1); } /** @@ -114,19 +116,19 @@ void CodegenAccVisitor::print_memory_allocation_routine() const { */ void CodegenAccVisitor::print_abort_routine() const { printer->add_newline(2); - printer->start_block("static inline void coreneuron_abort()"); - printer->add_line("printf(\"Error : Issue while running OpenACC kernel \\n\");"); + printer->push_block("static inline void coreneuron_abort()"); + printer->add_line(R"(printf("Error : Issue while running OpenACC kernel \n");)"); printer->add_line("assert(0==1);"); - printer->end_block(1); + printer->pop_block(1); } void CodegenAccVisitor::print_net_send_buffering_cnt_update() const { - printer->fmt_start_block("if (nt->compute_gpu)"); + printer->fmt_push_block("if (nt->compute_gpu)"); print_device_atomic_capture_annotation(); printer->add_line("i = nsb->_cnt++;"); - printer->restart_block("else"); + printer->chain_block("else"); printer->add_line("i = nsb->_cnt++;"); - printer->end_block(1); + printer->pop_block(1); } void CodegenAccVisitor::print_net_send_buffering_grow() { @@ -174,7 +176,7 @@ void CodegenAccVisitor::print_net_init_acc_serial_annotation_block_begin() { void CodegenAccVisitor::print_net_init_acc_serial_annotation_block_end() { if (!info.artificial_cell) { - printer->end_block(1); + printer->pop_block(1); } } @@ -198,7 +200,7 @@ void CodegenAccVisitor::print_fast_imem_calculation() { auto rhs_op = operator_for_rhs(); auto d_op = operator_for_d(); - printer->start_block("if (nt->nrn_fast_imem)"); + printer->push_block("if (nt->nrn_fast_imem)"); if (info.point_process) { print_atomic_reduction_pragma(); } @@ -207,7 +209,7 @@ void CodegenAccVisitor::print_fast_imem_calculation() { print_atomic_reduction_pragma(); } printer->fmt_line("nt->nrn_fast_imem->nrn_sav_d[node_id] {} g;", d_op); - printer->end_block(1); + printer->pop_block(1); } void CodegenAccVisitor::print_nrn_cur_matrix_shadow_reduction() { @@ -220,7 +222,7 @@ void CodegenAccVisitor::print_nrn_cur_matrix_shadow_reduction() { */ void CodegenAccVisitor::print_kernel_data_present_annotation_block_end() { if (!info.artificial_cell) { - printer->end_block(1); + printer->pop_block(1); } } @@ -237,25 +239,27 @@ bool CodegenAccVisitor::nrn_cur_reduction_loop_required() { void CodegenAccVisitor::print_global_variable_device_update_annotation() { if (!info.artificial_cell) { - printer->start_block("if (nt->compute_gpu)"); + printer->push_block("if (nt->compute_gpu)"); printer->fmt_line("nrn_pragma_acc(update device ({}))", global_struct_instance()); printer->fmt_line("nrn_pragma_omp(target update to({}))", global_struct_instance()); - printer->end_block(1); + printer->pop_block(1); } } void CodegenAccVisitor::print_newtonspace_transfer_to_device() const { int list_num = info.derivimplicit_list_num; - printer->start_block("if(nt->compute_gpu)"); - printer->add_line("double* device_vec = cnrn_target_copyin(vec, vec_size / sizeof(double));"); - printer->add_line("void* device_ns = cnrn_target_deviceptr(*ns);"); - printer->add_line("ThreadDatum* device_thread = cnrn_target_deviceptr(thread);"); + printer->push_block("if(nt->compute_gpu)"); + printer->add_multi_line(R"CODE( + double* device_vec = cnrn_target_copyin(vec, vec_size / sizeof(double)); + void* device_ns = cnrn_target_deviceptr(*ns); + ThreadDatum* device_thread = cnrn_target_deviceptr(thread); + )CODE"); printer->fmt_line("cnrn_target_memcpy_to_device(&(device_thread[{}]._pvoid), &device_ns);", info.thread_data_index - 1); printer->fmt_line("cnrn_target_memcpy_to_device(&(device_thread[dith{}()].pval), &device_vec);", list_num); - printer->end_block(1); + printer->pop_block(1); } @@ -276,32 +280,34 @@ void CodegenAccVisitor::print_instance_struct_transfer_routines( if (info.artificial_cell) { return; } - printer->fmt_start_block( + printer->fmt_push_block( "static inline void copy_instance_to_device(NrnThread* nt, Memb_list* ml, {} const* inst)", instance_struct()); - printer->start_block("if (!nt->compute_gpu)"); + printer->push_block("if (!nt->compute_gpu)"); printer->add_line("return;"); - printer->end_block(1); + printer->pop_block(1); printer->fmt_line("auto tmp = *inst;"); printer->add_line("auto* d_inst = cnrn_target_is_present(inst);"); - printer->start_block("if (!d_inst)"); + printer->push_block("if (!d_inst)"); printer->add_line("d_inst = cnrn_target_copyin(inst);"); - printer->end_block(1); + printer->pop_block(1); for (auto const& ptr_mem: ptr_members) { printer->fmt_line("tmp.{0} = cnrn_target_deviceptr(tmp.{0});", ptr_mem); } - printer->add_line("cnrn_target_memcpy_to_device(d_inst, &tmp);"); - printer->add_line("auto* d_ml = cnrn_target_deviceptr(ml);"); - printer->add_line("void* d_inst_void = d_inst;"); - printer->add_line("cnrn_target_memcpy_to_device(&(d_ml->instance), &d_inst_void);"); - printer->end_block(2); // copy_instance_to_device - - printer->fmt_start_block("static inline void delete_instance_from_device({}* inst)", - instance_struct()); - printer->start_block("if (cnrn_target_is_present(inst))"); + printer->add_multi_line(R"CODE( + cnrn_target_memcpy_to_device(d_inst, &tmp); + auto* d_ml = cnrn_target_deviceptr(ml); + void* d_inst_void = d_inst; + cnrn_target_memcpy_to_device(&(d_ml->instance), &d_inst_void); + )CODE"); + printer->pop_block(2); // copy_instance_to_device + + printer->fmt_push_block("static inline void delete_instance_from_device({}* inst)", + instance_struct()); + printer->push_block("if (cnrn_target_is_present(inst))"); printer->add_line("cnrn_target_delete(inst);"); - printer->end_block(1); - printer->end_block(2); // delete_instance_from_device + printer->pop_block(1); + printer->pop_block(2); // delete_instance_from_device } @@ -334,9 +340,9 @@ void CodegenAccVisitor::print_device_atomic_capture_annotation() const { void CodegenAccVisitor::print_device_stream_wait() const { - printer->start_block("if(nt->compute_gpu)"); + printer->push_block("if(nt->compute_gpu)"); printer->add_line("nrn_pragma_acc(wait(nt->stream_id))"); - printer->end_block(1); + printer->pop_block(1); } @@ -348,18 +354,18 @@ void CodegenAccVisitor::print_net_send_buf_count_update_to_host() const { void CodegenAccVisitor::print_net_send_buf_update_to_host() const { print_device_stream_wait(); - printer->start_block("if (nsb && nt->compute_gpu)"); + printer->push_block("if (nsb && nt->compute_gpu)"); print_net_send_buf_count_update_to_host(); printer->add_line("update_net_send_buffer_on_host(nt, nsb);"); - printer->end_block(1); + printer->pop_block(1); } void CodegenAccVisitor::print_net_send_buf_count_update_to_device() const { - printer->start_block("if (nt->compute_gpu)"); + printer->push_block("if (nt->compute_gpu)"); printer->add_line("nrn_pragma_acc(update device(nsb->_cnt))"); printer->add_line("nrn_pragma_omp(target update to(nsb->_cnt))"); - printer->end_block(1); + printer->pop_block(1); } diff --git a/src/codegen/codegen_cpp_visitor.cpp b/src/codegen/codegen_cpp_visitor.cpp index 88fd8985dc..da018ecae0 100644 --- a/src/codegen/codegen_cpp_visitor.cpp +++ b/src/codegen/codegen_cpp_visitor.cpp @@ -156,8 +156,7 @@ void CodegenCppVisitor::visit_local_list_statement(const LocalListStatement& nod if (!codegen) { return; } - auto type = local_var_type() + " "; - printer->add_text(type); + printer->add_text(local_var_type(), ' '); print_vector_elements(node.get_variables(), ", "); } @@ -329,11 +328,11 @@ void CodegenCppVisitor::visit_protect_statement(const ast::ProtectStatement& nod void CodegenCppVisitor::visit_mutex_lock(const ast::MutexLock& node) { printer->fmt_line("#pragma omp critical ({})", info.mod_suffix); printer->add_indent(); - printer->start_block(); + printer->push_block(); } void CodegenCppVisitor::visit_mutex_unlock(const ast::MutexUnlock& node) { - printer->end_block(1); + printer->pop_block(1); } /****************************************************************************************/ @@ -471,23 +470,23 @@ std::string CodegenCppVisitor::format_float_string(const std::string& s_value) { * block can appear as statement using expression statement which need to * be inspected. */ -bool CodegenCppVisitor::need_semicolon(Statement* node) { +bool CodegenCppVisitor::need_semicolon(const Statement& node) { // clang-format off - if (node->is_if_statement() - || node->is_else_if_statement() - || node->is_else_statement() - || node->is_from_statement() - || node->is_verbatim() - || node->is_from_statement() - || node->is_conductance_hint() - || node->is_while_statement() - || node->is_protect_statement() - || node->is_mutex_lock() - || node->is_mutex_unlock()) { + if (node.is_if_statement() + || node.is_else_if_statement() + || node.is_else_statement() + || node.is_from_statement() + || node.is_verbatim() + || node.is_from_statement() + || node.is_conductance_hint() + || node.is_while_statement() + || node.is_protect_statement() + || node.is_mutex_lock() + || node.is_mutex_unlock()) { return false; } - if (node->is_expression_statement()) { - auto expression = dynamic_cast(node)->get_expression(); + if (node.is_expression_statement()) { + auto expression = dynamic_cast(node).get_expression(); if (expression->is_statement_block() || expression->is_eigen_newton_solver_block() || expression->is_eigen_linear_solver_block() @@ -554,7 +553,7 @@ int CodegenCppVisitor::int_variables_size() const { * different variable names, we rely on backend-specific read_ion_variable_name * and write_ion_variable_name method which will be overloaded. */ -std::vector CodegenCppVisitor::ion_read_statements(BlockType type) { +std::vector CodegenCppVisitor::ion_read_statements(BlockType type) const { if (optimize_ion_variable_copies()) { return ion_read_statements_optimized(type); } @@ -584,14 +583,14 @@ std::vector CodegenCppVisitor::ion_read_statements(BlockType type) } -std::vector CodegenCppVisitor::ion_read_statements_optimized(BlockType type) { +std::vector CodegenCppVisitor::ion_read_statements_optimized(BlockType type) const { std::vector statements; for (const auto& ion: info.ions) { for (const auto& var: ion.writes) { if (ion.is_ionic_conc(var)) { auto variables = read_ion_variable_name(var); auto first = "ionvar." + variables.first; - auto second = get_variable_name(variables.second); + const auto& second = get_variable_name(variables.second); statements.push_back(fmt::format("{} = {};", first, second)); } } @@ -789,7 +788,7 @@ void CodegenCppVisitor::update_index_semantics() { } -std::vector CodegenCppVisitor::get_float_variables() { +std::vector CodegenCppVisitor::get_float_variables() const { // sort with definition order auto comparator = [](const SymbolType& first, const SymbolType& second) -> bool { return first->get_definition_order() < second->get_definition_order(); @@ -979,18 +978,21 @@ std::vector CodegenCppVisitor::get_int_variables() { /****************************************************************************************/ std::string CodegenCppVisitor::get_parameter_str(const ParamVector& params) { - std::string param{}; - for (auto iter = params.begin(); iter != params.end(); iter++) { - param += fmt::format("{}{} {}{}", - std::get<0>(*iter), - std::get<1>(*iter), - std::get<2>(*iter), - std::get<3>(*iter)); - if (!nmodl::utils::is_last(iter, params)) { - param += ", "; + std::string str; + bool is_first = true; + for (const auto& param: params) { + if (is_first) { + is_first = false; + } else { + str += ", "; } + str += fmt::format("{}{} {}{}", + std::get<0>(param), + std::get<1>(param), + std::get<2>(param), + std::get<3>(param)); } - return param; + return str; } @@ -1172,25 +1174,25 @@ bool CodegenCppVisitor::optimize_ion_variable_copies() const { void CodegenCppVisitor::print_memory_allocation_routine() const { printer->add_newline(2); auto args = "size_t num, size_t size, size_t alignment = 16"; - printer->fmt_start_block("static inline void* mem_alloc({})", args); + printer->fmt_push_block("static inline void* mem_alloc({})", args); printer->add_line("void* ptr;"); printer->add_line("posix_memalign(&ptr, alignment, num*size);"); printer->add_line("memset(ptr, 0, size);"); printer->add_line("return ptr;"); - printer->end_block(1); + printer->pop_block(1); printer->add_newline(2); - printer->start_block("static inline void mem_free(void* ptr)"); + printer->push_block("static inline void mem_free(void* ptr)"); printer->add_line("free(ptr);"); - printer->end_block(1); + printer->pop_block(1); } void CodegenCppVisitor::print_abort_routine() const { printer->add_newline(2); - printer->start_block("static inline void coreneuron_abort()"); + printer->push_block("static inline void coreneuron_abort()"); printer->add_line("abort();"); - printer->end_block(1); + printer->pop_block(1); } @@ -1222,7 +1224,7 @@ std::string CodegenCppVisitor::global_var_struct_type_qualifier() { } void CodegenCppVisitor::print_global_var_struct_decl() { - printer->fmt_line("{} {};", global_struct(), global_struct_instance()); + printer->add_line(global_struct(), ' ', global_struct_instance(), ';'); } /****************************************************************************************/ @@ -1240,10 +1242,10 @@ void CodegenCppVisitor::print_statement_block(const ast::StatementBlock& node, bool open_brace, bool close_brace) { if (open_brace) { - printer->start_block(); + printer->push_block(); } - auto statements = node.get_statements(); + const auto& statements = node.get_statements(); for (const auto& statement: statements) { if (statement_to_skip(*statement)) { continue; @@ -1254,8 +1256,8 @@ void CodegenCppVisitor::print_statement_block(const ast::StatementBlock& node, printer->add_indent(); } statement->accept(*this); - if (need_semicolon(statement.get())) { - printer->add_text(";"); + if (need_semicolon(*statement)) { + printer->add_text(';'); } if (!statement->is_mutex_lock() && !statement->is_mutex_unlock()) { printer->add_newline(); @@ -1263,13 +1265,13 @@ void CodegenCppVisitor::print_statement_block(const ast::StatementBlock& node, } if (close_brace) { - printer->end_block(); + printer->pop_block(); } } void CodegenCppVisitor::print_function_call(const FunctionCall& node) { - auto name = node.get_node_name(); + const auto& name = node.get_node_name(); auto function_name = name; if (defined_method(name)) { function_name = method_name(name); @@ -1290,8 +1292,8 @@ void CodegenCppVisitor::print_function_call(const FunctionCall& node) { return; } - auto arguments = node.get_arguments(); - printer->fmt_text("{}(", function_name); + const auto& arguments = node.get_arguments(); + printer->add_text(function_name, '('); if (defined_method(name)) { printer->add_text(internal_method_arguments()); @@ -1301,7 +1303,7 @@ void CodegenCppVisitor::print_function_call(const FunctionCall& node) { } print_vector_elements(arguments, ", "); - printer->add_text(")"); + printer->add_text(')'); } @@ -1363,12 +1365,12 @@ void CodegenCppVisitor::print_function_prototypes() { printer->add_newline(2); for (const auto& node: info.functions) { print_function_declaration(*node, node->get_node_name()); - printer->add_text(";"); + printer->add_text(';'); printer->add_newline(); } for (const auto& node: info.procedures) { print_function_declaration(*node, node->get_node_name()); - printer->add_text(";"); + printer->add_text(';'); printer->add_newline(); } codegen = false; @@ -1420,13 +1422,13 @@ void CodegenCppVisitor::print_table_check_function(const Block& node) { printer->add_newline(2); print_device_method_annotation(); - printer->fmt_start_block("void check_{}({})", - method_name(name), - get_parameter_str(internal_params)); + printer->fmt_push_block("void check_{}({})", + method_name(name), + get_parameter_str(internal_params)); { - printer->fmt_start_block("if ({} == 0)", use_table_var); + printer->fmt_push_block("if ({} == 0)", use_table_var); printer->add_line("return;"); - printer->end_block(1); + printer->pop_block(1); printer->add_line("static bool make_table = true;"); for (const auto& variable: depend_variables) { @@ -1434,27 +1436,27 @@ void CodegenCppVisitor::print_table_check_function(const Block& node) { } for (const auto& variable: depend_variables) { - auto var_name = variable->get_node_name(); - auto instance_name = get_variable_name(var_name); - printer->fmt_start_block("if (save_{} != {})", var_name, instance_name); + const auto& var_name = variable->get_node_name(); + const auto& instance_name = get_variable_name(var_name); + printer->fmt_push_block("if (save_{} != {})", var_name, instance_name); printer->add_line("make_table = true;"); - printer->end_block(1); + printer->pop_block(1); } - printer->start_block("if (make_table)"); + printer->push_block("if (make_table)"); { printer->add_line("make_table = false;"); printer->add_indent(); - printer->fmt_text("{} = ", tmin_name); + printer->add_text(tmin_name, " = "); from->accept(*this); - printer->add_text(";"); + printer->add_text(';'); printer->add_newline(); printer->add_indent(); printer->add_text("double tmax = "); to->accept(*this); - printer->add_text(";"); + printer->add_text(';'); printer->add_newline(); @@ -1462,7 +1464,7 @@ void CodegenCppVisitor::print_table_check_function(const Block& node) { printer->fmt_line("{} = 1./dx;", mfac_name); printer->fmt_line("double x = {};", tmin_name); - printer->fmt_start_block("for (std::size_t i = 0; i < {}; x += dx, i++)", with + 1); + printer->fmt_push_block("for (std::size_t i = 0; i < {}; x += dx, i++)", with + 1); auto function = method_name("f_" + name); if (node.is_procedure_block()) { printer->fmt_line("{}({}, x);", function, internal_method_arguments()); @@ -1487,7 +1489,7 @@ void CodegenCppVisitor::print_table_check_function(const Block& node) { function, internal_method_arguments()); } - printer->end_block(1); + printer->pop_block(1); for (const auto& variable: depend_variables) { auto var_name = variable->get_node_name(); @@ -1495,9 +1497,9 @@ void CodegenCppVisitor::print_table_check_function(const Block& node) { printer->fmt_line("save_{} = {};", var_name, instance_name); } } - printer->end_block(1); + printer->pop_block(1); } - printer->end_block(1); + printer->pop_block(1); } @@ -1513,10 +1515,10 @@ void CodegenCppVisitor::print_table_replacement_function(const ast::Block& node) printer->add_newline(2); print_function_declaration(node, name); - printer->start_block(); + printer->push_block(); { const auto& params = node.get_parameters(); - printer->fmt_start_block("if ({} == 0)", use_table_var); + printer->fmt_push_block("if ({} == 0)", use_table_var); if (node.is_procedure_block()) { printer->fmt_line("{}({}, {});", function_name, @@ -1529,13 +1531,13 @@ void CodegenCppVisitor::print_table_replacement_function(const ast::Block& node) internal_method_arguments(), params[0].get()->get_node_name()); } - printer->end_block(1); + printer->pop_block(1); printer->fmt_line("double xi = {} * ({} - {});", mfac_name, params[0].get()->get_node_name(), tmin_name); - printer->start_block("if (isnan(xi))"); + printer->push_block("if (isnan(xi))"); if (node.is_procedure_block()) { for (const auto& var: table_variables) { auto var_name = get_variable_name(var->get_node_name()); @@ -1552,9 +1554,9 @@ void CodegenCppVisitor::print_table_replacement_function(const ast::Block& node) } else { printer->add_line("return xi;"); } - printer->end_block(1); + printer->pop_block(1); - printer->fmt_start_block("if (xi <= 0. || xi >= {}.)", with); + printer->fmt_push_block("if (xi <= 0. || xi >= {}.)", with); printer->fmt_line("int index = (xi <= 0.) ? 0 : {};", with); if (node.is_procedure_block()) { for (const auto& variable: table_variables) { @@ -1576,7 +1578,7 @@ void CodegenCppVisitor::print_table_replacement_function(const ast::Block& node) auto table_name = get_variable_name("t_" + name); printer->fmt_line("return {}[index];", table_name); } - printer->end_block(1); + printer->pop_block(1); printer->add_line("int i = int(xi);"); printer->add_line("double theta = xi - double(i);"); @@ -1606,7 +1608,7 @@ void CodegenCppVisitor::print_table_replacement_function(const ast::Block& node) printer->fmt_line("return {0}[i] + theta * ({0}[i+1] - {0}[i]);", table_name); } } - printer->end_block(1); + printer->pop_block(1); } @@ -1619,7 +1621,7 @@ void CodegenCppVisitor::print_check_table_thread_function() { auto name = method_name("check_table_thread"); auto parameters = external_method_parameters(true); - printer->fmt_start_block("static void {} ({})", name, parameters); + printer->fmt_push_block("static void {} ({})", name, parameters); printer->add_line("setup_instance(nt, ml);"); printer->fmt_line("auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct()); printer->add_line("double v = 0;"); @@ -1630,7 +1632,7 @@ void CodegenCppVisitor::print_check_table_thread_function() { printer->fmt_line("{}({});", method_name_str, arguments); } - printer->end_block(1); + printer->pop_block(1); } @@ -1639,7 +1641,7 @@ void CodegenCppVisitor::print_function_or_procedure(const ast::Block& node, printer->add_newline(2); print_function_declaration(node, name); printer->add_text(" "); - printer->start_block(); + printer->push_block(); // function requires return variable declaration if (node.is_function_block()) { @@ -1651,7 +1653,7 @@ void CodegenCppVisitor::print_function_or_procedure(const ast::Block& node, print_statement_block(*node.get_statement_block(), false, false); printer->fmt_line("return ret_{};", name); - printer->end_block(1); + printer->pop_block(1); } @@ -1705,7 +1707,7 @@ void CodegenCppVisitor::print_function_tables(const ast::FunctionTableBlock& nod params.emplace_back("", "double", "", i->get_node_name()); } printer->fmt_line("double {}({})", method_name(name), get_parameter_str(params)); - printer->start_block(); + printer->push_block(); printer->fmt_line("double _arg[{}];", p.size()); for (size_t i = 0; i < p.size(); ++i) { printer->fmt_line("_arg[{}] = {};", i, p[i]->get_node_name()); @@ -1713,14 +1715,14 @@ void CodegenCppVisitor::print_function_tables(const ast::FunctionTableBlock& nod printer->fmt_line("return hoc_func_table({}, {}, _arg);", get_variable_name(std::string("_ptable_" + name), true), p.size()); - printer->end_block(1); + printer->pop_block(1); - printer->fmt_start_block("double table_{}()", method_name(name)); + printer->fmt_push_block("double table_{}()", method_name(name)); printer->fmt_line("hoc_spec_table(&{}, {});", get_variable_name(std::string("_ptable_" + name)), p.size()); printer->add_line("return 0.;"); - printer->end_block(1); + printer->pop_block(1); } /** @@ -1776,9 +1778,9 @@ void CodegenCppVisitor::print_functor_definition(const ast::EigenNewtonSolverBlo int N = node.get_n_state_vars()->get_value(); const auto functor_name = info.functor_names[&node]; - printer->fmt_start_block("struct {0}", functor_name); + printer->fmt_push_block("struct {0}", functor_name); printer->add_line("NrnThread* nt;"); - printer->fmt_line("{0}* inst;", instance_struct()); + printer->add_line(instance_struct(), "* inst;"); printer->add_line("int id, pnodecount;"); printer->add_line("double v;"); printer->add_line("const Datum* indexes;"); @@ -1792,9 +1794,9 @@ void CodegenCppVisitor::print_functor_definition(const ast::EigenNewtonSolverBlo print_statement_block(*node.get_variable_block(), false, false); printer->add_newline(); - printer->start_block("void initialize()"); + printer->push_block("void initialize()"); print_statement_block(*node.get_initialize_block(), false, false); - printer->end_block(2); + printer->pop_block(2); printer->fmt_line( "{0}(NrnThread* nt, {1}* inst, int id, int pnodecount, double v, const Datum* indexes, " @@ -1817,19 +1819,19 @@ void CodegenCppVisitor::print_functor_definition(const ast::EigenNewtonSolverBlo float_type, N, is_functor_const(variable_block, functor_block) ? "const " : ""); - printer->start_block(); + printer->push_block(); printer->fmt_line("const {}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type); printer->fmt_line("{}* nmodl_eigen_j = nmodl_eigen_jm.data();", float_type); printer->fmt_line("{}* nmodl_eigen_f = nmodl_eigen_fm.data();", float_type); print_statement_block(functor_block, false, false); - printer->end_block(2); + printer->pop_block(2); // assign newton solver results in matrix X to state vars - printer->start_block("void finalize()"); + printer->push_block("void finalize()"); print_statement_block(*node.get_finalize_block(), false, false); - printer->end_block(1); + printer->pop_block(1); - printer->end_block(";"); + printer->pop_block(";"); } void CodegenCppVisitor::visit_eigen_newton_solver_block(const ast::EigenNewtonSolverBlock& node) { @@ -1885,12 +1887,12 @@ void CodegenCppVisitor::visit_eigen_linear_solver_block(const ast::EigenLinearSo void CodegenCppVisitor::print_eigen_linear_solver(const std::string& float_type, int N) { if (N <= 4) { // Faster compared to LU, given the template specialization in Eigen. - printer->add_line("bool invertible;"); - printer->add_line("nmodl_eigen_jm.computeInverseWithCheck(nmodl_eigen_jm_inv,invertible);"); - printer->add_line("nmodl_eigen_xm = nmodl_eigen_jm_inv*nmodl_eigen_fm;"); - printer->add_line( - "if (!invertible) assert(false && \"Singular or ill-conditioned matrix " - "(Eigen::inverse)!\");"); + printer->add_multi_line(R"CODE( + bool invertible; + nmodl_eigen_jm.computeInverseWithCheck(nmodl_eigen_jm_inv,invertible); + nmodl_eigen_xm = nmodl_eigen_jm_inv*nmodl_eigen_fm; + if (!invertible) assert(false && "Singular or ill-conditioned matrix (Eigen::inverse)!"); + )CODE"); } else { // In Eigen the default storage order is ColMajor. // Crout's implementation requires matrices stored in RowMajor order (C-style arrays). @@ -1935,7 +1937,7 @@ std::string CodegenCppVisitor::internal_method_arguments() { * @todo: figure out how to correctly handle qualifiers */ CodegenCppVisitor::ParamVector CodegenCppVisitor::internal_method_parameters() { - auto params = ParamVector(); + ParamVector params; params.emplace_back("", "int", "", "id"); params.emplace_back("", "int", "", "pnodecount"); params.emplace_back("", fmt::format("{}*", instance_struct()), "", "inst"); @@ -1951,12 +1953,12 @@ CodegenCppVisitor::ParamVector CodegenCppVisitor::internal_method_parameters() { } -std::string CodegenCppVisitor::external_method_arguments() { +const char* CodegenCppVisitor::external_method_arguments() noexcept { return "id, pnodecount, data, indexes, thread, nt, ml, v"; } -std::string CodegenCppVisitor::external_method_parameters(bool table) { +const char* CodegenCppVisitor::external_method_parameters(bool table) noexcept { if (table) { return "int id, int pnodecount, double* data, Datum* indexes, " "ThreadDatum* thread, NrnThread* nt, Memb_list* ml, int tml_id"; @@ -1966,7 +1968,7 @@ std::string CodegenCppVisitor::external_method_parameters(bool table) { } -std::string CodegenCppVisitor::nrn_thread_arguments() { +std::string CodegenCppVisitor::nrn_thread_arguments() const { if (ion_variable_struct_required()) { return "id, pnodecount, ionvar, data, indexes, thread, nt, ml, v"; } @@ -2142,24 +2144,24 @@ void CodegenCppVisitor::print_nmodl_constants() { void CodegenCppVisitor::print_first_pointer_var_index_getter() { printer->add_newline(2); print_device_method_annotation(); - printer->start_block("static inline int first_pointer_var_index()"); + printer->push_block("static inline int first_pointer_var_index()"); printer->fmt_line("return {};", info.first_pointer_var_index); - printer->end_block(1); + printer->pop_block(1); } void CodegenCppVisitor::print_num_variable_getter() { printer->add_newline(2); print_device_method_annotation(); - printer->start_block("static inline int float_variables_size()"); + printer->push_block("static inline int float_variables_size()"); printer->fmt_line("return {};", float_variables_size()); - printer->end_block(1); + printer->pop_block(1); printer->add_newline(2); print_device_method_annotation(); - printer->start_block("static inline int int_variables_size()"); + printer->push_block("static inline int int_variables_size()"); printer->fmt_line("return {};", int_variables_size()); - printer->end_block(1); + printer->pop_block(1); } @@ -2169,42 +2171,42 @@ void CodegenCppVisitor::print_net_receive_arg_size_getter() { } printer->add_newline(2); print_device_method_annotation(); - printer->start_block("static inline int num_net_receive_args()"); + printer->push_block("static inline int num_net_receive_args()"); printer->fmt_line("return {};", info.num_net_receive_parameters); - printer->end_block(1); + printer->pop_block(1); } void CodegenCppVisitor::print_mech_type_getter() { printer->add_newline(2); print_device_method_annotation(); - printer->start_block("static inline int get_mech_type()"); + printer->push_block("static inline int get_mech_type()"); // false => get it from the host-only global struct, not the instance structure printer->fmt_line("return {};", get_variable_name("mech_type", false)); - printer->end_block(1); + printer->pop_block(1); } void CodegenCppVisitor::print_memb_list_getter() { printer->add_newline(2); print_device_method_annotation(); - printer->start_block("static inline Memb_list* get_memb_list(NrnThread* nt)"); - printer->start_block("if (!nt->_ml_list)"); + printer->push_block("static inline Memb_list* get_memb_list(NrnThread* nt)"); + printer->push_block("if (!nt->_ml_list)"); printer->add_line("return nullptr;"); - printer->end_block(1); + printer->pop_block(1); printer->add_line("return nt->_ml_list[get_mech_type()];"); - printer->end_block(1); + printer->pop_block(1); } void CodegenCppVisitor::print_namespace_start() { printer->add_newline(2); - printer->start_block("namespace coreneuron"); + printer->push_block("namespace coreneuron"); } void CodegenCppVisitor::print_namespace_stop() { - printer->end_block(1); + printer->pop_block(1); } @@ -2230,33 +2232,33 @@ void CodegenCppVisitor::print_thread_getters() { printer->add_line("/** thread specific helper routines for derivimplicit */"); printer->add_newline(1); - printer->fmt_start_block("static inline int* deriv{}_advance(ThreadDatum* thread)", list); + printer->fmt_push_block("static inline int* deriv{}_advance(ThreadDatum* thread)", list); printer->fmt_line("return &(thread[{}].i);", tid); - printer->end_block(2); + printer->pop_block(2); - printer->fmt_start_block("static inline int dith{}()", list); + printer->fmt_push_block("static inline int dith{}()", list); printer->fmt_line("return {};", tid+1); - printer->end_block(2); + printer->pop_block(2); - printer->fmt_start_block("static inline void** newtonspace{}(ThreadDatum* thread)", list); + printer->fmt_push_block("static inline void** newtonspace{}(ThreadDatum* thread)", list); printer->fmt_line("return &(thread[{}]._pvoid);", tid+2); - printer->end_block(1); + printer->pop_block(1); } if (info.vectorize && !info.thread_variables.empty()) { printer->add_newline(2); printer->add_line("/** tid for thread variables */"); - printer->start_block("static inline int thread_var_tid()"); + printer->push_block("static inline int thread_var_tid()"); printer->fmt_line("return {};", info.thread_var_thread_id); - printer->end_block(1); + printer->pop_block(1); } if (info.vectorize && !info.top_local_variables.empty()) { printer->add_newline(2); printer->add_line("/** tid for top local tread variables */"); - printer->start_block("static inline int top_local_var_tid()"); + printer->push_block("static inline int top_local_var_tid()"); printer->fmt_line("return {};", info.top_local_thread_id); - printer->end_block(1); + printer->pop_block(1); } // clang-format on } @@ -2341,7 +2343,7 @@ std::string CodegenCppVisitor::update_if_ion_variable_name(const std::string& na std::string CodegenCppVisitor::get_variable_name(const std::string& name, bool use_instance) const { - std::string varname = update_if_ion_variable_name(name); + const std::string& varname = update_if_ion_variable_name(name); // clang-format off auto symbol_comparator = [&varname](const SymbolType& sym) { @@ -2419,39 +2421,43 @@ void CodegenCppVisitor::print_backend_info() { auto version = nmodl::Version::NMODL_VERSION + " [" + nmodl::Version::GIT_REVISION + "]"; printer->add_line("/*********************************************************"); - printer->fmt_line("Model Name : {}", info.mod_suffix); - printer->fmt_line("Filename : {}", info.mod_file + ".mod"); - printer->fmt_line("NMODL Version : {}", nmodl_version()); + printer->add_line("Model Name : ", info.mod_suffix); + printer->add_line("Filename : ", info.mod_file, ".mod"); + printer->add_line("NMODL Version : ", nmodl_version()); printer->fmt_line("Vectorized : {}", info.vectorize); printer->fmt_line("Threadsafe : {}", info.thread_safe); - printer->fmt_line("Created : {}", stringutils::trim(data_time_str)); - printer->fmt_line("Backend : {}", backend_name()); - printer->fmt_line("NMODL Compiler : {}", version); + printer->add_line("Created : ", stringutils::trim(data_time_str)); + printer->add_line("Backend : ", backend_name()); + printer->add_line("NMODL Compiler : ", version); printer->add_line("*********************************************************/"); } void CodegenCppVisitor::print_standard_includes() { printer->add_newline(); - printer->add_line("#include "); - printer->add_line("#include "); - printer->add_line("#include "); - printer->add_line("#include "); + printer->add_multi_line(R"CODE( + #include + #include + #include + #include + )CODE"); } void CodegenCppVisitor::print_coreneuron_includes() { printer->add_newline(); - printer->add_line("#include "); - printer->add_line("#include "); - printer->add_line("#include "); - printer->add_line("#include "); - printer->add_line("#include "); - printer->add_line("#include "); - printer->add_line("#include "); - printer->add_line("#include "); - printer->add_line("#include "); - printer->add_line("#include "); + printer->add_multi_line(R"CODE( + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + )CODE"); if (info.eigen_newton_solver_exist) { printer->add_line("#include "); } @@ -2487,14 +2493,14 @@ void CodegenCppVisitor::print_coreneuron_includes() { * same for some variables to keep same code as neuron. */ // NOLINTNEXTLINE(readability-function-cognitive-complexity) -void CodegenCppVisitor::print_mechanism_global_var_structure(bool print_initialisers) { - const auto value_initialise = print_initialisers ? "{}" : ""; +void CodegenCppVisitor::print_mechanism_global_var_structure(bool print_initializers) { + const auto value_initialise = print_initializers ? "{}" : ""; const auto qualifier = global_var_struct_type_qualifier(); auto float_type = default_float_data_type(); printer->add_newline(2); printer->add_line("/** all global variables */"); - printer->fmt_start_block("struct {}", global_struct()); + printer->fmt_push_block("struct {}", global_struct()); for (const auto& ion: info.ions) { auto name = fmt::format("{}_type", ion.name); @@ -2575,7 +2581,7 @@ void CodegenCppVisitor::print_mechanism_global_var_structure(bool print_initiali qualifier, float_type, name, - print_initialisers ? fmt::format("{{{:g}}}", value) : std::string{}); + print_initializers ? fmt::format("{{{:g}}}", value) : std::string{}); } codegen_global_variables.push_back(var); } @@ -2588,7 +2594,7 @@ void CodegenCppVisitor::print_mechanism_global_var_structure(bool print_initiali qualifier, float_type, name, - print_initialisers ? fmt::format("{{{:g}}}", value) : std::string{}); + print_initializers ? fmt::format("{{{:g}}}", value) : std::string{}); codegen_global_variables.push_back(var); } @@ -2601,7 +2607,7 @@ void CodegenCppVisitor::print_mechanism_global_var_structure(bool print_initiali info.prime_variables_by_order.size())}; } auto const initializer_list = [&](auto const& primes, const char* prefix) -> std::string { - if (!print_initialisers) { + if (!print_initializers) { return {}; } std::string list{"{"}; @@ -2637,7 +2643,7 @@ void CodegenCppVisitor::print_mechanism_global_var_structure(bool print_initiali } if (info.table_count > 0) { - printer->fmt_line("{}double usetable{};", qualifier, print_initialisers ? "{1}" : ""); + printer->fmt_line("{}double usetable{};", qualifier, print_initializers ? "{1}" : ""); codegen_global_variables.push_back(make_symbol(naming::USE_TABLE_VARIABLE)); for (const auto& block: info.functions_with_table) { @@ -2681,7 +2687,7 @@ void CodegenCppVisitor::print_mechanism_global_var_structure(bool print_initiali codegen_global_variables.push_back(make_symbol("ext_call_thread")); } - printer->end_block(";"); + printer->pop_block(";"); print_global_var_struct_assertions(); print_global_var_struct_decl(); @@ -2717,7 +2723,7 @@ void CodegenCppVisitor::print_mechanism_info() { if (v->is_array()) { name += fmt::format("[{}]", v->get_length()); } - printer->add_line(add_escape_quote(name) + ","); + printer->add_line(add_escape_quote(name), ","); } }; @@ -2725,8 +2731,8 @@ void CodegenCppVisitor::print_mechanism_info() { printer->add_line("/** channel information */"); printer->add_line("static const char *mechanism[] = {"); printer->increase_indent(); - printer->add_line(add_escape_quote(nmodl_version()) + ","); - printer->add_line(add_escape_quote(info.mod_suffix) + ","); + printer->add_line(add_escape_quote(nmodl_version()), ","); + printer->add_line(add_escape_quote(info.mod_suffix), ","); variable_printer(info.range_parameter_vars); printer->add_line("0,"); variable_printer(info.range_assigned_vars); @@ -2853,16 +2859,16 @@ static std::string get_register_type_for_ba_block(const ast::Block* block) { void CodegenCppVisitor::print_mechanism_register() { printer->add_newline(2); printer->add_line("/** register channel with the simulator */"); - printer->fmt_start_block("void _{}_reg()", info.mod_file); + printer->fmt_push_block("void _{}_reg()", info.mod_file); // type related information auto suffix = add_escape_quote(info.mod_suffix); printer->add_newline(); printer->fmt_line("int mech_type = nrn_get_mechtype({});", suffix); printer->fmt_line("{} = mech_type;", get_variable_name("mech_type", false)); - printer->start_block("if (mech_type == -1)"); + printer->push_block("if (mech_type == -1)"); printer->add_line("return;"); - printer->end_block(1); + printer->pop_block(1); printer->add_newline(); printer->add_line("_nrn_layout_reg(mech_type, 0);"); // 0 for SoA @@ -2988,7 +2994,7 @@ void CodegenCppVisitor::print_mechanism_register() { // register variables for hoc printer->add_line("hoc_register_var(hoc_scalar_double, hoc_vector_double, NULL);"); - printer->end_block(1); + printer->pop_block(1); } @@ -3000,7 +3006,7 @@ void CodegenCppVisitor::print_thread_memory_callbacks() { // thread_mem_init callback printer->add_newline(2); printer->add_line("/** thread memory allocation callback */"); - printer->start_block("static void thread_mem_init(ThreadDatum* thread) "); + printer->push_block("static void thread_mem_init(ThreadDatum* thread) "); if (info.vectorize && info.derivimplicit_used()) { printer->fmt_line("thread[dith{}()].pval = nullptr;", info.derivimplicit_list_num); @@ -3015,19 +3021,19 @@ void CodegenCppVisitor::print_thread_memory_callbacks() { auto thread_data = get_variable_name("thread_data"); auto thread_data_in_use = get_variable_name("thread_data_in_use"); auto allocation = fmt::format("(double*)mem_alloc({}, sizeof(double))", length); - printer->fmt_start_block("if ({})", thread_data_in_use); + printer->fmt_push_block("if ({})", thread_data_in_use); printer->fmt_line("thread[thread_var_tid()].pval = {};", allocation); - printer->restart_block("else"); + printer->chain_block("else"); printer->fmt_line("thread[thread_var_tid()].pval = {};", thread_data); printer->fmt_line("{} = 1;", thread_data_in_use); - printer->end_block(1); + printer->pop_block(1); } - printer->end_block(3); + printer->pop_block(3); // thread_mem_cleanup callback printer->add_line("/** thread memory cleanup callback */"); - printer->start_block("static void thread_mem_cleanup(ThreadDatum* thread) "); + printer->push_block("static void thread_mem_cleanup(ThreadDatum* thread) "); // clang-format off if (info.vectorize && info.derivimplicit_used()) { @@ -3044,29 +3050,29 @@ void CodegenCppVisitor::print_thread_memory_callbacks() { if (info.thread_var_data_size != 0) { auto thread_data = get_variable_name("thread_data"); auto thread_data_in_use = get_variable_name("thread_data_in_use"); - printer->fmt_start_block("if (thread[thread_var_tid()].pval == {})", thread_data); + printer->fmt_push_block("if (thread[thread_var_tid()].pval == {})", thread_data); printer->fmt_line("{} = 0;", thread_data_in_use); - printer->restart_block("else"); + printer->chain_block("else"); printer->add_line("free(thread[thread_var_tid()].pval);"); - printer->end_block(1); + printer->pop_block(1); } - printer->end_block(1); + printer->pop_block(1); } -void CodegenCppVisitor::print_mechanism_range_var_structure(bool print_initialisers) { - auto const value_initialise = print_initialisers ? "{}" : ""; +void CodegenCppVisitor::print_mechanism_range_var_structure(bool print_initializers) { + auto const value_initialise = print_initializers ? "{}" : ""; auto int_type = default_int_data_type(); printer->add_newline(2); printer->add_line("/** all mechanism instance variables and global variables */"); - printer->fmt_start_block("struct {} ", instance_struct()); + printer->fmt_push_block("struct {} ", instance_struct()); for (auto const& [var, type]: info.neuron_global_variables) { auto const name = var->get_name(); printer->fmt_line("{}* {}{};", type, name, - print_initialisers ? fmt::format("{{&coreneuron::{}}}", name) + print_initializers ? fmt::format("{{&coreneuron::{}}}", name) : std::string{}); } for (auto& var: codegen_float_variables) { @@ -3090,9 +3096,9 @@ void CodegenCppVisitor::print_mechanism_range_var_structure(bool print_initialis printer->fmt_line("{}* {}{};", global_struct(), naming::INST_GLOBAL_MEMBER, - print_initialisers ? fmt::format("{{&{}}}", global_struct_instance()) + print_initializers ? fmt::format("{{&{}}}", global_struct_instance()) : std::string{}); - printer->end_block(";"); + printer->pop_block(";"); } @@ -3102,7 +3108,7 @@ void CodegenCppVisitor::print_ion_var_structure() { } printer->add_newline(2); printer->add_line("/** ion write variables */"); - printer->start_block("struct IonCurVar"); + printer->push_block("struct IonCurVar"); std::string float_type = default_float_data_type(); std::vector members; @@ -3122,14 +3128,15 @@ void CodegenCppVisitor::print_ion_var_structure() { print_ion_var_constructor(members); - printer->end_block(";"); + printer->pop_block(";"); } void CodegenCppVisitor::print_ion_var_constructor(const std::vector& members) { // constructor printer->add_newline(); - printer->add_line("IonCurVar() : ", 0); + printer->add_indent(); + printer->add_text("IonCurVar() : "); for (int i = 0; i < members.size(); i++) { printer->fmt_text("{}(0)", members[i]); if (i + 1 < members.size()) { @@ -3155,14 +3162,14 @@ void CodegenCppVisitor::print_setup_range_variable() { auto type = float_data_type(); printer->add_newline(2); printer->add_line("/** allocate and setup array for range variable */"); - printer->fmt_start_block("static inline {}* setup_range_variable(double* variable, int n)", - type); + printer->fmt_push_block("static inline {}* setup_range_variable(double* variable, int n)", + type); printer->fmt_line("{0}* data = ({0}*) mem_alloc(n, sizeof({0}));", type); - printer->start_block("for(size_t i = 0; i < n; i++)"); + printer->push_block("for(size_t i = 0; i < n; i++)"); printer->add_line("data[i] = variable[i];"); - printer->end_block(1); + printer->pop_block(1); printer->add_line("return data;"); - printer->end_block(1); + printer->pop_block(1); } @@ -3195,8 +3202,8 @@ void CodegenCppVisitor::print_instance_variable_setup() { printer->add_newline(); printer->add_line("// Allocate instance structure"); - printer->fmt_start_block("static void {}(NrnThread* nt, Memb_list* ml, int type)", - method_name(naming::NRN_PRIVATE_CONSTRUCTOR_METHOD)); + printer->fmt_push_block("static void {}(NrnThread* nt, Memb_list* ml, int type)", + method_name(naming::NRN_PRIVATE_CONSTRUCTOR_METHOD)); printer->add_line("assert(!ml->instance);"); printer->add_line("assert(!ml->global_variables);"); printer->add_line("assert(ml->global_variables_size == 0);"); @@ -3207,7 +3214,7 @@ void CodegenCppVisitor::print_instance_variable_setup() { printer->add_line("ml->instance = inst;"); printer->fmt_line("ml->global_variables = inst->{};", naming::INST_GLOBAL_MEMBER); printer->fmt_line("ml->global_variables_size = sizeof({});", global_struct()); - printer->end_block(2); + printer->pop_block(2); auto const cast_inst_and_assert_validity = [&]() { printer->fmt_line("auto* const inst = static_cast<{}*>(ml->instance);", instance_struct()); @@ -3225,19 +3232,21 @@ void CodegenCppVisitor::print_instance_variable_setup() { print_instance_struct_transfer_routine_declarations(); printer->add_line("// Deallocate the instance structure"); - printer->fmt_start_block("static void {}(NrnThread* nt, Memb_list* ml, int type)", - method_name(naming::NRN_PRIVATE_DESTRUCTOR_METHOD)); + printer->fmt_push_block("static void {}(NrnThread* nt, Memb_list* ml, int type)", + method_name(naming::NRN_PRIVATE_DESTRUCTOR_METHOD)); cast_inst_and_assert_validity(); print_instance_struct_delete_from_device(); - printer->add_line("delete inst;"); - printer->add_line("ml->instance = nullptr;"); - printer->add_line("ml->global_variables = nullptr;"); - printer->add_line("ml->global_variables_size = 0;"); - printer->end_block(2); + printer->add_multi_line(R"CODE( + delete inst; + ml->instance = nullptr; + ml->global_variables = nullptr; + ml->global_variables_size = 0; + )CODE"); + printer->pop_block(2); printer->add_line("/** initialize mechanism instance variables */"); - printer->start_block("static inline void setup_instance(NrnThread* nt, Memb_list* ml)"); + printer->push_block("static inline void setup_instance(NrnThread* nt, Memb_list* ml)"); cast_inst_and_assert_validity(); std::string stride; @@ -3287,7 +3296,7 @@ void CodegenCppVisitor::print_instance_variable_setup() { ptr_members.push_back(std::move(name)); } print_instance_struct_copy_to_device(); - printer->end_block(2); // setup_instance + printer->pop_block(2); // setup_instance print_instance_struct_transfer_routines(ptr_members); } @@ -3353,7 +3362,7 @@ void CodegenCppVisitor::print_global_function_common_code(BlockType type, } print_global_method_annotation(); - printer->fmt_start_block("void {}({})", method, args); + printer->fmt_push_block("void {}({})", method, args); if (type != BlockType::Destructor && type != BlockType::Constructor) { // We do not (currently) support DESTRUCTOR and CONSTRUCTOR blocks // running anything on the GPU. @@ -3363,11 +3372,13 @@ void CodegenCppVisitor::print_global_function_common_code(BlockType type, /// Related to https://github.com/BlueBrain/nmodl/issues/692 printer->add_line("#ifndef CORENEURON_BUILD"); } - printer->add_line("int nodecount = ml->nodecount;"); - printer->add_line("int pnodecount = ml->_nodecount_padded;"); - printer->add_line("const int* node_index = ml->nodeindices;"); - printer->add_line("double* data = ml->data;"); - printer->add_line("const double* voltage = nt->_actual_v;"); + printer->add_multi_line(R"CODE( + int nodecount = ml->nodecount; + int pnodecount = ml->_nodecount_padded; + const int* node_index = ml->nodeindices; + double* data = ml->data; + const double* voltage = nt->_actual_v; + )CODE"); if (type == BlockType::Equation) { printer->add_line("double* vec_rhs = nt->_actual_rhs;"); @@ -3401,13 +3412,13 @@ void CodegenCppVisitor::print_nrn_init(bool skip_init_check) { print_deriv_advance_flag_transfer_to_device(); printer->fmt_line("auto ns = newtonspace{}(thread);", list_num); printer->fmt_line("auto& th = thread[dith{}()];", list_num); - printer->start_block("if (*ns == nullptr)"); + printer->push_block("if (*ns == nullptr)"); printer->fmt_line("int vec_size = 2*{}*pnodecount*sizeof(double);", nequation); printer->fmt_line("double* vec = makevector(vec_size);", nequation); printer->fmt_line("th.pval = vec;", list_num); printer->fmt_line("*ns = nrn_cons_newtonspace({}, pnodecount);", nequation); print_newtonspace_transfer_to_device(); - printer->end_block(1); + printer->pop_block(1); // clang-format on } @@ -3417,7 +3428,7 @@ void CodegenCppVisitor::print_nrn_init(bool skip_init_check) { print_global_variable_device_update_annotation(); if (skip_init_check) { - printer->start_block("if (_nrn_skip_initmodel == 0)"); + printer->push_block("if (_nrn_skip_initmodel == 0)"); } if (!info.changed_dt.empty()) { @@ -3430,21 +3441,21 @@ void CodegenCppVisitor::print_nrn_init(bool skip_init_check) { } print_channel_iteration_block_parallel_hint(BlockType::Initial, info.initial_node); - printer->start_block("for (int id = 0; id < nodecount; id++)"); + printer->push_block("for (int id = 0; id < nodecount; id++)"); if (info.net_receive_node != nullptr) { printer->fmt_line("{} = -1e20;", get_variable_name("tsave")); } print_initial_block(info.initial_node); - printer->end_block(1); + printer->pop_block(1); if (!info.changed_dt.empty()) { printer->fmt_line("{} = _save_prev_dt;", get_variable_name(naming::NTHREAD_DT_VARIABLE)); print_dt_update_to_device(); } - printer->end_block(1); + printer->pop_block(1); if (info.derivimplicit_used()) { printer->add_line("deriv_advance_flag = 1;"); @@ -3457,7 +3468,7 @@ void CodegenCppVisitor::print_nrn_init(bool skip_init_check) { print_kernel_data_present_annotation_block_end(); if (skip_init_check) { - printer->end_block(1); + printer->pop_block(1); } codegen = false; } @@ -3487,7 +3498,7 @@ void CodegenCppVisitor::print_before_after_block(const ast::Block* node, size_t print_global_function_common_code(BlockType::BeforeAfter, function_name); print_channel_iteration_block_parallel_hint(BlockType::BeforeAfter, node); - printer->start_block("for (int id = 0; id < nodecount; id++)"); + printer->push_block("for (int id = 0; id < nodecount; id++)"); printer->add_line("int node_id = node_index[id];"); printer->add_line("double v = voltage[node_id];"); @@ -3512,8 +3523,8 @@ void CodegenCppVisitor::print_before_after_block(const ast::Block* node, size_t } /// loop end including data annotation block - printer->end_block(1); - printer->end_block(1); + printer->pop_block(1); + printer->pop_block(1); print_kernel_data_present_annotation_block_end(); codegen = false; @@ -3527,7 +3538,7 @@ void CodegenCppVisitor::print_nrn_constructor() { print_statement_block(*block, false, false); } printer->add_line("#endif"); - printer->end_block(1); + printer->pop_block(1); } @@ -3539,7 +3550,7 @@ void CodegenCppVisitor::print_nrn_destructor() { print_statement_block(*block, false, false); } printer->add_line("#endif"); - printer->end_block(1); + printer->pop_block(1); } @@ -3556,9 +3567,9 @@ void CodegenCppVisitor::print_functors_definitions() { void CodegenCppVisitor::print_nrn_alloc() { printer->add_newline(2); auto method = method_name(naming::NRN_ALLOC_METHOD); - printer->fmt_start_block("static void {}(double* data, Datum* indexes, int type)", method); + printer->fmt_push_block("static void {}(double* data, Datum* indexes, int type)", method); printer->add_line("// do nothing"); - printer->end_block(1); + printer->pop_block(1); } /** @@ -3574,19 +3585,19 @@ void CodegenCppVisitor::print_watch_activate() { printer->add_newline(2); auto inst = fmt::format("{}* inst", instance_struct()); - printer->fmt_start_block( + printer->fmt_push_block( "static void nrn_watch_activate({}, int id, int pnodecount, int watch_id, " "double v, bool &watch_remove)", inst); // initialize all variables only during first watch statement - printer->start_block("if (watch_remove == false)"); + printer->push_block("if (watch_remove == false)"); for (int i = 0; i < info.watch_count; i++) { auto name = get_variable_name(fmt::format("watch{}", i + 1)); printer->fmt_line("{} = 0;", name); } printer->add_line("watch_remove = true;"); - printer->end_block(1); + printer->pop_block(1); /** * \todo Similar to neuron/coreneuron we are using @@ -3594,7 +3605,7 @@ void CodegenCppVisitor::print_watch_activate() { */ for (int i = 0; i < info.watch_statements.size(); i++) { auto statement = info.watch_statements[i]; - printer->fmt_start_block("if (watch_id == {})", i); + printer->fmt_push_block("if (watch_id == {})", i); auto varname = get_variable_name(fmt::format("watch{}", i + 1)); printer->add_indent(); @@ -3604,9 +3615,9 @@ void CodegenCppVisitor::print_watch_activate() { printer->add_text(");"); printer->add_newline(); - printer->end_block(1); + printer->pop_block(1); } - printer->end_block(1); + printer->pop_block(1); codegen = false; } @@ -3630,7 +3641,7 @@ void CodegenCppVisitor::print_watch_check() { // we don't need to have ivdep pragma related check print_channel_iteration_block_parallel_hint(BlockType::Watch, nullptr); - printer->start_block("for (int id = 0; id < nodecount; id++)"); + printer->push_block("for (int id = 0; id < nodecount; id++)"); if (info.is_voltage_used_by_watch_statements()) { printer->add_line("int node_id = node_index[id];"); @@ -3643,11 +3654,11 @@ void CodegenCppVisitor::print_watch_check() { for (int i = 0; i < info.watch_statements.size(); i++) { auto statement = info.watch_statements[i]; - auto watch = statement->get_statements().front(); - auto varname = get_variable_name(fmt::format("watch{}", i + 1)); + const auto& watch = statement->get_statements().front(); + const auto& varname = get_variable_name(fmt::format("watch{}", i + 1)); // start block 1 - printer->fmt_start_block("if ({}&2 && watch_untriggered)", varname); + printer->fmt_push_block("if ({}&2 && watch_untriggered)", varname); // start block 2 printer->add_indent(); @@ -3658,15 +3669,15 @@ void CodegenCppVisitor::print_watch_check() { printer->increase_indent(); // start block 3 - printer->fmt_start_block("if (({}&1) == 0)", varname); + printer->fmt_push_block("if (({}&1) == 0)", varname); printer->add_line("watch_untriggered = false;"); - auto tqitem = get_variable_name("tqitem"); - auto point_process = get_variable_name("point_process"); + const auto& tqitem = get_variable_name("tqitem"); + const auto& point_process = get_variable_name("point_process"); printer->add_indent(); printer->add_text("net_send_buffering("); - auto t = get_variable_name("t"); + const auto& t = get_variable_name("t"); printer->fmt_text("nt, ml->_net_send_buffer, 0, {}, -1, {}, {}+0.0, ", tqitem, point_process, @@ -3674,34 +3685,37 @@ void CodegenCppVisitor::print_watch_check() { watch->get_value()->accept(*this); printer->add_text(");"); printer->add_newline(); - printer->end_block(1); + printer->pop_block(1); - printer->fmt_line("{} = 3;", varname); + printer->add_line(varname, " = 3;"); // end block 3 // start block 3 printer->decrease_indent(); - printer->start_block("} else"); - printer->fmt_line("{} = 2;", varname); - printer->end_block(1); + printer->push_block("} else"); + printer->add_line(varname, " = 2;"); + printer->pop_block(1); // end block 3 - printer->end_block(1); + printer->pop_block(1); // end block 1 } - printer->end_block(1); + printer->pop_block(1); print_send_event_move(); print_kernel_data_present_annotation_block_end(); - printer->end_block(1); + printer->pop_block(1); codegen = false; } void CodegenCppVisitor::print_net_receive_common_code(const Block& node, bool need_mech_inst) { - printer->add_line("int tid = pnt->_tid;"); - printer->add_line("int id = pnt->_i_instance;"); - printer->add_line("double v = 0;"); + printer->add_multi_line(R"CODE( + int tid = pnt->_tid; + int id = pnt->_i_instance; + double v = 0; + )CODE"); + if (info.artificial_cell || node.is_initial_block()) { printer->add_line("NrnThread* nt = nrn_threads + tid;"); printer->add_line("Memb_list* ml = nt->_ml_list[pnt->_type];"); @@ -3710,12 +3724,14 @@ void CodegenCppVisitor::print_net_receive_common_code(const Block& node, bool ne print_kernel_data_present_annotation_block_begin(); } - printer->add_line("int nodecount = ml->nodecount;"); - printer->add_line("int pnodecount = ml->_nodecount_padded;"); - printer->add_line("double* data = ml->data;"); - printer->add_line("double* weights = nt->weights;"); - printer->add_line("Datum* indexes = ml->pdata;"); - printer->add_line("ThreadDatum* thread = ml->_thread;"); + printer->add_multi_line(R"CODE( + int nodecount = ml->nodecount; + int pnodecount = ml->_nodecount_padded; + double* data = ml->data; + double* weights = nt->weights; + Datum* indexes = ml->pdata; + ThreadDatum* thread = ml->_thread; + )CODE"); if (need_mech_inst) { printer->fmt_line("auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct()); } @@ -3745,7 +3761,7 @@ void CodegenCppVisitor::print_net_receive_common_code(const Block& node, bool ne void CodegenCppVisitor::print_net_send_call(const FunctionCall& node) { auto const& arguments = node.get_arguments(); - auto tqitem = get_variable_name("tqitem"); + const auto& tqitem = get_variable_name("tqitem"); std::string weight_index = "weight_index"; std::string pnt = "pnt"; @@ -3764,14 +3780,14 @@ void CodegenCppVisitor::print_net_send_call(const FunctionCall& node) { if (info.artificial_cell) { printer->fmt_text("artcell_net_send(&{}, {}, {}, nt->_t+", tqitem, weight_index, pnt); } else { - auto point_process = get_variable_name("point_process"); - std::string t = get_variable_name("t"); + const auto& point_process = get_variable_name("point_process"); + const auto& t = get_variable_name("t"); printer->add_text("net_send_buffering("); printer->fmt_text("nt, ml->_net_send_buffer, 0, {}, {}, {}, {}+", tqitem, weight_index, point_process, t); } // clang-format off print_vector_elements(arguments, ", "); - printer->add_text(")"); + printer->add_text(')'); } @@ -3781,7 +3797,7 @@ void CodegenCppVisitor::print_net_move_call(const FunctionCall& node) { } auto const& arguments = node.get_arguments(); - auto tqitem = get_variable_name("tqitem"); + const auto& tqitem = get_variable_name("tqitem"); std::string weight_index = "-1"; std::string pnt = "pnt"; @@ -3792,7 +3808,7 @@ void CodegenCppVisitor::print_net_move_call(const FunctionCall& node) { print_vector_elements(arguments, ", "); printer->add_text(")"); } else { - auto point_process = get_variable_name("point_process"); + const auto& point_process = get_variable_name("point_process"); printer->add_text("net_send_buffering("); printer->fmt_text("nt, ml->_net_send_buffer, 2, {}, {}, {}, ", tqitem, weight_index, point_process); print_vector_elements(arguments, ", "); @@ -3808,7 +3824,7 @@ void CodegenCppVisitor::print_net_event_call(const FunctionCall& node) { printer->add_text("net_event(pnt, "); print_vector_elements(arguments, ", "); } else { - auto point_process = get_variable_name("point_process"); + const auto& point_process = get_variable_name("point_process"); printer->add_text("net_send_buffering("); printer->fmt_text("nt, ml->_net_send_buffer, 1, -1, -1, {}, ", point_process); print_vector_elements(arguments, ", "); @@ -3842,9 +3858,9 @@ void CodegenCppVisitor::print_net_event_call(const FunctionCall& node) { * So, the `R` in AST needs to be renamed with `(*R)`. */ static void rename_net_receive_arguments(const ast::NetReceiveBlock& net_receive_node, const ast::Node& node) { - auto parameters = net_receive_node.get_parameters(); + const auto& parameters = net_receive_node.get_parameters(); for (auto& parameter: parameters) { - auto name = parameter->get_node_name(); + const auto& name = parameter->get_node_name(); auto var_used = VarUsageVisitor().variable_used(node, name); if (var_used) { RenameVisitor vr(name, "(*" + name + ")"); @@ -3868,7 +3884,7 @@ void CodegenCppVisitor::print_net_init() { auto args = "Point_process* pnt, int weight_index, double flag"; printer->add_newline(2); printer->add_line("/** initialize block for net receive */"); - printer->fmt_start_block("static void net_init({})", args); + printer->fmt_push_block("static void net_init({})", args); auto block = node->get_statement_block().get(); if (block->get_statements().empty()) { printer->add_line("// do nothing"); @@ -3882,7 +3898,7 @@ void CodegenCppVisitor::print_net_init() { print_net_send_buf_update_to_host(); } } - printer->end_block(1); + printer->pop_block(1); codegen = false; printing_net_init = false; } @@ -3892,16 +3908,18 @@ void CodegenCppVisitor::print_send_event_move() { printer->add_newline(); printer->add_line("NetSendBuffer_t* nsb = ml->_net_send_buffer;"); print_net_send_buf_update_to_host(); - printer->start_block("for (int i=0; i < nsb->_cnt; i++)"); - printer->add_line("int type = nsb->_sendtype[i];"); - printer->add_line("int tid = nt->id;"); - printer->add_line("double t = nsb->_nsb_t[i];"); - printer->add_line("double flag = nsb->_nsb_flag[i];"); - printer->add_line("int vdata_index = nsb->_vdata_index[i];"); - printer->add_line("int weight_index = nsb->_weight_index[i];"); - printer->add_line("int point_index = nsb->_pnt_index[i];"); - printer->add_line("net_sem_from_gpu(type, vdata_index, weight_index, tid, point_index, t, flag);"); - printer->end_block(1); + printer->push_block("for (int i=0; i < nsb->_cnt; i++)"); + printer->add_multi_line(R"CODE( + int type = nsb->_sendtype[i]; + int tid = nt->id; + double t = nsb->_nsb_t[i]; + double flag = nsb->_nsb_flag[i]; + int vdata_index = nsb->_vdata_index[i]; + int weight_index = nsb->_weight_index[i]; + int point_index = nsb->_pnt_index[i]; + net_sem_from_gpu(type, vdata_index, weight_index, tid, point_index, t, flag); + )CODE"); + printer->pop_block(1); printer->add_line("nsb->_cnt = 0;"); print_net_send_buf_count_update_to_device(); } @@ -3914,21 +3932,21 @@ std::string CodegenCppVisitor::net_receive_buffering_declaration() { void CodegenCppVisitor::print_get_memb_list() { printer->add_line("Memb_list* ml = get_memb_list(nt);"); - printer->start_block("if (!ml)"); + printer->push_block("if (!ml)"); printer->add_line("return;"); - printer->end_block(2); + printer->pop_block(2); } void CodegenCppVisitor::print_net_receive_loop_begin() { printer->add_line("int count = nrb->_displ_cnt;"); print_channel_iteration_block_parallel_hint(BlockType::NetReceive, info.net_receive_node); - printer->start_block("for (int i = 0; i < count; i++)"); + printer->push_block("for (int i = 0; i < count; i++)"); } void CodegenCppVisitor::print_net_receive_loop_end() { - printer->end_block(1); + printer->pop_block(1); } @@ -3937,11 +3955,11 @@ void CodegenCppVisitor::print_net_receive_buffering(bool need_mech_inst) { return; } printer->add_newline(2); - printer->start_block(net_receive_buffering_declaration()); + printer->push_block(net_receive_buffering_declaration()); print_get_memb_list(); - auto net_receive = method_name("net_receive_kernel"); + const auto& net_receive = method_name("net_receive_kernel"); print_kernel_data_present_annotation_block_begin(); @@ -3952,15 +3970,17 @@ void CodegenCppVisitor::print_net_receive_buffering(bool need_mech_inst) { print_net_receive_loop_begin(); printer->add_line("int start = nrb->_displ[i];"); printer->add_line("int end = nrb->_displ[i+1];"); - printer->start_block("for (int j = start; j < end; j++)"); - printer->add_line("int index = nrb->_nrb_index[j];"); - printer->add_line("int offset = nrb->_pnt_index[index];"); - printer->add_line("double t = nrb->_nrb_t[index];"); - printer->add_line("int weight_index = nrb->_weight_index[index];"); - printer->add_line("double flag = nrb->_nrb_flag[index];"); - printer->add_line("Point_process* point_process = nt->pntprocs + offset;"); - printer->fmt_line("{}(t, point_process, inst, nt, ml, weight_index, flag);", net_receive); - printer->end_block(1); + printer->push_block("for (int j = start; j < end; j++)"); + printer->add_multi_line(R"CODE( + int index = nrb->_nrb_index[j]; + int offset = nrb->_pnt_index[index]; + double t = nrb->_nrb_t[index]; + int weight_index = nrb->_weight_index[index]; + double flag = nrb->_nrb_flag[index]; + Point_process* point_process = nt->pntprocs + offset; + )CODE"); + printer->add_line(net_receive, "(t, point_process, inst, nt, ml, weight_index, flag);"); + printer->pop_block(1); print_net_receive_loop_end(); print_device_stream_wait(); @@ -3972,7 +3992,7 @@ void CodegenCppVisitor::print_net_receive_buffering(bool need_mech_inst) { } print_kernel_data_present_annotation_block_end(); - printer->end_block(1); + printer->pop_block(1); } void CodegenCppVisitor::print_net_send_buffering_cnt_update() const { @@ -3980,9 +4000,9 @@ void CodegenCppVisitor::print_net_send_buffering_cnt_update() const { } void CodegenCppVisitor::print_net_send_buffering_grow() { - printer->start_block("if (i >= nsb->_size)"); + printer->push_block("if (i >= nsb->_size)"); printer->add_line("nsb->grow();"); - printer->end_block(1); + printer->pop_block(1); } void CodegenCppVisitor::print_net_send_buffering() { @@ -3995,19 +4015,21 @@ void CodegenCppVisitor::print_net_send_buffering() { auto args = "const NrnThread* nt, NetSendBuffer_t* nsb, int type, int vdata_index, " "int weight_index, int point_index, double t, double flag"; - printer->fmt_start_block("static inline void net_send_buffering({})", args); + printer->fmt_push_block("static inline void net_send_buffering({})", args); printer->add_line("int i = 0;"); print_net_send_buffering_cnt_update(); print_net_send_buffering_grow(); - printer->start_block("if (i < nsb->_size)"); - printer->add_line("nsb->_sendtype[i] = type;"); - printer->add_line("nsb->_vdata_index[i] = vdata_index;"); - printer->add_line("nsb->_weight_index[i] = weight_index;"); - printer->add_line("nsb->_pnt_index[i] = point_index;"); - printer->add_line("nsb->_nsb_t[i] = t;"); - printer->add_line("nsb->_nsb_flag[i] = flag;"); - printer->end_block(1); - printer->end_block(1); + printer->push_block("if (i < nsb->_size)"); + printer->add_multi_line(R"CODE( + nsb->_sendtype[i] = type; + nsb->_vdata_index[i] = vdata_index; + nsb->_weight_index[i] = weight_index; + nsb->_pnt_index[i] = point_index; + nsb->_nsb_t[i] = t; + nsb->_nsb_flag[i] = flag; + )CODE"); + printer->pop_block(1); + printer->pop_block(1); } @@ -4023,7 +4045,7 @@ void CodegenCppVisitor::visit_for_netcon(const ast::ForNetcon& node) { // sanitize node_name since we want to substitute names like (*w) as they are auto old_name = std::regex_replace(args[i_arg]->get_node_name(), regex_special_chars, R"(\$&)"); - auto new_name = fmt::format("weights[{} + nt->_fornetcon_weight_perm[i]]", i_arg); + const auto& new_name = fmt::format("weights[{} + nt->_fornetcon_weight_perm[i]]", i_arg); v.set(old_name, new_name); statement_block->accept(v); } @@ -4060,7 +4082,7 @@ void CodegenCppVisitor::print_net_receive_kernel() { rename_net_receive_arguments(*info.net_receive_node, *node); std::string name; - auto params = ParamVector(); + ParamVector params; if (!info.artificial_cell) { name = method_name("net_receive_kernel"); params.emplace_back("", "double", "", "t"); @@ -4079,7 +4101,7 @@ void CodegenCppVisitor::print_net_receive_kernel() { } printer->add_newline(2); - printer->fmt_start_block("static inline void {}({})", name, get_parameter_str(params)); + printer->fmt_push_block("static inline void {}({})", name, get_parameter_str(params)); print_net_receive_common_code(*node, info.artificial_cell); if (info.artificial_cell) { printer->add_line("double t = nt->_t;"); @@ -4101,7 +4123,7 @@ void CodegenCppVisitor::print_net_receive_kernel() { printer->add_indent(); node->get_statement_block()->accept(*this); printer->add_newline(); - printer->end_block(); + printer->pop_block(); printer->add_newline(); printing_net_receive = false; @@ -4116,26 +4138,28 @@ void CodegenCppVisitor::print_net_receive() { codegen = true; printing_net_receive = true; if (!info.artificial_cell) { - std::string name = method_name("net_receive"); - auto params = ParamVector(); + const auto& name = method_name("net_receive"); + ParamVector params; params.emplace_back("", "Point_process*", "", "pnt"); params.emplace_back("", "int", "", "weight_index"); params.emplace_back("", "double", "", "flag"); printer->add_newline(2); - printer->fmt_start_block("static void {}({})", name, get_parameter_str(params)); + printer->fmt_push_block("static void {}({})", name, get_parameter_str(params)); printer->add_line("NrnThread* nt = nrn_threads + pnt->_tid;"); printer->add_line("Memb_list* ml = get_memb_list(nt);"); printer->add_line("NetReceiveBuffer_t* nrb = ml->_net_receive_buffer;"); - printer->start_block("if (nrb->_cnt >= nrb->_size)"); + printer->push_block("if (nrb->_cnt >= nrb->_size)"); printer->add_line("realloc_net_receive_buffer(nt, ml);"); - printer->end_block(1); - printer->add_line("int id = nrb->_cnt;"); - printer->add_line("nrb->_pnt_index[id] = pnt-nt->pntprocs;"); - printer->add_line("nrb->_weight_index[id] = weight_index;"); - printer->add_line("nrb->_nrb_t[id] = nt->_t;"); - printer->add_line("nrb->_nrb_flag[id] = flag;"); - printer->add_line("nrb->_cnt++;"); - printer->end_block(1); + printer->pop_block(1); + printer->add_multi_line(R"CODE( + int id = nrb->_cnt; + nrb->_pnt_index[id] = pnt-nt->pntprocs; + nrb->_weight_index[id] = weight_index; + nrb->_nrb_t[id] = nt->_t; + nrb->_nrb_flag[id] = flag; + nrb->_cnt++; + )CODE"); + printer->pop_block(1); } printing_net_receive = false; codegen = false; @@ -4149,20 +4173,20 @@ void CodegenCppVisitor::print_net_receive() { * actual variable names? [resolved now?] * slist needs to added as local variable */ -void CodegenCppVisitor::print_derivimplicit_kernel(Block* block) { +void CodegenCppVisitor::print_derivimplicit_kernel(const Block& block) { auto ext_args = external_method_arguments(); auto ext_params = external_method_parameters(); auto suffix = info.mod_suffix; auto list_num = info.derivimplicit_list_num; - auto block_name = block->get_node_name(); + auto block_name = block.get_node_name(); auto primes_size = info.primes_size; auto stride = "*pnodecount+id"; printer->add_newline(2); - printer->start_block("namespace"); - printer->fmt_start_block("struct _newton_{}_{}", block_name, info.mod_suffix); - printer->fmt_start_block("int operator()({}) const", external_method_parameters()); + printer->push_block("namespace"); + printer->fmt_push_block("struct _newton_{}_{}", block_name, info.mod_suffix); + printer->fmt_push_block("int operator()({}) const", external_method_parameters()); auto const instance = fmt::format("auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct()); auto const slist1 = fmt::format("auto const& slist{} = {};", @@ -4190,37 +4214,37 @@ void CodegenCppVisitor::print_derivimplicit_kernel(Block* block) { printer->add_line(dlist1); printer->add_line(dlist2); codegen = true; - print_statement_block(*block->get_statement_block(), false, false); + print_statement_block(*block.get_statement_block(), false, false); codegen = false; printer->add_line("int counter = -1;"); - printer->fmt_start_block("for (int i=0; i<{}; i++)", info.num_primes); - printer->fmt_start_block("if (*deriv{}_advance(thread))", list_num); + printer->fmt_push_block("for (int i=0; i<{}; i++)", info.num_primes); + printer->fmt_push_block("if (*deriv{}_advance(thread))", list_num); printer->fmt_line( "dlist{0}[(++counter){1}] = " "data[dlist{2}[i]{1}]-(data[slist{2}[i]{1}]-savstate{2}[i{1}])/nt->_dt;", list_num + 1, stride, list_num); - printer->restart_block("else"); + printer->chain_block("else"); printer->fmt_line("dlist{0}[(++counter){1}] = data[slist{2}[i]{1}]-savstate{2}[i{1}];", list_num + 1, stride, list_num); - printer->end_block(1); - printer->end_block(1); + printer->pop_block(1); + printer->pop_block(1); printer->add_line("return 0;"); - printer->end_block(1); // operator() - printer->end_block(";"); // struct - printer->end_block(2); // namespace - printer->fmt_start_block("int {}_{}({})", block_name, suffix, ext_params); + printer->pop_block(1); // operator() + printer->pop_block(";"); // struct + printer->pop_block(2); // namespace + printer->fmt_push_block("int {}_{}({})", block_name, suffix, ext_params); printer->add_line(instance); printer->fmt_line("double* savstate{} = (double*) thread[dith{}()].pval;", list_num, list_num); printer->add_line(slist1); printer->add_line(slist2); printer->add_line(dlist2); - printer->fmt_start_block("for (int i=0; i<{}; i++)", info.num_primes); + printer->fmt_push_block("for (int i=0; i<{}; i++)", info.num_primes); printer->fmt_line("savstate{}[i{}] = data[slist{}[i]{}];", list_num, stride, list_num, stride); - printer->end_block(1); + printer->pop_block(1); printer->fmt_line( "int reset = nrn_newton_thread(static_cast(*newtonspace{}(thread)), {}, " "slist{}, _newton_{}_{}{{}}, dlist{}, {});", @@ -4232,7 +4256,7 @@ void CodegenCppVisitor::print_derivimplicit_kernel(Block* block) { list_num + 1, ext_args); printer->add_line("return reset;"); - printer->end_block(3); + printer->pop_block(3); } @@ -4277,7 +4301,7 @@ void CodegenCppVisitor::print_nrn_state() { printer->add_line("/** update state */"); print_global_function_common_code(BlockType::State); print_channel_iteration_block_parallel_hint(BlockType::State, info.nrn_state_block); - printer->start_block("for (int id = 0; id < nodecount; id++)"); + printer->push_block("for (int id = 0; id < nodecount; id++)"); printer->add_line("int node_id = node_index[id];"); printer->add_line("double v = voltage[node_id];"); @@ -4305,16 +4329,16 @@ void CodegenCppVisitor::print_nrn_state() { print_statement_block(*block, false, false); } - auto write_statements = ion_write_statements(BlockType::State); + const auto& write_statements = ion_write_statements(BlockType::State); for (auto& statement: write_statements) { - auto text = process_shadow_update_statement(statement, BlockType::State); + const auto& text = process_shadow_update_statement(statement, BlockType::State); printer->add_line(text); } - printer->end_block(1); + printer->pop_block(1); print_kernel_data_present_annotation_block_end(); - printer->end_block(1); + printer->pop_block(1); codegen = false; } @@ -4325,21 +4349,21 @@ void CodegenCppVisitor::print_nrn_state() { void CodegenCppVisitor::print_nrn_current(const BreakpointBlock& node) { - auto args = internal_method_parameters(); + const auto& args = internal_method_parameters(); const auto& block = node.get_statement_block(); printer->add_newline(2); print_device_method_annotation(); - printer->fmt_start_block("inline double nrn_current_{}({})", - info.mod_suffix, - get_parameter_str(args)); + printer->fmt_push_block("inline double nrn_current_{}({})", + info.mod_suffix, + get_parameter_str(args)); printer->add_line("double current = 0.0;"); print_statement_block(*block, false, false); for (auto& current: info.currents) { - auto name = get_variable_name(current); + const auto& name = get_variable_name(current); printer->fmt_line("current += {};", name); } printer->add_line("return current;"); - printer->end_block(1); + printer->pop_block(1); } @@ -4370,10 +4394,10 @@ void CodegenCppVisitor::print_nrn_cur_conductance_kernel(const BreakpointBlock& for (const auto& conductance: info.conductances) { if (!conductance.ion.empty()) { - auto lhs = std::string(naming::ION_VARNAME_PREFIX) + "di" + conductance.ion + "dv"; - auto rhs = get_variable_name(conductance.variable); - ShadowUseStatement statement{lhs, "+=", rhs}; - auto text = process_shadow_update_statement(statement, BlockType::Equation); + const auto& lhs = std::string(naming::ION_VARNAME_PREFIX) + "di" + conductance.ion + "dv"; + const auto& rhs = get_variable_name(conductance.variable); + const ShadowUseStatement statement{lhs, "+=", rhs}; + const auto& text = process_shadow_update_statement(statement, BlockType::Equation); printer->add_line(text); } } @@ -4387,7 +4411,7 @@ void CodegenCppVisitor::print_nrn_cur_non_conductance_kernel() { for (auto& ion: info.ions) { for (auto& var: ion.writes) { if (ion.is_ionic_current(var)) { - auto name = get_variable_name(var); + const auto& name = get_variable_name(var); printer->fmt_line("double di{} = {};", ion.name, name); } } @@ -4399,14 +4423,14 @@ void CodegenCppVisitor::print_nrn_cur_non_conductance_kernel() { for (auto& ion: info.ions) { for (auto& var: ion.writes) { if (ion.is_ionic_current(var)) { - auto lhs = std::string(naming::ION_VARNAME_PREFIX) + "di" + ion.name + "dv"; + const auto& lhs = std::string(naming::ION_VARNAME_PREFIX) + "di" + ion.name + "dv"; auto rhs = fmt::format("(di{}-{})/0.001", ion.name, get_variable_name(var)); if (info.point_process) { auto area = get_variable_name(naming::NODE_AREA_VARIABLE); rhs += fmt::format("*1.e2/{}", area); } - ShadowUseStatement statement{lhs, "+=", rhs}; - auto text = process_shadow_update_statement(statement, BlockType::Equation); + const ShadowUseStatement statement{lhs, "+=", rhs}; + const auto& text = process_shadow_update_statement(statement, BlockType::Equation); printer->add_line(text); } } @@ -4422,7 +4446,7 @@ void CodegenCppVisitor::print_nrn_cur_kernel(const BreakpointBlock& node) { print_ion_variable(); } - auto read_statements = ion_read_statements(BlockType::Equation); + const auto& read_statements = ion_read_statements(BlockType::Equation); for (auto& statement: read_statements) { printer->add_line(statement); } @@ -4433,14 +4457,14 @@ void CodegenCppVisitor::print_nrn_cur_kernel(const BreakpointBlock& node) { print_nrn_cur_conductance_kernel(node); } - auto write_statements = ion_write_statements(BlockType::Equation); + const auto& write_statements = ion_write_statements(BlockType::Equation); for (auto& statement: write_statements) { auto text = process_shadow_update_statement(statement, BlockType::Equation); printer->add_line(text); } if (info.point_process) { - auto area = get_variable_name(naming::NODE_AREA_VARIABLE); + const auto& area = get_variable_name(naming::NODE_AREA_VARIABLE); printer->fmt_line("double mfactor = 1.e2/{};", area); printer->add_line("g = g*mfactor;"); printer->add_line("rhs = rhs*mfactor;"); @@ -4464,17 +4488,17 @@ void CodegenCppVisitor::print_fast_imem_calculation() { d = "g"; } - printer->start_block("if (nt->nrn_fast_imem)"); + printer->push_block("if (nt->nrn_fast_imem)"); if (nrn_cur_reduction_loop_required()) { - printer->start_block("for (int id = 0; id < nodecount; id++)"); + printer->push_block("for (int id = 0; id < nodecount; id++)"); printer->add_line("int node_id = node_index[id];"); } printer->fmt_line("nt->nrn_fast_imem->nrn_sav_rhs[node_id] {} {};", rhs_op, rhs); printer->fmt_line("nt->nrn_fast_imem->nrn_sav_d[node_id] {} {};", d_op, d); if (nrn_cur_reduction_loop_required()) { - printer->end_block(1); + printer->pop_block(1); } - printer->end_block(1); + printer->pop_block(1); } void CodegenCppVisitor::print_nrn_cur() { @@ -4491,23 +4515,23 @@ void CodegenCppVisitor::print_nrn_cur() { printer->add_line("/** update current */"); print_global_function_common_code(BlockType::Equation); print_channel_iteration_block_parallel_hint(BlockType::Equation, info.breakpoint_node); - printer->start_block("for (int id = 0; id < nodecount; id++)"); + printer->push_block("for (int id = 0; id < nodecount; id++)"); print_nrn_cur_kernel(*info.breakpoint_node); print_nrn_cur_matrix_shadow_update(); if (!nrn_cur_reduction_loop_required()) { print_fast_imem_calculation(); } - printer->end_block(1); + printer->pop_block(1); if (nrn_cur_reduction_loop_required()) { - printer->start_block("for (int id = 0; id < nodecount; id++)"); + printer->push_block("for (int id = 0; id < nodecount; id++)"); print_nrn_cur_matrix_shadow_reduction(); - printer->end_block(1); + printer->pop_block(1); print_fast_imem_calculation(); } print_kernel_data_present_annotation_block_end(); - printer->end_block(1); + printer->pop_block(1); codegen = false; } @@ -4545,9 +4569,9 @@ void CodegenCppVisitor::print_common_getters() { } -void CodegenCppVisitor::print_data_structures(bool print_initialisers) { - print_mechanism_global_var_structure(print_initialisers); - print_mechanism_range_var_structure(print_initialisers); +void CodegenCppVisitor::print_data_structures(bool print_initializers) { + print_mechanism_global_var_structure(print_initializers); + print_mechanism_range_var_structure(print_initializers); print_ion_var_structure(); } @@ -4555,15 +4579,19 @@ void CodegenCppVisitor::print_v_unused() const { if (!info.vectorize) { return; } - printer->add_line("#if NRN_PRCELLSTATE"); - printer->add_line("inst->v_unused[id] = v;"); - printer->add_line("#endif"); + printer->add_multi_line(R"CODE( + #if NRN_PRCELLSTATE + inst->v_unused[id] = v; + #endif + )CODE"); } void CodegenCppVisitor::print_g_unused() const { - printer->add_line("#if NRN_PRCELLSTATE"); - printer->add_line("inst->g_unused[id] = g;"); - printer->add_line("#endif"); + printer->add_multi_line(R"CODE( + #if NRN_PRCELLSTATE + inst->g_unused[id] = g; + #endif + )CODE"); } void CodegenCppVisitor::print_compute_functions() { @@ -4581,7 +4609,7 @@ void CodegenCppVisitor::print_compute_functions() { print_before_after_block(info.before_after_blocks[i], i); } for (const auto& callback: info.derivimplicit_callbacks) { - auto block = callback->get_node_to_solve().get(); + const auto& block = *callback->get_node_to_solve(); print_derivimplicit_kernel(block); } print_net_send_buffering(); diff --git a/src/codegen/codegen_cpp_visitor.hpp b/src/codegen/codegen_cpp_visitor.hpp index 0cb44081f5..0df63480c1 100644 --- a/src/codegen/codegen_cpp_visitor.hpp +++ b/src/codegen/codegen_cpp_visitor.hpp @@ -32,18 +32,19 @@ #include "visitors/ast_visitor.hpp" -namespace nmodl { /// encapsulates code generation backend implementations +namespace nmodl { + namespace codegen { /** - * @defgroup codegen Code Generation Implementation - * @brief Implementations of code generation backends + * \defgroup codegen Code Generation Implementation + * \brief Implementations of code generation backends * - * @defgroup codegen_details Codegen Helpers - * @ingroup codegen - * @brief Helper routines/types for code generation - * @{ + * \defgroup codegen_details Codegen Helpers + * \ingroup codegen + * \brief Helper routines/types for code generation + * \{ */ /** @@ -53,7 +54,7 @@ namespace codegen { * Note: do not assign integers to these enums * */ -enum BlockType { +enum class BlockType { /// initial block Initial, @@ -112,7 +113,7 @@ struct IndexVariableInfo { /// symbol for the variable const std::shared_ptr symbol; - /// if variable reside in vdata field of NrnThread + /// if variable resides in vdata field of NrnThread /// typically true for bbcore pointer bool is_vdata = false; @@ -153,7 +154,7 @@ struct ShadowUseStatement { std::string rhs; }; -/** @} */ // end of codegen_details +/** \} */ // end of codegen_details using printer::CodePrinter; @@ -188,6 +189,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * - type (e.g. \c double) * - pointer qualifier (e.g. \c \_\_restrict\_\_) * - parameter name (e.g. \c data) + * */ using ParamVector = std::vector>; @@ -305,7 +307,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** * Operator for rhs vector update (matrix update) */ - std::string operator_for_rhs() const noexcept { + const char* operator_for_rhs() const noexcept { return info.electrode_current ? "+=" : "-="; } @@ -313,7 +315,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** * Operator for diagonal vector update (matrix update) */ - std::string operator_for_d() const noexcept { + const char* operator_for_d() const noexcept { return info.electrode_current ? "-=" : "+="; } @@ -321,7 +323,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** * Data type for the local variables */ - std::string local_var_type() const noexcept { + const char* local_var_type() const noexcept { return codegen::naming::DEFAULT_LOCAL_VAR_TYPE; } @@ -329,7 +331,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** * Default data type for floating point elements */ - std::string default_float_data_type() const noexcept { + const char* default_float_data_type() const noexcept { return codegen::naming::DEFAULT_FLOAT_TYPE; } @@ -345,7 +347,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** * Default data type for integer (offset) elements */ - std::string default_int_data_type() const noexcept { + const char* default_int_data_type() const noexcept { return codegen::naming::DEFAULT_INTEGER_TYPE; } @@ -492,7 +494,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * \param node The AST Statement node to check * \return \c true if this Statement requires a semicolon */ - static bool need_semicolon(ast::Statement* node); + static bool need_semicolon(const ast::Statement& node); /** @@ -634,7 +636,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * Determine all \c float variables required during code generation * \return A \c vector of \c float variables */ - std::vector get_float_variables(); + std::vector get_float_variables() const; /** @@ -651,10 +653,10 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * current printer. Elements are expected to be of type nmodl::ast::Ast and are printed by being * visited. Care is taken to omit the separator after the the last element. * - * \tparam The element type in the vector, which must be of type nmodl::ast::Ast + * \tparam T The element type in the vector, which must be of type nmodl::ast::Ast * \param elements The vector of elements to be printed - * \param prefix A prefix string to printed before each element - * \param separator The seperator string to be printed between all elements + * \param separator The separator string to print between all elements + * \param prefix A prefix string to print before each element */ template void print_vector_elements(const std::vector& elements, @@ -667,7 +669,8 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * The procedure parameters are stored in a vector of 4-tuples each representing a parameter. * * \param params The parameters that should be concatenated into the function parameter - * declaration \return The string representing the declaration of function parameters + * declaration + * \return The string representing the declaration of function parameters */ static std::string get_parameter_str(const ParamVector& params); @@ -722,7 +725,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * \param type The type of code block being generated * \return A \c vector of strings representing the reading of ion variables */ - std::vector ion_read_statements(BlockType type); + std::vector ion_read_statements(BlockType type) const; /** @@ -731,7 +734,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * \param type The type of code block being generated * \return A \c vector of strings representing the reading of ion variables */ - std::vector ion_read_statements_optimized(BlockType type); + std::vector ion_read_statements_optimized(BlockType type) const; /** @@ -789,7 +792,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * Arguments for external functions called from generated code * \return A string representing the arguments passed to an external function */ - static std::string external_method_arguments(); + static const char* external_method_arguments() noexcept; /** @@ -801,7 +804,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { * \param table * \return A string representing the parameters of the function */ - static std::string external_method_parameters(bool table = false); + static const char* external_method_parameters(bool table = false) noexcept; /** @@ -813,7 +816,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** * Arguments for "_threadargs_" macro in neuron implementation */ - std::string nrn_thread_arguments(); + std::string nrn_thread_arguments() const; /** @@ -980,10 +983,10 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** * Print the structure that wraps all global variables used in the NMODL * - * @param print_initialisers Whether to include default values in the struct + * \param print_initializers Whether to include default values in the struct * definition (true: int foo{42}; false: int foo;) */ - void print_mechanism_global_var_structure(bool print_initialisers); + void print_mechanism_global_var_structure(bool print_initializers); /** @@ -1374,9 +1377,9 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** * Print derivative kernel when \c derivimplicit method is used * - * \param block The corresponding AST node represening an NMODL \c derivimplicit block + * \param block The corresponding AST node representing an NMODL \c derivimplicit block */ - void print_derivimplicit_kernel(ast::Block* block); + void print_derivimplicit_kernel(const ast::Block& block); /** @@ -1540,9 +1543,9 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** * Print all classes - * @param print_initialisers Whether to include default values. + * \param print_initializers Whether to include default values. */ - void print_data_structures(bool print_initialisers); + void print_data_structures(bool print_initializers); /** @@ -1744,8 +1747,8 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** * Print NMODL before / after block in target backend code - * @param node AST node of type before/after type being printed - * @param block_id Index of the before/after block + * \param node AST node of type before/after type being printed + * \param block_id Index of the before/after block */ virtual void print_before_after_block(const ast::Block* node, size_t block_id); @@ -1766,18 +1769,18 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { /** * Find unique variable name defined in nmodl::utils::SingletonRandomString by the * nmodl::visitor::SympySolverVisitor - * @param original_name Original name of variable to change - * @return std::string Unique name produced as [original_name]_[random_string] + * \param original_name Original name of variable to change + * \return std::string Unique name produced as [original_name]_[random_string] */ std::string find_var_unique_name(const std::string& original_name) const; /** * Print the structure that wraps all range and int variables required for the NMODL * - * @param print_initialisers Whether or not default values for variables + * \param print_initializers Whether or not default values for variables * be included in the struct declaration. */ - void print_mechanism_range_var_structure(bool print_initialisers); + void print_mechanism_range_var_structure(bool print_initializers); /** * Print the function that initialize instance structure @@ -1793,10 +1796,10 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { void print_functors_definitions(); /** - * @brief Based on the \c EigenNewtonSolverBlock passed print the definition needed for its + * \brief Based on the \c EigenNewtonSolverBlock passed print the definition needed for its * functor * - * @param node \c EigenNewtonSolverBlock for which to print the functor + * \param node \c EigenNewtonSolverBlock for which to print the functor */ void print_functor_definition(const ast::EigenNewtonSolverBlock& node); @@ -1890,7 +1893,7 @@ void CodegenCppVisitor::print_function_declaration(const T& node, const std::str } // procedures have "int" return type by default - std::string return_type = "int"; + const char* return_type = "int"; if (node.is_function_block()) { return_type = default_float_data_type(); } diff --git a/src/codegen/codegen_helper_visitor.cpp b/src/codegen/codegen_helper_visitor.cpp index a8ccea09ae..fa9e4438b6 100644 --- a/src/codegen/codegen_helper_visitor.cpp +++ b/src/codegen/codegen_helper_visitor.cpp @@ -464,9 +464,9 @@ void CodegenHelperVisitor::find_table_variables() { void CodegenHelperVisitor::find_neuron_global_variables() { // TODO: it would be nicer not to have this hardcoded list using pair = std::pair; - for (auto [var, type]: {pair{naming::CELSIUS_VARIABLE, "double"}, - pair{"secondorder", "int"}, - pair{"pi", "double"}}) { + for (const auto& [var, type]: {pair{naming::CELSIUS_VARIABLE, "double"}, + pair{"secondorder", "int"}, + pair{"pi", "double"}}) { auto sym = psymtab->lookup(var); if (sym && (sym->get_read_count() || sym->get_write_count())) { info.neuron_global_variables.emplace_back(std::move(sym), type); diff --git a/src/codegen/codegen_info.cpp b/src/codegen/codegen_info.cpp index 874f3bfe4d..3bec91b51f 100644 --- a/src/codegen/codegen_info.cpp +++ b/src/codegen/codegen_info.cpp @@ -18,7 +18,7 @@ namespace codegen { using visitor::VarUsageVisitor; /// if any ion has write variable -bool CodegenInfo::ion_has_write_variable() const { +bool CodegenInfo::ion_has_write_variable() const noexcept { return std::any_of(ions.begin(), ions.end(), [](auto const& ion) { return !ion.writes.empty(); }); @@ -26,7 +26,7 @@ bool CodegenInfo::ion_has_write_variable() const { /// if given variable is ion write variable -bool CodegenInfo::is_ion_write_variable(const std::string& name) const { +bool CodegenInfo::is_ion_write_variable(const std::string& name) const noexcept { return std::any_of(ions.begin(), ions.end(), [&name](auto const& ion) { return std::any_of(ion.writes.begin(), ion.writes.end(), [&name](auto const& var) { return var == name; @@ -36,7 +36,7 @@ bool CodegenInfo::is_ion_write_variable(const std::string& name) const { /// if given variable is ion read variable -bool CodegenInfo::is_ion_read_variable(const std::string& name) const { +bool CodegenInfo::is_ion_read_variable(const std::string& name) const noexcept { return std::any_of(ions.begin(), ions.end(), [&name](auto const& ion) { return std::any_of(ion.reads.begin(), ion.reads.end(), [&name](auto const& var) { return var == name; @@ -46,13 +46,13 @@ bool CodegenInfo::is_ion_read_variable(const std::string& name) const { /// if either read or write variable -bool CodegenInfo::is_ion_variable(const std::string& name) const { +bool CodegenInfo::is_ion_variable(const std::string& name) const noexcept { return is_ion_read_variable(name) || is_ion_write_variable(name); } /// if a current (ionic or non-specific) -bool CodegenInfo::is_current(const std::string& name) const { +bool CodegenInfo::is_current(const std::string& name) const noexcept { return std::any_of(currents.begin(), currents.end(), [&name](auto const& var) { return var == name; }); @@ -60,20 +60,20 @@ bool CodegenInfo::is_current(const std::string& name) const { /// true is a given variable name if a ionic current /// (i.e. currents excluding non-specific current) -bool CodegenInfo::is_ionic_current(const std::string& name) const { +bool CodegenInfo::is_ionic_current(const std::string& name) const noexcept { return std::any_of(ions.begin(), ions.end(), [&name](auto const& ion) { return ion.is_ionic_current(name); }); } /// true if given variable name is a ionic concentration -bool CodegenInfo::is_ionic_conc(const std::string& name) const { +bool CodegenInfo::is_ionic_conc(const std::string& name) const noexcept { return std::any_of(ions.begin(), ions.end(), [&name](auto const& ion) { return ion.is_ionic_conc(name); }); } -bool CodegenInfo::function_uses_table(std::string& name) const { +bool CodegenInfo::function_uses_table(std::string& name) const noexcept { return std::any_of(functions_with_table.begin(), functions_with_table.end(), [&name](auto const& function) { return name == function->get_node_name(); }); diff --git a/src/codegen/codegen_info.hpp b/src/codegen/codegen_info.hpp index 23a4dcc3a9..37220da56b 100644 --- a/src/codegen/codegen_info.hpp +++ b/src/codegen/codegen_info.hpp @@ -15,6 +15,7 @@ #include #include #include +#include #include "ast/ast.hpp" #include "symtab/symbol_table.hpp" @@ -134,7 +135,7 @@ struct IndexSemantics { IndexSemantics() = delete; IndexSemantics(int index, std::string name, int size) : index(index) - , name(name) + , name(std::move(name)) , size(size) {} }; @@ -390,25 +391,25 @@ struct CodegenInfo { bool eigen_linear_solver_exist = false; /// if any ion has write variable - bool ion_has_write_variable() const; + bool ion_has_write_variable() const noexcept; /// if given variable is ion write variable - bool is_ion_write_variable(const std::string& name) const; + bool is_ion_write_variable(const std::string& name) const noexcept; /// if given variable is ion read variable - bool is_ion_read_variable(const std::string& name) const; + bool is_ion_read_variable(const std::string& name) const noexcept; /// if either read or write variable - bool is_ion_variable(const std::string& name) const; + bool is_ion_variable(const std::string& name) const noexcept; /// if given variable is a current - bool is_current(const std::string& name) const; + bool is_current(const std::string& name) const noexcept; /// if given variable is a ionic current - bool is_ionic_current(const std::string& name) const; + bool is_ionic_current(const std::string& name) const noexcept; /// if given variable is a ionic concentration - bool is_ionic_conc(const std::string& name) const; + bool is_ionic_conc(const std::string& name) const noexcept; /// if watch statements are used bool is_watch_used() const noexcept { @@ -424,7 +425,7 @@ struct CodegenInfo { return !derivimplicit_callbacks.empty(); } - bool function_uses_table(std::string& name) const; + bool function_uses_table(std::string& name) const noexcept; /// true if EigenNewtonSolver is used in nrn_state block bool nrn_state_has_eigen_solver_block() const; diff --git a/src/printer/code_printer.cpp b/src/printer/code_printer.cpp index 9baf8342b9..0dcdf0f114 100644 --- a/src/printer/code_printer.cpp +++ b/src/printer/code_printer.cpp @@ -24,23 +24,23 @@ CodePrinter::CodePrinter(const std::string& filename) { } sbuf = ofs.rdbuf(); - result = std::make_shared(sbuf); + result = std::make_unique(sbuf); } -void CodePrinter::start_block() { - *result << "{"; +void CodePrinter::push_block() { + *result << '{'; add_newline(); indent_level++; } -void CodePrinter::start_block(std::string&& expression) { +void CodePrinter::push_block(const std::string& expression) { add_indent(); *result << expression << " {"; add_newline(); indent_level++; } -void CodePrinter::restart_block(std::string const& expression) { +void CodePrinter::chain_block(std::string const& expression) { --indent_level; add_indent(); *result << "} " << expression << " {"; @@ -49,41 +49,59 @@ void CodePrinter::restart_block(std::string const& expression) { } void CodePrinter::add_indent() { - *result << std::string(indent_level * NUM_SPACES, ' '); -} - -void CodePrinter::add_text(const std::string& text) { - *result << text; -} - -void CodePrinter::add_line(const std::string& text, int num_new_lines) { - add_indent(); - *result << text; - add_newline(num_new_lines); + for (std::size_t i = 0; i < indent_level * NUM_SPACES; ++i) { + *result << ' '; + } } void CodePrinter::add_multi_line(const std::string& text) { - auto lines = stringutils::split_string(text, '\n'); + const auto& lines = stringutils::split_string(text, '\n'); + + int prefix_length{}; + int start_line{}; + while (start_line < lines.size()) { + const auto& line = lines[start_line]; + // skip first empty line, if any + if (line.empty()) { + ++start_line; + continue; + } + // The common indentation of all blocks if the number of spaces + // at the beginning of the first non-empty line. + for (auto line_it = line.begin(); line_it != line.end() && *line_it == ' '; ++line_it) { + prefix_length += 1; + } + break; + } + for (const auto& line: lines) { - add_line(line); + if (line.size() < prefix_length) { + // ignore lines made of ' ' characters + if (std::find_if_not(line.begin(), line.end(), [](char c) { return c == ' '; }) != + line.end()) { + throw std::invalid_argument("Indentation mismatch"); + } + } else { + add_line(line.substr(prefix_length)); + } } } -void CodePrinter::add_newline(int n) { - for (int i = 0; i < n; i++) { - *result << std::endl; +void CodePrinter::add_newline(std::size_t n) { + for (std::size_t i{}; i < n; ++i) { + *result << '\n'; } } -void CodePrinter::end_block(int num_newlines) { +void CodePrinter::pop_block(int num_newlines) { indent_level--; add_indent(); - *result << "}"; + *result << '}'; add_newline(num_newlines); } -void CodePrinter::end_block(std::string_view suffix, std::size_t num_newlines) { - end_block(0); +void CodePrinter::pop_block(const std::string_view& suffix, std::size_t num_newlines) { + pop_block(0); *result << suffix; add_newline(num_newlines); } diff --git a/src/printer/code_printer.hpp b/src/printer/code_printer.hpp index 051edcc3a6..f99b1c217f 100644 --- a/src/printer/code_printer.hpp +++ b/src/printer/code_printer.hpp @@ -43,16 +43,16 @@ class CodePrinter { private: std::ofstream ofs; std::streambuf* sbuf = nullptr; - std::shared_ptr result; + std::unique_ptr result; size_t indent_level = 0; - const int NUM_SPACES = 4; + const size_t NUM_SPACES = 4; public: CodePrinter() - : result(std::make_shared(std::cout.rdbuf())) {} + : result(std::make_unique(std::cout.rdbuf())) {} CodePrinter(std::ostream& stream) - : result(std::make_shared(stream.rdbuf())) {} + : result(std::make_unique(stream.rdbuf())) {} CodePrinter(const std::string& filename); @@ -64,17 +64,25 @@ class CodePrinter { void add_indent(); /// start a block scope without indentation (i.e. "{\n") - void start_block(); + void push_block(); /// start a block scope with an expression (i.e. "[indent][expression] {\n") - void start_block(std::string&& expression); + void push_block(const std::string& expression); /// end a block and immediately start a new one (i.e. "[indent-1]} [expression] {\n") - void restart_block(std::string const& expression); + void chain_block(std::string const& expression); - void add_text(const std::string&); + template + void add_text(Args&&... args) { + (operator<<(*result, args), ...); + } - void add_line(const std::string&, int num_new_lines = 1); + template + void add_line(Args&&... args) { + add_indent(); + add_text(std::forward(args)...); + add_newline(1); + } /// fmt_line(x, y, z) is just shorthand for add_line(fmt::format(x, y, z)) template @@ -82,10 +90,10 @@ class CodePrinter { add_line(fmt::format(std::forward(args)...)); } - /// fmt_start_block(args...) is just shorthand for start_block(fmt::format(args...)) + /// fmt_push_block(args...) is just shorthand for push_block(fmt::format(args...)) template - void fmt_start_block(Args&&... args) { - start_block(fmt::format(std::forward(args)...)); + void fmt_push_block(Args&&... args) { + push_block(fmt::format(std::forward(args)...)); } /// fmt_text(args...) is just shorthand for add_text(fmt::format(args...)) @@ -96,7 +104,7 @@ class CodePrinter { void add_multi_line(const std::string&); - void add_newline(int n = 1); + void add_newline(std::size_t n = 1); void increase_indent() { indent_level++; @@ -107,10 +115,10 @@ class CodePrinter { } /// end of current block scope (i.e. end with "}") - void end_block(int num_newlines = 0); + void pop_block(int num_newlines = 0); /// end a block with `suffix` before the newline(s) (i.e. [indent]}[suffix]\n*num_newlines) - void end_block(std::string_view suffix, std::size_t num_newlines = 1); + void pop_block(const std::string_view& suffix, std::size_t num_newlines = 1); int indent_spaces() { return NUM_SPACES * indent_level; diff --git a/src/symtab/symbol.hpp b/src/symtab/symbol.hpp index d356abeee6..cd00bf9330 100644 --- a/src/symtab/symbol.hpp +++ b/src/symtab/symbol.hpp @@ -244,14 +244,14 @@ class Symbol { nodes.push_back(node); } - std::vector get_nodes() const noexcept { + const std::vector& get_nodes() const noexcept { return nodes; } std::vector get_nodes_by_type( std::initializer_list l) const noexcept; - ModToken get_token() const noexcept { + const ModToken& get_token() const noexcept { return token; } diff --git a/src/symtab/symbol_properties.cpp b/src/symtab/symbol_properties.cpp index cddd6e56ff..8782332ee2 100644 --- a/src/symtab/symbol_properties.cpp +++ b/src/symtab/symbol_properties.cpp @@ -199,13 +199,11 @@ std::vector to_string_vector(const Status& obj) { } std::ostream& operator<<(std::ostream& os, const NmodlType& obj) { - os << to_string(obj); - return os; + return os << to_string(obj); } std::ostream& operator<<(std::ostream& os, const Status& obj) { - os << to_string(obj); - return os; + return os << to_string(obj); } } // namespace syminfo diff --git a/src/symtab/symbol_table.cpp b/src/symtab/symbol_table.cpp index 37d1f02ec6..4a5e51f774 100644 --- a/src/symtab/symbol_table.cpp +++ b/src/symtab/symbol_table.cpp @@ -28,7 +28,7 @@ int SymbolTable::Table::counter = 0; // NOLINT(cppcoreguidelines-avoid-non-cons * cases where we were getting re-insertion errors. */ void SymbolTable::Table::insert(const std::shared_ptr& symbol) { - std::string name = symbol->get_name(); + const auto& name = symbol->get_name(); if (lookup(name) != nullptr) { throw std::runtime_error("Trying to re-insert symbol " + name); } @@ -54,11 +54,11 @@ SymbolTable::SymbolTable(const SymbolTable& table) , parent{nullptr} {} bool SymbolTable::is_method_defined(const std::string& name) const { - auto symbol = lookup_in_scope(name); + const auto& symbol = lookup_in_scope(name); if (symbol == nullptr) { return false; } - auto nodes = symbol->get_nodes_by_type( + const auto& nodes = symbol->get_nodes_by_type( {AstNodeType::FUNCTION_BLOCK, AstNodeType::PROCEDURE_BLOCK}); return !nodes.empty(); } @@ -76,7 +76,7 @@ std::string SymbolTable::position() const { } -void SymbolTable::insert_table(const std::string& name, std::shared_ptr table) { +void SymbolTable::insert_table(const std::string& name, const std::shared_ptr& table) { if (children.find(name) != children.end()) { throw std::runtime_error("Trying to re-insert SymbolTable " + name); } @@ -106,8 +106,8 @@ std::vector> SymbolTable::get_variables_with_properties( /// return all symbol which has all "with" properties and none of the "without" properties std::vector> SymbolTable::get_variables(NmodlType with, NmodlType without) const { - auto variables = get_variables_with_properties(with, true); - decltype(variables) result; + const auto& variables = get_variables_with_properties(with, true); + std::decay_t result; for (auto& variable: variables) { if (!variable->has_any_property(without)) { result.push_back(variable); @@ -176,18 +176,17 @@ std::shared_ptr ModelSymbolTable::lookup(const std::string& name) { } auto symbol = current_symtab->lookup(name); - if (symbol) { - return symbol; - } - // check into all parent symbol tables - auto parent = current_symtab->get_parent_table(); - while (parent != nullptr) { - symbol = parent->lookup(name); - if (symbol) { - break; + if (!symbol) { + // check into all parent symbol tables + auto parent = current_symtab->get_parent_table(); + while (parent != nullptr) { + symbol = parent->lookup(name); + if (symbol) { + break; + } + parent = parent->get_parent_table(); } - parent = parent->get_parent_table(); } return symbol; } @@ -200,8 +199,8 @@ std::shared_ptr ModelSymbolTable::lookup(const std::string& name) { void ModelSymbolTable::emit_message(const std::shared_ptr& first, const std::shared_ptr& second, bool redefinition) { - auto nodes = first->get_nodes(); - std::string name = first->get_name(); + const auto& nodes = first->get_nodes(); + const auto& name = first->get_name(); auto properties = to_string(second->get_properties()); std::string type = "UNKNOWN"; if (!nodes.empty()) { @@ -247,7 +246,7 @@ std::shared_ptr ModelSymbolTable::update_mode_insert( symbol->set_scope(current_symtab->name()); symbol->mark_created(); - std::string name = symbol->get_name(); + const auto& name = symbol->get_name(); auto search_symbol = lookup(name); /// if no symbol found then safe to insert @@ -271,10 +270,11 @@ std::shared_ptr ModelSymbolTable::update_mode_insert( void ModelSymbolTable::update_order(const std::shared_ptr& present_symbol, const std::shared_ptr& new_symbol) { - auto symbol = (present_symbol != nullptr) ? present_symbol : new_symbol; + const auto& symbol = (present_symbol != nullptr) ? present_symbol : new_symbol; - bool is_parameter = new_symbol->has_any_property(NmodlType::param_assign); - bool is_assigned_definition = new_symbol->has_any_property(NmodlType::assigned_definition); + const bool is_parameter = new_symbol->has_any_property(NmodlType::param_assign); + const bool is_assigned_definition = new_symbol->has_any_property( + NmodlType::assigned_definition); if (symbol->get_definition_order() == -1) { if (is_parameter || is_assigned_definition) { @@ -288,7 +288,7 @@ std::shared_ptr ModelSymbolTable::insert(const std::shared_ptr& throw std::logic_error("Can not insert symbol without entering scope"); } - auto search_symbol = lookup(symbol->get_name()); + const auto& search_symbol = lookup(symbol->get_name()); update_order(search_symbol, symbol); /// handle update mode insertion @@ -396,7 +396,7 @@ SymbolTable* ModelSymbolTable::enter_scope(const std::string& name, } if (node_symtab == nullptr || !update_table) { - auto new_name = get_unique_name(name, node, global); + const auto& new_name = get_unique_name(name, node, global); auto new_symtab = std::make_shared(new_name, node, global); new_symtab->set_parent_table(current_symtab); if (symtab == nullptr) { @@ -421,9 +421,7 @@ void ModelSymbolTable::leave_scope() { if (current_symtab == nullptr) { throw std::logic_error("Trying leave scope without entering"); } - if (current_symtab != nullptr) { - current_symtab = current_symtab->get_parent_table(); - } + current_symtab = current_symtab->get_parent_table(); if (current_symtab == nullptr) { current_symtab = symtab.get(); } @@ -481,13 +479,13 @@ void SymbolTable::Table::print(std::ostream& stream, std::string title, int inde if (symbol->is_array()) { name += "[" + std::to_string(symbol->get_length()) + "]"; } - auto position = symbol->get_token().position(); - auto properties = syminfo::to_string(symbol->get_properties()); - auto status = syminfo::to_string(symbol->get_status()); - auto reads = std::to_string(symbol->get_read_count()); - auto nodes = std::to_string(symbol->get_nodes().size()); + const auto& position = symbol->get_token().position(); + const auto& properties = syminfo::to_string(symbol->get_properties()); + const auto status = syminfo::to_string(symbol->get_status()); + const auto reads = std::to_string(symbol->get_read_count()); + const auto nodes = std::to_string(symbol->get_nodes().size()); std::string value; - auto sym_value = symbol->get_value(); + const auto& sym_value = symbol->get_value(); if (sym_value) { value = std::to_string(*sym_value); } @@ -501,10 +499,10 @@ void SymbolTable::Table::print(std::ostream& stream, std::string title, int inde /// construct title for symbol table std::string SymbolTable::title() const { - auto node_type = node->get_node_type_name(); - auto name = symtab_name + " [" + node_type + " IN " + get_parent_table_name() + "] "; - auto location = "POSITION : " + position(); - auto scope = global ? "GLOBAL" : "LOCAL"; + const auto& node_type = node->get_node_type_name(); + const auto& name = symtab_name + " [" + node_type + " IN " + get_parent_table_name() + "] "; + const auto& location = "POSITION : " + position(); + const auto scope = global ? "GLOBAL" : "LOCAL"; return name + location + " SCOPE : " + scope; } diff --git a/src/symtab/symbol_table.hpp b/src/symtab/symbol_table.hpp index d6316a568c..dcc916c5ef 100644 --- a/src/symtab/symbol_table.hpp +++ b/src/symtab/symbol_table.hpp @@ -84,7 +84,7 @@ class SymbolTable { }; /// name of the block - std::string symtab_name; + const std::string symtab_name; /// table holding all symbols in the current block Table table; @@ -111,7 +111,7 @@ class SymbolTable { /// \{ SymbolTable(std::string name, ast::Ast* node, bool global = false) - : symtab_name(name) + : symtab_name(std::move(name)) , node(node) , global(global) {} @@ -123,7 +123,7 @@ class SymbolTable { /// \name Getter /// \{ - SymbolTable* get_parent_table() const { + SymbolTable* get_parent_table() const noexcept { return parent; } @@ -167,7 +167,7 @@ class SymbolTable { return s.str(); } - std::string name() const { + const std::string& name() const noexcept { return symtab_name; } @@ -175,7 +175,7 @@ class SymbolTable { return global; } - void insert(std::shared_ptr symbol) { + void insert(const std::shared_ptr& symbol) { table.insert(symbol); } @@ -207,7 +207,7 @@ class SymbolTable { bool under_global_scope(); /// insert new symbol table as one of the children block - void insert_table(const std::string& name, std::shared_ptr table); + void insert_table(const std::string& name, const std::shared_ptr& table); void print(std::ostream& ss, int level) const; diff --git a/src/utils/string_utils.hpp b/src/utils/string_utils.hpp index c11ab18b20..749ee84c0f 100644 --- a/src/utils/string_utils.hpp +++ b/src/utils/string_utils.hpp @@ -90,7 +90,6 @@ enum class text_alignment { left, right, center }; * \return a copy of the given string with every " and \ characters prefixed with an extra \ */ [[nodiscard]] static inline std::string escape_quotes(const std::string& text) { - std::ostringstream oss; std::string result; for (auto c: text) {