diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index c1c907eae..55547101a 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -252,6 +252,51 @@ def make_symbol(var, /): return sp.Symbol(var, real=True) if isinstance(var, str) else var +def rename_to_sympy(var, /): + """ + Symbol-aware renaming of reserved or built-in symbols. + + Notes + ----- + For renaming variables in complex expressions, see + :ref:``regex_rename_to_sympy``. + """ + var = make_symbol(var) + + if var.name in forbidden_var: + new_name = f"_sympy_{var.name}" + + # there is no "Pythonic" way of doing this since the SymPy API does not + # allow changing names of symbols, or actually copying objects properly + if isinstance(var, sp.Symbol): + assumptions = var.assumptions0 + return sp.Symbol(new_name, **assumptions) + + elif isinstance(var, sp.IndexedBase): + return sp.IndexedBase(new_name, shape=var.shape) + + elif isinstance(var, sp.FunctionClass): + return sp.Function(new_name) + + return var + + +def regex_rename_to_sympy(expression, /): + """ + Rename expression containing reserved or built-in symbols using a regex. + + Notes + ----- + For renaming single variables, see :ref:``rename_to_sympy``. + """ + + for var in forbidden_var: + pattern = re.compile(rf"\b{var}\b") + expression = re.sub(pattern, f"_sympy_{var}", expression) + + return expression + + def solve_lin_system( eq_strings, vars, @@ -622,18 +667,23 @@ def differentiate2c( if stepsize <= 0: raise ValueError("arg `stepsize` must be > 0") prev_expressions = prev_expressions or [] - # every symbol (a.k.a variable) that SymPy - # is going to manipulate needs to be declared - # explicitly - x = make_symbol(dependent_var) - vars = set(vars) - vars.discard(dependent_var) + + # we keep the original symbol around as well so we can rename it back + x_original = make_symbol(dependent_var) + sympy_vars_original = { + **{str(var): make_symbol(var) for var in vars}, + str(x_original): x_original, + } + # declare all other supplied variables - sympy_vars = {str(var): make_symbol(var) for var in vars} - sympy_vars[dependent_var] = x + x = rename_to_sympy(dependent_var) + sympy_vars = { + **{str(rename_to_sympy(var)): rename_to_sympy(var) for var in vars}, + str(x): x, + } # parse string into SymPy equation - expr = sp.sympify(expression, locals=sympy_vars) + expr = sp.sympify(regex_rename_to_sympy(expression), locals=sympy_vars) # parse previous expressions in the order that they came in # substitute any x-dependent vars in rhs with their rhs expressions, @@ -672,6 +722,16 @@ def differentiate2c( .evalf() ) + # once we have the derivative, it's safe to put back the original variables + reverse_map = { + old_var: new_var + for old_var, new_var in zip( + sympy_vars.values(), + sympy_vars_original.values(), + ) + } + diff = diff.subs(reverse_map) + # the codegen method does not like undefined function calls, so we extract # them here custom_fcts = {str(f.func): str(f.func) for f in diff.atoms(sp.Function)}