Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Omar Awile committed Sep 20, 2023
1 parent 6d7f7e2 commit 9fb4266
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 16 deletions.
13 changes: 7 additions & 6 deletions src/codegen/codegen_helper_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,12 +774,13 @@ void CodegenHelperVisitor::visit_after_block(const ast::AfterBlock& node) {
info.before_after_blocks.push_back(&node);
}

void CodegenHelperVisitor::visit_random(const ast::Random& node) {
auto pdf = Distribution(node.get_distribution()->get_node_name(), node.get_distribution_params());
for (const auto& r : node.get_variables()) {
auto sym = psymtab->lookup(r->get_node_name());
info.random_vars.emplace(sym, pdf);
}
void CodegenHelperVisitor::visit_random(const ast::Random& node) {
auto pdf = Distribution(node.get_distribution()->get_node_name(),
node.get_distribution_params());
for (const auto& r: node.get_variables()) {
auto sym = psymtab->lookup(r->get_node_name());
info.random_vars.emplace(sym, pdf);
}
}

} // namespace codegen
Expand Down
8 changes: 5 additions & 3 deletions src/visitors/check_random_statement_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ void CheckRandomStatementVisitor::visit_random(const ast::Random& node) {
auto& params = node.get_distribution_params();
if (distributions.find(distribution_name) != distributions.end()) {
if (distributions.at(distribution_name) != params.size()) {
throw std::logic_error("Validation Error: {} declared with {} instead of {} parameters"_format(distribution_name, params.size()
, distributions.at(distribution_name)));
throw std::logic_error(
"Validation Error: {} declared with {} instead of {} parameters"_format(
distribution_name, params.size(), distributions.at(distribution_name)));
}
} else {
throw std::logic_error("Validation Error: distribution {} unknown"_format(distribution_name));
throw std::logic_error(
"Validation Error: distribution {} unknown"_format(distribution_name));
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/visitors/check_random_statement_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace visitor {
* \{
*/

class CheckRandomStatementVisitor: protected ConstAstVisitor {
class CheckRandomStatementVisitor: protected ConstAstVisitor {
private:
void visit_random(const ast::Random& node) override;

Expand All @@ -45,5 +45,5 @@ class CheckRandomStatementVisitor: protected ConstAstVisitor {
* \}
*/

} // namespace visitor
} // namespace nmodl
} // namespace visitor
} // namespace nmodl
13 changes: 9 additions & 4 deletions test/unit/parser/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,16 @@ SCENARIO("NEURON block can add RANDOM variable", "[parser][random]") {
nmodl::parser::NmodlDriver driver;
auto ast = driver.parse_string(construct);
nmodl::visitor::CheckRandomStatementVisitor().visit_program(*ast);
const auto& random_statements = nmodl::collect_nodes(*ast, {nmodl::ast::AstNodeType::RANDOM});
const auto& random_statements = nmodl::collect_nodes(*ast,
{nmodl::ast::AstNodeType::RANDOM});

REQUIRE(random_statements.size() == 1);
REQUIRE(static_cast<nmodl::ast::Random*>(random_statements[0].get())->get_distribution()->get_node_name() == "UNIFORM");
REQUIRE(static_cast<nmodl::ast::Random*>(random_statements[0].get())->get_distribution_params().size() == 2);
REQUIRE(static_cast<nmodl::ast::Random*>(random_statements[0].get())
->get_distribution()
->get_node_name() == "UNIFORM");
REQUIRE(static_cast<nmodl::ast::Random*>(random_statements[0].get())
->get_distribution_params()
.size() == 2);
}
}
GIVEN("Incomplete RANDOM variable declaration") {
Expand All @@ -312,7 +317,7 @@ SCENARIO("NEURON block can add RANDOM variable", "[parser][random]") {
nmodl::parser::NmodlDriver driver;
auto ast = driver.parse_string(construct);
REQUIRE_THROWS_WITH(nmodl::visitor::CheckRandomStatementVisitor().visit_program(
static_cast<const nmodl::ast::Program&>(*ast)),
static_cast<const nmodl::ast::Program&>(*ast)),
Catch::Contains("Validation Error"));
}
}
Expand Down

0 comments on commit 9fb4266

Please sign in to comment.