Skip to content

Commit

Permalink
TIME-X Test Hackathon @ TUD: Test for SwitchEstimator (#404)
Browse files Browse the repository at this point in the history
* Added piecewise linear interpolation to SwitchEstimator

* Started with test for SwitchEstimator [WIP]

* Test to proof sum_restarts when event occuring at boundary

* Started with test to check adapt_interpolation_info [WIP]

* Added test for SE.adapt_interpolation_info()

* Update linear interpolation + logging + changing tolerances

* Test for linear interpolation + update of other test

* Correction for finite difference + adaption tolerance

* Added test for DAE case for SE

* Choice of FD seems to be important for performance of SE

* Removed attributes from dummy probs (since the parent classes have it)

* Test for dummy problems + using functions from battery_model.py

* Moved standard params for test to function

* Updated hardcoded solutions for battery models

* Updated hardcoded solutions for DiscontinuousTestODE

* Updated docu in SE for FDs

* Lagrange Interpolation works better with baclward FD and alpha=0.9

* Added test for state function + global error

* Updated LogEvent hooks

* Updated hardcoded solutions again

* Adapted test_problems.py

* Minor changes

* Updated tests

* Speed-up test for buck converter

* Black..

* Use msg about convergence info in Newton in SE

* Moved dummy problem to file

* Speed up loop using mask

* Removed loop
  • Loading branch information
lisawim authored Apr 21, 2024
1 parent b02181b commit 18989d3
Show file tree
Hide file tree
Showing 10 changed files with 686 additions and 179 deletions.
63 changes: 62 additions & 1 deletion pySDC/implementations/problem_classes/DiscontinuousTestODE.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,72 @@ def get_switching_info(self, u, t):
m_guess = m - 1
break

state_function = [u[m][0] - 5 for m in range(len(u))] if switch_detected else []
state_function = [u[m][0] - 5 for m in range(len(u))]
return switch_detected, m_guess, state_function

def count_switches(self):
"""
Setter to update the number of switches if one is found.
"""
self.nswitches += 1


class ExactDiscontinuousTestODE(DiscontinuousTestODE):
r"""
Dummy ODE problem for testing the ``SwitchEstimator`` class. The problem contains the exact dynamics
of the problem class ``DiscontinuousTestODE``.
"""

def __init__(self, newton_maxiter=100, newton_tol=1e-8):
"""Initialization routine"""
super().__init__(newton_maxiter, newton_tol)

def eval_f(self, u, t):
"""
Derivative.
Parameters
----------
u : dtype_u
Exact value of u.
t : float
Time :math:`t`.
Returns
-------
f : dtype_f
Derivative.
"""

f = self.dtype_f(self.init)

t_switch = np.inf if self.t_switch is None else self.t_switch
h = u[0] - 5
if h >= 0 or t >= t_switch:
f[:] = 1
else:
f[:] = np.exp(t)
return f

def solve_system(self, rhs, factor, u0, t):
"""
Just return the exact solution...
Parameters
----------
rhs : dtype_f
Right-hand side for the linear system.
factor : float
Abbrev. for the local stepsize (or any other factor required).
u0 : dtype_u
Initial guess for the iterative solver.
t : float
Current time (e.g. for time-dependent BCs).
Returns
-------
me : dtype_u
The solution as mesh.
"""

return self.u_exact(t)
21 changes: 12 additions & 9 deletions pySDC/projects/PinTSimE/battery_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,13 @@ def generateDescription(
'convergence_controllers': convergence_controllers,
}

return description, controller_params
# instantiate controller
controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)

return description, controller_params, controller


def controllerRun(description, controller_params, t0, Tend, exact_event_time_avail=False):
def controllerRun(description, controller_params, controller, t0, Tend, exact_event_time_avail=False):
"""
Executes a controller run for a problem defined in the description.
Expand All @@ -180,6 +183,8 @@ def controllerRun(description, controller_params, t0, Tend, exact_event_time_ava
Contains all information for a controller run.
controller_params : dict
Parameters needed for a controller run.
controller : pySDC.core.Controller
Controller to do the stuff.
t0 : float
Starting time of simulation.
Tend : float
Expand All @@ -193,9 +198,6 @@ def controllerRun(description, controller_params, t0, Tend, exact_event_time_ava
Raw statistics from a controller run.
"""

# instantiate controller
controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)

