diff --git a/src/pydvl/influence/torch/operator.py b/src/pydvl/influence/torch/operator.py index 4383b996c..00ffac67a 100644 --- a/src/pydvl/influence/torch/operator.py +++ b/src/pydvl/influence/torch/operator.py @@ -432,6 +432,7 @@ def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: self.data.batch_size, shuffle=True, ) + is_converged = False for k in tqdm( range(self.maxiter), disable=not self.progress, desc="Lissa iteration" ): @@ -449,8 +450,10 @@ def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: f"{max_residual*100:.2f} % max residual and" f" mean residual {mean_residual*100:.5f} %" ) + is_converged = True break - else: + + if not is_converged: mean_residual = torch.mean(torch.abs(residual / h_estimate)) log_level = logging.WARNING if self.warn_on_max_iteration else logging.DEBUG logger.log(