diff --git a/odetoolbox/__init__.py b/odetoolbox/__init__.py
index f365d01a..66ee08ea 100644
--- a/odetoolbox/__init__.py
+++ b/odetoolbox/__init__.py
@@ -27,7 +27,7 @@
from sympy.core.expr import Expr as SympyExpr # works for both sympy 1.4 and 1.8
from .config import Config
-from .sympy_helpers import _find_in_matrix, _is_zero, _is_sympy_type, SympyPrinter
+from .sympy_helpers import _check_numerical_issue, _check_forbidden_name, _find_in_matrix, _is_zero, _is_sympy_type, SympyPrinter
from .system_of_shapes import SystemOfShapes
from .shapes import MalformedInputException, Shape
@@ -109,7 +109,7 @@ def _from_json_to_shapes(indict, parameters=None) -> Tuple[List[Shape], Dict[sym
"""
logging.info("Processing input shapes...")
- shapes = []
+
# first run for grabbing all the variable names. Coefficients might be incorrect.
all_variable_symbols = []
all_parameter_symbols = set()
@@ -124,6 +124,11 @@ def _from_json_to_shapes(indict, parameters=None) -> Tuple[List[Shape], Dict[sym
assert all([_is_sympy_type(sym) for sym in all_variable_symbols])
logging.info("All known variables: " + str(all_variable_symbols) + ", all parameters used in ODEs: " + str(all_parameter_symbols))
+ # validate input for forbidden names
+ for var in set(all_variable_symbols) | all_parameter_symbols:
+ _check_forbidden_name(var)
+
+ # validate parameters
for param in all_parameter_symbols:
if parameters is None:
parameters = dict()
@@ -135,6 +140,7 @@ def _from_json_to_shapes(indict, parameters=None) -> Tuple[List[Shape], Dict[sym
parameters[param] = None
# second run with the now-known list of variable symbols
+ shapes = []
for shape_json in indict["dynamics"]:
shape = Shape.from_json(shape_json, all_variable_symbols=all_variable_symbols, parameters=parameters, _debug=True)
shapes.append(shape)
@@ -210,6 +216,9 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
assert type(k) is sympy.Symbol
parameters[k] = v
+ _check_forbidden_name(k)
+
+
#
# create Shapes and SystemOfShapes
#
@@ -292,8 +301,12 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
all_shape_symbols = [str(sympy.Symbol(str(shape.symbol) + Config().differential_order_symbol * i)) for i in range(shape.order)]
for sym in all_shape_symbols:
if sym in solver_json["state_variables"]:
- solver_json["initial_values"][sym] = str(shape.get_initial_value(sym.replace(Config().differential_order_symbol, "'")))
+ iv_expr = shape.get_initial_value(sym.replace(Config().differential_order_symbol, "'"))
+ solver_json["initial_values"][sym] = str(iv_expr)
+ # validate output for numerical problems
+ for var in iv_expr.atoms():
+ _check_numerical_issue(var)
#
# copy the parameter values from the input to the output for convenience; convert into numeric values
@@ -318,7 +331,20 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
break
if symbol_appears_in_any_expr:
- solver_json["parameters"][param_name] = str(sympy.parsing.sympy_parser.parse_expr(param_expr, global_dict=Shape._sympy_globals).n())
+ sympy_expr = sympy.parsing.sympy_parser.parse_expr(param_expr, global_dict=Shape._sympy_globals)
+
+ # validate output for numerical problems
+ for var in sympy_expr.atoms():
+ _check_numerical_issue(var)
+
+ # convert to numeric value
+ sympy_expr = sympy_expr.n()
+
+ # validate output for numerical problems
+ for var in sympy_expr.atoms():
+ _check_numerical_issue(var)
+
+ solver_json["parameters"][param_name] = str(sympy_expr)
#
diff --git a/odetoolbox/config.py b/odetoolbox/config.py
index 8efb3e8c..6f4fe97e 100644
--- a/odetoolbox/config.py
+++ b/odetoolbox/config.py
@@ -34,7 +34,8 @@ class Config:
"sim_time": 100E-3,
"max_step_size": 999.,
"integration_accuracy_abs": 1E-6,
- "integration_accuracy_rel": 1E-6
+ "integration_accuracy_rel": 1E-6,
+ "forbidden_names": ["oo", "zoo", "nan", "NaN", "__h"]
}
def __getitem__(self, key):
diff --git a/odetoolbox/shapes.py b/odetoolbox/shapes.py
index ba9fc0c1..507ab979 100644
--- a/odetoolbox/shapes.py
+++ b/odetoolbox/shapes.py
@@ -21,7 +21,7 @@
from __future__ import annotations
-from typing import List, Mapping, Tuple
+from typing import List, Tuple
import functools
import logging
@@ -30,11 +30,9 @@
import sympy.parsing.sympy_parser
from sympy.core.expr import Expr as SympyExpr # works for both sympy 1.4 and 1.8
-from sympy.core.numbers import One as SympyOne
-from sympy.core.numbers import Zero as SympyZero
from .config import Config
-from .sympy_helpers import _custom_simplify_expr, _is_sympy_type, _is_zero
+from .sympy_helpers import _check_numerical_issue, _check_forbidden_name, _custom_simplify_expr, _is_constant_term, _is_sympy_type, _is_zero
class MalformedInputException(Exception):
@@ -44,17 +42,6 @@ class MalformedInputException(Exception):
pass
-def is_constant_term(term, parameters: Mapping[sympy.Symbol, str] = None):
- r"""
- :return: :python:`True` if and only if this term contains only numerical values and parameters; :python:`False` otherwise.
- """
- if parameters is None:
- parameters = {}
- assert all([type(k) is sympy.Symbol for k in parameters.keys()])
- return type(term) in [sympy.Float, sympy.Integer, SympyZero, SympyOne] \
- or all([sym in parameters.keys() for sym in term.free_symbols])
-
-
class Shape:
r"""
This class provides a canonical representation of a shape function independently of the way in which the user specified the shape. It assumes a differential equation of the general form (where bracketed superscript :math:`{}^{(n)}` indicates the :math:`n`-th derivative with respect to time):
@@ -79,6 +66,7 @@ class Shape:
"Integer": sympy.Integer,
"Float": sympy.Float,
"Function": sympy.Function,
+ "Mul": sympy.Mul,
"Pow": sympy.Pow,
"power": sympy.Pow,
"exp": sympy.exp,
@@ -326,6 +314,9 @@ def from_json(cls, indict, all_variable_symbols=None, parameters=None, _debug=Fa
raise MalformedInputException("In defintion of initial value for variable \"" + iv_symbol + "\": differential order (" + str(iv_order) + ") exceeds that of overall equation order (" + str(order) + ")")
if initial_val_specified[iv_order]:
raise MalformedInputException("Initial value for order " + str(iv_order) + " specified more than once")
+
+ _check_forbidden_name(iv_symbol)
+
initial_val_specified[iv_order] = True
initial_values[iv_symbol + iv_order * "'"] = iv_rhs
@@ -392,13 +383,13 @@ def split_lin_inhom_nonlin(expr, x, parameters=None):
terms = [expr]
for term in terms:
- if is_constant_term(term, parameters=parameters):
+ if _is_constant_term(term, parameters=parameters):
inhom_term += term
else:
# check if the term is linear in any of the symbols in `x`
is_lin = False
for j, sym in enumerate(x):
- if is_constant_term(term / sym, parameters=parameters):
+ if _is_constant_term(term / sym, parameters=parameters):
lin_factors[j] += term / sym
is_lin = True
break
@@ -584,8 +575,21 @@ def from_ode(cls, symbol: str, definition: str, initial_values: dict, all_variab
order: int = len(initial_values)
all_variable_symbols_dict = {str(el): el for el in all_variable_symbols}
definition = sympy.parsing.sympy_parser.parse_expr(definition.replace("'", Config().differential_order_symbol), global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) # minimal global_dict to make no assumptions (e.g. "beta" could otherwise be recognised as a function instead of as a parameter symbol)
+
+ # validate input for forbidden names
+ _initial_values = {k: sympy.parsing.sympy_parser.parse_expr(v, evaluate=False, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) for k, v in initial_values.items()}
+ for iv_expr in _initial_values.values():
+ for var in iv_expr.atoms():
+ _check_forbidden_name(var)
+
+ # parse input
initial_values = {k: sympy.parsing.sympy_parser.parse_expr(v, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) for k, v in initial_values.items()}
+ # validate input for numerical issues
+ for iv_expr in initial_values.values():
+ for var in iv_expr.atoms():
+ _check_numerical_issue(var)
+
local_symbols = [symbol + Config().differential_order_symbol * i for i in range(order)]
local_symbols_sympy = [sympy.Symbol(sym_name) for sym_name in local_symbols]
if not symbol in all_variable_symbols:
diff --git a/odetoolbox/sympy_helpers.py b/odetoolbox/sympy_helpers.py
index 0d59f588..ba45457c 100644
--- a/odetoolbox/sympy_helpers.py
+++ b/odetoolbox/sympy_helpers.py
@@ -1,5 +1,5 @@
#
-# sympy_printer.py
+# sympy_helpers.py
#
# This file is part of the NEST ODE toolbox.
#
@@ -19,6 +19,8 @@
# along with NEST. If not, see .
#
+from typing import Mapping
+
import logging
import sympy
import sys
@@ -26,6 +28,37 @@
from .config import Config
+class NumericalIssueException(Exception):
+ r"""Thrown in case of numerical issues, like division by zero."""
+ pass
+
+
+def _is_constant_term(term, parameters: Mapping[sympy.Symbol, str] = None) -> bool:
+ r"""
+ :return: :python:`True` if and only if this term contains only numerical values and parameters; :python:`False` otherwise.
+ """
+ if parameters is None:
+ parameters = {}
+ assert all([type(k) is sympy.Symbol for k in parameters.keys()])
+ return type(term) in [sympy.Float, sympy.Integer, sympy.core.numbers.Zero, sympy.core.numbers.One] \
+ or all([sym in parameters.keys() for sym in term.free_symbols])
+
+
+def _check_numerical_issue(var: str) -> None:
+ forbidden_vars = ["zoo", "oo", "nan", "NaN"]
+ stripped_var_name = str(var).strip("'")
+ if stripped_var_name in forbidden_vars:
+ raise NumericalIssueException("The variable \"" + stripped_var_name + "\" was found. This indicates a numerical problem while solving the system of ODEs. Please check the input for correctness (such as the presence of divisions by zero).")
+
+
+def _check_forbidden_name(var: str) -> None:
+ from .shapes import MalformedInputException
+
+ stripped_var_name = str(var).strip("'")
+ if stripped_var_name in Config().forbidden_names + dir(sympy.core.numbers):
+ raise MalformedInputException("Variable by name \"" + stripped_var_name + "\" not allowed; this is a reserved name.")
+
+
def _is_zero(x):
r"""
Check if a sympy expression is equal to zero.
diff --git a/tests/test_malformed_input.py b/tests/test_malformed_input.py
new file mode 100644
index 00000000..12b9dcf3
--- /dev/null
+++ b/tests/test_malformed_input.py
@@ -0,0 +1,67 @@
+#
+# test_malformed_input.py
+#
+# This file is part of the NEST ODE toolbox.
+#
+# Copyright (C) 2017 The NEST Initiative
+#
+# The NEST ODE toolbox is free software: you can redistribute it
+# and/or modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation, either version 2 of
+# the License, or (at your option) any later version.
+#
+# The NEST ODE toolbox is distributed in the hope that it will be
+# useful, but WITHOUT ANY WARRANTY; without even the implied warranty
+# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+#
+
+import pytest
+
+from .context import odetoolbox
+from odetoolbox.shapes import MalformedInputException
+from odetoolbox.sympy_helpers import NumericalIssueException
+
+
+class TestMalformedInput:
+ r"""Test for failure when forbidden names are used in the input."""
+
+ @pytest.mark.xfail(strict=True, raises=MalformedInputException)
+ def test_malformed_input_iv(self):
+ indict = {"dynamics": [{"expression": "x' = 0",
+ "initial_value": "zoo"}]}
+ result = odetoolbox.analysis(indict, disable_stiffness_check=True)
+
+ @pytest.mark.xfail(strict=True, raises=MalformedInputException)
+ def test_malformed_input_expr(self):
+ indict = {"dynamics": [{"expression": "x' = 42 * NaN",
+ "initial_value": "0."}]}
+ result = odetoolbox.analysis(indict, disable_stiffness_check=True)
+
+ @pytest.mark.xfail(strict=True, raises=MalformedInputException)
+ def test_malformed_input_sym(self):
+ indict = {"dynamics": [{"expression": "oo' = 0",
+ "initial_value": "0."}]}
+ result = odetoolbox.analysis(indict, disable_stiffness_check=True)
+
+ def test_correct_input(self):
+ indict = {"dynamics": [{"expression": "foo' = 0",
+ "initial_value": "0."}]}
+ result = odetoolbox.analysis(indict, disable_stiffness_check=True)
+
+ @pytest.mark.xfail(strict=True, raises=NumericalIssueException)
+ def test_malformed_input_numerical_iv(self):
+ indict = {"dynamics": [{"expression": "foo' = 0",
+ "initial_value": "1/0"}]}
+ result = odetoolbox.analysis(indict, disable_stiffness_check=True)
+ import pdb;pdb.set_trace()
+
+ @pytest.mark.xfail(strict=True, raises=NumericalIssueException)
+ def test_malformed_input_numerical_parameter(self):
+ indict = {"dynamics": [{"expression": "foo' = bar",
+ "initial_value": "1"}],
+ "parameters": {"bar": "1/0"}}
+ result = odetoolbox.analysis(indict, disable_stiffness_check=True)