diff --git a/src/pydvl/influence/array.py b/src/pydvl/influence/array.py index 5faa288ac..7a8c5e881 100644 --- a/src/pydvl/influence/array.py +++ b/src/pydvl/influence/array.py @@ -405,8 +405,7 @@ def _initialize_zarr_array( class SumAggregator(SequenceAggregator): def __call__(self, tensor_sequence: LazyChunkSequence): """ - Aggregates tensors from a single-level generator by summing up. This method simply - collects each tensor emitted by the generator into a single list. + Aggregates tensors from a single-level generator by summing up. Args: tensor_sequence: Object wrapping a generator that yields `TensorType` @@ -418,5 +417,5 @@ def __call__(self, tensor_sequence: LazyChunkSequence): tensor_generator = tensor_sequence.generator_factory() result = next(tensor_generator) for tensor in tensor_generator: - result += tensor + result = result + tensor return result diff --git a/src/pydvl/influence/base_influence_function_model.py b/src/pydvl/influence/base_influence_function_model.py index 5077f8d88..51feb8ab4 100644 --- a/src/pydvl/influence/base_influence_function_model.py +++ b/src/pydvl/influence/base_influence_function_model.py @@ -30,7 +30,10 @@ def __init__(self, object_type: Type): class NotImplementedLayerRepresentationException(ValueError): def __init__(self, module_id: str): - message = f"Only Linear layers are supported, but found module {module_id} requiring grad." + message = ( + f"Only Linear layers are supported, but found module {module_id} " + f"requiring grad." + ) super().__init__(message) @@ -82,6 +85,23 @@ def wrapper(self, *args, **kwargs): return wrapper def influence_factors(self, x: TensorType, y: TensorType) -> TensorType: + r""" + Computes the approximation of + + \[ H^{-1}\nabla_{\theta} \ell(y, f_{\theta}(x)) \] + + where the gradient is meant to be per sample of the batch $(x, y)$. + For all input tensors it is assumed, + that the first dimension is the batch dimension. + + Args: + x: model input to use in the gradient computations + y: label tensor to compute gradients + + Returns: + Tensor representing the element-wise inverse Hessian matrix vector products + + """ if not self.is_fitted: raise NotFittedException(type(self)) return self._influence_factors(x, y) @@ -112,6 +132,36 @@ def influences( y: Optional[TensorType] = None, mode: InfluenceMode = InfluenceMode.Up, ) -> TensorType: + r""" + Computes the approximation of + + \[ \langle H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, + f_{\theta}(x_{\text{test}})), + \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle H^{-1}\nabla_{\theta} \ell(y_{test}, f_{\theta}(x_{test})), + \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. + + Args: + x_test: model input to use in the gradient computations + of $H^{-1}\nabla_{theta} \ell(y_{test}, f_{\theta}(x_{test}))$ + y_test: label tensor to compute gradients + x: optional model input to use in the gradient computations + $\nabla_{theta}\ell(y, f_{\theta}(x))$, + resp. $\nabla_{x}\nabla_{theta}\ell(y, f_{\theta}(x))$, + if None, use $x=x_{test}$ + y: optional label tensor to compute gradients + mode: enum value of [InfluenceMode] + [pydvl.influence.base_influence_function_model.InfluenceMode] + + Returns: + Tensor representing the element-wise scalar products for the provided batch + + """ if not self.is_fitted: raise NotFittedException(type(self)) @@ -214,6 +264,11 @@ class ComposableInfluence( Generic[TensorType, BatchType, DataLoaderType, BlockMapperType], ABC, ): + """ + Generic abstract base class, that allow for block-wise computation of influence + quantities. Inherit from this base class for specific influence algorithms and + tensor frameworks. + """ block_mapper: BlockMapperType @@ -226,11 +281,30 @@ def is_fitted(self): @log_duration(log_level=logging.INFO) def fit(self, data: DataLoaderType) -> InfluenceFunctionModel: + """ + Fitting to provided data, by internally creating a block mapper instance from + it. + Args: + data: iterable of tensors + + Returns: + Fitted instance + """ self.block_mapper = self._create_block_mapper(data) return self @abstractmethod def _create_block_mapper(self, data: DataLoaderType) -> BlockMapperType: + """ + Override this method to create a block mapper instance, that can be used + to compute block-wise influence quantities. + + Args: + data: iterable of tensors + + Returns: + BlockMapper instance + """ pass @InfluenceFunctionModel.fit_required @@ -242,6 +316,39 @@ def influences_by_block( y: Optional[TensorType] = None, mode: InfluenceMode = InfluenceMode.Up, ) -> OrderedDict[str, TensorType]: + r""" + Compute the block-wise influence values for the provided data, i.e. an + approximation of + + \[ \langle H^{-1}\nabla_{theta} \ell(y_{\text{test}}, + f_{\theta}(x_{\text{test}})), + \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle H^{-1}\nabla_{theta} \ell(y_{test}, f_{\theta}(x_{test})), + \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. + + Args: + x_test: model input to use in the gradient computations + of the approximation of + $H^{-1}\nabla_{theta} \ell(y_{test}, f_{\theta}(x_{test}))$ + y_test: label tensor to compute gradients + x: optional model input to use in the gradient computations + $\nabla_{theta}\ell(y, f_{\theta}(x))$, + resp. $\nabla_{x}\nabla_{theta}\ell(y, f_{\theta}(x))$, + if None, use $x=x_{test}$ + y: optional label tensor to compute gradients + mode: enum value of [InfluenceMode] + [pydvl.influence.base_influence_function_model.InfluenceMode] + + Returns: + Ordered dictionary of tensors representing the element-wise scalar products + for the provided batch per block. + + """ left_batch = self._create_batch(x_test, y_test) if x is None: @@ -259,13 +366,29 @@ def influences_by_block( ) right_batch = self._create_batch(x, y) - return self.block_mapper.block_interactions(left_batch, right_batch, mode) + return self.block_mapper.interactions(left_batch, right_batch, mode) @InfluenceFunctionModel.fit_required def influence_factors_by_block( self, x: TensorType, y: TensorType ) -> OrderedDict[str, TensorType]: - return self.block_mapper.block_transformed_gradients(self._create_batch(x, y)) + r""" + Compute the block-wise approximation of + + \[ H^{-1}\nabla_{\theta} \ell(y, f_{\theta}(x)) \] + + where the gradient is meant to be per sample of the batch $(x, y)$. + + Args: + x: model input to use in the gradient computations + y: label tensor to compute gradients + + Returns: + Ordered dictionary of tensors representing the element-wise + approximate inverse Hessian matrix vector products per block. + + """ + return self.block_mapper.transformed_grads(self._create_batch(x, y)) @InfluenceFunctionModel.fit_required def influences_from_factors_by_block( @@ -275,13 +398,44 @@ def influences_from_factors_by_block( y: TensorType, mode: InfluenceMode = InfluenceMode.Up, ) -> OrderedDict[str, TensorType]: - return self.block_mapper.block_interactions_from_transformed_gradients( + r""" + Block-wise computation of + + \[ \langle z_{\text{test_factors}}, + \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle z_{\text{test_factors}}, + \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. The gradient is meant to be per sample + of the batch $(x, y)$. + + Args: + z_test_factors: pre-computed array, approximating + $H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, + f_{\theta}(x_{\text{test}}))$ + x: model input to use in the gradient computations + $\nabla_{\theta}\ell(y, f_{\theta}(x))$, + resp. $\nabla_{x}\nabla_{\theta}\ell(y, f_{\theta}(x))$, + if None, use $x=x_{\text{test}}$ + y: label tensor to compute gradients + mode: enum value of [InfluenceMode] + [pydvl.influence.base_influence_function_model.InfluenceMode] + + Returns: + Ordered dictionary of tensors representing the element-wise scalar products + for the provided batch per block + + """ + return self.block_mapper.interactions_from_transformed_grads( z_test_factors, self._create_batch(x, y), mode ) def _influence_factors(self, x: TensorType, y: TensorType) -> TensorType: tensor_gen_factory = partial( - self.block_mapper.generate_transformed_gradients, self._create_batch(x, y) + self.block_mapper.generate_transformed_grads, self._create_batch(x, y) ) aggregator = SumAggregator() result: TensorType = aggregator(LazyChunkSequence(tensor_gen_factory)) @@ -308,7 +462,7 @@ def _influences( right_batch = self._create_batch(x, y) tensor_gen_factory = partial( - self.block_mapper.generate_gradient_interactions, + self.block_mapper.generate_interactions, left_batch, right_batch, mode, @@ -325,8 +479,39 @@ def influences_from_factors( y: TensorType, mode: InfluenceMode = InfluenceMode.Up, ) -> TensorType: + r""" + Computation of + + \[ \langle z_{\text{test_factors}}, + \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle z_{\text{test_factors}}, + \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. The gradient is meant to be per sample + of the batch $(x, y)$. + + Args: + z_test_factors: pre-computed array, approximating + $H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, + f_{\theta}(x_{\text{test}}))$ + x: model input to use in the gradient computations + $\nabla_{\theta}\ell(y, f_{\theta}(x))$, + resp. $\nabla_{x}\nabla_{\theta}\ell(y, f_{\theta}(x))$, + if None, use $x=x_{\text{test}}$ + y: label tensor to compute gradients + mode: enum value of [InfluenceMode] + [pydvl.influence.base_influence_function_model.InfluenceMode] + + Returns: + Tensor representing the element-wise scalar products for the provided batch + + """ + tensor_gen_factory = partial( - self.block_mapper.generate_interactions_from_transformed_gradients, + self.block_mapper.generate_interactions_from_transformed_grads, z_test_factors, self._create_batch(x, y), mode, @@ -339,4 +524,6 @@ def influences_from_factors( @staticmethod @abstractmethod def _create_batch(x: TensorType, y: TensorType) -> BatchType: - pass + """Implement this method to provide the creation of a subtype of + [Batch][pydvl.influence.types.Batch] for a specific framework + """ diff --git a/src/pydvl/influence/torch/__init__.py b/src/pydvl/influence/torch/__init__.py index 3bbd9552c..a1e0bb09a 100644 --- a/src/pydvl/influence/torch/__init__.py +++ b/src/pydvl/influence/torch/__init__.py @@ -3,6 +3,7 @@ CgInfluence, DirectInfluence, EkfacInfluence, + InverseHarmonicMeanInfluence, LissaInfluence, NystroemSketchInfluence, ) diff --git a/src/pydvl/influence/torch/base.py b/src/pydvl/influence/torch/base.py index 982af400c..73a0f2933 100644 --- a/src/pydvl/influence/torch/base.py +++ b/src/pydvl/influence/torch/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, TypeVar, Union import torch from torch.func import functional_call @@ -14,9 +14,9 @@ Batch, BilinearForm, BlockMapper, + GradientProvider, Operator, OperatorGradientComposition, - PerSampleGradientProvider, ) from .functional import ( create_matrix_jacobian_product_function, @@ -61,9 +61,7 @@ def to(self, device: torch.device): return TorchBatch(self.x.to(device), self.y.to(device)) -class TorchPerSampleGradientProvider( - PerSampleGradientProvider[TorchBatch, torch.Tensor], ABC -): +class TorchGradientProvider(GradientProvider[TorchBatch, torch.Tensor], ABC): r""" Abstract base class for calculating per-sample gradients of a function defined by a [torch.nn.Module][torch.nn.Module] and a loss function. @@ -97,7 +95,9 @@ def __init__( self.model = model if restrict_to is None: - restrict_to = ModelParameterDictBuilder(model).build(BlockMode.FULL) + restrict_to = ModelParameterDictBuilder(model).build_from_block_mode( + BlockMode.FULL + ) self.params_to_restrict_to = restrict_to @@ -119,17 +119,15 @@ def dtype(self): return next(self.model.parameters()).dtype @abstractmethod - def _per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + def _grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: pass @abstractmethod - def _per_sample_mixed_gradient_dict( - self, batch: TorchBatch - ) -> Dict[str, torch.Tensor]: + def _mixed_grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: pass @abstractmethod - def _matrix_jacobian_product( + def _jacobian_prod( self, batch: TorchBatch, g: torch.Tensor, @@ -140,9 +138,9 @@ def _matrix_jacobian_product( def _detach_dict(tensor_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {k: g.detach() if g.requires_grad else g for k, g in tensor_dict.items()} - def per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + def grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: r""" - Computes and returns a dictionary mapping gradient names to their respective + Computes and returns a dictionary mapping parameter names to their respective per-sample gradients. Given the example in the class docstring, this means $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, @@ -159,12 +157,10 @@ def per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor] A dictionary where keys are gradient identifiers and values are the gradients computed per sample. """ - gradient_dict = self._per_sample_gradient_dict(batch.to(self.device)) + gradient_dict = self._grads(batch.to(self.device)) return self._detach_dict(gradient_dict) - def per_sample_mixed_gradient_dict( - self, batch: TorchBatch - ) -> Dict[str, torch.Tensor]: + def mixed_grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: r""" Computes and returns a dictionary mapping gradient names to their respective per-sample mixed gradients. In this context, mixed gradients refer to computing @@ -187,10 +183,10 @@ def per_sample_mixed_gradient_dict( A dictionary where keys are gradient identifiers and values are the mixed gradients computed per sample. """ - gradient_dict = self._per_sample_mixed_gradient_dict(batch.to(self.device)) + gradient_dict = self._mixed_grads(batch.to(self.device)) return self._detach_dict(gradient_dict) - def matrix_jacobian_product( + def jacobian_prod( self, batch: TorchBatch, g: torch.Tensor, @@ -215,24 +211,22 @@ def matrix_jacobian_product( Returns: The resulting tensor from the matrix-Jacobian product computation. """ - result = self._matrix_jacobian_product(batch.to(self.device), g.to(self.device)) + result = self._jacobian_prod(batch.to(self.device), g.to(self.device)) if result.requires_grad: result = result.detach() return result - def per_sample_flat_gradient(self, batch: TorchBatch) -> torch.Tensor: + def flat_grads(self, batch: TorchBatch) -> torch.Tensor: return flatten_dimensions( - self.per_sample_gradient_dict(batch).values(), shape=(batch.x.shape[0], -1) + self.grads(batch).values(), shape=(batch.x.shape[0], -1) ) - def per_sample_flat_mixed_gradient(self, batch: TorchBatch) -> torch.Tensor: + def flat_mixed_grads(self, batch: TorchBatch) -> torch.Tensor: shape = (*batch.x.shape, -1) - return flatten_dimensions( - self.per_sample_mixed_gradient_dict(batch).values(), shape=shape - ) + return flatten_dimensions(self.mixed_grads(batch).values(), shape=shape) -class TorchPerSampleAutoGrad(TorchPerSampleGradientProvider): +class TorchAutoGrad(TorchGradientProvider): r""" Compute per-sample gradients of a function defined by a [torch.nn.Module][torch.nn.Module] and a loss function using @@ -273,19 +267,17 @@ def _compute_loss( outputs = functional_call(self.model, params, (x.unsqueeze(0).to(self.device),)) return self.loss(outputs, y.unsqueeze(0)) - def _per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + def _grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: return self._per_sample_gradient_function( self.params_to_restrict_to, batch.x, batch.y ) - def _per_sample_mixed_gradient_dict( - self, batch: TorchBatch - ) -> Dict[str, torch.Tensor]: + def _mixed_grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: return self._per_sample_mixed_gradient_func( self.params_to_restrict_to, batch.x, batch.y ) - def _matrix_jacobian_product( + def _jacobian_prod( self, batch: TorchBatch, g: torch.Tensor, @@ -300,12 +292,12 @@ def _matrix_jacobian_product( GradientProviderFactoryType = Callable[ [torch.nn.Module, LossType, Optional[Dict[str, torch.nn.Parameter]]], - TorchPerSampleGradientProvider, + TorchGradientProvider, ] class OperatorBilinearForm( - BilinearForm[torch.Tensor, TorchBatch, TorchPerSampleGradientProvider] + BilinearForm[torch.Tensor, TorchBatch, TorchGradientProvider] ): r""" Base class for bilinear forms based on an instance of @@ -322,13 +314,13 @@ def __init__( ): self.operator = operator - def inner_product( + def inner_prod( self, left: torch.Tensor, right: Optional[torch.Tensor] ) -> torch.Tensor: r""" Computes the weighted inner product of two vectors, i.e. - $$ \langle x, y \rangle_{B} = \langle \operatorname{Op}(x), y \rangle $$ + $$ \langle \operatorname{Op}(\text{left}), \text{right} \rangle $$ Args: left: The first tensor in the inner product computation. @@ -354,30 +346,10 @@ def _inner_product(self, left: torch.Tensor, right: torch.Tensor) -> torch.Tenso class TorchOperator(Operator[torch.Tensor, OperatorBilinearForm], ABC): - def __init__(self, regularization: float = 0.0): - """ - Initializes the Operator with an optional regularization parameter. - - Args: - regularization: A non-negative float that represents the regularization - strength (default is 0.0). - - Raises: - ValueError: If the regularization parameter is negative. - """ - if regularization < 0: - raise ValueError("regularization must be non-negative") - self._regularization = regularization - - @property - def regularization(self): - return self._regularization - - @regularization.setter - def regularization(self, value: float): - if value < 0: - raise ValueError("regularization must be non-negative") - self._regularization = value + """ + Abstract base class for operators that can be applied to instances of + [torch.Tensor][torch.Tensor]. + """ @property @abstractmethod @@ -398,18 +370,46 @@ def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: pass def as_bilinear_form(self): + """ + Represent this operator as a + [OperatorBilinearForm][pydvl.influence.torch.base.OperatorBilinearForm]. + """ return OperatorBilinearForm(self) def apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: + """ + Applies the operator to a single vector. + Args: + vec: A single vector consistent to the operator, i.e. it's length + must be equal to the property `input_size`. + + Returns: + A single vector after applying the batch operation + """ return self._apply_to_vec(vec.to(self.device)) def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the operator to a matrix. + Args: + mat: A matrix to apply the operator to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ return torch.func.vmap(self.apply_to_vec, in_dims=0, randomness="same")(mat) +TorchOperatorType = TypeVar("TorchOperatorType", bound=TorchOperator) + + class TorchOperatorGradientComposition( OperatorGradientComposition[ - torch.Tensor, TorchBatch, TorchOperator, TorchPerSampleGradientProvider + torch.Tensor, TorchBatch, TorchOperatorType, TorchGradientProvider ] ): """ @@ -422,7 +422,7 @@ class TorchOperatorGradientComposition( an abstract operator and gradient provider. """ - def __init__(self, op: TorchOperator, gp: TorchPerSampleGradientProvider): + def __init__(self, op: TorchOperatorType, gp: TorchGradientProvider): super().__init__(op, gp) def to(self, device: torch.device): @@ -432,7 +432,9 @@ def to(self, device: torch.device): class TorchBlockMapper( - BlockMapper[torch.Tensor, TorchBatch, TorchOperatorGradientComposition] + BlockMapper[ + torch.Tensor, TorchBatch, TorchOperatorGradientComposition[TorchOperatorType] + ] ): """ Class for mapping operations across multiple compositional blocks represented by @@ -469,24 +471,31 @@ def to(self, device: torch.device): class TorchComposableInfluence( - ComposableInfluence[torch.Tensor, TorchBatch, DataLoader, TorchBlockMapper], + ComposableInfluence[ + torch.Tensor, TorchBatch, DataLoader, TorchBlockMapper[TorchOperatorType] + ], ModelInfoMixin, ABC, ): + """ + Abstract base class, that allow for block-wise computation of influence + quantities with the [torch][torch] framework. + Inherit from this base class for specific influence algorithms. + """ + def __init__( self, model: torch.nn.Module, - block_structure: Union[ - BlockMode, OrderedDict[str, OrderedDict[str, torch.nn.Parameter]] - ] = BlockMode.FULL, + block_structure: Union[BlockMode, OrderedDict[str, List[str]]] = BlockMode.FULL, regularization: Optional[Union[float, Dict[str, Optional[float]]]] = None, ): + parameter_dict_builder = ModelParameterDictBuilder(model) if isinstance(block_structure, BlockMode): - self.parameter_dict = ModelParameterDictBuilder(model).build( + self.parameter_dict = parameter_dict_builder.build_from_block_mode( block_structure ) else: - self.parameter_dict = block_structure + self.parameter_dict = parameter_dict_builder.build(block_structure) self._regularization_dict = self._build_regularization_dict(regularization) diff --git a/src/pydvl/influence/torch/batch_operation.py b/src/pydvl/influence/torch/batch_operation.py index 4a881de6e..4c7075924 100644 --- a/src/pydvl/influence/torch/batch_operation.py +++ b/src/pydvl/influence/torch/batch_operation.py @@ -1,39 +1,48 @@ +r""" +This module contains abstractions and implementations for operations carried out on a +batch $b$. These operations are of the form + +$$ m(b) \cdot v$$, + +where $m(b)$ is a matrix defined by the data in the batch and $v$ is a vector or matrix. +These batch operations can be used to conveniently build aggregations or recursions +over sequence of batches, e.g. an average of the form + +$$ \frac{1}{|B|} \sum_{b in B}m(b)\cdot v$$, + +which is useful in the case that keeping $B$ in memory is not feasible. + +""" from abc import ABC, abstractmethod -from typing import Callable, Dict, Optional, Type, Union +from typing import Callable, Dict, Optional, Type, TypeVar, Union import torch from .base import ( GradientProviderFactoryType, + TorchAutoGrad, TorchBatch, - TorchPerSampleAutoGrad, - TorchPerSampleGradientProvider, + TorchGradientProvider, ) from .functional import create_batch_hvp_function from .util import LossType, inverse_rank_one_update, rank_one_mvp class BatchOperation(ABC): - def __init__(self, regularization: float = 0.0): - if regularization < 0: - raise ValueError("regularization must be non-negative") - self._regularization = regularization + r""" + Abstract base class to implement operations of the form + + $$ m(b) \cdot v $$ + + where $m(b)$ is a matrix defined by the data in the batch and $v$ is a vector + or matrix. + """ @property @abstractmethod - def n_parameters(self): + def input_size(self): pass - @property - def regularization(self): - return self._regularization - - @regularization.setter - def regularization(self, value: float): - if value < 0: - raise ValueError("regularization must be non-negative") - self._regularization = value - @property @abstractmethod def device(self): @@ -53,9 +62,32 @@ def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: pass def apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor): + """ + Applies the batch operation to a single vector. + Args: + batch: Batch of data for computation + vec: A single vector consistent to the operation, i.e. it's length + must be equal to the property `input_size`. + + Returns: + A single vector after applying the batch operation + """ return self._apply_to_vec(batch.to(self.device), vec.to(self.device)) def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the batch operation to a matrix. + Args: + batch: Batch of data for computation + mat: A matrix to apply the batch operation to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ return torch.func.vmap( lambda _x, _y, m: self._apply_to_vec(TorchBatch(_x, _y), m), in_dims=(None, None, 0), @@ -64,20 +96,25 @@ def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: class ModelBasedBatchOperation(BatchOperation, ABC): + r""" + Abstract base class to implement operations of the form + + $$ m(\text{model}, b) \cdot v $$ + + where model is a [torch.nn.Module][torch.nn.Module]. + + """ + def __init__( self, model: torch.nn.Module, - loss: LossType, - regularization: float = 0.0, restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, ): - super().__init__(regularization) if restrict_to is None: restrict_to = { k: p.detach() for k, p in model.named_parameters() if p.requires_grad } self.params_to_restrict_to = restrict_to - self.loss = loss self.model = model @property @@ -89,7 +126,7 @@ def dtype(self): return next(self.model.parameters()).dtype @property - def n_parameters(self): + def input_size(self): return sum(p.numel() for p in self.params_to_restrict_to.values()) def to(self, device: torch.device): @@ -103,17 +140,36 @@ def to(self, device: torch.device): class HessianBatchOperation(ModelBasedBatchOperation): + r""" + Given a model and loss function computes the Hessian vector or matrix product + with respect to the model parameters, i.e. + + \begin{align*} + &\nabla^2_{\theta} L(b;\theta) \cdot v \\\ + &L(b;\theta) = \left( \frac{1}{|b|} \sum_{(x,y) \in b} + \text{loss}(\text{model}(x; \theta), y)\right), + \end{align*} + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix. + + Args: + model: The model. + loss: The loss function. + restrict_to: The parameters to restrict the second order differentiation to, + i.e. the corresponding sub-matrix of the Hessian. If None, the full Hessian + is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + reverse_only: If True only the reverse mode is used in the autograd computation. + """ + def __init__( self, model: torch.nn.Module, loss: LossType, - regularization: float = 0.0, restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, reverse_only: bool = True, ): - super().__init__( - model, loss, regularization=regularization, restrict_to=restrict_to - ) + super().__init__(model, restrict_to=restrict_to) self._batch_hvp = create_batch_hvp_function( model, loss, reverse_only=reverse_only ) @@ -123,34 +179,68 @@ def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: class GaussNewtonBatchOperation(ModelBasedBatchOperation): + r""" + Given a model and loss function computes the Gauss-Newton vector or matrix product + with respect to the model parameters, i.e. + + \begin{align*} + G(\text{model}, \text{loss}, b, \theta) &\cdot v, \\\ + G(\text{model}, \text{loss}, b, \theta) &= + \frac{1}{|b|}\sum_{(x, y) \in b}\nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix. + + Args: + model: The model. + loss: The loss function. + gradient_provider_factory: An optional factory to create an object of type + [TorchGradientProvider][pydvl.influence.torch.base.TorchGradientProvider], + depending on the model, loss and optional parameters to restrict to. + If not provided, the implementation + [TorchAutograd][pydvl.influence.torch.base.TorchAutograd] is used. + restrict_to: The parameters to restrict the differentiation to, + i.e. the corresponding sub-matrix of the Jacobian. If None, the full + Jacobian is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + def __init__( self, model: torch.nn.Module, loss: LossType, - regularization: float = 0.0, gradient_provider_factory: Union[ GradientProviderFactoryType, - Type[TorchPerSampleGradientProvider], - ] = TorchPerSampleAutoGrad, + Type[TorchGradientProvider], + ] = TorchAutoGrad, restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, ): - super().__init__( - model, loss, regularization=regularization, restrict_to=restrict_to - ) + super().__init__(model, restrict_to=restrict_to) self.gradient_provider = gradient_provider_factory( model, loss, self.params_to_restrict_to ) def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: - flat_grads = self.gradient_provider.per_sample_flat_gradient(batch) + flat_grads = self.gradient_provider.flat_grads(batch) result = rank_one_mvp(flat_grads, vec) - - if self.regularization > 0.0: - result += self.regularization * vec - return result def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the batch operation to a matrix. + Args: + batch: Batch of data for computation + mat: A matrix to apply the batch operation to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ return self.apply_to_vec(batch, mat) def to(self, device: torch.device): @@ -159,6 +249,52 @@ def to(self, device: torch.device): class InverseHarmonicMeanBatchOperation(ModelBasedBatchOperation): + r""" + Given a model and loss function computes an approximation of the inverse + Gauss-Newton vector or matrix product. Viewing the damped Gauss-newton matrix + + \begin{align*} + G_{\lambda}(\text{model}, \text{loss}, b, \theta) &= + \frac{1}{|b|}\sum_{(x, y) \in b}\nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t + \lambda \operatorname{I}, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + as an arithmetic mean of the rank-$1$ updates, this operation replaces it with + the harmonic mean of the rank-$1$ updates, i.e. + + $$ \tilde{G}_{\lambda}(\text{model}, \text{loss}, b, \theta) = + \left(n \sum_{(x, y) \in b} \left( \nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t + \lambda \operatorname{I}\right)^{-1} + \right)^{-1}$$ + + and computes + + $$ \tilde{G}_{\lambda}^{-1}(\text{model}, \text{loss}, b, \theta) + \cdot v.$$ + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix. + In other words, it switches the order of summation and inversion, which resolves + to the `inverse harmonic mean` of the rank-$1$ updates. + + The inverses of the rank-$1$ updates are not calculated explicitly, + but instead a vectorized version of the + [Sherman–Morrison formula]( + https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) + is applied. + + For more information, + see [Inverse Harmonic Mean][inverse-harmonic-mean]. + + Args: + model: The model. + loss: The loss function. + restrict_to: The parameters to restrict the differentiation to, + i.e. the corresponding sub-matrix of the Jacobian. If None, the full + Jacobian is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + def __init__( self, model: torch.nn.Module, @@ -166,32 +302,56 @@ def __init__( regularization: float, gradient_provider_factory: Union[ GradientProviderFactoryType, - Type[TorchPerSampleGradientProvider], - ] = TorchPerSampleAutoGrad, + Type[TorchGradientProvider], + ] = TorchAutoGrad, restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, ): if regularization <= 0: raise ValueError("regularization must be positive") - - super().__init__( - model, loss, regularization=regularization, restrict_to=restrict_to - ) self.regularization = regularization + + super().__init__(model, restrict_to=restrict_to) self.gradient_provider = gradient_provider_factory( model, loss, self.params_to_restrict_to ) + @property + def regularization(self): + return self._regularization + + @regularization.setter + def regularization(self, value: float): + if value <= 0: + raise ValueError("regularization must be positive") + self._regularization = value + def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: - grads = self.gradient_provider.per_sample_flat_gradient(batch) + grads = self.gradient_provider.flat_grads(batch) return ( inverse_rank_one_update(grads, vec, self.regularization) / self.regularization ) def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the batch operation to a matrix. + Args: + batch: Batch of data for computation + mat: A matrix to apply the batch operation to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ return self.apply_to_vec(batch, mat) def to(self, device: torch.device): super().to(device) self.gradient_provider.params_to_restrict_to = self.params_to_restrict_to return self + + +BatchOperationType = TypeVar("BatchOperationType", bound=BatchOperation) diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index 96d57632e..42be0149f 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -24,9 +24,9 @@ UnsupportedInfluenceModeException, ) from .base import ( + TorchAutoGrad, TorchComposableInfluence, TorchOperatorGradientComposition, - TorchPerSampleAutoGrad, ) from .functional import ( LowRankProductRepresentation, @@ -57,6 +57,7 @@ "ArnoldiInfluence", "EkfacInfluence", "NystroemSketchInfluence", + "InverseHarmonicMeanInfluence", ] logger = logging.getLogger(__name__) @@ -1802,23 +1803,113 @@ def fit(self, data: DataLoader): return self -class InverseHarmonicMeanInfluence(TorchComposableInfluence): +class InverseHarmonicMeanInfluence( + TorchComposableInfluence[InverseHarmonicMeanOperator] +): + r""" + This implementation replaces the inverse Hessian matrix in the influence computation + an approximation of the inverse Gauss-Newton vector product. + + Viewing the damped Gauss-newton matrix + + \begin{align*} + G_{\lambda}(\theta) &= + \frac{1}{N}\sum_{i}^N\nabla_{\theta}\ell (x_i,y_i; \theta) + \nabla_{\theta}\ell (x_i, y_i; \theta)^t + \lambda \operatorname{I}, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + as an arithmetic mean of the rank-$1$ updates, this implementation replaces it with + the harmonic mean of the rank-$1$ updates, i.e. + + $$ \tilde{G}_{\lambda}(\theta) = + \left(N \cdot \sum_{i=1}^N \left( \nabla_{\theta}\ell (x_i,y_i; \theta) + \nabla_{\theta}\ell (x_i,y_i; \theta)^t + + \lambda \operatorname{I}\right)^{-1} + \right)^{-1}$$ + + and uses the matrix + + $$ \tilde{G}_{\lambda}^{-1}(\theta)$$ + + instead of the inverse Hessian. + + In other words, it switches the order of summation and inversion, which resolves + to the `inverse harmonic mean` of the rank-$1$ updates. The results are averaged + over the batches provided by the data loader. + + The inverses of the rank-$1$ updates are not calculated explicitly, + but instead a vectorized version of the + [Sherman–Morrison formula]( + https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) + is applied. + + For more information, + see [Inverse Harmonic Mean][inverse-harmonic-mean]. + + Block-mode: + This implementation is capable of using a block-matrix approximation. The + blocking structure can be specified via the `block_structure` parameter. + The `block_structure` parameter can either be a + [BlockMode][pydvl.influence.torch.util.BlockMode] enum (which provides + layer-wise or parameter-wise blocking) or a custom block structure defined + by an ordered dictionary with the keys being the block identifiers (arbitrary + strings) and the values being lists of parameter names contained in the block. + + ```python + block_structure = OrderedDict( + ( + ("custom_block1", ["0.weight", "1.bias"]), + ("custom_block2", ["1.weight", "0.bias"]), + ) + ) + ``` + + If you would like to apply a block-specific regularization, you can provide a + dictionary with the block names as keys and the regularization values as values. + In this case, the specification must be complete, i.e. every block must have + a positive regularization value. + + ```python + regularization = { + "custom_block1": 0.1, + "custom_block2": 0.2, + } + ``` + Accordingly, if you choose a layer-wise or parameter-wise structure + (by providing `BlockMode.LAYER_WISE` or `BlockMode.PARAMETER_WISE` for + `block_structure`) the keys must be the layer names or parameter names, + respectively. + + You can retrieve the block-wise influence information from the methods + with suffix `_by_block`. By default, `block_structure` is set to + `BlockMode.FULL` and in this case these methods will return a dictionary + with the empty string being the only key. + + + Args: + model: The model. + loss: The loss function. + regularization: The regularization parameter. In case a dictionary is provided, + the keys must match the blocking structure. + block_structure: The blocking structure, either a pre-defined enum or a + custom block structure, see the information regarding block-mode. + """ + def __init__( self, model: torch.nn.Module, loss: LossType, regularization: Union[float, Dict[str, Optional[float]]], - block_structure: Union[ - BlockMode, OrderedDict[str, OrderedDict[str, torch.Tensor]] - ] = BlockMode.FULL, + block_structure: Union[BlockMode, OrderedDict[str, List[str]]] = BlockMode.FULL, ): super().__init__(model, block_structure, regularization=regularization) - self.gradient_provider_factory = TorchPerSampleAutoGrad + self.gradient_provider_factory = TorchAutoGrad self.loss = loss @property def n_parameters(self): - return super().n_parameters() + return sum(block.op.input_size for _, block in self.block_mapper.items()) @property def is_thread_safe(self) -> bool: @@ -1858,6 +1949,16 @@ def _create_block( def with_regularization( self, regularization: Union[float, Dict[str, Optional[float]]] ) -> TorchComposableInfluence: + """ + Update the regularization parameter. + Args: + regularization: Either a positive float or a dictionary with the + block names as keys and the regularization values as values. + + Returns: + The modified instance + + """ self._regularization_dict = self._build_regularization_dict(regularization) for k, reg in self._regularization_dict.items(): self.block_mapper.composable_block_dict[k].op.regularization = reg diff --git a/src/pydvl/influence/torch/operator.py b/src/pydvl/influence/torch/operator.py index 5b58b9f70..81d6b0442 100644 --- a/src/pydvl/influence/torch/operator.py +++ b/src/pydvl/influence/torch/operator.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Generator, Optional, Type, Union +from typing import Callable, Dict, Generator, Generic, Optional, Type, Union import torch from torch import nn as nn @@ -7,13 +7,13 @@ from ..array import LazyChunkSequence, SequenceAggregator from .base import ( GradientProviderFactoryType, + TorchAutoGrad, TorchBatch, + TorchGradientProvider, TorchOperator, - TorchPerSampleAutoGrad, - TorchPerSampleGradientProvider, ) from .batch_operation import ( - BatchOperation, + BatchOperationType, GaussNewtonBatchOperation, HessianBatchOperation, InverseHarmonicMeanBatchOperation, @@ -21,17 +21,30 @@ from .util import TorchChunkAverageAggregator, TorchPointAverageAggregator -class AggregateBatchOperator(TorchOperator): +class AggregateBatchOperator(TorchOperator, Generic[BatchOperationType]): + """ + Class for aggregating batch operations over a dataset using a provided data loader + and aggregator. + + This class facilitates the application of a batch operation across multiple batches + of data, aggregating the results using a specified sequence aggregator. + + Attributes: + batch_operation: The batch operation to apply. + dataloader: The data loader providing batches of data. + aggregator: The sequence aggregator to aggregate the results of the batch + operations. + """ + def __init__( self, - batch_operation: BatchOperation, + batch_operation: BatchOperationType, dataloader: DataLoader, aggregator: SequenceAggregator[torch.Tensor], ): self.batch_operation = batch_operation self.dataloader = dataloader self.aggregator = aggregator - super().__init__(self.batch_operation.regularization) @property def device(self): @@ -45,23 +58,26 @@ def to(self, device: torch.device): self.batch_operation = self.batch_operation.to(device) return self - @property - def regularization(self): - return self._regularization - - @regularization.setter - def regularization(self, value: float): - self._regularization = value - self.batch_operation.regularization = value - @property def input_size(self): - return self.batch_operation.n_parameters + return self.batch_operation.input_size def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: return self._apply(vec, self.batch_operation.apply_to_vec) def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the operator to a matrix. + Args: + mat: A matrix to apply the operator to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ return self._apply(mat, self.batch_operation.apply_to_mat) def _apply( @@ -83,7 +99,32 @@ def tensor_gen_factory() -> Generator[torch.Tensor, None, None]: return self.aggregator(lazy_tensor_sequence) -class GaussNewtonOperator(AggregateBatchOperator): +class GaussNewtonOperator(AggregateBatchOperator[GaussNewtonBatchOperation]): + r""" + Given a model and loss function computes the Gauss-Newton vector or matrix product + with respect to the model parameters on a batch, i.e. + + \begin{align*} + G(\text{model}, \text{loss}, b, \theta) &\cdot v, \\\ + G(\text{model}, \text{loss}, b, \theta) &= + \frac{1}{|b|}\sum_{(x, y) \in b}\nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix, + and average the results over the batches provided by the data loader. + + Args: + model: The model. + loss: The loss function. + dataloader: The data loader providing batches of data. + restrict_to: The parameters to restrict the differentiation to, + i.e. the corresponding sub-matrix of the Jacobian. If None, the full + Jacobian is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + def __init__( self, model: nn.Module, @@ -91,8 +132,8 @@ def __init__( dataloader: DataLoader, gradient_provider_factory: Union[ GradientProviderFactoryType, - Type[TorchPerSampleGradientProvider], - ] = TorchPerSampleAutoGrad, + Type[TorchGradientProvider], + ] = TorchAutoGrad, restrict_to: Optional[Dict[str, nn.Parameter]] = None, ): batch_op = GaussNewtonBatchOperation( @@ -105,7 +146,31 @@ def __init__( super().__init__(batch_op, dataloader, aggregator) -class HessianOperator(AggregateBatchOperator): +class HessianOperator(AggregateBatchOperator[HessianBatchOperation]): + r""" + Given a model and loss function computes the Hessian vector or matrix product + with respect to the model parameters for a given batch, i.e. + + \begin{align*} + &\nabla^2_{\theta} L(b;\theta) \cdot v \\\ + &L(b;\theta) = \left( \frac{1}{|b|} \sum_{(x,y) \in b} + \text{loss}(\text{model}(x; \theta), y)\right), + \end{align*} + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix, + and average the results over the batches provided by the data loader. + + Args: + model: The model. + loss: The loss function. + dataloader: The data loader providing batches of data. + restrict_to: The parameters to restrict the second order differentiation to, + i.e. the corresponding sub-matrix of the Hessian. If None, the full Hessian + is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + reverse_only: If True only the reverse mode is used in the autograd computation. + """ + def __init__( self, model: nn.Module, @@ -121,7 +186,61 @@ def __init__( super().__init__(batch_op, dataloader, aggregator) -class InverseHarmonicMeanOperator(AggregateBatchOperator): +class InverseHarmonicMeanOperator( + AggregateBatchOperator[InverseHarmonicMeanBatchOperation] +): + r""" + Given a model and loss function computes an approximation of the inverse + Gauss-Newton vector or matrix product per batch and averages the results. + + Viewing the damped Gauss-newton matrix + + \begin{align*} + G_{\lambda}(\text{model}, \text{loss}, b, \theta) &= + \frac{1}{|b|}\sum_{(x, y) \in b}\nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t + \lambda \operatorname{I}, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + as an arithmetic mean of the rank-$1$ updates, this operator replaces it with + the harmonic mean of the rank-$1$ updates, i.e. + + $$ \tilde{G}_{\lambda}(\text{model}, \text{loss}, b, \theta) = + \left(n \sum_{(x, y) \in b} \left( \nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t + \lambda \operatorname{I}\right)^{-1} + \right)^{-1}$$ + + and computes + + $$ \tilde{G}_{\lambda}^{-1}(\text{model}, \text{loss}, b, \theta) + \cdot v.$$ + + for any given batch $b$, + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix. + + In other words, it switches the order of summation and inversion, which resolves + to the `inverse harmonic mean` of the rank-$1$ updates. The results are averaged + over the batches provided by the data loader. + + The inverses of the rank-$1$ updates are not calculated explicitly, + but instead a vectorized version of the + [Sherman–Morrison formula]( + https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) + is applied. + + For more information, + see [Inverse Harmonic Mean][inverse-harmonic-mean]. + + Args: + model: The model. + loss: The loss function. + dataloader: The data loader providing batches of data. + restrict_to: The parameters to restrict the differentiation to, + i.e. the corresponding sub-matrix of the Jacobian. If None, the full + Jacobian is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + def __init__( self, model: nn.Module, @@ -130,10 +249,15 @@ def __init__( regularization: float, gradient_provider_factory: Union[ GradientProviderFactoryType, - Type[TorchPerSampleGradientProvider], - ] = TorchPerSampleAutoGrad, + Type[TorchGradientProvider], + ] = TorchAutoGrad, restrict_to: Optional[Dict[str, nn.Parameter]] = None, ): + if regularization <= 0: + raise ValueError("regularization must be positive") + + self._regularization = regularization + batch_op = InverseHarmonicMeanBatchOperation( model, loss, @@ -143,3 +267,14 @@ def __init__( ) aggregator = TorchPointAverageAggregator(weighted=False) super().__init__(batch_op, dataloader, aggregator) + + @property + def regularization(self): + return self._regularization + + @regularization.setter + def regularization(self, value: float): + if value <= 0: + raise ValueError("regularization must be positive") + self._regularization = value + self.batch_operation.regularization = value diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index d7d4b0226..bae7e4372 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -2,6 +2,7 @@ import logging import math +import warnings from collections import OrderedDict from dataclasses import dataclass from enum import Enum @@ -59,6 +60,8 @@ "LossType", "ModelParameterDictBuilder", "BlockMode", + "ModelInfoMixin", + "safe_torch_linalg_eigh", ] @@ -665,6 +668,29 @@ def rank_one_mvp(x: torch.Tensor, v: torch.Tensor) -> torch.Tensor: def inverse_rank_one_update( x: torch.Tensor, v: torch.Tensor, regularization: float ) -> torch.Tensor: + r""" + Performs an inverse-rank one update on x and v. More precisely, it computes + + $$ \sum_{i=1}^n \left(x[i]x[i]^t+\lambda \operatorname{I}\right)^{-1}v $$ + + where $\operatorname{I}$ is the identity matrix and $\lambda$ is positive + regularization parameter. The inverse matrices are not calculated explicitly, + but instead a vectorized version of the + [Sherman–Morrison formula]( + https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) + is applied. + + Args: + x: Input matrix used for the rank one expressions. First dimension is + assumed to be the batch dimension. + v: Matrix to multiply with. First dimension is + assumed to be the batch dimension. + regularization: Regularization parameter to make the rank-one expressions + invertible, must be positive. + + Returns: + Matrix of size $(D, M)$ for x having shape $(N, D)$ and v having shape $(M, D)$. + """ nominator = torch.einsum("ij,kj->ki", x, v) denominator = x.shape[0] * (regularization + torch.sum(x**2, dim=1)) return (v - (nominator / denominator) @ x) / regularization @@ -674,6 +700,15 @@ def inverse_rank_one_update( class BlockMode(Enum): + """ + Enumeration for different modes of grouping model parameters. + + Attributes: + LAYER_WISE: Groups parameters by layers of the model. + PARAMETER_WISE: Groups parameters individually. + FULL: Groups all parameters together. + """ + LAYER_WISE: str = "layer_wise" PARAMETER_WISE: str = "parameter_wise" FULL: str = "full" @@ -681,6 +716,15 @@ class BlockMode(Enum): @dataclass class ModelParameterDictBuilder: + """ + A builder class for creating ordered dictionaries of model parameters based on + specified block modes or custom blocking structures. + + Attributes: + model: The neural network model. + detach: Whether to detach the parameters from the computation graph. + """ + model: torch.nn.Module detach: bool = True @@ -689,9 +733,58 @@ def _optional_detach(self, p: torch.nn.Parameter): return p.detach() return p - def build( + def build(self, block_structure: OrderedDict[str, List[str]]): + """ + Builds an ordered dictionary of model parameters based on the specified block + structure represented by an ordered dictionary, where the keys are block + identifiers and the values are lists of model parameter names contained in + this block. + + Args: + block_structure: The block structure specifying how to group the parameters. + + Returns: + An ordered dictionary of ordered dictionaries, where the outer dictionary's + keys are block identifiers and the inner dictionaries map parameter names + to parameters. + """ + parameter_dict = OrderedDict() + + for block_name, parameter_names in block_structure.items(): + inner_ordered_dict = OrderedDict() + for parameter_name in parameter_names: + parameter = self.model.state_dict()[parameter_name] + if parameter.requires_grad: + inner_ordered_dict[parameter_name] = self._optional_detach( + parameter + ) + else: + warnings.warn( + f"The parameter {parameter_name} from the block " + f"{block_name} is mark as not trainable in the model " + f"and will be excluded from the computation." + ) + parameter_dict[block_name] = inner_ordered_dict + + return parameter_dict + + def build_from_block_mode( self, block_mode: BlockMode ) -> OrderedDict[str, OrderedDict[str, torch.nn.Parameter]]: + """ + Builds an ordered dictionary of model parameters based on the specified block + mode or custom blocking structure represented by an ordered dictionary, where + the keys are block identifiers and the values are lists of model parameter names + contained in this block. + + Args: + block_mode: The block mode specifying how to group the parameters. + + Returns: + An ordered dictionary of ordered dictionaries, where the outer dictionary's + keys are block identifiers and the inner dictionaries map parameter names + to parameters. + """ parameter_dict = OrderedDict() if block_mode is BlockMode.FULL: diff --git a/src/pydvl/influence/types.py b/src/pydvl/influence/types.py index 0a4d5b0fd..5cb405fda 100644 --- a/src/pydvl/influence/types.py +++ b/src/pydvl/influence/types.py @@ -120,7 +120,7 @@ class GradientProvider(Generic[BatchType, TensorType], ABC): """ @abstractmethod - def per_sample_gradient_dict(self, batch: BatchType) -> Dict[str, TensorType]: + def grads(self, batch: BatchType) -> Dict[str, TensorType]: r""" Computes and returns a dictionary mapping gradient names to their respective per-sample gradients. Given the example in the class docstring, this means @@ -141,7 +141,7 @@ def per_sample_gradient_dict(self, batch: BatchType) -> Dict[str, TensorType]: """ @abstractmethod - def per_sample_mixed_gradient_dict(self, batch: BatchType) -> Dict[str, TensorType]: + def mixed_grads(self, batch: BatchType) -> Dict[str, TensorType]: r""" Computes and returns a dictionary mapping gradient names to their respective per-sample mixed gradients. In this context, mixed gradients refer to computing @@ -166,7 +166,7 @@ def per_sample_mixed_gradient_dict(self, batch: BatchType) -> Dict[str, TensorTy """ @abstractmethod - def matrix_jacobian_product( + def jacobian_prod( self, batch: BatchType, g: TensorType, @@ -193,7 +193,7 @@ def matrix_jacobian_product( """ @abstractmethod - def per_sample_flat_gradient(self, batch: BatchType) -> TensorType: + def flat_grads(self, batch: BatchType) -> TensorType: r""" Computes and returns the flat per-sample gradients for the provided batch. Given the example in the class docstring, this means @@ -215,7 +215,7 @@ def per_sample_flat_gradient(self, batch: BatchType) -> TensorType: """ @abstractmethod - def per_sample_flat_mixed_gradient(self, batch: BatchType) -> TensorType: + def flat_mixed_grads(self, batch: BatchType) -> TensorType: r""" Computes and returns the flat per-sample mixed gradients for the provided batch. Given the example in the class docstring, this means @@ -248,9 +248,7 @@ class BilinearForm(Generic[TensorType, BatchType, GradientProviderType], ABC): """ @abstractmethod - def inner_product( - self, left: TensorType, right: Optional[TensorType] - ) -> TensorType: + def inner_prod(self, left: TensorType, right: Optional[TensorType]) -> TensorType: r""" Computes the inner product of two vectors, i.e. @@ -273,7 +271,7 @@ def inner_product( A tensor representing the inner product. """ - def gradient_inner_product( + def grads_inner_prod( self, left: BatchType, right: Optional[BatchType], @@ -298,14 +296,14 @@ def gradient_inner_product( Returns: A tensor representing the inner products of the per-sample gradients """ - left_grad = gradient_provider.per_sample_flat_gradient(left) + left_grad = gradient_provider.flat_grads(left) if right is None: right_grad = left_grad else: - right_grad = gradient_provider.per_sample_flat_gradient(right) - return self.inner_product(left_grad, right_grad) + right_grad = gradient_provider.flat_grads(right) + return self.inner_prod(left_grad, right_grad) - def mixed_gradient_inner_product( + def mixed_grads_inner_prod( self, left: BatchType, right: BatchType, gradient_provider: GradientProviderType ) -> TensorType: r""" @@ -327,9 +325,9 @@ def mixed_gradient_inner_product( Returns: A tensor representing the inner products of the mixed per-sample gradients """ - left_grad = gradient_provider.per_sample_flat_gradient(left) - right_mixed_grad = gradient_provider.per_sample_flat_mixed_gradient(right) - return self.inner_product(left_grad, right_mixed_grad) + left_grad = gradient_provider.flat_grads(left) + right_mixed_grad = gradient_provider.flat_mixed_grads(right) + return self.inner_prod(left_grad, right_mixed_grad) BilinearFormType = TypeVar("BilinearFormType", bound=BilinearForm) @@ -414,7 +412,7 @@ def __init__(self, op: OperatorType, gp: GradientProviderType): self.gp = gp self.op = op - def gradient_interaction( + def interactions( self, left_batch: BatchType, right_batch: BatchType, mode: InfluenceMode ): r""" @@ -445,14 +443,10 @@ def gradient_interaction( """ bilinear_form = self.op.as_bilinear_form() if mode is InfluenceMode.Up: - return bilinear_form.gradient_inner_product( - left_batch, right_batch, self.gp - ) - return bilinear_form.mixed_gradient_inner_product( - left_batch, right_batch, self.gp - ) + return bilinear_form.grads_inner_prod(left_batch, right_batch, self.gp) + return bilinear_form.mixed_grads_inner_prod(left_batch, right_batch, self.gp) - def transformed_gradients(self, batch: BatchType): + def transformed_grads(self, batch: BatchType): r""" Computes the gradients of a data batch, transformed by the operator application , i.e. the expressions @@ -467,10 +461,10 @@ def transformed_gradients(self, batch: BatchType): A tensor representing the application of the operator to the gradients. """ - grads = self.gp.per_sample_flat_gradient(batch) + grads = self.gp.flat_grads(batch) return self.op.apply_to_mat(grads) - def interaction_from_transformed_gradients( + def interactions_from_transformed_grads( self, left_factors: TensorType, right_batch: BatchType, mode: InfluenceMode ): r""" @@ -499,16 +493,18 @@ def interaction_from_transformed_gradients( batch gradients. """ if mode is InfluenceMode.Up: - right_grads = self.gp.per_sample_flat_gradient(right_batch) + right_grads = self.gp.flat_grads(right_batch) else: - right_grads = self.gp.per_sample_flat_mixed_gradient(right_batch) - return self.op.as_bilinear_form().inner_product(left_factors, right_grads) + right_grads = self.gp.flat_mixed_grads(right_batch) + return self.op.as_bilinear_form().inner_prod(left_factors, right_grads) -ComposableBlockType = TypeVar("ComposableBlockType", bound=OperatorGradientComposition) +OperatorGradientCompositionType = TypeVar( + "OperatorGradientCompositionType", bound=OperatorGradientComposition +) -class BlockMapper(Generic[TensorType, BatchType, ComposableBlockType], ABC): +class BlockMapper(Generic[TensorType, BatchType, OperatorGradientCompositionType], ABC): """ Abstract base class for mapping operations across multiple compositional blocks. @@ -521,9 +517,17 @@ class BlockMapper(Generic[TensorType, BatchType, ComposableBlockType], ABC): interactions. """ - def __init__(self, composable_block_dict: OrderedDict[str, ComposableBlockType]): + def __init__( + self, composable_block_dict: OrderedDict[str, OperatorGradientCompositionType] + ): self.composable_block_dict = composable_block_dict + def __getitem__(self, item: str): + return self.composable_block_dict[item] + + def items(self): + return self.composable_block_dict.items() + def _to_ordered_dict( self, tensor_generator: Generator[TensorType, None, None] ) -> OrderedDict[str, TensorType]: @@ -539,7 +543,7 @@ def _split_to_blocks( """Must be implemented in a way to preserve the ordering defined by the `composable_block_dict` attribute""" - def block_transformed_gradients( + def transformed_grads( self, batch: BatchType, ) -> OrderedDict[str, TensorType]: @@ -553,10 +557,10 @@ def block_transformed_gradients( Returns: An ordered dictionary of transformed gradients by block. """ - tensor_gen = self.generate_transformed_gradients(batch) + tensor_gen = self.generate_transformed_grads(batch) return self._to_ordered_dict(tensor_gen) - def block_interactions( + def interactions( self, left_batch: BatchType, right_batch: BatchType, mode: InfluenceMode ) -> OrderedDict[str, TensorType]: """ @@ -571,10 +575,10 @@ def block_interactions( Returns: An ordered dictionary of gradient interactions by block. """ - tensor_gen = self.generate_gradient_interactions(left_batch, right_batch, mode) + tensor_gen = self.generate_interactions(left_batch, right_batch, mode) return self._to_ordered_dict(tensor_gen) - def block_interactions_from_transformed_gradients( + def interactions_from_transformed_grads( self, left_factors: OrderedDict[str, TensorType], right_batch: BatchType, @@ -594,12 +598,12 @@ def block_interactions_from_transformed_gradients( Returns: An ordered dictionary of interactions from transformed gradients by block. """ - tensor_gen = self.generate_interactions_from_transformed_gradients( + tensor_gen = self.generate_interactions_from_transformed_grads( left_factors, right_batch, mode ) return self._to_ordered_dict(tensor_gen) - def generate_transformed_gradients( + def generate_transformed_grads( self, batch: BatchType ) -> Generator[TensorType, None, None]: """ @@ -613,9 +617,9 @@ def generate_transformed_gradients( Transformed gradients for each block. """ for comp_block in self.composable_block_dict.values(): - yield comp_block.transformed_gradients(batch) + yield comp_block.transformed_grads(batch) - def generate_gradient_interactions( + def generate_interactions( self, left_batch: BatchType, right_batch: BatchType, mode: InfluenceMode ) -> Generator[TensorType, None, None]: """ @@ -631,9 +635,9 @@ def generate_gradient_interactions( TensorType: Gradient interactions for each block. """ for comp_block in self.composable_block_dict.values(): - yield comp_block.gradient_interaction(left_batch, right_batch, mode) + yield comp_block.interactions(left_batch, right_batch, mode) - def generate_interactions_from_transformed_gradients( + def generate_interactions_from_transformed_grads( self, left_factors: Union[TensorType, OrderedDict[str, TensorType]], right_batch: BatchType, @@ -657,7 +661,7 @@ def generate_interactions_from_transformed_gradients( else: left_factors_dict = cast(OrderedDict[str, TensorType], left_factors) for k, comp_block in self.composable_block_dict.items(): - yield comp_block.interaction_from_transformed_gradients( + yield comp_block.interactions_from_transformed_grads( left_factors_dict[k], right_batch, mode ) diff --git a/tests/influence/test_influence_calculator.py b/tests/influence/test_influence_calculator.py index 854321f8f..9a82e89cf 100644 --- a/tests/influence/test_influence_calculator.py +++ b/tests/influence/test_influence_calculator.py @@ -28,10 +28,6 @@ EkfacInfluence, ) from pydvl.influence.torch.influence_function_model import NystroemSketchInfluence -from pydvl.influence.torch.pre_conditioner import ( - JacobiPreConditioner, - NystroemPreConditioner, -) from pydvl.influence.torch.util import ( NestedTorchCatAggregator, TorchCatAggregator, diff --git a/tests/influence/torch/test_batch_operation.py b/tests/influence/torch/test_batch_operation.py index b6141a6c8..304de4518 100644 --- a/tests/influence/torch/test_batch_operation.py +++ b/tests/influence/torch/test_batch_operation.py @@ -2,7 +2,6 @@ import pytest import torch -from influence.torch.test_util import model_data, test_parameters from pydvl.influence.torch.base import TorchBatch from pydvl.influence.torch.batch_operation import ( @@ -10,6 +9,8 @@ HessianBatchOperation, ) +from .test_util import model_data, test_parameters + @pytest.mark.torch @pytest.mark.parametrize( diff --git a/tests/influence/torch/test_gradient_provider.py b/tests/influence/torch/test_gradient_provider.py index c15814e37..8fab7fadf 100644 --- a/tests/influence/torch/test_gradient_provider.py +++ b/tests/influence/torch/test_gradient_provider.py @@ -1,11 +1,12 @@ import numpy as np import pytest import torch -from influence.conftest import linear_mixed_second_derivative_analytical, linear_model -from influence.torch.conftest import DATA_OUTPUT_NOISE, linear_mvp_model from pydvl.influence.torch.base import TorchAutoGrad, TorchBatch +from ..conftest import linear_mixed_second_derivative_analytical, linear_model +from .conftest import DATA_OUTPUT_NOISE, linear_mvp_model + class TestTorchPerSampleAutograd: @pytest.mark.torch @@ -22,8 +23,8 @@ def test_per_sample_gradient(self, in_features, out_features, batch_size): params = {k: p.detach() for k, p in model.named_parameters() if p.requires_grad} gp = TorchAutoGrad(model, loss, restrict_to=params) - gradients = gp.per_sample_gradient_dict(TorchBatch(x, y)) - flat_gradients = gp.per_sample_flat_gradient(TorchBatch(x, y)) + gradients = gp.grads(TorchBatch(x, y)) + flat_gradients = gp.flat_grads(TorchBatch(x, y)) # Compute analytical gradients y_pred = model(x) @@ -69,7 +70,7 @@ def test_mixed_derivatives(self, in_features, out_features, train_set_size): torch_train_x = torch.as_tensor(train_x) torch_train_y = torch.as_tensor(train_y) gp = TorchAutoGrad(model, loss, restrict_to=params) - flat_functorch_mixed_derivatives = gp.per_sample_flat_mixed_gradient( + flat_functorch_mixed_derivatives = gp.flat_mixed_grads( TorchBatch(torch_train_x, torch_train_y) ) assert torch.allclose( @@ -95,7 +96,7 @@ def test_matrix_jacobian_product( gp = TorchAutoGrad(model, torch.nn.functional.mse_loss, restrict_to=params) G = torch.randn((10, out_features * (in_features + 1))) - mjp = gp.matrix_jacobian_product(TorchBatch(x, y), G) + mjp = gp.jacobian_prod(TorchBatch(x, y), G) dL_dw = torch.vmap( lambda r, s, t: 2 diff --git a/tests/influence/torch/test_util.py b/tests/influence/torch/test_util.py index 82510e640..a3c1caa29 100644 --- a/tests/influence/torch/test_util.py +++ b/tests/influence/torch/test_util.py @@ -387,7 +387,7 @@ def test_build(self, block_mode, model): model=model, detach=True, ) - param_dict = builder.build(block_mode) + param_dict = builder.build_from_block_mode(block_mode) if block_mode is BlockMode.FULL: assert "" in param_dict