Skip to content

Commit

Permalink
Handle input containing division by zero better
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Oct 17, 2024
1 parent 88c637b commit a2f5e21
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 23 deletions.
34 changes: 30 additions & 4 deletions odetoolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sympy.core.expr import Expr as SympyExpr # works for both sympy 1.4 and 1.8

from .config import Config
from .sympy_helpers import _find_in_matrix, _is_zero, _is_sympy_type, SympyPrinter
from .sympy_helpers import _check_numerical_issue, _check_forbidden_name, _find_in_matrix, _is_zero, _is_sympy_type, SympyPrinter
from .system_of_shapes import SystemOfShapes
from .shapes import MalformedInputException, Shape

Expand Down Expand Up @@ -109,7 +109,7 @@ def _from_json_to_shapes(indict, parameters=None) -> Tuple[List[Shape], Dict[sym
"""

logging.info("Processing input shapes...")
shapes = []

# first run for grabbing all the variable names. Coefficients might be incorrect.
all_variable_symbols = []
all_parameter_symbols = set()
Expand All @@ -124,6 +124,11 @@ def _from_json_to_shapes(indict, parameters=None) -> Tuple[List[Shape], Dict[sym
assert all([_is_sympy_type(sym) for sym in all_variable_symbols])
logging.info("All known variables: " + str(all_variable_symbols) + ", all parameters used in ODEs: " + str(all_parameter_symbols))

# validate input for forbidden names
for var in set(all_variable_symbols) | all_parameter_symbols:
_check_forbidden_name(var)

# validate parameters
for param in all_parameter_symbols:
if parameters is None:
parameters = dict()
Expand All @@ -135,6 +140,7 @@ def _from_json_to_shapes(indict, parameters=None) -> Tuple[List[Shape], Dict[sym
parameters[param] = None

# second run with the now-known list of variable symbols
shapes = []
for shape_json in indict["dynamics"]:
shape = Shape.from_json(shape_json, all_variable_symbols=all_variable_symbols, parameters=parameters, _debug=True)
shapes.append(shape)
Expand Down Expand Up @@ -210,6 +216,9 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
assert type(k) is sympy.Symbol
parameters[k] = v

_check_forbidden_name(k)


#
# create Shapes and SystemOfShapes
#
Expand Down Expand Up @@ -292,8 +301,12 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
all_shape_symbols = [str(sympy.Symbol(str(shape.symbol) + Config().differential_order_symbol * i)) for i in range(shape.order)]
for sym in all_shape_symbols:
if sym in solver_json["state_variables"]:
solver_json["initial_values"][sym] = str(shape.get_initial_value(sym.replace(Config().differential_order_symbol, "'")))
iv_expr = shape.get_initial_value(sym.replace(Config().differential_order_symbol, "'"))
solver_json["initial_values"][sym] = str(iv_expr)

# validate output for numerical problems
for var in iv_expr.atoms():
_check_numerical_issue(var)

#
# copy the parameter values from the input to the output for convenience; convert into numeric values
Expand All @@ -318,7 +331,20 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
break

if symbol_appears_in_any_expr:
solver_json["parameters"][param_name] = str(sympy.parsing.sympy_parser.parse_expr(param_expr, global_dict=Shape._sympy_globals).n())
sympy_expr = sympy.parsing.sympy_parser.parse_expr(param_expr, global_dict=Shape._sympy_globals)

# validate output for numerical problems
for var in sympy_expr.atoms():
_check_numerical_issue(var)

# convert to numeric value
sympy_expr = sympy_expr.n()

# validate output for numerical problems
for var in sympy_expr.atoms():
_check_numerical_issue(var)

solver_json["parameters"][param_name] = str(sympy_expr)


#
Expand Down
3 changes: 2 additions & 1 deletion odetoolbox/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class Config:
"sim_time": 100E-3,
"max_step_size": 999.,
"integration_accuracy_abs": 1E-6,
"integration_accuracy_rel": 1E-6
"integration_accuracy_rel": 1E-6,
"forbidden_names": ["oo", "zoo", "nan", "NaN", "__h"]
}

def __getitem__(self, key):
Expand Down
38 changes: 21 additions & 17 deletions odetoolbox/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from __future__ import annotations

from typing import List, Mapping, Tuple
from typing import List, Tuple

import functools
import logging
Expand All @@ -30,11 +30,9 @@
import sympy.parsing.sympy_parser

from sympy.core.expr import Expr as SympyExpr # works for both sympy 1.4 and 1.8
from sympy.core.numbers import One as SympyOne
from sympy.core.numbers import Zero as SympyZero

from .config import Config
from .sympy_helpers import _custom_simplify_expr, _is_sympy_type, _is_zero
from .sympy_helpers import _check_numerical_issue, _check_forbidden_name, _custom_simplify_expr, _is_constant_term, _is_sympy_type, _is_zero


class MalformedInputException(Exception):
Expand All @@ -44,17 +42,6 @@ class MalformedInputException(Exception):
pass


def is_constant_term(term, parameters: Mapping[sympy.Symbol, str] = None):
r"""
:return: :python:`True` if and only if this term contains only numerical values and parameters; :python:`False` otherwise.
"""
if parameters is None:
parameters = {}
assert all([type(k) is sympy.Symbol for k in parameters.keys()])
return type(term) in [sympy.Float, sympy.Integer, SympyZero, SympyOne] \
or all([sym in parameters.keys() for sym in term.free_symbols])


class Shape:
r"""
This class provides a canonical representation of a shape function independently of the way in which the user specified the shape. It assumes a differential equation of the general form (where bracketed superscript :math:`{}^{(n)}` indicates the :math:`n`-th derivative with respect to time):
Expand All @@ -79,6 +66,7 @@ class Shape:
"Integer": sympy.Integer,
"Float": sympy.Float,
"Function": sympy.Function,
"Mul": sympy.Mul,
"Pow": sympy.Pow,
"power": sympy.Pow,
"exp": sympy.exp,
Expand Down Expand Up @@ -326,6 +314,9 @@ def from_json(cls, indict, all_variable_symbols=None, parameters=None, _debug=Fa
raise MalformedInputException("In defintion of initial value for variable \"" + iv_symbol + "\": differential order (" + str(iv_order) + ") exceeds that of overall equation order (" + str(order) + ")")
if initial_val_specified[iv_order]:
raise MalformedInputException("Initial value for order " + str(iv_order) + " specified more than once")

_check_forbidden_name(iv_symbol)

initial_val_specified[iv_order] = True
initial_values[iv_symbol + iv_order * "'"] = iv_rhs

Expand Down Expand Up @@ -392,13 +383,13 @@ def split_lin_inhom_nonlin(expr, x, parameters=None):
terms = [expr]

for term in terms:
if is_constant_term(term, parameters=parameters):
if _is_constant_term(term, parameters=parameters):
inhom_term += term
else:
# check if the term is linear in any of the symbols in `x`
is_lin = False
for j, sym in enumerate(x):
if is_constant_term(term / sym, parameters=parameters):
if _is_constant_term(term / sym, parameters=parameters):
lin_factors[j] += term / sym
is_lin = True
break
Expand Down Expand Up @@ -584,8 +575,21 @@ def from_ode(cls, symbol: str, definition: str, initial_values: dict, all_variab
order: int = len(initial_values)
all_variable_symbols_dict = {str(el): el for el in all_variable_symbols}
definition = sympy.parsing.sympy_parser.parse_expr(definition.replace("'", Config().differential_order_symbol), global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) # minimal global_dict to make no assumptions (e.g. "beta" could otherwise be recognised as a function instead of as a parameter symbol)

# validate input for forbidden names
_initial_values = {k: sympy.parsing.sympy_parser.parse_expr(v, evaluate=False, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) for k, v in initial_values.items()}
for iv_expr in _initial_values.values():
for var in iv_expr.atoms():
_check_forbidden_name(var)

# parse input
initial_values = {k: sympy.parsing.sympy_parser.parse_expr(v, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) for k, v in initial_values.items()}

# validate input for numerical issues
for iv_expr in initial_values.values():
for var in iv_expr.atoms():
_check_numerical_issue(var)

local_symbols = [symbol + Config().differential_order_symbol * i for i in range(order)]
local_symbols_sympy = [sympy.Symbol(sym_name) for sym_name in local_symbols]
if not symbol in all_variable_symbols:
Expand Down
35 changes: 34 additions & 1 deletion odetoolbox/sympy_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# sympy_printer.py
# sympy_helpers.py
#
# This file is part of the NEST ODE toolbox.
#
Expand All @@ -19,13 +19,46 @@
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
#

from typing import Mapping

import logging
import sympy
import sys

from .config import Config


class NumericalIssueException(Exception):
r"""Thrown in case of numerical issues, like division by zero."""
pass


def _is_constant_term(term, parameters: Mapping[sympy.Symbol, str] = None) -> bool:
r"""
:return: :python:`True` if and only if this term contains only numerical values and parameters; :python:`False` otherwise.
"""
if parameters is None:
parameters = {}
assert all([type(k) is sympy.Symbol for k in parameters.keys()])
return type(term) in [sympy.Float, sympy.Integer, sympy.core.numbers.Zero, sympy.core.numbers.One] \
or all([sym in parameters.keys() for sym in term.free_symbols])


def _check_numerical_issue(var: str) -> None:
forbidden_vars = ["zoo", "oo", "nan", "NaN"]
stripped_var_name = str(var).strip("'")
if stripped_var_name in forbidden_vars:
raise NumericalIssueException("The variable \"" + stripped_var_name + "\" was found. This indicates a numerical problem while solving the system of ODEs. Please check the input for correctness (such as the presence of divisions by zero).")


def _check_forbidden_name(var: str) -> None:
from .shapes import MalformedInputException

stripped_var_name = str(var).strip("'")
if stripped_var_name in Config().forbidden_names + dir(sympy.core.numbers):
raise MalformedInputException("Variable by name \"" + stripped_var_name + "\" not allowed; this is a reserved name.")


def _is_zero(x):
r"""
Check if a sympy expression is equal to zero.
Expand Down
66 changes: 66 additions & 0 deletions tests/test_malformed_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#
# test_malformed_input.py
#
# This file is part of the NEST ODE toolbox.
#
# Copyright (C) 2017 The NEST Initiative
#
# The NEST ODE toolbox is free software: you can redistribute it
# and/or modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation, either version 2 of
# the License, or (at your option) any later version.
#
# The NEST ODE toolbox is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty
# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
#

import pytest

from .context import odetoolbox
from odetoolbox.shapes import MalformedInputException
from odetoolbox.sympy_helpers import NumericalIssueException


class TestMalformedInput:
r"""Test for failure when forbidden names are used in the input."""

@pytest.mark.xfail(strict=True, raises=MalformedInputException)
def test_malformed_input_iv(self):
indict = {"dynamics": [{"expression": "x' = 0",
"initial_value": "zoo"}]}
result = odetoolbox.analysis(indict, disable_stiffness_check=True)

@pytest.mark.xfail(strict=True, raises=MalformedInputException)
def test_malformed_input_expr(self):
indict = {"dynamics": [{"expression": "x' = 42 * NaN",
"initial_value": "0."}]}
result = odetoolbox.analysis(indict, disable_stiffness_check=True)

@pytest.mark.xfail(strict=True, raises=MalformedInputException)
def test_malformed_input_sym(self):
indict = {"dynamics": [{"expression": "oo' = 0",
"initial_value": "0."}]}
result = odetoolbox.analysis(indict, disable_stiffness_check=True)

def test_correct_input(self):
indict = {"dynamics": [{"expression": "foo' = 0",
"initial_value": "0."}]}
result = odetoolbox.analysis(indict, disable_stiffness_check=True)

@pytest.mark.xfail(strict=True, raises=NumericalIssueException)
def test_malformed_input_numerical_iv(self):
indict = {"dynamics": [{"expression": "foo' = 0",
"initial_value": "1/0"}]}
result = odetoolbox.analysis(indict, disable_stiffness_check=True)

@pytest.mark.xfail(strict=True, raises=NumericalIssueException)
def test_malformed_input_numerical_parameter(self):
indict = {"dynamics": [{"expression": "foo' = bar",
"initial_value": "1"}],
"parameters": {"bar": "1/0"}}
result = odetoolbox.analysis(indict, disable_stiffness_check=True)

0 comments on commit a2f5e21

Please sign in to comment.