diff --git a/CHANGELOG.md b/CHANGELOG.md index a59958014..1d4bb84d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,8 @@ to `regularization` and change the type annotation to allow for block-wise regularization parameters [PR #593](https://github.com/aai-institute/pyDVL/pull/593) - + - Remove parameter `h0` from init of `LissaInfluence` + [PR #593](https://github.com/aai-institute/pyDVL/pull/593) ## 0.9.2 - 🏗 Bug fixes, logging improvement diff --git a/docs/influence/influence_function_model.md b/docs/influence/influence_function_model.md index f71861af9..adb581183 100644 --- a/docs/influence/influence_function_model.md +++ b/docs/influence/influence_function_model.md @@ -74,25 +74,29 @@ original paper [@agarwal_secondorder_2017]. ```python -from pydvl.influence.torch import LissaInfluence +from pydvl.influence.torch import LissaInfluence, BlockMode, SecondOrderMode if_model = LissaInfluence( model, loss, - hessian_regularization=0.0 + regularization=0.0 maxiter=1000, dampen=0.0, scale=10.0, - h0=None, rtol=1e-4, + block_structure=BlockMode.FULL, + second_order_mode=SecondOrderMode.GAUSS_NEWTON ) if_model.fit(train_loader) ``` -with the additional optional parameters `maxiter`, `dampen`, `scale`, `h0`, and +with the additional optional parameters `maxiter`, `dampen`, `scale`, and `rtol`, being the maximum number of iterations, the dampening factor, the scaling -factor, the initial guess for the solution and the relative tolerance, -respectively. +factor and the relative tolerance, +respectively. This implementation is capable of using a block-matrix +approximation, see +[Block-diagonal approximation](#block-diagonal-approximation), and can handle +[Gauss-Newton approximation](#gauss-newton-approximation). ### Arnoldi diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index 5b2fe900c..9e24790ee 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -808,6 +808,45 @@ class LissaInfluence(TorchComposableInfluence[LissaOperator[BatchOperationType]] see [Linear time Stochastic Second-Order Approximation (LiSSA)] [linear-time-stochastic-second-order-approximation-lissa] + Block-mode: + This implementation is capable of using a block-matrix approximation. The + blocking structure can be specified via the `block_structure` parameter. + The `block_structure` parameter can either be a + [BlockMode][pydvl.influence.torch.util.BlockMode] enum (which provides + layer-wise or parameter-wise blocking) or a custom block structure defined + by an ordered dictionary with the keys being the block identifiers (arbitrary + strings) and the values being lists of parameter names contained in the block. + + ```python + block_structure = OrderedDict( + ( + ("custom_block1", ["0.weight", "1.bias"]), + ("custom_block2", ["1.weight", "0.bias"]), + ) + ) + ``` + + If you would like to apply a block-specific regularization, you can provide a + dictionary with the block names as keys and the regularization values as values. + In this case, the specification must be complete, i.e. every block must have + a positive regularization value. + + ```python + regularization = { + "custom_block1": 0.1, + "custom_block2": 0.2, + } + ``` + Accordingly, if you choose a layer-wise or parameter-wise structure + (by providing `BlockMode.LAYER_WISE` or `BlockMode.PARAMETER_WISE` for + `block_structure`) the keys must be the layer names or parameter names, + respectively. + + You can retrieve the block-wise influence information from the methods + with suffix `_by_block`. By default, `block_structure` is set to + `BlockMode.FULL` and in this case these methods will return a dictionary + with the empty string being the only key. + Args: model: A PyTorch model. The Hessian will be calculated with respect to this model's parameters. @@ -823,6 +862,10 @@ class LissaInfluence(TorchComposableInfluence[LissaOperator[BatchOperationType]] warn_on_max_iteration: If True, logs a warning, if the desired tolerance is not achieved within `maxiter` iterations. If False, the log level for this information is `logging.DEBUG` + block_structure: The blocking structure, either a pre-defined enum or a + custom block structure, see the information regarding block-mode. + second_order_mode: The second order mode, either `SecondOrderMode.HESSIAN` or + `SecondOrderMode.GAUSS_NEWTON`. """ def __init__(