Skip to content

Commit

Permalink
Fix SymPy errors when using reserved symbols in differentiate2c
Browse files Browse the repository at this point in the history
The user should now make sure to pass all of the symbols to the function
as SymPy objects.
  • Loading branch information
JCGoran committed Oct 23, 2024
1 parent edaa090 commit 492c7ff
Showing 1 changed file with 69 additions and 9 deletions.
78 changes: 69 additions & 9 deletions python/nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,51 @@ def make_symbol(var, /):
return sp.Symbol(var, real=True) if isinstance(var, str) else var


def rename_to_sympy(var, /):
"""
Symbol-aware renaming of reserved or built-in symbols.
Notes
-----
For renaming variables in complex expressions, see
:ref:``regex_rename_to_sympy``.
"""
var = make_symbol(var)

if var.name in forbidden_var:
new_name = f"_sympy_{var.name}"

# there is no "Pythonic" way of doing this since the SymPy API does not
# allow changing names of symbols, or actually copying objects properly
if isinstance(var, sp.Symbol):
assumptions = var.assumptions0
return sp.Symbol(new_name, **assumptions)

elif isinstance(var, sp.IndexedBase):
return sp.IndexedBase(new_name, shape=var.shape)

elif isinstance(var, sp.FunctionClass):
return sp.Function(new_name)

return var


def regex_rename_to_sympy(expression, /):
"""
Rename expression containing reserved or built-in symbols using a regex.
Notes
-----
For renaming single variables, see :ref:``rename_to_sympy``.
"""

for var in forbidden_var:
pattern = re.compile(rf"\b{var}\b")
expression = re.sub(pattern, f"_sympy_{var}", expression)

return expression


def solve_lin_system(
eq_strings,
vars,
Expand Down Expand Up @@ -622,18 +667,23 @@ def differentiate2c(
if stepsize <= 0:
raise ValueError("arg `stepsize` must be > 0")
prev_expressions = prev_expressions or []
# every symbol (a.k.a variable) that SymPy
# is going to manipulate needs to be declared
# explicitly
x = make_symbol(dependent_var)
vars = set(vars)
vars.discard(dependent_var)

# we keep the original symbol around as well so we can rename it back
x_original = make_symbol(dependent_var)
sympy_vars_original = {
**{str(var): make_symbol(var) for var in vars},
str(x_original): x_original,
}

# declare all other supplied variables
sympy_vars = {str(var): make_symbol(var) for var in vars}
sympy_vars[dependent_var] = x
x = rename_to_sympy(dependent_var)
sympy_vars = {
**{str(rename_to_sympy(var)): rename_to_sympy(var) for var in vars},
str(x): x,
}

# parse string into SymPy equation
expr = sp.sympify(expression, locals=sympy_vars)
expr = sp.sympify(regex_rename_to_sympy(expression), locals=sympy_vars)

# parse previous expressions in the order that they came in
# substitute any x-dependent vars in rhs with their rhs expressions,
Expand Down Expand Up @@ -672,6 +722,16 @@ def differentiate2c(
.evalf()
)

# once we have the derivative, it's safe to put back the original variables
reverse_map = {
old_var: new_var
for old_var, new_var in zip(
sympy_vars.values(),
sympy_vars_original.values(),
)
}
diff = diff.subs(reverse_map)

# the codegen method does not like undefined function calls, so we extract
# them here
custom_fcts = {str(f.func): str(f.func) for f in diff.atoms(sp.Function)}
Expand Down

0 comments on commit 492c7ff

Please sign in to comment.