diff --git a/pySDC/projects/DAE/tests/test_problems.py b/pySDC/projects/DAE/tests/test_problems.py index 3bcbf33233..671988c331 100644 --- a/pySDC/projects/DAE/tests/test_problems.py +++ b/pySDC/projects/DAE/tests/test_problems.py @@ -614,7 +614,7 @@ def test_WSCC9_update_YBus(): @pytest.mark.timeout(360) @pytest.mark.base -def test_WSCC9_SDC_detection(): +def test_WSCC9_get_switching_info(): """ Test if state function states a root. """ @@ -623,8 +623,6 @@ def test_WSCC9_SDC_detection(): from pySDC.projects.DAE.problems.WSCC9BusSystem import WSCC9BusSystem from pySDC.projects.DAE.sweepers.fully_implicit_DAE import fully_implicit_DAE from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI - from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator - from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestartingNonMPI dt = 0.75 level_params = { @@ -650,21 +648,6 @@ def test_WSCC9_SDC_detection(): 'logger_level': 30, } - switch_estimator_params = { - 'tol': 1e-10, - 'alpha': 0.97, - } - - restarting_params = { - 'max_restarts': 200, - 'crash_after_max_restarts': False, - } - - convergence_controllers = { - SwitchEstimator: switch_estimator_params, - BasicRestartingNonMPI: restarting_params, - } - description = { 'problem_class': WSCC9BusSystem, 'problem_params': problem_params, @@ -672,7 +655,6 @@ def test_WSCC9_SDC_detection(): 'sweeper_params': sweeper_params, 'level_params': level_params, 'step_params': step_params, - 'convergence_controllers': convergence_controllers, } controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) @@ -680,17 +662,18 @@ def test_WSCC9_SDC_detection(): t0 = 0.0 Tend = dt - P = controller.MS[0].levels[0].prob + L = controller.MS[0].levels[0] + P = L.prob uinit = P.u_exact(t0) _, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) - switches = get_sorted(stats, type='switch', sortby='time', recomputed=False) - assert len(switches) >= 1, 'ERROR: No events found!' - t_switch = [me[1] for me in switches][0] - assert np.isclose( - t_switch, 0.528458886745887, atol=1e-3 - ), f'Found event does not match a threshold! Got {t_switch=}' + switch_detected, _, state_function = P.get_switching_info(L.u, L.time) + + assert switch_detected, f"Event should found here, but no event is found!" + + sign_change = True if state_function[0] * state_function[-1] < 0 else False + assert sign_change, f"State function does not have sign change" # @pytest.mark.base