diff --git a/pySDC/implementations/problem_classes/nonlinear_ODE_1.py b/pySDC/implementations/problem_classes/nonlinear_ODE_1.py index 867b2a75cf..3bb0d01208 100755 --- a/pySDC/implementations/problem_classes/nonlinear_ODE_1.py +++ b/pySDC/implementations/problem_classes/nonlinear_ODE_1.py @@ -13,7 +13,7 @@ class nonlinear_ODE_1(ptype): given by .. math:: - \frac{du(t)}{dt} = \sqrt(1 - u(t)) + \frac{du(t)}{dt} = \sqrt{1 - u(t)} with initial condition :math:`u(0) = 0`. The exact solution is @@ -30,6 +30,13 @@ class nonlinear_ODE_1(ptype): Tolerance for Newton's method to terminate. stop_at_nan : bool, optional Indicates that Newton solver has to stop if ``nan`` values arise. + + Attributes + ---------- + newton_itercount : int + Counts the Newton iterations. + newton_ncalls : int + Counts calls of Newton method. """ dtype_u = mesh @@ -42,6 +49,9 @@ def __init__(self, u0=0.0, newton_maxiter=200, newton_tol=5e-11, stop_at_nan=Tru 'u0', 'newton_maxiter', 'newton_tol', 'stop_at_nan', localVars=locals(), readOnly=True ) + self.newton_itercount = 0 + self.newton_ncalls = 0 + def u_exact(self, t): r""" Routine to compute the exact solution at time :math:`t`. @@ -125,12 +135,16 @@ def solve_system(self, rhs, dt, u0, t): # increase iteration count n += 1 + self.newton_itercount += 1 + if np.isnan(res) and self.stop_at_nan: raise ProblemError('Newton got nan after %i iterations, aborting...' % n) elif np.isnan(res): self.logger.warning('Newton got nan after %i iterations...' % n) if n == self.newton_maxiter: - raise ProblemError('Newton did not converge after %i iterations, error is %s' % (n, res)) + self.logger.warning('Newton did not converge after %i iterations, error is %s' % (n, res)) + + self.newton_ncalls += 1 return u diff --git a/pySDC/tests/test_problems/test_nonlinear_ODE_1.py b/pySDC/tests/test_problems/test_nonlinear_ODE_1.py new file mode 100644 index 0000000000..5ee69fb338 --- /dev/null +++ b/pySDC/tests/test_problems/test_nonlinear_ODE_1.py @@ -0,0 +1,95 @@ +import pytest + + +@pytest.mark.base +def test_singularity(): + """ + Test if the singularity occurs at correct time. + """ + + import numpy as np + from pySDC.implementations.problem_classes.nonlinear_ODE_1 import nonlinear_ODE_1 + + problem_params = { + 'stop_at_nan': False, + } + + nonlinear_ODE_class = nonlinear_ODE_1(**problem_params) + t_event = 2 + + u_event = nonlinear_ODE_class.u_exact(t_event) + f = nonlinear_ODE_class.eval_f(u_event, t_event) + + assert f == 0, "Evaluation of right-hand side at singularity does not match with zero!" + + dt = 1e-1 + t0 = 1.9 + u0 = nonlinear_ODE_class.u_exact(t0) + args = { + 'rhs': u0, + 'dt': dt, + 'u0': u0, + 't': t0, + } + + sol = nonlinear_ODE_class.solve_system(**args) + assert abs(sol - 1) < 1e-14, f"Solution is not close enough to the value at singularity! Expected 1, got {sol}" + assert ( + nonlinear_ODE_class.newton_itercount == nonlinear_ODE_class.newton_maxiter + ), f"Expected {nonlinear_ODE_class.newton_maxiter} Newton iterations, got {nonlinear_ODE_class.newton_itercount}" + + +@pytest.mark.base +def test_SDC_on_problem_class(): + """ + Test for SDC applied on problem class. + """ + + import numpy as np + from pySDC.implementations.problem_classes.nonlinear_ODE_1 import nonlinear_ODE_1 + from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit + from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI + + level_params = { + 'restol': 1e-13, + 'dt': 1e-3, + } + + sweeper_params = { + 'quad_type': 'LOBATTO', + 'num_nodes': 3, + 'QI': 'IE', + } + + problem_params = { + 'stop_at_nan': False, + } + + step_params = { + 'maxiter': 30, + } + + controller_params = { + 'logger_level': 30, + } + + description = dict() + description['problem_class'] = nonlinear_ODE_1 + description['problem_params'] = problem_params + description['sweeper_class'] = generic_implicit + description['sweeper_params'] = sweeper_params + description['level_params'] = level_params + description['step_params'] = step_params + + controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) + + t0 = 1.999 + Tend = 2.0 + + P = controller.MS[0].levels[0].prob + uex = P.u_exact(Tend) + + uend, _ = controller.run(u0=uex, t0=t0, Tend=Tend) + + err = abs(uex - uend) + assert err < 1.968e-8, f"Error is too large! Got {err}"