diff --git a/odetoolbox/__init__.py b/odetoolbox/__init__.py index f365d01a..66ee08ea 100644 --- a/odetoolbox/__init__.py +++ b/odetoolbox/__init__.py @@ -27,7 +27,7 @@ from sympy.core.expr import Expr as SympyExpr # works for both sympy 1.4 and 1.8 from .config import Config -from .sympy_helpers import _find_in_matrix, _is_zero, _is_sympy_type, SympyPrinter +from .sympy_helpers import _check_numerical_issue, _check_forbidden_name, _find_in_matrix, _is_zero, _is_sympy_type, SympyPrinter from .system_of_shapes import SystemOfShapes from .shapes import MalformedInputException, Shape @@ -109,7 +109,7 @@ def _from_json_to_shapes(indict, parameters=None) -> Tuple[List[Shape], Dict[sym """ logging.info("Processing input shapes...") - shapes = [] + # first run for grabbing all the variable names. Coefficients might be incorrect. all_variable_symbols = [] all_parameter_symbols = set() @@ -124,6 +124,11 @@ def _from_json_to_shapes(indict, parameters=None) -> Tuple[List[Shape], Dict[sym assert all([_is_sympy_type(sym) for sym in all_variable_symbols]) logging.info("All known variables: " + str(all_variable_symbols) + ", all parameters used in ODEs: " + str(all_parameter_symbols)) + # validate input for forbidden names + for var in set(all_variable_symbols) | all_parameter_symbols: + _check_forbidden_name(var) + + # validate parameters for param in all_parameter_symbols: if parameters is None: parameters = dict() @@ -135,6 +140,7 @@ def _from_json_to_shapes(indict, parameters=None) -> Tuple[List[Shape], Dict[sym parameters[param] = None # second run with the now-known list of variable symbols + shapes = [] for shape_json in indict["dynamics"]: shape = Shape.from_json(shape_json, all_variable_symbols=all_variable_symbols, parameters=parameters, _debug=True) shapes.append(shape) @@ -210,6 +216,9 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so assert type(k) is sympy.Symbol parameters[k] = v + _check_forbidden_name(k) + + # # create Shapes and SystemOfShapes # @@ -292,8 +301,12 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so all_shape_symbols = [str(sympy.Symbol(str(shape.symbol) + Config().differential_order_symbol * i)) for i in range(shape.order)] for sym in all_shape_symbols: if sym in solver_json["state_variables"]: - solver_json["initial_values"][sym] = str(shape.get_initial_value(sym.replace(Config().differential_order_symbol, "'"))) + iv_expr = shape.get_initial_value(sym.replace(Config().differential_order_symbol, "'")) + solver_json["initial_values"][sym] = str(iv_expr) + # validate output for numerical problems + for var in iv_expr.atoms(): + _check_numerical_issue(var) # # copy the parameter values from the input to the output for convenience; convert into numeric values @@ -318,7 +331,20 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so break if symbol_appears_in_any_expr: - solver_json["parameters"][param_name] = str(sympy.parsing.sympy_parser.parse_expr(param_expr, global_dict=Shape._sympy_globals).n()) + sympy_expr = sympy.parsing.sympy_parser.parse_expr(param_expr, global_dict=Shape._sympy_globals) + + # validate output for numerical problems + for var in sympy_expr.atoms(): + _check_numerical_issue(var) + + # convert to numeric value + sympy_expr = sympy_expr.n() + + # validate output for numerical problems + for var in sympy_expr.atoms(): + _check_numerical_issue(var) + + solver_json["parameters"][param_name] = str(sympy_expr) # diff --git a/odetoolbox/config.py b/odetoolbox/config.py index 8efb3e8c..6f4fe97e 100644 --- a/odetoolbox/config.py +++ b/odetoolbox/config.py @@ -34,7 +34,8 @@ class Config: "sim_time": 100E-3, "max_step_size": 999., "integration_accuracy_abs": 1E-6, - "integration_accuracy_rel": 1E-6 + "integration_accuracy_rel": 1E-6, + "forbidden_names": ["oo", "zoo", "nan", "NaN", "__h"] } def __getitem__(self, key): diff --git a/odetoolbox/shapes.py b/odetoolbox/shapes.py index ba9fc0c1..507ab979 100644 --- a/odetoolbox/shapes.py +++ b/odetoolbox/shapes.py @@ -21,7 +21,7 @@ from __future__ import annotations -from typing import List, Mapping, Tuple +from typing import List, Tuple import functools import logging @@ -30,11 +30,9 @@ import sympy.parsing.sympy_parser from sympy.core.expr import Expr as SympyExpr # works for both sympy 1.4 and 1.8 -from sympy.core.numbers import One as SympyOne -from sympy.core.numbers import Zero as SympyZero from .config import Config -from .sympy_helpers import _custom_simplify_expr, _is_sympy_type, _is_zero +from .sympy_helpers import _check_numerical_issue, _check_forbidden_name, _custom_simplify_expr, _is_constant_term, _is_sympy_type, _is_zero class MalformedInputException(Exception): @@ -44,17 +42,6 @@ class MalformedInputException(Exception): pass -def is_constant_term(term, parameters: Mapping[sympy.Symbol, str] = None): - r""" - :return: :python:`True` if and only if this term contains only numerical values and parameters; :python:`False` otherwise. - """ - if parameters is None: - parameters = {} - assert all([type(k) is sympy.Symbol for k in parameters.keys()]) - return type(term) in [sympy.Float, sympy.Integer, SympyZero, SympyOne] \ - or all([sym in parameters.keys() for sym in term.free_symbols]) - - class Shape: r""" This class provides a canonical representation of a shape function independently of the way in which the user specified the shape. It assumes a differential equation of the general form (where bracketed superscript :math:`{}^{(n)}` indicates the :math:`n`-th derivative with respect to time): @@ -79,6 +66,7 @@ class Shape: "Integer": sympy.Integer, "Float": sympy.Float, "Function": sympy.Function, + "Mul": sympy.Mul, "Pow": sympy.Pow, "power": sympy.Pow, "exp": sympy.exp, @@ -326,6 +314,9 @@ def from_json(cls, indict, all_variable_symbols=None, parameters=None, _debug=Fa raise MalformedInputException("In defintion of initial value for variable \"" + iv_symbol + "\": differential order (" + str(iv_order) + ") exceeds that of overall equation order (" + str(order) + ")") if initial_val_specified[iv_order]: raise MalformedInputException("Initial value for order " + str(iv_order) + " specified more than once") + + _check_forbidden_name(iv_symbol) + initial_val_specified[iv_order] = True initial_values[iv_symbol + iv_order * "'"] = iv_rhs @@ -392,13 +383,13 @@ def split_lin_inhom_nonlin(expr, x, parameters=None): terms = [expr] for term in terms: - if is_constant_term(term, parameters=parameters): + if _is_constant_term(term, parameters=parameters): inhom_term += term else: # check if the term is linear in any of the symbols in `x` is_lin = False for j, sym in enumerate(x): - if is_constant_term(term / sym, parameters=parameters): + if _is_constant_term(term / sym, parameters=parameters): lin_factors[j] += term / sym is_lin = True break @@ -584,8 +575,21 @@ def from_ode(cls, symbol: str, definition: str, initial_values: dict, all_variab order: int = len(initial_values) all_variable_symbols_dict = {str(el): el for el in all_variable_symbols} definition = sympy.parsing.sympy_parser.parse_expr(definition.replace("'", Config().differential_order_symbol), global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) # minimal global_dict to make no assumptions (e.g. "beta" could otherwise be recognised as a function instead of as a parameter symbol) + + # validate input for forbidden names + _initial_values = {k: sympy.parsing.sympy_parser.parse_expr(v, evaluate=False, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) for k, v in initial_values.items()} + for iv_expr in _initial_values.values(): + for var in iv_expr.atoms(): + _check_forbidden_name(var) + + # parse input initial_values = {k: sympy.parsing.sympy_parser.parse_expr(v, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) for k, v in initial_values.items()} + # validate input for numerical issues + for iv_expr in initial_values.values(): + for var in iv_expr.atoms(): + _check_numerical_issue(var) + local_symbols = [symbol + Config().differential_order_symbol * i for i in range(order)] local_symbols_sympy = [sympy.Symbol(sym_name) for sym_name in local_symbols] if not symbol in all_variable_symbols: diff --git a/odetoolbox/sympy_helpers.py b/odetoolbox/sympy_helpers.py index 0d59f588..ba45457c 100644 --- a/odetoolbox/sympy_helpers.py +++ b/odetoolbox/sympy_helpers.py @@ -1,5 +1,5 @@ # -# sympy_printer.py +# sympy_helpers.py # # This file is part of the NEST ODE toolbox. # @@ -19,6 +19,8 @@ # along with NEST. If not, see . # +from typing import Mapping + import logging import sympy import sys @@ -26,6 +28,37 @@ from .config import Config +class NumericalIssueException(Exception): + r"""Thrown in case of numerical issues, like division by zero.""" + pass + + +def _is_constant_term(term, parameters: Mapping[sympy.Symbol, str] = None) -> bool: + r""" + :return: :python:`True` if and only if this term contains only numerical values and parameters; :python:`False` otherwise. + """ + if parameters is None: + parameters = {} + assert all([type(k) is sympy.Symbol for k in parameters.keys()]) + return type(term) in [sympy.Float, sympy.Integer, sympy.core.numbers.Zero, sympy.core.numbers.One] \ + or all([sym in parameters.keys() for sym in term.free_symbols]) + + +def _check_numerical_issue(var: str) -> None: + forbidden_vars = ["zoo", "oo", "nan", "NaN"] + stripped_var_name = str(var).strip("'") + if stripped_var_name in forbidden_vars: + raise NumericalIssueException("The variable \"" + stripped_var_name + "\" was found. This indicates a numerical problem while solving the system of ODEs. Please check the input for correctness (such as the presence of divisions by zero).") + + +def _check_forbidden_name(var: str) -> None: + from .shapes import MalformedInputException + + stripped_var_name = str(var).strip("'") + if stripped_var_name in Config().forbidden_names + dir(sympy.core.numbers): + raise MalformedInputException("Variable by name \"" + stripped_var_name + "\" not allowed; this is a reserved name.") + + def _is_zero(x): r""" Check if a sympy expression is equal to zero. diff --git a/tests/test_malformed_input.py b/tests/test_malformed_input.py new file mode 100644 index 00000000..12b9dcf3 --- /dev/null +++ b/tests/test_malformed_input.py @@ -0,0 +1,67 @@ +# +# test_malformed_input.py +# +# This file is part of the NEST ODE toolbox. +# +# Copyright (C) 2017 The NEST Initiative +# +# The NEST ODE toolbox is free software: you can redistribute it +# and/or modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation, either version 2 of +# the License, or (at your option) any later version. +# +# The NEST ODE toolbox is distributed in the hope that it will be +# useful, but WITHOUT ANY WARRANTY; without even the implied warranty +# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . +# + +import pytest + +from .context import odetoolbox +from odetoolbox.shapes import MalformedInputException +from odetoolbox.sympy_helpers import NumericalIssueException + + +class TestMalformedInput: + r"""Test for failure when forbidden names are used in the input.""" + + @pytest.mark.xfail(strict=True, raises=MalformedInputException) + def test_malformed_input_iv(self): + indict = {"dynamics": [{"expression": "x' = 0", + "initial_value": "zoo"}]} + result = odetoolbox.analysis(indict, disable_stiffness_check=True) + + @pytest.mark.xfail(strict=True, raises=MalformedInputException) + def test_malformed_input_expr(self): + indict = {"dynamics": [{"expression": "x' = 42 * NaN", + "initial_value": "0."}]} + result = odetoolbox.analysis(indict, disable_stiffness_check=True) + + @pytest.mark.xfail(strict=True, raises=MalformedInputException) + def test_malformed_input_sym(self): + indict = {"dynamics": [{"expression": "oo' = 0", + "initial_value": "0."}]} + result = odetoolbox.analysis(indict, disable_stiffness_check=True) + + def test_correct_input(self): + indict = {"dynamics": [{"expression": "foo' = 0", + "initial_value": "0."}]} + result = odetoolbox.analysis(indict, disable_stiffness_check=True) + + @pytest.mark.xfail(strict=True, raises=NumericalIssueException) + def test_malformed_input_numerical_iv(self): + indict = {"dynamics": [{"expression": "foo' = 0", + "initial_value": "1/0"}]} + result = odetoolbox.analysis(indict, disable_stiffness_check=True) + import pdb;pdb.set_trace() + + @pytest.mark.xfail(strict=True, raises=NumericalIssueException) + def test_malformed_input_numerical_parameter(self): + indict = {"dynamics": [{"expression": "foo' = bar", + "initial_value": "1"}], + "parameters": {"bar": "1/0"}} + result = odetoolbox.analysis(indict, disable_stiffness_check=True)