diff --git a/pySDC/implementations/convergence_controller_classes/adaptivity.py b/pySDC/implementations/convergence_controller_classes/adaptivity.py index b2ec90e5c..209225d20 100644 --- a/pySDC/implementations/convergence_controller_classes/adaptivity.py +++ b/pySDC/implementations/convergence_controller_classes/adaptivity.py @@ -229,7 +229,16 @@ def determine_restart(self, controller, S, **kwargs): if self.get_convergence(controller, S, **kwargs): self.res_last_iter = np.inf - if self.params.restart_at_maxiter and S.levels[0].status.residual > S.levels[0].params.restol: + L = S.levels[0] + e_tol_converged = ( + L.status.increment < L.params.e_tol if (L.params.get('e_tol') and L.status.get('increment')) else False + ) + + if ( + self.params.restart_at_maxiter + and S.levels[0].status.residual > S.levels[0].params.restol + and not e_tol_converged + ): self.trigger_restart_upon_nonconvergence(S) elif self.get_local_error_estimate(controller, S, **kwargs) > self.params.e_tol: S.status.restart = True diff --git a/pySDC/implementations/convergence_controller_classes/check_convergence.py b/pySDC/implementations/convergence_controller_classes/check_convergence.py index 9cbe85e25..36cb8e4a2 100644 --- a/pySDC/implementations/convergence_controller_classes/check_convergence.py +++ b/pySDC/implementations/convergence_controller_classes/check_convergence.py @@ -75,9 +75,7 @@ def check_convergence(S, self=None): iter_converged = S.status.iter >= S.params.maxiter res_converged = L.status.residual <= L.params.restol e_tol_converged = ( - L.status.error_embedded_estimate < L.params.e_tol - if (L.params.get('e_tol') and L.status.get('error_embedded_estimate')) - else False + L.status.increment < L.params.e_tol if (L.params.get('e_tol') and L.status.get('increment')) else False ) converged = ( iter_converged or res_converged or e_tol_converged or S.status.force_done diff --git a/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py b/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py index 545c2f6ef..08aa14730 100644 --- a/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py +++ b/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py @@ -109,12 +109,13 @@ def estimate_embedded_error_serial(self, L): def setup_status_variables(self, controller, **kwargs): """ - Add the embedded error variable to the error function. + Add the embedded error to the level status Args: controller (pySDC.Controller): The controller """ self.add_status_variable_to_level('error_embedded_estimate') + self.add_status_variable_to_level('increment') def post_iteration_processing(self, controller, S, **kwargs): """ @@ -134,6 +135,7 @@ def post_iteration_processing(self, controller, S, **kwargs): if S.status.iter > 0 or self.params.sweeper_type == "RK": for L in S.levels: L.status.error_embedded_estimate = max([self.estimate_embedded_error_serial(L), np.finfo(float).eps]) + L.status.increment = L.status.error_embedded_estimate * 1 self.debug(f'L.status.error_embedded_estimate={L.status.error_embedded_estimate:.5e}', S) return None diff --git a/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py b/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py index f083651e3..cce409df6 100644 --- a/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py +++ b/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py @@ -150,7 +150,7 @@ def post_iteration_processing(self, controller, S, **kwargs): if self.comm: buf = np.array(abs(u_inter - high_order_sol) if self.comm.rank == rank else 0.0) self.comm.Bcast(buf, root=rank) - L.status.error_embedded_estimate = buf + L.status.error_embedded_estimate = float(buf) else: L.status.error_embedded_estimate = abs(u_inter - high_order_sol)