Skip to content

Commit

Permalink
Refactor for/else statement in LissaOperator for better readability
Browse files Browse the repository at this point in the history
  • Loading branch information
schroedk committed Jun 10, 2024
1 parent 044aa7c commit 68cd60d
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/pydvl/influence/torch/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor:
self.data.batch_size,
shuffle=True,
)
is_converged = False
for k in tqdm(
range(self.maxiter), disable=not self.progress, desc="Lissa iteration"
):
Expand All @@ -449,8 +450,10 @@ def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor:
f"{max_residual*100:.2f} % max residual and"
f" mean residual {mean_residual*100:.5f} %"
)
is_converged = True
break
else:

if not is_converged:
mean_residual = torch.mean(torch.abs(residual / h_estimate))
log_level = logging.WARNING if self.warn_on_max_iteration else logging.DEBUG
logger.log(
Expand Down

0 comments on commit 68cd60d

Please sign in to comment.