Skip to content

Commit

Permalink
Remaining changes for current state of adaptivity paper (#378)
Browse files Browse the repository at this point in the history
* Remaining changes for dt-k adaptivity

* Fix

* Made plots a tiny bit more pretty
  • Loading branch information
brownbaerchen authored Nov 20, 2023
1 parent 39903ea commit f05aed2
Show file tree
Hide file tree
Showing 12 changed files with 1,406 additions and 657 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,5 @@ def check_parameters(self, controller, params, description, **kwargs):
False,
"Switching the collocation problems requires solving them to some tolerance that can be reached. Please set attainable `restol` in the level params",
)
if description["step_params"].get("maxiter", -1.0) < 99:
return (
False,
"Switching the collocation problems requires solving them exactly, which may require many iterations please set `maxiter` to at least 99 in the step params",
)

return True, ""
90 changes: 74 additions & 16 deletions pySDC/implementations/convergence_controller_classes/adaptivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def setup(self, controller, params, description, **kwargs):

controller.add_hook(LogStepSize)

from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence

self.communicate_convergence = CheckConvergence.communicate_convergence

return {**defaults, **super().setup(controller, params, description, **kwargs)}

def dependencies(self, controller, description, **kwargs):
Expand Down Expand Up @@ -149,6 +153,32 @@ def determine_restart(self, controller, S, **kwargs):


class AdaptivityForConvergedCollocationProblems(AdaptivityBase):
def dependencies(self, controller, description, **kwargs):
"""
Load interpolation between restarts.
Args:
controller (pySDC.Controller): The controller
description (dict): The description object used to instantiate the controller
Returns:
None
"""
super().dependencies(controller, description, **kwargs)

if self.params.interpolate_between_restarts:
from pySDC.implementations.convergence_controller_classes.interpolate_between_restarts import (
InterpolateBetweenRestarts,
)

controller.add_convergence_controller(
InterpolateBetweenRestarts,
description=description,
params={},
)
self.interpolator = controller.convergence_controllers[-1]
return None

def get_convergence(self, controller, S, **kwargs):
raise NotImplementedError("Please implement a way to check if the collocation problem is converged!")

Expand All @@ -167,10 +197,13 @@ def setup(self, controller, params, description, **kwargs):
defaults = {
'restol_rel': None,
'e_tol_rel': None,
'restart_at_maxiter': False,
'restart_at_maxiter': True,
'restol_min': 1e-12,
'restol_max': 1e-5,
'factor_if_not_converged': 4.0,
'residual_max_tol': 1e9,
'maxiter': description['sweeper_params'].get('maxiter', 99),
'interpolate_between_restarts': True,
**super().setup(controller, params, description, **kwargs),
}
if defaults['restol_rel']:
Expand All @@ -182,20 +215,40 @@ def setup(self, controller, params, description, **kwargs):

if defaults['restart_at_maxiter']:
defaults['maxiter'] = description['step_params'].get('maxiter', 99)

self.res_last_iter = np.inf

return defaults

def determine_restart(self, controller, S, **kwargs):
if self.get_convergence(controller, S, **kwargs):
if self.get_local_error_estimate(controller, S, **kwargs) > self.params.e_tol:
S.status.restart = True
elif S.status.iter >= self.params.maxiter and self.params.restart_at_maxiter:
self.res_last_iter = np.inf

if self.params.restart_at_maxiter and S.levels[0].status.residual > S.levels[0].params.restol:
self.trigger_restart_upon_nonconvergence(S)
elif self.get_local_error_estimate(controller, S, **kwargs) > self.params.e_tol:
S.status.restart = True
for L in S.levels:
L.status.dt_new = L.params.dt / 2.0
self.log(
f'Collocation problem not converged after max. number of iterations, halving step size to {L.status.dt_new:.2e}',
S,
)
elif S.status.time_size == 1 and self.res_last_iter < S.levels[0].status.residual and S.status.iter > 0:
self.trigger_restart_upon_nonconvergence(S)
elif S.levels[0].status.residual > self.params.residual_max_tol:
self.trigger_restart_upon_nonconvergence(S)

if self.params.useMPI:
self.communicate_convergence(self, controller, S, **kwargs)

self.res_last_iter = S.levels[0].status.residual * 1.0

def trigger_restart_upon_nonconvergence(self, S):
S.status.restart = True
S.status.force_done = True
for L in S.levels:
L.status.dt_new = L.params.dt / self.params.factor_if_not_converged
self.log(
f'Collocation problem not converged. Reducing step size to {L.status.dt_new:.2e}',
S,
)
if self.params.interpolate_between_restarts:
self.interpolator.status.skip_interpolation = True