# get initial values on finest level
P = controller.MS[0].levels[0].prob
uinit = P.u_exact(t0)
Expand Down Expand Up @@ -233,7 +235,7 @@ def main():
'max_restarts': 50,
'recomputed': False,
'tol_event': 1e-10,
'alpha': 1.0,
'alpha': 0.96,
'exact_event_time_avail': None,
}

Expand All @@ -244,8 +246,8 @@ def main():

hook_class = [LogSolution, LogEventBattery, LogEmbeddedErrorEstimate, LogStepSize]

use_detection = [True]
use_adaptivity = [True]
use_detection = [True, False]
use_adaptivity = [True, False]

for problem, sweeper in zip([battery, battery_implicit], [imex_1st_order, generic_implicit]):
for defaults in [False, True]:
Expand Down Expand Up @@ -360,7 +362,7 @@ def runSimulation(problem, sweeper, all_params, use_adaptivity, use_detection, h

restol = -1 if use_A else handling_params['restol']

description, controller_params = generateDescription(
description, controller_params, controller = generateDescription(
dt=dt,
problem=problem,
sweeper=sweeper,
Expand All @@ -381,6 +383,7 @@ def runSimulation(problem, sweeper, all_params, use_adaptivity, use_detection, h
stats, t_switch_exact = controllerRun(
description=description,
controller_params=controller_params,
controller=controller,
t0=interval[0],
Tend=interval[-1],
exact_event_time_avail=handling_params['exact_event_time_avail'],
Expand Down
2 changes: 1 addition & 1 deletion pySDC/projects/PinTSimE/buck_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def main():
use_adaptivity=use_adaptivity,
use_detection=use_detection,
hook_class=hook_class,
interval=(0.0, 2e-2),
interval=(0.0, 1e-2),
dt_list=[1e-5, 2e-5],
nnodes=[M_fix],
)
Expand Down
3 changes: 2 additions & 1 deletion pySDC/projects/PinTSimE/discontinuous_test_ODE.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def main():
'max_restarts': 50,
'recomputed': False,
'tol_event': 1e-12,
'alpha': 1.0,
'alpha': 0.96,
'exact_event_time_avail': True,
'typeFD': 'backward',
}

# ---- all parameters are stored in this dictionary ----
Expand Down
14 changes: 10 additions & 4 deletions pySDC/projects/PinTSimE/estimation_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def run_estimation_check():
'max_restarts': 50,
'recomputed': False,
'tol_event': 1e-10,
'alpha': 1.0,
'alpha': 0.96,
'exact_event_time_avail': None,
}

Expand Down Expand Up @@ -114,7 +114,7 @@ def run_estimation_check():

plotAccuracyCheck(u_num, prob_cls_name, M_fix)

plotStateFunctionAroundEvent(u_num, prob_cls_name, M_fix)
# plotStateFunctionAroundEvent(u_num, prob_cls_name, M_fix)

plotStateFunctionOverTime(u_num, prob_cls_name, M_fix)

Expand Down Expand Up @@ -187,6 +187,9 @@ def plotStateFunctionAroundEvent(u_num, prob_cls_name, M_fix): # pragma: no cov
Routine that plots the state function at time before the event, exactly at the event, and after the event. Note
that this routine does make sense only for a state function that remains constant after the event.
TODO: Function still does not work as expected. Every time when the switch estimator is adapted, the tolerances
does not suit anymore!
Parameters
----------
u_num : dict
Expand Down Expand Up @@ -239,15 +242,18 @@ def plotStateFunctionAroundEvent(u_num, prob_cls_name, M_fix): # pragma: no cov

if use_SE:
t_switches = [u_num[dt][M_fix][use_SE][use_A]['t_switches'] for dt in dt_list]
t_switch = [t_event[i] for t_event in t_switches]
for t_switch_item in t_switches:
mask = np.append([True], np.abs(t_switch_item[1:] - t_switch_item[:-1]) > 1e-10)
t_switch_item = t_switch_item[mask]

t_switch = [t_event[i] for t_event in t_switches]
ax[0, ind].plot(
dt_list,
[
h_item[m]
for (t_item, h_item, t_switch_item) in zip(t, h, t_switch)
for m in range(len(t_item))
if abs(t_item[m] - t_switch_item) <= 1e-14
if abs(t_item[m] - t_switch_item) <= 2.7961188919789493e-11
],
color='limegreen',
marker='s',
Expand Down
Loading

0 comments on commit 18989d3

Please sign in to comment.