Skip to content

Commit

Permalink
First proposal of improvements on C++ generator
Browse files Browse the repository at this point in the history
* reduce number of copies
* `CodePrinter::add_text`:
  * now supports multiple arguments
  * use it instead of trivial but expensive call to `fmt::format`.
* `CodePrinter::add_multi_line`:
  * rework to work with C++ raw string literals
  * used it instead of 3 or more consecutive `add_line`
* `CodePrinter::start_block`: rename to `push_block`
* `CodePrinter::end_block`: rename to pop_block
  • Loading branch information
tristan0x committed Sep 19, 2023
1 parent 441a3b1 commit 0e59bd4
Show file tree
Hide file tree
Showing 17 changed files with 735 additions and 650 deletions.
102 changes: 54 additions & 48 deletions src/codegen/codegen_acc_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,19 @@ void CodegenAccVisitor::print_memory_allocation_routine() const {
}
printer->add_newline(2);
auto args = "size_t num, size_t size, size_t alignment = 16";
printer->fmt_start_block("static inline void* mem_alloc({})", args);
printer->add_line("void* ptr;");
printer->add_line("cudaMallocManaged(&ptr, num*size);");
printer->add_line("cudaMemset(ptr, 0, num*size);");
printer->add_line("return ptr;");
printer->end_block(1);
printer->fmt_push_block("static inline void* mem_alloc({})", args);
printer->add_multi_line(R"CODE(
void* ptr;
cudaMallocManaged(&ptr, num*size);
cudaMemset(ptr, 0, num*size);
return ptr;
)CODE");
printer->pop_block(1);

printer->add_newline(2);
printer->start_block("static inline void mem_free(void* ptr)");
printer->push_block("static inline void mem_free(void* ptr)");
printer->add_line("cudaFree(ptr);");
printer->end_block(1);
printer->pop_block(1);
}

/**
Expand All @@ -114,23 +116,23 @@ void CodegenAccVisitor::print_memory_allocation_routine() const {
*/
void CodegenAccVisitor::print_abort_routine() const {
printer->add_newline(2);
printer->start_block("static inline void coreneuron_abort()");
printer->add_line("printf(\"Error : Issue while running OpenACC kernel \\n\");");
printer->push_block("static inline void coreneuron_abort()");
printer->add_line(R"(printf("Error : Issue while running OpenACC kernel \n");)");
printer->add_line("assert(0==1);");
printer->end_block(1);
printer->pop_block(1);
}

void CodegenAccVisitor::print_net_send_buffering_cnt_update() const {
printer->fmt_start_block("if (nt->compute_gpu)");
printer->fmt_push_block("if (nt->compute_gpu)");
print_device_atomic_capture_annotation();
printer->add_line("i = nsb->_cnt++;");
printer->restart_block("else");
printer->chain_block("else");
printer->add_line("i = nsb->_cnt++;");
printer->end_block(1);
printer->pop_block(1);
}

void CodegenAccVisitor::print_net_send_buffering_grow() {
// can not grow buffer during gpu execution
// no-op since can not grow buffer during gpu execution
}

/**
Expand Down Expand Up @@ -174,7 +176,7 @@ void CodegenAccVisitor::print_net_init_acc_serial_annotation_block_begin() {

void CodegenAccVisitor::print_net_init_acc_serial_annotation_block_end() {
if (!info.artificial_cell) {
printer->end_block(1);
printer->pop_block(1);

Check warning on line 179 in src/codegen/codegen_acc_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/codegen/codegen_acc_visitor.cpp#L179

Added line #L179 was not covered by tests
}
}

Expand All @@ -198,7 +200,7 @@ void CodegenAccVisitor::print_fast_imem_calculation() {

auto rhs_op = operator_for_rhs();
auto d_op = operator_for_d();
printer->start_block("if (nt->nrn_fast_imem)");
printer->push_block("if (nt->nrn_fast_imem)");

Check warning on line 203 in src/codegen/codegen_acc_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/codegen/codegen_acc_visitor.cpp#L203

Added line #L203 was not covered by tests
if (info.point_process) {
print_atomic_reduction_pragma();
}
Expand All @@ -207,7 +209,7 @@ void CodegenAccVisitor::print_fast_imem_calculation() {
print_atomic_reduction_pragma();
}
printer->fmt_line("nt->nrn_fast_imem->nrn_sav_d[node_id] {} g;", d_op);
printer->end_block(1);
printer->pop_block(1);

Check warning on line 212 in src/codegen/codegen_acc_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/codegen/codegen_acc_visitor.cpp#L212

Added line #L212 was not covered by tests
}

void CodegenAccVisitor::print_nrn_cur_matrix_shadow_reduction() {
Expand All @@ -220,7 +222,7 @@ void CodegenAccVisitor::print_nrn_cur_matrix_shadow_reduction() {
*/
void CodegenAccVisitor::print_kernel_data_present_annotation_block_end() {
if (!info.artificial_cell) {
printer->end_block(1);
printer->pop_block(1);
}
}

