-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added convergence controller that can crash the code if the solution (#…
…360) exceeds some threshold or contains NaN
- Loading branch information
1 parent
f40728b
commit cf89515
Showing
2 changed files
with
255 additions
and
0 deletions.
There are no files selected for viewing
79 changes: 79 additions & 0 deletions
79
pySDC/implementations/convergence_controller_classes/stop_at_nan.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from pySDC.core.ConvergenceController import ConvergenceController | ||
from pySDC.core.Errors import ConvergenceError | ||
import numpy as np | ||
|
||
|
||
class StopAtNan(ConvergenceController): | ||
""" | ||
Crash the code when the norm of the solution exceeds some limit or contains nan. | ||
This class is useful when running with MPI in the sweeper or controller. | ||
""" | ||
|
||
def __init__(self, controller, params, description, **kwargs): | ||
super().__init__(controller, params, description, **kwargs) | ||
if self.comm or self.params.useMPI: | ||
from mpi4py import MPI | ||
|
||
self.MPI_OR = MPI.LOR | ||
|
||
def setup(self, controller, params, description, **kwargs): | ||
""" | ||
Define parameters here. | ||
Default parameters are: | ||
- tresh (float): Crash the code when the norm of the solution exceeds this threshold | ||
Args: | ||
controller (pySDC.Controller): The controller | ||
params (dict): The params passed for this specific convergence controller | ||
description (dict): The description object used to instantiate the controller | ||
Returns: | ||
(dict): The updated params dictionary | ||
""" | ||
self.comm = description['sweeper_params'].get('comm', None) | ||
defaults = { | ||
"control_order": 95, | ||
"thresh": np.inf, | ||
} | ||
|
||
return {**defaults, **super().setup(controller, params, description, **kwargs)} | ||
|
||
def post_iteration_processing(self, controller, S, comm=None, **kwargs): | ||
""" | ||
Check if we need to crash the code. | ||
Args: | ||
controller (pySDC.Controller.controller): Controller | ||
S (pySDC.Step.step): Step | ||
comm (mpi4py.MPI.Intracomm or None): Communicator of the controller, if applicable | ||
Raises: | ||
ConvergenceError: If the solution does not fall within the allowed space | ||
""" | ||
isfinite, below_limit = True, True | ||
crash = False | ||
|
||
for lvl in S.levels: | ||
for u in lvl.u: | ||
if u is None: | ||
break | ||
isfinite = all(np.isfinite(u)) | ||
below_limit = abs(u) < self.params.thresh | ||
|
||
crash = not (isfinite and below_limit) | ||
|
||
if crash: | ||
break | ||
if crash: | ||
break | ||
|
||
if self.comm: | ||
crash = self.comm.allreduce(crash, op=self.MPI_OR) | ||
elif comm: | ||
crash = comm.allreduce(crash, op=self.MPI_OR) | ||
else: | ||
crash = not isfinite or not below_limit | ||
|
||
if crash: | ||
raise ConvergenceError(f'Solution exceeds bounds! Crashing code at {S.time}!') |
176 changes: 176 additions & 0 deletions
176
pySDC/tests/test_convergence_controllers/test_stop_at_nan.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import pytest | ||
|
||
|
||
def get_controller(MPIsweeper, MPIcontroller): | ||
""" | ||
Runs a single advection problem with certain parameters | ||
Args: | ||
MPIsweeper (bool): Use MPI parallel sweeper | ||
MPIcontroller (bool): Use MPI parallel controller | ||
Returns: | ||
(pySDC.Controller.controller): Controller used in the run | ||
""" | ||
from pySDC.implementations.problem_classes.polynomial_test_problem import polynomial_testequation | ||
from pySDC.implementations.convergence_controller_classes.stop_at_nan import StopAtNan | ||
|
||
if MPIcontroller: | ||
from pySDC.implementations.controller_classes.controller_MPI import controller_MPI as controller_class | ||
from mpi4py import MPI | ||
|
||
controller_args = {'comm': MPI.COMM_WORLD} | ||
else: | ||
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI as controller_class | ||
|
||
controller_args = {'num_procs': 1} | ||
|
||
if MPIsweeper: | ||
from pySDC.implementations.sweeper_classes.generic_implicit_MPI import generic_implicit_MPI as sweeper_class | ||
from mpi4py import MPI | ||
|
||
comm = MPI.COMM_WORLD | ||
else: | ||
from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper_class | ||
|
||
comm = None | ||
|
||
# initialize level parameters | ||
level_params = {} | ||
level_params['dt'] = 1.0 | ||
level_params['restol'] = 1.0 | ||
|
||
# initialize sweeper parameters | ||
sweeper_params = {} | ||
sweeper_params['quad_type'] = 'RADAU-RIGHT' | ||
sweeper_params['num_nodes'] = 3 | ||
sweeper_params['comm'] = comm | ||
|
||
problem_params = {'degree': 12} | ||
|
||
# initialize step parameters | ||
step_params = {} | ||
step_params['maxiter'] = 0 | ||
|
||
# initialize controller parameters | ||
controller_params = {} | ||
controller_params['logger_level'] = 30 | ||
controller_params['mssdc_jac'] = False | ||
|
||
# fill description dictionary for easy step instantiation | ||
description = {} | ||
description['problem_class'] = polynomial_testequation | ||
description['problem_params'] = problem_params | ||
description['sweeper_class'] = sweeper_class | ||
description['sweeper_params'] = sweeper_params | ||
description['level_params'] = level_params | ||
description['step_params'] = step_params | ||
description['convergence_controllers'] = {StopAtNan: {'thresh': 1e3}} | ||
|
||
controller = controller_class(controller_params=controller_params, description=description, **controller_args) | ||
return controller | ||
|
||
|
||
def single_test(MPIsweeper=False, MPIcontroller=False): | ||
""" | ||
Run a single test where the solution is replaced by a polynomial and the nodes are changed. | ||
Because we know the polynomial going in, we can check if the interpolation based change was | ||
exact. If the solution is not a polynomial or a polynomial of higher degree then the number | ||
of nodes, the change in nodes does add some error, of course, but here it is on the order of | ||
machine precision. | ||
""" | ||
import numpy as np | ||
from pySDC.core.Errors import ConvergenceError | ||
|
||
args = { | ||
'MPIsweeper': MPIsweeper, | ||
'MPIcontroller': MPIcontroller, | ||
} | ||
|
||
# prepare variables | ||
controller = get_controller(**args) | ||
|
||
if MPIcontroller: | ||
step = controller.S | ||
modify = controller.comm.rank == 0 | ||
comm = controller.comm | ||
else: | ||
step = controller.MS[0] | ||
comm = None | ||
modify = True | ||
level = step.levels[0] | ||
prob = level.prob | ||
cont = controller.convergence_controllers[ | ||
np.arange(len(controller.convergence_controllers))[ | ||
[type(me).__name__ == 'StopAtNan' for me in controller.convergence_controllers] | ||
][0] | ||
] | ||
|
||
nodes = np.append([0], level.sweep.coll.nodes) | ||
|
||
# initialize variables | ||
step.status.slot = 0 | ||
step.status.iter = 1 | ||
level.status.time = 0.0 | ||
level.status.residual = 0.0 | ||
level.u[0] = prob.u_exact(t=0) | ||
level.sweep.predict() | ||
|
||
for i in range(len(level.u)): | ||
if level.u[i] is not None: | ||
level.u[i][:] = prob.u_exact(nodes[i] * level.dt) | ||
|
||
cont.post_iteration_processing(controller, step, comm=comm) | ||
|
||
try: | ||
if modify: | ||
level.u[0][:] = np.nan | ||
cont.post_iteration_processing(controller, step, comm=comm) | ||
raise Exception('Did not raise error!') | ||
except ConvergenceError: | ||
print('Successfully raised error when nan is part of the solution') | ||
|
||
try: | ||
if modify: | ||
level.u[0][:] = 1e99 | ||
cont.post_iteration_processing(controller, step, comm=comm) | ||
raise Exception('Did not raise error!') | ||
except ConvergenceError: | ||
print('Successfully raised error solution exceeds limit') | ||
|
||
|
||
@pytest.mark.base | ||
def test_stop_at_nan(): | ||
single_test() | ||
|
||
|
||
@pytest.mark.mpi4py | ||
@pytest.mark.parametrize('mode', ['0 1', '1 0']) | ||
def test_interpolation_error_MPI(mode): | ||
import subprocess | ||
import os | ||
|
||
# Set python path once | ||
my_env = os.environ.copy() | ||
my_env['PYTHONPATH'] = '../../..:.' | ||
my_env['COVERAGE_PROCESS_START'] = 'pyproject.toml' | ||
|
||
cmd = f"mpirun -np {3} python {__file__} {mode}".split() | ||
|
||
p = subprocess.Popen(cmd, env=my_env, cwd=".") | ||
|
||
p.wait() | ||
assert p.returncode == 0, 'ERROR: did not get return code 0, got %s with %2i processes' % (p.returncode, 3) | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
|
||
if len(sys.argv) > 1: | ||
kwargs = { | ||
'MPIsweeper': bool(int(sys.argv[1])), | ||
'MPIcontroller': bool(int(sys.argv[2])), | ||
} | ||
single_test(**kwargs) | ||
else: | ||
single_test() |