Skip to content

Commit

Permalink
Update and extend documentation, update CHANGELOG.md
Browse files Browse the repository at this point in the history
  • Loading branch information
schroedk committed Jun 7, 2024
1 parent ce05780 commit 0ac0f86
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 7 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
to `regularization` and change the type annotation to allow
for block-wise regularization parameters
[PR #593](https://github.com/aai-institute/pyDVL/pull/593)

- Remove parameter `h0` from init of `LissaInfluence`
[PR #593](https://github.com/aai-institute/pyDVL/pull/593)

## 0.9.2 - 🏗 Bug fixes, logging improvement

Expand Down
16 changes: 10 additions & 6 deletions docs/influence/influence_function_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,29 @@ original paper [@agarwal_secondorder_2017].


```python
from pydvl.influence.torch import LissaInfluence
from pydvl.influence.torch import LissaInfluence, BlockMode, SecondOrderMode
if_model = LissaInfluence(
model,
loss,
hessian_regularization=0.0
regularization=0.0
maxiter=1000,
dampen=0.0,
scale=10.0,
h0=None,
rtol=1e-4,
block_structure=BlockMode.FULL,
second_order_mode=SecondOrderMode.GAUSS_NEWTON
)
if_model.fit(train_loader)
```

with the additional optional parameters `maxiter`, `dampen`, `scale`, `h0`, and
with the additional optional parameters `maxiter`, `dampen`, `scale`, and
`rtol`,
being the maximum number of iterations, the dampening factor, the scaling
factor, the initial guess for the solution and the relative tolerance,
respectively.
factor and the relative tolerance,
respectively. This implementation is capable of using a block-matrix
approximation, see
[Block-diagonal approximation](#block-diagonal-approximation), and can handle
[Gauss-Newton approximation](#gauss-newton-approximation).

### Arnoldi

Expand Down
43 changes: 43 additions & 0 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,45 @@ class LissaInfluence(TorchComposableInfluence[LissaOperator[BatchOperationType]]
see [Linear time Stochastic Second-Order Approximation (LiSSA)]
[linear-time-stochastic-second-order-approximation-lissa]
Block-mode:
This implementation is capable of using a block-matrix approximation. The
blocking structure can be specified via the `block_structure` parameter.
The `block_structure` parameter can either be a
[BlockMode][pydvl.influence.torch.util.BlockMode] enum (which provides
layer-wise or parameter-wise blocking) or a custom block structure defined
by an ordered dictionary with the keys being the block identifiers (arbitrary
strings) and the values being lists of parameter names contained in the block.
```python
block_structure = OrderedDict(
(
("custom_block1", ["0.weight", "1.bias"]),
("custom_block2", ["1.weight", "0.bias"]),
)
)
```
If you would like to apply a block-specific regularization, you can provide a
dictionary with the block names as keys and the regularization values as values.
In this case, the specification must be complete, i.e. every block must have
a positive regularization value.
```python
regularization = {
"custom_block1": 0.1,
"custom_block2": 0.2,
}
```
Accordingly, if you choose a layer-wise or parameter-wise structure
(by providing `BlockMode.LAYER_WISE` or `BlockMode.PARAMETER_WISE` for
`block_structure`) the keys must be the layer names or parameter names,
respectively.
You can retrieve the block-wise influence information from the methods
with suffix `_by_block`. By default, `block_structure` is set to
`BlockMode.FULL` and in this case these methods will return a dictionary
with the empty string being the only key.
Args:
model: A PyTorch model. The Hessian will be calculated with respect to
this model's parameters.
Expand All @@ -823,6 +862,10 @@ class LissaInfluence(TorchComposableInfluence[LissaOperator[BatchOperationType]]
warn_on_max_iteration: If True, logs a warning, if the desired tolerance is not
achieved within `maxiter` iterations. If False, the log level for this
information is `logging.DEBUG`
block_structure: The blocking structure, either a pre-defined enum or a
custom block structure, see the information regarding block-mode.
second_order_mode: The second order mode, either `SecondOrderMode.HESSIAN` or
`SecondOrderMode.GAUSS_NEWTON`.
"""

def __init__(
Expand Down

0 comments on commit 0ac0f86

Please sign in to comment.