diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index cd6b2b27a..c1c907eae 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -595,10 +595,16 @@ def differentiate2c( If the result coincides with one of the vars, or the LHS of one of the prev_expressions, then it is simplified to this expression. + Note that, in order to differentiate against indexed variables (such as + ``x[0]``), you must pass an instance of ``sympy.Indexed`` to + ``dependent_var`` (_not_ an instance of ``sympy.IndexedBase``), as well as + an instance of ``sympy.IndexedBase`` to ``vars``. + Some simple examples of use: - ``nmodl.ode.differentiate2c ("a*x", "x", {"a"}) == "a"`` - ``differentiate2c ("cos(y) + b*y**2", "y", {"a","b"}) == "Dy = 2*b*y - sin(y)"`` + - ``differentiate2c("a * x[0]", sympy.IndexedBase("x", shape=[1])[0], {"a", sympy.IndexedBase("x", shape=[1])}) == "a"`` Args: expression: expression to be differentiated e.g. "a*x + b" @@ -619,7 +625,7 @@ def differentiate2c( # every symbol (a.k.a variable) that SymPy # is going to manipulate needs to be declared # explicitly - x = sp.symbols(dependent_var, real=True) + x = make_symbol(dependent_var) vars = set(vars) vars.discard(dependent_var) # declare all other supplied variables diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 6eae70699..82e0358d2 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -111,6 +111,19 @@ def test_differentiate2c(): {sp.IndexedBase("s", shape=[1]), sp.IndexedBase("z", shape=[1])}, ) + # make sure we can diff against indexed vars as well + var = sp.IndexedBase("x", shape=[1]) + + assert _equivalent( + differentiate2c( + "a * x[0]", + var[0], + {"a", var}, + ), + "a", + {"a"}, + ) + result = differentiate2c( "-f(x)", "x",