Skip to content

Commit

Permalink
Updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lisawim committed Mar 7, 2024
1 parent 0bfcae7 commit 827a8ed
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 26 deletions.
24 changes: 8 additions & 16 deletions pySDC/tests/test_projects/test_DAE/test_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def test_DiscontinuousTestDAE_singularity():

assert np.isclose(
abs(f_before_event), 0.0
), f"ERROR: Right-hand side after event does not match! Expected {(0.0, 0.0)}, got {f_before_event}"
), f"ERROR: Right-hand side after event does not match! Expected {(0.0, 0.0)}, got {f_before_event=}"

# test for t <= t^*
u_event = disc_test_DAE.u_exact(t_event)
Expand Down Expand Up @@ -424,7 +424,7 @@ def test_DiscontinuousTestDAE_SDC(M):
uend, _ = controller.run(u0=uinit, t0=t0, Tend=Tend)

err = abs(uex.diff[0] - uend.diff[0])
assert err < err_tol[M], f"ERROR: Error is too large! Expected {err_tol[M]}, got {err}"
assert err < err_tol[M], f"ERROR: Error is too large! Expected {err_tol[M]=}, got {err=}"


@pytest.mark.base
Expand All @@ -442,20 +442,14 @@ def test_DiscontinuousTestDAE_SDC_detection(M):
from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator
from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestartingNonMPI

err_tol = {
3: 5.97e-13,
4: 1.43e-10,
5: 3.18e-10,
}

event_err_tol = {
3: 0.02,
4: 8.22e-13,
5: 1.64e-12,
4: 5e-10,
5: 1e-10,
}

level_params = {
'restol': 1e-13,
'restol': 1e-10,
'dt': 1e-2,
}

Expand Down Expand Up @@ -513,8 +507,7 @@ def test_DiscontinuousTestDAE_SDC_detection(M):

uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
err = abs(uex.diff[0] - uend.diff[0])
print(err)
assert err < err_tol[M], f"ERROR for M={M}: Error is too large! Expected {err_tol[M]}, got {err}"
assert err < 2e-9, f"ERROR for M={M}: Error is too large! Expected something lower than {2e-9}, got {err=}"

switches = get_sorted(stats, type='switch', sortby='time', recomputed=False)
assert len(switches) >= 1, 'ERROR for M={M}: No events found!'
Expand All @@ -523,10 +516,9 @@ def test_DiscontinuousTestDAE_SDC_detection(M):

t_switch_exact = P.t_switch_exact
event_err = abs(t_switch_exact - t_switch)
print(event_err)
assert (
event_err < event_err_tol[M]
), f"ERROR for M={M}: Event error is too large! Expected {event_err_tol[M]}, got {event_err}"
), f"ERROR for M={M}: Event error is too large! Expected {event_err_tol[M]=}, got {event_err=}"


@pytest.mark.base
Expand Down Expand Up @@ -695,7 +687,7 @@ def test_WSCC9_SDC_detection():
switches = get_sorted(stats, type='switch', sortby='time', recomputed=False)
assert len(switches) >= 1, 'ERROR: No events found!'
t_switch = [me[1] for me in switches][0]
assert np.isclose(t_switch, 0.528458886745887, atol=1e-3), f'Found event does not match a threshold! Got {t_switch}'
assert np.isclose(t_switch, 0.528458886745887, atol=1e-3), f'Found event does not match a threshold! Got {t_switch=}'


# @pytest.mark.base
Expand Down
18 changes: 8 additions & 10 deletions pySDC/tests/test_projects/test_pintsime/test_SwitchEstimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,12 @@ def testAdaptInterpolationInfo(quad_type):
assert t_interp[0] != t_interp[1], 'Starting time from interpolation axis is not removed!'
assert (
len(t_interp) == num_nodes
), f'Number of values on interpolation axis does not match. Expected {num_nodes}, got {len(t_interp)}'
), f'Number of values on interpolation axis does not match. Expected {num_nodes=}, got {len(t_interp)}'

elif quad_type == 'RADAU-RIGHT':
assert (
len(t_interp) == num_nodes + 1
), f'Number of values on interpolation axis does not match. Expected {num_nodes + 1}, got {len(t_interp)}'
), f'Number of values on interpolation axis does not match. Expected {num_nodes + 1=}, got {len(t_interp)}'


@pytest.mark.base
Expand Down Expand Up @@ -364,7 +364,7 @@ def testDetectionODE(tol, num_nodes, quad_type):

t_switch = switches[-1]
event_err = abs(t_switch - t_switch_exact)
assert np.isclose(event_err, 0, atol=1.2e-11), f'Event time error {event_err} is not small enough!'
assert np.isclose(event_err, 0, atol=1.2e-11), f'Event time error {event_err=} is not small enough!'


@pytest.mark.base
Expand Down Expand Up @@ -411,7 +411,7 @@ def testDetectionDAE(num_nodes):

_, _, _, _, useA, useSE, exact_event_time_avail = getParamsRun()

restol = 1e-13
restol = 1e-11
maxiter = 60
max_restarts = 20
alpha = 0.97
Expand Down Expand Up @@ -447,18 +447,16 @@ def testDetectionDAE(num_nodes):

# in this specific example only one event has to be found
switches = [me[1] for me in get_sorted(stats, type='switch', sortby='time', recomputed=False)]
assert len(switches) >= 1, f'{problem.__name__}: No events found for tol={tol} and M={num_nodes}!'
assert len(switches) >= 1, f'{problem.__name__}: No events found for {tol=} and {num_nodes=}!'

t_switch = switches[-1]
event_err = abs(t_switch - t_switch_exact)
assert np.isclose(event_err, 0, atol=2.2e-6), f'Event time error {event_err} is not small enough!'
assert np.isclose(event_err, 0, atol=2.2e-6), f'Event time error {event_err=} is not small enough!'

h = np.array([val[1] for val in get_sorted(stats, type='state_function', sortby='time', recomputed=False)])
if h[-1] < 0:
assert abs(h[-1]) < 4.7e-10, f"State function has large negative value -> SE does switch too early! Got {h[-1]}"
assert np.isclose(abs(h[-1]), 0.0, atol=5e-10), f'State function is not close to zero; value is {h[-1]}'
assert np.isclose(abs(h[-1]), 0.0, atol=2e-9), f'State function is not close to zero; value is {h[-1]}'

e_global = np.array(get_sorted(stats, type='e_global_differential_post_step', sortby='time', recomputed=False))
assert np.isclose(
e_global[-1, 1], 0.0, atol=2.4e-10
e_global[-1, 1], 0.0, atol=9.93e-10
), f"Error at end time is too large! Expected {1e-11}, got {e_global[-1, 1]}"

0 comments on commit 827a8ed

Please sign in to comment.