Skip to content

Commit

Permalink
Merge branch 'master' into jelic/fix_broken_cvode_equations
Browse files Browse the repository at this point in the history
  • Loading branch information
JCGoran authored Dec 3, 2024
2 parents 4a075a9 + 2c3c7dd commit 30d1cb8
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 2 deletions.
4 changes: 3 additions & 1 deletion python/nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,12 @@ def integrate2c(diff_string, dt_var, vars, use_pade_approx=False):
if _a1 == 0 and _a2 == 0:
solution = _a0

custom_fcts = {str(f.func): str(f.func) for f in solution.atoms(sp.Function)}

# return result as C code in NEURON format:
# - in the lhs x_0 refers to the state var at time (t+dt)
# - in the rhs x_0 refers to the state var at time t
return f"{sp.ccode(x)} = {sp.ccode(solution.evalf())}"
return f"{sp.ccode(x)} = {sp.ccode(solution.evalf(), user_functions=custom_fcts)}"


def forwards_euler2c(diff_string, dt_var, vars, function_calls):
Expand Down
23 changes: 23 additions & 0 deletions src/visitors/cvode_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "utils/logger.hpp"
#include "visitors/visitor_utils.hpp"
#include <optional>
#include <regex>
#include <utility>

namespace pywrap = nmodl::pybind_wrappers;
Expand All @@ -35,6 +36,25 @@ static void remove_conserve_statements(ast::StatementBlock& node) {
}
}

// remove units from CVODE block so sympy can parse it properly
static void remove_units(ast::BinaryExpression& node) {
// matches either an int or a float, followed by any (including zero)
// number of spaces, followed by an expression in parentheses, that only
// has letters of the alphabet
std::regex unit_pattern(R"((\d+\.?\d*|\.\d+)\s*\([a-zA-Z]+\))");
auto rhs_string = to_nmodl(node.get_rhs());
auto rhs_string_no_units = fmt::format("{} = {}",
to_nmodl(node.get_lhs()),
std::regex_replace(rhs_string, unit_pattern, "$1"));
logger->debug("CvodeVisitor :: removing units from statement {}", to_nmodl(node));
logger->debug("CvodeVisitor :: result: {}", rhs_string_no_units);
auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
create_statement(rhs_string_no_units));
const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
expr_statement->get_expression());
node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
}

static std::pair<std::string, std::optional<int>> parse_independent_var(
std::shared_ptr<ast::Identifier> node) {
auto variable = std::make_pair(node->get_node_name(), std::optional<int>());
Expand Down Expand Up @@ -152,7 +172,10 @@ class StiffVisitor: public CvodeHelperVisitor {
program_symtab->insert(symbol);
}

remove_units(node);

auto rhs = node.get_rhs();

// all indexed variables (need special treatment in SymPy)
auto indexed_variables = get_indexed_variables(*rhs, name->get_node_name());
auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c;
Expand Down
2 changes: 2 additions & 0 deletions test/unit/ode/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def test_integrate2c():
("a", "x + a*dt"),
("a*x", "x*exp(a*dt)"),
("a*x+b", "(-b + (a*x + b)*exp(a*dt))/a"),
# assume custom_function is defined in mod file
("custom_function(a)*x", "x*exp(custom_function(a)*dt)"),
]
for eq, sol in test_cases:
assert _equivalent(
Expand Down
8 changes: 7 additions & 1 deletion test/usecases/cvode/derivative.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ NEURON {
SUFFIX scalar
}

UNITS {
(um) = (micron)
}

PARAMETER {
freq = 10
a = 5
Expand All @@ -14,7 +18,7 @@ PARAMETER {
k = 0.2
}

STATE {var1 var2 var3}
STATE {var1 var2 var3 var4}

INITIAL {
var1 = v1
Expand All @@ -34,4 +38,6 @@ DERIVATIVE equation {
var2' = -var2 * a
: logistic ODE
var3' = r * var3 * (1 - var3 / k)
: ODE with some units
var4' = 1(um) * var4 + a * .1(um) + r * 1.(um) + 1.0 (um)
}

0 comments on commit 30d1cb8

Please sign in to comment.