Skip to content

Commit

Permalink
Moved dummy problem to file
Browse files Browse the repository at this point in the history
  • Loading branch information
lisawim committed Apr 11, 2024
1 parent 872e81e commit bc751fd
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 68 deletions.
61 changes: 61 additions & 0 deletions pySDC/implementations/problem_classes/DiscontinuousTestODE.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,64 @@ def count_switches(self):
Setter to update the number of switches if one is found.
"""
self.nswitches += 1


class ExactDiscontinuousTestODE(DiscontinuousTestODE):
r"""
Dummy ODE problem for testing the ``SwitchEstimator`` class. The problem contains the exact dynamics
of the problem class ``DiscontinuousTestODE``.
"""

def __init__(self, newton_maxiter=100, newton_tol=1e-8):
"""Initialization routine"""
super().__init__(newton_maxiter, newton_tol)

def eval_f(self, u, t):
"""
Derivative.
Parameters
----------
u : dtype_u
Exact value of u.
t : float
Time :math:`t`.
Returns
-------
f : dtype_f
Derivative.
"""

f = self.dtype_f(self.init)

t_switch = np.inf if self.t_switch is None else self.t_switch
h = u[0] - 5
if h >= 0 or t >= t_switch:
f[:] = 1
else:
f[:] = np.exp(t)
return f

def solve_system(self, rhs, factor, u0, t):
"""
Just return the exact solution...
Parameters
----------
rhs : dtype_f
Right-hand side for the linear system.
factor : float
Abbrev. for the local stepsize (or any other factor required).
u0 : dtype_u
Initial guess for the iterative solver.
t : float
Current time (e.g. for time-dependent BCs).
Returns
-------
me : dtype_u
The solution as mesh.
"""

return self.u_exact(t)
79 changes: 11 additions & 68 deletions pySDC/tests/test_projects/test_pintsime/test_SwitchEstimator.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,6 @@
import numpy as np
import pytest

from pySDC.implementations.problem_classes.DiscontinuousTestODE import DiscontinuousTestODE
from pySDC.projects.DAE.problems.DiscontinuousTestDAE import DiscontinuousTestDAE


class ExactDiscontinuousTestODE(DiscontinuousTestODE):
r"""
Dummy ODE problem for testing. The problem contains the exact dynamics of the problem class ``DiscontinuousTestODE``.
"""

def __init__(self, newton_maxiter=100, newton_tol=1e-8):
"""Initialization routine"""
super().__init__(newton_maxiter, newton_tol)

def eval_f(self, u, t):
"""
Derivative.
Parameters
----------
u : dtype_u
Exact value of u.
t : float
Time :math:`t`.
Returns
-------
f : dtype_f
Derivative.
"""

f = self.dtype_f(self.init)

t_switch = np.inf if self.t_switch is None else self.t_switch
h = u[0] - 5
if h >= 0 or t >= t_switch:
f[:] = 1
else:
f[:] = np.exp(t)
return f

def solve_system(self, rhs, factor, u0, t):
"""
Just return the exact solution...
Parameters
----------
rhs : dtype_f
Right-hand side for the linear system.
factor : float
Abbrev. for the local stepsize (or any other factor required).
u0 : dtype_u
Initial guess for the iterative solver.
t : float
Current time (e.g. for time-dependent BCs).
Returns
-------
me : dtype_u
The solution as mesh.
"""

return self.u_exact(t)


def getParamsRun():
r"""
Expand All @@ -87,11 +24,14 @@ def testExactDummyProblem():
a random right-hand side to enforce the sweeper to do not stop to compute.
"""

from pySDC.implementations.datatype_classes.mesh import mesh
from pySDC.implementations.problem_classes.DiscontinuousTestODE import (
DiscontinuousTestODE,
ExactDiscontinuousTestODE,
)

childODE = ExactDiscontinuousTestODE(**{})
parentODE = DiscontinuousTestODE(**{})
assert childODE.t_switch_exact == parentODE.t_switch_exact, f"Exact event times between classes does not match!"
assert childODE.t_switch_exact == parentODE.t_switch_exact, "Exact event times between classes does not match!"

t0 = 1.0
dt = 0.1
Expand All @@ -114,11 +54,11 @@ def testExactDummyProblem():

fExactOde = childODE.eval_f(u0, t0)
fOde = parentODE.eval_f(u0, t0)
assert np.allclose(fExactOde, fOde), f"Right-hand sides do not match!"
assert np.allclose(fExactOde, fOde), "Right-hand sides do not match!"

fExactOdeEvent = childODE.eval_f(u0Event, tExactEventODE)
fOdeEvent = parentODE.eval_f(u0Event, tExactEventODE)
assert np.allclose(fExactOdeEvent, fOdeEvent), f"Right-hand sides at event do not match!"
assert np.allclose(fExactOdeEvent, fOdeEvent), "Right-hand sides at event do not match!"


@pytest.mark.base
Expand All @@ -143,6 +83,7 @@ def testAdaptInterpolationInfo(quad_type):
from pySDC.projects.PinTSimE.battery_model import generateDescription
from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator
from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
from pySDC.implementations.problem_classes.DiscontinuousTestODE import ExactDiscontinuousTestODE

problem = ExactDiscontinuousTestODE
problem_params = dict()
Expand Down Expand Up @@ -237,6 +178,7 @@ def testDetectionBoundary(num_nodes):

from pySDC.projects.PinTSimE.battery_model import generateDescription, controllerRun
from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
from pySDC.implementations.problem_classes.DiscontinuousTestODE import ExactDiscontinuousTestODE
from pySDC.implementations.hooks.log_solution import LogSolution
from pySDC.implementations.hooks.log_restarts import LogRestarts
from pySDC.helpers.stats_helper import get_sorted
Expand Down Expand Up @@ -313,6 +255,7 @@ def testDetectionODE(tol, num_nodes, quad_type):
"""

from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
from pySDC.implementations.problem_classes.DiscontinuousTestODE import ExactDiscontinuousTestODE
from pySDC.helpers.stats_helper import get_sorted
from pySDC.projects.PinTSimE.battery_model import generateDescription, controllerRun
from pySDC.implementations.hooks.log_solution import LogSolution
Expand Down Expand Up @@ -399,7 +342,7 @@ def testDetectionDAE(num_nodes):
from pySDC.projects.PinTSimE.paper_PSCC2024.log_event import LogEventDiscontinuousTestDAE

problem = DiscontinuousTestDAE
problem_params = dict({'newton_tol': 1e-6})
problem_params = {'newton_tol': 1e-6}
t0 = 4.6
Tend = 4.62
dt = Tend - t0
Expand Down

0 comments on commit bc751fd

Please sign in to comment.