diff --git a/pySDC/implementations/problem_classes/DiscontinuousTestODE.py b/pySDC/implementations/problem_classes/DiscontinuousTestODE.py index af797577bd..111a8267ea 100644 --- a/pySDC/implementations/problem_classes/DiscontinuousTestODE.py +++ b/pySDC/implementations/problem_classes/DiscontinuousTestODE.py @@ -222,3 +222,64 @@ def count_switches(self): Setter to update the number of switches if one is found. """ self.nswitches += 1 + + +class ExactDiscontinuousTestODE(DiscontinuousTestODE): + r""" + Dummy ODE problem for testing the ``SwitchEstimator`` class. The problem contains the exact dynamics + of the problem class ``DiscontinuousTestODE``. + """ + + def __init__(self, newton_maxiter=100, newton_tol=1e-8): + """Initialization routine""" + super().__init__(newton_maxiter, newton_tol) + + def eval_f(self, u, t): + """ + Derivative. + + Parameters + ---------- + u : dtype_u + Exact value of u. + t : float + Time :math:`t`. + + Returns + ------- + f : dtype_f + Derivative. + """ + + f = self.dtype_f(self.init) + + t_switch = np.inf if self.t_switch is None else self.t_switch + h = u[0] - 5 + if h >= 0 or t >= t_switch: + f[:] = 1 + else: + f[:] = np.exp(t) + return f + + def solve_system(self, rhs, factor, u0, t): + """ + Just return the exact solution... + + Parameters + ---------- + rhs : dtype_f + Right-hand side for the linear system. + factor : float + Abbrev. for the local stepsize (or any other factor required). + u0 : dtype_u + Initial guess for the iterative solver. + t : float + Current time (e.g. for time-dependent BCs). + + Returns + ------- + me : dtype_u + The solution as mesh. + """ + + return self.u_exact(t) diff --git a/pySDC/tests/test_projects/test_pintsime/test_SwitchEstimator.py b/pySDC/tests/test_projects/test_pintsime/test_SwitchEstimator.py index f4958e0b58..4de3df9db5 100644 --- a/pySDC/tests/test_projects/test_pintsime/test_SwitchEstimator.py +++ b/pySDC/tests/test_projects/test_pintsime/test_SwitchEstimator.py @@ -1,69 +1,6 @@ import numpy as np import pytest -from pySDC.implementations.problem_classes.DiscontinuousTestODE import DiscontinuousTestODE -from pySDC.projects.DAE.problems.DiscontinuousTestDAE import DiscontinuousTestDAE - - -class ExactDiscontinuousTestODE(DiscontinuousTestODE): - r""" - Dummy ODE problem for testing. The problem contains the exact dynamics of the problem class ``DiscontinuousTestODE``. - """ - - def __init__(self, newton_maxiter=100, newton_tol=1e-8): - """Initialization routine""" - super().__init__(newton_maxiter, newton_tol) - - def eval_f(self, u, t): - """ - Derivative. - - Parameters - ---------- - u : dtype_u - Exact value of u. - t : float - Time :math:`t`. - - Returns - ------- - f : dtype_f - Derivative. - """ - - f = self.dtype_f(self.init) - - t_switch = np.inf if self.t_switch is None else self.t_switch - h = u[0] - 5 - if h >= 0 or t >= t_switch: - f[:] = 1 - else: - f[:] = np.exp(t) - return f - - def solve_system(self, rhs, factor, u0, t): - """ - Just return the exact solution... - - Parameters - ---------- - rhs : dtype_f - Right-hand side for the linear system. - factor : float - Abbrev. for the local stepsize (or any other factor required). - u0 : dtype_u - Initial guess for the iterative solver. - t : float - Current time (e.g. for time-dependent BCs). - - Returns - ------- - me : dtype_u - The solution as mesh. - """ - - return self.u_exact(t) - def getParamsRun(): r""" @@ -87,11 +24,14 @@ def testExactDummyProblem(): a random right-hand side to enforce the sweeper to do not stop to compute. """ - from pySDC.implementations.datatype_classes.mesh import mesh + from pySDC.implementations.problem_classes.DiscontinuousTestODE import ( + DiscontinuousTestODE, + ExactDiscontinuousTestODE, + ) childODE = ExactDiscontinuousTestODE(**{}) parentODE = DiscontinuousTestODE(**{}) - assert childODE.t_switch_exact == parentODE.t_switch_exact, f"Exact event times between classes does not match!" + assert childODE.t_switch_exact == parentODE.t_switch_exact, "Exact event times between classes does not match!" t0 = 1.0 dt = 0.1 @@ -114,11 +54,11 @@ def testExactDummyProblem(): fExactOde = childODE.eval_f(u0, t0) fOde = parentODE.eval_f(u0, t0) - assert np.allclose(fExactOde, fOde), f"Right-hand sides do not match!" + assert np.allclose(fExactOde, fOde), "Right-hand sides do not match!" fExactOdeEvent = childODE.eval_f(u0Event, tExactEventODE) fOdeEvent = parentODE.eval_f(u0Event, tExactEventODE) - assert np.allclose(fExactOdeEvent, fOdeEvent), f"Right-hand sides at event do not match!" + assert np.allclose(fExactOdeEvent, fOdeEvent), "Right-hand sides at event do not match!" @pytest.mark.base @@ -143,6 +83,7 @@ def testAdaptInterpolationInfo(quad_type): from pySDC.projects.PinTSimE.battery_model import generateDescription from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit + from pySDC.implementations.problem_classes.DiscontinuousTestODE import ExactDiscontinuousTestODE problem = ExactDiscontinuousTestODE problem_params = dict() @@ -237,6 +178,7 @@ def testDetectionBoundary(num_nodes): from pySDC.projects.PinTSimE.battery_model import generateDescription, controllerRun from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit + from pySDC.implementations.problem_classes.DiscontinuousTestODE import ExactDiscontinuousTestODE from pySDC.implementations.hooks.log_solution import LogSolution from pySDC.implementations.hooks.log_restarts import LogRestarts from pySDC.helpers.stats_helper import get_sorted @@ -313,6 +255,7 @@ def testDetectionODE(tol, num_nodes, quad_type): """ from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit + from pySDC.implementations.problem_classes.DiscontinuousTestODE import ExactDiscontinuousTestODE from pySDC.helpers.stats_helper import get_sorted from pySDC.projects.PinTSimE.battery_model import generateDescription, controllerRun from pySDC.implementations.hooks.log_solution import LogSolution @@ -399,7 +342,7 @@ def testDetectionDAE(num_nodes): from pySDC.projects.PinTSimE.paper_PSCC2024.log_event import LogEventDiscontinuousTestDAE problem = DiscontinuousTestDAE - problem_params = dict({'newton_tol': 1e-6}) + problem_params = {'newton_tol': 1e-6} t0 = 4.6 Tend = 4.62 dt = Tend - t0