Skip to content

Commit

Permalink
Implement block-diagonal and Gauss_newton approximation for direct so…
Browse files Browse the repository at this point in the history
…lver
  • Loading branch information
schroedk committed Jun 3, 2024
1 parent 8498524 commit b04f7a0
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 99 deletions.
4 changes: 3 additions & 1 deletion src/pydvl/influence/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def grads_inner_prod(
def mixed_grads_inner_prod(
self,
left: TorchBatch,
right: TorchBatch,
right: Optional[TorchBatch],
gradient_provider: TorchGradientProvider,
) -> torch.Tensor:
r"""
Expand All @@ -386,6 +386,8 @@ def mixed_grads_inner_prod(
A tensor representing the inner products of the mixed per-sample gradients
"""
operator = cast(TensorDictOperator, self.operator)
if right is None:
right = left
right_grads = gradient_provider.mixed_grads(right)
left_grads = gradient_provider.grads(left)
left_grads = operator.apply_to_dict(left_grads)
Expand Down
87 changes: 77 additions & 10 deletions src/pydvl/influence/torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@
from torch.func import functional_call, grad, jvp, vjp
from torch.utils.data import DataLoader

from .util import align_structure, align_with_model, flatten_dimensions, to_model_device
from .util import (
BlockMode,
ModelParameterDictBuilder,
align_structure,
align_with_model,
flatten_dimensions,
to_model_device,
)

__all__ = [
"create_hvp_function",
Expand Down Expand Up @@ -383,6 +390,7 @@ def hessian(
data_loader: DataLoader,
use_hessian_avg: bool = True,
track_gradients: bool = False,
restrict_to: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
"""
Computes the Hessian matrix for a given model and loss function.
Expand All @@ -397,18 +405,23 @@ def hessian(
If False, the empirical loss across the entire dataset is used.
track_gradients: Whether to track gradients for the resulting tensor of
the hessian vector products.
restrict_to: The parameters to restrict the second order differentiation to,
i.e. the corresponding sub-matrix of the Hessian. If None, the full Hessian
is computed.
Returns:
A tensor representing the Hessian matrix. The shape of the tensor will be
(n_parameters, n_parameters), where n_parameters is the number of trainable
parameters in the model.
"""
params = restrict_to

params = {
k: p if track_gradients else p.detach()
for k, p in model.named_parameters()
if p.requires_grad
}
if params is None:
params = {
k: p if track_gradients else p.detach()
for k, p in model.named_parameters()
if p.requires_grad
}
n_parameters = sum([p.numel() for p in params.values()])
model_dtype = next((p.dtype for p in params.values()))

Expand All @@ -424,13 +437,16 @@ def hessian(
def flat_input_batch_loss_function(
p: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor
):
return blf(align_with_model(p, model), t_x, t_y)
return blf(align_structure(params, p), t_x, t_y)

for x, y in iter(data_loader):
n_samples += x.shape[0]
hessian_mat += x.shape[0] * torch.func.hessian(
flat_input_batch_loss_function
)(flat_params, to_model_device(x, model), to_model_device(y, model))
batch_hessian = torch.func.hessian(flat_input_batch_loss_function)(
flat_params, to_model_device(x, model), to_model_device(y, model)
)
if not track_gradients and batch_hessian.requires_grad:
batch_hessian = batch_hessian.detach()
hessian_mat += x.shape[0] * batch_hessian

hessian_mat /= n_samples
else:
Expand All @@ -447,6 +463,57 @@ def flat_input_empirical_loss(p: torch.Tensor):
return hessian_mat


def gauss_newton(
model: torch.nn.Module,
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
data_loader: DataLoader,
restrict_to: Optional[Dict[str, torch.Tensor]] = None,
):
r"""
Compute the Gauss-Newton matrix, i.e.
$$ \sum_{i=1}^N \nabla_{\theta}\ell(m(x_i; \theta), y)
\nabla_{\theta}\ell(m(x_i; \theta), y)^t,$$
for a loss function $\ell$ and a model $m$ with model parameters $\theta$.
Args:
model: The PyTorch model.
loss: A callable that computes the loss.
data_loader: A PyTorch DataLoader providing batches of input data and
corresponding output data.
restrict_to: The parameters to restrict the differentiation to,
i.e. the corresponding sub-matrix of the Jacobian. If None, the full
Jacobian is used.
Returns:
The Gauss-Newton matrix.
"""

per_sample_grads = create_per_sample_gradient_function(model, loss)

params = restrict_to
if params is None:
params = {k: p.detach() for k, p in model.named_parameters() if p.requires_grad}

def generate_batch_matrices():
for x, y in data_loader:
grads = flatten_dimensions(
per_sample_grads(params, x, y).values(), shape=(x.shape[0], -1)
)
batch_mat = grads.t() @ grads
yield batch_mat.detach()

n_points = 0
tensors = generate_batch_matrices()
result = next(tensors)

for t in tensors:
result += t
n_points += t.shape[0]

return result / n_points


def create_per_sample_loss_function(
model: torch.nn.Module, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) -> Callable[[Dict[str, torch.Tensor], torch.Tensor, torch.Tensor], torch.Tensor]:
Expand Down
172 changes: 90 additions & 82 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,18 @@
create_matrix_jacobian_product_function,
create_per_sample_gradient_function,
create_per_sample_mixed_derivative_function,
gauss_newton,
hessian,
model_hessian_low_rank,
model_hessian_nystroem_approximation,
)
from .operator import InverseHarmonicMeanOperator
from .operator import DirectSolveOperator, InverseHarmonicMeanOperator
from .pre_conditioner import PreConditioner
from .util import (
BlockMode,
EkfacRepresentation,
LossType,
SecondOrderMode,
empirical_cross_entropy_loss_fn,
flatten_dimensions,
safe_torch_linalg_eigh,
Expand Down Expand Up @@ -351,109 +353,115 @@ def to(self, device: torch.device):
return self


class DirectInfluence(TorchInfluenceFunctionModel):
class DirectInfluence(TorchComposableInfluence[DirectSolveOperator]):
r"""
Given a model and training data, it finds x such that \(Hx = b\),
with \(H\) being the model hessian.
with \(H\) being the model hessian or Gauss-Newton matrix.
Block-mode:
This implementation is capable of using a block-matrix approximation. The
blocking structure can be specified via the `block_structure` parameter.
The `block_structure` parameter can either be a
[BlockMode][pydvl.influence.torch.util.BlockMode] enum (which provides
layer-wise or parameter-wise blocking) or a custom block structure defined
by an ordered dictionary with the keys being the block identifiers (arbitrary
strings) and the values being lists of parameter names contained in the block.
```python
block_structure = OrderedDict(
(
("custom_block1", ["0.weight", "1.bias"]),
("custom_block2", ["1.weight", "0.bias"]),
)
)
```
If you would like to apply a block-specific regularization, you can provide a
dictionary with the block names as keys and the regularization values as values.
In this case, the specification must be complete, i.e. every block must have
a positive regularization value.
```python
regularization = {
"custom_block1": 0.1,
"custom_block2": 0.2,
}
```
Accordingly, if you choose a layer-wise or parameter-wise structure
(by providing `BlockMode.LAYER_WISE` or `BlockMode.PARAMETER_WISE` for
`block_structure`) the keys must be the layer names or parameter names,
respectively.
You can retrieve the block-wise influence information from the methods
with suffix `_by_block`. By default, `block_structure` is set to
`BlockMode.FULL` and in this case these methods will return a dictionary
with the empty string being the only key.
Args:
model: A PyTorch model. The Hessian will be calculated with respect to
this model's parameters.
loss: A callable that takes the model's output and target as input and returns
the scalar loss.
hessian_regularization: Regularization of the hessian.
model: The model.
loss: The loss function.
regularization: The regularization parameter. In case a dictionary is provided,
the keys must match the blocking structure.
block_structure: The blocking structure, either a pre-defined enum or a
custom block structure, see the information regarding block-mode.
second_order_mode: The second order mode, either `SecondOrderMode.HESSIAN` or
`SecondOrderMode.GAUSS_NEWTON`.
"""

def __init__(
self,
model: nn.Module,
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
hessian_regularization: float = 0.0,
loss: LossType,
regularization: Optional[Union[float, Dict[str, Optional[float]]]] = None,
block_structure: Union[BlockMode, OrderedDict[str, List[str]]] = BlockMode.FULL,
second_order_mode: SecondOrderMode = SecondOrderMode.HESSIAN,
):
super().__init__(model, loss)
self.hessian_regularization = hessian_regularization

hessian: torch.Tensor

@property
def is_fitted(self):
try:
return self.hessian is not None
except AttributeError:
return False
super().__init__(
model,
block_structure=block_structure,
regularization=regularization,
)
self.second_order_mode = second_order_mode
self.loss = loss

@log_duration(log_level=logging.INFO)
def fit(self, data: DataLoader) -> DirectInfluence:
def with_regularization(
self, regularization: Union[float, Dict[str, Optional[float]]]
) -> TorchComposableInfluence:
"""
Compute the hessian matrix based on a provided dataloader.
Update the regularization parameter.
Args:
data: The data to compute the Hessian with.
regularization: Either a positive float or a dictionary with the
block names as keys and the regularization values as values.
Returns:
The fitted instance.
The modified instance
"""
self.hessian = hessian(self.model, self.loss, data)
self._regularization_dict = self._build_regularization_dict(regularization)
for k, reg in self._regularization_dict.items():
self.block_mapper.composable_block_dict[k].op.regularization = reg
return self

@log_duration
def influences(
def _create_block(
self,
x_test: torch.Tensor,
y_test: torch.Tensor,
x: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
mode: InfluenceMode = InfluenceMode.Up,
) -> torch.Tensor:
r"""
Compute approximation of
\[ \langle H^{-1}\nabla_{\theta} \ell(y_{\text{test}},
f_{\theta}(x_{\text{test}})),
\nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle, \]
for the case of up-weighting influence, resp.
\[ \langle H^{-1}\nabla_{\theta} \ell(y_{\text{test}},
f_{\theta}(x_{\text{test}})),
\nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \]
for the perturbation type influence case. The action of $H^{-1}$ is achieved
via a direct solver using [torch.linalg.solve][torch.linalg.solve].
Args:
x_test: model input to use in the gradient computations of
$H^{-1}\nabla_{\theta} \ell(y_{\text{test}},
f_{\theta}(x_{\text{test}}))$
y_test: label tensor to compute gradients
x: optional model input to use in the gradient computations
$\nabla_{\theta}\ell(y, f_{\theta}(x))$,
resp. $\nabla_{x}\nabla_{\theta}\ell(y, f_{\theta}(x))$,
if None, use $x=x_{\text{test}}$
y: optional label tensor to compute gradients
mode: enum value of [InfluenceMode]
[pydvl.influence.base_influence_function_model.InfluenceMode]
Returns:
A tensor representing the element-wise scalar products for the
provided batch.
block_params: Dict[str, torch.nn.Parameter],
data: DataLoader,
regularization: Optional[float],
) -> TorchOperatorGradientComposition:
gp = TorchGradientProvider(self.model, self.loss, restrict_to=block_params)

"""
return super().influences(x_test, y_test, x, y, mode=mode)
if self.second_order_mode is SecondOrderMode.GAUSS_NEWTON:
mat = gauss_newton(self.model, self.loss, data, restrict_to=block_params)
else:
mat = hessian(self.model, self.loss, data, restrict_to=block_params)

@log_duration
def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
return torch.linalg.solve(
self.hessian.to(self.model_device)
+ self.hessian_regularization
* torch.eye(self.n_parameters, device=self.model_device),
rhs.T.to(self.model_device),
).T
op = DirectSolveOperator(mat, regularization=regularization)
return TorchOperatorGradientComposition(op, gp)

def to(self, device: torch.device):
if self.is_fitted:
self.hessian = self.hessian.to(device)
return super().to(device)
@property
def is_thread_safe(self) -> bool:
return False


class CgInfluence(TorchInfluenceFunctionModel):
Expand Down
Loading

0 comments on commit b04f7a0

Please sign in to comment.