diff --git a/src/codegen/codegen_cpp_visitor.cpp b/src/codegen/codegen_cpp_visitor.cpp index c395bb7dc..ed7abe674 100644 --- a/src/codegen/codegen_cpp_visitor.cpp +++ b/src/codegen/codegen_cpp_visitor.cpp @@ -391,8 +391,15 @@ std::string CodegenCppVisitor::breakpoint_current(std::string current) const { * for(int id = 0; id < nodecount; id++) { * \endcode */ -void CodegenCppVisitor::print_parallel_iteration_hint(BlockType /* type */, - const ast::Block* block) { +void CodegenCppVisitor::print_parallel_iteration_hint(BlockType type, const ast::Block* block) { + if (parallel_iteration_condition(type, block)) { + printer->add_line("#pragma omp simd"); + printer->add_line("#pragma ivdep"); + } +} + +bool CodegenCppVisitor::parallel_iteration_condition(BlockType /* type */, + const ast::Block* block) { // ivdep allows SIMD parallelisation of a block/loop but doesn't provide // a standard mechanism for atomics. Also, even with openmp 5.0, openmp // atomics do not enable vectorisation under "omp simd" (gives compiler @@ -406,10 +413,8 @@ void CodegenCppVisitor::print_parallel_iteration_hint(BlockType /* type */, ast::AstNodeType::MUTEX_LOCK, ast::AstNodeType::MUTEX_UNLOCK}); } - if (nodes.empty()) { - printer->add_line("#pragma omp simd"); - printer->add_line("#pragma ivdep"); - } + + return nodes.empty(); } diff --git a/src/codegen/codegen_cpp_visitor.hpp b/src/codegen/codegen_cpp_visitor.hpp index 1b2eda1e9..7a1351d04 100644 --- a/src/codegen/codegen_cpp_visitor.hpp +++ b/src/codegen/codegen_cpp_visitor.hpp @@ -797,6 +797,10 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor { virtual void print_parallel_iteration_hint(BlockType type, const ast::Block* block); + /** Condition for parallel iteration. */ + virtual bool parallel_iteration_condition(BlockType type, const ast::Block* block); + + /****************************************************************************************/ /* Backend specific routines */ /****************************************************************************************/ diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index e4aec03b5..5ba23cabe 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -76,6 +76,12 @@ int CodegenNeuronCppVisitor::position_of_int_var(const std::string& name) const } +bool CodegenNeuronCppVisitor::parallel_iteration_condition(BlockType type, + const ast::Block* block) { + return info.thread_safe && CodegenCppVisitor::parallel_iteration_condition(type, block); +} + + /****************************************************************************************/ /* Backend specific routines */ /****************************************************************************************/ @@ -2041,6 +2047,7 @@ void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) { print_global_function_common_code(BlockType::Initial); + print_parallel_iteration_hint(BlockType::Initial, info.initial_node); printer->push_block("for (int id = 0; id < nodecount; id++)"); printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); @@ -2084,6 +2091,8 @@ void CodegenNeuronCppVisitor::print_nrn_jacob() { print_entrypoint_setup_code_from_memb_list(); printer->fmt_line("auto nodecount = _ml_arg->nodecount;"); + + print_parallel_iteration_hint(BlockType::Equation, nullptr); printer->push_block("for (int id = 0; id < nodecount; id++)"); // begin for if (breakpoint_exist()) { @@ -2291,6 +2300,7 @@ void CodegenNeuronCppVisitor::print_nrn_state() { printer->add_newline(2); print_global_function_common_code(BlockType::State); + print_parallel_iteration_hint(BlockType::State, info.nrn_state_block); printer->push_block("for (int id = 0; id < nodecount; id++)"); printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto* _ppvar = _ml_arg->pdata[id];"); @@ -2504,7 +2514,7 @@ void CodegenNeuronCppVisitor::print_nrn_cur() { printer->add_newline(2); printer->add_line("/** update current */"); print_global_function_common_code(BlockType::Equation); - // print_channel_iteration_block_parallel_hint(BlockType::Equation, info.breakpoint_node); + print_parallel_iteration_hint(BlockType::Equation, info.breakpoint_node); printer->push_block("for (int id = 0; id < nodecount; id++)"); print_nrn_cur_kernel(*info.breakpoint_node); // print_nrn_cur_matrix_shadow_update(); diff --git a/src/codegen/codegen_neuron_cpp_visitor.hpp b/src/codegen/codegen_neuron_cpp_visitor.hpp index c0aa55f83..e1fc9fc28 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.hpp +++ b/src/codegen/codegen_neuron_cpp_visitor.hpp @@ -149,6 +149,9 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { int position_of_int_var(const std::string& name) const override; + bool parallel_iteration_condition(BlockType type, const ast::Block* block) override; + + /****************************************************************************************/ /* Backend specific routines */ /****************************************************************************************/