From 167d38f03aa0ea90cd81d7da97ec950a5ca8aa95 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 24 Oct 2024 13:43:24 +0200 Subject: [PATCH] Address comments from review --- src/visitors/cvode_visitor.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 75e20c43d..ea86b14bf 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -153,10 +153,10 @@ class StiffVisitor: public CvodeHelperVisitor { auto rhs = node.get_rhs(); // all indexed variables (need special treatment in SymPy) - auto name_map = get_indexed_variables(*rhs, name->get_node_name()); + auto indexed_variables = get_indexed_variables(*rhs, name->get_node_name()); auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; - auto [jacobian, - exception_message] = diff2c(to_nmodl(*rhs), parse_independent_var(name), name_map); + auto [jacobian, exception_message] = + diff2c(to_nmodl(*rhs), parse_independent_var(name), indexed_variables); if (!exception_message.empty()) { logger->warn("CvodeVisitor :: python exception: {}", exception_message); } @@ -172,11 +172,10 @@ class StiffVisitor: public CvodeHelperVisitor { } }; - -void CvodeVisitor::visit_program(ast::Program& node) { +static std::shared_ptr get_derivative_block(ast::Program& node) { auto derivative_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK}); if (derivative_blocks.empty()) { - return; + return nullptr; } // steady state adds a DERIVATIVE block with a `_steadystate` suffix @@ -195,8 +194,15 @@ void CvodeVisitor::visit_program(ast::Program& node) { throw std::runtime_error(message); } - auto derivative_block = std::dynamic_pointer_cast( - derivative_blocks_copy[0]); + return std::dynamic_pointer_cast(derivative_blocks_copy[0]); +} + + +void CvodeVisitor::visit_program(ast::Program& node) { + auto derivative_block = get_derivative_block(node); + if (derivative_block == nullptr) { + return; + } auto non_stiff_block = derivative_block->get_statement_block()->clone(); remove_conserve_statements(*non_stiff_block);