class Adaptivity(AdaptivityBase):
Expand Down Expand Up @@ -251,7 +304,7 @@ def dependencies(self, controller, description, **kwargs):
"""
from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedError

super().dependencies(controller, description)
super().dependencies(controller, description, **kwargs)

controller.add_convergence_controller(
EstimateEmbeddedError.get_implementation(self.params.embedded_error_flavor, self.params.useMPI),
Expand Down Expand Up @@ -398,6 +451,7 @@ def setup(self, controller, params, description, **kwargs):
- control_order (int): The order relative to other convergence controllers
- e_tol_low (float): Lower absolute threshold for the residual
- e_tol (float): Upper absolute threshold for the residual
- use_restol (bool): Restart if the residual tolerance was not reached
- max_restarts: Override maximum number of restarts
Args:
Expand All @@ -412,6 +466,7 @@ def setup(self, controller, params, description, **kwargs):
"control_order": -45,
"e_tol_low": 0,
"e_tol": np.inf,
"use_restol": False,
"max_restarts": 99 if "e_tol_low" in params else None,
"allowed_modifications": ['increase', 'decrease'], # what we are allowed to do with the step size
}
Expand Down Expand Up @@ -481,7 +536,9 @@ def get_new_step_size(self, controller, S, **kwargs):

dt_planned = L.status.dt_new if L.status.dt_new is not None else L.params.dt

if res > self.params.e_tol and 'decrease' in self.params.allowed_modifications:
if (
res > self.params.e_tol or (res > L.params.restol and self.params.use_restol)
) and 'decrease' in self.params.allowed_modifications:
L.status.dt_new = min([dt_planned, L.params.dt / 2.0])
self.log(f'Adjusting step size from {L.params.dt:.2e} to {L.status.dt_new:.2e}', S)
elif res < self.params.e_tol_low and 'increase' in self.params.allowed_modifications:
Expand Down Expand Up @@ -561,7 +618,7 @@ def dependencies(self, controller, description, **kwargs):
EstimateEmbeddedErrorCollocation,
)

super().dependencies(controller, description)
super().dependencies(controller, description, **kwargs)

params = {'adaptive_coll_params': self.params.adaptive_coll_params}
controller.add_convergence_controller(
Expand Down Expand Up @@ -694,7 +751,7 @@ def dependencies(self, controller, description, **kwargs):
EstimateExtrapolationErrorWithinQ,
)

super().dependencies(controller, description)
super().dependencies(controller, description, **kwargs)

controller.add_convergence_controller(
EstimateExtrapolationErrorWithinQ,
Expand Down Expand Up @@ -761,11 +818,12 @@ def setup(self, controller, params, description, **kwargs):

defaults = {
'control_order': -50,
**super().setup(controller, params, description, **kwargs),
**params,
}

self.check_convergence = CheckConvergence.check_convergence
return {**defaults, **super().setup(controller, params, description, **kwargs)}
return defaults

def get_convergence(self, controller, S, **kwargs):
return self.check_convergence(S)
Expand All @@ -785,7 +843,7 @@ def dependencies(self, controller, description, **kwargs):
EstimatePolynomialError,
)

super().dependencies(controller, description)
super().dependencies(controller, description, **kwargs)

controller.add_convergence_controller(
EstimatePolynomialError,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def setup(self, controller, params, description, **kwargs):
'You cannot interpolate with lower accuracy to the end point if the end point is a node!'
)

self.interpolation_matrix = None

return defaults

def reset_status_variables(self, controller, **kwargs):
Expand Down Expand Up @@ -131,16 +133,18 @@ def post_iteration_processing(self, controller, S, **kwargs):
nodes = np.append(np.append(0, coll.nodes), 1.0)
estimate_on_node = self.params.estimate_on_node

interpolator = LagrangeApproximation(
points=[nodes[i] for i in range(coll.num_nodes + 1) if i != estimate_on_node]
)
interpolation_matrix = interpolator.getInterpolationMatrix([nodes[estimate_on_node]])
if self.interpolation_matrix is None:
interpolator = LagrangeApproximation(
points=[nodes[i] for i in range(coll.num_nodes + 1) if i != estimate_on_node]
)
self.interpolation_matrix = interpolator.getInterpolationMatrix([nodes[estimate_on_node]])

u = [
L.u[i].flatten() if L.u[i] is not None else L.u[i]
for i in range(coll.num_nodes + 1)
if i != estimate_on_node
]
u_inter = self.matmul(interpolation_matrix, u)[0].reshape(L.prob.init[0])
u_inter = self.matmul(self.interpolation_matrix, u)[0].reshape(L.prob.init[0])

# compute end point if needed
if estimate_on_node == len(nodes) - 1:
Expand All @@ -161,6 +165,11 @@ def post_iteration_processing(self, controller, S, **kwargs):
else:
L.status.error_embedded_estimate = abs(u_inter - high_order_sol)

self.debug(
f'Obtained error estimate: {L.status.error_embedded_estimate:.2e} of order {L.status.order_embedded_estimate}',
S,
)

def check_parameters(self, controller, params, description, **kwargs):
"""
Check if we allow the scheme to solve the collocation problems to convergence.
Expand Down
13 changes: 8 additions & 5 deletions pySDC/projects/Resilience/Lorenz.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from pySDC.helpers.stats_helper import get_sorted
from pySDC.implementations.problem_classes.Lorenz import LorenzAttractor
from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
from pySDC.core.Errors import ConvergenceError
from pySDC.projects.Resilience.hook import LogData, hook_collection
from pySDC.projects.Resilience.strategies import merge_descriptions
from pySDC.projects.Resilience.sweepers import generic_implicit_efficient


