From 02051ba70d3c540603a48a3a2f097491b7a39af8 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 14 Oct 2024 13:09:44 +0200 Subject: [PATCH 1/2] Add diffing against IndexedBase instances --- python/nmodl/ode.py | 7 ++++++- test/unit/ode/test_ode.py | 13 +++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index cd6b2b27a..c3de73cfb 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -595,6 +595,11 @@ def differentiate2c( If the result coincides with one of the vars, or the LHS of one of the prev_expressions, then it is simplified to this expression. + Note that, in order to differentiate against indexed variables (such as + ``x[0]``), you must pass an instance of ``sympy.Indexed`` to + ``dependent_var`` (_not_ an instance of ``sympy.IndexedBase``), as well as + an instance of ``sympy.IndexedBase`` to ``vars``. + Some simple examples of use: - ``nmodl.ode.differentiate2c ("a*x", "x", {"a"}) == "a"`` @@ -619,7 +624,7 @@ def differentiate2c( # every symbol (a.k.a variable) that SymPy # is going to manipulate needs to be declared # explicitly - x = sp.symbols(dependent_var, real=True) + x = make_symbol(dependent_var) vars = set(vars) vars.discard(dependent_var) # declare all other supplied variables diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 6eae70699..82e0358d2 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -111,6 +111,19 @@ def test_differentiate2c(): {sp.IndexedBase("s", shape=[1]), sp.IndexedBase("z", shape=[1])}, ) + # make sure we can diff against indexed vars as well + var = sp.IndexedBase("x", shape=[1]) + + assert _equivalent( + differentiate2c( + "a * x[0]", + var[0], + {"a", var}, + ), + "a", + {"a"}, + ) + result = differentiate2c( "-f(x)", "x", From 59ffe7dc582f84ab5cbd298acfc69496624e18b8 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 15 Oct 2024 10:54:05 +0200 Subject: [PATCH 2/2] Add example for diffing --- python/nmodl/ode.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index c3de73cfb..c1c907eae 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -604,6 +604,7 @@ def differentiate2c( - ``nmodl.ode.differentiate2c ("a*x", "x", {"a"}) == "a"`` - ``differentiate2c ("cos(y) + b*y**2", "y", {"a","b"}) == "Dy = 2*b*y - sin(y)"`` + - ``differentiate2c("a * x[0]", sympy.IndexedBase("x", shape=[1])[0], {"a", sympy.IndexedBase("x", shape=[1])}) == "a"`` Args: expression: expression to be differentiated e.g. "a*x + b"