From 68cd60d50a76f7abbda32bcfe4c9f4a4dcc5f611 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Mon, 10 Jun 2024 16:36:14 +0200 Subject: [PATCH] Refactor for/else statement in LissaOperator for better readability --- src/pydvl/influence/torch/operator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/pydvl/influence/torch/operator.py b/src/pydvl/influence/torch/operator.py index 4383b996c..00ffac67a 100644 --- a/src/pydvl/influence/torch/operator.py +++ b/src/pydvl/influence/torch/operator.py @@ -432,6 +432,7 @@ def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: self.data.batch_size, shuffle=True, ) + is_converged = False for k in tqdm( range(self.maxiter), disable=not self.progress, desc="Lissa iteration" ): @@ -449,8 +450,10 @@ def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: f"{max_residual*100:.2f} % max residual and" f" mean residual {mean_residual*100:.5f} %" ) + is_converged = True break - else: + + if not is_converged: mean_residual = torch.mean(torch.abs(residual / h_estimate)) log_level = logging.WARNING if self.warn_on_max_iteration else logging.DEBUG logger.log(