def run_Lorenz(
Expand Down Expand Up @@ -37,7 +37,7 @@ def run_Lorenz(
Returns:
dict: The stats object
controller: The controller
Tend: The time that was supposed to be integrated to
bool: Whether the code crashed
"""

# initialize level parameters
Expand Down Expand Up @@ -72,7 +72,7 @@ def run_Lorenz(
description = dict()
description['problem_class'] = LorenzAttractor
description['problem_params'] = problem_params
description['sweeper_class'] = generic_implicit
description['sweeper_class'] = generic_implicit_efficient
description['sweeper_params'] = sweeper_params
description['level_params'] = level_params
description['step_params'] = step_params
Expand Down Expand Up @@ -105,12 +105,15 @@ def run_Lorenz(
prepare_controller_for_faults(controller, fault_stuff)

# call main function to get things done...
crash = False
try:
uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
except ConvergenceError:
except (ConvergenceError, ZeroDivisionError) as e:
print(f'Warning: Premature termination: {e}')
stats = controller.return_stats()
crash = True

return stats, controller, Tend
return stats, controller, crash


def plot_solution(stats): # pragma: no cover
Expand Down
25 changes: 15 additions & 10 deletions pySDC/projects/Resilience/Schroedinger.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
from mpi4py import MPI
import numpy as np

from pySDC.helpers.stats_helper import get_sorted

from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
from pySDC.implementations.problem_classes.NonlinearSchroedinger_MPIFFT import (
nonlinearschroedinger_imex,
nonlinearschroedinger_fully_implicit,
)
from pySDC.implementations.transfer_classes.TransferMesh_MPIFFT import fft_to_fft
from pySDC.projects.Resilience.hook import LogData, hook_collection
from pySDC.projects.Resilience.strategies import merge_descriptions
from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient, generic_implicit_efficient
from pySDC.core.Errors import ConvergenceError

from pySDC.core.Hooks import hooks

Expand Down Expand Up @@ -95,7 +91,7 @@ def run_Schroedinger(
Returns:
dict: The stats object
controller: The controller
Tend: The time that was supposed to be integrated to
bool: If the code crashed
"""
if custom_description is not None:
problem_params = custom_description.get('problem_params', {})
Expand All @@ -119,6 +115,7 @@ def run_Schroedinger(
sweeper_params['quad_type'] = 'RADAU-RIGHT'
sweeper_params['num_nodes'] = 3
sweeper_params['QI'] = 'IE'
sweeper_params['QE'] = 'PIC'
sweeper_params['initial_guess'] = 'spread'

# initialize problem parameters
Expand Down Expand Up @@ -148,7 +145,7 @@ def run_Schroedinger(
description = dict()
description['problem_params'] = problem_params
description['problem_class'] = nonlinearschroedinger_imex if imex else nonlinearschroedinger_fully_implicit
description['sweeper_class'] = imex_1st_order if imex else generic_implicit
description['sweeper_class'] = imex_1st_order_efficient if imex else generic_implicit_efficient
description['sweeper_params'] = sweeper_params
description['level_params'] = level_params
description['step_params'] = step_params
Expand Down Expand Up @@ -187,12 +184,20 @@ def run_Schroedinger(
prepare_controller_for_faults(controller, fault_stuff, rnd_args)

# call main function to get things done...
uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
crash = False
try:
uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
except (ConvergenceError, OverflowError) as e:
print(f'Warning: Premature termination: {e}')
stats = controller.return_stats()
crash = True

return stats, controller, Tend
return stats, controller, crash


def main():
from mpi4py import MPI

stats, _, _ = run_Schroedinger(space_comm=MPI.COMM_WORLD, hook_class=live_plotting, imex=False)
plt.show()

Expand Down
Loading

0 comments on commit f05aed2

Please sign in to comment.