Skip to content

Commit

Permalink
Set log_duration level for fit methods to INFO
Browse files Browse the repository at this point in the history
  • Loading branch information
schroedk committed Apr 21, 2024
1 parent a033f87 commit f09f020
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def is_fitted(self):
except AttributeError:
return False

@log_duration(log_level=logging.INFO)
def fit(self, data: DataLoader) -> DirectInfluence:
"""
Compute the hessian matrix based on a provided dataloader.
Expand Down Expand Up @@ -500,6 +501,7 @@ def is_fitted(self):
except AttributeError:
return False

@log_duration(log_level=logging.INFO)
def fit(self, data: DataLoader) -> CgInfluence:
self.train_dataloader = data
if self.pre_conditioner is not None:
Expand Down Expand Up @@ -816,6 +818,7 @@ def is_fitted(self):
except AttributeError:
return False

@log_duration(log_level=logging.INFO)
def fit(self, data: DataLoader) -> LissaInfluence:
self.train_dataloader = data
return self
Expand Down Expand Up @@ -948,6 +951,7 @@ def is_fitted(self):
except AttributeError:
return False

@log_duration(log_level=logging.INFO)
def fit(self, data: DataLoader) -> ArnoldiInfluence:
r"""
Fitting corresponds to the computation of the low rank decomposition
Expand Down Expand Up @@ -1204,6 +1208,7 @@ def _get_kfac_blocks(

return forward_x, grad_y

@log_duration(log_level=logging.INFO)
def fit(self, data: DataLoader) -> EkfacInfluence:
"""
Compute the KFAC blocks for each layer of the model, using the provided data.
Expand Down Expand Up @@ -1712,6 +1717,7 @@ def is_fitted(self):
except AttributeError:
return False

@log_duration(log_level=logging.INFO)
def fit(self, data: DataLoader):
self.low_rank_representation = model_hessian_nystroem_approximation(
self.model, self.loss, data, self.rank
Expand Down

0 comments on commit f09f020

Please sign in to comment.