Expand All @@ -237,25 +239,27 @@ bool CodegenAccVisitor::nrn_cur_reduction_loop_required() {

void CodegenAccVisitor::print_global_variable_device_update_annotation() {
if (!info.artificial_cell) {
printer->start_block("if (nt->compute_gpu)");
printer->push_block("if (nt->compute_gpu)");
printer->fmt_line("nrn_pragma_acc(update device ({}))", global_struct_instance());
printer->fmt_line("nrn_pragma_omp(target update to({}))", global_struct_instance());
printer->end_block(1);
printer->pop_block(1);
}
}


void CodegenAccVisitor::print_newtonspace_transfer_to_device() const {
int list_num = info.derivimplicit_list_num;
printer->start_block("if(nt->compute_gpu)");
printer->add_line("double* device_vec = cnrn_target_copyin(vec, vec_size / sizeof(double));");
printer->add_line("void* device_ns = cnrn_target_deviceptr(*ns);");
printer->add_line("ThreadDatum* device_thread = cnrn_target_deviceptr(thread);");
printer->push_block("if(nt->compute_gpu)");
printer->add_multi_line(R"CODE(

Check warning on line 253 in src/codegen/codegen_acc_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/codegen/codegen_acc_visitor.cpp#L252-L253

Added lines #L252 - L253 were not covered by tests
double* device_vec = cnrn_target_copyin(vec, vec_size / sizeof(double));
void* device_ns = cnrn_target_deviceptr(*ns);
ThreadDatum* device_thread = cnrn_target_deviceptr(thread);
)CODE");
printer->fmt_line("cnrn_target_memcpy_to_device(&(device_thread[{}]._pvoid), &device_ns);",
info.thread_data_index - 1);
printer->fmt_line("cnrn_target_memcpy_to_device(&(device_thread[dith{}()].pval), &device_vec);",
list_num);
printer->end_block(1);
printer->pop_block(1);

Check warning on line 262 in src/codegen/codegen_acc_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/codegen/codegen_acc_visitor.cpp#L262

Added line #L262 was not covered by tests
}


Expand All @@ -276,32 +280,34 @@ void CodegenAccVisitor::print_instance_struct_transfer_routines(
if (info.artificial_cell) {
return;
}
printer->fmt_start_block(
printer->fmt_push_block(
"static inline void copy_instance_to_device(NrnThread* nt, Memb_list* ml, {} const* inst)",
instance_struct());
printer->start_block("if (!nt->compute_gpu)");
printer->push_block("if (!nt->compute_gpu)");
printer->add_line("return;");
printer->end_block(1);
printer->pop_block(1);
printer->fmt_line("auto tmp = *inst;");
printer->add_line("auto* d_inst = cnrn_target_is_present(inst);");
printer->start_block("if (!d_inst)");
printer->push_block("if (!d_inst)");
printer->add_line("d_inst = cnrn_target_copyin(inst);");
printer->end_block(1);
printer->pop_block(1);
for (auto const& ptr_mem: ptr_members) {
printer->fmt_line("tmp.{0} = cnrn_target_deviceptr(tmp.{0});", ptr_mem);
}
printer->add_line("cnrn_target_memcpy_to_device(d_inst, &tmp);");
printer->add_line("auto* d_ml = cnrn_target_deviceptr(ml);");
printer->add_line("void* d_inst_void = d_inst;");
printer->add_line("cnrn_target_memcpy_to_device(&(d_ml->instance), &d_inst_void);");
printer->end_block(2); // copy_instance_to_device

printer->fmt_start_block("static inline void delete_instance_from_device({}* inst)",
instance_struct());
printer->start_block("if (cnrn_target_is_present(inst))");
printer->add_multi_line(R"CODE(
cnrn_target_memcpy_to_device(d_inst, &tmp);
auto* d_ml = cnrn_target_deviceptr(ml);
void* d_inst_void = d_inst;
cnrn_target_memcpy_to_device(&(d_ml->instance), &d_inst_void);
)CODE");
printer->pop_block(2); // copy_instance_to_device

printer->fmt_push_block("static inline void delete_instance_from_device({}* inst)",
instance_struct());
printer->push_block("if (cnrn_target_is_present(inst))");
printer->add_line("cnrn_target_delete(inst);");
printer->end_block(1);
printer->end_block(2); // delete_instance_from_device
printer->pop_block(1);
printer->pop_block(2); // delete_instance_from_device
}


Expand Down Expand Up @@ -334,9 +340,9 @@ void CodegenAccVisitor::print_device_atomic_capture_annotation() const {


void CodegenAccVisitor::print_device_stream_wait() const {
printer->start_block("if(nt->compute_gpu)");
printer->push_block("if(nt->compute_gpu)");
printer->add_line("nrn_pragma_acc(wait(nt->stream_id))");
printer->end_block(1);
printer->pop_block(1);
}


Expand All @@ -348,18 +354,18 @@ void CodegenAccVisitor::print_net_send_buf_count_update_to_host() const {

void CodegenAccVisitor::print_net_send_buf_update_to_host() const {
print_device_stream_wait();
printer->start_block("if (nsb && nt->compute_gpu)");
printer->push_block("if (nsb && nt->compute_gpu)");
print_net_send_buf_count_update_to_host();
printer->add_line("update_net_send_buffer_on_host(nt, nsb);");
printer->end_block(1);
printer->pop_block(1);
}


void CodegenAccVisitor::print_net_send_buf_count_update_to_device() const {
printer->start_block("if (nt->compute_gpu)");
printer->push_block("if (nt->compute_gpu)");
printer->add_line("nrn_pragma_acc(update device(nsb->_cnt))");
printer->add_line("nrn_pragma_omp(target update to(nsb->_cnt))");
printer->end_block(1);
printer->pop_block(1);
}


Expand Down
25 changes: 14 additions & 11 deletions src/codegen/codegen_acc_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ namespace nmodl {
namespace codegen {

/**
* @addtogroup codegen_backends
* @{
* \addtogroup codegen_backends
* \{
*/

/**
Expand Down Expand Up @@ -97,38 +97,41 @@ class CodegenAccVisitor: public CodegenCppVisitor {
void print_instance_struct_transfer_routine_declarations() override;

/// define helper functions for copying the instance struct to the device
void print_instance_struct_transfer_routines(std::vector<std::string> const&) override;
void print_instance_struct_transfer_routines(
const std::vector<std::string>& ptr_members) override;

/// call helper function for copying the instance struct to the device
void print_instance_struct_copy_to_device() override;

/// call helper function that deletes the instance struct from the device
void print_instance_struct_delete_from_device() override;

// update derivimplicit advance flag on the gpu device
/// update derivimplicit advance flag on the gpu device
void print_deriv_advance_flag_transfer_to_device() const override;

// update NetSendBuffer_t count from device to host
/// update NetSendBuffer_t count from device to host
void print_net_send_buf_count_update_to_host() const override;

// update NetSendBuffer_t from device to host
/// update NetSendBuffer_t from device to host
void print_net_send_buf_update_to_host() const override;

// update NetSendBuffer_t count from host to device
/// update NetSendBuffer_t count from host to device
virtual void print_net_send_buf_count_update_to_device() const override;

// update dt from host to device
/// update dt from host to device
virtual void print_dt_update_to_device() const override;

// synchronise/wait on stream specific to NrnThread
virtual void print_device_stream_wait() const override;

// print atomic capture pragma
/// print atomic capture pragma
void print_device_atomic_capture_annotation() const override;

// print atomic update of NetSendBuffer_t cnt
/// print atomic update of NetSendBuffer_t cnt
void print_net_send_buffering_cnt_update() const override;

/// Replace default implementation by a no-op
/// since the buffer cannot be grown up during gpu execution
void print_net_send_buffering_grow() override;


Expand All @@ -146,7 +149,7 @@ class CodegenAccVisitor: public CodegenCppVisitor {
: CodegenCppVisitor(mod_file, stream, float_type, optimize_ionvar_copies) {}
};

/** @} */ // end of codegen_backends
/** \} */ // end of codegen_backends

} // namespace codegen
} // namespace nmodl
18 changes: 9 additions & 9 deletions src/codegen/codegen_compatibility_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const std::map<ast::AstNodeType, CodegenCompatibilityVisitor::FunctionPointer>

std::string CodegenCompatibilityVisitor::return_error_if_solve_method_is_unhandled(
ast::Ast& /* node */,
const std::shared_ptr<ast::Ast>& ast_node) {
const std::shared_ptr<ast::Ast>& ast_node) const {
auto solve_block_ast_node = std::dynamic_pointer_cast<ast::SolveBlock>(ast_node);
std::stringstream unhandled_method_error_message;
auto method = solve_block_ast_node->get_method();
Expand All @@ -53,7 +53,7 @@ std::string CodegenCompatibilityVisitor::return_error_if_solve_method_is_unhandl
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
std::string CodegenCompatibilityVisitor::return_error_global_var(
ast::Ast& node,
const std::shared_ptr<ast::Ast>& ast_node) {
const std::shared_ptr<ast::Ast>& ast_node) const {
auto global_var = std::dynamic_pointer_cast<ast::GlobalVar>(ast_node);
std::stringstream error_message_global_var;
if (node.get_symbol_table()->lookup(global_var->get_node_name())->get_write_count() > 0) {
Expand All @@ -69,7 +69,7 @@ std::string CodegenCompatibilityVisitor::return_error_global_var(
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
std::string CodegenCompatibilityVisitor::return_error_param_var(
ast::Ast& node,
const std::shared_ptr<ast::Ast>& ast_node) {
const std::shared_ptr<ast::Ast>& ast_node) const {
auto param_assign = std::dynamic_pointer_cast<ast::ParamAssign>(ast_node);
std::stringstream error_message_global_var;
auto symbol = node.get_symbol_table()->lookup(param_assign->get_node_name());
Expand All @@ -85,7 +85,7 @@ std::string CodegenCompatibilityVisitor::return_error_param_var(
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
std::string CodegenCompatibilityVisitor::return_error_if_no_bbcore_read_write(
ast::Ast& node,
const std::shared_ptr<ast::Ast>& /* ast_node */) {
const std::shared_ptr<ast::Ast>& /* ast_node */) const {
std::stringstream error_message_no_bbcore_read_write;
const auto& verbatim_nodes = collect_nodes(node, {AstNodeType::VERBATIM});
auto found_bbcore_read = false;
Expand Down Expand Up @@ -129,15 +129,15 @@ std::string CodegenCompatibilityVisitor::return_error_if_no_bbcore_read_write(
* some kind of incompatibility return false.
*/

bool CodegenCompatibilityVisitor::find_unhandled_ast_nodes(Ast& node) {
bool CodegenCompatibilityVisitor::find_unhandled_ast_nodes(Ast& node) const {
std::vector<ast::AstNodeType> unhandled_ast_types;
unhandled_ast_types.reserve(unhandled_ast_types_func.size());
for (auto kv: unhandled_ast_types_func) {
unhandled_ast_types.push_back(kv.first);
for (auto [node_type, _]: unhandled_ast_types_func) {
unhandled_ast_types.push_back(node_type);
}
unhandled_ast_nodes = collect_nodes(node, unhandled_ast_types);
const auto& unhandled_ast_nodes = collect_nodes(node, unhandled_ast_types);

std::stringstream ss;
std::ostringstream ss;
for (const auto& it: unhandled_ast_nodes) {
auto node_type = it->get_node_type();
ss << (this->*unhandled_ast_types_func.find(node_type)->second)(node, it);
Expand Down
Loading

0 comments on commit 0e59bd4

Please sign in to comment.