Skip to content

Commit

Permalink
Setting use_range_ptr_var NmodlType in SymTab (#1139)
Browse files Browse the repository at this point in the history
Created FunctionCallpathVisitor that traverses the function calls
and sets `use_range_ptr_var` property
  • Loading branch information
iomaganaris authored Feb 9, 2024
1 parent d70c51b commit d1290e0
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "visitors/after_cvode_to_cnexp_visitor.hpp"
#include "visitors/ast_visitor.hpp"
#include "visitors/constant_folder_visitor.hpp"
#include "visitors/function_callpath_visitor.hpp"
#include "visitors/global_var_visitor.hpp"
#include "visitors/implicit_argument_visitor.hpp"
#include "visitors/indexedname_visitor.hpp"
Expand Down Expand Up @@ -528,6 +529,12 @@ int main(int argc, const char* argv[]) {
SymtabVisitor(update_symtab).visit_program(*ast);
}

{
FunctionCallpathVisitor{}.visit_program(*ast);
ast_to_nmodl(*ast, filepath("FunctionCallpathVisitor"));
SymtabVisitor(update_symtab).visit_program(*ast);
}

{
if (coreneuron_code && oacc_backend) {
logger->info("Running OpenACC backend code generator for CoreNEURON");
Expand Down
4 changes: 4 additions & 0 deletions src/symtab/symbol_properties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ std::vector<std::string> to_string_vector(const NmodlType& obj) {
properties.emplace_back("codegen_var");
}

if (has_property(obj, NmodlType::use_range_ptr_var)) {
properties.emplace_back("use_range_ptr_var");
}

if (has_property(obj, NmodlType::random_var)) {
properties.emplace_back("random_var");
}
Expand Down
5 changes: 4 additions & 1 deletion src/symtab/symbol_properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,10 @@ enum class NmodlType : enum_type {
codegen_var = 1LL << 33,

/// Randomvar Type
random_var = 1LL << 34
random_var = 1LL << 34,

/// FUNCTION or PROCEDURE needs setdata check
use_range_ptr_var = 1LL << 35
};

template <typename T>
Expand Down
1 change: 1 addition & 0 deletions src/visitors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_library(
after_cvode_to_cnexp_visitor.cpp
constant_folder_visitor.cpp
defuse_analyze_visitor.cpp
function_callpath_visitor.cpp
global_var_visitor.cpp
implicit_argument_visitor.cpp
indexedname_visitor.cpp
Expand Down
95 changes: 95 additions & 0 deletions src/visitors/function_callpath_visitor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@

/*
* Copyright 2024 Blue Brain Project, EPFL.
* See the top-level LICENSE file for details.
*
* SPDX-License-Identifier: Apache-2.0
*/

#include "visitors/function_callpath_visitor.hpp"

namespace nmodl {
namespace visitor {

using symtab::Symbol;
using symtab::syminfo::NmodlType;

void FunctionCallpathVisitor::visit_var_name(const ast::VarName& node) {
if (visited_functions_or_procedures.empty()) {
return;
}
/// If node is either a RANGE var, a POINTER or a BBCOREPOINTER then
/// the FUNCTION or PROCEDURE it's used in should have the `use_range_ptr_var`
/// property
auto sym = psymtab->lookup(node.get_node_name());
const auto properties = NmodlType::range_var | NmodlType::pointer_var |
NmodlType::bbcore_pointer_var;
if (sym && sym->has_any_property(properties)) {
const auto top = visited_functions_or_procedures.back();
const auto caller_func_name =
top->is_function_block()
? dynamic_cast<const ast::FunctionBlock*>(top)->get_node_name()
: dynamic_cast<const ast::ProcedureBlock*>(top)->get_node_name();
auto caller_func_proc_sym = psymtab->lookup(caller_func_name);
caller_func_proc_sym->add_properties(NmodlType::use_range_ptr_var);
}
}

void FunctionCallpathVisitor::visit_function_call(const ast::FunctionCall& node) {
if (visited_functions_or_procedures.empty()) {
return;
}
const auto name = node.get_node_name();
const auto func_symbol = psymtab->lookup(name);
if (!func_symbol ||
!func_symbol->has_any_property(NmodlType::function_block | NmodlType::procedure_block) ||
func_symbol->get_nodes().empty()) {
return;
}
/// Visit the called FUNCTION/PROCEDURE AST node to check whether
/// it has `use_range_ptr_var` property. If it does the currently called
/// function needs to have it too.
const auto func_block = func_symbol->get_nodes()[0];
func_block->accept(*this);
if (func_symbol->has_any_property(NmodlType::use_range_ptr_var)) {
const auto top = visited_functions_or_procedures.back();
auto caller_func_name =
top->is_function_block()
? dynamic_cast<const ast::FunctionBlock*>(top)->get_node_name()
: dynamic_cast<const ast::ProcedureBlock*>(top)->get_node_name();
auto caller_func_proc_sym = psymtab->lookup(caller_func_name);
caller_func_proc_sym->add_properties(NmodlType::use_range_ptr_var);
}
}

void FunctionCallpathVisitor::visit_procedure_block(const ast::ProcedureBlock& node) {
/// Avoid recursive calls
if (std::find(visited_functions_or_procedures.begin(),
visited_functions_or_procedures.end(),
&node) != visited_functions_or_procedures.end()) {
return;
}
visited_functions_or_procedures.push_back(&node);
node.visit_children(*this);
visited_functions_or_procedures.pop_back();
}

void FunctionCallpathVisitor::visit_function_block(const ast::FunctionBlock& node) {
// Avoid recursive calls
if (std::find(visited_functions_or_procedures.begin(),
visited_functions_or_procedures.end(),
&node) != visited_functions_or_procedures.end()) {
return;
}
visited_functions_or_procedures.push_back(&node);
node.visit_children(*this);
visited_functions_or_procedures.pop_back();
}

void FunctionCallpathVisitor::visit_program(const ast::Program& node) {
psymtab = node.get_symbol_table();
node.visit_children(*this);
}

} // namespace visitor
} // namespace nmodl
65 changes: 65 additions & 0 deletions src/visitors/function_callpath_visitor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2024 Blue Brain Project, EPFL.
* See the top-level LICENSE file for details.
*
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

/**
* \file
* \brief \copybrief nmodl::visitor::FunctionCallpathVisitor
*/

#include "ast/all.hpp"
#include "symtab/decl.hpp"
#include "visitors/ast_visitor.hpp"

namespace nmodl {
namespace visitor {

/**
* \addtogroup visitor_classes
* \{
*/

/**
* \class FunctionCallpathVisitor
* \brief %Visitor for traversing \c FunctionBlock s and \c ProcedureBlocks through
* their \c FunctionCall s
*
* This visitor is used to traverse the \c FUNCTION s and \c PROCEDURE s in the NMODL files.
* It visits the \c FunctionBlock s and \c ProcedureBlock s and if there is a \c FunctionCall
* in those, it visits the \c FunctionBlock or \c ProcedureBlock of the \c FunctionCall.
* Currently it only checks whether in this path of function calls there is any use of \c RANGE ,
* \c POINTER or \c BBCOREPOINTER variable. In case there is it adds the \c use_range_ptr_var
* property in the \c Symbol of the function or procedure in the program \c SymbolTable and does the
* same recursively for all the caller functions. The \c use_range_ptr_var property is used later in
* the \c CodegenNeuronCppVisitor .
*
*/
class FunctionCallpathVisitor: public ConstAstVisitor {
private:
/// Vector of currently visited functions or procedures (used as a searchable stack)
std::vector<const ast::Block*> visited_functions_or_procedures;

/// symbol table for the program
symtab::SymbolTable* psymtab = nullptr;

public:
void visit_var_name(const ast::VarName& node) override;

void visit_function_call(const ast::FunctionCall& node) override;

void visit_function_block(const ast::FunctionBlock& node) override;

void visit_procedure_block(const ast::ProcedureBlock& node) override;

void visit_program(const ast::Program& node) override;
};

/** \} */ // end of visitor_classes

} // namespace visitor
} // namespace nmodl
50 changes: 48 additions & 2 deletions test/unit/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "codegen/codegen_neuron_cpp_visitor.hpp"
#include "parser/nmodl_driver.hpp"
#include "test/unit/utils/test_utils.hpp"
#include "visitors/function_callpath_visitor.hpp"
#include "visitors/inline_visitor.hpp"
#include "visitors/neuron_solve_visitor.hpp"
#include "visitors/solve_block_visitor.hpp"
Expand All @@ -25,6 +26,7 @@ using namespace codegen;

using nmodl::parser::NmodlDriver;
using nmodl::test_utils::reindent_text;
using symtab::syminfo::NmodlType;

/// Helper for creating C codegen visitor
std::shared_ptr<CodegenNeuronCppVisitor> create_neuron_cpp_visitor(
Expand All @@ -38,6 +40,7 @@ std::shared_ptr<CodegenNeuronCppVisitor> create_neuron_cpp_visitor(
InlineVisitor().visit_program(*ast);
NeuronSolveVisitor().visit_program(*ast);
SolveBlockVisitor().visit_program(*ast);
FunctionCallpathVisitor().visit_program(*ast);

/// create C code generation visitor
auto cv = std::make_shared<CodegenNeuronCppVisitor>("_test", ss, "double", false);
Expand All @@ -47,8 +50,7 @@ std::shared_ptr<CodegenNeuronCppVisitor> create_neuron_cpp_visitor(


/// print entire code
std::string get_neuron_cpp_code(const std::string& nmodl_text,
const bool generate_gpu_code = false) {
std::string get_neuron_cpp_code(const std::string& nmodl_text) {
const auto& ast = NmodlDriver().parse_string(nmodl_text);
std::stringstream ss;
auto cvisitor = create_neuron_cpp_visitor(ast, nmodl_text, ss);
Expand Down Expand Up @@ -276,3 +278,47 @@ void _nrn_mechanism_register_data_fields(Args&&... args) {
}
}
}


SCENARIO("Check whether PROCEDURE and FUNCTION need setdata call", "[codegen][needsetdata]") {
GIVEN("mod file with GLOBAL and RANGE variables used in FUNC and PROC") {
std::string input_nmodl = R"(
NEURON {
SUFFIX test
RANGE x
GLOBAL s
}
PARAMETER {
s = 2
}
ASSIGNED {
x
}
PROCEDURE a() {
x = get_42()
}
FUNCTION b() {
a()
}
FUNCTION get_42() {
get_42 = 42
}
)";
const auto& ast = NmodlDriver().parse_string(input_nmodl);
std::stringstream ss;
auto cvisitor = create_neuron_cpp_visitor(ast, input_nmodl, ss);
cvisitor->visit_program(*ast);
const auto symtab = ast->get_symbol_table();
THEN("use_range_ptr_var property is added to needed FUNC and PROC") {
auto use_range_ptr_var_funcs = symtab->get_variables_with_properties(
NmodlType::use_range_ptr_var);
REQUIRE(use_range_ptr_var_funcs.size() == 2);
const auto a = symtab->lookup("a");
REQUIRE(a->has_any_property(NmodlType::use_range_ptr_var));
const auto b = symtab->lookup("b");
REQUIRE(b->has_any_property(NmodlType::use_range_ptr_var));
const auto get_42 = symtab->lookup("get_42");
REQUIRE(!get_42->has_any_property(NmodlType::use_range_ptr_var));
}
}
}

0 comments on commit d1290e0

Please sign in to comment.