From 93b6ecdbc734f22bafa83b588fbbda1ed4e98b2d Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 16:54:28 +0200 Subject: [PATCH 01/43] Add new module influence.types containing generic types to build up influence computation in a very flexible way --- src/pydvl/influence/types.py | 623 +++++++++++++++++++++++++++++++++++ 1 file changed, 623 insertions(+) create mode 100644 src/pydvl/influence/types.py diff --git a/src/pydvl/influence/types.py b/src/pydvl/influence/types.py new file mode 100644 index 000000000..093518275 --- /dev/null +++ b/src/pydvl/influence/types.py @@ -0,0 +1,623 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from collections import OrderedDict +from typing import ( + TypeVar, + Iterable, + Generic, + Optional, + Generator, + Union, + Collection, + cast, + Dict, +) + + +class InfluenceMode(str, Enum): + """ + Enum representation for the types of influence. + + Attributes: + Up: [Approximating the influence of a point] + [approximating-the-influence-of-a-point] + Perturbation: [Perturbation definition of the influence score] + [perturbation-definition-of-the-influence-score] + + """ + + Up = "up" + Perturbation = "perturbation" + + +"""Type variable for tensors, i.e. sequences of numbers""" +TensorType = TypeVar("TensorType", bound=Collection) +DataLoaderType = TypeVar("DataLoaderType", bound=Iterable) + + +@dataclass(frozen=True) +class Batch(Generic[TensorType]): + """ + Represents a batch of data containing features and labels. + + Attributes: + x: Represents the input features of the batch. + y: Represents the labels or targets associated with the input features. + """ + + x: TensorType + y: TensorType + + +BatchType = TypeVar("BatchType", bound=Batch) + + +class PerSampleGradientProvider(Generic[BatchType, TensorType], ABC): + r""" + Provides an interface for calculating per-sample gradients and other related + computations for a given batch of data. + + This class must be subclassed with implementations for its abstract methods tailored + to specific gradient computation needs, e.g. using an autograd engine for + a model loss function. Consider a function + + $$ \ell: \mathbb{R}^{d_1} \times \mathbb{R}^{d_2} \times \mathbb{R}^{n} \times + \mathbb{R}^{n}, \quad \ell(\omega_1, \omega_2, x, y) = + \operatorname{loss}(f(\omega_1, \omega_2; x), y) $$ + + e.g. a two layer neural network $f$ with a loss function, then this object should + compute the expressions: + + $$ \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega}\ell(\omega_1, \omega_2, x, y) \cdot v$$ + + """ + + @abstractmethod + def per_sample_gradient_dict(self, batch: BatchType) -> Dict[str, TensorType]: + r""" + Computes and returns a dictionary mapping gradient names to their respective + per-sample gradients. Given the example in the class docstring, this means + + $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y}), $$ + + where the first dimension of the resulting tensors is always considered to be + the batch dimension, so the shape of the resulting tensors are $(N, d_i)$, + where $N$ is the number of samples in the batch. + + Args: + batch: The batch of data for which to compute gradients. + + Returns: + A dictionary where keys are gradient identifiers and values are the + gradients computed per sample. + """ + + @abstractmethod + def per_sample_mixed_gradient_dict(self, batch: BatchType) -> Dict[str, TensorType]: + r""" + Computes and returns a dictionary mapping gradient names to their respective + per-sample mixed gradients. In this context, mixed gradients refer to computing + gradients with respect to the instance definition in addition to + compute derivatives with respect to the input batch. + Given the example in the class docstring, this means + + $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, + \omega_2, \text{batch.x}, \text{batch.y}), $$ + + where the first dimension of the resulting tensors is always considered to be + the batch dimension and the last to be the non-batch input related derivatives. + So the shape of the resulting tensors are $(N, n, d_i)$, + where $N$ is the number of samples in the batch. + + Args: + batch: The batch of data for which to compute mixed gradients. + + Returns: + A dictionary where keys are gradient identifiers and values are the + mixed gradients computed per sample. + """ + + @abstractmethod + def matrix_jacobian_product( + self, + batch: BatchType, + g: TensorType, + ) -> TensorType: + r""" + Computes the matrix-Jacobian product for the provided batch and input tensor. + Given the example in the class docstring, this means + + $$ (\nabla_{\omega_{1}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y}), + \nabla_{\omega_{2}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y})) \cdot g^T$$ + + where g must be a tensor of shape $(K, d_1+d_2)$, so the resulting tensor + is of shape $(N, K)$. + + Args: + batch: The batch of data for which to compute the Jacobian. + g: The tensor to be used in the matrix-Jacobian product + calculation. + + Returns: + The resulting tensor from the matrix-Jacobian product computation. + """ + + @abstractmethod + def per_sample_flat_gradient(self, batch: BatchType) -> TensorType: + r""" + Computes and returns the flat per-sample gradients for the provided batch. + Given the example in the class docstring, this means + + $$ (\nabla_{\omega_{1}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y}), + \nabla_{\omega_{2}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y}))$$ + + where the first dimension of the resulting tensor is always considered to be + the batch dimension, so the shape of the resulting tensor is $(N, d_1+d_2)$, + where $N$ is the number of samples in the batch. + + Args: + batch: The batch of data for which to compute the gradients. + + Returns: + A tensor containing the flat gradients computed per sample. + """ + + @abstractmethod + def per_sample_flat_mixed_gradient(self, batch: BatchType) -> TensorType: + r""" + Computes and returns the flat per-sample mixed gradients for the provided batch. + Given the example in the class docstring, this means + + $$ (\nabla_{\omega_1}\nabla_{x}\ell(\omega_1, + \omega_2, \text{batch.x}, \text{batch.y}), + \nabla_{\omega_1}\nabla_{x}\ell(\omega_1, + \omega_2, \text{batch.x}, \text{batch.y} ))$$ + + where the first dimension of the resulting tensor is always considered to be + the batch dimension and the last to be the non-batch input related derivatives. + So the shape of the resulting tensor is $(N, n, d_1 + d_2)$, + where $N$ is the number of samples in the batch. + + Args: + batch: The batch of data for which to compute the flat mixed gradients. + + Returns: + A tensor containing the flat mixed gradients computed per sample. + """ + + +GradientProviderType = TypeVar("GradientProviderType", bound=PerSampleGradientProvider) + + +class BilinearForm(Generic[TensorType, BatchType, GradientProviderType], ABC): + """ + Abstract base class for bilinear forms, which facilitates the computation of inner + products involving gradients of batches of data. + """ + + @abstractmethod + def inner_product( + self, left: TensorType, right: Optional[TensorType] + ) -> TensorType: + r""" + Computes the inner product of two vectors, i.e. + + $$ \langle x, y \rangle_{B}$$ + + if we denote the bilinear-form by $\langle \cdot, \cdot \rangle_{B}$. + The implementations must take care of according vectorization to make + it applicable to the case, where `left` and `right` are not one-dimensional. + In this case, the trailing dimension of the `left` and `right` tensors are + considered for the computation of the inner product. For example, + if `left` is a tensor of shape $(N, D)$ and, `right` is of shape $(M,..., D)$, + then the result is of shape $(N,..., M)$ + + Args: + left: The first tensor in the inner product computation. + right: The second tensor, optional; if not provided, the inner product will + use `left` tensor for both arguments. + + Returns: + A tensor representing the inner product. + """ + pass + + def gradient_inner_product( + self, + left: BatchType, + right: Optional[BatchType], + gradient_provider: GradientProviderType, + ) -> TensorType: + r""" + Computes the gradient inner product of two batches of data, i.e. + + $$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}), + \nabla_{\omega}\ell(\omega, \text{right.x}, \text{right.y}) \rangle_{B}$$ + + where $\nabla_{\omega}\ell(\omega, \cdot, \cdot)$ is represented by the + `gradient_provider` and the expression must be understood sample-wise. + + Args: + left: The first batch for gradient and inner product computation + right: The second batch for gradient and inner product computation, + optional; if not provided, the inner product will use the gradient + computed for `left` for both arguments. + gradient_provider: The gradient provider to compute the gradients. + + Returns: + A tensor representing the inner products of the per-sample gradients + """ + left_grad = gradient_provider.per_sample_flat_gradient(left) + if right is None: + right_grad = left_grad + else: + right_grad = gradient_provider.per_sample_flat_gradient(right) + return self.inner_product(left_grad, right_grad) + + def mixed_gradient_inner_product( + self, left: BatchType, right: BatchType, gradient_provider: GradientProviderType + ) -> TensorType: + r""" + Computes the mixed gradient inner product of two batches of data, i.e. + + $$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}), + \nabla_{\omega}\nabla_{x}\ell(\omega, \text{right.x}, \text{right.y}) + \rangle_{B}$$ + + where $\nabla_{\omega}\ell(\omega, \cdot)$ and + $\nabla_{\omega}\nabla_{x}\ell(\omega, \cdot)$ are represented by the + `gradient_provider`. The expression must be understood sample-wise. + + Args: + left: The first batch for gradient and inner product computation + right: The second batch for gradient and inner product computation + gradient_provider: The gradient provider to compute the gradients. + + Returns: + A tensor representing the inner products of the mixed per-sample gradients + """ + left_grad = gradient_provider.per_sample_flat_gradient(left) + right_mixed_grad = gradient_provider.per_sample_flat_mixed_gradient(right) + return self.inner_product(left_grad, right_mixed_grad) + + +BilinearFormType = TypeVar("BilinearFormType", bound=BilinearForm) + + +class Operator(Generic[TensorType, BilinearFormType], ABC): + """ + Abstract base class for operators, capable of applying transformations to + vectors and matrices, and can be represented as a bilinear form. + """ + + @property + @abstractmethod + def input_size(self) -> int: + """ + Abstract property to get the needed size for inputs to the operator + instance + + Returns: + An integer representing the input size. + """ + + @abstractmethod + def apply_to_vec(self, vec: TensorType) -> TensorType: + """ + Applies the operator to a vector. + + Args: + vec: A tensor representing the vector to which the operator is applied, + must conform to the operator's input size. + + Returns: + A tensor representing the result of the operator application. + """ + + @abstractmethod + def apply_to_mat(self, mat: TensorType) -> TensorType: + """ + Applies the operator to a matrix. + + Args: + mat: A tensor representing the matrix to which the operator is applied, + where the first dimension is the batch dimension and last dimension + of the matrix must conform to the operator's input size + + Returns: + A tensor representing the result of the operator application. + """ + + @abstractmethod + def as_bilinear_form(self) -> BilinearFormType: + r""" + Represents the operator as a bilinear form, i.e. the weighted inner product + + $$ \langle \operatorname{Op}(x), y \rangle$$ + + Returns: + An instance of type [BilinearForm][pydvl.influence.types.BilinearForm] + representing this operator. + """ + + +OperatorType = TypeVar("OperatorType", bound=Operator) + + +class OperatorGradientComposition( + Generic[TensorType, BatchType, OperatorType, GradientProviderType] +): + """ + Generic base class representing a composable block that integrates an operator and + a gradient provider to compute influences between batches of data. + + This block is designed to be flexible, handling different computational modes via + an abstract operator and gradient provider. + + Attributes: + op: The operator used for transformations and influence computations. + gp: The gradient provider used for obtaining necessary gradients. + """ + + def __init__(self, op: OperatorType, gp: GradientProviderType): + self.gp = gp + self.op = op + + def gradient_interaction( + self, left_batch: BatchType, right_batch: BatchType, mode: InfluenceMode + ): + r""" + Computes the interaction between the gradients on two batches of data based on + the specified mode weighted by the operator action, + i.e. + + $$ \langle \operatorname{Op}(\nabla_{\omega}\ell(\omega, \text{left.x}, + \text{left.y})), + \nabla_{\omega}\ell(\omega, \text{right.x}, \text{right.y}) \rangle$$ + + for the case `InfluenceMode.Up` and + + $$ \langle \operatorname{Op}(\nabla_{\omega}\ell(\omega, \text{left.x}, + \text{left.y})), + \nabla_{\omega}\nabla_{x}\ell(\omega, \text{right.x}, \text{right.y}) \rangle $$ + + for the case `InfluenceMode.Perturbation`. + + Args: + left_batch: The left data batch for gradient computation. + right_batch: The right data batch for gradient computation. + mode: An instance of InfluenceMode determining the type of influence + computation. + + Returns: + The result of the influence computation as dictated by the mode. + """ + bilinear_form = self.op.as_bilinear_form() + if mode is InfluenceMode.Up: + return bilinear_form.gradient_inner_product( + left_batch, right_batch, self.gp + ) + return bilinear_form.mixed_gradient_inner_product( + left_batch, right_batch, self.gp + ) + + def transformed_gradients(self, batch: BatchType): + r""" + Computes the gradients of a data batch, transformed by the operator application + , i.e. the expressions + + $$ \operatorname{Op}(\nabla_{\omega}\ell(\omega, \text{batch.x}, + \text{batch.y})) $$ + + Args: + batch: The data batch for gradient computation. + + Returns: + A tensor representing the application of the operator to the gradients. + + """ + grads = self.gp.per_sample_flat_gradient(batch) + return self.op.apply_to_mat(grads) + + def interaction_from_transformed_gradients( + self, left_factors: TensorType, right_batch: BatchType, mode: InfluenceMode + ): + r""" + Computes the interaction between the transformed gradients on two batches of + data using pre-computed factors and a batch of data, + based on the specified mode. This means + + $$ \langle \text{left_factors}, + \nabla_{\omega}\ell(\omega, \text{right.x}, \text{right.y}) \rangle$$ + + for the case `InfluenceMode.Up` and + + $$ \langle \text{left_factors}, + \nabla_{\omega}\nabla_{x}\ell(\omega, \text{right.x}, \text{right.y}) \rangle $$ + + for the case `InfluenceMode.Perturbation`. + + Args: + left_factors: Pre-computed tensor factors from a left batch. + right_batch: The right data batch for influence computation. + mode: An instance of InfluenceMode determining the type of influence + computation. + + Returns: + The result of the interaction computation using the provided factors and + batch gradients. + """ + if mode is InfluenceMode.Up: + right_grads = self.gp.per_sample_flat_gradient(right_batch) + else: + right_grads = self.gp.per_sample_flat_mixed_gradient(right_batch) + return self.op.as_bilinear_form().inner_product(left_factors, right_grads) + + +ComposableBlockType = TypeVar("ComposableBlockType", bound=OperatorGradientComposition) + + +class BlockMapper(Generic[TensorType, BatchType, ComposableBlockType], ABC): + """ + Abstract base class for mapping operations across multiple compositional blocks. + + This class takes a dictionary of compositional blocks and applies their methods to + batches or tensors, and aggregates the results. + + Attributes: + composable_block_dict: A dictionary mapping string identifiers to + composable blocks which define operations like transformations and + interactions. + """ + + def __init__(self, composable_block_dict: OrderedDict[str, ComposableBlockType]): + self.composable_block_dict = composable_block_dict + + def _to_ordered_dict( + self, tensor_generator: Generator[TensorType, None, None] + ) -> OrderedDict[str, TensorType]: + tensor_dict = OrderedDict() + for k, t in zip(self.composable_block_dict.keys(), tensor_generator): + tensor_dict[k] = t + return tensor_dict + + @abstractmethod + def _split_to_blocks( + self, z: TensorType, dim: int = -1 + ) -> OrderedDict[str, TensorType]: + """Must be implemented in a way to preserve the ordering defined by the + `composable_block_dict` attribute""" + + def block_transformed_gradients( + self, + batch: BatchType, + ) -> OrderedDict[str, TensorType]: + """ + Computes and returns the transformed gradients for a batch in dictionary + with the keys defined by the block names. + + Args: + batch: The batch of data for which to compute transformed gradients. + + Returns: + An ordered dictionary of transformed gradients by block. + """ + tensor_gen = self.generate_transformed_gradients(batch) + return self._to_ordered_dict(tensor_gen) + + def block_interactions( + self, left_batch: BatchType, right_batch: BatchType, mode: InfluenceMode + ) -> OrderedDict[str, TensorType]: + """ + Computes interactions between two batches, aggregated by block, + based on a specified mode. + + Args: + left_batch: The left batch for interaction computation. + right_batch: The right batch for interaction computation. + mode: The mode determining the type of interactions. + + Returns: + An ordered dictionary of gradient interactions by block. + """ + tensor_gen = self.generate_gradient_interactions(left_batch, right_batch, mode) + return self._to_ordered_dict(tensor_gen) + + def block_interactions_from_transformed_gradients( + self, + left_factors: OrderedDict[str, TensorType], + right_batch: BatchType, + mode: InfluenceMode, + ) -> OrderedDict[str, TensorType]: + """ + Computes interactions from transformed gradients and a right batch, + aggregated by block and based on a mode. + + Args: + left_factors: Pre-computed factors as a tensor or an ordered dictionary of + tensors by block. If the input is a tensor, it is split into blocks + according to the ordering in the `composable_block_dict` attribute. + right_batch: The right batch for interaction computation. + mode: The mode determining the type of interactions. + + Returns: + An ordered dictionary of interactions from transformed gradients by block. + """ + tensor_gen = self.generate_interactions_from_transformed_gradients( + left_factors, right_batch, mode + ) + return self._to_ordered_dict(tensor_gen) + + def generate_transformed_gradients( + self, batch: BatchType + ) -> Generator[TensorType, None, None]: + """ + Generator that yields transformed gradients for a given batch, + processed by each block. + + Args: + batch: The batch of data for which to generate transformed gradients. + + Yields: + Transformed gradients for each block. + """ + for comp_block in self.composable_block_dict.values(): + yield comp_block.transformed_gradients(batch) + + def generate_gradient_interactions( + self, left_batch: BatchType, right_batch: BatchType, mode: InfluenceMode + ) -> Generator[TensorType, None, None]: + """ + Generator that yields gradient interactions between two batches, processed by + each block based on a mode. + + Args: + left_batch: The left batch for interaction computation. + right_batch: The right batch for interaction computation. + mode: The mode determining the type of interactions. + + Yields: + TensorType: Gradient interactions for each block. + """ + for comp_block in self.composable_block_dict.values(): + yield comp_block.gradient_interaction(left_batch, right_batch, mode) + + def generate_interactions_from_transformed_gradients( + self, + left_factors: Union[TensorType, OrderedDict[str, TensorType]], + right_batch: BatchType, + mode: InfluenceMode, + ) -> Generator[TensorType, None, None]: + """ + Generator that yields interactions computed from pre-computed factors and a + right batch, processed by each block based on a mode. + + Args: + left_factors: Pre-computed factors as a tensor or an ordered dictionary of + tensors by block. + right_batch: The right batch for interaction computation. + mode: The mode determining the type of interactions. + + Yields: + TensorType: Interactions for each block. + """ + if not isinstance(left_factors, dict): + left_factors_dict = self._split_to_blocks(left_factors) + else: + left_factors_dict = cast(OrderedDict[str, TensorType], left_factors) + for k, comp_block in self.composable_block_dict.items(): + yield comp_block.interaction_from_transformed_gradients( + left_factors_dict[k], right_batch, mode + ) + + +BlockMapperType = TypeVar("BlockMapperType", bound=BlockMapper) From db0fa0714be2673c02c2b9e4ed2ced24180605eb Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 16:56:35 +0200 Subject: [PATCH 02/43] Refactor influence.base_influence_function_model: * move types to influence.types * add fit_required decorator * add class ComposableInfluence, based on generic types from influence.types * add class SumAggregator to influence.array --- src/pydvl/influence/__init__.py | 2 +- src/pydvl/influence/array.py | 22 +- .../base_influence_function_model.py | 197 +++++++++++++++--- src/pydvl/influence/influence_calculator.py | 3 +- 4 files changed, 196 insertions(+), 28 deletions(-) diff --git a/src/pydvl/influence/__init__.py b/src/pydvl/influence/__init__.py index 6065b7cf9..187c98de1 100644 --- a/src/pydvl/influence/__init__.py +++ b/src/pydvl/influence/__init__.py @@ -10,9 +10,9 @@ probably change. """ -from .base_influence_function_model import InfluenceMode from .influence_calculator import ( DaskInfluenceCalculator, DisableClientSingleThreadCheck, SequentialInfluenceCalculator, ) +from .types import InfluenceMode diff --git a/src/pydvl/influence/array.py b/src/pydvl/influence/array.py index 7e71050f9..5faa288ac 100644 --- a/src/pydvl/influence/array.py +++ b/src/pydvl/influence/array.py @@ -28,7 +28,7 @@ from zarr.storage import StoreLike from ..utils import log_duration -from .base_influence_function_model import TensorType +from .types import TensorType class NumpyConverter(Generic[TensorType], ABC): @@ -400,3 +400,23 @@ def _initialize_zarr_array( chunks=chunk_size, dtype=block.dtype, ) + + +class SumAggregator(SequenceAggregator): + def __call__(self, tensor_sequence: LazyChunkSequence): + """ + Aggregates tensors from a single-level generator by summing up. This method simply + collects each tensor emitted by the generator into a single list. + + Args: + tensor_sequence: Object wrapping a generator that yields `TensorType` + objects. + + Returns: + A single tensor representing the sum of all tensors from the generator. + """ + tensor_generator = tensor_sequence.generator_factory() + result = next(tensor_generator) + for tensor in tensor_generator: + result += tensor + return result diff --git a/src/pydvl/influence/base_influence_function_model.py b/src/pydvl/influence/base_influence_function_model.py index 541fbedf0..b6854c250 100644 --- a/src/pydvl/influence/base_influence_function_model.py +++ b/src/pydvl/influence/base_influence_function_model.py @@ -1,26 +1,14 @@ from __future__ import annotations +import logging from abc import ABC, abstractmethod -from enum import Enum -from typing import Collection, Generic, Iterable, Optional, Type, TypeVar +from collections import OrderedDict +from functools import partial, wraps +from typing import Generic, Optional, Type -__all__ = ["InfluenceMode"] - - -class InfluenceMode(str, Enum): - """ - Enum representation for the types of influence. - - Attributes: - Up: [Approximating the influence of a point] - [approximating-the-influence-of-a-point] - Perturbation: [Perturbation definition of the influence score] - [perturbation-definition-of-the-influence-score] - - """ - - Up = "up" - Perturbation = "perturbation" +from ..utils.progress import log_duration +from .array import LazyChunkSequence, SumAggregator +from .types import BatchType, BlockMapperType, DataLoaderType, InfluenceMode, TensorType class UnsupportedInfluenceModeException(ValueError): @@ -46,11 +34,6 @@ def __init__(self, module_id: str): super().__init__(message) -"""Type variable for tensors, i.e. sequences of numbers""" -TensorType = TypeVar("TensorType", bound=Collection) -DataLoaderType = TypeVar("DataLoaderType", bound=Iterable) - - class InfluenceFunctionModel(Generic[TensorType, DataLoaderType], ABC): """ Generic abstract base class for computing influence related quantities. @@ -86,6 +69,18 @@ def fit(self, data: DataLoaderType) -> InfluenceFunctionModel: The fitted instance """ + @staticmethod + def fit_required(method): + """Decorator to enforce the fitted check""" + + @wraps(method) + def wrapper(self, *args, **kwargs): + if not self.is_fitted: + raise NotFittedException(type(self)) + return method(self, *args, **kwargs) + + return wrapper + def influence_factors(self, x: TensorType, y: TensorType) -> TensorType: if not self.is_fitted: raise NotFittedException(type(self)) @@ -119,6 +114,19 @@ def influences( ) -> TensorType: if not self.is_fitted: raise NotFittedException(type(self)) + + if x is None and y is not None: + raise ValueError( + "Providing labels y, without providing model input x " + "is not supported" + ) + + if x is not None and y is None: + raise ValueError( + "Providing model input x, without providing labels y " + "is not supported" + ) + return self._influences(x_test, y_test, x, y, mode) @abstractmethod @@ -199,3 +207,144 @@ def influences_from_factors( Tensor representing the element-wise scalar products for the provided batch """ + + +class ComposableInfluence( + InfluenceFunctionModel, + Generic[TensorType, BatchType, DataLoaderType, BlockMapperType], + ABC, +): + + block_mapper: BlockMapperType + + @property + def n_parameters(self): + return super().n_parameters() + + @property + def is_thread_safe(self) -> bool: + return False + + @property + def is_fitted(self): + try: + return self.block_mapper is not None + except AttributeError: + return False + + @log_duration(log_level=logging.INFO) + def fit(self, data: DataLoaderType) -> InfluenceFunctionModel: + self.block_mapper = self._create_block_mapper(data) + return self + + @abstractmethod + def _create_block_mapper(self, data: DataLoaderType) -> BlockMapperType: + pass + + @InfluenceFunctionModel.fit_required + def influences_by_block( + self, + x_test: TensorType, + y_test: TensorType, + x: Optional[TensorType] = None, + y: Optional[TensorType] = None, + mode: InfluenceMode = InfluenceMode.Up, + ) -> OrderedDict[str, TensorType]: + left_batch = self._create_batch(x_test, y_test) + + if x is None: + if y is not None: + raise ValueError( + "Providing labels y, without providing model input x " + "is not supported" + ) + right_batch = left_batch + else: + if y is None: + raise ValueError( + "Providing model input x, without providing labels y " + "is not supported" + ) + right_batch = self._create_batch(x, y) + + return self.block_mapper.block_interactions(left_batch, right_batch, mode) + + @InfluenceFunctionModel.fit_required + def influence_factors_by_block( + self, x: TensorType, y: TensorType + ) -> OrderedDict[str, TensorType]: + return self.block_mapper.block_transformed_gradients(self._create_batch(x, y)) + + @InfluenceFunctionModel.fit_required + def influences_from_factors_by_block( + self, + z_test_factors: OrderedDict[str, TensorType], + x: TensorType, + y: TensorType, + mode: InfluenceMode = InfluenceMode.Up, + ) -> OrderedDict[str, TensorType]: + return self.block_mapper.block_interactions_from_transformed_gradients( + z_test_factors, self._create_batch(x, y), mode + ) + + def _influence_factors(self, x: TensorType, y: TensorType) -> TensorType: + tensor_gen_factory = partial( + self.block_mapper.generate_transformed_gradients, self._create_batch(x, y) + ) + aggregator = SumAggregator() + result: TensorType = aggregator(LazyChunkSequence(tensor_gen_factory)) + return result + + def _influences( + self, + x_test: TensorType, + y_test: TensorType, + x: Optional[TensorType] = None, + y: Optional[TensorType] = None, + mode: InfluenceMode = InfluenceMode.Up, + ) -> TensorType: + left_batch = self._create_batch(x_test, y_test) + + if x is None: + right_batch = left_batch + elif y is None: + raise ValueError( + "Providing model input x, without providing labels y " + "is not supported" + ) + else: + right_batch = self._create_batch(x, y) + + tensor_gen_factory = partial( + self.block_mapper.generate_gradient_interactions, + left_batch, + right_batch, + mode, + ) + aggregator = SumAggregator() + result: TensorType = aggregator(LazyChunkSequence(tensor_gen_factory)) + return result + + @InfluenceFunctionModel.fit_required + def influences_from_factors( + self, + z_test_factors: TensorType, + x: TensorType, + y: TensorType, + mode: InfluenceMode = InfluenceMode.Up, + ) -> TensorType: + tensor_gen_factory = partial( + self.block_mapper.generate_interactions_from_transformed_gradients, + z_test_factors, + self._create_batch(x, y), + mode, + ) + + aggregator = SumAggregator() + result: TensorType = aggregator(LazyChunkSequence(tensor_gen_factory)) + return result + + @staticmethod + @abstractmethod + def _create_batch(x: TensorType, y: TensorType) -> BatchType: + pass diff --git a/src/pydvl/influence/influence_calculator.py b/src/pydvl/influence/influence_calculator.py index 7c48e8636..327a4137d 100644 --- a/src/pydvl/influence/influence_calculator.py +++ b/src/pydvl/influence/influence_calculator.py @@ -18,10 +18,9 @@ from .array import LazyChunkSequence, NestedLazyChunkSequence, NumpyConverter from .base_influence_function_model import ( InfluenceFunctionModel, - InfluenceMode, - TensorType, UnsupportedInfluenceModeException, ) +from .types import InfluenceMode, TensorType __all__ = [ "DaskInfluenceCalculator", From c4b7d3a54db298542f9864661cf3747a4663485e Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 20:54:28 +0200 Subject: [PATCH 03/43] Fix linting issues in influence.types --- src/pydvl/influence/types.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/pydvl/influence/types.py b/src/pydvl/influence/types.py index 093518275..e9adafc1a 100644 --- a/src/pydvl/influence/types.py +++ b/src/pydvl/influence/types.py @@ -1,18 +1,19 @@ from __future__ import annotations + from abc import ABC, abstractmethod +from collections import OrderedDict from dataclasses import dataclass from enum import Enum -from collections import OrderedDict from typing import ( - TypeVar, - Iterable, + Collection, + Dict, + Generator, Generic, + Iterable, Optional, - Generator, + TypeVar, Union, - Collection, cast, - Dict, ) From c18984833bcffe814cc929865e2c6667e6e48c5f Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 20:56:37 +0200 Subject: [PATCH 04/43] Add class ModelInfoMixin to influence.torch.util --- src/pydvl/influence/torch/util.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index d157d5455..db33a8ee6 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -598,3 +598,22 @@ def __init__(self, original_exception: RuntimeError): f" Inspect the original exception message: \n{str(original_exception)}" ) super().__init__(err_msg) +class ModelInfoMixin: + """ + A mixin class for classes that contain information about a model. + """ + + def __init__(self, model: torch.nn.Module): + self.model = model + + @property + def device(self) -> torch.device: + return next(self.model.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.model.parameters()).dtype + + @property + def n_parameters(self) -> int: + return sum(p.numel() for p in self.model.parameters() if p.requires_grad) From bda35043c194b21a42788772438a1256740cac4a Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 20:58:27 +0200 Subject: [PATCH 05/43] Add class TorchBatch to influence.torch.util --- src/pydvl/influence/torch/util.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index db33a8ee6..45b3c38ca 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -32,6 +32,7 @@ NumpyConverter, SequenceAggregator, ) +from ..types import Batch logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ "torch_dataset_to_dask_array", "EkfacRepresentation", "empirical_cross_entropy_loss_fn", + "TorchBatch", ] @@ -598,6 +600,34 @@ def __init__(self, original_exception: RuntimeError): f" Inspect the original exception message: \n{str(original_exception)}" ) super().__init__(err_msg) +@dataclass(frozen=True) +class TorchBatch(Batch): + """ + A convenience class for handling batches of data. Validates, the alignment + of the first dimension (batch dimension) of the input and target tensor + + Attributes: + x: The input tensor that contains features or data points. + y: The target tensor that contains labels corresponding to the inputs. + + """ + + x: torch.Tensor + y: torch.Tensor + + def __post_init__(self): + if self.x.shape[0] != self.y.shape[0]: + raise ValueError( + f"The first dimension of x and y must be the same, " + f"got {self.x.shape[0]} and {self.y.shape[0]}" + ) + + def __len__(self): + return self.x.shape[0] + + def to(self, device: torch.device): + return TorchBatch(self.x.to(device), self.y.to(device)) + class ModelInfoMixin: """ A mixin class for classes that contain information about a model. From ddec527647c869b0bbfc90ad0ad6047b6ffcb3f8 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 20:59:54 +0200 Subject: [PATCH 06/43] Add classes TorchChunkAverageAggregator and TorchPointAverageAggregator to influence.torch.util --- src/pydvl/influence/torch/util.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index 45b3c38ca..c1e063965 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -438,6 +438,33 @@ def __call__( return torch.cat(list(t_gen)) +class TorchChunkAverageAggregator(SequenceAggregator[torch.Tensor]): + def __call__(self, tensor_sequence: LazyChunkSequence): + t_gen = tensor_sequence.generator_factory() + result = next(t_gen) + n_chunks = 1 + for t in t_gen: + result += t + n_chunks += 1 + return result / n_chunks + + +class TorchPointAverageAggregator(SequenceAggregator[torch.Tensor]): + def __init__(self, batch_dim: int = 0, weighted: bool = True): + self.weighted = weighted + self.batch_dim = batch_dim + + def __call__(self, tensor_sequence: LazyChunkSequence): + tensor_generator = tensor_sequence.generator_factory() + result = next(tensor_generator) + n_points = result.shape[self.batch_dim] + for tensor in tensor_generator: + n_points_in_batch = tensor.shape[self.batch_dim] + result += n_points_in_batch * tensor if self.weighted else tensor + n_points += n_points_in_batch + return result / n_points + + class NestedTorchCatAggregator(NestedSequenceAggregator[torch.Tensor]): """ An aggregator that concatenates tensors using PyTorch's [torch.cat][torch.cat] From 74728406619504c146f8e9e4a3a51cc21f44008a Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 21:01:01 +0200 Subject: [PATCH 07/43] Add functions rank_one_mvp and inverse_rank_one_update to influence.torch.util --- src/pydvl/influence/torch/util.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index c1e063965..e044dddd9 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -627,6 +627,36 @@ def __init__(self, original_exception: RuntimeError): f" Inspect the original exception message: \n{str(original_exception)}" ) super().__init__(err_msg) + + +def rank_one_mvp(x: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + r""" + Computes the matrix-vector product of xx^T and v for each row in X and V without + forming xx^T and sums the result. Here, X and V are matrices where each row + represents an individual vector. Effectively it is computing + + $$ V@(\sum_i^N x[i]x[i]^T) $$ + + Args: + x: Matrix of vectors of size `(N, M)`. + v: Matrix of vectors of size `(B, M)` to be multiplied by the corresponding + $xx^T$. + + Returns: + A matrix of size `(B, N)` where each column is the result of xx^T v for + corresponding rows in x and v. + """ + return torch.einsum("ij,kj->ki", x, v) @ x + + +def inverse_rank_one_update( + x: torch.Tensor, v: torch.Tensor, regularization: float +) -> torch.Tensor: + nominator = torch.einsum("ij,kj->ki", x, v) + denominator = x.shape[0] * (regularization + torch.sum(x**2, dim=1)) + return (v - (nominator / denominator) @ x) / regularization + + @dataclass(frozen=True) class TorchBatch(Batch): """ From e32ae46cb7720a63cc3849c7df491ea73ecf2ecb Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 21:02:37 +0200 Subject: [PATCH 08/43] Add LossType (annotation), enum BlockMode and class ModelParameterDictBuilder to influence.torch.util --- src/pydvl/influence/torch/util.py | 63 +++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index e044dddd9..0d6a55f99 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -1,8 +1,13 @@ +from __future__ import annotations + import logging import math +from collections import OrderedDict from dataclasses import dataclass +from enum import Enum from functools import partial from typing import ( + Callable, Collection, Dict, Iterable, @@ -48,7 +53,14 @@ "torch_dataset_to_dask_array", "EkfacRepresentation", "empirical_cross_entropy_loss_fn", + "rank_one_mvp", + "inverse_rank_one_update", + "TorchPointAverageAggregator", + "TorchChunkAverageAggregator", "TorchBatch", + "LossType", + "ModelParameterDictBuilder", + "BlockMode", ] @@ -657,6 +669,9 @@ def inverse_rank_one_update( return (v - (nominator / denominator) @ x) / regularization +LossType = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + + @dataclass(frozen=True) class TorchBatch(Batch): """ @@ -685,6 +700,54 @@ def __len__(self): def to(self, device: torch.device): return TorchBatch(self.x.to(device), self.y.to(device)) + +class BlockMode(Enum): + LAYER_WISE: str = "layer_wise" + PARAMETER_WISE: str = "parameter_wise" + FULL: str = "full" + + +@dataclass +class ModelParameterDictBuilder: + model: torch.nn.Module + detach: bool = True + + def _optional_detach(self, p: torch.nn.Parameter): + if self.detach: + return p.detach() + return p + + def build( + self, block_mode: BlockMode + ) -> OrderedDict[str, OrderedDict[str, torch.nn.Parameter]]: + parameter_dict = OrderedDict() + + if block_mode is BlockMode.FULL: + inner_ordered_dict = OrderedDict() + for k, v in self.model.named_parameters(): + if v.requires_grad: + inner_ordered_dict[k] = self._optional_detach(v) + parameter_dict[""] = inner_ordered_dict + + elif block_mode is BlockMode.PARAMETER_WISE: + for k, v in self.model.named_parameters(): + if v.requires_grad: + parameter_dict[k] = OrderedDict({k: self._optional_detach(v)}) + + if block_mode is BlockMode.LAYER_WISE: + for name, submodule in self.model.named_children(): + inner_ordered_dict = OrderedDict() + for param_name, param in submodule.named_parameters(): + if param.requires_grad: + inner_ordered_dict[ + f"{name}.{param_name}" + ] = self._optional_detach(param) + if inner_ordered_dict: + parameter_dict[name] = inner_ordered_dict + + return parameter_dict + + class ModelInfoMixin: """ A mixin class for classes that contain information about a model. From d804e6b8879174e35b53ec5639a252b73c223bbc Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 21:03:48 +0200 Subject: [PATCH 09/43] Add dtype property to LowRankProductRepresentation --- src/pydvl/influence/torch/functional.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pydvl/influence/torch/functional.py b/src/pydvl/influence/torch/functional.py index 1028b6acd..6db6f12fa 100644 --- a/src/pydvl/influence/torch/functional.py +++ b/src/pydvl/influence/torch/functional.py @@ -632,6 +632,10 @@ def device(self) -> torch.device: else torch.device("cpu") ) + @property + def dtype(self) -> torch.dtype: + return self.projections.dtype + def to(self, device: torch.device): """ Move the representing tensors to a device From 993bcfa67fb9841e07bcb5a672de6da03256c7f2 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 21:11:53 +0200 Subject: [PATCH 10/43] Add new subpackage influence.torch.operator with submodule operator.gradient_provider --- .../influence/torch/operator/__init__.py | 0 .../torch/operator/gradient_provider.py | 161 ++++++++++++++++++ 2 files changed, 161 insertions(+) create mode 100644 src/pydvl/influence/torch/operator/__init__.py create mode 100644 src/pydvl/influence/torch/operator/gradient_provider.py diff --git a/src/pydvl/influence/torch/operator/__init__.py b/src/pydvl/influence/torch/operator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pydvl/influence/torch/operator/gradient_provider.py b/src/pydvl/influence/torch/operator/gradient_provider.py new file mode 100644 index 000000000..5b06496bb --- /dev/null +++ b/src/pydvl/influence/torch/operator/gradient_provider.py @@ -0,0 +1,161 @@ +from abc import ABC, abstractmethod +from typing import Dict, Callable, Optional + +import torch +from torch.func import functional_call + +from ..functional import ( + create_per_sample_gradient_function, + create_per_sample_mixed_derivative_function, + create_matrix_jacobian_product_function, +) + +from ..util import ( + flatten_dimensions, + LossType, + TorchBatch, + ModelParameterDictBuilder, + BlockMode, +) + +from ...types import PerSampleGradientProvider + + +class TorchPerSampleGradientProvider( + PerSampleGradientProvider[TorchBatch, torch.Tensor], ABC +): + def __init__( + self, + model: torch.nn.Module, + loss: LossType, + restrict_to: Optional[Dict[str, torch.nn.Parameter]], + ): + self.loss = loss + self.model = model + + if restrict_to is None: + restrict_to = ModelParameterDictBuilder(model).build(BlockMode.FULL) + + self.params_to_restrict_to = restrict_to + + def to(self, device: torch.device): + self.model = self.model.to(device) + self.params_to_restrict_to = { + k: p.detach() + for k, p in self.model.named_parameters() + if k in self.params_to_restrict_to + } + return self + + @property + def device(self): + return next(self.model.parameters()).device + + @property + def dtype(self): + return next(self.model.parameters()).dtype + + @abstractmethod + def _per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + pass + + @abstractmethod + def _per_sample_mixed_gradient_dict( + self, batch: TorchBatch + ) -> Dict[str, torch.Tensor]: + pass + + @abstractmethod + def _matrix_jacobian_product( + self, + batch: TorchBatch, + g: torch.Tensor, + ) -> torch.Tensor: + pass + + @staticmethod + def _detach_dict(tensor_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return {k: g.detach() if g.requires_grad else g for k, g in tensor_dict.items()} + + def per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + gradient_dict = self._per_sample_gradient_dict(batch.to(self.device)) + return self._detach_dict(gradient_dict) + + def per_sample_mixed_gradient_dict( + self, batch: TorchBatch + ) -> Dict[str, torch.Tensor]: + gradient_dict = self._per_sample_mixed_gradient_dict(batch.to(self.device)) + return self._detach_dict(gradient_dict) + + def matrix_jacobian_product( + self, + batch: TorchBatch, + g: torch.Tensor, + ) -> torch.Tensor: + result = self._matrix_jacobian_product(batch.to(self.device), g.to(self.device)) + if result.requires_grad: + result = result.detach() + return result + + def per_sample_flat_gradient(self, batch: TorchBatch) -> torch.Tensor: + return flatten_dimensions( + self.per_sample_gradient_dict(batch).values(), shape=(batch.x.shape[0], -1) + ) + + def per_sample_flat_mixed_gradient(self, batch: TorchBatch) -> torch.Tensor: + shape = (*batch.x.shape, -1) + return flatten_dimensions( + self.per_sample_mixed_gradient_dict(batch).values(), shape=shape + ) + + +class TorchPerSampleAutoGrad(TorchPerSampleGradientProvider): + def __init__( + self, + model: torch.nn.Module, + loss: LossType, + restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, + ): + super().__init__(model, loss, restrict_to) + self._per_sample_gradient_function = create_per_sample_gradient_function( + model, loss + ) + self._per_sample_mixed_gradient_func = ( + create_per_sample_mixed_derivative_function(model, loss) + ) + + def _compute_loss( + self, params: Dict[str, torch.Tensor], x: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: + outputs = functional_call(self.model, params, (x.unsqueeze(0).to(self.device),)) + return self.loss(outputs, y.unsqueeze(0)) + + def _per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + return self._per_sample_gradient_function( + self.params_to_restrict_to, batch.x, batch.y + ) + + def _per_sample_mixed_gradient_dict( + self, batch: TorchBatch + ) -> Dict[str, torch.Tensor]: + return self._per_sample_mixed_gradient_func( + self.params_to_restrict_to, batch.x, batch.y + ) + + def _matrix_jacobian_product( + self, + batch: TorchBatch, + g: torch.Tensor, + ) -> torch.Tensor: + matrix_jacobian_product_func = create_matrix_jacobian_product_function( + self.model, self.loss, g + ) + return matrix_jacobian_product_func( + self.params_to_restrict_to, batch.x, batch.y + ) + + +GradientProviderFactoryType = Callable[ + [torch.nn.Module, LossType, Optional[Dict[str, torch.nn.Parameter]]], + TorchPerSampleGradientProvider, +] From 3292253591a4d00f0a4af90a9af7ec52b04872cc Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 22:36:40 +0200 Subject: [PATCH 11/43] Add operator subpackages batch_operation, add tests for bnatch_operation and gradient_provider --- .../torch/operator/batch_operation.py | 202 ++++++++++++++++++ .../torch/operator/gradient_provider.py | 14 +- tests/influence/torch/operator/__init__.py | 0 .../torch/operator/test_batch_operation.py | 58 +++++ .../torch/operator/test_gradient_provider.py | 103 +++++++++ 5 files changed, 369 insertions(+), 8 deletions(-) create mode 100644 src/pydvl/influence/torch/operator/batch_operation.py create mode 100644 tests/influence/torch/operator/__init__.py create mode 100644 tests/influence/torch/operator/test_batch_operation.py create mode 100644 tests/influence/torch/operator/test_gradient_provider.py diff --git a/src/pydvl/influence/torch/operator/batch_operation.py b/src/pydvl/influence/torch/operator/batch_operation.py new file mode 100644 index 000000000..45dcb8682 --- /dev/null +++ b/src/pydvl/influence/torch/operator/batch_operation.py @@ -0,0 +1,202 @@ +from abc import ABC, abstractmethod +from typing import Callable, Optional, Dict, Union, Type + +import torch + +from ..functional import create_batch_hvp_function +from ..util import ( + inverse_rank_one_update, + rank_one_mvp, + LossType, + TorchBatch, +) + +from .gradient_provider import ( + TorchPerSampleGradientProvider, + TorchPerSampleAutoGrad, + GradientProviderFactoryType, +) + + +class BatchOperation(ABC): + def __init__(self, regularization: float = 0.0): + if regularization < 0: + raise ValueError("regularization must be non-negative") + self._regularization = regularization + + @property + @abstractmethod + def n_parameters(self): + pass + + @property + def regularization(self): + return self._regularization + + @regularization.setter + def regularization(self, value: float): + if value < 0: + raise ValueError("regularization must be non-negative") + self._regularization = value + + @property + @abstractmethod + def device(self): + pass + + @property + @abstractmethod + def dtype(self): + pass + + @abstractmethod + def to(self, device: torch.device): + pass + + @abstractmethod + def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: + pass + + def apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor): + return self._apply_to_vec(batch.to(self.device), vec.to(self.device)) + + def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + return torch.func.vmap( + lambda _x, _y, m: self._apply_to_vec(TorchBatch(_x, _y), m), + in_dims=(None, None, 0), + randomness="same", + )(batch.x, batch.y, mat) + + +class ModelBasedBatchOperation(BatchOperation, ABC): + def __init__( + self, + model: torch.nn.Module, + loss: LossType, + regularization: float = 0.0, + restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, + ): + super().__init__(regularization) + if restrict_to is None: + restrict_to = { + k: p.detach() for k, p in model.named_parameters() if p.requires_grad + } + self.params_to_restrict_to = restrict_to + self.loss = loss + self.model = model + + @property + def device(self): + return next(self.model.parameters()).device + + @property + def dtype(self): + return next(self.model.parameters()).dtype + + @property + def n_parameters(self): + return sum(p.numel() for p in self.params_to_restrict_to.values()) + + def to(self, device: torch.device): + self.model = self.model.to(device) + self.params_to_restrict_to = { + k: p.detach() + for k, p in self.model.named_parameters() + if k in self.params_to_restrict_to + } + return self + + +class HessianBatchOperation(ModelBasedBatchOperation): + def __init__( + self, + model: torch.nn.Module, + loss: LossType, + regularization: float = 0.0, + restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, + reverse_only: bool = True, + ): + super().__init__( + model, loss, regularization=regularization, restrict_to=restrict_to + ) + self._batch_hvp = create_batch_hvp_function( + model, loss, reverse_only=reverse_only + ) + + def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: + return self._batch_hvp(self.params_to_restrict_to, batch.x, batch.y, vec) + + +class GaussNewtonBatchOperation(ModelBasedBatchOperation): + def __init__( + self, + model: torch.nn.Module, + loss: LossType, + regularization: float = 0.0, + gradient_provider_factory: Union[ + GradientProviderFactoryType, + Type[TorchPerSampleGradientProvider], + ] = TorchPerSampleAutoGrad, + restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, + ): + super().__init__( + model, loss, regularization=regularization, restrict_to=restrict_to + ) + self.gradient_provider = gradient_provider_factory( + model, loss, self.params_to_restrict_to + ) + + def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: + flat_grads = self.gradient_provider.per_sample_flat_gradient(batch) + result = rank_one_mvp(flat_grads, vec) + + if self.regularization > 0.0: + result += self.regularization * vec + + return result + + def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + return self.apply_to_vec(batch, mat) + + def to(self, device: torch.device): + self.gradient_provider = self.gradient_provider.to(device) + return super().to(device) + + +class InverseHarmonicMeanBatchOperation(ModelBasedBatchOperation): + def __init__( + self, + model: torch.nn.Module, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + regularization: float, + gradient_provider_factory: Union[ + GradientProviderFactoryType, + Type[TorchPerSampleGradientProvider], + ] = TorchPerSampleAutoGrad, + restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, + ): + if regularization <= 0: + raise ValueError("regularization must be positive") + + super().__init__( + model, loss, regularization=regularization, restrict_to=restrict_to + ) + self.regularization = regularization + self.gradient_provider = gradient_provider_factory( + model, loss, self.params_to_restrict_to + ) + + def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: + grads = self.gradient_provider.per_sample_flat_gradient(batch) + return ( + inverse_rank_one_update(grads, vec, self.regularization) + / self.regularization + ) + + def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + return self.apply_to_vec(batch, mat) + + def to(self, device: torch.device): + super().to(device) + self.gradient_provider.params_to_restrict_to = self.params_to_restrict_to + return self diff --git a/src/pydvl/influence/torch/operator/gradient_provider.py b/src/pydvl/influence/torch/operator/gradient_provider.py index 5b06496bb..76ca84010 100644 --- a/src/pydvl/influence/torch/operator/gradient_provider.py +++ b/src/pydvl/influence/torch/operator/gradient_provider.py @@ -1,25 +1,23 @@ from abc import ABC, abstractmethod -from typing import Dict, Callable, Optional +from typing import Callable, Dict, Optional import torch from torch.func import functional_call +from ...types import PerSampleGradientProvider from ..functional import ( + create_matrix_jacobian_product_function, create_per_sample_gradient_function, create_per_sample_mixed_derivative_function, - create_matrix_jacobian_product_function, ) - from ..util import ( - flatten_dimensions, + BlockMode, LossType, - TorchBatch, ModelParameterDictBuilder, - BlockMode, + TorchBatch, + flatten_dimensions, ) -from ...types import PerSampleGradientProvider - class TorchPerSampleGradientProvider( PerSampleGradientProvider[TorchBatch, torch.Tensor], ABC diff --git a/tests/influence/torch/operator/__init__.py b/tests/influence/torch/operator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/influence/torch/operator/test_batch_operation.py b/tests/influence/torch/operator/test_batch_operation.py new file mode 100644 index 000000000..7c77c6b19 --- /dev/null +++ b/tests/influence/torch/operator/test_batch_operation.py @@ -0,0 +1,58 @@ +import pytest +import torch +from dataclasses import astuple + +from pydvl.influence.torch.operator.batch_operation import (HessianBatchOperation, + GaussNewtonBatchOperation) +from pydvl.influence.torch.util import TorchBatch + +from ..test_util import model_data, test_parameters + + +@pytest.mark.torch +@pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-5) for tp in test_parameters], + indirect=["model_data"], +) +def test_hessian_batch_operation(model_data, tol: float): + torch_model, x, y, vec, h_analytical = model_data + + params = dict(torch_model.named_parameters()) + + hessian_op = HessianBatchOperation(torch_model, torch.nn.functional.mse_loss, + restrict_to=params) + hvp_autograd = hessian_op.apply_to_vec(TorchBatch(x, y), vec) + + assert torch.allclose(hvp_autograd, h_analytical @ vec, rtol=tol) + + +@pytest.mark.torch +@pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], +) +def test_gauss_newton_batch_operation(model_data, tol: float): + torch_model, x, y, vec, _ = model_data + + y_pred = torch_model(x) + out_features = y_pred.shape[1] + dl_dw = torch.vmap( + lambda r, s, t: 2 / float(out_features) * (s - t).view(-1, 1) @ r.view(1, -1) + )(x, y_pred, y) + dl_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) + grad_analytical = torch.cat([dl_dw.reshape(x.shape[0], -1), dl_db], dim=-1) + gn_mat_analytical = torch.sum( + torch.func.vmap(lambda t: t.unsqueeze(-1) * t.unsqueeze(-1).t()) + (grad_analytical), dim=0) + + params = dict(torch_model.named_parameters()) + + gn_op = GaussNewtonBatchOperation(torch_model, torch.nn.functional.mse_loss, + restrict_to=params) + gn_autograd = gn_op.apply_to_vec(TorchBatch(x, y), vec) + + gn_analytical = gn_mat_analytical @ vec + + assert torch.allclose(gn_autograd, gn_analytical, atol=1e-5, rtol=tol) diff --git a/tests/influence/torch/operator/test_gradient_provider.py b/tests/influence/torch/operator/test_gradient_provider.py new file mode 100644 index 000000000..051f98d74 --- /dev/null +++ b/tests/influence/torch/operator/test_gradient_provider.py @@ -0,0 +1,103 @@ +import torch +import pytest +import numpy as np + +from pydvl.influence.torch.operator.gradient_provider import TorchPerSampleAutoGrad +from pydvl.influence.torch.util import TorchBatch + +from ...conftest import ( + linear_mixed_second_derivative_analytical, + linear_model, +) + +from ..conftest import DATA_OUTPUT_NOISE, linear_mvp_model + + +class TestTorchPerSampleAutograd: + + @pytest.mark.torch + @pytest.mark.parametrize( + "in_features, out_features, batch_size", + [(46, 6, 632), (50, 3, 120), (100, 5, 120), (25, 10, 550)], + ) + def test_per_sample_gradient(self, in_features, out_features, batch_size): + model = torch.nn.Linear(in_features, out_features) + loss = torch.nn.functional.mse_loss + + x = torch.randn(batch_size, in_features, requires_grad=True) + y = torch.randn(batch_size, out_features) + params = {k: p.detach() for k, p in model.named_parameters() if p.requires_grad} + + gp = TorchPerSampleAutoGrad(model, loss, restrict_to=params) + gradients = gp.per_sample_gradient_dict(TorchBatch(x, y)) + flat_gradients = gp.per_sample_flat_gradient(TorchBatch(x, y)) + + # Compute analytical gradients + y_pred = model(x) + dL_dw = torch.vmap( + lambda r, s, t: 2 / float(out_features) * (s - t).view(-1, 1) @ r.view(1, -1) + )(x, y_pred, y) + dL_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) + + # Assert the gradient values for equality with analytical gradients + assert torch.allclose(gradients["weight"], dL_dw, atol=1e-5) + assert torch.allclose(gradients["bias"], dL_db, atol=1e-5) + assert torch.allclose(flat_gradients, torch.cat([dL_dw.reshape(batch_size, -1), dL_db], dim=-1), atol=1e-5) + + @pytest.mark.torch + @pytest.mark.parametrize( + "in_features, out_features, train_set_size", + [(46, 1, 1000), (50, 3, 100), (100, 5, 512), (25, 10, 734)], + ) + def test_mixed_derivatives(self, in_features, out_features, train_set_size): + A, b = linear_model((out_features, in_features), 5) + loss = torch.nn.functional.mse_loss + model = linear_mvp_model(A, b) + + data_model = lambda x: np.random.normal(x @ A.T + b, DATA_OUTPUT_NOISE) + train_x = np.random.uniform(size=[train_set_size, in_features]) + train_y = data_model(train_x) + + params = {k: p for k, p in model.named_parameters() if p.requires_grad} + + test_derivative = linear_mixed_second_derivative_analytical( + (A, b), + train_x, + train_y, + ) + + torch_train_x = torch.as_tensor(train_x) + torch_train_y = torch.as_tensor(train_y) + gp = TorchPerSampleAutoGrad(model, loss, restrict_to=params) + flat_functorch_mixed_derivatives = gp.per_sample_flat_mixed_gradient(TorchBatch(torch_train_x, torch_train_y)) + assert torch.allclose( + torch.as_tensor(test_derivative), + flat_functorch_mixed_derivatives.transpose(2, 1), + ) + + @pytest.mark.torch + @pytest.mark.parametrize( + "in_features, out_features, batch_size", + [(46, 1, 632), (50, 3, 120), (100, 5, 110), (25, 10, 500)], + ) + def test_matrix_jacobian_product(self, in_features, out_features, batch_size, pytorch_seed): + model = torch.nn.Linear(in_features, out_features) + params = {k: p for k, p in model.named_parameters() if p.requires_grad} + + x = torch.randn(batch_size, in_features, requires_grad=True) + y = torch.randn(batch_size, out_features, requires_grad=True) + y_pred = model(x) + + gp = TorchPerSampleAutoGrad(model, torch.nn.functional.mse_loss, restrict_to=params) + + G = torch.randn((10, out_features * (in_features + 1))) + mjp = gp.matrix_jacobian_product(TorchBatch(x, y), G) + + dL_dw = torch.vmap( + lambda r, s, t: 2 / float(out_features) * (s - t).view(-1, 1) @ r.view(1, -1) + )(x, y_pred, y) + dL_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) + analytic_grads = torch.cat([dL_dw.reshape(dL_dw.shape[0], -1), dL_db], dim=1) + analytical_mjp = G @ analytic_grads.T + + assert torch.allclose(analytical_mjp, mjp, atol=1e-5, rtol=1e-3) \ No newline at end of file From ec572828b60625e6460f60134a80c8a0905cd8b0 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 22:43:20 +0200 Subject: [PATCH 12/43] Add submodules operator.base, operator.bilinear_form and operator.solve --- src/pydvl/influence/torch/operator/base.py | 178 ++++++++++++++++++ .../influence/torch/operator/bilinear_form.py | 36 ++++ src/pydvl/influence/torch/operator/solve.py | 39 ++++ 3 files changed, 253 insertions(+) create mode 100644 src/pydvl/influence/torch/operator/base.py create mode 100644 src/pydvl/influence/torch/operator/bilinear_form.py create mode 100644 src/pydvl/influence/torch/operator/solve.py diff --git a/src/pydvl/influence/torch/operator/base.py b/src/pydvl/influence/torch/operator/base.py new file mode 100644 index 000000000..06480e36b --- /dev/null +++ b/src/pydvl/influence/torch/operator/base.py @@ -0,0 +1,178 @@ +from abc import ABC, abstractmethod +from typing import Callable, Generator, Union, Type, Optional, Dict + +import torch +from torch import nn as nn +from torch.utils.data import DataLoader + +from ...array import SequenceAggregator, LazyChunkSequence +from ..util import TorchPointAverageAggregator, TorchChunkAverageAggregator, TorchBatch + +from .batch_operation import ( + BatchOperation, + GaussNewtonBatchOperation, + HessianBatchOperation, +) + +from .bilinear_form import OperatorBilinearForm +from .gradient_provider import ( + TorchPerSampleGradientProvider, + GradientProviderFactoryType, + TorchPerSampleAutoGrad, +) + +from ...types import Operator + + +class TorchOperator(Operator[torch.Tensor, OperatorBilinearForm], ABC): + def __init__(self, regularization: float = 0.0): + """ + Initializes the Operator with an optional regularization parameter. + + Args: + regularization: A non-negative float that represents the regularization + strength (default is 0.0). + + Raises: + ValueError: If the regularization parameter is negative. + """ + if regularization < 0: + raise ValueError("regularization must be non-negative") + self._regularization = regularization + + @property + def regularization(self): + return self._regularization + + @regularization.setter + def regularization(self, value: float): + if value < 0: + raise ValueError("regularization must be non-negative") + self._regularization = value + + @property + @abstractmethod + def device(self): + pass + + @property + @abstractmethod + def dtype(self): + pass + + @abstractmethod + def to(self, device: torch.device): + pass + + @abstractmethod + def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: + pass + + def as_bilinear_form(self): + return OperatorBilinearForm(self) + + def apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: + return self._apply_to_vec(vec.to(self.device)) + + def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: + return torch.func.vmap(self.apply_to_vec, in_dims=0, randomness="same")(mat) + + +class AggregateBatchOperator(TorchOperator): + def __init__( + self, + batch_operation: BatchOperation, + dataloader: DataLoader, + aggregator: SequenceAggregator[torch.Tensor], + ): + self.batch_operation = batch_operation + self.dataloader = dataloader + self.aggregator = aggregator + super().__init__(self.batch_operation.regularization) + + @property + def device(self): + return self.batch_operation.device + + @property + def dtype(self): + return self.batch_operation.dtype + + def to(self, device: torch.device): + self.batch_operation = self.batch_operation.to(device) + return self + + @property + def regularization(self): + return self._regularization + + @regularization.setter + def regularization(self, value: float): + self._regularization = value + self.batch_operation.regularization = value + + @property + def input_size(self): + return self.batch_operation.n_parameters + + def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: + return self._apply(vec, self.batch_operation.apply_to_vec) + + def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: + return self._apply(mat, self.batch_operation.apply_to_mat) + + def _apply( + self, + z: torch.Tensor, + batch_ops: Callable[[TorchBatch, torch.Tensor], torch.Tensor], + ): + def tensor_gen_factory() -> Generator[torch.Tensor, None, None]: + return ( + batch_ops( + TorchBatch(x.to(self.device), y.to(self.device)), z.to(self.device) + ) + for x, y in self.dataloader + ) + + lazy_tensor_sequence = LazyChunkSequence( + tensor_gen_factory, len_generator=len(self.dataloader) + ) + return self.aggregator(lazy_tensor_sequence) + + +class GaussNewtonOperator(AggregateBatchOperator): + def __init__( + self, + model: nn.Module, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + dataloader: DataLoader, + gradient_provider_factory: Union[ + GradientProviderFactoryType, + Type[TorchPerSampleGradientProvider], + ] = TorchPerSampleAutoGrad, + restrict_to: Optional[Dict[str, nn.Parameter]] = None, + ): + batch_op = GaussNewtonBatchOperation( + model, + loss, + gradient_provider_factory=gradient_provider_factory, + restrict_to=restrict_to, + ) + aggregator = TorchPointAverageAggregator() + super().__init__(batch_op, dataloader, aggregator) + + +class HessianOperator(AggregateBatchOperator): + def __init__( + self, + model: nn.Module, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + dataloader: DataLoader, + restrict_to: Optional[Dict[str, nn.Parameter]] = None, + reverse_only: bool = True, + ): + batch_op = HessianBatchOperation( + model, loss, restrict_to=restrict_to, reverse_only=reverse_only + ) + aggregator = TorchChunkAverageAggregator() + super().__init__(batch_op, dataloader, aggregator) diff --git a/src/pydvl/influence/torch/operator/bilinear_form.py b/src/pydvl/influence/torch/operator/bilinear_form.py new file mode 100644 index 000000000..8b77d519d --- /dev/null +++ b/src/pydvl/influence/torch/operator/bilinear_form.py @@ -0,0 +1,36 @@ +from typing import TYPE_CHECKING, Optional, cast +import torch + +from .gradient_provider import TorchPerSampleGradientProvider +from ..util import TorchBatch +from ...types import BilinearForm + +if TYPE_CHECKING: + from .base import TorchOperator + +class OperatorBilinearForm( + BilinearForm[torch.Tensor, TorchBatch, TorchPerSampleGradientProvider] +): + def __init__( + self, + operator: "TorchOperator", + ): + self.operator = operator + + def inner_product( + self, left: torch.Tensor, right: Optional[torch.Tensor] + ) -> torch.Tensor: + if right is None: + right = left + if left.shape[0] <= right.shape[0]: + return self._inner_product(left, right) + return self._inner_product(right, left).T + + def _inner_product(self, left: torch.Tensor, right: torch.Tensor) -> torch.Tensor: + left_result = self.operator.apply_to_mat(left) + + if left_result.ndim == right.ndim and left.shape[-1] == right.shape[-1]: + return left_result @ right.T + + return torch.einsum("ia,j...a->ij...", left_result, right) + diff --git a/src/pydvl/influence/torch/operator/solve.py b/src/pydvl/influence/torch/operator/solve.py new file mode 100644 index 000000000..69a640c9e --- /dev/null +++ b/src/pydvl/influence/torch/operator/solve.py @@ -0,0 +1,39 @@ +from typing import Callable, Union, Type, Optional, Dict +import torch +from torch import nn as nn +from torch.utils.data import DataLoader +from ..util import TorchPointAverageAggregator +from .base import TorchOperator, AggregateBatchOperator +from .batch_operation import InverseHarmonicMeanBatchOperation +from .gradient_provider import ( + GradientProviderFactoryType, + TorchPerSampleGradientProvider, + TorchPerSampleAutoGrad, +) + +__all__ = ["InverseHarmonicMeanOperator"] + + +class InverseHarmonicMeanOperator(AggregateBatchOperator): + def __init__( + self, + model: nn.Module, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + dataloader: DataLoader, + regularization: float, + gradient_provider_factory: Union[ + GradientProviderFactoryType, + Type[TorchPerSampleGradientProvider], + ] = TorchPerSampleAutoGrad, + restrict_to: Optional[Dict[str, nn.Parameter]] = None, + ): + batch_op = InverseHarmonicMeanBatchOperation( + model, + loss, + regularization, + gradient_provider_factory=gradient_provider_factory, + restrict_to=restrict_to, + ) + aggregator = TorchPointAverageAggregator(weighted=False) + super().__init__(batch_op, dataloader, aggregator) + From 5e37a4bbc3501616d2a5102bffb9a9c20e599340 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 22:45:52 +0200 Subject: [PATCH 13/43] Add tests for influence.torch.util --- src/pydvl/influence/torch/util.py | 3 ++ tests/influence/torch/test_util.py | 87 ++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index 0d6a55f99..78df32fd4 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -658,6 +658,9 @@ def rank_one_mvp(x: torch.Tensor, v: torch.Tensor) -> torch.Tensor: A matrix of size `(B, N)` where each column is the result of xx^T v for corresponding rows in x and v. """ + if v.ndim == 1: + result = torch.einsum("ij,kj->ki", x, v.unsqueeze(0)) @ x + return result.squeeze() return torch.einsum("ij,kj->ki", x, v) @ x diff --git a/tests/influence/torch/test_util.py b/tests/influence/torch/test_util.py index c63a34253..82510e640 100644 --- a/tests/influence/torch/test_util.py +++ b/tests/influence/torch/test_util.py @@ -17,10 +17,14 @@ lanzcos_low_rank_hessian_approx, ) from pydvl.influence.torch.util import ( + BlockMode, + ModelParameterDictBuilder, TorchLinalgEighException, TorchTensorContainerType, align_structure, flatten_dimensions, + inverse_rank_one_update, + rank_one_mvp, safe_torch_linalg_eigh, torch_dataset_to_dask_array, ) @@ -318,3 +322,86 @@ def test_safe_torch_linalg_eigh(): def test_safe_torch_linalg_eigh_exception(): with pytest.raises(TorchLinalgEighException): safe_torch_linalg_eigh(torch.randn([53000, 53000])) + + +@pytest.mark.torch +@pytest.mark.parametrize( + "x_dim_0, x_dim_1, v_dim_0", + [(10, 1, 12), (3, 2, 5), (4, 5, 30), (6, 6, 6), (1, 7, 7)], +) +def test_rank_one_mvp(x_dim_0, x_dim_1, v_dim_0): + X = torch.randn(x_dim_0, x_dim_1) + V = torch.randn(v_dim_0, x_dim_1) + + expected = ( + (torch.vmap(lambda x: x.unsqueeze(-1) * x.unsqueeze(-1).t())(X) @ V.t()) + .sum(dim=0) + .t() + ) + + result = rank_one_mvp(X, V) + + assert result.shape == V.shape + assert torch.allclose(result, expected, atol=1e-5, rtol=1e-4) + + +@pytest.mark.torch +@pytest.mark.parametrize( + "x_dim_0, x_dim_1, v_dim_0", + [(10, 1, 12), (3, 2, 5), (4, 5, 10), (6, 6, 6), (1, 7, 7)], +) +@pytest.mark.parametrize("reg", [0.1, 100, 1.0, 10]) +def test_inverse_rank_one_update(x_dim_0, x_dim_1, v_dim_0, reg): + X = torch.randn(x_dim_0, x_dim_1) + V = torch.randn(v_dim_0, x_dim_1) + + inverse_result = torch.zeros_like(V) + + for x in X: + rank_one_matrix = x.unsqueeze(-1) * x.unsqueeze(-1).t() + inverse_result += torch.linalg.solve( + rank_one_matrix + reg * torch.eye(rank_one_matrix.shape[0]), V, left=False + ) + + inverse_result /= X.shape[0] + result = inverse_rank_one_update(X, V, reg) + + assert torch.allclose(result, inverse_result, atol=1e-5) + + +class TestModelParameterDictBuilder: + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(5, 10) + self.fc2 = torch.nn.Linear(10, 5) + self.fc1.weight.requires_grad = False + + @pytest.fixture + def model(self): + return TestModelParameterDictBuilder.SimpleModel() + + @pytest.mark.parametrize("block_mode", [mode for mode in BlockMode]) + def test_build(self, block_mode, model): + builder = ModelParameterDictBuilder( + model=model, + detach=True, + ) + param_dict = builder.build(block_mode) + + if block_mode is BlockMode.FULL: + assert "" in param_dict + assert "fc1.weight" not in param_dict[""] + elif block_mode is BlockMode.PARAMETER_WISE: + assert "fc2.bias" in param_dict + assert len(param_dict["fc2.bias"]) > 0 + assert "fc1.weight" not in param_dict + elif block_mode is BlockMode.LAYER_WISE: + assert "fc2" in param_dict + assert "fc2.bias" in param_dict["fc2"] + assert "fc1.weight" not in param_dict["fc1"] + assert "fc1.bias" in param_dict["fc1"] + + assert all( + (not p.requires_grad for q in param_dict.values() for p in q.values()) + ) From 29eb69afd2c8971c2b6b83fd25272f9ecc8c9417 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 22:47:53 +0200 Subject: [PATCH 14/43] Add implementation of generic types to influence.torch.influence_function_model, implement InverseHarmonicMeanInfluence and add corresponding test --- .../torch/influence_function_model.py | 211 +++++++++++++++++- tests/influence/torch/test_influence_model.py | 55 +++++ 2 files changed, 262 insertions(+), 4 deletions(-) diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index b3d608a23..a3d490fcf 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -8,21 +8,23 @@ import logging from abc import ABC, abstractmethod -from typing import Callable, Dict, List, Optional, Tuple +from collections import OrderedDict +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn as nn from torch.utils.data import DataLoader from tqdm.auto import tqdm -from pydvl.utils.progress import log_duration - +from ...utils.progress import log_duration +from .. import InfluenceMode from ..base_influence_function_model import ( + ComposableInfluence, InfluenceFunctionModel, - InfluenceMode, NotImplementedLayerRepresentationException, UnsupportedInfluenceModeException, ) +from ..types import BlockMapper, OperatorGradientComposition from .functional import ( LowRankProductRepresentation, create_batch_hvp_function, @@ -34,9 +36,20 @@ model_hessian_low_rank, model_hessian_nystroem_approximation, ) +from .operator.base import TorchOperator +from .operator.gradient_provider import ( + TorchPerSampleAutoGrad, + TorchPerSampleGradientProvider, +) +from .operator.solve import InverseHarmonicMeanOperator, LowRankOperator from .pre_conditioner import PreConditioner from .util import ( + BlockMode, EkfacRepresentation, + LossType, + ModelInfoMixin, + ModelParameterDictBuilder, + TorchBatch, empirical_cross_entropy_loss_fn, flatten_dimensions, safe_torch_linalg_eigh, @@ -986,6 +999,7 @@ class ArnoldiInfluence(TorchInfluenceFunctionModel): Set this to False, if you can't afford to keep the full computation graph in memory. """ + low_rank_representation: LowRankProductRepresentation def __init__( @@ -1791,3 +1805,192 @@ def fit(self, data: DataLoader): self.model, self.loss, data, self.rank ) return self + + +class TorchOperatorGradientComposition( + OperatorGradientComposition[ + torch.Tensor, TorchBatch, TorchOperator, TorchPerSampleGradientProvider + ] +): + def to(self, device: torch.device): + self.gp = self.gp.to(device) + self.op = self.op.to(device) + return self + + +class TorchBlockMapper( + BlockMapper[torch.Tensor, TorchBatch, TorchOperatorGradientComposition] +): + def _split_to_blocks( + self, z: torch.Tensor, dim: int = -1 + ) -> OrderedDict[str, torch.Tensor]: + block_sizes = [bi.op.input_size for bi in self.composable_block_dict.values()] + + block_dict = OrderedDict( + zip( + list(self.composable_block_dict.keys()), + torch.split(z, block_sizes, dim=dim), + ) + ) + return block_dict + + def to(self, device: torch.device): + self.composable_block_dict = OrderedDict( + [(k, bi.to(device)) for k, bi in self.composable_block_dict.items()] + ) + return self + + +class TorchComposableInfluence( + ComposableInfluence[torch.Tensor, TorchBatch, DataLoader, TorchBlockMapper], + ModelInfoMixin, +): + def __init__( + self, + model: torch.nn.Module, + block_structure: Union[ + BlockMode, OrderedDict[str, OrderedDict[str, torch.nn.Parameter]] + ] = BlockMode.FULL, + regularization: Optional[Union[float, Dict[str, float]]] = None, + ): + if isinstance(block_structure, BlockMode): + self.parameter_dict = ModelParameterDictBuilder(model).build( + block_structure + ) + else: + self.parameter_dict = block_structure + + self._regularization_dict = self._build_regularization_dict(regularization) + + super().__init__(model) + + @property + def regularization(self) -> Dict[str, float]: + return self._regularization_dict + + @regularization.setter + def regularization(self, value: Union[float, Dict[str, float]]): + self._regularization_dict = self._build_regularization_dict(value) + + @property + def block_names(self) -> List[str]: + return list(self.parameter_dict.keys()) + + @abstractmethod + def with_regularization( + self, regularization: Union[float, Dict[str, float]] + ) -> TorchComposableInfluence: + pass + + def _build_regularization_dict( + self, regularization: Optional[Union[float, Dict[str, Optional[float]]]] + ) -> Dict[str, Optional[float]]: + if regularization is None or isinstance(regularization, float): + return { + k: self._validate_regularization(k, regularization) + for k in self.block_names + } + + if set(regularization.keys()).issubset(set(self.block_names)): + raise ValueError( + f"The regularization must be a float or the keys of the regularization" + f"dictionary must match a subset of" + f"block names: \n {self.block_names}.\n Found not in block names: \n" + f"{set(regularization.keys()).difference(set(self.block_names))}" + ) + return { + k: self._validate_regularization(k, regularization.get(k, None)) + for k in self.block_names + } + + @staticmethod + def _validate_regularization( + block_name: str, value: Optional[float] + ) -> Optional[float]: + if isinstance(value, float) and value < 0.0: + raise ValueError( + f"The regularization for block '{block_name}' must be non-negative, " + f"but found {value=}" + ) + return value + + @abstractmethod + def _create_block( + self, + block_params: Dict[str, torch.nn.Parameter], + data: DataLoader, + regularization: Optional[float], + ) -> TorchOperatorGradientComposition: + pass + + def _create_block_mapper(self, data: DataLoader) -> TorchBlockMapper: + block_influence_dict = OrderedDict() + for k, p in self.parameter_dict.items(): + reg = self._regularization_dict.get(k, None) + reg = self._validate_regularization(k, reg) + block_influence_dict[k] = self._create_block(p, data, reg).to(self.device) + + return TorchBlockMapper(block_influence_dict) + + @staticmethod + def _create_batch(x: torch.Tensor, y: torch.Tensor) -> TorchBatch: + return TorchBatch(x, y) + + def to(self, device: torch.device): + self.model = self.model.to(device) + if hasattr(self, "block_mapper") and self.block_mapper is not None: + self.block_mapper = self.block_mapper.to(device) + return self + + +class InverseHarmonicMeanInfluence(TorchComposableInfluence): + def __init__( + self, + model: torch.nn.Module, + loss: LossType, + regularization: Union[float, Dict[str, float]], + block_structure: Union[ + BlockMode, OrderedDict[str, OrderedDict[str, torch.Tensor]] + ] = BlockMode.FULL, + ): + super().__init__(model, block_structure, regularization=regularization) + self.gradient_provider_factory = TorchPerSampleAutoGrad + self.loss = loss + + @staticmethod + def _validate_regularization( + block_name: str, value: Optional[float] + ) -> Optional[float]: + if value is None or value <= 0.0: + raise ValueError( + f"The regularization for block '{block_name}' must be a positive float," + f"but found {value=}" + ) + return value + + def _create_block( + self, + block_params: Dict[str, torch.nn.Parameter], + data: DataLoader, + regularization: Optional[float], + ) -> TorchOperatorGradientComposition: + op = InverseHarmonicMeanOperator( + self.model, + self.loss, + data, + regularization, + self.gradient_provider_factory, + restrict_to=block_params, + ) + gp = self.gradient_provider_factory( + self.model, self.loss, restrict_to=block_params + ) + return TorchOperatorGradientComposition(op, gp) + + def with_regularization( + self, regularization: Union[float, Dict[str, float]] + ) -> TorchComposableInfluence: + self._regularization_dict = self._build_regularization_dict(regularization) + for k, reg in self._regularization_dict.items(): + self.block_mapper.composable_block_dict[k].op.regularization = reg + return self diff --git a/tests/influence/torch/test_influence_model.py b/tests/influence/torch/test_influence_model.py index d2203a84e..6cc2ad0de 100644 --- a/tests/influence/torch/test_influence_model.py +++ b/tests/influence/torch/test_influence_model.py @@ -4,6 +4,7 @@ import numpy as np import pytest from numpy.typing import NDArray +from scipy.stats import pearsonr, spearmanr from pydvl.influence.base_influence_function_model import ( NotFittedException, @@ -14,6 +15,7 @@ CgInfluence, DirectInfluence, EkfacInfluence, + InverseHarmonicMeanInfluence, LissaInfluence, NystroemSketchInfluence, ) @@ -22,6 +24,7 @@ NystroemPreConditioner, PreConditioner, ) +from pydvl.influence.torch.util import BlockMode from tests.influence.torch.conftest import minimal_training torch = pytest.importorskip("torch") @@ -754,3 +757,55 @@ def test_influences_cg( .numpy() ) assert np.allclose(single_influence, direct_factors[0], atol=1e-6, rtol=1e-4) + + +composable_influence_factories = [InverseHarmonicMeanInfluence] + + +@pytest.mark.parametrize("composable_influence_factory", composable_influence_factories) +@pytest.mark.parametrize("block_mode", [mode for mode in BlockMode]) +@pytest.mark.torch +def test_composable_influence( + test_case: TestCase, + model_and_data: Tuple[ + torch.nn.Module, + Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + direct_influences, + direct_sym_influences, + device: torch.device, + block_mode, + composable_influence_factory, +): + model, loss, x_train, y_train, x_test, y_test = model_and_data + + train_dataloader = DataLoader( + TensorDataset(x_train, y_train), batch_size=test_case.batch_size + ) + + harmonic_mean_influence = composable_influence_factory( + model, loss, test_case.hessian_reg, block_structure=block_mode + ).to(device) + harmonic_mean_influence = harmonic_mean_influence.fit(train_dataloader) + harmonic_mean_influence_values = ( + harmonic_mean_influence.influences( + x_test, y_test, x_train, y_train, mode=test_case.mode + ) + .cpu() + .numpy() + ) + + threshold = 0.999 + flat_direct_influences = direct_influences.reshape(-1) + flat_harmonic_influences = harmonic_mean_influence_values.reshape(-1) + assert np.all( + pearsonr(flat_direct_influences, flat_harmonic_influences).statistic > threshold + ) + assert np.all( + spearmanr(flat_direct_influences, flat_harmonic_influences).statistic + > threshold + ) From 07f80bdbfec33b56734b2a381d98a882b263d0e2 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Tue, 21 May 2024 22:49:41 +0200 Subject: [PATCH 15/43] Fix linting issues in operator subpackage --- src/pydvl/influence/torch/operator/base.py | 13 +++--- .../torch/operator/batch_operation.py | 14 ++----- .../influence/torch/operator/bilinear_form.py | 7 ++-- src/pydvl/influence/torch/operator/solve.py | 9 ++-- .../torch/operator/test_batch_operation.py | 26 ++++++++---- .../torch/operator/test_gradient_provider.py | 41 ++++++++++++------- 6 files changed, 61 insertions(+), 49 deletions(-) diff --git a/src/pydvl/influence/torch/operator/base.py b/src/pydvl/influence/torch/operator/base.py index 06480e36b..c56398af9 100644 --- a/src/pydvl/influence/torch/operator/base.py +++ b/src/pydvl/influence/torch/operator/base.py @@ -1,28 +1,25 @@ from abc import ABC, abstractmethod -from typing import Callable, Generator, Union, Type, Optional, Dict +from typing import Callable, Dict, Generator, Optional, Type, Union import torch from torch import nn as nn from torch.utils.data import DataLoader -from ...array import SequenceAggregator, LazyChunkSequence -from ..util import TorchPointAverageAggregator, TorchChunkAverageAggregator, TorchBatch - +from ...array import LazyChunkSequence, SequenceAggregator +from ...types import Operator +from ..util import TorchBatch, TorchChunkAverageAggregator, TorchPointAverageAggregator from .batch_operation import ( BatchOperation, GaussNewtonBatchOperation, HessianBatchOperation, ) - from .bilinear_form import OperatorBilinearForm from .gradient_provider import ( - TorchPerSampleGradientProvider, GradientProviderFactoryType, TorchPerSampleAutoGrad, + TorchPerSampleGradientProvider, ) -from ...types import Operator - class TorchOperator(Operator[torch.Tensor, OperatorBilinearForm], ABC): def __init__(self, regularization: float = 0.0): diff --git a/src/pydvl/influence/torch/operator/batch_operation.py b/src/pydvl/influence/torch/operator/batch_operation.py index 45dcb8682..9b7d67b25 100644 --- a/src/pydvl/influence/torch/operator/batch_operation.py +++ b/src/pydvl/influence/torch/operator/batch_operation.py @@ -1,20 +1,14 @@ from abc import ABC, abstractmethod -from typing import Callable, Optional, Dict, Union, Type +from typing import Callable, Dict, Optional, Type, Union import torch from ..functional import create_batch_hvp_function -from ..util import ( - inverse_rank_one_update, - rank_one_mvp, - LossType, - TorchBatch, -) - +from ..util import LossType, TorchBatch, inverse_rank_one_update, rank_one_mvp from .gradient_provider import ( - TorchPerSampleGradientProvider, - TorchPerSampleAutoGrad, GradientProviderFactoryType, + TorchPerSampleAutoGrad, + TorchPerSampleGradientProvider, ) diff --git a/src/pydvl/influence/torch/operator/bilinear_form.py b/src/pydvl/influence/torch/operator/bilinear_form.py index 8b77d519d..4e7cf94a2 100644 --- a/src/pydvl/influence/torch/operator/bilinear_form.py +++ b/src/pydvl/influence/torch/operator/bilinear_form.py @@ -1,13 +1,15 @@ from typing import TYPE_CHECKING, Optional, cast + import torch -from .gradient_provider import TorchPerSampleGradientProvider -from ..util import TorchBatch from ...types import BilinearForm +from ..util import TorchBatch +from .gradient_provider import TorchPerSampleGradientProvider if TYPE_CHECKING: from .base import TorchOperator + class OperatorBilinearForm( BilinearForm[torch.Tensor, TorchBatch, TorchPerSampleGradientProvider] ): @@ -33,4 +35,3 @@ def _inner_product(self, left: torch.Tensor, right: torch.Tensor) -> torch.Tenso return left_result @ right.T return torch.einsum("ia,j...a->ij...", left_result, right) - diff --git a/src/pydvl/influence/torch/operator/solve.py b/src/pydvl/influence/torch/operator/solve.py index 69a640c9e..8c4b4ee55 100644 --- a/src/pydvl/influence/torch/operator/solve.py +++ b/src/pydvl/influence/torch/operator/solve.py @@ -1,14 +1,16 @@ -from typing import Callable, Union, Type, Optional, Dict +from typing import Callable, Dict, Optional, Type, Union + import torch from torch import nn as nn from torch.utils.data import DataLoader + from ..util import TorchPointAverageAggregator -from .base import TorchOperator, AggregateBatchOperator +from .base import AggregateBatchOperator, TorchOperator from .batch_operation import InverseHarmonicMeanBatchOperation from .gradient_provider import ( GradientProviderFactoryType, - TorchPerSampleGradientProvider, TorchPerSampleAutoGrad, + TorchPerSampleGradientProvider, ) __all__ = ["InverseHarmonicMeanOperator"] @@ -36,4 +38,3 @@ def __init__( ) aggregator = TorchPointAverageAggregator(weighted=False) super().__init__(batch_op, dataloader, aggregator) - diff --git a/tests/influence/torch/operator/test_batch_operation.py b/tests/influence/torch/operator/test_batch_operation.py index 7c77c6b19..1f5deed71 100644 --- a/tests/influence/torch/operator/test_batch_operation.py +++ b/tests/influence/torch/operator/test_batch_operation.py @@ -1,9 +1,12 @@ +from dataclasses import astuple + import pytest import torch -from dataclasses import astuple -from pydvl.influence.torch.operator.batch_operation import (HessianBatchOperation, - GaussNewtonBatchOperation) +from pydvl.influence.torch.operator.batch_operation import ( + GaussNewtonBatchOperation, + HessianBatchOperation, +) from pydvl.influence.torch.util import TorchBatch from ..test_util import model_data, test_parameters @@ -20,8 +23,9 @@ def test_hessian_batch_operation(model_data, tol: float): params = dict(torch_model.named_parameters()) - hessian_op = HessianBatchOperation(torch_model, torch.nn.functional.mse_loss, - restrict_to=params) + hessian_op = HessianBatchOperation( + torch_model, torch.nn.functional.mse_loss, restrict_to=params + ) hvp_autograd = hessian_op.apply_to_vec(TorchBatch(x, y), vec) assert torch.allclose(hvp_autograd, h_analytical @ vec, rtol=tol) @@ -44,13 +48,17 @@ def test_gauss_newton_batch_operation(model_data, tol: float): dl_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) grad_analytical = torch.cat([dl_dw.reshape(x.shape[0], -1), dl_db], dim=-1) gn_mat_analytical = torch.sum( - torch.func.vmap(lambda t: t.unsqueeze(-1) * t.unsqueeze(-1).t()) - (grad_analytical), dim=0) + torch.func.vmap(lambda t: t.unsqueeze(-1) * t.unsqueeze(-1).t())( + grad_analytical + ), + dim=0, + ) params = dict(torch_model.named_parameters()) - gn_op = GaussNewtonBatchOperation(torch_model, torch.nn.functional.mse_loss, - restrict_to=params) + gn_op = GaussNewtonBatchOperation( + torch_model, torch.nn.functional.mse_loss, restrict_to=params + ) gn_autograd = gn_op.apply_to_vec(TorchBatch(x, y), vec) gn_analytical = gn_mat_analytical @ vec diff --git a/tests/influence/torch/operator/test_gradient_provider.py b/tests/influence/torch/operator/test_gradient_provider.py index 051f98d74..47324acee 100644 --- a/tests/influence/torch/operator/test_gradient_provider.py +++ b/tests/influence/torch/operator/test_gradient_provider.py @@ -1,20 +1,15 @@ -import torch -import pytest import numpy as np +import pytest +import torch from pydvl.influence.torch.operator.gradient_provider import TorchPerSampleAutoGrad from pydvl.influence.torch.util import TorchBatch -from ...conftest import ( - linear_mixed_second_derivative_analytical, - linear_model, -) - +from ...conftest import linear_mixed_second_derivative_analytical, linear_model from ..conftest import DATA_OUTPUT_NOISE, linear_mvp_model class TestTorchPerSampleAutograd: - @pytest.mark.torch @pytest.mark.parametrize( "in_features, out_features, batch_size", @@ -35,14 +30,21 @@ def test_per_sample_gradient(self, in_features, out_features, batch_size): # Compute analytical gradients y_pred = model(x) dL_dw = torch.vmap( - lambda r, s, t: 2 / float(out_features) * (s - t).view(-1, 1) @ r.view(1, -1) + lambda r, s, t: 2 + / float(out_features) + * (s - t).view(-1, 1) + @ r.view(1, -1) )(x, y_pred, y) dL_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) # Assert the gradient values for equality with analytical gradients assert torch.allclose(gradients["weight"], dL_dw, atol=1e-5) assert torch.allclose(gradients["bias"], dL_db, atol=1e-5) - assert torch.allclose(flat_gradients, torch.cat([dL_dw.reshape(batch_size, -1), dL_db], dim=-1), atol=1e-5) + assert torch.allclose( + flat_gradients, + torch.cat([dL_dw.reshape(batch_size, -1), dL_db], dim=-1), + atol=1e-5, + ) @pytest.mark.torch @pytest.mark.parametrize( @@ -69,7 +71,9 @@ def test_mixed_derivatives(self, in_features, out_features, train_set_size): torch_train_x = torch.as_tensor(train_x) torch_train_y = torch.as_tensor(train_y) gp = TorchPerSampleAutoGrad(model, loss, restrict_to=params) - flat_functorch_mixed_derivatives = gp.per_sample_flat_mixed_gradient(TorchBatch(torch_train_x, torch_train_y)) + flat_functorch_mixed_derivatives = gp.per_sample_flat_mixed_gradient( + TorchBatch(torch_train_x, torch_train_y) + ) assert torch.allclose( torch.as_tensor(test_derivative), flat_functorch_mixed_derivatives.transpose(2, 1), @@ -80,7 +84,9 @@ def test_mixed_derivatives(self, in_features, out_features, train_set_size): "in_features, out_features, batch_size", [(46, 1, 632), (50, 3, 120), (100, 5, 110), (25, 10, 500)], ) - def test_matrix_jacobian_product(self, in_features, out_features, batch_size, pytorch_seed): + def test_matrix_jacobian_product( + self, in_features, out_features, batch_size, pytorch_seed + ): model = torch.nn.Linear(in_features, out_features) params = {k: p for k, p in model.named_parameters() if p.requires_grad} @@ -88,16 +94,21 @@ def test_matrix_jacobian_product(self, in_features, out_features, batch_size, py y = torch.randn(batch_size, out_features, requires_grad=True) y_pred = model(x) - gp = TorchPerSampleAutoGrad(model, torch.nn.functional.mse_loss, restrict_to=params) + gp = TorchPerSampleAutoGrad( + model, torch.nn.functional.mse_loss, restrict_to=params + ) G = torch.randn((10, out_features * (in_features + 1))) mjp = gp.matrix_jacobian_product(TorchBatch(x, y), G) dL_dw = torch.vmap( - lambda r, s, t: 2 / float(out_features) * (s - t).view(-1, 1) @ r.view(1, -1) + lambda r, s, t: 2 + / float(out_features) + * (s - t).view(-1, 1) + @ r.view(1, -1) )(x, y_pred, y) dL_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) analytic_grads = torch.cat([dL_dw.reshape(dL_dw.shape[0], -1), dL_db], dim=1) analytical_mjp = G @ analytic_grads.T - assert torch.allclose(analytical_mjp, mjp, atol=1e-5, rtol=1e-3) \ No newline at end of file + assert torch.allclose(analytical_mjp, mjp, atol=1e-5, rtol=1e-3) From 3b7289ca32a8f4d774efe7ffed6bd38527f26e29 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 22 May 2024 00:38:18 +0200 Subject: [PATCH 16/43] Fix type-checking issues --- .../torch/influence_function_model.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index a3d490fcf..1114a641e 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -41,7 +41,7 @@ TorchPerSampleAutoGrad, TorchPerSampleGradientProvider, ) -from .operator.solve import InverseHarmonicMeanOperator, LowRankOperator +from .operator.solve import InverseHarmonicMeanOperator from .pre_conditioner import PreConditioner from .util import ( BlockMode, @@ -1851,7 +1851,7 @@ def __init__( block_structure: Union[ BlockMode, OrderedDict[str, OrderedDict[str, torch.nn.Parameter]] ] = BlockMode.FULL, - regularization: Optional[Union[float, Dict[str, float]]] = None, + regularization: Optional[Union[float, Dict[str, Optional[float]]]] = None, ): if isinstance(block_structure, BlockMode): self.parameter_dict = ModelParameterDictBuilder(model).build( @@ -1864,21 +1864,13 @@ def __init__( super().__init__(model) - @property - def regularization(self) -> Dict[str, float]: - return self._regularization_dict - - @regularization.setter - def regularization(self, value: Union[float, Dict[str, float]]): - self._regularization_dict = self._build_regularization_dict(value) - @property def block_names(self) -> List[str]: return list(self.parameter_dict.keys()) @abstractmethod def with_regularization( - self, regularization: Union[float, Dict[str, float]] + self, regularization: Union[float, Dict[str, Optional[float]]] ) -> TorchComposableInfluence: pass @@ -1948,7 +1940,7 @@ def __init__( self, model: torch.nn.Module, loss: LossType, - regularization: Union[float, Dict[str, float]], + regularization: Union[float, Dict[str, Optional[float]]], block_structure: Union[ BlockMode, OrderedDict[str, OrderedDict[str, torch.Tensor]] ] = BlockMode.FULL, @@ -1974,6 +1966,7 @@ def _create_block( data: DataLoader, regularization: Optional[float], ) -> TorchOperatorGradientComposition: + assert regularization is not None op = InverseHarmonicMeanOperator( self.model, self.loss, @@ -1988,7 +1981,7 @@ def _create_block( return TorchOperatorGradientComposition(op, gp) def with_regularization( - self, regularization: Union[float, Dict[str, float]] + self, regularization: Union[float, Dict[str, Optional[float]]] ) -> TorchComposableInfluence: self._regularization_dict = self._build_regularization_dict(regularization) for k, reg in self._regularization_dict.items(): From 531aaccc2d05b1caf09d237e0eef95fd5176d79c Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 22 May 2024 16:57:33 +0200 Subject: [PATCH 17/43] Add and improve docstrings --- .../base_influence_function_model.py | 8 -- .../torch/influence_function_model.py | 36 ++++++ .../influence/torch/operator/bilinear_form.py | 24 +++- .../torch/operator/gradient_provider.py | 103 ++++++++++++++++++ src/pydvl/influence/types.py | 3 +- 5 files changed, 163 insertions(+), 11 deletions(-) diff --git a/src/pydvl/influence/base_influence_function_model.py b/src/pydvl/influence/base_influence_function_model.py index b6854c250..5077f8d88 100644 --- a/src/pydvl/influence/base_influence_function_model.py +++ b/src/pydvl/influence/base_influence_function_model.py @@ -217,14 +217,6 @@ class ComposableInfluence( block_mapper: BlockMapperType - @property - def n_parameters(self): - return super().n_parameters() - - @property - def is_thread_safe(self) -> bool: - return False - @property def is_fitted(self): try: diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index 1114a641e..4fd89a9e6 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -1812,6 +1812,19 @@ class TorchOperatorGradientComposition( torch.Tensor, TorchBatch, TorchOperator, TorchPerSampleGradientProvider ] ): + """ + Representing a composable block that integrates an [TorchOperator] + [pydvl.influence.torch.operator.base.TorchOperator] and + a [TorchPerSampleGradientProvider] + [pydvl.influence.torch.operator.gradient_provider.TorchPerSampleGradientProvider] + + This block is designed to be flexible, handling different computational modes via + an abstract operator and gradient provider. + """ + + def __init__(self, op: TorchOperator, gp: TorchPerSampleGradientProvider): + super().__init__(op, gp) + def to(self, device: torch.device): self.gp = self.gp.to(device) self.op = self.op.to(device) @@ -1821,6 +1834,20 @@ def to(self, device: torch.device): class TorchBlockMapper( BlockMapper[torch.Tensor, TorchBatch, TorchOperatorGradientComposition] ): + """ + Class for mapping operations across multiple compositional blocks represented by + instances of [TorchOperatorGradientComposition] + [pydvl.influence.torch.influence_function_model.TorchOperatorGradientComposition]. + + This class takes a dictionary of compositional blocks and applies their methods to + batches or tensors, and aggregates the results. + """ + + def __init__( + self, composable_block_dict: OrderedDict[str, TorchOperatorGradientComposition] + ): + super().__init__(composable_block_dict) + def _split_to_blocks( self, z: torch.Tensor, dim: int = -1 ) -> OrderedDict[str, torch.Tensor]: @@ -1844,6 +1871,7 @@ def to(self, device: torch.device): class TorchComposableInfluence( ComposableInfluence[torch.Tensor, TorchBatch, DataLoader, TorchBlockMapper], ModelInfoMixin, + ABC, ): def __init__( self, @@ -1949,6 +1977,14 @@ def __init__( self.gradient_provider_factory = TorchPerSampleAutoGrad self.loss = loss + @property + def n_parameters(self): + return super().n_parameters() + + @property + def is_thread_safe(self) -> bool: + return False + @staticmethod def _validate_regularization( block_name: str, value: Optional[float] diff --git a/src/pydvl/influence/torch/operator/bilinear_form.py b/src/pydvl/influence/torch/operator/bilinear_form.py index 4e7cf94a2..0fbacd718 100644 --- a/src/pydvl/influence/torch/operator/bilinear_form.py +++ b/src/pydvl/influence/torch/operator/bilinear_form.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Optional import torch @@ -13,6 +13,15 @@ class OperatorBilinearForm( BilinearForm[torch.Tensor, TorchBatch, TorchPerSampleGradientProvider] ): + r""" + Base class for bilinear forms based on an instance of + [TorchOperator][pydvl.influence.torch.operator.base.TorchOperator]. This means it + computes weighted inner products of the form: + + $$ \langle \operatorname{Op}(x), y \rangle $$ + + """ + def __init__( self, operator: "TorchOperator", @@ -22,6 +31,19 @@ def __init__( def inner_product( self, left: torch.Tensor, right: Optional[torch.Tensor] ) -> torch.Tensor: + r""" + Computes the weighted inner product of two vectors, i.e. + + $$ \langle x, y \rangle_{B} = \langle \operatorname{Op}(x), y \rangle $$ + + Args: + left: The first tensor in the inner product computation. + right: The second tensor, optional; if not provided, the inner product will + use `left` tensor for both arguments. + + Returns: + A tensor representing the inner product. + """ if right is None: right = left if left.shape[0] <= right.shape[0]: diff --git a/src/pydvl/influence/torch/operator/gradient_provider.py b/src/pydvl/influence/torch/operator/gradient_provider.py index 76ca84010..f3cf8e554 100644 --- a/src/pydvl/influence/torch/operator/gradient_provider.py +++ b/src/pydvl/influence/torch/operator/gradient_provider.py @@ -22,6 +22,29 @@ class TorchPerSampleGradientProvider( PerSampleGradientProvider[TorchBatch, torch.Tensor], ABC ): + r""" + Abstract base class for calculating per-sample gradients of a function defined by + a [torch.nn.Module][torch.nn.Module] and a loss function. + + This class must be subclassed with implementations for its abstract methods tailored + to specific gradient computation needs, e.g. using [torch.autograd][torch.autograd] + or stochastic finite differences. + + Consider a function + + $$ \ell: \mathbb{R}^{d_1} \times \mathbb{R}^{d_2} \times \mathbb{R}^{n} \times + \mathbb{R}^{n}, \quad \ell(\omega_1, \omega_2, x, y) = + \operatorname{loss}(f(\omega_1, \omega_2; x), y) $$ + + e.g. a two layer neural network $f$ with a loss function, then this object should + compute the expressions: + + $$ \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega}\ell(\omega_1, \omega_2, x, y) \cdot v$$ + + """ + def __init__( self, model: torch.nn.Module, @@ -76,12 +99,52 @@ def _detach_dict(tensor_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor return {k: g.detach() if g.requires_grad else g for k, g in tensor_dict.items()} def per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + r""" + Computes and returns a dictionary mapping gradient names to their respective + per-sample gradients. Given the example in the class docstring, this means + + $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y}), $$ + + where the first dimension of the resulting tensors is always considered to be + the batch dimension, so the shape of the resulting tensors are $(N, d_i)$, + where $N$ is the number of samples in the batch. + + Args: + batch: The batch of data for which to compute gradients. + + Returns: + A dictionary where keys are gradient identifiers and values are the + gradients computed per sample. + """ gradient_dict = self._per_sample_gradient_dict(batch.to(self.device)) return self._detach_dict(gradient_dict) def per_sample_mixed_gradient_dict( self, batch: TorchBatch ) -> Dict[str, torch.Tensor]: + r""" + Computes and returns a dictionary mapping gradient names to their respective + per-sample mixed gradients. In this context, mixed gradients refer to computing + gradients with respect to the instance definition in addition to + compute derivatives with respect to the input batch. + Given the example in the class docstring, this means + + $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, + \omega_2, \text{batch.x}, \text{batch.y}), $$ + + where the first dimension of the resulting tensors is always considered to be + the batch dimension and the last to be the non-batch input related derivatives. + So the shape of the resulting tensors are $(N, n, d_i)$, + where $N$ is the number of samples in the batch. + + Args: + batch: The batch of data for which to compute mixed gradients. + + Returns: + A dictionary where keys are gradient identifiers and values are the + mixed gradients computed per sample. + """ gradient_dict = self._per_sample_mixed_gradient_dict(batch.to(self.device)) return self._detach_dict(gradient_dict) @@ -90,6 +153,26 @@ def matrix_jacobian_product( batch: TorchBatch, g: torch.Tensor, ) -> torch.Tensor: + r""" + Computes the matrix-Jacobian product for the provided batch and input tensor. + Given the example in the class docstring, this means + + $$ (\nabla_{\omega_{1}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y}), + \nabla_{\omega_{2}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y})) \cdot g^T$$ + + where g must be a tensor of shape $(K, d_1+d_2)$, so the resulting tensor + is of shape $(N, K)$. + + Args: + batch: The batch of data for which to compute the Jacobian. + g: The tensor to be used in the matrix-Jacobian product + calculation. + + Returns: + The resulting tensor from the matrix-Jacobian product computation. + """ result = self._matrix_jacobian_product(batch.to(self.device), g.to(self.device)) if result.requires_grad: result = result.detach() @@ -108,6 +191,26 @@ def per_sample_flat_mixed_gradient(self, batch: TorchBatch) -> torch.Tensor: class TorchPerSampleAutoGrad(TorchPerSampleGradientProvider): + r""" + Compute per-sample gradients of a function defined by + a [torch.nn.Module][torch.nn.Module] and a loss function using + [torch.func][torch.func]. + + Consider a function + + $$ \ell: \mathbb{R}^{d_1} \times \mathbb{R}^{d_2} \times \mathbb{R}^{n} \times + \mathbb{R}^{n}, \quad \ell(\omega_1, \omega_2, x, y) = + \operatorname{loss}(f(\omega_1, \omega_2; x), y) $$ + + e.g. a two layer neural network $f$ with a loss function, then this object should + compute the expressions: + + $$ \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega}\ell(\omega_1, \omega_2, x, y) \cdot v$$ + + """ + def __init__( self, model: torch.nn.Module, diff --git a/src/pydvl/influence/types.py b/src/pydvl/influence/types.py index e9adafc1a..b43ddcb4a 100644 --- a/src/pydvl/influence/types.py +++ b/src/pydvl/influence/types.py @@ -230,7 +230,6 @@ def inner_product( Returns: A tensor representing the inner product. """ - pass def gradient_inner_product( self, @@ -359,7 +358,7 @@ class OperatorGradientComposition( ): """ Generic base class representing a composable block that integrates an operator and - a gradient provider to compute influences between batches of data. + a gradient provider to compute interactions between batches of data. This block is designed to be flexible, handling different computational modes via an abstract operator and gradient provider. From 1e5f18b944c3aefb680db0f5b9d87e487c763d3b Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 22 May 2024 19:45:35 +0200 Subject: [PATCH 18/43] Simplify package structure: * move operator submodules to influence.torch level * move implementations of generic classes to influence.torch.base --- src/pydvl/influence/torch/base.py | 549 ++++++++++++++++++ .../torch/{operator => }/batch_operation.py | 7 +- .../torch/influence_function_model.py | 173 +----- .../torch/{operator/base.py => operator.py} | 102 +--- .../influence/torch/operator/__init__.py | 0 .../influence/torch/operator/bilinear_form.py | 59 -- .../torch/operator/gradient_provider.py | 262 --------- src/pydvl/influence/torch/operator/solve.py | 40 -- src/pydvl/influence/torch/util.py | 31 - .../torch/operator/test_batch_operation.py | 4 +- .../torch/operator/test_gradient_provider.py | 3 +- 11 files changed, 594 insertions(+), 636 deletions(-) create mode 100644 src/pydvl/influence/torch/base.py rename src/pydvl/influence/torch/{operator => }/batch_operation.py (97%) rename src/pydvl/influence/torch/{operator/base.py => operator.py} (60%) delete mode 100644 src/pydvl/influence/torch/operator/__init__.py delete mode 100644 src/pydvl/influence/torch/operator/bilinear_form.py delete mode 100644 src/pydvl/influence/torch/operator/gradient_provider.py delete mode 100644 src/pydvl/influence/torch/operator/solve.py diff --git a/src/pydvl/influence/torch/base.py b/src/pydvl/influence/torch/base.py new file mode 100644 index 000000000..6c570fb72 --- /dev/null +++ b/src/pydvl/influence/torch/base.py @@ -0,0 +1,549 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections import OrderedDict +from dataclasses import dataclass +from typing import Optional, Dict, Callable, Union, List + +import torch +from torch.func import functional_call +from torch.utils.data import DataLoader + +from .functional import create_per_sample_gradient_function, \ + create_per_sample_mixed_derivative_function, create_matrix_jacobian_product_function +from .util import LossType, ModelParameterDictBuilder, \ + BlockMode, flatten_dimensions, ModelInfoMixin +from ..base_influence_function_model import ComposableInfluence +from ..types import PerSampleGradientProvider, Operator, BilinearForm, Batch, \ + OperatorGradientComposition, BlockMapper + + +@dataclass(frozen=True) +class TorchBatch(Batch): + """ + A convenience class for handling batches of data. Validates, the alignment + of the first dimension (batch dimension) of the input and target tensor + + Attributes: + x: The input tensor that contains features or data points. + y: The target tensor that contains labels corresponding to the inputs. + + """ + + x: torch.Tensor + y: torch.Tensor + + def __post_init__(self): + if self.x.shape[0] != self.y.shape[0]: + raise ValueError( + f"The first dimension of x and y must be the same, " + f"got {self.x.shape[0]} and {self.y.shape[0]}" + ) + + def __len__(self): + return self.x.shape[0] + + def to(self, device: torch.device): + return TorchBatch(self.x.to(device), self.y.to(device)) + + +class TorchPerSampleGradientProvider( + PerSampleGradientProvider[TorchBatch, torch.Tensor], ABC +): + r""" + Abstract base class for calculating per-sample gradients of a function defined by + a [torch.nn.Module][torch.nn.Module] and a loss function. + + This class must be subclassed with implementations for its abstract methods tailored + to specific gradient computation needs, e.g. using [torch.autograd][torch.autograd] + or stochastic finite differences. + + Consider a function + + $$ \ell: \mathbb{R}^{d_1} \times \mathbb{R}^{d_2} \times \mathbb{R}^{n} \times + \mathbb{R}^{n}, \quad \ell(\omega_1, \omega_2, x, y) = + \operatorname{loss}(f(\omega_1, \omega_2; x), y) $$ + + e.g. a two layer neural network $f$ with a loss function, then this object should + compute the expressions: + + $$ \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega}\ell(\omega_1, \omega_2, x, y) \cdot v$$ + + """ + + def __init__( + self, + model: torch.nn.Module, + loss: LossType, + restrict_to: Optional[Dict[str, torch.nn.Parameter]], + ): + self.loss = loss + self.model = model + + if restrict_to is None: + restrict_to = ModelParameterDictBuilder(model).build(BlockMode.FULL) + + self.params_to_restrict_to = restrict_to + + def to(self, device: torch.device): + self.model = self.model.to(device) + self.params_to_restrict_to = { + k: p.detach() + for k, p in self.model.named_parameters() + if k in self.params_to_restrict_to + } + return self + + @property + def device(self): + return next(self.model.parameters()).device + + @property + def dtype(self): + return next(self.model.parameters()).dtype + + @abstractmethod + def _per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + pass + + @abstractmethod + def _per_sample_mixed_gradient_dict( + self, batch: TorchBatch + ) -> Dict[str, torch.Tensor]: + pass + + @abstractmethod + def _matrix_jacobian_product( + self, + batch: TorchBatch, + g: torch.Tensor, + ) -> torch.Tensor: + pass + + @staticmethod + def _detach_dict(tensor_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return {k: g.detach() if g.requires_grad else g for k, g in tensor_dict.items()} + + def per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + r""" + Computes and returns a dictionary mapping gradient names to their respective + per-sample gradients. Given the example in the class docstring, this means + + $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y}), $$ + + where the first dimension of the resulting tensors is always considered to be + the batch dimension, so the shape of the resulting tensors are $(N, d_i)$, + where $N$ is the number of samples in the batch. + + Args: + batch: The batch of data for which to compute gradients. + + Returns: + A dictionary where keys are gradient identifiers and values are the + gradients computed per sample. + """ + gradient_dict = self._per_sample_gradient_dict(batch.to(self.device)) + return self._detach_dict(gradient_dict) + + def per_sample_mixed_gradient_dict( + self, batch: TorchBatch + ) -> Dict[str, torch.Tensor]: + r""" + Computes and returns a dictionary mapping gradient names to their respective + per-sample mixed gradients. In this context, mixed gradients refer to computing + gradients with respect to the instance definition in addition to + compute derivatives with respect to the input batch. + Given the example in the class docstring, this means + + $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, + \omega_2, \text{batch.x}, \text{batch.y}), $$ + + where the first dimension of the resulting tensors is always considered to be + the batch dimension and the last to be the non-batch input related derivatives. + So the shape of the resulting tensors are $(N, n, d_i)$, + where $N$ is the number of samples in the batch. + + Args: + batch: The batch of data for which to compute mixed gradients. + + Returns: + A dictionary where keys are gradient identifiers and values are the + mixed gradients computed per sample. + """ + gradient_dict = self._per_sample_mixed_gradient_dict(batch.to(self.device)) + return self._detach_dict(gradient_dict) + + def matrix_jacobian_product( + self, + batch: TorchBatch, + g: torch.Tensor, + ) -> torch.Tensor: + r""" + Computes the matrix-Jacobian product for the provided batch and input tensor. + Given the example in the class docstring, this means + + $$ (\nabla_{\omega_{1}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y}), + \nabla_{\omega_{2}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y})) \cdot g^T$$ + + where g must be a tensor of shape $(K, d_1+d_2)$, so the resulting tensor + is of shape $(N, K)$. + + Args: + batch: The batch of data for which to compute the Jacobian. + g: The tensor to be used in the matrix-Jacobian product + calculation. + + Returns: + The resulting tensor from the matrix-Jacobian product computation. + """ + result = self._matrix_jacobian_product(batch.to(self.device), g.to(self.device)) + if result.requires_grad: + result = result.detach() + return result + + def per_sample_flat_gradient(self, batch: TorchBatch) -> torch.Tensor: + return flatten_dimensions( + self.per_sample_gradient_dict(batch).values(), shape=(batch.x.shape[0], -1) + ) + + def per_sample_flat_mixed_gradient(self, batch: TorchBatch) -> torch.Tensor: + shape = (*batch.x.shape, -1) + return flatten_dimensions( + self.per_sample_mixed_gradient_dict(batch).values(), shape=shape + ) + + +class TorchPerSampleAutoGrad(TorchPerSampleGradientProvider): + r""" + Compute per-sample gradients of a function defined by + a [torch.nn.Module][torch.nn.Module] and a loss function using + [torch.func][torch.func]. + + Consider a function + + $$ \ell: \mathbb{R}^{d_1} \times \mathbb{R}^{d_2} \times \mathbb{R}^{n} \times + \mathbb{R}^{n}, \quad \ell(\omega_1, \omega_2, x, y) = + \operatorname{loss}(f(\omega_1, \omega_2; x), y) $$ + + e.g. a two layer neural network $f$ with a loss function, then this object should + compute the expressions: + + $$ \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega}\ell(\omega_1, \omega_2, x, y) \cdot v$$ + + """ + + def __init__( + self, + model: torch.nn.Module, + loss: LossType, + restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, + ): + super().__init__(model, loss, restrict_to) + self._per_sample_gradient_function = create_per_sample_gradient_function( + model, loss + ) + self._per_sample_mixed_gradient_func = ( + create_per_sample_mixed_derivative_function(model, loss) + ) + + def _compute_loss( + self, params: Dict[str, torch.Tensor], x: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: + outputs = functional_call(self.model, params, (x.unsqueeze(0).to(self.device),)) + return self.loss(outputs, y.unsqueeze(0)) + + def _per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + return self._per_sample_gradient_function( + self.params_to_restrict_to, batch.x, batch.y + ) + + def _per_sample_mixed_gradient_dict( + self, batch: TorchBatch + ) -> Dict[str, torch.Tensor]: + return self._per_sample_mixed_gradient_func( + self.params_to_restrict_to, batch.x, batch.y + ) + + def _matrix_jacobian_product( + self, + batch: TorchBatch, + g: torch.Tensor, + ) -> torch.Tensor: + matrix_jacobian_product_func = create_matrix_jacobian_product_function( + self.model, self.loss, g + ) + return matrix_jacobian_product_func( + self.params_to_restrict_to, batch.x, batch.y + ) + + +GradientProviderFactoryType = Callable[ + [torch.nn.Module, LossType, Optional[Dict[str, torch.nn.Parameter]]], + TorchPerSampleGradientProvider, +] + + +class OperatorBilinearForm( + BilinearForm[torch.Tensor, TorchBatch, TorchPerSampleGradientProvider] +): + r""" + Base class for bilinear forms based on an instance of + [TorchOperator][pydvl.influence.torch.operator.base.TorchOperator]. This means it + computes weighted inner products of the form: + + $$ \langle \operatorname{Op}(x), y \rangle $$ + + """ + + def __init__( + self, + operator: "TorchOperator", + ): + self.operator = operator + + def inner_product( + self, left: torch.Tensor, right: Optional[torch.Tensor] + ) -> torch.Tensor: + r""" + Computes the weighted inner product of two vectors, i.e. + + $$ \langle x, y \rangle_{B} = \langle \operatorname{Op}(x), y \rangle $$ + + Args: + left: The first tensor in the inner product computation. + right: The second tensor, optional; if not provided, the inner product will + use `left` tensor for both arguments. + + Returns: + A tensor representing the inner product. + """ + if right is None: + right = left + if left.shape[0] <= right.shape[0]: + return self._inner_product(left, right) + return self._inner_product(right, left).T + + def _inner_product(self, left: torch.Tensor, right: torch.Tensor) -> torch.Tensor: + left_result = self.operator.apply_to_mat(left) + + if left_result.ndim == right.ndim and left.shape[-1] == right.shape[-1]: + return left_result @ right.T + + return torch.einsum("ia,j...a->ij...", left_result, right) + + +class TorchOperator(Operator[torch.Tensor, OperatorBilinearForm], ABC): + def __init__(self, regularization: float = 0.0): + """ + Initializes the Operator with an optional regularization parameter. + + Args: + regularization: A non-negative float that represents the regularization + strength (default is 0.0). + + Raises: + ValueError: If the regularization parameter is negative. + """ + if regularization < 0: + raise ValueError("regularization must be non-negative") + self._regularization = regularization + + @property + def regularization(self): + return self._regularization + + @regularization.setter + def regularization(self, value: float): + if value < 0: + raise ValueError("regularization must be non-negative") + self._regularization = value + + @property + @abstractmethod + def device(self): + pass + + @property + @abstractmethod + def dtype(self): + pass + + @abstractmethod + def to(self, device: torch.device): + pass + + @abstractmethod + def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: + pass + + def as_bilinear_form(self): + return OperatorBilinearForm(self) + + def apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: + return self._apply_to_vec(vec.to(self.device)) + + def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: + return torch.func.vmap(self.apply_to_vec, in_dims=0, randomness="same")(mat) + + +class TorchOperatorGradientComposition( + OperatorGradientComposition[ + torch.Tensor, TorchBatch, TorchOperator, TorchPerSampleGradientProvider + ] +): + """ + Representing a composable block that integrates an [TorchOperator] + [pydvl.influence.torch.operator.base.TorchOperator] and + a [TorchPerSampleGradientProvider] + [pydvl.influence.torch.operator.gradient_provider.TorchPerSampleGradientProvider] + + This block is designed to be flexible, handling different computational modes via + an abstract operator and gradient provider. + """ + + def __init__(self, op: TorchOperator, gp: TorchPerSampleGradientProvider): + super().__init__(op, gp) + + def to(self, device: torch.device): + self.gp = self.gp.to(device) + self.op = self.op.to(device) + return self + + +class TorchBlockMapper( + BlockMapper[torch.Tensor, TorchBatch, TorchOperatorGradientComposition] +): + """ + Class for mapping operations across multiple compositional blocks represented by + instances of [TorchOperatorGradientComposition] + [pydvl.influence.torch.influence_function_model.TorchOperatorGradientComposition]. + + This class takes a dictionary of compositional blocks and applies their methods to + batches or tensors, and aggregates the results. + """ + + def __init__( + self, composable_block_dict: OrderedDict[str, TorchOperatorGradientComposition] + ): + super().__init__(composable_block_dict) + + def _split_to_blocks( + self, z: torch.Tensor, dim: int = -1 + ) -> OrderedDict[str, torch.Tensor]: + block_sizes = [bi.op.input_size for bi in self.composable_block_dict.values()] + + block_dict = OrderedDict( + zip( + list(self.composable_block_dict.keys()), + torch.split(z, block_sizes, dim=dim), + ) + ) + return block_dict + + def to(self, device: torch.device): + self.composable_block_dict = OrderedDict( + [(k, bi.to(device)) for k, bi in self.composable_block_dict.items()] + ) + return self + + +class TorchComposableInfluence( + ComposableInfluence[torch.Tensor, TorchBatch, DataLoader, TorchBlockMapper], + ModelInfoMixin, + ABC, +): + def __init__( + self, + model: torch.nn.Module, + block_structure: Union[ + BlockMode, OrderedDict[str, OrderedDict[str, torch.nn.Parameter]] + ] = BlockMode.FULL, + regularization: Optional[Union[float, Dict[str, Optional[float]]]] = None, + ): + if isinstance(block_structure, BlockMode): + self.parameter_dict = ModelParameterDictBuilder(model).build( + block_structure + ) + else: + self.parameter_dict = block_structure + + self._regularization_dict = self._build_regularization_dict(regularization) + + super().__init__(model) + + @property + def block_names(self) -> List[str]: + return list(self.parameter_dict.keys()) + + @abstractmethod + def with_regularization( + self, regularization: Union[float, Dict[str, Optional[float]]] + ) -> TorchComposableInfluence: + pass + + def _build_regularization_dict( + self, regularization: Optional[Union[float, Dict[str, Optional[float]]]] + ) -> Dict[str, Optional[float]]: + if regularization is None or isinstance(regularization, float): + return { + k: self._validate_regularization(k, regularization) + for k in self.block_names + } + + if set(regularization.keys()).issubset(set(self.block_names)): + raise ValueError( + f"The regularization must be a float or the keys of the regularization" + f"dictionary must match a subset of" + f"block names: \n {self.block_names}.\n Found not in block names: \n" + f"{set(regularization.keys()).difference(set(self.block_names))}" + ) + return { + k: self._validate_regularization(k, regularization.get(k, None)) + for k in self.block_names + } + + @staticmethod + def _validate_regularization( + block_name: str, value: Optional[float] + ) -> Optional[float]: + if isinstance(value, float) and value < 0.0: + raise ValueError( + f"The regularization for block '{block_name}' must be non-negative, " + f"but found {value=}" + ) + return value + + @abstractmethod + def _create_block( + self, + block_params: Dict[str, torch.nn.Parameter], + data: DataLoader, + regularization: Optional[float], + ) -> TorchOperatorGradientComposition: + pass + + def _create_block_mapper(self, data: DataLoader) -> TorchBlockMapper: + block_influence_dict = OrderedDict() + for k, p in self.parameter_dict.items(): + reg = self._regularization_dict.get(k, None) + reg = self._validate_regularization(k, reg) + block_influence_dict[k] = self._create_block(p, data, reg).to(self.device) + + return TorchBlockMapper(block_influence_dict) + + @staticmethod + def _create_batch(x: torch.Tensor, y: torch.Tensor) -> TorchBatch: + return TorchBatch(x, y) + + def to(self, device: torch.device): + self.model = self.model.to(device) + if hasattr(self, "block_mapper") and self.block_mapper is not None: + self.block_mapper = self.block_mapper.to(device) + return self diff --git a/src/pydvl/influence/torch/operator/batch_operation.py b/src/pydvl/influence/torch/batch_operation.py similarity index 97% rename from src/pydvl/influence/torch/operator/batch_operation.py rename to src/pydvl/influence/torch/batch_operation.py index 9b7d67b25..4a881de6e 100644 --- a/src/pydvl/influence/torch/operator/batch_operation.py +++ b/src/pydvl/influence/torch/batch_operation.py @@ -3,13 +3,14 @@ import torch -from ..functional import create_batch_hvp_function -from ..util import LossType, TorchBatch, inverse_rank_one_update, rank_one_mvp -from .gradient_provider import ( +from .base import ( GradientProviderFactoryType, + TorchBatch, TorchPerSampleAutoGrad, TorchPerSampleGradientProvider, ) +from .functional import create_batch_hvp_function +from .util import LossType, inverse_rank_one_update, rank_one_mvp class BatchOperation(ABC): diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index 4fd89a9e6..96d57632e 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -19,12 +19,15 @@ from ...utils.progress import log_duration from .. import InfluenceMode from ..base_influence_function_model import ( - ComposableInfluence, InfluenceFunctionModel, NotImplementedLayerRepresentationException, UnsupportedInfluenceModeException, ) -from ..types import BlockMapper, OperatorGradientComposition +from .base import ( + TorchComposableInfluence, + TorchOperatorGradientComposition, + TorchPerSampleAutoGrad, +) from .functional import ( LowRankProductRepresentation, create_batch_hvp_function, @@ -36,20 +39,12 @@ model_hessian_low_rank, model_hessian_nystroem_approximation, ) -from .operator.base import TorchOperator -from .operator.gradient_provider import ( - TorchPerSampleAutoGrad, - TorchPerSampleGradientProvider, -) -from .operator.solve import InverseHarmonicMeanOperator +from .operator import InverseHarmonicMeanOperator from .pre_conditioner import PreConditioner from .util import ( BlockMode, EkfacRepresentation, LossType, - ModelInfoMixin, - ModelParameterDictBuilder, - TorchBatch, empirical_cross_entropy_loss_fn, flatten_dimensions, safe_torch_linalg_eigh, @@ -1807,162 +1802,6 @@ def fit(self, data: DataLoader): return self -class TorchOperatorGradientComposition( - OperatorGradientComposition[ - torch.Tensor, TorchBatch, TorchOperator, TorchPerSampleGradientProvider - ] -): - """ - Representing a composable block that integrates an [TorchOperator] - [pydvl.influence.torch.operator.base.TorchOperator] and - a [TorchPerSampleGradientProvider] - [pydvl.influence.torch.operator.gradient_provider.TorchPerSampleGradientProvider] - - This block is designed to be flexible, handling different computational modes via - an abstract operator and gradient provider. - """ - - def __init__(self, op: TorchOperator, gp: TorchPerSampleGradientProvider): - super().__init__(op, gp) - - def to(self, device: torch.device): - self.gp = self.gp.to(device) - self.op = self.op.to(device) - return self - - -class TorchBlockMapper( - BlockMapper[torch.Tensor, TorchBatch, TorchOperatorGradientComposition] -): - """ - Class for mapping operations across multiple compositional blocks represented by - instances of [TorchOperatorGradientComposition] - [pydvl.influence.torch.influence_function_model.TorchOperatorGradientComposition]. - - This class takes a dictionary of compositional blocks and applies their methods to - batches or tensors, and aggregates the results. - """ - - def __init__( - self, composable_block_dict: OrderedDict[str, TorchOperatorGradientComposition] - ): - super().__init__(composable_block_dict) - - def _split_to_blocks( - self, z: torch.Tensor, dim: int = -1 - ) -> OrderedDict[str, torch.Tensor]: - block_sizes = [bi.op.input_size for bi in self.composable_block_dict.values()] - - block_dict = OrderedDict( - zip( - list(self.composable_block_dict.keys()), - torch.split(z, block_sizes, dim=dim), - ) - ) - return block_dict - - def to(self, device: torch.device): - self.composable_block_dict = OrderedDict( - [(k, bi.to(device)) for k, bi in self.composable_block_dict.items()] - ) - return self - - -class TorchComposableInfluence( - ComposableInfluence[torch.Tensor, TorchBatch, DataLoader, TorchBlockMapper], - ModelInfoMixin, - ABC, -): - def __init__( - self, - model: torch.nn.Module, - block_structure: Union[ - BlockMode, OrderedDict[str, OrderedDict[str, torch.nn.Parameter]] - ] = BlockMode.FULL, - regularization: Optional[Union[float, Dict[str, Optional[float]]]] = None, - ): - if isinstance(block_structure, BlockMode): - self.parameter_dict = ModelParameterDictBuilder(model).build( - block_structure - ) - else: - self.parameter_dict = block_structure - - self._regularization_dict = self._build_regularization_dict(regularization) - - super().__init__(model) - - @property - def block_names(self) -> List[str]: - return list(self.parameter_dict.keys()) - - @abstractmethod - def with_regularization( - self, regularization: Union[float, Dict[str, Optional[float]]] - ) -> TorchComposableInfluence: - pass - - def _build_regularization_dict( - self, regularization: Optional[Union[float, Dict[str, Optional[float]]]] - ) -> Dict[str, Optional[float]]: - if regularization is None or isinstance(regularization, float): - return { - k: self._validate_regularization(k, regularization) - for k in self.block_names - } - - if set(regularization.keys()).issubset(set(self.block_names)): - raise ValueError( - f"The regularization must be a float or the keys of the regularization" - f"dictionary must match a subset of" - f"block names: \n {self.block_names}.\n Found not in block names: \n" - f"{set(regularization.keys()).difference(set(self.block_names))}" - ) - return { - k: self._validate_regularization(k, regularization.get(k, None)) - for k in self.block_names - } - - @staticmethod - def _validate_regularization( - block_name: str, value: Optional[float] - ) -> Optional[float]: - if isinstance(value, float) and value < 0.0: - raise ValueError( - f"The regularization for block '{block_name}' must be non-negative, " - f"but found {value=}" - ) - return value - - @abstractmethod - def _create_block( - self, - block_params: Dict[str, torch.nn.Parameter], - data: DataLoader, - regularization: Optional[float], - ) -> TorchOperatorGradientComposition: - pass - - def _create_block_mapper(self, data: DataLoader) -> TorchBlockMapper: - block_influence_dict = OrderedDict() - for k, p in self.parameter_dict.items(): - reg = self._regularization_dict.get(k, None) - reg = self._validate_regularization(k, reg) - block_influence_dict[k] = self._create_block(p, data, reg).to(self.device) - - return TorchBlockMapper(block_influence_dict) - - @staticmethod - def _create_batch(x: torch.Tensor, y: torch.Tensor) -> TorchBatch: - return TorchBatch(x, y) - - def to(self, device: torch.device): - self.model = self.model.to(device) - if hasattr(self, "block_mapper") and self.block_mapper is not None: - self.block_mapper = self.block_mapper.to(device) - return self - - class InverseHarmonicMeanInfluence(TorchComposableInfluence): def __init__( self, diff --git a/src/pydvl/influence/torch/operator/base.py b/src/pydvl/influence/torch/operator.py similarity index 60% rename from src/pydvl/influence/torch/operator/base.py rename to src/pydvl/influence/torch/operator.py index c56398af9..2f2b02175 100644 --- a/src/pydvl/influence/torch/operator/base.py +++ b/src/pydvl/influence/torch/operator.py @@ -1,78 +1,16 @@ -from abc import ABC, abstractmethod -from typing import Callable, Dict, Generator, Optional, Type, Union +from typing import Callable, Generator, Union, Type, Optional, Dict import torch from torch import nn as nn from torch.utils.data import DataLoader -from ...array import LazyChunkSequence, SequenceAggregator -from ...types import Operator -from ..util import TorchBatch, TorchChunkAverageAggregator, TorchPointAverageAggregator -from .batch_operation import ( - BatchOperation, - GaussNewtonBatchOperation, - HessianBatchOperation, -) -from .bilinear_form import OperatorBilinearForm -from .gradient_provider import ( - GradientProviderFactoryType, - TorchPerSampleAutoGrad, - TorchPerSampleGradientProvider, -) - - -class TorchOperator(Operator[torch.Tensor, OperatorBilinearForm], ABC): - def __init__(self, regularization: float = 0.0): - """ - Initializes the Operator with an optional regularization parameter. - - Args: - regularization: A non-negative float that represents the regularization - strength (default is 0.0). - - Raises: - ValueError: If the regularization parameter is negative. - """ - if regularization < 0: - raise ValueError("regularization must be non-negative") - self._regularization = regularization - - @property - def regularization(self): - return self._regularization - - @regularization.setter - def regularization(self, value: float): - if value < 0: - raise ValueError("regularization must be non-negative") - self._regularization = value - - @property - @abstractmethod - def device(self): - pass - - @property - @abstractmethod - def dtype(self): - pass - - @abstractmethod - def to(self, device: torch.device): - pass - - @abstractmethod - def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: - pass - - def as_bilinear_form(self): - return OperatorBilinearForm(self) - - def apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: - return self._apply_to_vec(vec.to(self.device)) - - def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: - return torch.func.vmap(self.apply_to_vec, in_dims=0, randomness="same")(mat) +from ..array import SequenceAggregator, LazyChunkSequence +from .base import TorchOperator, TorchBatch, \ + GradientProviderFactoryType, TorchPerSampleGradientProvider, TorchPerSampleAutoGrad +from .batch_operation import BatchOperation, \ + GaussNewtonBatchOperation, HessianBatchOperation, InverseHarmonicMeanBatchOperation +from .util import TorchPointAverageAggregator, \ + TorchChunkAverageAggregator class AggregateBatchOperator(TorchOperator): @@ -173,3 +111,27 @@ def __init__( ) aggregator = TorchChunkAverageAggregator() super().__init__(batch_op, dataloader, aggregator) + + +class InverseHarmonicMeanOperator(AggregateBatchOperator): + def __init__( + self, + model: nn.Module, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + dataloader: DataLoader, + regularization: float, + gradient_provider_factory: Union[ + GradientProviderFactoryType, + Type[TorchPerSampleGradientProvider], + ] = TorchPerSampleAutoGrad, + restrict_to: Optional[Dict[str, nn.Parameter]] = None, + ): + batch_op = InverseHarmonicMeanBatchOperation( + model, + loss, + regularization, + gradient_provider_factory=gradient_provider_factory, + restrict_to=restrict_to, + ) + aggregator = TorchPointAverageAggregator(weighted=False) + super().__init__(batch_op, dataloader, aggregator) diff --git a/src/pydvl/influence/torch/operator/__init__.py b/src/pydvl/influence/torch/operator/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/pydvl/influence/torch/operator/bilinear_form.py b/src/pydvl/influence/torch/operator/bilinear_form.py deleted file mode 100644 index 0fbacd718..000000000 --- a/src/pydvl/influence/torch/operator/bilinear_form.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import TYPE_CHECKING, Optional - -import torch - -from ...types import BilinearForm -from ..util import TorchBatch -from .gradient_provider import TorchPerSampleGradientProvider - -if TYPE_CHECKING: - from .base import TorchOperator - - -class OperatorBilinearForm( - BilinearForm[torch.Tensor, TorchBatch, TorchPerSampleGradientProvider] -): - r""" - Base class for bilinear forms based on an instance of - [TorchOperator][pydvl.influence.torch.operator.base.TorchOperator]. This means it - computes weighted inner products of the form: - - $$ \langle \operatorname{Op}(x), y \rangle $$ - - """ - - def __init__( - self, - operator: "TorchOperator", - ): - self.operator = operator - - def inner_product( - self, left: torch.Tensor, right: Optional[torch.Tensor] - ) -> torch.Tensor: - r""" - Computes the weighted inner product of two vectors, i.e. - - $$ \langle x, y \rangle_{B} = \langle \operatorname{Op}(x), y \rangle $$ - - Args: - left: The first tensor in the inner product computation. - right: The second tensor, optional; if not provided, the inner product will - use `left` tensor for both arguments. - - Returns: - A tensor representing the inner product. - """ - if right is None: - right = left - if left.shape[0] <= right.shape[0]: - return self._inner_product(left, right) - return self._inner_product(right, left).T - - def _inner_product(self, left: torch.Tensor, right: torch.Tensor) -> torch.Tensor: - left_result = self.operator.apply_to_mat(left) - - if left_result.ndim == right.ndim and left.shape[-1] == right.shape[-1]: - return left_result @ right.T - - return torch.einsum("ia,j...a->ij...", left_result, right) diff --git a/src/pydvl/influence/torch/operator/gradient_provider.py b/src/pydvl/influence/torch/operator/gradient_provider.py deleted file mode 100644 index f3cf8e554..000000000 --- a/src/pydvl/influence/torch/operator/gradient_provider.py +++ /dev/null @@ -1,262 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Callable, Dict, Optional - -import torch -from torch.func import functional_call - -from ...types import PerSampleGradientProvider -from ..functional import ( - create_matrix_jacobian_product_function, - create_per_sample_gradient_function, - create_per_sample_mixed_derivative_function, -) -from ..util import ( - BlockMode, - LossType, - ModelParameterDictBuilder, - TorchBatch, - flatten_dimensions, -) - - -class TorchPerSampleGradientProvider( - PerSampleGradientProvider[TorchBatch, torch.Tensor], ABC -): - r""" - Abstract base class for calculating per-sample gradients of a function defined by - a [torch.nn.Module][torch.nn.Module] and a loss function. - - This class must be subclassed with implementations for its abstract methods tailored - to specific gradient computation needs, e.g. using [torch.autograd][torch.autograd] - or stochastic finite differences. - - Consider a function - - $$ \ell: \mathbb{R}^{d_1} \times \mathbb{R}^{d_2} \times \mathbb{R}^{n} \times - \mathbb{R}^{n}, \quad \ell(\omega_1, \omega_2, x, y) = - \operatorname{loss}(f(\omega_1, \omega_2; x), y) $$ - - e.g. a two layer neural network $f$ with a loss function, then this object should - compute the expressions: - - $$ \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, x, y), - \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, \omega_2, x, y), - \nabla_{\omega}\ell(\omega_1, \omega_2, x, y) \cdot v$$ - - """ - - def __init__( - self, - model: torch.nn.Module, - loss: LossType, - restrict_to: Optional[Dict[str, torch.nn.Parameter]], - ): - self.loss = loss - self.model = model - - if restrict_to is None: - restrict_to = ModelParameterDictBuilder(model).build(BlockMode.FULL) - - self.params_to_restrict_to = restrict_to - - def to(self, device: torch.device): - self.model = self.model.to(device) - self.params_to_restrict_to = { - k: p.detach() - for k, p in self.model.named_parameters() - if k in self.params_to_restrict_to - } - return self - - @property - def device(self): - return next(self.model.parameters()).device - - @property - def dtype(self): - return next(self.model.parameters()).dtype - - @abstractmethod - def _per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: - pass - - @abstractmethod - def _per_sample_mixed_gradient_dict( - self, batch: TorchBatch - ) -> Dict[str, torch.Tensor]: - pass - - @abstractmethod - def _matrix_jacobian_product( - self, - batch: TorchBatch, - g: torch.Tensor, - ) -> torch.Tensor: - pass - - @staticmethod - def _detach_dict(tensor_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - return {k: g.detach() if g.requires_grad else g for k, g in tensor_dict.items()} - - def per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: - r""" - Computes and returns a dictionary mapping gradient names to their respective - per-sample gradients. Given the example in the class docstring, this means - - $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, - \text{batch.x}, \text{batch.y}), $$ - - where the first dimension of the resulting tensors is always considered to be - the batch dimension, so the shape of the resulting tensors are $(N, d_i)$, - where $N$ is the number of samples in the batch. - - Args: - batch: The batch of data for which to compute gradients. - - Returns: - A dictionary where keys are gradient identifiers and values are the - gradients computed per sample. - """ - gradient_dict = self._per_sample_gradient_dict(batch.to(self.device)) - return self._detach_dict(gradient_dict) - - def per_sample_mixed_gradient_dict( - self, batch: TorchBatch - ) -> Dict[str, torch.Tensor]: - r""" - Computes and returns a dictionary mapping gradient names to their respective - per-sample mixed gradients. In this context, mixed gradients refer to computing - gradients with respect to the instance definition in addition to - compute derivatives with respect to the input batch. - Given the example in the class docstring, this means - - $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, - \omega_2, \text{batch.x}, \text{batch.y}), $$ - - where the first dimension of the resulting tensors is always considered to be - the batch dimension and the last to be the non-batch input related derivatives. - So the shape of the resulting tensors are $(N, n, d_i)$, - where $N$ is the number of samples in the batch. - - Args: - batch: The batch of data for which to compute mixed gradients. - - Returns: - A dictionary where keys are gradient identifiers and values are the - mixed gradients computed per sample. - """ - gradient_dict = self._per_sample_mixed_gradient_dict(batch.to(self.device)) - return self._detach_dict(gradient_dict) - - def matrix_jacobian_product( - self, - batch: TorchBatch, - g: torch.Tensor, - ) -> torch.Tensor: - r""" - Computes the matrix-Jacobian product for the provided batch and input tensor. - Given the example in the class docstring, this means - - $$ (\nabla_{\omega_{1}}\ell(\omega_1, \omega_2, - \text{batch.x}, \text{batch.y}), - \nabla_{\omega_{2}}\ell(\omega_1, \omega_2, - \text{batch.x}, \text{batch.y})) \cdot g^T$$ - - where g must be a tensor of shape $(K, d_1+d_2)$, so the resulting tensor - is of shape $(N, K)$. - - Args: - batch: The batch of data for which to compute the Jacobian. - g: The tensor to be used in the matrix-Jacobian product - calculation. - - Returns: - The resulting tensor from the matrix-Jacobian product computation. - """ - result = self._matrix_jacobian_product(batch.to(self.device), g.to(self.device)) - if result.requires_grad: - result = result.detach() - return result - - def per_sample_flat_gradient(self, batch: TorchBatch) -> torch.Tensor: - return flatten_dimensions( - self.per_sample_gradient_dict(batch).values(), shape=(batch.x.shape[0], -1) - ) - - def per_sample_flat_mixed_gradient(self, batch: TorchBatch) -> torch.Tensor: - shape = (*batch.x.shape, -1) - return flatten_dimensions( - self.per_sample_mixed_gradient_dict(batch).values(), shape=shape - ) - - -class TorchPerSampleAutoGrad(TorchPerSampleGradientProvider): - r""" - Compute per-sample gradients of a function defined by - a [torch.nn.Module][torch.nn.Module] and a loss function using - [torch.func][torch.func]. - - Consider a function - - $$ \ell: \mathbb{R}^{d_1} \times \mathbb{R}^{d_2} \times \mathbb{R}^{n} \times - \mathbb{R}^{n}, \quad \ell(\omega_1, \omega_2, x, y) = - \operatorname{loss}(f(\omega_1, \omega_2; x), y) $$ - - e.g. a two layer neural network $f$ with a loss function, then this object should - compute the expressions: - - $$ \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, x, y), - \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, \omega_2, x, y), - \nabla_{\omega}\ell(\omega_1, \omega_2, x, y) \cdot v$$ - - """ - - def __init__( - self, - model: torch.nn.Module, - loss: LossType, - restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, - ): - super().__init__(model, loss, restrict_to) - self._per_sample_gradient_function = create_per_sample_gradient_function( - model, loss - ) - self._per_sample_mixed_gradient_func = ( - create_per_sample_mixed_derivative_function(model, loss) - ) - - def _compute_loss( - self, params: Dict[str, torch.Tensor], x: torch.Tensor, y: torch.Tensor - ) -> torch.Tensor: - outputs = functional_call(self.model, params, (x.unsqueeze(0).to(self.device),)) - return self.loss(outputs, y.unsqueeze(0)) - - def _per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: - return self._per_sample_gradient_function( - self.params_to_restrict_to, batch.x, batch.y - ) - - def _per_sample_mixed_gradient_dict( - self, batch: TorchBatch - ) -> Dict[str, torch.Tensor]: - return self._per_sample_mixed_gradient_func( - self.params_to_restrict_to, batch.x, batch.y - ) - - def _matrix_jacobian_product( - self, - batch: TorchBatch, - g: torch.Tensor, - ) -> torch.Tensor: - matrix_jacobian_product_func = create_matrix_jacobian_product_function( - self.model, self.loss, g - ) - return matrix_jacobian_product_func( - self.params_to_restrict_to, batch.x, batch.y - ) - - -GradientProviderFactoryType = Callable[ - [torch.nn.Module, LossType, Optional[Dict[str, torch.nn.Parameter]]], - TorchPerSampleGradientProvider, -] diff --git a/src/pydvl/influence/torch/operator/solve.py b/src/pydvl/influence/torch/operator/solve.py deleted file mode 100644 index 8c4b4ee55..000000000 --- a/src/pydvl/influence/torch/operator/solve.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Callable, Dict, Optional, Type, Union - -import torch -from torch import nn as nn -from torch.utils.data import DataLoader - -from ..util import TorchPointAverageAggregator -from .base import AggregateBatchOperator, TorchOperator -from .batch_operation import InverseHarmonicMeanBatchOperation -from .gradient_provider import ( - GradientProviderFactoryType, - TorchPerSampleAutoGrad, - TorchPerSampleGradientProvider, -) - -__all__ = ["InverseHarmonicMeanOperator"] - - -class InverseHarmonicMeanOperator(AggregateBatchOperator): - def __init__( - self, - model: nn.Module, - loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], - dataloader: DataLoader, - regularization: float, - gradient_provider_factory: Union[ - GradientProviderFactoryType, - Type[TorchPerSampleGradientProvider], - ] = TorchPerSampleAutoGrad, - restrict_to: Optional[Dict[str, nn.Parameter]] = None, - ): - batch_op = InverseHarmonicMeanBatchOperation( - model, - loss, - regularization, - gradient_provider_factory=gradient_provider_factory, - restrict_to=restrict_to, - ) - aggregator = TorchPointAverageAggregator(weighted=False) - super().__init__(batch_op, dataloader, aggregator) diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index 78df32fd4..d7d4b0226 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -37,7 +37,6 @@ NumpyConverter, SequenceAggregator, ) -from ..types import Batch logger = logging.getLogger(__name__) @@ -57,7 +56,6 @@ "inverse_rank_one_update", "TorchPointAverageAggregator", "TorchChunkAverageAggregator", - "TorchBatch", "LossType", "ModelParameterDictBuilder", "BlockMode", @@ -675,35 +673,6 @@ def inverse_rank_one_update( LossType = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] -@dataclass(frozen=True) -class TorchBatch(Batch): - """ - A convenience class for handling batches of data. Validates, the alignment - of the first dimension (batch dimension) of the input and target tensor - - Attributes: - x: The input tensor that contains features or data points. - y: The target tensor that contains labels corresponding to the inputs. - - """ - - x: torch.Tensor - y: torch.Tensor - - def __post_init__(self): - if self.x.shape[0] != self.y.shape[0]: - raise ValueError( - f"The first dimension of x and y must be the same, " - f"got {self.x.shape[0]} and {self.y.shape[0]}" - ) - - def __len__(self): - return self.x.shape[0] - - def to(self, device: torch.device): - return TorchBatch(self.x.to(device), self.y.to(device)) - - class BlockMode(Enum): LAYER_WISE: str = "layer_wise" PARAMETER_WISE: str = "parameter_wise" diff --git a/tests/influence/torch/operator/test_batch_operation.py b/tests/influence/torch/operator/test_batch_operation.py index 1f5deed71..5c6c89c1b 100644 --- a/tests/influence/torch/operator/test_batch_operation.py +++ b/tests/influence/torch/operator/test_batch_operation.py @@ -3,11 +3,11 @@ import pytest import torch -from pydvl.influence.torch.operator.batch_operation import ( +from pydvl.influence.torch.base import TorchBatch +from pydvl.influence.torch.batch_operation import ( GaussNewtonBatchOperation, HessianBatchOperation, ) -from pydvl.influence.torch.util import TorchBatch from ..test_util import model_data, test_parameters diff --git a/tests/influence/torch/operator/test_gradient_provider.py b/tests/influence/torch/operator/test_gradient_provider.py index 47324acee..59a986233 100644 --- a/tests/influence/torch/operator/test_gradient_provider.py +++ b/tests/influence/torch/operator/test_gradient_provider.py @@ -2,8 +2,7 @@ import pytest import torch -from pydvl.influence.torch.operator.gradient_provider import TorchPerSampleAutoGrad -from pydvl.influence.torch.util import TorchBatch +from pydvl.influence.torch.base import TorchBatch, TorchPerSampleAutoGrad from ...conftest import linear_mixed_second_derivative_analytical, linear_model from ..conftest import DATA_OUTPUT_NOISE, linear_mvp_model From 964f566f9281397bb8295c0291fd0b79a04d6016 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 22 May 2024 19:51:36 +0200 Subject: [PATCH 19/43] Fix linting issues --- src/pydvl/influence/torch/base.py | 28 ++++++++++++++++++++------- src/pydvl/influence/torch/operator.py | 24 +++++++++++++++-------- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/src/pydvl/influence/torch/base.py b/src/pydvl/influence/torch/base.py index 6c570fb72..982af400c 100644 --- a/src/pydvl/influence/torch/base.py +++ b/src/pydvl/influence/torch/base.py @@ -3,19 +3,33 @@ from abc import ABC, abstractmethod from collections import OrderedDict from dataclasses import dataclass -from typing import Optional, Dict, Callable, Union, List +from typing import Callable, Dict, List, Optional, Union import torch from torch.func import functional_call from torch.utils.data import DataLoader -from .functional import create_per_sample_gradient_function, \ - create_per_sample_mixed_derivative_function, create_matrix_jacobian_product_function -from .util import LossType, ModelParameterDictBuilder, \ - BlockMode, flatten_dimensions, ModelInfoMixin from ..base_influence_function_model import ComposableInfluence -from ..types import PerSampleGradientProvider, Operator, BilinearForm, Batch, \ - OperatorGradientComposition, BlockMapper +from ..types import ( + Batch, + BilinearForm, + BlockMapper, + Operator, + OperatorGradientComposition, + PerSampleGradientProvider, +) +from .functional import ( + create_matrix_jacobian_product_function, + create_per_sample_gradient_function, + create_per_sample_mixed_derivative_function, +) +from .util import ( + BlockMode, + LossType, + ModelInfoMixin, + ModelParameterDictBuilder, + flatten_dimensions, +) @dataclass(frozen=True) diff --git a/src/pydvl/influence/torch/operator.py b/src/pydvl/influence/torch/operator.py index 2f2b02175..5b58b9f70 100644 --- a/src/pydvl/influence/torch/operator.py +++ b/src/pydvl/influence/torch/operator.py @@ -1,16 +1,24 @@ -from typing import Callable, Generator, Union, Type, Optional, Dict +from typing import Callable, Dict, Generator, Optional, Type, Union import torch from torch import nn as nn from torch.utils.data import DataLoader -from ..array import SequenceAggregator, LazyChunkSequence -from .base import TorchOperator, TorchBatch, \ - GradientProviderFactoryType, TorchPerSampleGradientProvider, TorchPerSampleAutoGrad -from .batch_operation import BatchOperation, \ - GaussNewtonBatchOperation, HessianBatchOperation, InverseHarmonicMeanBatchOperation -from .util import TorchPointAverageAggregator, \ - TorchChunkAverageAggregator +from ..array import LazyChunkSequence, SequenceAggregator +from .base import ( + GradientProviderFactoryType, + TorchBatch, + TorchOperator, + TorchPerSampleAutoGrad, + TorchPerSampleGradientProvider, +) +from .batch_operation import ( + BatchOperation, + GaussNewtonBatchOperation, + HessianBatchOperation, + InverseHarmonicMeanBatchOperation, +) +from .util import TorchChunkAverageAggregator, TorchPointAverageAggregator class AggregateBatchOperator(TorchOperator): From 3fa9a95ba5aa6f5f455895aea304ac6912a3ed12 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Thu, 23 May 2024 12:54:50 +0200 Subject: [PATCH 20/43] Move test modules --- tests/influence/torch/operator/__init__.py | 0 .../torch/{operator => }/test_batch_operation.py | 3 +-- .../{operator => }/test_gradient_provider.py | 15 ++++++--------- 3 files changed, 7 insertions(+), 11 deletions(-) delete mode 100644 tests/influence/torch/operator/__init__.py rename tests/influence/torch/{operator => }/test_batch_operation.py (96%) rename tests/influence/torch/{operator => }/test_gradient_provider.py (89%) diff --git a/tests/influence/torch/operator/__init__.py b/tests/influence/torch/operator/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/influence/torch/operator/test_batch_operation.py b/tests/influence/torch/test_batch_operation.py similarity index 96% rename from tests/influence/torch/operator/test_batch_operation.py rename to tests/influence/torch/test_batch_operation.py index 5c6c89c1b..b6141a6c8 100644 --- a/tests/influence/torch/operator/test_batch_operation.py +++ b/tests/influence/torch/test_batch_operation.py @@ -2,6 +2,7 @@ import pytest import torch +from influence.torch.test_util import model_data, test_parameters from pydvl.influence.torch.base import TorchBatch from pydvl.influence.torch.batch_operation import ( @@ -9,8 +10,6 @@ HessianBatchOperation, ) -from ..test_util import model_data, test_parameters - @pytest.mark.torch @pytest.mark.parametrize( diff --git a/tests/influence/torch/operator/test_gradient_provider.py b/tests/influence/torch/test_gradient_provider.py similarity index 89% rename from tests/influence/torch/operator/test_gradient_provider.py rename to tests/influence/torch/test_gradient_provider.py index 59a986233..c15814e37 100644 --- a/tests/influence/torch/operator/test_gradient_provider.py +++ b/tests/influence/torch/test_gradient_provider.py @@ -1,11 +1,10 @@ import numpy as np import pytest import torch +from influence.conftest import linear_mixed_second_derivative_analytical, linear_model +from influence.torch.conftest import DATA_OUTPUT_NOISE, linear_mvp_model -from pydvl.influence.torch.base import TorchBatch, TorchPerSampleAutoGrad - -from ...conftest import linear_mixed_second_derivative_analytical, linear_model -from ..conftest import DATA_OUTPUT_NOISE, linear_mvp_model +from pydvl.influence.torch.base import TorchAutoGrad, TorchBatch class TestTorchPerSampleAutograd: @@ -22,7 +21,7 @@ def test_per_sample_gradient(self, in_features, out_features, batch_size): y = torch.randn(batch_size, out_features) params = {k: p.detach() for k, p in model.named_parameters() if p.requires_grad} - gp = TorchPerSampleAutoGrad(model, loss, restrict_to=params) + gp = TorchAutoGrad(model, loss, restrict_to=params) gradients = gp.per_sample_gradient_dict(TorchBatch(x, y)) flat_gradients = gp.per_sample_flat_gradient(TorchBatch(x, y)) @@ -69,7 +68,7 @@ def test_mixed_derivatives(self, in_features, out_features, train_set_size): torch_train_x = torch.as_tensor(train_x) torch_train_y = torch.as_tensor(train_y) - gp = TorchPerSampleAutoGrad(model, loss, restrict_to=params) + gp = TorchAutoGrad(model, loss, restrict_to=params) flat_functorch_mixed_derivatives = gp.per_sample_flat_mixed_gradient( TorchBatch(torch_train_x, torch_train_y) ) @@ -93,9 +92,7 @@ def test_matrix_jacobian_product( y = torch.randn(batch_size, out_features, requires_grad=True) y_pred = model(x) - gp = TorchPerSampleAutoGrad( - model, torch.nn.functional.mse_loss, restrict_to=params - ) + gp = TorchAutoGrad(model, torch.nn.functional.mse_loss, restrict_to=params) G = torch.randn((10, out_features * (in_features + 1))) mjp = gp.matrix_jacobian_product(TorchBatch(x, y), G) From a1d4c9a024e182e46f5286fd26c592b908d580e5 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Thu, 23 May 2024 12:55:45 +0200 Subject: [PATCH 21/43] Add module description to pydvl.influence.types --- src/pydvl/influence/types.py | 46 ++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/src/pydvl/influence/types.py b/src/pydvl/influence/types.py index b43ddcb4a..0a4d5b0fd 100644 --- a/src/pydvl/influence/types.py +++ b/src/pydvl/influence/types.py @@ -1,3 +1,45 @@ +""" +This module offers a set of generic types, which can be used to build modular and +flexible components for influence computation for different tensor frameworks. + + +Key components include: + +1. [GradientProvider][pydvl.influence.types.GradientProvider]: A generic + abstract base class designed to provide methods for computing per-sample + gradients and other related computations for given data batches. + +2. [BilinearForm][pydvl.influence.types.BilinearForm]: A generic abstract base class + for representing bilinear forms for computing inner products involving gradients. + +3. [Operator][pydvl.influence.types.Operator]: A generic abstract base class for + operators that can apply transformations to vectors and matrices and can be + represented as bilinear forms. + +4. [OperatorGradientComposition][pydvl.influence.types.OperatorGradientComposition]: A + generic abstract composition class that integrates an operator with a gradient + provider to compute interactions between batches of data. + +5. [BlockMapper][pydvl.influence.types.BlockMapper]: A generic abstract base class + for mapping operations across multiple compositional blocks, given by objects + of type + [OperatorGradientComposition][pydvl.influence.types.OperatorGradientComposition], + and aggregating the results. + +To see the usage of these types, see the implementation +[ComposableInfluence][pydvl.influence.base_influence_function_model.ComposableInfluence] +. Using these components allows the straightforward implementation of various +combinations of approximations of inverse Hessian applications +(or Gauss-Newton approximations), different blocking strategies +(e.g. layer-wise or block-wise) and different ways to +compute gradients. + +For the usage with a specific tensor framework, these types must be subclassed. An +example for [torch][torch] is provided in the module +[pydvl.influence.torch.base][pydvl.influence.torch.base] and the base class +[TorchComposableInfluence][pydvl.influence.torch.base.TorchComposableInfluence]. +""" + from __future__ import annotations from abc import ABC, abstractmethod @@ -55,7 +97,7 @@ class Batch(Generic[TensorType]): BatchType = TypeVar("BatchType", bound=Batch) -class PerSampleGradientProvider(Generic[BatchType, TensorType], ABC): +class GradientProvider(Generic[BatchType, TensorType], ABC): r""" Provides an interface for calculating per-sample gradients and other related computations for a given batch of data. @@ -196,7 +238,7 @@ def per_sample_flat_mixed_gradient(self, batch: BatchType) -> TensorType: """ -GradientProviderType = TypeVar("GradientProviderType", bound=PerSampleGradientProvider) +GradientProviderType = TypeVar("GradientProviderType", bound=GradientProvider) class BilinearForm(Generic[TensorType, BatchType, GradientProviderType], ABC): From 545308a90dbf5767420d486dcdd1073ca8258cd6 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Fri, 24 May 2024 22:41:45 +0200 Subject: [PATCH 22/43] Update reference bibtex file --- docs/assets/pydvl.bib | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/docs/assets/pydvl.bib b/docs/assets/pydvl.bib index 724e75f20..ed4dc30d3 100644 --- a/docs/assets/pydvl.bib +++ b/docs/assets/pydvl.bib @@ -122,7 +122,8 @@ @inproceedings{george_fast_2018 publisher = {Curran Associates, Inc.}, url = {https://proceedings.neurips.cc/paper/2018/hash/48000647b315f6f00f913caa757a70b3-Abstract.html}, urldate = {2024-01-12}, - abstract = {Optimization algorithms that leverage gradient covariance information, such as variants of natural gradient descent (Amari, 1998), offer the prospect of yielding more effective descent directions. For models with many parameters, the covari- ance matrix they are based on becomes gigantic, making them inapplicable in their original form. This has motivated research into both simple diagonal approxima- tions and more sophisticated factored approximations such as KFAC (Heskes, 2000; Martens \& Grosse, 2015; Grosse \& Martens, 2016). In the present work we draw inspiration from both to propose a novel approximation that is provably better than KFAC and amendable to cheap partial updates. It consists in tracking a diagonal variance, not in parameter coordinates, but in a Kronecker-factored eigenbasis, in which the diagonal approximation is likely to be more effective. Experiments show improvements over KFAC in optimization speed for several deep network architectures.} + abstract = {Optimization algorithms that leverage gradient covariance information, such as variants of natural gradient descent (Amari, 1998), offer the prospect of yielding more effective descent directions. For models with many parameters, the covari- ance matrix they are based on becomes gigantic, making them inapplicable in their original form. This has motivated research into both simple diagonal approxima- tions and more sophisticated factored approximations such as KFAC (Heskes, 2000; Martens \& Grosse, 2015; Grosse \& Martens, 2016). In the present work we draw inspiration from both to propose a novel approximation that is provably better than KFAC and amendable to cheap partial updates. It consists in tracking a diagonal variance, not in parameter coordinates, but in a Kronecker-factored eigenbasis, in which the diagonal approximation is likely to be more effective. Experiments show improvements over KFAC in optimization speed for several deep network architectures.}, + keywords = {notion} } @inproceedings{ghorbani_data_2019, @@ -175,7 +176,8 @@ @inproceedings{hataya_nystrom_2023 urldate = {2024-02-26}, abstract = {The essential difficulty of gradient-based bilevel optimization using implicit differentiation is to estimate the inverse Hessian vector product with respect to neural network parameters. This paper proposes to tackle this problem by the Nyström method and the Woodbury matrix identity, exploiting the low-rankness of the Hessian. Compared to existing methods using iterative approximation, such as conjugate gradient and the Neumann series approximation, the proposed method avoids numerical instability and can be efficiently computed in matrix operations without iterations. As a result, the proposed method works stably in various tasks and is faster than iterative approximations. Throughout experiments including large-scale hyperparameter optimization and meta learning, we demonstrate that the Nyström method consistently achieves comparable or even superior performance to other approaches. The source code is available from https://github.com/moskomule/hypergrad.}, eventtitle = {International {{Conference}} on {{Artificial Intelligence}} and {{Statistics}}}, - langid = {english} + langid = {english}, + keywords = {notion} } @article{ji_breakdownfree_2017, @@ -292,6 +294,18 @@ @inproceedings{kwon_beta_2022 keywords = {notion} } +@inproceedings{kwon_datainf_2023, + title = {{{DataInf}}: {{Efficiently Estimating Data Influence}} in {{LoRA-tuned LLMs}} and {{Diffusion Models}}}, + shorttitle = {{{DataInf}}}, + author = {Kwon, Yongchan and Wu, Eric and Wu, Kevin and Zou, James}, + date = {2023-10-13}, + doi = {10.48550/arXiv.2310.00902}, + url = {https://openreview.net/forum?id=9m02ib92Wz}, + urldate = {2023-10-27}, + abstract = {Quantifying the impact of training data points is crucial for understanding the outputs of machine learning models and for improving the transparency of the AI pipeline. The influence function is a principled and popular data attribution method, but its computational cost often makes it challenging to use. This issue becomes more pronounced in the setting of large language models and text-to-image models. In this work, we propose DataInf, an efficient influence approximation method that is practical for large-scale generative AI models. Leveraging an easy-to-compute closed-form expression, DataInf outperforms existing influence computation algorithms in terms of computational and memory efficiency. Our theoretical analysis shows that DataInf is particularly well-suited for parameter-efficient fine-tuning techniques such as LoRA. Through systematic empirical evaluations, we show that DataInf accurately approximates influence scores and is orders of magnitude faster than existing methods. In applications to RoBERTa-large, Llama-2-13B-chat, and stable-diffusion-v1.5 models, DataInf effectively identifies the most influential fine-tuning examples better than other approximate influence scores. Moreover, it can help to identify which data points are mislabeled.}, + eventtitle = {The {{Twelfth International Conference}} on {{Learning Representations}}} +} + @inproceedings{kwon_dataoob_2023, title = {Data-{{OOB}}: {{Out-of-bag Estimate}} as a {{Simple}} and {{Efficient Data Value}}}, shorttitle = {Data-{{OOB}}}, @@ -303,7 +317,7 @@ @inproceedings{kwon_dataoob_2023 issn = {2640-3498}, url = {https://proceedings.mlr.press/v202/kwon23e.html}, urldate = {2023-09-06}, - abstract = {Data valuation is a powerful framework for providing statistical insights into which data are beneficial or detrimental to model training. Many Shapley-based data valuation methods have shown promising results in various downstream tasks, however, they are well known to be computationally challenging as it requires training a large number of models. As a result, it has been recognized as infeasible to apply to large datasets. To address this issue, we propose Data-OOB, a new data valuation method for a bagging model that utilizes the out-of-bag estimate. The proposed method is computationally efficient and can scale to millions of data by reusing trained weak learners. Specifically, Data-OOB takes less than 2.25 hours on a single CPU processor when there are \$10\^{}6\$ samples to evaluate and the input dimension is 100. Furthermore, Data-OOB has solid theoretical interpretations in that it identifies the same important data point as the infinitesimal jackknife influence function when two different points are compared. We conduct comprehensive experiments using 12 classification datasets, each with thousands of sample sizes. We demonstrate that the proposed method significantly outperforms existing state-of-the-art data valuation methods in identifying mislabeled data and finding a set of helpful (or harmful) data points, highlighting the potential for applying data values in real-world applications.}, + abstract = {Data valuation is a powerful framework for providing statistical insights into which data are beneficial or detrimental to model training. Many Shapley-based data valuation methods have shown promising results in various downstream tasks, however, they are well known to be computationally challenging as it requires training a large number of models. As a result, it has been recognized as infeasible to apply to large datasets. To address this issue, we propose Data-OOB, a new data valuation method for a bagging model that utilizes the out-of-bag estimate. The proposed method is computationally efficient and can scale to millions of data by reusing trained weak learners. Specifically, Data-OOB takes less than 2.25 hours on a single CPU processor when there are \$10\textasciicircum 6\$ samples to evaluate and the input dimension is 100. Furthermore, Data-OOB has solid theoretical interpretations in that it identifies the same important data point as the infinitesimal jackknife influence function when two different points are compared. We conduct comprehensive experiments using 12 classification datasets, each with thousands of sample sizes. We demonstrate that the proposed method significantly outperforms existing state-of-the-art data valuation methods in identifying mislabeled data and finding a set of helpful (or harmful) data points, highlighting the potential for applying data values in real-world applications.}, eventtitle = {International {{Conference}} on {{Machine Learning}}}, langid = {english}, keywords = {notion} From 6337a5f150bd19a1b7bfa5a7dae63f6cfc9f3a63 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Fri, 24 May 2024 22:42:40 +0200 Subject: [PATCH 23/43] Refactor influence package: * improve naming * add and extend documentation * simplify classes --- src/pydvl/influence/array.py | 5 +- .../base_influence_function_model.py | 203 +++++++++++++- src/pydvl/influence/torch/__init__.py | 1 + src/pydvl/influence/torch/base.py | 147 +++++----- src/pydvl/influence/torch/batch_operation.py | 250 ++++++++++++++---- .../torch/influence_function_model.py | 115 +++++++- src/pydvl/influence/torch/operator.py | 183 +++++++++++-- src/pydvl/influence/torch/util.py | 95 ++++++- src/pydvl/influence/types.py | 92 ++++--- tests/influence/test_influence_calculator.py | 4 - tests/influence/torch/test_batch_operation.py | 3 +- .../influence/torch/test_gradient_provider.py | 13 +- tests/influence/torch/test_util.py | 2 +- 13 files changed, 900 insertions(+), 213 deletions(-) diff --git a/src/pydvl/influence/array.py b/src/pydvl/influence/array.py index 5faa288ac..7a8c5e881 100644 --- a/src/pydvl/influence/array.py +++ b/src/pydvl/influence/array.py @@ -405,8 +405,7 @@ def _initialize_zarr_array( class SumAggregator(SequenceAggregator): def __call__(self, tensor_sequence: LazyChunkSequence): """ - Aggregates tensors from a single-level generator by summing up. This method simply - collects each tensor emitted by the generator into a single list. + Aggregates tensors from a single-level generator by summing up. Args: tensor_sequence: Object wrapping a generator that yields `TensorType` @@ -418,5 +417,5 @@ def __call__(self, tensor_sequence: LazyChunkSequence): tensor_generator = tensor_sequence.generator_factory() result = next(tensor_generator) for tensor in tensor_generator: - result += tensor + result = result + tensor return result diff --git a/src/pydvl/influence/base_influence_function_model.py b/src/pydvl/influence/base_influence_function_model.py index 5077f8d88..51feb8ab4 100644 --- a/src/pydvl/influence/base_influence_function_model.py +++ b/src/pydvl/influence/base_influence_function_model.py @@ -30,7 +30,10 @@ def __init__(self, object_type: Type): class NotImplementedLayerRepresentationException(ValueError): def __init__(self, module_id: str): - message = f"Only Linear layers are supported, but found module {module_id} requiring grad." + message = ( + f"Only Linear layers are supported, but found module {module_id} " + f"requiring grad." + ) super().__init__(message) @@ -82,6 +85,23 @@ def wrapper(self, *args, **kwargs): return wrapper def influence_factors(self, x: TensorType, y: TensorType) -> TensorType: + r""" + Computes the approximation of + + \[ H^{-1}\nabla_{\theta} \ell(y, f_{\theta}(x)) \] + + where the gradient is meant to be per sample of the batch $(x, y)$. + For all input tensors it is assumed, + that the first dimension is the batch dimension. + + Args: + x: model input to use in the gradient computations + y: label tensor to compute gradients + + Returns: + Tensor representing the element-wise inverse Hessian matrix vector products + + """ if not self.is_fitted: raise NotFittedException(type(self)) return self._influence_factors(x, y) @@ -112,6 +132,36 @@ def influences( y: Optional[TensorType] = None, mode: InfluenceMode = InfluenceMode.Up, ) -> TensorType: + r""" + Computes the approximation of + + \[ \langle H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, + f_{\theta}(x_{\text{test}})), + \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle H^{-1}\nabla_{\theta} \ell(y_{test}, f_{\theta}(x_{test})), + \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. + + Args: + x_test: model input to use in the gradient computations + of $H^{-1}\nabla_{theta} \ell(y_{test}, f_{\theta}(x_{test}))$ + y_test: label tensor to compute gradients + x: optional model input to use in the gradient computations + $\nabla_{theta}\ell(y, f_{\theta}(x))$, + resp. $\nabla_{x}\nabla_{theta}\ell(y, f_{\theta}(x))$, + if None, use $x=x_{test}$ + y: optional label tensor to compute gradients + mode: enum value of [InfluenceMode] + [pydvl.influence.base_influence_function_model.InfluenceMode] + + Returns: + Tensor representing the element-wise scalar products for the provided batch + + """ if not self.is_fitted: raise NotFittedException(type(self)) @@ -214,6 +264,11 @@ class ComposableInfluence( Generic[TensorType, BatchType, DataLoaderType, BlockMapperType], ABC, ): + """ + Generic abstract base class, that allow for block-wise computation of influence + quantities. Inherit from this base class for specific influence algorithms and + tensor frameworks. + """ block_mapper: BlockMapperType @@ -226,11 +281,30 @@ def is_fitted(self): @log_duration(log_level=logging.INFO) def fit(self, data: DataLoaderType) -> InfluenceFunctionModel: + """ + Fitting to provided data, by internally creating a block mapper instance from + it. + Args: + data: iterable of tensors + + Returns: + Fitted instance + """ self.block_mapper = self._create_block_mapper(data) return self @abstractmethod def _create_block_mapper(self, data: DataLoaderType) -> BlockMapperType: + """ + Override this method to create a block mapper instance, that can be used + to compute block-wise influence quantities. + + Args: + data: iterable of tensors + + Returns: + BlockMapper instance + """ pass @InfluenceFunctionModel.fit_required @@ -242,6 +316,39 @@ def influences_by_block( y: Optional[TensorType] = None, mode: InfluenceMode = InfluenceMode.Up, ) -> OrderedDict[str, TensorType]: + r""" + Compute the block-wise influence values for the provided data, i.e. an + approximation of + + \[ \langle H^{-1}\nabla_{theta} \ell(y_{\text{test}}, + f_{\theta}(x_{\text{test}})), + \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle H^{-1}\nabla_{theta} \ell(y_{test}, f_{\theta}(x_{test})), + \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. + + Args: + x_test: model input to use in the gradient computations + of the approximation of + $H^{-1}\nabla_{theta} \ell(y_{test}, f_{\theta}(x_{test}))$ + y_test: label tensor to compute gradients + x: optional model input to use in the gradient computations + $\nabla_{theta}\ell(y, f_{\theta}(x))$, + resp. $\nabla_{x}\nabla_{theta}\ell(y, f_{\theta}(x))$, + if None, use $x=x_{test}$ + y: optional label tensor to compute gradients + mode: enum value of [InfluenceMode] + [pydvl.influence.base_influence_function_model.InfluenceMode] + + Returns: + Ordered dictionary of tensors representing the element-wise scalar products + for the provided batch per block. + + """ left_batch = self._create_batch(x_test, y_test) if x is None: @@ -259,13 +366,29 @@ def influences_by_block( ) right_batch = self._create_batch(x, y) - return self.block_mapper.block_interactions(left_batch, right_batch, mode) + return self.block_mapper.interactions(left_batch, right_batch, mode) @InfluenceFunctionModel.fit_required def influence_factors_by_block( self, x: TensorType, y: TensorType ) -> OrderedDict[str, TensorType]: - return self.block_mapper.block_transformed_gradients(self._create_batch(x, y)) + r""" + Compute the block-wise approximation of + + \[ H^{-1}\nabla_{\theta} \ell(y, f_{\theta}(x)) \] + + where the gradient is meant to be per sample of the batch $(x, y)$. + + Args: + x: model input to use in the gradient computations + y: label tensor to compute gradients + + Returns: + Ordered dictionary of tensors representing the element-wise + approximate inverse Hessian matrix vector products per block. + + """ + return self.block_mapper.transformed_grads(self._create_batch(x, y)) @InfluenceFunctionModel.fit_required def influences_from_factors_by_block( @@ -275,13 +398,44 @@ def influences_from_factors_by_block( y: TensorType, mode: InfluenceMode = InfluenceMode.Up, ) -> OrderedDict[str, TensorType]: - return self.block_mapper.block_interactions_from_transformed_gradients( + r""" + Block-wise computation of + + \[ \langle z_{\text{test_factors}}, + \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle z_{\text{test_factors}}, + \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. The gradient is meant to be per sample + of the batch $(x, y)$. + + Args: + z_test_factors: pre-computed array, approximating + $H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, + f_{\theta}(x_{\text{test}}))$ + x: model input to use in the gradient computations + $\nabla_{\theta}\ell(y, f_{\theta}(x))$, + resp. $\nabla_{x}\nabla_{\theta}\ell(y, f_{\theta}(x))$, + if None, use $x=x_{\text{test}}$ + y: label tensor to compute gradients + mode: enum value of [InfluenceMode] + [pydvl.influence.base_influence_function_model.InfluenceMode] + + Returns: + Ordered dictionary of tensors representing the element-wise scalar products + for the provided batch per block + + """ + return self.block_mapper.interactions_from_transformed_grads( z_test_factors, self._create_batch(x, y), mode ) def _influence_factors(self, x: TensorType, y: TensorType) -> TensorType: tensor_gen_factory = partial( - self.block_mapper.generate_transformed_gradients, self._create_batch(x, y) + self.block_mapper.generate_transformed_grads, self._create_batch(x, y) ) aggregator = SumAggregator() result: TensorType = aggregator(LazyChunkSequence(tensor_gen_factory)) @@ -308,7 +462,7 @@ def _influences( right_batch = self._create_batch(x, y) tensor_gen_factory = partial( - self.block_mapper.generate_gradient_interactions, + self.block_mapper.generate_interactions, left_batch, right_batch, mode, @@ -325,8 +479,39 @@ def influences_from_factors( y: TensorType, mode: InfluenceMode = InfluenceMode.Up, ) -> TensorType: + r""" + Computation of + + \[ \langle z_{\text{test_factors}}, + \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle z_{\text{test_factors}}, + \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. The gradient is meant to be per sample + of the batch $(x, y)$. + + Args: + z_test_factors: pre-computed array, approximating + $H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, + f_{\theta}(x_{\text{test}}))$ + x: model input to use in the gradient computations + $\nabla_{\theta}\ell(y, f_{\theta}(x))$, + resp. $\nabla_{x}\nabla_{\theta}\ell(y, f_{\theta}(x))$, + if None, use $x=x_{\text{test}}$ + y: label tensor to compute gradients + mode: enum value of [InfluenceMode] + [pydvl.influence.base_influence_function_model.InfluenceMode] + + Returns: + Tensor representing the element-wise scalar products for the provided batch + + """ + tensor_gen_factory = partial( - self.block_mapper.generate_interactions_from_transformed_gradients, + self.block_mapper.generate_interactions_from_transformed_grads, z_test_factors, self._create_batch(x, y), mode, @@ -339,4 +524,6 @@ def influences_from_factors( @staticmethod @abstractmethod def _create_batch(x: TensorType, y: TensorType) -> BatchType: - pass + """Implement this method to provide the creation of a subtype of + [Batch][pydvl.influence.types.Batch] for a specific framework + """ diff --git a/src/pydvl/influence/torch/__init__.py b/src/pydvl/influence/torch/__init__.py index 3bbd9552c..a1e0bb09a 100644 --- a/src/pydvl/influence/torch/__init__.py +++ b/src/pydvl/influence/torch/__init__.py @@ -3,6 +3,7 @@ CgInfluence, DirectInfluence, EkfacInfluence, + InverseHarmonicMeanInfluence, LissaInfluence, NystroemSketchInfluence, ) diff --git a/src/pydvl/influence/torch/base.py b/src/pydvl/influence/torch/base.py index 982af400c..73a0f2933 100644 --- a/src/pydvl/influence/torch/base.py +++ b/src/pydvl/influence/torch/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, TypeVar, Union import torch from torch.func import functional_call @@ -14,9 +14,9 @@ Batch, BilinearForm, BlockMapper, + GradientProvider, Operator, OperatorGradientComposition, - PerSampleGradientProvider, ) from .functional import ( create_matrix_jacobian_product_function, @@ -61,9 +61,7 @@ def to(self, device: torch.device): return TorchBatch(self.x.to(device), self.y.to(device)) -class TorchPerSampleGradientProvider( - PerSampleGradientProvider[TorchBatch, torch.Tensor], ABC -): +class TorchGradientProvider(GradientProvider[TorchBatch, torch.Tensor], ABC): r""" Abstract base class for calculating per-sample gradients of a function defined by a [torch.nn.Module][torch.nn.Module] and a loss function. @@ -97,7 +95,9 @@ def __init__( self.model = model if restrict_to is None: - restrict_to = ModelParameterDictBuilder(model).build(BlockMode.FULL) + restrict_to = ModelParameterDictBuilder(model).build_from_block_mode( + BlockMode.FULL + ) self.params_to_restrict_to = restrict_to @@ -119,17 +119,15 @@ def dtype(self): return next(self.model.parameters()).dtype @abstractmethod - def _per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + def _grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: pass @abstractmethod - def _per_sample_mixed_gradient_dict( - self, batch: TorchBatch - ) -> Dict[str, torch.Tensor]: + def _mixed_grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: pass @abstractmethod - def _matrix_jacobian_product( + def _jacobian_prod( self, batch: TorchBatch, g: torch.Tensor, @@ -140,9 +138,9 @@ def _matrix_jacobian_product( def _detach_dict(tensor_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {k: g.detach() if g.requires_grad else g for k, g in tensor_dict.items()} - def per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + def grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: r""" - Computes and returns a dictionary mapping gradient names to their respective + Computes and returns a dictionary mapping parameter names to their respective per-sample gradients. Given the example in the class docstring, this means $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, @@ -159,12 +157,10 @@ def per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor] A dictionary where keys are gradient identifiers and values are the gradients computed per sample. """ - gradient_dict = self._per_sample_gradient_dict(batch.to(self.device)) + gradient_dict = self._grads(batch.to(self.device)) return self._detach_dict(gradient_dict) - def per_sample_mixed_gradient_dict( - self, batch: TorchBatch - ) -> Dict[str, torch.Tensor]: + def mixed_grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: r""" Computes and returns a dictionary mapping gradient names to their respective per-sample mixed gradients. In this context, mixed gradients refer to computing @@ -187,10 +183,10 @@ def per_sample_mixed_gradient_dict( A dictionary where keys are gradient identifiers and values are the mixed gradients computed per sample. """ - gradient_dict = self._per_sample_mixed_gradient_dict(batch.to(self.device)) + gradient_dict = self._mixed_grads(batch.to(self.device)) return self._detach_dict(gradient_dict) - def matrix_jacobian_product( + def jacobian_prod( self, batch: TorchBatch, g: torch.Tensor, @@ -215,24 +211,22 @@ def matrix_jacobian_product( Returns: The resulting tensor from the matrix-Jacobian product computation. """ - result = self._matrix_jacobian_product(batch.to(self.device), g.to(self.device)) + result = self._jacobian_prod(batch.to(self.device), g.to(self.device)) if result.requires_grad: result = result.detach() return result - def per_sample_flat_gradient(self, batch: TorchBatch) -> torch.Tensor: + def flat_grads(self, batch: TorchBatch) -> torch.Tensor: return flatten_dimensions( - self.per_sample_gradient_dict(batch).values(), shape=(batch.x.shape[0], -1) + self.grads(batch).values(), shape=(batch.x.shape[0], -1) ) - def per_sample_flat_mixed_gradient(self, batch: TorchBatch) -> torch.Tensor: + def flat_mixed_grads(self, batch: TorchBatch) -> torch.Tensor: shape = (*batch.x.shape, -1) - return flatten_dimensions( - self.per_sample_mixed_gradient_dict(batch).values(), shape=shape - ) + return flatten_dimensions(self.mixed_grads(batch).values(), shape=shape) -class TorchPerSampleAutoGrad(TorchPerSampleGradientProvider): +class TorchAutoGrad(TorchGradientProvider): r""" Compute per-sample gradients of a function defined by a [torch.nn.Module][torch.nn.Module] and a loss function using @@ -273,19 +267,17 @@ def _compute_loss( outputs = functional_call(self.model, params, (x.unsqueeze(0).to(self.device),)) return self.loss(outputs, y.unsqueeze(0)) - def _per_sample_gradient_dict(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + def _grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: return self._per_sample_gradient_function( self.params_to_restrict_to, batch.x, batch.y ) - def _per_sample_mixed_gradient_dict( - self, batch: TorchBatch - ) -> Dict[str, torch.Tensor]: + def _mixed_grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: return self._per_sample_mixed_gradient_func( self.params_to_restrict_to, batch.x, batch.y ) - def _matrix_jacobian_product( + def _jacobian_prod( self, batch: TorchBatch, g: torch.Tensor, @@ -300,12 +292,12 @@ def _matrix_jacobian_product( GradientProviderFactoryType = Callable[ [torch.nn.Module, LossType, Optional[Dict[str, torch.nn.Parameter]]], - TorchPerSampleGradientProvider, + TorchGradientProvider, ] class OperatorBilinearForm( - BilinearForm[torch.Tensor, TorchBatch, TorchPerSampleGradientProvider] + BilinearForm[torch.Tensor, TorchBatch, TorchGradientProvider] ): r""" Base class for bilinear forms based on an instance of @@ -322,13 +314,13 @@ def __init__( ): self.operator = operator - def inner_product( + def inner_prod( self, left: torch.Tensor, right: Optional[torch.Tensor] ) -> torch.Tensor: r""" Computes the weighted inner product of two vectors, i.e. - $$ \langle x, y \rangle_{B} = \langle \operatorname{Op}(x), y \rangle $$ + $$ \langle \operatorname{Op}(\text{left}), \text{right} \rangle $$ Args: left: The first tensor in the inner product computation. @@ -354,30 +346,10 @@ def _inner_product(self, left: torch.Tensor, right: torch.Tensor) -> torch.Tenso class TorchOperator(Operator[torch.Tensor, OperatorBilinearForm], ABC): - def __init__(self, regularization: float = 0.0): - """ - Initializes the Operator with an optional regularization parameter. - - Args: - regularization: A non-negative float that represents the regularization - strength (default is 0.0). - - Raises: - ValueError: If the regularization parameter is negative. - """ - if regularization < 0: - raise ValueError("regularization must be non-negative") - self._regularization = regularization - - @property - def regularization(self): - return self._regularization - - @regularization.setter - def regularization(self, value: float): - if value < 0: - raise ValueError("regularization must be non-negative") - self._regularization = value + """ + Abstract base class for operators that can be applied to instances of + [torch.Tensor][torch.Tensor]. + """ @property @abstractmethod @@ -398,18 +370,46 @@ def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: pass def as_bilinear_form(self): + """ + Represent this operator as a + [OperatorBilinearForm][pydvl.influence.torch.base.OperatorBilinearForm]. + """ return OperatorBilinearForm(self) def apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: + """ + Applies the operator to a single vector. + Args: + vec: A single vector consistent to the operator, i.e. it's length + must be equal to the property `input_size`. + + Returns: + A single vector after applying the batch operation + """ return self._apply_to_vec(vec.to(self.device)) def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the operator to a matrix. + Args: + mat: A matrix to apply the operator to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ return torch.func.vmap(self.apply_to_vec, in_dims=0, randomness="same")(mat) +TorchOperatorType = TypeVar("TorchOperatorType", bound=TorchOperator) + + class TorchOperatorGradientComposition( OperatorGradientComposition[ - torch.Tensor, TorchBatch, TorchOperator, TorchPerSampleGradientProvider + torch.Tensor, TorchBatch, TorchOperatorType, TorchGradientProvider ] ): """ @@ -422,7 +422,7 @@ class TorchOperatorGradientComposition( an abstract operator and gradient provider. """ - def __init__(self, op: TorchOperator, gp: TorchPerSampleGradientProvider): + def __init__(self, op: TorchOperatorType, gp: TorchGradientProvider): super().__init__(op, gp) def to(self, device: torch.device): @@ -432,7 +432,9 @@ def to(self, device: torch.device): class TorchBlockMapper( - BlockMapper[torch.Tensor, TorchBatch, TorchOperatorGradientComposition] + BlockMapper[ + torch.Tensor, TorchBatch, TorchOperatorGradientComposition[TorchOperatorType] + ] ): """ Class for mapping operations across multiple compositional blocks represented by @@ -469,24 +471,31 @@ def to(self, device: torch.device): class TorchComposableInfluence( - ComposableInfluence[torch.Tensor, TorchBatch, DataLoader, TorchBlockMapper], + ComposableInfluence[ + torch.Tensor, TorchBatch, DataLoader, TorchBlockMapper[TorchOperatorType] + ], ModelInfoMixin, ABC, ): + """ + Abstract base class, that allow for block-wise computation of influence + quantities with the [torch][torch] framework. + Inherit from this base class for specific influence algorithms. + """ + def __init__( self, model: torch.nn.Module, - block_structure: Union[ - BlockMode, OrderedDict[str, OrderedDict[str, torch.nn.Parameter]] - ] = BlockMode.FULL, + block_structure: Union[BlockMode, OrderedDict[str, List[str]]] = BlockMode.FULL, regularization: Optional[Union[float, Dict[str, Optional[float]]]] = None, ): + parameter_dict_builder = ModelParameterDictBuilder(model) if isinstance(block_structure, BlockMode): - self.parameter_dict = ModelParameterDictBuilder(model).build( + self.parameter_dict = parameter_dict_builder.build_from_block_mode( block_structure ) else: - self.parameter_dict = block_structure + self.parameter_dict = parameter_dict_builder.build(block_structure) self._regularization_dict = self._build_regularization_dict(regularization) diff --git a/src/pydvl/influence/torch/batch_operation.py b/src/pydvl/influence/torch/batch_operation.py index 4a881de6e..4c7075924 100644 --- a/src/pydvl/influence/torch/batch_operation.py +++ b/src/pydvl/influence/torch/batch_operation.py @@ -1,39 +1,48 @@ +r""" +This module contains abstractions and implementations for operations carried out on a +batch $b$. These operations are of the form + +$$ m(b) \cdot v$$, + +where $m(b)$ is a matrix defined by the data in the batch and $v$ is a vector or matrix. +These batch operations can be used to conveniently build aggregations or recursions +over sequence of batches, e.g. an average of the form + +$$ \frac{1}{|B|} \sum_{b in B}m(b)\cdot v$$, + +which is useful in the case that keeping $B$ in memory is not feasible. + +""" from abc import ABC, abstractmethod -from typing import Callable, Dict, Optional, Type, Union +from typing import Callable, Dict, Optional, Type, TypeVar, Union import torch from .base import ( GradientProviderFactoryType, + TorchAutoGrad, TorchBatch, - TorchPerSampleAutoGrad, - TorchPerSampleGradientProvider, + TorchGradientProvider, ) from .functional import create_batch_hvp_function from .util import LossType, inverse_rank_one_update, rank_one_mvp class BatchOperation(ABC): - def __init__(self, regularization: float = 0.0): - if regularization < 0: - raise ValueError("regularization must be non-negative") - self._regularization = regularization + r""" + Abstract base class to implement operations of the form + + $$ m(b) \cdot v $$ + + where $m(b)$ is a matrix defined by the data in the batch and $v$ is a vector + or matrix. + """ @property @abstractmethod - def n_parameters(self): + def input_size(self): pass - @property - def regularization(self): - return self._regularization - - @regularization.setter - def regularization(self, value: float): - if value < 0: - raise ValueError("regularization must be non-negative") - self._regularization = value - @property @abstractmethod def device(self): @@ -53,9 +62,32 @@ def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: pass def apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor): + """ + Applies the batch operation to a single vector. + Args: + batch: Batch of data for computation + vec: A single vector consistent to the operation, i.e. it's length + must be equal to the property `input_size`. + + Returns: + A single vector after applying the batch operation + """ return self._apply_to_vec(batch.to(self.device), vec.to(self.device)) def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the batch operation to a matrix. + Args: + batch: Batch of data for computation + mat: A matrix to apply the batch operation to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ return torch.func.vmap( lambda _x, _y, m: self._apply_to_vec(TorchBatch(_x, _y), m), in_dims=(None, None, 0), @@ -64,20 +96,25 @@ def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: class ModelBasedBatchOperation(BatchOperation, ABC): + r""" + Abstract base class to implement operations of the form + + $$ m(\text{model}, b) \cdot v $$ + + where model is a [torch.nn.Module][torch.nn.Module]. + + """ + def __init__( self, model: torch.nn.Module, - loss: LossType, - regularization: float = 0.0, restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, ): - super().__init__(regularization) if restrict_to is None: restrict_to = { k: p.detach() for k, p in model.named_parameters() if p.requires_grad } self.params_to_restrict_to = restrict_to - self.loss = loss self.model = model @property @@ -89,7 +126,7 @@ def dtype(self): return next(self.model.parameters()).dtype @property - def n_parameters(self): + def input_size(self): return sum(p.numel() for p in self.params_to_restrict_to.values()) def to(self, device: torch.device): @@ -103,17 +140,36 @@ def to(self, device: torch.device): class HessianBatchOperation(ModelBasedBatchOperation): + r""" + Given a model and loss function computes the Hessian vector or matrix product + with respect to the model parameters, i.e. + + \begin{align*} + &\nabla^2_{\theta} L(b;\theta) \cdot v \\\ + &L(b;\theta) = \left( \frac{1}{|b|} \sum_{(x,y) \in b} + \text{loss}(\text{model}(x; \theta), y)\right), + \end{align*} + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix. + + Args: + model: The model. + loss: The loss function. + restrict_to: The parameters to restrict the second order differentiation to, + i.e. the corresponding sub-matrix of the Hessian. If None, the full Hessian + is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + reverse_only: If True only the reverse mode is used in the autograd computation. + """ + def __init__( self, model: torch.nn.Module, loss: LossType, - regularization: float = 0.0, restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, reverse_only: bool = True, ): - super().__init__( - model, loss, regularization=regularization, restrict_to=restrict_to - ) + super().__init__(model, restrict_to=restrict_to) self._batch_hvp = create_batch_hvp_function( model, loss, reverse_only=reverse_only ) @@ -123,34 +179,68 @@ def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: class GaussNewtonBatchOperation(ModelBasedBatchOperation): + r""" + Given a model and loss function computes the Gauss-Newton vector or matrix product + with respect to the model parameters, i.e. + + \begin{align*} + G(\text{model}, \text{loss}, b, \theta) &\cdot v, \\\ + G(\text{model}, \text{loss}, b, \theta) &= + \frac{1}{|b|}\sum_{(x, y) \in b}\nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix. + + Args: + model: The model. + loss: The loss function. + gradient_provider_factory: An optional factory to create an object of type + [TorchGradientProvider][pydvl.influence.torch.base.TorchGradientProvider], + depending on the model, loss and optional parameters to restrict to. + If not provided, the implementation + [TorchAutograd][pydvl.influence.torch.base.TorchAutograd] is used. + restrict_to: The parameters to restrict the differentiation to, + i.e. the corresponding sub-matrix of the Jacobian. If None, the full + Jacobian is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + def __init__( self, model: torch.nn.Module, loss: LossType, - regularization: float = 0.0, gradient_provider_factory: Union[ GradientProviderFactoryType, - Type[TorchPerSampleGradientProvider], - ] = TorchPerSampleAutoGrad, + Type[TorchGradientProvider], + ] = TorchAutoGrad, restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, ): - super().__init__( - model, loss, regularization=regularization, restrict_to=restrict_to - ) + super().__init__(model, restrict_to=restrict_to) self.gradient_provider = gradient_provider_factory( model, loss, self.params_to_restrict_to ) def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: - flat_grads = self.gradient_provider.per_sample_flat_gradient(batch) + flat_grads = self.gradient_provider.flat_grads(batch) result = rank_one_mvp(flat_grads, vec) - - if self.regularization > 0.0: - result += self.regularization * vec - return result def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the batch operation to a matrix. + Args: + batch: Batch of data for computation + mat: A matrix to apply the batch operation to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ return self.apply_to_vec(batch, mat) def to(self, device: torch.device): @@ -159,6 +249,52 @@ def to(self, device: torch.device): class InverseHarmonicMeanBatchOperation(ModelBasedBatchOperation): + r""" + Given a model and loss function computes an approximation of the inverse + Gauss-Newton vector or matrix product. Viewing the damped Gauss-newton matrix + + \begin{align*} + G_{\lambda}(\text{model}, \text{loss}, b, \theta) &= + \frac{1}{|b|}\sum_{(x, y) \in b}\nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t + \lambda \operatorname{I}, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + as an arithmetic mean of the rank-$1$ updates, this operation replaces it with + the harmonic mean of the rank-$1$ updates, i.e. + + $$ \tilde{G}_{\lambda}(\text{model}, \text{loss}, b, \theta) = + \left(n \sum_{(x, y) \in b} \left( \nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t + \lambda \operatorname{I}\right)^{-1} + \right)^{-1}$$ + + and computes + + $$ \tilde{G}_{\lambda}^{-1}(\text{model}, \text{loss}, b, \theta) + \cdot v.$$ + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix. + In other words, it switches the order of summation and inversion, which resolves + to the `inverse harmonic mean` of the rank-$1$ updates. + + The inverses of the rank-$1$ updates are not calculated explicitly, + but instead a vectorized version of the + [Sherman–Morrison formula]( + https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) + is applied. + + For more information, + see [Inverse Harmonic Mean][inverse-harmonic-mean]. + + Args: + model: The model. + loss: The loss function. + restrict_to: The parameters to restrict the differentiation to, + i.e. the corresponding sub-matrix of the Jacobian. If None, the full + Jacobian is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + def __init__( self, model: torch.nn.Module, @@ -166,32 +302,56 @@ def __init__( regularization: float, gradient_provider_factory: Union[ GradientProviderFactoryType, - Type[TorchPerSampleGradientProvider], - ] = TorchPerSampleAutoGrad, + Type[TorchGradientProvider], + ] = TorchAutoGrad, restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, ): if regularization <= 0: raise ValueError("regularization must be positive") - - super().__init__( - model, loss, regularization=regularization, restrict_to=restrict_to - ) self.regularization = regularization + + super().__init__(model, restrict_to=restrict_to) self.gradient_provider = gradient_provider_factory( model, loss, self.params_to_restrict_to ) + @property + def regularization(self): + return self._regularization + + @regularization.setter + def regularization(self, value: float): + if value <= 0: + raise ValueError("regularization must be positive") + self._regularization = value + def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: - grads = self.gradient_provider.per_sample_flat_gradient(batch) + grads = self.gradient_provider.flat_grads(batch) return ( inverse_rank_one_update(grads, vec, self.regularization) / self.regularization ) def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the batch operation to a matrix. + Args: + batch: Batch of data for computation + mat: A matrix to apply the batch operation to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ return self.apply_to_vec(batch, mat) def to(self, device: torch.device): super().to(device) self.gradient_provider.params_to_restrict_to = self.params_to_restrict_to return self + + +BatchOperationType = TypeVar("BatchOperationType", bound=BatchOperation) diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index 96d57632e..42be0149f 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -24,9 +24,9 @@ UnsupportedInfluenceModeException, ) from .base import ( + TorchAutoGrad, TorchComposableInfluence, TorchOperatorGradientComposition, - TorchPerSampleAutoGrad, ) from .functional import ( LowRankProductRepresentation, @@ -57,6 +57,7 @@ "ArnoldiInfluence", "EkfacInfluence", "NystroemSketchInfluence", + "InverseHarmonicMeanInfluence", ] logger = logging.getLogger(__name__) @@ -1802,23 +1803,113 @@ def fit(self, data: DataLoader): return self -class InverseHarmonicMeanInfluence(TorchComposableInfluence): +class InverseHarmonicMeanInfluence( + TorchComposableInfluence[InverseHarmonicMeanOperator] +): + r""" + This implementation replaces the inverse Hessian matrix in the influence computation + an approximation of the inverse Gauss-Newton vector product. + + Viewing the damped Gauss-newton matrix + + \begin{align*} + G_{\lambda}(\theta) &= + \frac{1}{N}\sum_{i}^N\nabla_{\theta}\ell (x_i,y_i; \theta) + \nabla_{\theta}\ell (x_i, y_i; \theta)^t + \lambda \operatorname{I}, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + as an arithmetic mean of the rank-$1$ updates, this implementation replaces it with + the harmonic mean of the rank-$1$ updates, i.e. + + $$ \tilde{G}_{\lambda}(\theta) = + \left(N \cdot \sum_{i=1}^N \left( \nabla_{\theta}\ell (x_i,y_i; \theta) + \nabla_{\theta}\ell (x_i,y_i; \theta)^t + + \lambda \operatorname{I}\right)^{-1} + \right)^{-1}$$ + + and uses the matrix + + $$ \tilde{G}_{\lambda}^{-1}(\theta)$$ + + instead of the inverse Hessian. + + In other words, it switches the order of summation and inversion, which resolves + to the `inverse harmonic mean` of the rank-$1$ updates. The results are averaged + over the batches provided by the data loader. + + The inverses of the rank-$1$ updates are not calculated explicitly, + but instead a vectorized version of the + [Sherman–Morrison formula]( + https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) + is applied. + + For more information, + see [Inverse Harmonic Mean][inverse-harmonic-mean]. + + Block-mode: + This implementation is capable of using a block-matrix approximation. The + blocking structure can be specified via the `block_structure` parameter. + The `block_structure` parameter can either be a + [BlockMode][pydvl.influence.torch.util.BlockMode] enum (which provides + layer-wise or parameter-wise blocking) or a custom block structure defined + by an ordered dictionary with the keys being the block identifiers (arbitrary + strings) and the values being lists of parameter names contained in the block. + + ```python + block_structure = OrderedDict( + ( + ("custom_block1", ["0.weight", "1.bias"]), + ("custom_block2", ["1.weight", "0.bias"]), + ) + ) + ``` + + If you would like to apply a block-specific regularization, you can provide a + dictionary with the block names as keys and the regularization values as values. + In this case, the specification must be complete, i.e. every block must have + a positive regularization value. + + ```python + regularization = { + "custom_block1": 0.1, + "custom_block2": 0.2, + } + ``` + Accordingly, if you choose a layer-wise or parameter-wise structure + (by providing `BlockMode.LAYER_WISE` or `BlockMode.PARAMETER_WISE` for + `block_structure`) the keys must be the layer names or parameter names, + respectively. + + You can retrieve the block-wise influence information from the methods + with suffix `_by_block`. By default, `block_structure` is set to + `BlockMode.FULL` and in this case these methods will return a dictionary + with the empty string being the only key. + + + Args: + model: The model. + loss: The loss function. + regularization: The regularization parameter. In case a dictionary is provided, + the keys must match the blocking structure. + block_structure: The blocking structure, either a pre-defined enum or a + custom block structure, see the information regarding block-mode. + """ + def __init__( self, model: torch.nn.Module, loss: LossType, regularization: Union[float, Dict[str, Optional[float]]], - block_structure: Union[ - BlockMode, OrderedDict[str, OrderedDict[str, torch.Tensor]] - ] = BlockMode.FULL, + block_structure: Union[BlockMode, OrderedDict[str, List[str]]] = BlockMode.FULL, ): super().__init__(model, block_structure, regularization=regularization) - self.gradient_provider_factory = TorchPerSampleAutoGrad + self.gradient_provider_factory = TorchAutoGrad self.loss = loss @property def n_parameters(self): - return super().n_parameters() + return sum(block.op.input_size for _, block in self.block_mapper.items()) @property def is_thread_safe(self) -> bool: @@ -1858,6 +1949,16 @@ def _create_block( def with_regularization( self, regularization: Union[float, Dict[str, Optional[float]]] ) -> TorchComposableInfluence: + """ + Update the regularization parameter. + Args: + regularization: Either a positive float or a dictionary with the + block names as keys and the regularization values as values. + + Returns: + The modified instance + + """ self._regularization_dict = self._build_regularization_dict(regularization) for k, reg in self._regularization_dict.items(): self.block_mapper.composable_block_dict[k].op.regularization = reg diff --git a/src/pydvl/influence/torch/operator.py b/src/pydvl/influence/torch/operator.py index 5b58b9f70..81d6b0442 100644 --- a/src/pydvl/influence/torch/operator.py +++ b/src/pydvl/influence/torch/operator.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Generator, Optional, Type, Union +from typing import Callable, Dict, Generator, Generic, Optional, Type, Union import torch from torch import nn as nn @@ -7,13 +7,13 @@ from ..array import LazyChunkSequence, SequenceAggregator from .base import ( GradientProviderFactoryType, + TorchAutoGrad, TorchBatch, + TorchGradientProvider, TorchOperator, - TorchPerSampleAutoGrad, - TorchPerSampleGradientProvider, ) from .batch_operation import ( - BatchOperation, + BatchOperationType, GaussNewtonBatchOperation, HessianBatchOperation, InverseHarmonicMeanBatchOperation, @@ -21,17 +21,30 @@ from .util import TorchChunkAverageAggregator, TorchPointAverageAggregator -class AggregateBatchOperator(TorchOperator): +class AggregateBatchOperator(TorchOperator, Generic[BatchOperationType]): + """ + Class for aggregating batch operations over a dataset using a provided data loader + and aggregator. + + This class facilitates the application of a batch operation across multiple batches + of data, aggregating the results using a specified sequence aggregator. + + Attributes: + batch_operation: The batch operation to apply. + dataloader: The data loader providing batches of data. + aggregator: The sequence aggregator to aggregate the results of the batch + operations. + """ + def __init__( self, - batch_operation: BatchOperation, + batch_operation: BatchOperationType, dataloader: DataLoader, aggregator: SequenceAggregator[torch.Tensor], ): self.batch_operation = batch_operation self.dataloader = dataloader self.aggregator = aggregator - super().__init__(self.batch_operation.regularization) @property def device(self): @@ -45,23 +58,26 @@ def to(self, device: torch.device): self.batch_operation = self.batch_operation.to(device) return self - @property - def regularization(self): - return self._regularization - - @regularization.setter - def regularization(self, value: float): - self._regularization = value - self.batch_operation.regularization = value - @property def input_size(self): - return self.batch_operation.n_parameters + return self.batch_operation.input_size def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: return self._apply(vec, self.batch_operation.apply_to_vec) def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the operator to a matrix. + Args: + mat: A matrix to apply the operator to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ return self._apply(mat, self.batch_operation.apply_to_mat) def _apply( @@ -83,7 +99,32 @@ def tensor_gen_factory() -> Generator[torch.Tensor, None, None]: return self.aggregator(lazy_tensor_sequence) -class GaussNewtonOperator(AggregateBatchOperator): +class GaussNewtonOperator(AggregateBatchOperator[GaussNewtonBatchOperation]): + r""" + Given a model and loss function computes the Gauss-Newton vector or matrix product + with respect to the model parameters on a batch, i.e. + + \begin{align*} + G(\text{model}, \text{loss}, b, \theta) &\cdot v, \\\ + G(\text{model}, \text{loss}, b, \theta) &= + \frac{1}{|b|}\sum_{(x, y) \in b}\nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix, + and average the results over the batches provided by the data loader. + + Args: + model: The model. + loss: The loss function. + dataloader: The data loader providing batches of data. + restrict_to: The parameters to restrict the differentiation to, + i.e. the corresponding sub-matrix of the Jacobian. If None, the full + Jacobian is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + def __init__( self, model: nn.Module, @@ -91,8 +132,8 @@ def __init__( dataloader: DataLoader, gradient_provider_factory: Union[ GradientProviderFactoryType, - Type[TorchPerSampleGradientProvider], - ] = TorchPerSampleAutoGrad, + Type[TorchGradientProvider], + ] = TorchAutoGrad, restrict_to: Optional[Dict[str, nn.Parameter]] = None, ): batch_op = GaussNewtonBatchOperation( @@ -105,7 +146,31 @@ def __init__( super().__init__(batch_op, dataloader, aggregator) -class HessianOperator(AggregateBatchOperator): +class HessianOperator(AggregateBatchOperator[HessianBatchOperation]): + r""" + Given a model and loss function computes the Hessian vector or matrix product + with respect to the model parameters for a given batch, i.e. + + \begin{align*} + &\nabla^2_{\theta} L(b;\theta) \cdot v \\\ + &L(b;\theta) = \left( \frac{1}{|b|} \sum_{(x,y) \in b} + \text{loss}(\text{model}(x; \theta), y)\right), + \end{align*} + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix, + and average the results over the batches provided by the data loader. + + Args: + model: The model. + loss: The loss function. + dataloader: The data loader providing batches of data. + restrict_to: The parameters to restrict the second order differentiation to, + i.e. the corresponding sub-matrix of the Hessian. If None, the full Hessian + is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + reverse_only: If True only the reverse mode is used in the autograd computation. + """ + def __init__( self, model: nn.Module, @@ -121,7 +186,61 @@ def __init__( super().__init__(batch_op, dataloader, aggregator) -class InverseHarmonicMeanOperator(AggregateBatchOperator): +class InverseHarmonicMeanOperator( + AggregateBatchOperator[InverseHarmonicMeanBatchOperation] +): + r""" + Given a model and loss function computes an approximation of the inverse + Gauss-Newton vector or matrix product per batch and averages the results. + + Viewing the damped Gauss-newton matrix + + \begin{align*} + G_{\lambda}(\text{model}, \text{loss}, b, \theta) &= + \frac{1}{|b|}\sum_{(x, y) \in b}\nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t + \lambda \operatorname{I}, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + as an arithmetic mean of the rank-$1$ updates, this operator replaces it with + the harmonic mean of the rank-$1$ updates, i.e. + + $$ \tilde{G}_{\lambda}(\text{model}, \text{loss}, b, \theta) = + \left(n \sum_{(x, y) \in b} \left( \nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t + \lambda \operatorname{I}\right)^{-1} + \right)^{-1}$$ + + and computes + + $$ \tilde{G}_{\lambda}^{-1}(\text{model}, \text{loss}, b, \theta) + \cdot v.$$ + + for any given batch $b$, + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix. + + In other words, it switches the order of summation and inversion, which resolves + to the `inverse harmonic mean` of the rank-$1$ updates. The results are averaged + over the batches provided by the data loader. + + The inverses of the rank-$1$ updates are not calculated explicitly, + but instead a vectorized version of the + [Sherman–Morrison formula]( + https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) + is applied. + + For more information, + see [Inverse Harmonic Mean][inverse-harmonic-mean]. + + Args: + model: The model. + loss: The loss function. + dataloader: The data loader providing batches of data. + restrict_to: The parameters to restrict the differentiation to, + i.e. the corresponding sub-matrix of the Jacobian. If None, the full + Jacobian is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + def __init__( self, model: nn.Module, @@ -130,10 +249,15 @@ def __init__( regularization: float, gradient_provider_factory: Union[ GradientProviderFactoryType, - Type[TorchPerSampleGradientProvider], - ] = TorchPerSampleAutoGrad, + Type[TorchGradientProvider], + ] = TorchAutoGrad, restrict_to: Optional[Dict[str, nn.Parameter]] = None, ): + if regularization <= 0: + raise ValueError("regularization must be positive") + + self._regularization = regularization + batch_op = InverseHarmonicMeanBatchOperation( model, loss, @@ -143,3 +267,14 @@ def __init__( ) aggregator = TorchPointAverageAggregator(weighted=False) super().__init__(batch_op, dataloader, aggregator) + + @property + def regularization(self): + return self._regularization + + @regularization.setter + def regularization(self, value: float): + if value <= 0: + raise ValueError("regularization must be positive") + self._regularization = value + self.batch_operation.regularization = value diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index d7d4b0226..bae7e4372 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -2,6 +2,7 @@ import logging import math +import warnings from collections import OrderedDict from dataclasses import dataclass from enum import Enum @@ -59,6 +60,8 @@ "LossType", "ModelParameterDictBuilder", "BlockMode", + "ModelInfoMixin", + "safe_torch_linalg_eigh", ] @@ -665,6 +668,29 @@ def rank_one_mvp(x: torch.Tensor, v: torch.Tensor) -> torch.Tensor: def inverse_rank_one_update( x: torch.Tensor, v: torch.Tensor, regularization: float ) -> torch.Tensor: + r""" + Performs an inverse-rank one update on x and v. More precisely, it computes + + $$ \sum_{i=1}^n \left(x[i]x[i]^t+\lambda \operatorname{I}\right)^{-1}v $$ + + where $\operatorname{I}$ is the identity matrix and $\lambda$ is positive + regularization parameter. The inverse matrices are not calculated explicitly, + but instead a vectorized version of the + [Sherman–Morrison formula]( + https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) + is applied. + + Args: + x: Input matrix used for the rank one expressions. First dimension is + assumed to be the batch dimension. + v: Matrix to multiply with. First dimension is + assumed to be the batch dimension. + regularization: Regularization parameter to make the rank-one expressions + invertible, must be positive. + + Returns: + Matrix of size $(D, M)$ for x having shape $(N, D)$ and v having shape $(M, D)$. + """ nominator = torch.einsum("ij,kj->ki", x, v) denominator = x.shape[0] * (regularization + torch.sum(x**2, dim=1)) return (v - (nominator / denominator) @ x) / regularization @@ -674,6 +700,15 @@ def inverse_rank_one_update( class BlockMode(Enum): + """ + Enumeration for different modes of grouping model parameters. + + Attributes: + LAYER_WISE: Groups parameters by layers of the model. + PARAMETER_WISE: Groups parameters individually. + FULL: Groups all parameters together. + """ + LAYER_WISE: str = "layer_wise" PARAMETER_WISE: str = "parameter_wise" FULL: str = "full" @@ -681,6 +716,15 @@ class BlockMode(Enum): @dataclass class ModelParameterDictBuilder: + """ + A builder class for creating ordered dictionaries of model parameters based on + specified block modes or custom blocking structures. + + Attributes: + model: The neural network model. + detach: Whether to detach the parameters from the computation graph. + """ + model: torch.nn.Module detach: bool = True @@ -689,9 +733,58 @@ def _optional_detach(self, p: torch.nn.Parameter): return p.detach() return p - def build( + def build(self, block_structure: OrderedDict[str, List[str]]): + """ + Builds an ordered dictionary of model parameters based on the specified block + structure represented by an ordered dictionary, where the keys are block + identifiers and the values are lists of model parameter names contained in + this block. + + Args: + block_structure: The block structure specifying how to group the parameters. + + Returns: + An ordered dictionary of ordered dictionaries, where the outer dictionary's + keys are block identifiers and the inner dictionaries map parameter names + to parameters. + """ + parameter_dict = OrderedDict() + + for block_name, parameter_names in block_structure.items(): + inner_ordered_dict = OrderedDict() + for parameter_name in parameter_names: + parameter = self.model.state_dict()[parameter_name] + if parameter.requires_grad: + inner_ordered_dict[parameter_name] = self._optional_detach( + parameter + ) + else: + warnings.warn( + f"The parameter {parameter_name} from the block " + f"{block_name} is mark as not trainable in the model " + f"and will be excluded from the computation." + ) + parameter_dict[block_name] = inner_ordered_dict + + return parameter_dict + + def build_from_block_mode( self, block_mode: BlockMode ) -> OrderedDict[str, OrderedDict[str, torch.nn.Parameter]]: + """ + Builds an ordered dictionary of model parameters based on the specified block + mode or custom blocking structure represented by an ordered dictionary, where + the keys are block identifiers and the values are lists of model parameter names + contained in this block. + + Args: + block_mode: The block mode specifying how to group the parameters. + + Returns: + An ordered dictionary of ordered dictionaries, where the outer dictionary's + keys are block identifiers and the inner dictionaries map parameter names + to parameters. + """ parameter_dict = OrderedDict() if block_mode is BlockMode.FULL: diff --git a/src/pydvl/influence/types.py b/src/pydvl/influence/types.py index 0a4d5b0fd..5cb405fda 100644 --- a/src/pydvl/influence/types.py +++ b/src/pydvl/influence/types.py @@ -120,7 +120,7 @@ class GradientProvider(Generic[BatchType, TensorType], ABC): """ @abstractmethod - def per_sample_gradient_dict(self, batch: BatchType) -> Dict[str, TensorType]: + def grads(self, batch: BatchType) -> Dict[str, TensorType]: r""" Computes and returns a dictionary mapping gradient names to their respective per-sample gradients. Given the example in the class docstring, this means @@ -141,7 +141,7 @@ def per_sample_gradient_dict(self, batch: BatchType) -> Dict[str, TensorType]: """ @abstractmethod - def per_sample_mixed_gradient_dict(self, batch: BatchType) -> Dict[str, TensorType]: + def mixed_grads(self, batch: BatchType) -> Dict[str, TensorType]: r""" Computes and returns a dictionary mapping gradient names to their respective per-sample mixed gradients. In this context, mixed gradients refer to computing @@ -166,7 +166,7 @@ def per_sample_mixed_gradient_dict(self, batch: BatchType) -> Dict[str, TensorTy """ @abstractmethod - def matrix_jacobian_product( + def jacobian_prod( self, batch: BatchType, g: TensorType, @@ -193,7 +193,7 @@ def matrix_jacobian_product( """ @abstractmethod - def per_sample_flat_gradient(self, batch: BatchType) -> TensorType: + def flat_grads(self, batch: BatchType) -> TensorType: r""" Computes and returns the flat per-sample gradients for the provided batch. Given the example in the class docstring, this means @@ -215,7 +215,7 @@ def per_sample_flat_gradient(self, batch: BatchType) -> TensorType: """ @abstractmethod - def per_sample_flat_mixed_gradient(self, batch: BatchType) -> TensorType: + def flat_mixed_grads(self, batch: BatchType) -> TensorType: r""" Computes and returns the flat per-sample mixed gradients for the provided batch. Given the example in the class docstring, this means @@ -248,9 +248,7 @@ class BilinearForm(Generic[TensorType, BatchType, GradientProviderType], ABC): """ @abstractmethod - def inner_product( - self, left: TensorType, right: Optional[TensorType] - ) -> TensorType: + def inner_prod(self, left: TensorType, right: Optional[TensorType]) -> TensorType: r""" Computes the inner product of two vectors, i.e. @@ -273,7 +271,7 @@ def inner_product( A tensor representing the inner product. """ - def gradient_inner_product( + def grads_inner_prod( self, left: BatchType, right: Optional[BatchType], @@ -298,14 +296,14 @@ def gradient_inner_product( Returns: A tensor representing the inner products of the per-sample gradients """ - left_grad = gradient_provider.per_sample_flat_gradient(left) + left_grad = gradient_provider.flat_grads(left) if right is None: right_grad = left_grad else: - right_grad = gradient_provider.per_sample_flat_gradient(right) - return self.inner_product(left_grad, right_grad) + right_grad = gradient_provider.flat_grads(right) + return self.inner_prod(left_grad, right_grad) - def mixed_gradient_inner_product( + def mixed_grads_inner_prod( self, left: BatchType, right: BatchType, gradient_provider: GradientProviderType ) -> TensorType: r""" @@ -327,9 +325,9 @@ def mixed_gradient_inner_product( Returns: A tensor representing the inner products of the mixed per-sample gradients """ - left_grad = gradient_provider.per_sample_flat_gradient(left) - right_mixed_grad = gradient_provider.per_sample_flat_mixed_gradient(right) - return self.inner_product(left_grad, right_mixed_grad) + left_grad = gradient_provider.flat_grads(left) + right_mixed_grad = gradient_provider.flat_mixed_grads(right) + return self.inner_prod(left_grad, right_mixed_grad) BilinearFormType = TypeVar("BilinearFormType", bound=BilinearForm) @@ -414,7 +412,7 @@ def __init__(self, op: OperatorType, gp: GradientProviderType): self.gp = gp self.op = op - def gradient_interaction( + def interactions( self, left_batch: BatchType, right_batch: BatchType, mode: InfluenceMode ): r""" @@ -445,14 +443,10 @@ def gradient_interaction( """ bilinear_form = self.op.as_bilinear_form() if mode is InfluenceMode.Up: - return bilinear_form.gradient_inner_product( - left_batch, right_batch, self.gp - ) - return bilinear_form.mixed_gradient_inner_product( - left_batch, right_batch, self.gp - ) + return bilinear_form.grads_inner_prod(left_batch, right_batch, self.gp) + return bilinear_form.mixed_grads_inner_prod(left_batch, right_batch, self.gp) - def transformed_gradients(self, batch: BatchType): + def transformed_grads(self, batch: BatchType): r""" Computes the gradients of a data batch, transformed by the operator application , i.e. the expressions @@ -467,10 +461,10 @@ def transformed_gradients(self, batch: BatchType): A tensor representing the application of the operator to the gradients. """ - grads = self.gp.per_sample_flat_gradient(batch) + grads = self.gp.flat_grads(batch) return self.op.apply_to_mat(grads) - def interaction_from_transformed_gradients( + def interactions_from_transformed_grads( self, left_factors: TensorType, right_batch: BatchType, mode: InfluenceMode ): r""" @@ -499,16 +493,18 @@ def interaction_from_transformed_gradients( batch gradients. """ if mode is InfluenceMode.Up: - right_grads = self.gp.per_sample_flat_gradient(right_batch) + right_grads = self.gp.flat_grads(right_batch) else: - right_grads = self.gp.per_sample_flat_mixed_gradient(right_batch) - return self.op.as_bilinear_form().inner_product(left_factors, right_grads) + right_grads = self.gp.flat_mixed_grads(right_batch) + return self.op.as_bilinear_form().inner_prod(left_factors, right_grads) -ComposableBlockType = TypeVar("ComposableBlockType", bound=OperatorGradientComposition) +OperatorGradientCompositionType = TypeVar( + "OperatorGradientCompositionType", bound=OperatorGradientComposition +) -class BlockMapper(Generic[TensorType, BatchType, ComposableBlockType], ABC): +class BlockMapper(Generic[TensorType, BatchType, OperatorGradientCompositionType], ABC): """ Abstract base class for mapping operations across multiple compositional blocks. @@ -521,9 +517,17 @@ class BlockMapper(Generic[TensorType, BatchType, ComposableBlockType], ABC): interactions. """ - def __init__(self, composable_block_dict: OrderedDict[str, ComposableBlockType]): + def __init__( + self, composable_block_dict: OrderedDict[str, OperatorGradientCompositionType] + ): self.composable_block_dict = composable_block_dict + def __getitem__(self, item: str): + return self.composable_block_dict[item] + + def items(self): + return self.composable_block_dict.items() + def _to_ordered_dict( self, tensor_generator: Generator[TensorType, None, None] ) -> OrderedDict[str, TensorType]: @@ -539,7 +543,7 @@ def _split_to_blocks( """Must be implemented in a way to preserve the ordering defined by the `composable_block_dict` attribute""" - def block_transformed_gradients( + def transformed_grads( self, batch: BatchType, ) -> OrderedDict[str, TensorType]: @@ -553,10 +557,10 @@ def block_transformed_gradients( Returns: An ordered dictionary of transformed gradients by block. """ - tensor_gen = self.generate_transformed_gradients(batch) + tensor_gen = self.generate_transformed_grads(batch) return self._to_ordered_dict(tensor_gen) - def block_interactions( + def interactions( self, left_batch: BatchType, right_batch: BatchType, mode: InfluenceMode ) -> OrderedDict[str, TensorType]: """ @@ -571,10 +575,10 @@ def block_interactions( Returns: An ordered dictionary of gradient interactions by block. """ - tensor_gen = self.generate_gradient_interactions(left_batch, right_batch, mode) + tensor_gen = self.generate_interactions(left_batch, right_batch, mode) return self._to_ordered_dict(tensor_gen) - def block_interactions_from_transformed_gradients( + def interactions_from_transformed_grads( self, left_factors: OrderedDict[str, TensorType], right_batch: BatchType, @@ -594,12 +598,12 @@ def block_interactions_from_transformed_gradients( Returns: An ordered dictionary of interactions from transformed gradients by block. """ - tensor_gen = self.generate_interactions_from_transformed_gradients( + tensor_gen = self.generate_interactions_from_transformed_grads( left_factors, right_batch, mode ) return self._to_ordered_dict(tensor_gen) - def generate_transformed_gradients( + def generate_transformed_grads( self, batch: BatchType ) -> Generator[TensorType, None, None]: """ @@ -613,9 +617,9 @@ def generate_transformed_gradients( Transformed gradients for each block. """ for comp_block in self.composable_block_dict.values(): - yield comp_block.transformed_gradients(batch) + yield comp_block.transformed_grads(batch) - def generate_gradient_interactions( + def generate_interactions( self, left_batch: BatchType, right_batch: BatchType, mode: InfluenceMode ) -> Generator[TensorType, None, None]: """ @@ -631,9 +635,9 @@ def generate_gradient_interactions( TensorType: Gradient interactions for each block. """ for comp_block in self.composable_block_dict.values(): - yield comp_block.gradient_interaction(left_batch, right_batch, mode) + yield comp_block.interactions(left_batch, right_batch, mode) - def generate_interactions_from_transformed_gradients( + def generate_interactions_from_transformed_grads( self, left_factors: Union[TensorType, OrderedDict[str, TensorType]], right_batch: BatchType, @@ -657,7 +661,7 @@ def generate_interactions_from_transformed_gradients( else: left_factors_dict = cast(OrderedDict[str, TensorType], left_factors) for k, comp_block in self.composable_block_dict.items(): - yield comp_block.interaction_from_transformed_gradients( + yield comp_block.interactions_from_transformed_grads( left_factors_dict[k], right_batch, mode ) diff --git a/tests/influence/test_influence_calculator.py b/tests/influence/test_influence_calculator.py index 854321f8f..9a82e89cf 100644 --- a/tests/influence/test_influence_calculator.py +++ b/tests/influence/test_influence_calculator.py @@ -28,10 +28,6 @@ EkfacInfluence, ) from pydvl.influence.torch.influence_function_model import NystroemSketchInfluence -from pydvl.influence.torch.pre_conditioner import ( - JacobiPreConditioner, - NystroemPreConditioner, -) from pydvl.influence.torch.util import ( NestedTorchCatAggregator, TorchCatAggregator, diff --git a/tests/influence/torch/test_batch_operation.py b/tests/influence/torch/test_batch_operation.py index b6141a6c8..304de4518 100644 --- a/tests/influence/torch/test_batch_operation.py +++ b/tests/influence/torch/test_batch_operation.py @@ -2,7 +2,6 @@ import pytest import torch -from influence.torch.test_util import model_data, test_parameters from pydvl.influence.torch.base import TorchBatch from pydvl.influence.torch.batch_operation import ( @@ -10,6 +9,8 @@ HessianBatchOperation, ) +from .test_util import model_data, test_parameters + @pytest.mark.torch @pytest.mark.parametrize( diff --git a/tests/influence/torch/test_gradient_provider.py b/tests/influence/torch/test_gradient_provider.py index c15814e37..8fab7fadf 100644 --- a/tests/influence/torch/test_gradient_provider.py +++ b/tests/influence/torch/test_gradient_provider.py @@ -1,11 +1,12 @@ import numpy as np import pytest import torch -from influence.conftest import linear_mixed_second_derivative_analytical, linear_model -from influence.torch.conftest import DATA_OUTPUT_NOISE, linear_mvp_model from pydvl.influence.torch.base import TorchAutoGrad, TorchBatch +from ..conftest import linear_mixed_second_derivative_analytical, linear_model +from .conftest import DATA_OUTPUT_NOISE, linear_mvp_model + class TestTorchPerSampleAutograd: @pytest.mark.torch @@ -22,8 +23,8 @@ def test_per_sample_gradient(self, in_features, out_features, batch_size): params = {k: p.detach() for k, p in model.named_parameters() if p.requires_grad} gp = TorchAutoGrad(model, loss, restrict_to=params) - gradients = gp.per_sample_gradient_dict(TorchBatch(x, y)) - flat_gradients = gp.per_sample_flat_gradient(TorchBatch(x, y)) + gradients = gp.grads(TorchBatch(x, y)) + flat_gradients = gp.flat_grads(TorchBatch(x, y)) # Compute analytical gradients y_pred = model(x) @@ -69,7 +70,7 @@ def test_mixed_derivatives(self, in_features, out_features, train_set_size): torch_train_x = torch.as_tensor(train_x) torch_train_y = torch.as_tensor(train_y) gp = TorchAutoGrad(model, loss, restrict_to=params) - flat_functorch_mixed_derivatives = gp.per_sample_flat_mixed_gradient( + flat_functorch_mixed_derivatives = gp.flat_mixed_grads( TorchBatch(torch_train_x, torch_train_y) ) assert torch.allclose( @@ -95,7 +96,7 @@ def test_matrix_jacobian_product( gp = TorchAutoGrad(model, torch.nn.functional.mse_loss, restrict_to=params) G = torch.randn((10, out_features * (in_features + 1))) - mjp = gp.matrix_jacobian_product(TorchBatch(x, y), G) + mjp = gp.jacobian_prod(TorchBatch(x, y), G) dL_dw = torch.vmap( lambda r, s, t: 2 diff --git a/tests/influence/torch/test_util.py b/tests/influence/torch/test_util.py index 82510e640..a3c1caa29 100644 --- a/tests/influence/torch/test_util.py +++ b/tests/influence/torch/test_util.py @@ -387,7 +387,7 @@ def test_build(self, block_mode, model): model=model, detach=True, ) - param_dict = builder.build(block_mode) + param_dict = builder.build_from_block_mode(block_mode) if block_mode is BlockMode.FULL: assert "" in param_dict From f00750785b587cf39277e9ee09d2aa0093299299 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Mon, 27 May 2024 15:12:16 +0200 Subject: [PATCH 24/43] Add inverse rank-1 update approximation for dictionary input --- src/pydvl/influence/torch/util.py | 26 ++++++++++++++++++++++++++ tests/influence/torch/test_util.py | 26 ++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index bae7e4372..a3553af5c 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -696,6 +696,32 @@ def inverse_rank_one_update( return (v - (nominator / denominator) @ x) / regularization +def inverse_rank_one_update_dict( + x: Dict[str, torch.Tensor], v: Dict[str, torch.Tensor], regularization: float +) -> Dict[str, torch.Tensor]: + + denominator = regularization + nominator = None + batch_size = None + for x_, v_ in zip(x.values(), v.values()): + if batch_size is None: + batch_size = x_.shape[0] + if nominator is None: + nominator = torch.einsum("i..., k...->ki", x_, v_) + else: + nominator += torch.einsum("i..., k...->ki", x_, v_) + denominator += torch.sum(x_.view(x_.shape[0], -1) ** 2, dim=1) + denominator = batch_size * denominator + + result = {} + for key in x.keys(): + result[key] = ( + v[key] - torch.einsum("ji, i... -> j...", nominator / denominator, x[key]) + ) / regularization + + return result + + LossType = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] diff --git a/tests/influence/torch/test_util.py b/tests/influence/torch/test_util.py index a3c1caa29..fc80ad805 100644 --- a/tests/influence/torch/test_util.py +++ b/tests/influence/torch/test_util.py @@ -24,6 +24,7 @@ align_structure, flatten_dimensions, inverse_rank_one_update, + inverse_rank_one_update_dict, rank_one_mvp, safe_torch_linalg_eigh, torch_dataset_to_dask_array, @@ -369,6 +370,31 @@ def test_inverse_rank_one_update(x_dim_0, x_dim_1, v_dim_0, reg): assert torch.allclose(result, inverse_result, atol=1e-5) +@pytest.mark.torch +@pytest.mark.parametrize( + "x_dim_1", + [{"1": (4, 2, 3), "2": (5, 7), "3": ()}, {"1": (3, 6, 8, 9), "2": (1, 2)}, {"1": (1,)}], +) +@pytest.mark.parametrize( + "x_dim_0, v_dim_0", + [(10, 12), (3, 5), (4, 10), (6, 6), (1, 7)], +) +@pytest.mark.parametrize("reg", [0.5, 100, 1.0, 10]) +def test_inverse_rank_one_update(x_dim_0, x_dim_1, v_dim_0, reg): + X_dict = {k: torch.randn(x_dim_0, *d) for k, d in x_dim_1.items()} + V_dict = {k: torch.randn(v_dim_0, *d) for k, d in x_dim_1.items()} + + X = flatten_dimensions(X_dict.values(), shape=(x_dim_0, -1)) + V = flatten_dimensions(V_dict.values(), shape=(v_dim_0, -1)) + result = inverse_rank_one_update(X, V, reg) + + inverse_result = flatten_dimensions( + inverse_rank_one_update_dict(X_dict, V_dict, reg).values(), shape=(v_dim_0, -1) + ) + + assert torch.allclose(result, inverse_result, atol=1e-5, rtol=1e-3) + + class TestModelParameterDictBuilder: class SimpleModel(torch.nn.Module): def __init__(self): From ed5f14d4f276184175b6c4fc44967a448e97b013 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 29 May 2024 11:47:00 +0200 Subject: [PATCH 25/43] Refactor new structure: * renaming classes * add cocept 'TensorDictOperator', which can act on tensor dictioniaries, to avoid intermeditate flatten and concat to reduce memory consumption --- src/pydvl/influence/torch/base.py | 204 +++++++++++- src/pydvl/influence/torch/batch_operation.py | 311 ++++++++++++++---- src/pydvl/influence/torch/operator.py | 72 ++-- src/pydvl/influence/torch/util.py | 127 ++++--- tests/influence/test_influence_calculator.py | 2 - tests/influence/torch/test_batch_operation.py | 157 ++++++++- tests/influence/torch/test_util.py | 57 +++- 7 files changed, 747 insertions(+), 183 deletions(-) diff --git a/src/pydvl/influence/torch/base.py b/src/pydvl/influence/torch/base.py index 73a0f2933..9bc1dc225 100644 --- a/src/pydvl/influence/torch/base.py +++ b/src/pydvl/influence/torch/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, TypeVar, Union +from typing import Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast import torch from torch.func import functional_call @@ -13,6 +13,7 @@ from ..types import ( Batch, BilinearForm, + BilinearFormType, BlockMapper, GradientProvider, Operator, @@ -47,6 +48,9 @@ class TorchBatch(Batch): x: torch.Tensor y: torch.Tensor + def __iter__(self): + return iter((self.x, self.y)) + def __post_init__(self): if self.x.shape[0] != self.y.shape[0]: raise ValueError( @@ -310,7 +314,7 @@ class OperatorBilinearForm( def __init__( self, - operator: "TorchOperator", + operator: "TensorOperator", ): self.operator = operator @@ -345,7 +349,128 @@ def _inner_product(self, left: torch.Tensor, right: torch.Tensor) -> torch.Tenso return torch.einsum("ia,j...a->ij...", left_result, right) -class TorchOperator(Operator[torch.Tensor, OperatorBilinearForm], ABC): +class DictBilinearForm(OperatorBilinearForm): + r""" + Base class for bilinear forms based on an instance of + [TorchOperator][pydvl.influence.torch.operator.base.TorchOperator]. This means it + computes weighted inner products of the form: + + $$ \langle \operatorname{Op}(x), y \rangle $$ + + """ + + def __init__( + self, + operator: "TensorDictOperator", + ): + super().__init__(operator) + + def grads_inner_prod( + self, + left: TorchBatch, + right: Optional[TorchBatch], + gradient_provider: TorchGradientProvider, + ) -> torch.Tensor: + r""" + Computes the gradient inner product of two batches of data, i.e. + + $$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}), + \nabla_{\omega}\ell(\omega, \text{right.x}, \text{right.y}) \rangle_{B}$$ + + where $\nabla_{\omega}\ell(\omega, \cdot, \cdot)$ is represented by the + `gradient_provider` and the expression must be understood sample-wise. + + Args: + left: The first batch for gradient and inner product computation + right: The second batch for gradient and inner product computation, + optional; if not provided, the inner product will use the gradient + computed for `left` for both arguments. + gradient_provider: The gradient provider to compute the gradients. + + Returns: + A tensor representing the inner products of the per-sample gradients + """ + operator = cast(TensorDictOperator, self.operator) + left_grads = gradient_provider.grads(left) + if right is None: + right_grads = left_grads + else: + right_grads = gradient_provider.grads(right) + + left_batch_size, right_batch_size = next( + ( + (l.shape[0], r.shape[0]) + for r, l in zip(left_grads.values(), right_grads.values()) + ) + ) + + if left_batch_size <= right_batch_size: + left_grads = operator.apply_to_mat_dict(left_grads) + tensor_pairs = zip(left_grads.values(), right_grads.values()) + else: + right_grads = operator.apply_to_mat_dict(right_grads) + tensor_pairs = zip(left_grads.values(), right_grads.values()) + + tensors_to_reduce = ( + self._aggregate_grads(left, right) for left, right in tensor_pairs + ) + + return cast(torch.Tensor, sum(tensors_to_reduce)) + + def mixed_grads_inner_prod( + self, + left: TorchBatch, + right: TorchBatch, + gradient_provider: TorchGradientProvider, + ) -> torch.Tensor: + r""" + Computes the mixed gradient inner product of two batches of data, i.e. + + $$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}), + \nabla_{\omega}\nabla_{x}\ell(\omega, \text{right.x}, \text{right.y}) + \rangle_{B}$$ + + where $\nabla_{\omega}\ell(\omega, \cdot)$ and + $\nabla_{\omega}\nabla_{x}\ell(\omega, \cdot)$ are represented by the + `gradient_provider`. The expression must be understood sample-wise. + + Args: + left: The first batch for gradient and inner product computation + right: The second batch for gradient and inner product computation + gradient_provider: The gradient provider to compute the gradients. + + Returns: + A tensor representing the inner products of the mixed per-sample gradients + """ + operator = cast(TensorDictOperator, self.operator) + right_grads = gradient_provider.mixed_grads(right) + left_grads = gradient_provider.grads(left) + left_grads = operator.apply_to_mat_dict(left_grads) + left_grads_views = (t.reshape(t.shape[0], -1) for t in left_grads.values()) + right_grads_views = ( + t.reshape(*right.x.shape, -1) for t in right_grads.values() + ) + tensor_pairs = zip(left_grads_views, right_grads_views) + tensors_to_reduce = ( + self._aggregate_mixed_grads(left, right) for left, right in tensor_pairs + ) + return cast(torch.Tensor, sum(tensors_to_reduce)) + + @staticmethod + def _aggregate_mixed_grads(left: torch.Tensor, right: torch.Tensor): + return torch.einsum("ik, j...k -> ij...", left, right) + + @staticmethod + def _aggregate_grads(left: torch.Tensor, right: torch.Tensor): + return torch.einsum("i..., j... -> ij", left, right) + + +OperatorBilinearFormType = TypeVar( + "OperatorBilinearFormType", bound=OperatorBilinearForm +) + + +class TensorOperator(Operator[torch.Tensor, OperatorBilinearForm], ABC): """ Abstract base class for operators that can be applied to instances of [torch.Tensor][torch.Tensor]. @@ -369,13 +494,6 @@ def to(self, device: torch.device): def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: pass - def as_bilinear_form(self): - """ - Represent this operator as a - [OperatorBilinearForm][pydvl.influence.torch.base.OperatorBilinearForm]. - """ - return OperatorBilinearForm(self) - def apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: """ Applies the operator to a single vector. @@ -403,8 +521,72 @@ def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: """ return torch.func.vmap(self.apply_to_vec, in_dims=0, randomness="same")(mat) + def as_bilinear_form(self) -> OperatorBilinearForm: + return OperatorBilinearForm(self) + + +class TensorDictOperator(TensorOperator, ABC): + """ + Abstract base class for operators that can be applied to instances of + [torch.Tensor][torch.Tensor] and compatible dictionaries mapping strings to tensors. + Input dictionaries must conform to the structure defined by the property + `input_dict_structure`. Useful for operators involving autograd functionality + to avoid intermediate flattening and concatenating of gradient inputs. + """ + + def apply_to_mat_dict( + self, mat: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """ + Applies the operator to a dictionary of tensors, compatible to the structure + defined by the property `input_dict_structure`. + + Args: + mat: dictionary of tensors, whose keys and shapes match the property + `input_dict_structure`. + + Returns: + A dictionary of tensors after applying the operator + """ + + if not self._validate_mat_dict(mat): + raise ValueError( + f"Incompatible input structure, expected (excluding batch" + f"dimension): \n {self.input_dict_structure}" + ) + + return self._apply_to_mat_dict(self._dict_to_device(mat)) + + def _dict_to_device(self, mat: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return {k: v.to(self.device) for k, v in mat.items()} + + @property + @abstractmethod + def input_dict_structure(self) -> Dict[str, Tuple[int, ...]]: + """ + Implement this to expose the expected structure of the input tensor dict, i.e. + a dictionary of shapes (excluding the first batch dimension), in order + to validate the input tensor dicts. + """ + + @abstractmethod + def _apply_to_mat_dict( + self, mat: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + pass + + def _validate_mat_dict(self, mat: Dict[str, torch.Tensor]) -> bool: + for keys, val in mat.items(): + if val.shape[1:] != self.input_dict_structure[keys]: + return False + else: + return True + + def as_bilinear_form(self) -> DictBilinearForm: + return DictBilinearForm(self) + -TorchOperatorType = TypeVar("TorchOperatorType", bound=TorchOperator) +TorchOperatorType = TypeVar("TorchOperatorType", bound=TensorOperator) class TorchOperatorGradientComposition( diff --git a/src/pydvl/influence/torch/batch_operation.py b/src/pydvl/influence/torch/batch_operation.py index 4c7075924..30235899c 100644 --- a/src/pydvl/influence/torch/batch_operation.py +++ b/src/pydvl/influence/torch/batch_operation.py @@ -13,8 +13,21 @@ which is useful in the case that keeping $B$ in memory is not feasible. """ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Callable, Dict, Optional, Type, TypeVar, Union +from typing import ( + Callable, + Dict, + Generator, + Generic, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import torch @@ -24,37 +37,99 @@ TorchBatch, TorchGradientProvider, ) -from .functional import create_batch_hvp_function -from .util import LossType, inverse_rank_one_update, rank_one_mvp +from .functional import create_batch_hvp_function, create_batch_loss_function, hvp +from .util import ( + LossType, + generate_inverse_rank_one_updates, + generate_rank_one_mvp, + inverse_rank_one_update, + rank_one_mvp, +) -class BatchOperation(ABC): +class _ModelBasedBatchOperation(ABC): r""" Abstract base class to implement operations of the form - $$ m(b) \cdot v $$ + $$ m(\text{model}, b) \cdot v $$ + + where model is a [torch.nn.Module][torch.nn.Module]. - where $m(b)$ is a matrix defined by the data in the batch and $v$ is a vector - or matrix. """ + def __init__( + self, + model: torch.nn.Module, + restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, + ): + if restrict_to is None: + restrict_to = { + k: p.detach() for k, p in model.named_parameters() if p.requires_grad + } + self.params_to_restrict_to = restrict_to + self.model = model + @property - @abstractmethod - def input_size(self): - pass + def input_dict_structure(self) -> Dict[str, Tuple[int, ...]]: + return {k: p.shape for k, p in self.params_to_restrict_to.items()} @property - @abstractmethod def device(self): - pass + return next(self.model.parameters()).device @property - @abstractmethod def dtype(self): - pass + return next(self.model.parameters()).dtype + + @property + def input_size(self): + return sum(p.numel() for p in self.params_to_restrict_to.values()) - @abstractmethod def to(self, device: torch.device): + self.model = self.model.to(device) + self.params_to_restrict_to = { + k: p.detach() + for k, p in self.model.named_parameters() + if k in self.params_to_restrict_to + } + return self + + def apply_to_tensor_dict( + self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + + if mat_dict.keys() != self.params_to_restrict_to.keys(): + raise ValueError( + "The keys of the matrix dictionary must match the keys of the " + "parameters to restrict to." + ) + + return self._apply_to_tensor_dict( + batch, {k: v.to(self.device) for k, v in mat_dict.items()} + ) + + def _has_batch_dim(self, tensor_dict: Dict[str, torch.Tensor]): + batch_dim_flags = [ + tensor_dict[key].shape == val.shape + for key, val in self.params_to_restrict_to.items() + ] + if len(set(batch_dim_flags)) == 2: + raise ValueError("Existence of batch dim must be consistent") + return not all(batch_dim_flags) + + def _add_batch_dim(self, vec_dict: Dict[str, torch.Tensor]): + result = {} + for key, value in self.params_to_restrict_to.items(): + if value.shape == vec_dict[key].shape: + result[key] = vec_dict[key].unsqueeze(0) + else: + result[key] = vec_dict[key] + return result + + @abstractmethod + def _apply_to_tensor_dict( + self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: pass @abstractmethod @@ -95,51 +170,7 @@ def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: )(batch.x, batch.y, mat) -class ModelBasedBatchOperation(BatchOperation, ABC): - r""" - Abstract base class to implement operations of the form - - $$ m(\text{model}, b) \cdot v $$ - - where model is a [torch.nn.Module][torch.nn.Module]. - - """ - - def __init__( - self, - model: torch.nn.Module, - restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, - ): - if restrict_to is None: - restrict_to = { - k: p.detach() for k, p in model.named_parameters() if p.requires_grad - } - self.params_to_restrict_to = restrict_to - self.model = model - - @property - def device(self): - return next(self.model.parameters()).device - - @property - def dtype(self): - return next(self.model.parameters()).dtype - - @property - def input_size(self): - return sum(p.numel() for p in self.params_to_restrict_to.values()) - - def to(self, device: torch.device): - self.model = self.model.to(device) - self.params_to_restrict_to = { - k: p.detach() - for k, p in self.model.named_parameters() - if k in self.params_to_restrict_to - } - return self - - -class HessianBatchOperation(ModelBasedBatchOperation): +class HessianBatchOperation(_ModelBasedBatchOperation): r""" Given a model and loss function computes the Hessian vector or matrix product with respect to the model parameters, i.e. @@ -173,12 +204,39 @@ def __init__( self._batch_hvp = create_batch_hvp_function( model, loss, reverse_only=reverse_only ) + self.loss = loss + self.reverse_only = reverse_only def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: return self._batch_hvp(self.params_to_restrict_to, batch.x, batch.y, vec) + def _apply_to_tensor_dict( + self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + + func = self._create_seq_func(*batch) + + if self._has_batch_dim(mat_dict): + func = torch.func.vmap( + func, in_dims=tuple((0 for _ in self.params_to_restrict_to)) + ) + + result: Dict[str, torch.Tensor] = func(*mat_dict.values()) + return result + + def _create_seq_func(self, x: torch.Tensor, y: torch.Tensor): + def seq_func(*vec: torch.Tensor) -> Dict[str, torch.Tensor]: + return hvp( + lambda p: create_batch_loss_function(self.model, self.loss)(p, x, y), + self.params_to_restrict_to, + dict(zip(self.params_to_restrict_to.keys(), vec)), + reverse_only=self.reverse_only, + ) + + return seq_func + -class GaussNewtonBatchOperation(ModelBasedBatchOperation): +class GaussNewtonBatchOperation(_ModelBasedBatchOperation): r""" Given a model and loss function computes the Gauss-Newton vector or matrix product with respect to the model parameters, i.e. @@ -222,10 +280,18 @@ def __init__( model, loss, self.params_to_restrict_to ) + def _apply_to_tensor_dict( + self, batch: TorchBatch, vec_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + vec_values = list(self._add_batch_dim(vec_dict).values()) + grads_dict = self.gradient_provider.grads(batch) + grads_values = list(self._add_batch_dim(grads_dict).values()) + gen_result = generate_rank_one_mvp(grads_values, vec_values) + return dict(zip(vec_dict.keys(), gen_result)) + def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: flat_grads = self.gradient_provider.flat_grads(batch) - result = rank_one_mvp(flat_grads, vec) - return result + return rank_one_mvp(flat_grads, vec) def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: """ @@ -248,7 +314,7 @@ def to(self, device: torch.device): return super().to(device) -class InverseHarmonicMeanBatchOperation(ModelBasedBatchOperation): +class InverseHarmonicMeanBatchOperation(_ModelBasedBatchOperation): r""" Given a model and loss function computes an approximation of the inverse Gauss-Newton vector or matrix product. Viewing the damped Gauss-newton matrix @@ -327,10 +393,11 @@ def regularization(self, value: float): def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: grads = self.gradient_provider.flat_grads(batch) - return ( - inverse_rank_one_update(grads, vec, self.regularization) - / self.regularization - ) + if vec.ndim == 1: + input_vec = vec.unsqueeze(0) + else: + input_vec = vec + return inverse_rank_one_update(grads, input_vec, self.regularization) def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: """ @@ -353,5 +420,111 @@ def to(self, device: torch.device): self.gradient_provider.params_to_restrict_to = self.params_to_restrict_to return self + def _apply_to_tensor_dict( + self, batch: TorchBatch, vec_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + vec_values = list(self._add_batch_dim(vec_dict).values()) + grads_dict = self.gradient_provider.grads(batch) + grads_values = list(self._add_batch_dim(grads_dict).values()) + gen_result = generate_inverse_rank_one_updates( + grads_values, vec_values, self.regularization + ) + return dict(zip(vec_dict.keys(), gen_result)) + + +BatchOperationType = TypeVar("BatchOperationType", bound=_ModelBasedBatchOperation) + + +class _TensorDictAveraging(ABC): + @abstractmethod + def __call__(self, tensor_dicts: Generator[Dict[str, torch.Tensor], None, None]): + pass + + +_TensorDictAveragingType = TypeVar( + "_TensorDictAveragingType", bound=_TensorDictAveraging +) + + +class _TensorAveraging(Generic[_TensorDictAveragingType], ABC): + @abstractmethod + def __call__(self, tensors: Generator[torch.Tensor, None, None]): + pass + + @abstractmethod + def as_dict_averaging(self) -> _TensorDictAveraging: + pass + + +TensorAveragingType = TypeVar("TensorAveragingType", bound=_TensorAveraging) + + +class _TensorDictChunkAveraging(_TensorDictAveraging): + def __call__(self, tensor_dicts: Generator[Dict[str, torch.Tensor], None, None]): + result = next(tensor_dicts) + n_chunks = 1.0 + for tensor_dict in tensor_dicts: + for key, tensor in tensor_dict.items(): + result[key] += tensor + n_chunks += 1.0 + return {k: t / n_chunks for k, t in result.items()} + + +class ChunkAveraging(_TensorAveraging[_TensorDictChunkAveraging]): + """ + Averages tensors, provided by a generator, and normalizes by the number + of tensors. + """ + + def __call__(self, tensors: Generator[torch.Tensor, None, None]): + result = next(tensors) + n_chunks = 1.0 + for tensor in tensors: + result += tensor + n_chunks += 1.0 + return result / n_chunks + + def as_dict_averaging(self) -> _TensorDictChunkAveraging: + return _TensorDictChunkAveraging() + + +class _TensorDictPointAveraging(_TensorDictAveraging): + def __init__(self, batch_dim: int = 0): + self.batch_dim = batch_dim + + def __call__(self, tensor_dicts: Generator[Dict[str, torch.Tensor], None, None]): + result = next(tensor_dicts) + n_points = next(iter(result.values())).shape[self.batch_dim] + for tensor_dict in tensor_dicts: + n_points_in_batch = next(iter(tensor_dict.values())).shape[self.batch_dim] + for key, tensor in tensor_dict.items(): + result[key] += n_points_in_batch * tensor + n_points += n_points_in_batch + return {k: t / float(n_points) for k, t in result.items()} + + +class PointAveraging(_TensorAveraging[_TensorDictPointAveraging]): + """ + Averages tensors provided by a generator. The averaging is weighted by + the number of points in each tensor and the final result is normalized by the + number of total points. + + Args: + batch_dim: Dimension to extract the number of points for the weighting. + + """ + + def __init__(self, batch_dim: int = 0): + self.batch_dim = batch_dim + + def __call__(self, tensors: Generator[torch.Tensor, None, None]): + result = next(tensors) + n_points = result.shape[self.batch_dim] + for tensor in tensors: + n_points_in_batch = tensor.shape[self.batch_dim] + result += n_points_in_batch * tensor + n_points += n_points_in_batch + return result / float(n_points) -BatchOperationType = TypeVar("BatchOperationType", bound=BatchOperation) + def as_dict_averaging(self) -> _TensorDictPointAveraging: + return _TensorDictPointAveraging(self.batch_dim) diff --git a/src/pydvl/influence/torch/operator.py b/src/pydvl/influence/torch/operator.py index 81d6b0442..745210f05 100644 --- a/src/pydvl/influence/torch/operator.py +++ b/src/pydvl/influence/torch/operator.py @@ -1,27 +1,30 @@ -from typing import Callable, Dict, Generator, Generic, Optional, Type, Union +from typing import Callable, Dict, Generic, Optional, Tuple, Type, Union import torch from torch import nn as nn from torch.utils.data import DataLoader -from ..array import LazyChunkSequence, SequenceAggregator from .base import ( GradientProviderFactoryType, + TensorDictOperator, TorchAutoGrad, TorchBatch, TorchGradientProvider, - TorchOperator, ) from .batch_operation import ( BatchOperationType, + ChunkAveraging, GaussNewtonBatchOperation, HessianBatchOperation, InverseHarmonicMeanBatchOperation, + PointAveraging, + TensorAveragingType, ) -from .util import TorchChunkAverageAggregator, TorchPointAverageAggregator -class AggregateBatchOperator(TorchOperator, Generic[BatchOperationType]): +class _AveragingBatchOperator( + TensorDictOperator, Generic[BatchOperationType, TensorAveragingType] +): """ Class for aggregating batch operations over a dataset using a provided data loader and aggregator. @@ -32,7 +35,7 @@ class AggregateBatchOperator(TorchOperator, Generic[BatchOperationType]): Attributes: batch_operation: The batch operation to apply. dataloader: The data loader providing batches of data. - aggregator: The sequence aggregator to aggregate the results of the batch + averaging: The sequence aggregator to aggregate the results of the batch operations. """ @@ -40,11 +43,27 @@ def __init__( self, batch_operation: BatchOperationType, dataloader: DataLoader, - aggregator: SequenceAggregator[torch.Tensor], + averager: TensorAveragingType, ): self.batch_operation = batch_operation self.dataloader = dataloader - self.aggregator = aggregator + self.averaging = averager + + @property + def input_dict_structure(self) -> Dict[str, Tuple[int, ...]]: + return self.batch_operation.input_dict_structure + + def _apply_to_mat_dict( + self, mat: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + + tensor_dicts = ( + self.batch_operation.apply_to_tensor_dict(TorchBatch(x, y), mat) + for x, y in self.dataloader + ) + dict_averaging = self.averaging.as_dict_averaging() + result: Dict[str, torch.Tensor] = dict_averaging(tensor_dicts) + return result @property def device(self): @@ -85,21 +104,20 @@ def _apply( z: torch.Tensor, batch_ops: Callable[[TorchBatch, torch.Tensor], torch.Tensor], ): - def tensor_gen_factory() -> Generator[torch.Tensor, None, None]: - return ( - batch_ops( - TorchBatch(x.to(self.device), y.to(self.device)), z.to(self.device) - ) - for x, y in self.dataloader - ) - lazy_tensor_sequence = LazyChunkSequence( - tensor_gen_factory, len_generator=len(self.dataloader) + tensors = ( + batch_ops( + TorchBatch(x.to(self.device), y.to(self.device)), z.to(self.device) + ) + for x, y in self.dataloader ) - return self.aggregator(lazy_tensor_sequence) + return self.averaging(tensors) -class GaussNewtonOperator(AggregateBatchOperator[GaussNewtonBatchOperation]): + +class GaussNewtonOperator( + _AveragingBatchOperator[GaussNewtonBatchOperation, PointAveraging] +): r""" Given a model and loss function computes the Gauss-Newton vector or matrix product with respect to the model parameters on a batch, i.e. @@ -142,11 +160,11 @@ def __init__( gradient_provider_factory=gradient_provider_factory, restrict_to=restrict_to, ) - aggregator = TorchPointAverageAggregator() - super().__init__(batch_op, dataloader, aggregator) + averaging = PointAveraging() + super().__init__(batch_op, dataloader, averaging) -class HessianOperator(AggregateBatchOperator[HessianBatchOperation]): +class HessianOperator(_AveragingBatchOperator[HessianBatchOperation, ChunkAveraging]): r""" Given a model and loss function computes the Hessian vector or matrix product with respect to the model parameters for a given batch, i.e. @@ -182,12 +200,12 @@ def __init__( batch_op = HessianBatchOperation( model, loss, restrict_to=restrict_to, reverse_only=reverse_only ) - aggregator = TorchChunkAverageAggregator() - super().__init__(batch_op, dataloader, aggregator) + averaging = ChunkAveraging() + super().__init__(batch_op, dataloader, averaging) class InverseHarmonicMeanOperator( - AggregateBatchOperator[InverseHarmonicMeanBatchOperation] + _AveragingBatchOperator[InverseHarmonicMeanBatchOperation, PointAveraging] ): r""" Given a model and loss function computes an approximation of the inverse @@ -265,8 +283,8 @@ def __init__( gradient_provider_factory=gradient_provider_factory, restrict_to=restrict_to, ) - aggregator = TorchPointAverageAggregator(weighted=False) - super().__init__(batch_op, dataloader, aggregator) + averaging = PointAveraging() + super().__init__(batch_op, dataloader, averaging) @property def regularization(self): diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index a3553af5c..661ace7da 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -11,6 +11,7 @@ Callable, Collection, Dict, + Generator, Iterable, Iterator, List, @@ -48,15 +49,11 @@ "align_with_model", "flatten_dimensions", "TorchNumpyConverter", - "TorchCatAggregator", - "NestedTorchCatAggregator", "torch_dataset_to_dask_array", "EkfacRepresentation", "empirical_cross_entropy_loss_fn", "rank_one_mvp", "inverse_rank_one_update", - "TorchPointAverageAggregator", - "TorchChunkAverageAggregator", "LossType", "ModelParameterDictBuilder", "BlockMode", @@ -451,33 +448,6 @@ def __call__( return torch.cat(list(t_gen)) -class TorchChunkAverageAggregator(SequenceAggregator[torch.Tensor]): - def __call__(self, tensor_sequence: LazyChunkSequence): - t_gen = tensor_sequence.generator_factory() - result = next(t_gen) - n_chunks = 1 - for t in t_gen: - result += t - n_chunks += 1 - return result / n_chunks - - -class TorchPointAverageAggregator(SequenceAggregator[torch.Tensor]): - def __init__(self, batch_dim: int = 0, weighted: bool = True): - self.weighted = weighted - self.batch_dim = batch_dim - - def __call__(self, tensor_sequence: LazyChunkSequence): - tensor_generator = tensor_sequence.generator_factory() - result = next(tensor_generator) - n_points = result.shape[self.batch_dim] - for tensor in tensor_generator: - n_points_in_batch = tensor.shape[self.batch_dim] - result += n_points_in_batch * tensor if self.weighted else tensor - n_points += n_points_in_batch - return result / n_points - - class NestedTorchCatAggregator(NestedSequenceAggregator[torch.Tensor]): """ An aggregator that concatenates tensors using PyTorch's [torch.cat][torch.cat] @@ -648,7 +618,7 @@ def rank_one_mvp(x: torch.Tensor, v: torch.Tensor) -> torch.Tensor: forming xx^T and sums the result. Here, X and V are matrices where each row represents an individual vector. Effectively it is computing - $$ V@(\sum_i^N x[i]x[i]^T) $$ + $$ V@( \frac{1}{N}\sum_i^N x[i]x[i]^T) $$ Args: x: Matrix of vectors of size `(N, M)`. @@ -661,8 +631,23 @@ def rank_one_mvp(x: torch.Tensor, v: torch.Tensor) -> torch.Tensor: """ if v.ndim == 1: result = torch.einsum("ij,kj->ki", x, v.unsqueeze(0)) @ x - return result.squeeze() - return torch.einsum("ij,kj->ki", x, v) @ x + return result.squeeze() / x.shape[0] + return (torch.einsum("ij,kj->ki", x, v) @ x) / x.shape[0] + + +def generate_rank_one_mvp( + x: List[torch.Tensor], v: List[torch.Tensor] +) -> Generator[torch.Tensor, None, None]: + x_v_iterator = zip(x, v) + x_, v_ = next(x_v_iterator) + + nominator = torch.einsum("i..., k...->ki", x_, v_) + + for x_, v_ in x_v_iterator: + nominator += torch.einsum("i..., k...->ki", x_, v_) + + for x_, v_ in zip(x, v): + yield torch.einsum("ji, i... -> j...", nominator, x_) / x_.shape[0] def inverse_rank_one_update( @@ -696,30 +681,64 @@ def inverse_rank_one_update( return (v - (nominator / denominator) @ x) / regularization -def inverse_rank_one_update_dict( - x: Dict[str, torch.Tensor], v: Dict[str, torch.Tensor], regularization: float -) -> Dict[str, torch.Tensor]: +def generate_inverse_rank_one_updates( + x: List[torch.Tensor], v: List[torch.Tensor], regularization: float +) -> Generator[torch.Tensor, None, None]: + def _check_batch_dim(t_x, t_v, idx: int): + if t_x.ndim <= 1: + raise ValueError( + f"Provided tensors in the lists must have at least " + f"2 dimensions, " + f"but found {t_x.ndim=} at {idx=} in list x" + ) + + if v_.ndim <= 1: + raise ValueError( + f"Provided tensors in the lists must have at least " + f"2 dimensions, " + f"but found shape {t_v.ndim=} at {idx=} in list v" + ) + + def _create_dim_error(x_shape, v_shape, idx: int): + return ValueError( + f"Entries in the tensor lists must have the same " + f"(excluding the first batch dimensions), " + f"but found shapes {x_shape} and {v_shape}" + f"at {idx=}" + ) + + if not len(x) == len(v): + raise ValueError( + f"Provided tensor lists must have the same length, but got" + f"{len(x)=} and {len(v)=}" + ) - denominator = regularization - nominator = None - batch_size = None - for x_, v_ in zip(x.values(), v.values()): - if batch_size is None: - batch_size = x_.shape[0] - if nominator is None: - nominator = torch.einsum("i..., k...->ki", x_, v_) - else: - nominator += torch.einsum("i..., k...->ki", x_, v_) + x_v_iterator = enumerate(zip(x, v)) + index, (x_, v_) = next(x_v_iterator) + + _check_batch_dim(x_, v_, index) + + if x_.shape[1:] != v_.shape[1:]: + raise _create_dim_error(x_.shape[1:], v_.shape[1:], index) + + denominator = regularization + torch.sum(x_.view(x_.shape[0], -1) ** 2, dim=1) + nominator = torch.einsum("i..., k...->ki", x_, v_) + num_data_points = x_.shape[0] + + for k, (x_, v_) in x_v_iterator: + _check_batch_dim(x_, v_, k) + if x_.shape[1:] != v_.shape[1:]: + raise _create_dim_error(x_.shape[1:], v_.shape[1:], k) + + nominator += torch.einsum("i..., k...->ki", x_, v_) denominator += torch.sum(x_.view(x_.shape[0], -1) ** 2, dim=1) - denominator = batch_size * denominator - result = {} - for key in x.keys(): - result[key] = ( - v[key] - torch.einsum("ji, i... -> j...", nominator / denominator, x[key]) - ) / regularization + denominator = num_data_points * denominator - return result + for x_, v_ in zip(x, v): + yield ( + v_ - torch.einsum("ji, i... -> j...", nominator / denominator, x_) + ) / regularization LossType = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] diff --git a/tests/influence/test_influence_calculator.py b/tests/influence/test_influence_calculator.py index 9a82e89cf..70a29bf1a 100644 --- a/tests/influence/test_influence_calculator.py +++ b/tests/influence/test_influence_calculator.py @@ -1,5 +1,3 @@ -import uuid - import dask.array as da import numpy as np import pytest diff --git a/tests/influence/torch/test_batch_operation.py b/tests/influence/torch/test_batch_operation.py index 304de4518..a987857e6 100644 --- a/tests/influence/torch/test_batch_operation.py +++ b/tests/influence/torch/test_batch_operation.py @@ -7,7 +7,9 @@ from pydvl.influence.torch.batch_operation import ( GaussNewtonBatchOperation, HessianBatchOperation, + InverseHarmonicMeanBatchOperation, ) +from pydvl.influence.torch.util import align_structure, flatten_dimensions from .test_util import model_data, test_parameters @@ -18,7 +20,7 @@ [(astuple(tp.model_params), 1e-5) for tp in test_parameters], indirect=["model_data"], ) -def test_hessian_batch_operation(model_data, tol: float): +def test_hessian_batch_operation(model_data, tol: float, pytorch_seed): torch_model, x, y, vec, h_analytical = model_data params = dict(torch_model.named_parameters()) @@ -26,9 +28,36 @@ def test_hessian_batch_operation(model_data, tol: float): hessian_op = HessianBatchOperation( torch_model, torch.nn.functional.mse_loss, restrict_to=params ) + batch_size = 10 + rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} + flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) + hvp_autograd_mat_dict = hessian_op.apply_to_tensor_dict( + TorchBatch(x, y), rand_mat_dict + ) + hvp_autograd = hessian_op.apply_to_vec(TorchBatch(x, y), vec) + hvp_autograd_dict = hessian_op.apply_to_tensor_dict( + TorchBatch(x, y), align_structure(params, vec) + ) + hvp_autograd_dict_flat = flatten_dimensions(hvp_autograd_dict.values()) assert torch.allclose(hvp_autograd, h_analytical @ vec, rtol=tol) + assert torch.allclose(hvp_autograd_dict_flat, h_analytical @ vec, rtol=tol) + + op_then_flat = flatten_dimensions( + hvp_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op_analytical = torch.einsum("ik, jk -> ji", h_analytical, flat_rand_mat) + + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-5, + rtol=tol, + ) + assert torch.allclose( + hessian_op.apply_to_mat(TorchBatch(x, y), flat_rand_mat), op_then_flat + ) @pytest.mark.torch @@ -37,7 +66,7 @@ def test_hessian_batch_operation(model_data, tol: float): [(astuple(tp.model_params), 1e-3) for tp in test_parameters], indirect=["model_data"], ) -def test_gauss_newton_batch_operation(model_data, tol: float): +def test_gauss_newton_batch_operation(model_data, tol: float, pytorch_seed): torch_model, x, y, vec, _ = model_data y_pred = torch_model(x) @@ -47,11 +76,14 @@ def test_gauss_newton_batch_operation(model_data, tol: float): )(x, y_pred, y) dl_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) grad_analytical = torch.cat([dl_dw.reshape(x.shape[0], -1), dl_db], dim=-1) - gn_mat_analytical = torch.sum( - torch.func.vmap(lambda t: t.unsqueeze(-1) * t.unsqueeze(-1).t())( - grad_analytical - ), - dim=0, + gn_mat_analytical = ( + torch.sum( + torch.func.vmap(lambda t: t.unsqueeze(-1) * t.unsqueeze(-1).t())( + grad_analytical + ), + dim=0, + ) + / x.shape[0] ) params = dict(torch_model.named_parameters()) @@ -59,8 +91,115 @@ def test_gauss_newton_batch_operation(model_data, tol: float): gn_op = GaussNewtonBatchOperation( torch_model, torch.nn.functional.mse_loss, restrict_to=params ) + batch_size = 10 + + gn_autograd = gn_op.apply_to_vec(TorchBatch(x, y), vec) + gn_autograd_dict = gn_op.apply_to_tensor_dict( + TorchBatch(x, y), align_structure(params, vec) + ) + gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) + analytical_vec = gn_mat_analytical @ vec + assert torch.allclose(gn_autograd, analytical_vec, atol=1e-5, rtol=tol) + assert torch.allclose(gn_autograd_dict_flat, analytical_vec, atol=1e-5, rtol=tol) + + rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} + flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) + gn_autograd_mat_dict = gn_op.apply_to_tensor_dict(TorchBatch(x, y), rand_mat_dict) + + op_then_flat = flatten_dimensions( + gn_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op = gn_op.apply_to_mat(TorchBatch(x, y), flat_rand_mat) + + assert torch.allclose( + op_then_flat, + flat_then_op, + atol=1e-5, + rtol=tol, + ) + + flat_then_op_analytical = torch.einsum( + "ik, jk -> ji", gn_mat_analytical, flat_rand_mat + ) + + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-5, + rtol=1e-2, + ) + + +@pytest.mark.torch +@pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], +) +@pytest.mark.parametrize("reg", [0.4]) +def test_inverse_harmonic_mean_batch_operation( + model_data, tol: float, reg, pytorch_seed +): + torch_model, x, y, vec, _ = model_data + y_pred = torch_model(x) + out_features = y_pred.shape[1] + dl_dw = torch.vmap( + lambda r, s, t: 2 / float(out_features) * (s - t).view(-1, 1) @ r.view(1, -1) + )(x, y_pred, y) + dl_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) + grad_analytical = torch.cat([dl_dw.reshape(x.shape[0], -1), dl_db], dim=-1) + params = { + k: p.detach() for k, p in torch_model.named_parameters() if p.requires_grad + } + + ihm_mat_analytical = torch.sum( + torch.func.vmap( + lambda z: torch.linalg.inv( + z.unsqueeze(-1) * z.unsqueeze(-1).t() + reg * torch.eye(len(z)) + ) + )(grad_analytical), + dim=0, + ) + ihm_mat_analytical /= x.shape[0] + + gn_op = InverseHarmonicMeanBatchOperation( + torch_model, torch.nn.functional.mse_loss, reg, restrict_to=params + ) + batch_size = 10 + gn_autograd = gn_op.apply_to_vec(TorchBatch(x, y), vec) + gn_autograd_dict = gn_op.apply_to_tensor_dict( + TorchBatch(x, y), align_structure(params, vec) + ) + gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) + analytical = ihm_mat_analytical @ vec + + assert torch.allclose(gn_autograd, analytical, atol=1e-5, rtol=tol) + assert torch.allclose(gn_autograd_dict_flat, analytical, atol=1e-5, rtol=tol) - gn_analytical = gn_mat_analytical @ vec + rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} + flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) + gn_autograd_mat_dict = gn_op.apply_to_tensor_dict(TorchBatch(x, y), rand_mat_dict) - assert torch.allclose(gn_autograd, gn_analytical, atol=1e-5, rtol=tol) + op_then_flat = flatten_dimensions( + gn_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op = gn_op.apply_to_mat(TorchBatch(x, y), flat_rand_mat) + + assert torch.allclose( + op_then_flat, + flat_then_op, + atol=1e-5, + rtol=tol, + ) + + flat_then_op_analytical = torch.einsum( + "ik, jk -> ji", ihm_mat_analytical, flat_rand_mat + ) + + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-5, + rtol=1e-2, + ) diff --git a/tests/influence/torch/test_util.py b/tests/influence/torch/test_util.py index fc80ad805..bd11b05e2 100644 --- a/tests/influence/torch/test_util.py +++ b/tests/influence/torch/test_util.py @@ -23,8 +23,9 @@ TorchTensorContainerType, align_structure, flatten_dimensions, + generate_inverse_rank_one_updates, + generate_rank_one_mvp, inverse_rank_one_update, - inverse_rank_one_update_dict, rank_one_mvp, safe_torch_linalg_eigh, torch_dataset_to_dask_array, @@ -338,7 +339,7 @@ def test_rank_one_mvp(x_dim_0, x_dim_1, v_dim_0): (torch.vmap(lambda x: x.unsqueeze(-1) * x.unsqueeze(-1).t())(X) @ V.t()) .sum(dim=0) .t() - ) + ) / x_dim_0 result = rank_one_mvp(X, V) @@ -346,6 +347,35 @@ def test_rank_one_mvp(x_dim_0, x_dim_1, v_dim_0): assert torch.allclose(result, expected, atol=1e-5, rtol=1e-4) +@pytest.mark.torch +@pytest.mark.parametrize( + "x_dim_1", + [ + [(4, 2, 3), (5, 7), (5,)], + [(3, 6, 8, 9), (1, 2)], + [(1,)], + ], +) +@pytest.mark.parametrize( + "x_dim_0, v_dim_0", + [(10, 12), (3, 5), (4, 10), (6, 6), (1, 7)], +) +def test_generate_rank_one_mvp(x_dim_0, x_dim_1, v_dim_0): + x_list = [torch.randn(x_dim_0, *d) for d in x_dim_1] + v_list = [torch.randn(v_dim_0, *d) for d in x_dim_1] + + x = flatten_dimensions(x_list, shape=(x_dim_0, -1)) + v = flatten_dimensions(v_list, shape=(v_dim_0, -1)) + result = rank_one_mvp(x, v) + + inverse_result = flatten_dimensions( + generate_rank_one_mvp(x_list, v_list), + shape=(v_dim_0, -1), + ) + + assert torch.allclose(result, inverse_result, atol=1e-5, rtol=1e-3) + + @pytest.mark.torch @pytest.mark.parametrize( "x_dim_0, x_dim_1, v_dim_0", @@ -373,26 +403,31 @@ def test_inverse_rank_one_update(x_dim_0, x_dim_1, v_dim_0, reg): @pytest.mark.torch @pytest.mark.parametrize( "x_dim_1", - [{"1": (4, 2, 3), "2": (5, 7), "3": ()}, {"1": (3, 6, 8, 9), "2": (1, 2)}, {"1": (1,)}], + [ + [(4, 2, 3), (5, 7), (5,)], + [(3, 6, 8, 9), (1, 2)], + [(1,)], + ], ) @pytest.mark.parametrize( "x_dim_0, v_dim_0", [(10, 12), (3, 5), (4, 10), (6, 6), (1, 7)], ) @pytest.mark.parametrize("reg", [0.5, 100, 1.0, 10]) -def test_inverse_rank_one_update(x_dim_0, x_dim_1, v_dim_0, reg): - X_dict = {k: torch.randn(x_dim_0, *d) for k, d in x_dim_1.items()} - V_dict = {k: torch.randn(v_dim_0, *d) for k, d in x_dim_1.items()} +def test_generate_inverse_rank_one_updates(x_dim_0, x_dim_1, v_dim_0, reg): + x_list = [torch.randn(x_dim_0, *d) for d in x_dim_1] + v_list = [torch.randn(v_dim_0, *d) for d in x_dim_1] - X = flatten_dimensions(X_dict.values(), shape=(x_dim_0, -1)) - V = flatten_dimensions(V_dict.values(), shape=(v_dim_0, -1)) - result = inverse_rank_one_update(X, V, reg) + x = flatten_dimensions(x_list, shape=(x_dim_0, -1)) + v = flatten_dimensions(v_list, shape=(v_dim_0, -1)) + result = inverse_rank_one_update(x, v, reg) inverse_result = flatten_dimensions( - inverse_rank_one_update_dict(X_dict, V_dict, reg).values(), shape=(v_dim_0, -1) + generate_inverse_rank_one_updates(x_list, v_list, reg), + shape=(v_dim_0, -1), ) - assert torch.allclose(result, inverse_result, atol=1e-5, rtol=1e-3) + assert torch.allclose(result, inverse_result) class TestModelParameterDictBuilder: From f4f593117572ebc27070f07995c606459cbc2203 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 29 May 2024 12:37:31 +0200 Subject: [PATCH 26/43] Move functions from util to static methods, increase tolerance for test --- src/pydvl/influence/torch/batch_operation.py | 108 +++++++++++++-- src/pydvl/influence/torch/util.py | 131 ------------------ tests/influence/torch/test_batch_operation.py | 110 ++++++++++++++- tests/influence/torch/test_util.py | 108 --------------- 4 files changed, 206 insertions(+), 251 deletions(-) diff --git a/src/pydvl/influence/torch/batch_operation.py b/src/pydvl/influence/torch/batch_operation.py index 30235899c..88bae26ee 100644 --- a/src/pydvl/influence/torch/batch_operation.py +++ b/src/pydvl/influence/torch/batch_operation.py @@ -38,13 +38,7 @@ TorchGradientProvider, ) from .functional import create_batch_hvp_function, create_batch_loss_function, hvp -from .util import ( - LossType, - generate_inverse_rank_one_updates, - generate_rank_one_mvp, - inverse_rank_one_update, - rank_one_mvp, -) +from .util import LossType class _ModelBasedBatchOperation(ABC): @@ -286,12 +280,12 @@ def _apply_to_tensor_dict( vec_values = list(self._add_batch_dim(vec_dict).values()) grads_dict = self.gradient_provider.grads(batch) grads_values = list(self._add_batch_dim(grads_dict).values()) - gen_result = generate_rank_one_mvp(grads_values, vec_values) + gen_result = self._generate_rank_one_mvp(grads_values, vec_values) return dict(zip(vec_dict.keys(), gen_result)) def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: flat_grads = self.gradient_provider.flat_grads(batch) - return rank_one_mvp(flat_grads, vec) + return self._rank_one_mvp(flat_grads, vec) def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: """ @@ -313,6 +307,44 @@ def to(self, device: torch.device): self.gradient_provider = self.gradient_provider.to(device) return super().to(device) + @staticmethod + def _rank_one_mvp(x: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + r""" + Computes the matrix-vector product of xx^T and v for each row in X and V without + forming xx^T and sums the result. Here, X and V are matrices where each row + represents an individual vector. Effectively it is computing + + $$ V@( \frac{1}{N}\sum_i^N x[i]x[i]^T) $$ + + Args: + x: Matrix of vectors of size `(N, M)`. + v: Matrix of vectors of size `(B, M)` to be multiplied by the corresponding + $xx^T$. + + Returns: + A matrix of size `(B, N)` where each column is the result of xx^T v for + corresponding rows in x and v. + """ + if v.ndim == 1: + result = torch.einsum("ij,kj->ki", x, v.unsqueeze(0)) @ x + return result.squeeze() / x.shape[0] + return (torch.einsum("ij,kj->ki", x, v) @ x) / x.shape[0] + + @staticmethod + def _generate_rank_one_mvp( + x: List[torch.Tensor], v: List[torch.Tensor] + ) -> Generator[torch.Tensor, None, None]: + x_v_iterator = zip(x, v) + x_, v_ = next(x_v_iterator) + + nominator = torch.einsum("i..., k...->ki", x_, v_) + + for x_, v_ in x_v_iterator: + nominator += torch.einsum("i..., k...->ki", x_, v_) + + for x_, v_ in zip(x, v): + yield torch.einsum("ji, i... -> j...", nominator, x_) / x_.shape[0] + class InverseHarmonicMeanBatchOperation(_ModelBasedBatchOperation): r""" @@ -397,7 +429,7 @@ def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: input_vec = vec.unsqueeze(0) else: input_vec = vec - return inverse_rank_one_update(grads, input_vec, self.regularization) + return self._inverse_rank_one_update(grads, input_vec, self.regularization) def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: """ @@ -426,11 +458,65 @@ def _apply_to_tensor_dict( vec_values = list(self._add_batch_dim(vec_dict).values()) grads_dict = self.gradient_provider.grads(batch) grads_values = list(self._add_batch_dim(grads_dict).values()) - gen_result = generate_inverse_rank_one_updates( + gen_result = self._generate_inverse_rank_one_updates( grads_values, vec_values, self.regularization ) return dict(zip(vec_dict.keys(), gen_result)) + @staticmethod + def _inverse_rank_one_update( + x: torch.Tensor, v: torch.Tensor, regularization: float + ) -> torch.Tensor: + r""" + Performs an inverse-rank one update on x and v. More precisely, it computes + + $$ \sum_{i=1}^n \left(x[i]x[i]^t+\lambda \operatorname{I}\right)^{-1}v $$ + + where $\operatorname{I}$ is the identity matrix and $\lambda$ is positive + regularization parameter. The inverse matrices are not calculated explicitly, + but instead a vectorized version of the + [Sherman–Morrison formula]( + https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) + is applied. + + Args: + x: Input matrix used for the rank one expressions. First dimension is + assumed to be the batch dimension. + v: Matrix to multiply with. First dimension is + assumed to be the batch dimension. + regularization: Regularization parameter to make the rank-one expressions + invertible, must be positive. + + Returns: + Matrix of size $(D, M)$ for x having shape $(N, D)$ and v having shape $(M, D)$. + """ + nominator = torch.einsum("ij,kj->ki", x, v) + denominator = x.shape[0] * (regularization + torch.sum(x**2, dim=1)) + return (v - (nominator / denominator) @ x) / regularization + + @staticmethod + def _generate_inverse_rank_one_updates( + x: List[torch.Tensor], v: List[torch.Tensor], regularization: float + ) -> Generator[torch.Tensor, None, None]: + + x_v_iterator = enumerate(zip(x, v)) + index, (x_, v_) = next(x_v_iterator) + + denominator = regularization + torch.sum(x_.view(x_.shape[0], -1) ** 2, dim=1) + nominator = torch.einsum("i..., k...->ki", x_, v_) + num_data_points = x_.shape[0] + + for k, (x_, v_) in x_v_iterator: + nominator += torch.einsum("i..., k...->ki", x_, v_) + denominator += torch.sum(x_.view(x_.shape[0], -1) ** 2, dim=1) + + denominator = num_data_points * denominator + + for x_, v_ in zip(x, v): + yield ( + v_ - torch.einsum("ji, i... -> j...", nominator / denominator, x_) + ) / regularization + BatchOperationType = TypeVar("BatchOperationType", bound=_ModelBasedBatchOperation) diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index 661ace7da..d1018941e 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -52,8 +52,6 @@ "torch_dataset_to_dask_array", "EkfacRepresentation", "empirical_cross_entropy_loss_fn", - "rank_one_mvp", - "inverse_rank_one_update", "LossType", "ModelParameterDictBuilder", "BlockMode", @@ -612,135 +610,6 @@ def __init__(self, original_exception: RuntimeError): super().__init__(err_msg) -def rank_one_mvp(x: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - r""" - Computes the matrix-vector product of xx^T and v for each row in X and V without - forming xx^T and sums the result. Here, X and V are matrices where each row - represents an individual vector. Effectively it is computing - - $$ V@( \frac{1}{N}\sum_i^N x[i]x[i]^T) $$ - - Args: - x: Matrix of vectors of size `(N, M)`. - v: Matrix of vectors of size `(B, M)` to be multiplied by the corresponding - $xx^T$. - - Returns: - A matrix of size `(B, N)` where each column is the result of xx^T v for - corresponding rows in x and v. - """ - if v.ndim == 1: - result = torch.einsum("ij,kj->ki", x, v.unsqueeze(0)) @ x - return result.squeeze() / x.shape[0] - return (torch.einsum("ij,kj->ki", x, v) @ x) / x.shape[0] - - -def generate_rank_one_mvp( - x: List[torch.Tensor], v: List[torch.Tensor] -) -> Generator[torch.Tensor, None, None]: - x_v_iterator = zip(x, v) - x_, v_ = next(x_v_iterator) - - nominator = torch.einsum("i..., k...->ki", x_, v_) - - for x_, v_ in x_v_iterator: - nominator += torch.einsum("i..., k...->ki", x_, v_) - - for x_, v_ in zip(x, v): - yield torch.einsum("ji, i... -> j...", nominator, x_) / x_.shape[0] - - -def inverse_rank_one_update( - x: torch.Tensor, v: torch.Tensor, regularization: float -) -> torch.Tensor: - r""" - Performs an inverse-rank one update on x and v. More precisely, it computes - - $$ \sum_{i=1}^n \left(x[i]x[i]^t+\lambda \operatorname{I}\right)^{-1}v $$ - - where $\operatorname{I}$ is the identity matrix and $\lambda$ is positive - regularization parameter. The inverse matrices are not calculated explicitly, - but instead a vectorized version of the - [Sherman–Morrison formula]( - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) - is applied. - - Args: - x: Input matrix used for the rank one expressions. First dimension is - assumed to be the batch dimension. - v: Matrix to multiply with. First dimension is - assumed to be the batch dimension. - regularization: Regularization parameter to make the rank-one expressions - invertible, must be positive. - - Returns: - Matrix of size $(D, M)$ for x having shape $(N, D)$ and v having shape $(M, D)$. - """ - nominator = torch.einsum("ij,kj->ki", x, v) - denominator = x.shape[0] * (regularization + torch.sum(x**2, dim=1)) - return (v - (nominator / denominator) @ x) / regularization - - -def generate_inverse_rank_one_updates( - x: List[torch.Tensor], v: List[torch.Tensor], regularization: float -) -> Generator[torch.Tensor, None, None]: - def _check_batch_dim(t_x, t_v, idx: int): - if t_x.ndim <= 1: - raise ValueError( - f"Provided tensors in the lists must have at least " - f"2 dimensions, " - f"but found {t_x.ndim=} at {idx=} in list x" - ) - - if v_.ndim <= 1: - raise ValueError( - f"Provided tensors in the lists must have at least " - f"2 dimensions, " - f"but found shape {t_v.ndim=} at {idx=} in list v" - ) - - def _create_dim_error(x_shape, v_shape, idx: int): - return ValueError( - f"Entries in the tensor lists must have the same " - f"(excluding the first batch dimensions), " - f"but found shapes {x_shape} and {v_shape}" - f"at {idx=}" - ) - - if not len(x) == len(v): - raise ValueError( - f"Provided tensor lists must have the same length, but got" - f"{len(x)=} and {len(v)=}" - ) - - x_v_iterator = enumerate(zip(x, v)) - index, (x_, v_) = next(x_v_iterator) - - _check_batch_dim(x_, v_, index) - - if x_.shape[1:] != v_.shape[1:]: - raise _create_dim_error(x_.shape[1:], v_.shape[1:], index) - - denominator = regularization + torch.sum(x_.view(x_.shape[0], -1) ** 2, dim=1) - nominator = torch.einsum("i..., k...->ki", x_, v_) - num_data_points = x_.shape[0] - - for k, (x_, v_) in x_v_iterator: - _check_batch_dim(x_, v_, k) - if x_.shape[1:] != v_.shape[1:]: - raise _create_dim_error(x_.shape[1:], v_.shape[1:], k) - - nominator += torch.einsum("i..., k...->ki", x_, v_) - denominator += torch.sum(x_.view(x_.shape[0], -1) ** 2, dim=1) - - denominator = num_data_points * denominator - - for x_, v_ in zip(x, v): - yield ( - v_ - torch.einsum("ji, i... -> j...", nominator / denominator, x_) - ) / regularization - - LossType = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] diff --git a/tests/influence/torch/test_batch_operation.py b/tests/influence/torch/test_batch_operation.py index a987857e6..f4714ba87 100644 --- a/tests/influence/torch/test_batch_operation.py +++ b/tests/influence/torch/test_batch_operation.py @@ -11,7 +11,7 @@ ) from pydvl.influence.torch.util import align_structure, flatten_dimensions -from .test_util import model_data, test_parameters +from .test_util import model_data, test_parameters, torch @pytest.mark.torch @@ -203,3 +203,111 @@ def test_inverse_harmonic_mean_batch_operation( atol=1e-5, rtol=1e-2, ) + + +@pytest.mark.torch +@pytest.mark.parametrize( + "x_dim_0, x_dim_1, v_dim_0", + [(10, 1, 12), (3, 2, 5), (4, 5, 30), (6, 6, 6), (1, 7, 7)], +) +def test_rank_one_mvp(x_dim_0, x_dim_1, v_dim_0): + X = torch.randn(x_dim_0, x_dim_1) + V = torch.randn(v_dim_0, x_dim_1) + + expected = ( + (torch.vmap(lambda x: x.unsqueeze(-1) * x.unsqueeze(-1).t())(X) @ V.t()) + .sum(dim=0) + .t() + ) / x_dim_0 + + result = GaussNewtonBatchOperation._rank_one_mvp(X, V) + + assert result.shape == V.shape + assert torch.allclose(result, expected, atol=1e-5, rtol=1e-4) + + +@pytest.mark.torch +@pytest.mark.parametrize( + "x_dim_1", + [ + [(4, 2, 3), (5, 7), (5,)], + [(3, 6, 8, 9), (1, 2)], + [(1,)], + ], +) +@pytest.mark.parametrize( + "x_dim_0, v_dim_0", + [(10, 12), (3, 5), (4, 10), (6, 6), (1, 7)], +) +def test_generate_rank_one_mvp(x_dim_0, x_dim_1, v_dim_0): + x_list = [torch.randn(x_dim_0, *d) for d in x_dim_1] + v_list = [torch.randn(v_dim_0, *d) for d in x_dim_1] + + x = flatten_dimensions(x_list, shape=(x_dim_0, -1)) + v = flatten_dimensions(v_list, shape=(v_dim_0, -1)) + result = GaussNewtonBatchOperation._rank_one_mvp(x, v) + + inverse_result = flatten_dimensions( + GaussNewtonBatchOperation._generate_rank_one_mvp(x_list, v_list), + shape=(v_dim_0, -1), + ) + + assert torch.allclose(result, inverse_result, atol=1e-5, rtol=1e-3) + + +@pytest.mark.torch +@pytest.mark.parametrize( + "x_dim_0, x_dim_1, v_dim_0", + [(10, 1, 12), (3, 2, 5), (4, 5, 10), (6, 6, 6), (1, 7, 7)], +) +@pytest.mark.parametrize("reg", [0.1, 100, 1.0, 10]) +def test_inverse_rank_one_update(x_dim_0, x_dim_1, v_dim_0, reg): + X = torch.randn(x_dim_0, x_dim_1) + V = torch.randn(v_dim_0, x_dim_1) + + inverse_result = torch.zeros_like(V) + + for x in X: + rank_one_matrix = x.unsqueeze(-1) * x.unsqueeze(-1).t() + inverse_result += torch.linalg.solve( + rank_one_matrix + reg * torch.eye(rank_one_matrix.shape[0]), V, left=False + ) + + inverse_result /= X.shape[0] + result = InverseHarmonicMeanBatchOperation._inverse_rank_one_update(X, V, reg) + + assert torch.allclose(result, inverse_result, atol=1e-5) + + +@pytest.mark.torch +@pytest.mark.parametrize( + "x_dim_1", + [ + [(4, 2, 3), (5, 7), (5,)], + [(3, 6, 8, 9), (1, 2)], + [(1,)], + ], +) +@pytest.mark.parametrize( + "x_dim_0, v_dim_0", + [(10, 12), (3, 5), (4, 10), (6, 6), (1, 7)], +) +@pytest.mark.parametrize("reg", [0.5, 100, 1.0, 10]) +def test_generate_inverse_rank_one_updates( + x_dim_0, x_dim_1, v_dim_0, reg, pytorch_seed +): + x_list = [torch.randn(x_dim_0, *d) for d in x_dim_1] + v_list = [torch.randn(v_dim_0, *d) for d in x_dim_1] + + x = flatten_dimensions(x_list, shape=(x_dim_0, -1)) + v = flatten_dimensions(v_list, shape=(v_dim_0, -1)) + result = InverseHarmonicMeanBatchOperation._inverse_rank_one_update(x, v, reg) + + inverse_result = flatten_dimensions( + InverseHarmonicMeanBatchOperation._generate_inverse_rank_one_updates( + x_list, v_list, reg + ), + shape=(v_dim_0, -1), + ) + + assert torch.allclose(result, inverse_result, atol=1e-5, rtol=1e-3) diff --git a/tests/influence/torch/test_util.py b/tests/influence/torch/test_util.py index bd11b05e2..a1b782a8c 100644 --- a/tests/influence/torch/test_util.py +++ b/tests/influence/torch/test_util.py @@ -23,10 +23,6 @@ TorchTensorContainerType, align_structure, flatten_dimensions, - generate_inverse_rank_one_updates, - generate_rank_one_mvp, - inverse_rank_one_update, - rank_one_mvp, safe_torch_linalg_eigh, torch_dataset_to_dask_array, ) @@ -326,110 +322,6 @@ def test_safe_torch_linalg_eigh_exception(): safe_torch_linalg_eigh(torch.randn([53000, 53000])) -@pytest.mark.torch -@pytest.mark.parametrize( - "x_dim_0, x_dim_1, v_dim_0", - [(10, 1, 12), (3, 2, 5), (4, 5, 30), (6, 6, 6), (1, 7, 7)], -) -def test_rank_one_mvp(x_dim_0, x_dim_1, v_dim_0): - X = torch.randn(x_dim_0, x_dim_1) - V = torch.randn(v_dim_0, x_dim_1) - - expected = ( - (torch.vmap(lambda x: x.unsqueeze(-1) * x.unsqueeze(-1).t())(X) @ V.t()) - .sum(dim=0) - .t() - ) / x_dim_0 - - result = rank_one_mvp(X, V) - - assert result.shape == V.shape - assert torch.allclose(result, expected, atol=1e-5, rtol=1e-4) - - -@pytest.mark.torch -@pytest.mark.parametrize( - "x_dim_1", - [ - [(4, 2, 3), (5, 7), (5,)], - [(3, 6, 8, 9), (1, 2)], - [(1,)], - ], -) -@pytest.mark.parametrize( - "x_dim_0, v_dim_0", - [(10, 12), (3, 5), (4, 10), (6, 6), (1, 7)], -) -def test_generate_rank_one_mvp(x_dim_0, x_dim_1, v_dim_0): - x_list = [torch.randn(x_dim_0, *d) for d in x_dim_1] - v_list = [torch.randn(v_dim_0, *d) for d in x_dim_1] - - x = flatten_dimensions(x_list, shape=(x_dim_0, -1)) - v = flatten_dimensions(v_list, shape=(v_dim_0, -1)) - result = rank_one_mvp(x, v) - - inverse_result = flatten_dimensions( - generate_rank_one_mvp(x_list, v_list), - shape=(v_dim_0, -1), - ) - - assert torch.allclose(result, inverse_result, atol=1e-5, rtol=1e-3) - - -@pytest.mark.torch -@pytest.mark.parametrize( - "x_dim_0, x_dim_1, v_dim_0", - [(10, 1, 12), (3, 2, 5), (4, 5, 10), (6, 6, 6), (1, 7, 7)], -) -@pytest.mark.parametrize("reg", [0.1, 100, 1.0, 10]) -def test_inverse_rank_one_update(x_dim_0, x_dim_1, v_dim_0, reg): - X = torch.randn(x_dim_0, x_dim_1) - V = torch.randn(v_dim_0, x_dim_1) - - inverse_result = torch.zeros_like(V) - - for x in X: - rank_one_matrix = x.unsqueeze(-1) * x.unsqueeze(-1).t() - inverse_result += torch.linalg.solve( - rank_one_matrix + reg * torch.eye(rank_one_matrix.shape[0]), V, left=False - ) - - inverse_result /= X.shape[0] - result = inverse_rank_one_update(X, V, reg) - - assert torch.allclose(result, inverse_result, atol=1e-5) - - -@pytest.mark.torch -@pytest.mark.parametrize( - "x_dim_1", - [ - [(4, 2, 3), (5, 7), (5,)], - [(3, 6, 8, 9), (1, 2)], - [(1,)], - ], -) -@pytest.mark.parametrize( - "x_dim_0, v_dim_0", - [(10, 12), (3, 5), (4, 10), (6, 6), (1, 7)], -) -@pytest.mark.parametrize("reg", [0.5, 100, 1.0, 10]) -def test_generate_inverse_rank_one_updates(x_dim_0, x_dim_1, v_dim_0, reg): - x_list = [torch.randn(x_dim_0, *d) for d in x_dim_1] - v_list = [torch.randn(v_dim_0, *d) for d in x_dim_1] - - x = flatten_dimensions(x_list, shape=(x_dim_0, -1)) - v = flatten_dimensions(v_list, shape=(v_dim_0, -1)) - result = inverse_rank_one_update(x, v, reg) - - inverse_result = flatten_dimensions( - generate_inverse_rank_one_updates(x_list, v_list, reg), - shape=(v_dim_0, -1), - ) - - assert torch.allclose(result, inverse_result) - - class TestModelParameterDictBuilder: class SimpleModel(torch.nn.Module): def __init__(self): From 01791d503afb9345403f19a2e8a591771f4a4416 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 29 May 2024 13:55:44 +0200 Subject: [PATCH 27/43] Simplify operator interface --- src/pydvl/influence/torch/base.py | 46 +++++++++------- src/pydvl/influence/torch/batch_operation.py | 55 ++++++++++++------- src/pydvl/influence/torch/operator.py | 37 +++---------- src/pydvl/influence/types.py | 33 +++++++---- tests/influence/torch/test_batch_operation.py | 26 ++++----- 5 files changed, 105 insertions(+), 92 deletions(-) diff --git a/src/pydvl/influence/torch/base.py b/src/pydvl/influence/torch/base.py index 9bc1dc225..f186b9e60 100644 --- a/src/pydvl/influence/torch/base.py +++ b/src/pydvl/influence/torch/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from dataclasses import dataclass -from typing import Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast +from typing import Callable, Dict, List, Optional, Tuple, TypeVar, Union, cast import torch from torch.func import functional_call @@ -13,7 +13,6 @@ from ..types import ( Batch, BilinearForm, - BilinearFormType, BlockMapper, GradientProvider, Operator, @@ -341,7 +340,7 @@ def inner_prod( return self._inner_product(right, left).T def _inner_product(self, left: torch.Tensor, right: torch.Tensor) -> torch.Tensor: - left_result = self.operator.apply_to_mat(left) + left_result = self.operator.apply(left) if left_result.ndim == right.ndim and left.shape[-1] == right.shape[-1]: return left_result @ right.T @@ -405,10 +404,10 @@ def grads_inner_prod( ) if left_batch_size <= right_batch_size: - left_grads = operator.apply_to_mat_dict(left_grads) + left_grads = operator.apply_to_dict(left_grads) tensor_pairs = zip(left_grads.values(), right_grads.values()) else: - right_grads = operator.apply_to_mat_dict(right_grads) + right_grads = operator.apply_to_dict(right_grads) tensor_pairs = zip(left_grads.values(), right_grads.values()) tensors_to_reduce = ( @@ -445,7 +444,7 @@ def mixed_grads_inner_prod( operator = cast(TensorDictOperator, self.operator) right_grads = gradient_provider.mixed_grads(right) left_grads = gradient_provider.grads(left) - left_grads = operator.apply_to_mat_dict(left_grads) + left_grads = operator.apply_to_dict(left_grads) left_grads_views = (t.reshape(t.shape[0], -1) for t in left_grads.values()) right_grads_views = ( t.reshape(*right.x.shape, -1) for t in right_grads.values() @@ -490,11 +489,25 @@ def dtype(self): def to(self, device: torch.device): pass + def _validate_tensor_input(self, tensor: torch.Tensor) -> None: + if not (1 <= tensor.ndim <= 2): + raise ValueError( + f"Expected a 1 or 2 dimensional tensor, got {tensor.ndim} dimensions." + ) + if tensor.shape[-1] != self.input_size: + raise ValueError( + f"Expected the last dimension to be of size {self.input_size}." + ) + + def _apply(self, tensor: torch.Tensor) -> torch.Tensor: + + if tensor.ndim == 2: + return self._apply_to_mat(tensor.to(self.device)) + + return self._apply_to_vec(tensor.to(self.device)) + @abstractmethod def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: - pass - - def apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: """ Applies the operator to a single vector. Args: @@ -504,9 +517,8 @@ def apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: Returns: A single vector after applying the batch operation """ - return self._apply_to_vec(vec.to(self.device)) - def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: + def _apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: """ Applies the operator to a matrix. Args: @@ -519,7 +531,7 @@ def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: $(N, \text{input_size})$ """ - return torch.func.vmap(self.apply_to_vec, in_dims=0, randomness="same")(mat) + return torch.func.vmap(self._apply_to_vec, in_dims=0, randomness="same")(mat) def as_bilinear_form(self) -> OperatorBilinearForm: return OperatorBilinearForm(self) @@ -534,9 +546,7 @@ class TensorDictOperator(TensorOperator, ABC): to avoid intermediate flattening and concatenating of gradient inputs. """ - def apply_to_mat_dict( - self, mat: Dict[str, torch.Tensor] - ) -> Dict[str, torch.Tensor]: + def apply_to_dict(self, mat: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Applies the operator to a dictionary of tensors, compatible to the structure defined by the property `input_dict_structure`. @@ -555,7 +565,7 @@ def apply_to_mat_dict( f"dimension): \n {self.input_dict_structure}" ) - return self._apply_to_mat_dict(self._dict_to_device(mat)) + return self._apply_to_dict(self._dict_to_device(mat)) def _dict_to_device(self, mat: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {k: v.to(self.device) for k, v in mat.items()} @@ -570,9 +580,7 @@ def input_dict_structure(self) -> Dict[str, Tuple[int, ...]]: """ @abstractmethod - def _apply_to_mat_dict( - self, mat: Dict[str, torch.Tensor] - ) -> Dict[str, torch.Tensor]: + def _apply_to_dict(self, mat: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: pass def _validate_mat_dict(self, mat: Dict[str, torch.Tensor]) -> bool: diff --git a/src/pydvl/influence/torch/batch_operation.py b/src/pydvl/influence/torch/batch_operation.py index 88bae26ee..6e50a27e7 100644 --- a/src/pydvl/influence/torch/batch_operation.py +++ b/src/pydvl/influence/torch/batch_operation.py @@ -88,7 +88,7 @@ def to(self, device: torch.device): } return self - def apply_to_tensor_dict( + def apply_to_dict( self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: @@ -98,11 +98,11 @@ def apply_to_tensor_dict( "parameters to restrict to." ) - return self._apply_to_tensor_dict( + return self._apply_to_dict( batch, {k: v.to(self.device) for k, v in mat_dict.items()} ) - def _has_batch_dim(self, tensor_dict: Dict[str, torch.Tensor]): + def _has_batch_dim_dict(self, tensor_dict: Dict[str, torch.Tensor]): batch_dim_flags = [ tensor_dict[key].shape == val.shape for key, val in self.params_to_restrict_to.items() @@ -121,7 +121,7 @@ def _add_batch_dim(self, vec_dict: Dict[str, torch.Tensor]): return result @abstractmethod - def _apply_to_tensor_dict( + def _apply_to_dict( self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: pass @@ -130,20 +130,34 @@ def _apply_to_tensor_dict( def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: pass - def apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor): + def apply(self, batch: TorchBatch, tensor: torch.Tensor): """ - Applies the batch operation to a single vector. + Applies the batch operation to a tensor. Args: batch: Batch of data for computation - vec: A single vector consistent to the operation, i.e. it's length - must be equal to the property `input_size`. + tensor: A tensor consistent to the operation, i.e. it must be + at most 2-dim, and it's tailing dimension must + be equal to the property `input_size`. Returns: - A single vector after applying the batch operation + A tensor after applying the batch operation """ - return self._apply_to_vec(batch.to(self.device), vec.to(self.device)) - def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + if not tensor.ndim <= 2: + raise ValueError( + f"The input tensor must be at most 2-dimensional, got {tensor.ndim}" + ) + + if tensor.shape[-1] != self.input_size: + raise ValueError( + "The last dimension of the input tensor must be equal to the " + "property `input_size`." + ) + if tensor.ndim == 2: + return self._apply_to_mat(batch.to(self.device), tensor.to(self.device)) + return self._apply_to_vec(batch.to(self.device), tensor.to(self.device)) + + def _apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: """ Applies the batch operation to a matrix. Args: @@ -204,13 +218,13 @@ def __init__( def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: return self._batch_hvp(self.params_to_restrict_to, batch.x, batch.y, vec) - def _apply_to_tensor_dict( + def _apply_to_dict( self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: func = self._create_seq_func(*batch) - if self._has_batch_dim(mat_dict): + if self._has_batch_dim_dict(mat_dict): func = torch.func.vmap( func, in_dims=tuple((0 for _ in self.params_to_restrict_to)) ) @@ -274,7 +288,7 @@ def __init__( model, loss, self.params_to_restrict_to ) - def _apply_to_tensor_dict( + def _apply_to_dict( self, batch: TorchBatch, vec_dict: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: vec_values = list(self._add_batch_dim(vec_dict).values()) @@ -287,7 +301,7 @@ def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: flat_grads = self.gradient_provider.flat_grads(batch) return self._rank_one_mvp(flat_grads, vec) - def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + def _apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: """ Applies the batch operation to a matrix. Args: @@ -301,7 +315,7 @@ def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: $(N, \text{input_size})$ """ - return self.apply_to_vec(batch, mat) + return self._apply_to_vec(batch, mat) def to(self, device: torch.device): self.gradient_provider = self.gradient_provider.to(device) @@ -431,7 +445,7 @@ def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: input_vec = vec return self._inverse_rank_one_update(grads, input_vec, self.regularization) - def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + def _apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: """ Applies the batch operation to a matrix. Args: @@ -445,14 +459,14 @@ def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: $(N, \text{input_size})$ """ - return self.apply_to_vec(batch, mat) + return self._apply_to_vec(batch, mat) def to(self, device: torch.device): super().to(device) self.gradient_provider.params_to_restrict_to = self.params_to_restrict_to return self - def _apply_to_tensor_dict( + def _apply_to_dict( self, batch: TorchBatch, vec_dict: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: vec_values = list(self._add_batch_dim(vec_dict).values()) @@ -488,7 +502,8 @@ def _inverse_rank_one_update( invertible, must be positive. Returns: - Matrix of size $(D, M)$ for x having shape $(N, D)$ and v having shape $(M, D)$. + Matrix of size $(D, M)$ for x having shape $(N, D)$ and v having shape + $(M, D)$. """ nominator = torch.einsum("ij,kj->ki", x, v) denominator = x.shape[0] * (regularization + torch.sum(x**2, dim=1)) diff --git a/src/pydvl/influence/torch/operator.py b/src/pydvl/influence/torch/operator.py index 745210f05..638bafe76 100644 --- a/src/pydvl/influence/torch/operator.py +++ b/src/pydvl/influence/torch/operator.py @@ -4,6 +4,7 @@ from torch import nn as nn from torch.utils.data import DataLoader +from ..types import TensorType from .base import ( GradientProviderFactoryType, TensorDictOperator, @@ -53,12 +54,10 @@ def __init__( def input_dict_structure(self) -> Dict[str, Tuple[int, ...]]: return self.batch_operation.input_dict_structure - def _apply_to_mat_dict( - self, mat: Dict[str, torch.Tensor] - ) -> Dict[str, torch.Tensor]: + def _apply_to_dict(self, mat: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: tensor_dicts = ( - self.batch_operation.apply_to_tensor_dict(TorchBatch(x, y), mat) + self.batch_operation.apply_to_dict(TorchBatch(x, y), mat) for x, y in self.dataloader ) dict_averaging = self.averaging.as_dict_averaging() @@ -81,33 +80,13 @@ def to(self, device: torch.device): def input_size(self): return self.batch_operation.input_size - def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: - return self._apply(vec, self.batch_operation.apply_to_vec) - - def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: - """ - Applies the operator to a matrix. - Args: - mat: A matrix to apply the operator to. The last dimension is - assumed to be consistent to the operation, i.e. it must equal - to the property `input_size`. - - Returns: - A matrix of shape $(N, \text{input_size})$, given the shape of mat is - $(N, \text{input_size})$ - - """ - return self._apply(mat, self.batch_operation.apply_to_mat) - - def _apply( - self, - z: torch.Tensor, - batch_ops: Callable[[TorchBatch, torch.Tensor], torch.Tensor], - ): + def _apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: + return self._apply_to_vec(mat) + def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: tensors = ( - batch_ops( - TorchBatch(x.to(self.device), y.to(self.device)), z.to(self.device) + self.batch_operation.apply( + TorchBatch(x.to(self.device), y.to(self.device)), vec.to(self.device) ) for x, y in self.dataloader ) diff --git a/src/pydvl/influence/types.py b/src/pydvl/influence/types.py index 5cb405fda..fee7c9077 100644 --- a/src/pydvl/influence/types.py +++ b/src/pydvl/influence/types.py @@ -351,27 +351,40 @@ def input_size(self) -> int: """ @abstractmethod - def apply_to_vec(self, vec: TensorType) -> TensorType: + def _validate_tensor_input(self, tensor: TensorType) -> None: """ - Applies the operator to a vector. + Validates the input tensor for the operator. Args: - vec: A tensor representing the vector to which the operator is applied, - must conform to the operator's input size. + tensor: A tensor to validate. + + Raises: + ValueError: If the tensor is invalid for the operator. + """ + + def apply(self, tensor: TensorType) -> TensorType: + """ + Applies the operator to a tensor. + + Args: + tensor: A tensor, whose tailing dimension must conform to the + operator's input size Returns: A tensor representing the result of the operator application. """ + self._validate_tensor_input(tensor) + return self._apply(tensor) @abstractmethod - def apply_to_mat(self, mat: TensorType) -> TensorType: + def _apply(self, tensor: TensorType) -> TensorType: """ - Applies the operator to a matrix. + Applies the operator to a tensor. Implement this to handle + batched input. Args: - mat: A tensor representing the matrix to which the operator is applied, - where the first dimension is the batch dimension and last dimension - of the matrix must conform to the operator's input size + tensor: A tensor, whose tailing dimension must conform to the + operator's input size Returns: A tensor representing the result of the operator application. @@ -462,7 +475,7 @@ def transformed_grads(self, batch: BatchType): """ grads = self.gp.flat_grads(batch) - return self.op.apply_to_mat(grads) + return self.op.apply(grads) def interactions_from_transformed_grads( self, left_factors: TensorType, right_batch: BatchType, mode: InfluenceMode diff --git a/tests/influence/torch/test_batch_operation.py b/tests/influence/torch/test_batch_operation.py index f4714ba87..7b5983b5a 100644 --- a/tests/influence/torch/test_batch_operation.py +++ b/tests/influence/torch/test_batch_operation.py @@ -31,12 +31,10 @@ def test_hessian_batch_operation(model_data, tol: float, pytorch_seed): batch_size = 10 rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) - hvp_autograd_mat_dict = hessian_op.apply_to_tensor_dict( - TorchBatch(x, y), rand_mat_dict - ) + hvp_autograd_mat_dict = hessian_op.apply_to_dict(TorchBatch(x, y), rand_mat_dict) - hvp_autograd = hessian_op.apply_to_vec(TorchBatch(x, y), vec) - hvp_autograd_dict = hessian_op.apply_to_tensor_dict( + hvp_autograd = hessian_op.apply(TorchBatch(x, y), vec) + hvp_autograd_dict = hessian_op.apply_to_dict( TorchBatch(x, y), align_structure(params, vec) ) hvp_autograd_dict_flat = flatten_dimensions(hvp_autograd_dict.values()) @@ -56,7 +54,7 @@ def test_hessian_batch_operation(model_data, tol: float, pytorch_seed): rtol=tol, ) assert torch.allclose( - hessian_op.apply_to_mat(TorchBatch(x, y), flat_rand_mat), op_then_flat + hessian_op._apply_to_mat(TorchBatch(x, y), flat_rand_mat), op_then_flat ) @@ -93,8 +91,8 @@ def test_gauss_newton_batch_operation(model_data, tol: float, pytorch_seed): ) batch_size = 10 - gn_autograd = gn_op.apply_to_vec(TorchBatch(x, y), vec) - gn_autograd_dict = gn_op.apply_to_tensor_dict( + gn_autograd = gn_op.apply(TorchBatch(x, y), vec) + gn_autograd_dict = gn_op.apply_to_dict( TorchBatch(x, y), align_structure(params, vec) ) gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) @@ -104,12 +102,12 @@ def test_gauss_newton_batch_operation(model_data, tol: float, pytorch_seed): rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) - gn_autograd_mat_dict = gn_op.apply_to_tensor_dict(TorchBatch(x, y), rand_mat_dict) + gn_autograd_mat_dict = gn_op.apply_to_dict(TorchBatch(x, y), rand_mat_dict) op_then_flat = flatten_dimensions( gn_autograd_mat_dict.values(), shape=(batch_size, -1) ) - flat_then_op = gn_op.apply_to_mat(TorchBatch(x, y), flat_rand_mat) + flat_then_op = gn_op._apply_to_mat(TorchBatch(x, y), flat_rand_mat) assert torch.allclose( op_then_flat, @@ -167,8 +165,8 @@ def test_inverse_harmonic_mean_batch_operation( ) batch_size = 10 - gn_autograd = gn_op.apply_to_vec(TorchBatch(x, y), vec) - gn_autograd_dict = gn_op.apply_to_tensor_dict( + gn_autograd = gn_op.apply(TorchBatch(x, y), vec) + gn_autograd_dict = gn_op.apply_to_dict( TorchBatch(x, y), align_structure(params, vec) ) gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) @@ -179,12 +177,12 @@ def test_inverse_harmonic_mean_batch_operation( rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) - gn_autograd_mat_dict = gn_op.apply_to_tensor_dict(TorchBatch(x, y), rand_mat_dict) + gn_autograd_mat_dict = gn_op.apply_to_dict(TorchBatch(x, y), rand_mat_dict) op_then_flat = flatten_dimensions( gn_autograd_mat_dict.values(), shape=(batch_size, -1) ) - flat_then_op = gn_op.apply_to_mat(TorchBatch(x, y), flat_rand_mat) + flat_then_op = gn_op._apply_to_mat(TorchBatch(x, y), flat_rand_mat) assert torch.allclose( op_then_flat, From 79f75696ad653d0a5801bd3d97eeeb34b679e943 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 29 May 2024 14:14:18 +0200 Subject: [PATCH 28/43] Simplify tensor summation --- src/pydvl/influence/array.py | 19 ----------- .../base_influence_function_model.py | 33 +++++-------------- 2 files changed, 9 insertions(+), 43 deletions(-) diff --git a/src/pydvl/influence/array.py b/src/pydvl/influence/array.py index 7a8c5e881..7ad9a59f0 100644 --- a/src/pydvl/influence/array.py +++ b/src/pydvl/influence/array.py @@ -400,22 +400,3 @@ def _initialize_zarr_array( chunks=chunk_size, dtype=block.dtype, ) - - -class SumAggregator(SequenceAggregator): - def __call__(self, tensor_sequence: LazyChunkSequence): - """ - Aggregates tensors from a single-level generator by summing up. - - Args: - tensor_sequence: Object wrapping a generator that yields `TensorType` - objects. - - Returns: - A single tensor representing the sum of all tensors from the generator. - """ - tensor_generator = tensor_sequence.generator_factory() - result = next(tensor_generator) - for tensor in tensor_generator: - result = result + tensor - return result diff --git a/src/pydvl/influence/base_influence_function_model.py b/src/pydvl/influence/base_influence_function_model.py index 51feb8ab4..72775eea7 100644 --- a/src/pydvl/influence/base_influence_function_model.py +++ b/src/pydvl/influence/base_influence_function_model.py @@ -3,11 +3,10 @@ import logging from abc import ABC, abstractmethod from collections import OrderedDict -from functools import partial, wraps -from typing import Generic, Optional, Type +from functools import wraps +from typing import Generic, Optional, Type, cast from ..utils.progress import log_duration -from .array import LazyChunkSequence, SumAggregator from .types import BatchType, BlockMapperType, DataLoaderType, InfluenceMode, TensorType @@ -434,12 +433,10 @@ def influences_from_factors_by_block( ) def _influence_factors(self, x: TensorType, y: TensorType) -> TensorType: - tensor_gen_factory = partial( - self.block_mapper.generate_transformed_grads, self._create_batch(x, y) + transformed_grads = self.block_mapper.transformed_grads( + self._create_batch(x, y) ) - aggregator = SumAggregator() - result: TensorType = aggregator(LazyChunkSequence(tensor_gen_factory)) - return result + return cast(TensorType, sum(transformed_grads)) def _influences( self, @@ -461,15 +458,8 @@ def _influences( else: right_batch = self._create_batch(x, y) - tensor_gen_factory = partial( - self.block_mapper.generate_interactions, - left_batch, - right_batch, - mode, - ) - aggregator = SumAggregator() - result: TensorType = aggregator(LazyChunkSequence(tensor_gen_factory)) - return result + tensors = self.block_mapper.generate_interactions(left_batch, right_batch, mode) + return cast(TensorType, sum(tensors)) @InfluenceFunctionModel.fit_required def influences_from_factors( @@ -509,17 +499,12 @@ def influences_from_factors( Tensor representing the element-wise scalar products for the provided batch """ - - tensor_gen_factory = partial( - self.block_mapper.generate_interactions_from_transformed_grads, + tensors = self.block_mapper.generate_interactions_from_transformed_grads( z_test_factors, self._create_batch(x, y), mode, ) - - aggregator = SumAggregator() - result: TensorType = aggregator(LazyChunkSequence(tensor_gen_factory)) - return result + return cast(TensorType, sum(tensors)) @staticmethod @abstractmethod From 5bab8e41c84ed4c1103dc26b841a4a8441ff5b43 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 29 May 2024 14:56:51 +0200 Subject: [PATCH 29/43] Simplify Hessian operations --- src/pydvl/influence/torch/batch_operation.py | 9 ++------- src/pydvl/influence/torch/functional.py | 2 ++ src/pydvl/influence/torch/operator.py | 6 +----- tests/influence/torch/test_batch_operation.py | 2 +- 4 files changed, 6 insertions(+), 13 deletions(-) diff --git a/src/pydvl/influence/torch/batch_operation.py b/src/pydvl/influence/torch/batch_operation.py index 6e50a27e7..5f35f9173 100644 --- a/src/pydvl/influence/torch/batch_operation.py +++ b/src/pydvl/influence/torch/batch_operation.py @@ -198,7 +198,6 @@ class HessianBatchOperation(_ModelBasedBatchOperation): i.e. the corresponding sub-matrix of the Hessian. If None, the full Hessian is used. Make sure the input matches the corrct dimension, i.e. the last dimension must be equal to the property `input_size`. - reverse_only: If True only the reverse mode is used in the autograd computation. """ def __init__( @@ -206,14 +205,10 @@ def __init__( model: torch.nn.Module, loss: LossType, restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, - reverse_only: bool = True, ): super().__init__(model, restrict_to=restrict_to) - self._batch_hvp = create_batch_hvp_function( - model, loss, reverse_only=reverse_only - ) + self._batch_hvp = create_batch_hvp_function(model, loss, reverse_only=True) self.loss = loss - self.reverse_only = reverse_only def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: return self._batch_hvp(self.params_to_restrict_to, batch.x, batch.y, vec) @@ -238,7 +233,7 @@ def seq_func(*vec: torch.Tensor) -> Dict[str, torch.Tensor]: lambda p: create_batch_loss_function(self.model, self.loss)(p, x, y), self.params_to_restrict_to, dict(zip(self.params_to_restrict_to.keys(), vec)), - reverse_only=self.reverse_only, + reverse_only=True, ) return seq_func diff --git a/src/pydvl/influence/torch/functional.py b/src/pydvl/influence/torch/functional.py index 6db6f12fa..ba5acdd3e 100644 --- a/src/pydvl/influence/torch/functional.py +++ b/src/pydvl/influence/torch/functional.py @@ -50,6 +50,8 @@ "LowRankProductRepresentation", "randomized_nystroem_approximation", "model_hessian_nystroem_approximation", + "create_batch_loss_function", + "hvp", ] diff --git a/src/pydvl/influence/torch/operator.py b/src/pydvl/influence/torch/operator.py index 638bafe76..6911ba455 100644 --- a/src/pydvl/influence/torch/operator.py +++ b/src/pydvl/influence/torch/operator.py @@ -165,7 +165,6 @@ class HessianOperator(_AveragingBatchOperator[HessianBatchOperation, ChunkAverag i.e. the corresponding sub-matrix of the Hessian. If None, the full Hessian is used. Make sure the input matches the corrct dimension, i.e. the last dimension must be equal to the property `input_size`. - reverse_only: If True only the reverse mode is used in the autograd computation. """ def __init__( @@ -174,11 +173,8 @@ def __init__( loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], dataloader: DataLoader, restrict_to: Optional[Dict[str, nn.Parameter]] = None, - reverse_only: bool = True, ): - batch_op = HessianBatchOperation( - model, loss, restrict_to=restrict_to, reverse_only=reverse_only - ) + batch_op = HessianBatchOperation(model, loss, restrict_to=restrict_to) averaging = ChunkAveraging() super().__init__(batch_op, dataloader, averaging) diff --git a/tests/influence/torch/test_batch_operation.py b/tests/influence/torch/test_batch_operation.py index 7b5983b5a..c07f65ab7 100644 --- a/tests/influence/torch/test_batch_operation.py +++ b/tests/influence/torch/test_batch_operation.py @@ -23,7 +23,7 @@ def test_hessian_batch_operation(model_data, tol: float, pytorch_seed): torch_model, x, y, vec, h_analytical = model_data - params = dict(torch_model.named_parameters()) + params = {k: p.detach() for k, p in torch_model.named_parameters()} hessian_op = HessianBatchOperation( torch_model, torch.nn.functional.mse_loss, restrict_to=params From ae133065c7c071eea18f157e992015c0587ef31e Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 29 May 2024 17:06:56 +0200 Subject: [PATCH 30/43] Fix typo, import BlockMode in influence.torch --- src/pydvl/influence/torch/__init__.py | 1 + src/pydvl/influence/torch/influence_function_model.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pydvl/influence/torch/__init__.py b/src/pydvl/influence/torch/__init__.py index a1e0bb09a..9b2299d0b 100644 --- a/src/pydvl/influence/torch/__init__.py +++ b/src/pydvl/influence/torch/__init__.py @@ -8,3 +8,4 @@ NystroemSketchInfluence, ) from .pre_conditioner import JacobiPreConditioner, NystroemPreConditioner +from .util import BlockMode diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index 42be0149f..1fc249f97 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -1808,7 +1808,7 @@ class InverseHarmonicMeanInfluence( ): r""" This implementation replaces the inverse Hessian matrix in the influence computation - an approximation of the inverse Gauss-Newton vector product. + with an approximation of the inverse Gauss-Newton vector product. Viewing the damped Gauss-newton matrix From 6cf70e5350279121178b124b2b547203eb6a383b Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 29 May 2024 17:07:49 +0200 Subject: [PATCH 31/43] Add documentation for InverseHarmonicMeanInfluence --- docs/influence/influence_function_model.md | 125 +++++++++++++++++++++ docs/influence/scaling_computation.md | 3 +- 2 files changed, 126 insertions(+), 2 deletions(-) diff --git a/docs/influence/influence_function_model.md b/docs/influence/influence_function_model.md index 0a424e918..131cce052 100644 --- a/docs/influence/influence_function_model.md +++ b/docs/influence/influence_function_model.md @@ -207,7 +207,132 @@ if_model = NystroemSketchInfluence( if_model.fit(train_loader) ``` +### Inverse Harmonic Mean + +This implementation replaces the inverse Hessian matrix in the influence computation +with an approximation of the inverse Gauss-Newton vector product and was +proposed in [@kwon_datainf_2023]. + +The approximation method comprises +the following steps: + +1. Replace the Hessian $H(\theta)$ with the Gauss-Newton matrix + $G(\theta)$: + + \begin{equation*} + G(\theta)=n^{-1} \sum_{i=1}^n \nabla_{\theta}\ell_i\nabla_{\theta}\ell_i^T + \end{equation*} + + which results in + + \begin{equation*} + \mathcal{I}(z_{t}, z) \approx \nabla_{\theta} \ell(z_{t}, \theta)^T + (G(\theta) + \lambda I_d)^{-1} + \nabla_{\theta} \ell(z, \theta) + \end{equation*} + +2. Simplify the problem by breaking it down into a block diagonal structure, + where each block $G_l(\theta)$ corresponds to the l-th block: + + \begin{equation*} + G_{l}(\theta) = n^{-1} \sum_{i=1}^n \nabla_{\theta_l} \ell_i + \nabla_{\theta_l} \ell_i^{T} + \lambda_l I_{d_l}, + \end{equation*} + + which leads to + + \begin{equation*} + \mathcal{I}(z_{t}, z) \approx \nabla_{\theta} \ell(z_{t}, \theta)^T + \operatorname{diag}(G_1(\theta)^{-1}, + \dots, G_L(\theta)^{-1}) + \nabla_{\theta} \ell(z, \theta) + \end{equation*} + +3. Substitute the arithmetic mean of the rank-$1$ updates in + $G_l(\theta)$, with the inverse harmonic mean $R_l(\theta)$ of the rank-1 + updates: + + \begin{align*} + G_l(\theta)^{-1} &= \left( n^{-1} \sum_{i=1}^n \nabla_{\theta_l} + \ell(z_i, \theta) \nabla_{\theta_l} + \ell(z_i, \theta)^{T} + + \lambda_l I_{d_l}\right)^{-1} \\\ + R_{l}(\theta)&= n^{-1} \sum_{i=1}^n \left( \nabla_{\theta_l} + \ell(z_i, \theta) \nabla_{\theta_l} \ell(z_i, \theta)^{T} + + \lambda_l I_{d_l} \right)^{-1} + \end{align*} + +4. Use the + + Sherman–Morrison formula + + to get an explicit representation of the inverses in the definition of + $R_l(\theta):$ + + \begin{align*} + R_l(\theta) &= n^{-1} \sum_{i=1}^n \left( \nabla_{\theta_l} \ell_i + \nabla_{\theta_l} \ell_i^{T} + + \lambda_l I_{d_l}\right)^{-1} \\\ + &= n^{-1} \sum_{i=1}^n \lambda_l^{-1} \left(I_{d_l} + - \frac{\nabla_{\theta_l} \ell_i \nabla_{\theta_l} + \ell_i^{T}}{\lambda_l + + \\|\nabla_{\theta_l} \ell_i\\|_2^2}\right) + , + \end{align*} + + which means application of $R_l(\theta)$ boils down to computing $n$ + rank-$1$ updates. + +```python +from pydvl.influence.torch import InverseHarmonicMeanInfluence, BlockMode + +if_model = InverseHarmonicMeanInfluence( + model, + loss, + regularization=1e-1, + block_structure=BlockMode.LAYER_WISE +) +if_model.fit(train_loader) +``` + +!!! Info + This implementation is capable of using a block-matrix approximation. The + blocking structure can be specified via the `block_structure` parameter. + The `block_structure` parameter can either be a + [BlockMode][pydvl.influence.torch.util.BlockMode] enum (which provides + layer-wise or parameter-wise blocking) or a custom block structure defined + by an ordered dictionary with the keys being the block identifiers (arbitrary + strings) and the values being lists of parameter names contained in the block. + ```python + block_structure = OrderedDict( + ( + ("custom_block1", ["0.weight", "1.bias"]), + ("custom_block2", ["1.weight", "0.bias"]), + ) + ) + ``` + If you would like to apply a block-specific regularization, you can provide a + dictionary with the block names as keys and the regularization values as values. + In this case, the specification must be complete, i.e. every block must have + a positive regularization value. + ```python + regularization = { + "custom_block1": 0.1, + "custom_block2": 0.2, + } + ``` + Accordingly, if you choose a layer-wise or parameter-wise structure + (by providing `BlockMode.LAYER_WISE` or `BlockMode.PARAMETER_WISE` for + `block_structure`) the keys must be the layer names or parameter names, + respectively. + You can retrieve the block-wise influence information from the methods + with suffix `_by_block`. By default, `block_structure` is set to + `BlockMode.FULL` and in this case these methods will return a dictionary + with the empty string being the only key. + These implementations represent the calculation logic on in memory tensors. To scale up to large collection of data, we map these influence function models over these collections. For a detailed discussion see the documentation page [Scaling Computation](scaling_computation.md). + + diff --git a/docs/influence/scaling_computation.md b/docs/influence/scaling_computation.md index b8ffbe98f..32a2088ee 100644 --- a/docs/influence/scaling_computation.md +++ b/docs/influence/scaling_computation.md @@ -24,8 +24,7 @@ into memory. ```python from pydvl.influence import SequentialInfluenceCalculator from pydvl.influence.torch.util import ( - NestedTorchCatAggregator, - TorchNumpyConverter, + TorchNumpyConverter, NestedTorchCatAggregator, ) from pydvl.influence.torch import CgInfluence From 7564662c86b572a389c0a9c017974ee526064067 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 29 May 2024 18:06:56 +0200 Subject: [PATCH 32/43] Update CHANGELOG.md --- CHANGELOG.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 076d4d927..dfbe91943 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## Unreleased + +### Added + +- New method `InverseHarmonicMeanInfluence`, implementation for the paper + `DataInf: Efficiently Estimating Data Influence in LoRA-tuned LLMs and + Diffusion Models` + [PR #582](https://github.com/aai-institute/pyDVL/pull/582) +- Add new backend implementations for influence computation + to account for block-diagonal approximations + [PR #582](https://github.com/aai-institute/pyDVL/pull/582) + + ## 0.9.2 - 🏗 Bug fixes, logging improvement ### Added From 983b3793eb2b02985fac832fcd2af3bdacef8636 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Thu, 30 May 2024 09:59:16 +0200 Subject: [PATCH 33/43] Add test cases, increase tolerance --- tests/influence/torch/test_batch_operation.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/tests/influence/torch/test_batch_operation.py b/tests/influence/torch/test_batch_operation.py index c07f65ab7..f0838aaf1 100644 --- a/tests/influence/torch/test_batch_operation.py +++ b/tests/influence/torch/test_batch_operation.py @@ -64,7 +64,7 @@ def test_hessian_batch_operation(model_data, tol: float, pytorch_seed): [(astuple(tp.model_params), 1e-3) for tp in test_parameters], indirect=["model_data"], ) -def test_gauss_newton_batch_operation(model_data, tol: float, pytorch_seed): +def test_gauss_newton_batch_operation(model_data, tol: float): torch_model, x, y, vec, _ = model_data y_pred = torch_model(x) @@ -97,8 +97,8 @@ def test_gauss_newton_batch_operation(model_data, tol: float, pytorch_seed): ) gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) analytical_vec = gn_mat_analytical @ vec - assert torch.allclose(gn_autograd, analytical_vec, atol=1e-5, rtol=tol) - assert torch.allclose(gn_autograd_dict_flat, analytical_vec, atol=1e-5, rtol=tol) + assert torch.allclose(gn_autograd, analytical_vec, atol=1e-4, rtol=tol) + assert torch.allclose(gn_autograd_dict_flat, analytical_vec, atol=1e-4, rtol=tol) rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) @@ -112,7 +112,7 @@ def test_gauss_newton_batch_operation(model_data, tol: float, pytorch_seed): assert torch.allclose( op_then_flat, flat_then_op, - atol=1e-5, + atol=1e-4, rtol=tol, ) @@ -123,7 +123,7 @@ def test_gauss_newton_batch_operation(model_data, tol: float, pytorch_seed): assert torch.allclose( op_then_flat, flat_then_op_analytical, - atol=1e-5, + atol=1e-4, rtol=1e-2, ) @@ -134,10 +134,8 @@ def test_gauss_newton_batch_operation(model_data, tol: float, pytorch_seed): [(astuple(tp.model_params), 1e-3) for tp in test_parameters], indirect=["model_data"], ) -@pytest.mark.parametrize("reg", [0.4]) -def test_inverse_harmonic_mean_batch_operation( - model_data, tol: float, reg, pytorch_seed -): +@pytest.mark.parametrize("reg", [1.0, 10, 100]) +def test_inverse_harmonic_mean_batch_operation(model_data, tol: float, reg): torch_model, x, y, vec, _ = model_data y_pred = torch_model(x) out_features = y_pred.shape[1] @@ -172,8 +170,8 @@ def test_inverse_harmonic_mean_batch_operation( gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) analytical = ihm_mat_analytical @ vec - assert torch.allclose(gn_autograd, analytical, atol=1e-5, rtol=tol) - assert torch.allclose(gn_autograd_dict_flat, analytical, atol=1e-5, rtol=tol) + assert torch.allclose(gn_autograd, analytical, atol=1e-4, rtol=tol) + assert torch.allclose(gn_autograd_dict_flat, analytical, atol=1e-4, rtol=tol) rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) @@ -187,7 +185,7 @@ def test_inverse_harmonic_mean_batch_operation( assert torch.allclose( op_then_flat, flat_then_op, - atol=1e-5, + atol=1e-4, rtol=tol, ) @@ -198,7 +196,7 @@ def test_inverse_harmonic_mean_batch_operation( assert torch.allclose( op_then_flat, flat_then_op_analytical, - atol=1e-5, + atol=1e-4, rtol=1e-2, ) From ea89feb92b404ffd7342d6b94eed43384bc31ff4 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Sat, 1 Jun 2024 10:20:22 +0200 Subject: [PATCH 34/43] Fix issues --- .../influence/base_influence_function_model.py | 4 ++-- src/pydvl/influence/torch/batch_operation.py | 1 + src/pydvl/influence/torch/util.py | 14 +++++++------- src/pydvl/influence/types.py | 10 ++++++++-- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/pydvl/influence/base_influence_function_model.py b/src/pydvl/influence/base_influence_function_model.py index 72775eea7..ac8d0e15a 100644 --- a/src/pydvl/influence/base_influence_function_model.py +++ b/src/pydvl/influence/base_influence_function_model.py @@ -436,7 +436,7 @@ def _influence_factors(self, x: TensorType, y: TensorType) -> TensorType: transformed_grads = self.block_mapper.transformed_grads( self._create_batch(x, y) ) - return cast(TensorType, sum(transformed_grads)) + return cast(TensorType, sum(transformed_grads.values())) def _influences( self, @@ -449,7 +449,7 @@ def _influences( left_batch = self._create_batch(x_test, y_test) if x is None: - right_batch = left_batch + right_batch = None elif y is None: raise ValueError( "Providing model input x, without providing labels y " diff --git a/src/pydvl/influence/torch/batch_operation.py b/src/pydvl/influence/torch/batch_operation.py index 5f35f9173..9aa47453a 100644 --- a/src/pydvl/influence/torch/batch_operation.py +++ b/src/pydvl/influence/torch/batch_operation.py @@ -153,6 +153,7 @@ def apply(self, batch: TorchBatch, tensor: torch.Tensor): "The last dimension of the input tensor must be equal to the " "property `input_size`." ) + if tensor.ndim == 2: return self._apply_to_mat(batch.to(self.device), tensor.to(self.device)) return self._apply_to_vec(batch.to(self.device), tensor.to(self.device)) diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index d1018941e..34c8af56e 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -662,10 +662,10 @@ def build(self, block_structure: OrderedDict[str, List[str]]): keys are block identifiers and the inner dictionaries map parameter names to parameters. """ - parameter_dict = OrderedDict() + parameter_dict = {} for block_name, parameter_names in block_structure.items(): - inner_ordered_dict = OrderedDict() + inner_ordered_dict = {} for parameter_name in parameter_names: parameter = self.model.state_dict()[parameter_name] if parameter.requires_grad: @@ -684,7 +684,7 @@ def build(self, block_structure: OrderedDict[str, List[str]]): def build_from_block_mode( self, block_mode: BlockMode - ) -> OrderedDict[str, OrderedDict[str, torch.nn.Parameter]]: + ) -> Dict[str, Dict[str, torch.nn.Parameter]]: """ Builds an ordered dictionary of model parameters based on the specified block mode or custom blocking structure represented by an ordered dictionary, where @@ -699,10 +699,10 @@ def build_from_block_mode( keys are block identifiers and the inner dictionaries map parameter names to parameters. """ - parameter_dict = OrderedDict() + parameter_dict = {} if block_mode is BlockMode.FULL: - inner_ordered_dict = OrderedDict() + inner_ordered_dict = {} for k, v in self.model.named_parameters(): if v.requires_grad: inner_ordered_dict[k] = self._optional_detach(v) @@ -711,11 +711,11 @@ def build_from_block_mode( elif block_mode is BlockMode.PARAMETER_WISE: for k, v in self.model.named_parameters(): if v.requires_grad: - parameter_dict[k] = OrderedDict({k: self._optional_detach(v)}) + parameter_dict[k] = {k: self._optional_detach(v)} if block_mode is BlockMode.LAYER_WISE: for name, submodule in self.model.named_children(): - inner_ordered_dict = OrderedDict() + inner_ordered_dict = {} for param_name, param in submodule.named_parameters(): if param.requires_grad: inner_ordered_dict[ diff --git a/src/pydvl/influence/types.py b/src/pydvl/influence/types.py index fee7c9077..45f93fe4a 100644 --- a/src/pydvl/influence/types.py +++ b/src/pydvl/influence/types.py @@ -426,7 +426,10 @@ def __init__(self, op: OperatorType, gp: GradientProviderType): self.op = op def interactions( - self, left_batch: BatchType, right_batch: BatchType, mode: InfluenceMode + self, + left_batch: BatchType, + right_batch: Optional[BatchType], + mode: InfluenceMode, ): r""" Computes the interaction between the gradients on two batches of data based on @@ -633,7 +636,10 @@ def generate_transformed_grads( yield comp_block.transformed_grads(batch) def generate_interactions( - self, left_batch: BatchType, right_batch: BatchType, mode: InfluenceMode + self, + left_batch: BatchType, + right_batch: Optional[BatchType], + mode: InfluenceMode, ) -> Generator[TensorType, None, None]: """ Generator that yields gradient interactions between two batches, processed by From d21da08ae4e09318d8f5951228bc2b441ef9fca7 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Sat, 1 Jun 2024 10:38:44 +0200 Subject: [PATCH 35/43] Add property to base class --- src/pydvl/influence/torch/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/pydvl/influence/torch/base.py b/src/pydvl/influence/torch/base.py index f186b9e60..5d7b7bde4 100644 --- a/src/pydvl/influence/torch/base.py +++ b/src/pydvl/influence/torch/base.py @@ -501,7 +501,7 @@ def _validate_tensor_input(self, tensor: torch.Tensor) -> None: def _apply(self, tensor: torch.Tensor) -> torch.Tensor: - if tensor.ndim == 2: + if tensor.ndim == 2 and tensor.shape[0] > 1: return self._apply_to_mat(tensor.to(self.device)) return self._apply_to_vec(tensor.to(self.device)) @@ -695,6 +695,10 @@ def __init__( def block_names(self) -> List[str]: return list(self.parameter_dict.keys()) + @property + def n_parameters(self): + return sum(block.op.input_size for _, block in self.block_mapper.items()) + @abstractmethod def with_regularization( self, regularization: Union[float, Dict[str, Optional[float]]] From 9c5ddeabfd5b0c8b1cdb9c670e98291d03008c8e Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Mon, 3 Jun 2024 09:36:09 +0200 Subject: [PATCH 36/43] Simplify gradient computation for torch, remove ABC for now, do not expose provider to user --- src/pydvl/influence/torch/base.py | 131 +++++------------- src/pydvl/influence/torch/batch_operation.py | 37 +---- .../torch/influence_function_model.py | 8 +- src/pydvl/influence/torch/operator.py | 21 +-- .../influence/torch/test_gradient_provider.py | 10 +- 5 files changed, 52 insertions(+), 155 deletions(-) diff --git a/src/pydvl/influence/torch/base.py b/src/pydvl/influence/torch/base.py index 5d7b7bde4..99fb359b5 100644 --- a/src/pydvl/influence/torch/base.py +++ b/src/pydvl/influence/torch/base.py @@ -64,14 +64,11 @@ def to(self, device: torch.device): return TorchBatch(self.x.to(device), self.y.to(device)) -class TorchGradientProvider(GradientProvider[TorchBatch, torch.Tensor], ABC): +class TorchGradientProvider(GradientProvider[TorchBatch, torch.Tensor]): r""" - Abstract base class for calculating per-sample gradients of a function defined by - a [torch.nn.Module][torch.nn.Module] and a loss function. - - This class must be subclassed with implementations for its abstract methods tailored - to specific gradient computation needs, e.g. using [torch.autograd][torch.autograd] - or stochastic finite differences. + Compute per-sample gradients of a function defined by + a [torch.nn.Module][torch.nn.Module] and a loss function using + [torch.func][torch.func]. Consider a function @@ -94,6 +91,12 @@ def __init__( loss: LossType, restrict_to: Optional[Dict[str, torch.nn.Parameter]], ): + self._per_sample_gradient_function = create_per_sample_gradient_function( + model, loss + ) + self._per_sample_mixed_gradient_func = ( + create_per_sample_mixed_derivative_function(model, loss) + ) self.loss = loss self.model = model @@ -104,6 +107,34 @@ def __init__( self.params_to_restrict_to = restrict_to + def _compute_loss( + self, params: Dict[str, torch.Tensor], x: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: + outputs = functional_call(self.model, params, (x.unsqueeze(0).to(self.device),)) + return self.loss(outputs, y.unsqueeze(0)) + + def _grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + return self._per_sample_gradient_function( + self.params_to_restrict_to, batch.x, batch.y + ) + + def _mixed_grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + return self._per_sample_mixed_gradient_func( + self.params_to_restrict_to, batch.x, batch.y + ) + + def _jacobian_prod( + self, + batch: TorchBatch, + g: torch.Tensor, + ) -> torch.Tensor: + matrix_jacobian_product_func = create_matrix_jacobian_product_function( + self.model, self.loss, g + ) + return matrix_jacobian_product_func( + self.params_to_restrict_to, batch.x, batch.y + ) + def to(self, device: torch.device): self.model = self.model.to(device) self.params_to_restrict_to = { @@ -121,22 +152,6 @@ def device(self): def dtype(self): return next(self.model.parameters()).dtype - @abstractmethod - def _grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: - pass - - @abstractmethod - def _mixed_grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: - pass - - @abstractmethod - def _jacobian_prod( - self, - batch: TorchBatch, - g: torch.Tensor, - ) -> torch.Tensor: - pass - @staticmethod def _detach_dict(tensor_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {k: g.detach() if g.requires_grad else g for k, g in tensor_dict.items()} @@ -229,76 +244,6 @@ def flat_mixed_grads(self, batch: TorchBatch) -> torch.Tensor: return flatten_dimensions(self.mixed_grads(batch).values(), shape=shape) -class TorchAutoGrad(TorchGradientProvider): - r""" - Compute per-sample gradients of a function defined by - a [torch.nn.Module][torch.nn.Module] and a loss function using - [torch.func][torch.func]. - - Consider a function - - $$ \ell: \mathbb{R}^{d_1} \times \mathbb{R}^{d_2} \times \mathbb{R}^{n} \times - \mathbb{R}^{n}, \quad \ell(\omega_1, \omega_2, x, y) = - \operatorname{loss}(f(\omega_1, \omega_2; x), y) $$ - - e.g. a two layer neural network $f$ with a loss function, then this object should - compute the expressions: - - $$ \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, x, y), - \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, \omega_2, x, y), - \nabla_{\omega}\ell(\omega_1, \omega_2, x, y) \cdot v$$ - - """ - - def __init__( - self, - model: torch.nn.Module, - loss: LossType, - restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, - ): - super().__init__(model, loss, restrict_to) - self._per_sample_gradient_function = create_per_sample_gradient_function( - model, loss - ) - self._per_sample_mixed_gradient_func = ( - create_per_sample_mixed_derivative_function(model, loss) - ) - - def _compute_loss( - self, params: Dict[str, torch.Tensor], x: torch.Tensor, y: torch.Tensor - ) -> torch.Tensor: - outputs = functional_call(self.model, params, (x.unsqueeze(0).to(self.device),)) - return self.loss(outputs, y.unsqueeze(0)) - - def _grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: - return self._per_sample_gradient_function( - self.params_to_restrict_to, batch.x, batch.y - ) - - def _mixed_grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: - return self._per_sample_mixed_gradient_func( - self.params_to_restrict_to, batch.x, batch.y - ) - - def _jacobian_prod( - self, - batch: TorchBatch, - g: torch.Tensor, - ) -> torch.Tensor: - matrix_jacobian_product_func = create_matrix_jacobian_product_function( - self.model, self.loss, g - ) - return matrix_jacobian_product_func( - self.params_to_restrict_to, batch.x, batch.y - ) - - -GradientProviderFactoryType = Callable[ - [torch.nn.Module, LossType, Optional[Dict[str, torch.nn.Parameter]]], - TorchGradientProvider, -] - - class OperatorBilinearForm( BilinearForm[torch.Tensor, TorchBatch, TorchGradientProvider] ): diff --git a/src/pydvl/influence/torch/batch_operation.py b/src/pydvl/influence/torch/batch_operation.py index 9aa47453a..908da3ef3 100644 --- a/src/pydvl/influence/torch/batch_operation.py +++ b/src/pydvl/influence/torch/batch_operation.py @@ -16,27 +16,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import ( - Callable, - Dict, - Generator, - Generic, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, -) +from typing import Callable, Dict, Generator, Generic, List, Optional, Tuple, TypeVar import torch -from .base import ( - GradientProviderFactoryType, - TorchAutoGrad, - TorchBatch, - TorchGradientProvider, -) +from .base import TorchBatch, TorchGradientProvider from .functional import create_batch_hvp_function, create_batch_loss_function, hvp from .util import LossType @@ -258,11 +242,6 @@ class GaussNewtonBatchOperation(_ModelBasedBatchOperation): Args: model: The model. loss: The loss function. - gradient_provider_factory: An optional factory to create an object of type - [TorchGradientProvider][pydvl.influence.torch.base.TorchGradientProvider], - depending on the model, loss and optional parameters to restrict to. - If not provided, the implementation - [TorchAutograd][pydvl.influence.torch.base.TorchAutograd] is used. restrict_to: The parameters to restrict the differentiation to, i.e. the corresponding sub-matrix of the Jacobian. If None, the full Jacobian is used. Make sure the input matches the corrct dimension, i.e. the @@ -273,14 +252,10 @@ def __init__( self, model: torch.nn.Module, loss: LossType, - gradient_provider_factory: Union[ - GradientProviderFactoryType, - Type[TorchGradientProvider], - ] = TorchAutoGrad, restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, ): super().__init__(model, restrict_to=restrict_to) - self.gradient_provider = gradient_provider_factory( + self.gradient_provider = TorchGradientProvider( model, loss, self.params_to_restrict_to ) @@ -408,10 +383,6 @@ def __init__( model: torch.nn.Module, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], regularization: float, - gradient_provider_factory: Union[ - GradientProviderFactoryType, - Type[TorchGradientProvider], - ] = TorchAutoGrad, restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, ): if regularization <= 0: @@ -419,7 +390,7 @@ def __init__( self.regularization = regularization super().__init__(model, restrict_to=restrict_to) - self.gradient_provider = gradient_provider_factory( + self.gradient_provider = TorchGradientProvider( model, loss, self.params_to_restrict_to ) diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index 1fc249f97..35ab09501 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -24,8 +24,8 @@ UnsupportedInfluenceModeException, ) from .base import ( - TorchAutoGrad, TorchComposableInfluence, + TorchGradientProvider, TorchOperatorGradientComposition, ) from .functional import ( @@ -1904,7 +1904,6 @@ def __init__( block_structure: Union[BlockMode, OrderedDict[str, List[str]]] = BlockMode.FULL, ): super().__init__(model, block_structure, regularization=regularization) - self.gradient_provider_factory = TorchAutoGrad self.loss = loss @property @@ -1938,12 +1937,9 @@ def _create_block( self.loss, data, regularization, - self.gradient_provider_factory, restrict_to=block_params, ) - gp = self.gradient_provider_factory( - self.model, self.loss, restrict_to=block_params - ) + gp = TorchGradientProvider(self.model, self.loss, restrict_to=block_params) return TorchOperatorGradientComposition(op, gp) def with_regularization( diff --git a/src/pydvl/influence/torch/operator.py b/src/pydvl/influence/torch/operator.py index 6911ba455..2396a2efb 100644 --- a/src/pydvl/influence/torch/operator.py +++ b/src/pydvl/influence/torch/operator.py @@ -1,17 +1,10 @@ -from typing import Callable, Dict, Generic, Optional, Tuple, Type, Union +from typing import Callable, Dict, Generic, Optional, Tuple import torch from torch import nn as nn from torch.utils.data import DataLoader -from ..types import TensorType -from .base import ( - GradientProviderFactoryType, - TensorDictOperator, - TorchAutoGrad, - TorchBatch, - TorchGradientProvider, -) +from .base import TensorDictOperator, TorchBatch from .batch_operation import ( BatchOperationType, ChunkAveraging, @@ -127,16 +120,11 @@ def __init__( model: nn.Module, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], dataloader: DataLoader, - gradient_provider_factory: Union[ - GradientProviderFactoryType, - Type[TorchGradientProvider], - ] = TorchAutoGrad, restrict_to: Optional[Dict[str, nn.Parameter]] = None, ): batch_op = GaussNewtonBatchOperation( model, loss, - gradient_provider_factory=gradient_provider_factory, restrict_to=restrict_to, ) averaging = PointAveraging() @@ -240,10 +228,6 @@ def __init__( loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], dataloader: DataLoader, regularization: float, - gradient_provider_factory: Union[ - GradientProviderFactoryType, - Type[TorchGradientProvider], - ] = TorchAutoGrad, restrict_to: Optional[Dict[str, nn.Parameter]] = None, ): if regularization <= 0: @@ -255,7 +239,6 @@ def __init__( model, loss, regularization, - gradient_provider_factory=gradient_provider_factory, restrict_to=restrict_to, ) averaging = PointAveraging() diff --git a/tests/influence/torch/test_gradient_provider.py b/tests/influence/torch/test_gradient_provider.py index 8fab7fadf..ebee2923d 100644 --- a/tests/influence/torch/test_gradient_provider.py +++ b/tests/influence/torch/test_gradient_provider.py @@ -2,7 +2,7 @@ import pytest import torch -from pydvl.influence.torch.base import TorchAutoGrad, TorchBatch +from pydvl.influence.torch.base import TorchBatch, TorchGradientProvider from ..conftest import linear_mixed_second_derivative_analytical, linear_model from .conftest import DATA_OUTPUT_NOISE, linear_mvp_model @@ -22,7 +22,7 @@ def test_per_sample_gradient(self, in_features, out_features, batch_size): y = torch.randn(batch_size, out_features) params = {k: p.detach() for k, p in model.named_parameters() if p.requires_grad} - gp = TorchAutoGrad(model, loss, restrict_to=params) + gp = TorchGradientProvider(model, loss, restrict_to=params) gradients = gp.grads(TorchBatch(x, y)) flat_gradients = gp.flat_grads(TorchBatch(x, y)) @@ -69,7 +69,7 @@ def test_mixed_derivatives(self, in_features, out_features, train_set_size): torch_train_x = torch.as_tensor(train_x) torch_train_y = torch.as_tensor(train_y) - gp = TorchAutoGrad(model, loss, restrict_to=params) + gp = TorchGradientProvider(model, loss, restrict_to=params) flat_functorch_mixed_derivatives = gp.flat_mixed_grads( TorchBatch(torch_train_x, torch_train_y) ) @@ -93,7 +93,9 @@ def test_matrix_jacobian_product( y = torch.randn(batch_size, out_features, requires_grad=True) y_pred = model(x) - gp = TorchAutoGrad(model, torch.nn.functional.mse_loss, restrict_to=params) + gp = TorchGradientProvider( + model, torch.nn.functional.mse_loss, restrict_to=params + ) G = torch.randn((10, out_features * (in_features + 1))) mjp = gp.jacobian_prod(TorchBatch(x, y), G) From 8498524c92f0ec98a572f44bc9bc9d6ef6c2ac5b Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Mon, 3 Jun 2024 10:29:39 +0200 Subject: [PATCH 37/43] Simplify gradient computation --- src/pydvl/influence/torch/base.py | 49 +++++++++++++++---------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/src/pydvl/influence/torch/base.py b/src/pydvl/influence/torch/base.py index 99fb359b5..01efe605a 100644 --- a/src/pydvl/influence/torch/base.py +++ b/src/pydvl/influence/torch/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple, TypeVar, Union, cast +from typing import Dict, List, Optional, Tuple, TypeVar, Union, cast import torch from torch.func import functional_call @@ -18,16 +18,12 @@ Operator, OperatorGradientComposition, ) -from .functional import ( - create_matrix_jacobian_product_function, - create_per_sample_gradient_function, - create_per_sample_mixed_derivative_function, -) from .util import ( BlockMode, LossType, ModelInfoMixin, ModelParameterDictBuilder, + align_structure, flatten_dimensions, ) @@ -91,14 +87,8 @@ def __init__( loss: LossType, restrict_to: Optional[Dict[str, torch.nn.Parameter]], ): - self._per_sample_gradient_function = create_per_sample_gradient_function( - model, loss - ) - self._per_sample_mixed_gradient_func = ( - create_per_sample_mixed_derivative_function(model, loss) - ) - self.loss = loss self.model = model + self.loss = loss if restrict_to is None: restrict_to = ModelParameterDictBuilder(model).build_from_block_mode( @@ -114,26 +104,35 @@ def _compute_loss( return self.loss(outputs, y.unsqueeze(0)) def _grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: - return self._per_sample_gradient_function( - self.params_to_restrict_to, batch.x, batch.y - ) + result: Dict[str, torch.Tensor] = torch.vmap( + torch.func.grad(self._compute_loss), in_dims=(None, 0, 0) + )(self.params_to_restrict_to, batch.x, batch.y) + return result def _mixed_grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: - return self._per_sample_mixed_gradient_func( - self.params_to_restrict_to, batch.x, batch.y - ) + result: Dict[str, torch.Tensor] = torch.vmap( + torch.func.jacrev(torch.func.grad(self._compute_loss, argnums=1)), + in_dims=(None, 0, 0), + )(self.params_to_restrict_to, batch.x, batch.y) + return result def _jacobian_prod( self, batch: TorchBatch, g: torch.Tensor, ) -> torch.Tensor: - matrix_jacobian_product_func = create_matrix_jacobian_product_function( - self.model, self.loss, g - ) - return matrix_jacobian_product_func( - self.params_to_restrict_to, batch.x, batch.y - ) + def single_jvp( + _g: torch.Tensor, + ): + return torch.func.jvp( + lambda p: torch.vmap(self._compute_loss, in_dims=(None, 0, 0))( + p, *batch + ), + (self.params_to_restrict_to,), + (align_structure(self.params_to_restrict_to, _g),), + )[1] + + return torch.func.vmap(single_jvp)(g) def to(self, device: torch.device): self.model = self.model.to(device) From 5b7bbb99c344f5224c2cefa7c9dee51486a26e0a Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 5 Jun 2024 12:38:40 +0200 Subject: [PATCH 38/43] Fix issue in building factors from factors dict --- .../influence/base_influence_function_model.py | 17 +++++++++++++---- src/pydvl/influence/torch/base.py | 9 ++++++++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/pydvl/influence/base_influence_function_model.py b/src/pydvl/influence/base_influence_function_model.py index ac8d0e15a..058ef823b 100644 --- a/src/pydvl/influence/base_influence_function_model.py +++ b/src/pydvl/influence/base_influence_function_model.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from functools import wraps -from typing import Generic, Optional, Type, cast +from typing import Generic, Iterable, Optional, Type, cast from ..utils.progress import log_duration from .types import BatchType, BlockMapperType, DataLoaderType, InfluenceMode, TensorType @@ -433,10 +433,19 @@ def influences_from_factors_by_block( ) def _influence_factors(self, x: TensorType, y: TensorType) -> TensorType: - transformed_grads = self.block_mapper.transformed_grads( - self._create_batch(x, y) + transformed_grads = self.influence_factors_by_block(x, y) + transformed_grads = ( + self._flatten_trailing_dim(t) for t in transformed_grads.values() ) - return cast(TensorType, sum(transformed_grads.values())) + return cast(TensorType, self._concat(transformed_grads, dim=-1)) + + @abstractmethod + def _concat(self, tensors: Iterable[TensorType], dim: int): + """Implement this to concat tensors at a specified dimension""" + + @abstractmethod + def _flatten_trailing_dim(self, tensor: TensorType): + """Implement this to flatten all but the first dimension""" def _influences( self, diff --git a/src/pydvl/influence/torch/base.py b/src/pydvl/influence/torch/base.py index 01efe605a..65b6d4f8b 100644 --- a/src/pydvl/influence/torch/base.py +++ b/src/pydvl/influence/torch/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, TypeVar, Union, cast +from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union, cast import torch from torch.func import functional_call @@ -17,6 +17,7 @@ GradientProvider, Operator, OperatorGradientComposition, + TensorType, ) from .util import ( BlockMode, @@ -635,6 +636,12 @@ def __init__( super().__init__(model) + def _concat(self, tensors: Iterable[torch.Tensor], dim: int): + return torch.cat(list(tensors), dim=dim) + + def _flatten_trailing_dim(self, tensor: torch.Tensor): + return tensor.reshape((tensor.shape[0], -1)) + @property def block_names(self) -> List[str]: return list(self.parameter_dict.keys()) From a8caf139511b2b41c1bcd0430fa921e9998b7700 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 5 Jun 2024 16:04:22 +0200 Subject: [PATCH 39/43] Fix issue in types module, due to OrderedDict from typing and collections --- src/pydvl/influence/types.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pydvl/influence/types.py b/src/pydvl/influence/types.py index 45f93fe4a..2d5f6be73 100644 --- a/src/pydvl/influence/types.py +++ b/src/pydvl/influence/types.py @@ -42,8 +42,8 @@ from __future__ import annotations +import collections from abc import ABC, abstractmethod -from collections import OrderedDict from dataclasses import dataclass from enum import Enum from typing import ( @@ -53,6 +53,7 @@ Generic, Iterable, Optional, + OrderedDict, TypeVar, Union, cast, @@ -547,7 +548,7 @@ def items(self): def _to_ordered_dict( self, tensor_generator: Generator[TensorType, None, None] ) -> OrderedDict[str, TensorType]: - tensor_dict = OrderedDict() + tensor_dict = collections.OrderedDict() for k, t in zip(self.composable_block_dict.keys(), tensor_generator): tensor_dict[k] = t return tensor_dict From a69ae3d048a1e0c7710ec77450c6a410c54946dd Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 5 Jun 2024 16:45:57 +0200 Subject: [PATCH 40/43] Refactor ModelParameterDictBuilder to be more readable --- src/pydvl/influence/torch/util.py | 72 ++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index 34c8af56e..9700f047f 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -11,7 +11,6 @@ Callable, Collection, Dict, - Generator, Iterable, Iterator, List, @@ -647,7 +646,16 @@ def _optional_detach(self, p: torch.nn.Parameter): return p.detach() return p - def build(self, block_structure: OrderedDict[str, List[str]]): + def _extract_parameter_by_name(self, name: str) -> torch.nn.Parameter: + for k, p in self.model.named_parameters(): + if k == name: + return p + else: + raise ValueError(f"Parameter {name} not found in the model.") + + def build( + self, block_structure: OrderedDict[str, List[str]] + ) -> Dict[str, Dict[str, torch.nn.Parameter]]: """ Builds an ordered dictionary of model parameters based on the specified block structure represented by an ordered dictionary, where the keys are block @@ -667,7 +675,7 @@ def build(self, block_structure: OrderedDict[str, List[str]]): for block_name, parameter_names in block_structure.items(): inner_ordered_dict = {} for parameter_name in parameter_names: - parameter = self.model.state_dict()[parameter_name] + parameter = self._extract_parameter_by_name(parameter_name) if parameter.requires_grad: inner_ordered_dict[parameter_name] = self._optional_detach( parameter @@ -699,31 +707,43 @@ def build_from_block_mode( keys are block identifiers and the inner dictionaries map parameter names to parameters. """ - parameter_dict = {} - if block_mode is BlockMode.FULL: - inner_ordered_dict = {} - for k, v in self.model.named_parameters(): - if v.requires_grad: - inner_ordered_dict[k] = self._optional_detach(v) - parameter_dict[""] = inner_ordered_dict - - elif block_mode is BlockMode.PARAMETER_WISE: - for k, v in self.model.named_parameters(): - if v.requires_grad: - parameter_dict[k] = {k: self._optional_detach(v)} - - if block_mode is BlockMode.LAYER_WISE: - for name, submodule in self.model.named_children(): - inner_ordered_dict = {} - for param_name, param in submodule.named_parameters(): - if param.requires_grad: - inner_ordered_dict[ - f"{name}.{param_name}" - ] = self._optional_detach(param) - if inner_ordered_dict: - parameter_dict[name] = inner_ordered_dict + block_mode_mapping = { + BlockMode.FULL: self._build_full, + BlockMode.PARAMETER_WISE: self._build_parameter_wise, + BlockMode.LAYER_WISE: self._build_layer_wise, + } + + parameter_dict_func = block_mode_mapping.get(block_mode, None) + + if parameter_dict_func is None: + raise ValueError(f"Unknown block mode {block_mode}.") + + return self.build(parameter_dict_func()) + + def _build_full(self): + parameter_dict = OrderedDict() + parameter_dict[""] = [ + n for n, p in self.model.named_parameters() if p.requires_grad + ] + return parameter_dict + + def _build_parameter_wise(self): + parameter_dict = OrderedDict() + for k, v in self.model.named_parameters(): + if v.requires_grad: + parameter_dict[k] = [k] + return parameter_dict + def _build_layer_wise(self): + parameter_dict = OrderedDict() + for name, submodule in self.model.named_children(): + layer_parameter_names = [] + for param_name, param in submodule.named_parameters(): + if param.requires_grad: + layer_parameter_names.append(f"{name}.{param_name}") + if layer_parameter_names: + parameter_dict[name] = layer_parameter_names return parameter_dict From 23442e49db90f20fe009eaa506c9bb8bc097bff8 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 5 Jun 2024 18:02:17 +0200 Subject: [PATCH 41/43] Refactor tests in test_batch_operation to be more readable --- tests/influence/torch/test_batch_operation.py | 392 +++++++++++------- 1 file changed, 237 insertions(+), 155 deletions(-) diff --git a/tests/influence/torch/test_batch_operation.py b/tests/influence/torch/test_batch_operation.py index f0838aaf1..cf3ea58d6 100644 --- a/tests/influence/torch/test_batch_operation.py +++ b/tests/influence/torch/test_batch_operation.py @@ -15,190 +15,272 @@ @pytest.mark.torch -@pytest.mark.parametrize( - "model_data, tol", - [(astuple(tp.model_params), 1e-5) for tp in test_parameters], - indirect=["model_data"], -) -def test_hessian_batch_operation(model_data, tol: float, pytorch_seed): - torch_model, x, y, vec, h_analytical = model_data - - params = {k: p.detach() for k, p in torch_model.named_parameters()} +class TestHessianBatchOperation: + @pytest.fixture(autouse=True) + def setup(self, model_data): + self.torch_model, self.x, self.y, self.vec, self.h_analytical = model_data + self.params = {k: p.detach() for k, p in self.torch_model.named_parameters()} + self.hessian_op = HessianBatchOperation( + self.torch_model, torch.nn.functional.mse_loss, restrict_to=self.params + ) - hessian_op = HessianBatchOperation( - torch_model, torch.nn.functional.mse_loss, restrict_to=params + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-5) for tp in test_parameters], + indirect=["model_data"], ) - batch_size = 10 - rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} - flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) - hvp_autograd_mat_dict = hessian_op.apply_to_dict(TorchBatch(x, y), rand_mat_dict) - - hvp_autograd = hessian_op.apply(TorchBatch(x, y), vec) - hvp_autograd_dict = hessian_op.apply_to_dict( - TorchBatch(x, y), align_structure(params, vec) - ) - hvp_autograd_dict_flat = flatten_dimensions(hvp_autograd_dict.values()) + def test_analytical_comparison(self, model_data, tol, pytorch_seed): + hvp_autograd = self.hessian_op.apply(TorchBatch(self.x, self.y), self.vec) + hvp_autograd_dict = self.hessian_op.apply_to_dict( + TorchBatch(self.x, self.y), align_structure(self.params, self.vec) + ) + hvp_autograd_dict_flat = flatten_dimensions(hvp_autograd_dict.values()) - assert torch.allclose(hvp_autograd, h_analytical @ vec, rtol=tol) - assert torch.allclose(hvp_autograd_dict_flat, h_analytical @ vec, rtol=tol) + assert torch.allclose(hvp_autograd, self.h_analytical @ self.vec, rtol=tol) + assert torch.allclose( + hvp_autograd_dict_flat, self.h_analytical @ self.vec, rtol=tol + ) - op_then_flat = flatten_dimensions( - hvp_autograd_mat_dict.values(), shape=(batch_size, -1) + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-5) for tp in test_parameters], + indirect=["model_data"], ) - flat_then_op_analytical = torch.einsum("ik, jk -> ji", h_analytical, flat_rand_mat) + def test_flattening_commutation(self, model_data, tol, pytorch_seed): + batch_size = 10 + rand_mat_dict = { + k: torch.randn(batch_size, *t.shape) for k, t in self.params.items() + } + flat_rand_mat = flatten_dimensions( + rand_mat_dict.values(), shape=(batch_size, -1) + ) + hvp_autograd_mat_dict = self.hessian_op.apply_to_dict( + TorchBatch(self.x, self.y), rand_mat_dict + ) + op_then_flat = flatten_dimensions( + hvp_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op_analytical = torch.einsum( + "ik, jk -> ji", self.h_analytical, flat_rand_mat + ) - assert torch.allclose( - op_then_flat, - flat_then_op_analytical, - atol=1e-5, - rtol=tol, - ) - assert torch.allclose( - hessian_op._apply_to_mat(TorchBatch(x, y), flat_rand_mat), op_then_flat - ) + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-5, + rtol=tol, + ) + assert torch.allclose( + self.hessian_op._apply_to_mat(TorchBatch(self.x, self.y), flat_rand_mat), + op_then_flat, + ) @pytest.mark.torch -@pytest.mark.parametrize( - "model_data, tol", - [(astuple(tp.model_params), 1e-3) for tp in test_parameters], - indirect=["model_data"], -) -def test_gauss_newton_batch_operation(model_data, tol: float): - torch_model, x, y, vec, _ = model_data - - y_pred = torch_model(x) - out_features = y_pred.shape[1] - dl_dw = torch.vmap( - lambda r, s, t: 2 / float(out_features) * (s - t).view(-1, 1) @ r.view(1, -1) - )(x, y_pred, y) - dl_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) - grad_analytical = torch.cat([dl_dw.reshape(x.shape[0], -1), dl_db], dim=-1) - gn_mat_analytical = ( - torch.sum( - torch.func.vmap(lambda t: t.unsqueeze(-1) * t.unsqueeze(-1).t())( - grad_analytical - ), - dim=0, - ) - / x.shape[0] - ) - - params = dict(torch_model.named_parameters()) - - gn_op = GaussNewtonBatchOperation( - torch_model, torch.nn.functional.mse_loss, restrict_to=params - ) - batch_size = 10 +class TestGaussNewtonBatchOperation: + @pytest.fixture(autouse=True) + def setup(self, model_data): + self.torch_model, self.x, self.y, self.vec, _ = model_data + self.params = dict(self.torch_model.named_parameters()) + self.gn_op = GaussNewtonBatchOperation( + self.torch_model, torch.nn.functional.mse_loss, restrict_to=self.params + ) + self.out_features = self.torch_model(self.x).shape[1] + self.grad_analytical = self.compute_grad_analytical() + self.gn_mat_analytical = self.compute_gn_mat_analytical() + + def compute_grad_analytical(self): + y_pred = self.torch_model(self.x) + dl_dw = torch.vmap( + lambda r, s, t: 2 + / float(self.out_features) + * (s - t).view(-1, 1) + @ r.view(1, -1) + )(self.x, y_pred, self.y) + dl_db = torch.vmap(lambda s, t: 2 / float(self.out_features) * (s - t))( + y_pred, self.y + ) + return torch.cat([dl_dw.reshape(self.x.shape[0], -1), dl_db], dim=-1) + + def compute_gn_mat_analytical(self): + return ( + torch.sum( + torch.func.vmap(lambda t: t.unsqueeze(-1) * t.unsqueeze(-1).t())( + self.grad_analytical + ), + dim=0, + ) + / self.x.shape[0] + ) - gn_autograd = gn_op.apply(TorchBatch(x, y), vec) - gn_autograd_dict = gn_op.apply_to_dict( - TorchBatch(x, y), align_structure(params, vec) + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], ) - gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) - analytical_vec = gn_mat_analytical @ vec - assert torch.allclose(gn_autograd, analytical_vec, atol=1e-4, rtol=tol) - assert torch.allclose(gn_autograd_dict_flat, analytical_vec, atol=1e-4, rtol=tol) + def test_analytical_comparison(self, tol): + gn_autograd = self.gn_op.apply(TorchBatch(self.x, self.y), self.vec) + gn_autograd_dict = self.gn_op.apply_to_dict( + TorchBatch(self.x, self.y), align_structure(self.params, self.vec) + ) + gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) + analytical_vec = self.gn_mat_analytical @ self.vec - rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} - flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) - gn_autograd_mat_dict = gn_op.apply_to_dict(TorchBatch(x, y), rand_mat_dict) + assert torch.allclose(gn_autograd, analytical_vec, atol=1e-4, rtol=tol) + assert torch.allclose( + gn_autograd_dict_flat, analytical_vec, atol=1e-4, rtol=tol + ) - op_then_flat = flatten_dimensions( - gn_autograd_mat_dict.values(), shape=(batch_size, -1) + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], ) - flat_then_op = gn_op._apply_to_mat(TorchBatch(x, y), flat_rand_mat) + def test_flattening_commutation(self, tol): + batch_size = 10 + rand_mat_dict = { + k: torch.randn(batch_size, *t.shape) for k, t in self.params.items() + } + flat_rand_mat = flatten_dimensions( + rand_mat_dict.values(), shape=(batch_size, -1) + ) + gn_autograd_mat_dict = self.gn_op.apply_to_dict( + TorchBatch(self.x, self.y), rand_mat_dict + ) - assert torch.allclose( - op_then_flat, - flat_then_op, - atol=1e-4, - rtol=tol, - ) + op_then_flat = flatten_dimensions( + gn_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op = self.gn_op._apply_to_mat( + TorchBatch(self.x, self.y), flat_rand_mat + ) - flat_then_op_analytical = torch.einsum( - "ik, jk -> ji", gn_mat_analytical, flat_rand_mat - ) + assert torch.allclose( + op_then_flat, + flat_then_op, + atol=1e-4, + rtol=tol, + ) - assert torch.allclose( - op_then_flat, - flat_then_op_analytical, - atol=1e-4, - rtol=1e-2, - ) + flat_then_op_analytical = torch.einsum( + "ik, jk -> ji", self.gn_mat_analytical, flat_rand_mat + ) + + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-4, + rtol=1e-2, + ) @pytest.mark.torch -@pytest.mark.parametrize( - "model_data, tol", - [(astuple(tp.model_params), 1e-3) for tp in test_parameters], - indirect=["model_data"], -) -@pytest.mark.parametrize("reg", [1.0, 10, 100]) -def test_inverse_harmonic_mean_batch_operation(model_data, tol: float, reg): - torch_model, x, y, vec, _ = model_data - y_pred = torch_model(x) - out_features = y_pred.shape[1] - dl_dw = torch.vmap( - lambda r, s, t: 2 / float(out_features) * (s - t).view(-1, 1) @ r.view(1, -1) - )(x, y_pred, y) - dl_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) - grad_analytical = torch.cat([dl_dw.reshape(x.shape[0], -1), dl_db], dim=-1) - params = { - k: p.detach() for k, p in torch_model.named_parameters() if p.requires_grad - } - - ihm_mat_analytical = torch.sum( - torch.func.vmap( - lambda z: torch.linalg.inv( - z.unsqueeze(-1) * z.unsqueeze(-1).t() + reg * torch.eye(len(z)) - ) - )(grad_analytical), - dim=0, - ) - ihm_mat_analytical /= x.shape[0] +class TestInverseHarmonicMeanBatchOperation: + @pytest.fixture(autouse=True) + def setup(self, model_data, reg): + self.torch_model, self.x, self.y, self.vec, _ = model_data + self.reg = reg + self.params = { + k: p.detach() + for k, p in self.torch_model.named_parameters() + if p.requires_grad + } + self.grad_analytical = self.compute_grad_analytical() + self.ihm_mat_analytical = self.compute_ihm_mat_analytical() + self.gn_op = InverseHarmonicMeanBatchOperation( + self.torch_model, + torch.nn.functional.mse_loss, + self.reg, + restrict_to=self.params, + ) - gn_op = InverseHarmonicMeanBatchOperation( - torch_model, torch.nn.functional.mse_loss, reg, restrict_to=params - ) - batch_size = 10 + def compute_grad_analytical(self): + y_pred = self.torch_model(self.x) + out_features = y_pred.shape[1] + dl_dw = torch.vmap( + lambda r, s, t: 2 + / float(out_features) + * (s - t).view(-1, 1) + @ r.view(1, -1) + )(self.x, y_pred, self.y) + dl_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))( + y_pred, self.y + ) + return torch.cat([dl_dw.reshape(self.x.shape[0], -1), dl_db], dim=-1) + + def compute_ihm_mat_analytical(self): + return ( + torch.sum( + torch.func.vmap( + lambda z: torch.linalg.inv( + z.unsqueeze(-1) * z.unsqueeze(-1).t() + + self.reg * torch.eye(len(z)) + ) + )(self.grad_analytical), + dim=0, + ) + / self.x.shape[0] + ) - gn_autograd = gn_op.apply(TorchBatch(x, y), vec) - gn_autograd_dict = gn_op.apply_to_dict( - TorchBatch(x, y), align_structure(params, vec) + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], ) - gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) - analytical = ihm_mat_analytical @ vec - - assert torch.allclose(gn_autograd, analytical, atol=1e-4, rtol=tol) - assert torch.allclose(gn_autograd_dict_flat, analytical, atol=1e-4, rtol=tol) + @pytest.mark.parametrize("reg", [1.0, 10, 100]) + def test_analytical_comparison(self, model_data, tol, reg): + gn_autograd = self.gn_op.apply(TorchBatch(self.x, self.y), self.vec) + gn_autograd_dict = self.gn_op.apply_to_dict( + TorchBatch(self.x, self.y), align_structure(self.params, self.vec) + ) + gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) + analytical = self.ihm_mat_analytical @ self.vec - rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} - flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) - gn_autograd_mat_dict = gn_op.apply_to_dict(TorchBatch(x, y), rand_mat_dict) + assert torch.allclose(gn_autograd, analytical, atol=1e-4, rtol=tol) + assert torch.allclose(gn_autograd_dict_flat, analytical, atol=1e-4, rtol=tol) - op_then_flat = flatten_dimensions( - gn_autograd_mat_dict.values(), shape=(batch_size, -1) + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], ) - flat_then_op = gn_op._apply_to_mat(TorchBatch(x, y), flat_rand_mat) + @pytest.mark.parametrize("reg", [1.0, 10, 100]) + def test_flattening_commutation(self, model_data, tol, reg): + batch_size = 10 + rand_mat_dict = { + k: torch.randn(batch_size, *t.shape) for k, t in self.params.items() + } + flat_rand_mat = flatten_dimensions( + rand_mat_dict.values(), shape=(batch_size, -1) + ) + gn_autograd_mat_dict = self.gn_op.apply_to_dict( + TorchBatch(self.x, self.y), rand_mat_dict + ) - assert torch.allclose( - op_then_flat, - flat_then_op, - atol=1e-4, - rtol=tol, - ) + op_then_flat = flatten_dimensions( + gn_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op = self.gn_op._apply_to_mat( + TorchBatch(self.x, self.y), flat_rand_mat + ) - flat_then_op_analytical = torch.einsum( - "ik, jk -> ji", ihm_mat_analytical, flat_rand_mat - ) + assert torch.allclose( + op_then_flat, + flat_then_op, + atol=1e-4, + rtol=tol, + ) - assert torch.allclose( - op_then_flat, - flat_then_op_analytical, - atol=1e-4, - rtol=1e-2, - ) + flat_then_op_analytical = torch.einsum( + "ik, jk -> ji", self.ihm_mat_analytical, flat_rand_mat + ) + + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-4, + rtol=1e-2, + ) @pytest.mark.torch From f25f2891e6d97b0ff382b55928923b87eeb5fddc Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 5 Jun 2024 18:17:35 +0200 Subject: [PATCH 42/43] Remove abstract methods, which compute gradient dictionary outputs from interface GradientProvider --- src/pydvl/influence/types.py | 46 ------------------------------------ 1 file changed, 46 deletions(-) diff --git a/src/pydvl/influence/types.py b/src/pydvl/influence/types.py index 2d5f6be73..8300768de 100644 --- a/src/pydvl/influence/types.py +++ b/src/pydvl/influence/types.py @@ -120,52 +120,6 @@ class GradientProvider(Generic[BatchType, TensorType], ABC): """ - @abstractmethod - def grads(self, batch: BatchType) -> Dict[str, TensorType]: - r""" - Computes and returns a dictionary mapping gradient names to their respective - per-sample gradients. Given the example in the class docstring, this means - - $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, - \text{batch.x}, \text{batch.y}), $$ - - where the first dimension of the resulting tensors is always considered to be - the batch dimension, so the shape of the resulting tensors are $(N, d_i)$, - where $N$ is the number of samples in the batch. - - Args: - batch: The batch of data for which to compute gradients. - - Returns: - A dictionary where keys are gradient identifiers and values are the - gradients computed per sample. - """ - - @abstractmethod - def mixed_grads(self, batch: BatchType) -> Dict[str, TensorType]: - r""" - Computes and returns a dictionary mapping gradient names to their respective - per-sample mixed gradients. In this context, mixed gradients refer to computing - gradients with respect to the instance definition in addition to - compute derivatives with respect to the input batch. - Given the example in the class docstring, this means - - $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, - \omega_2, \text{batch.x}, \text{batch.y}), $$ - - where the first dimension of the resulting tensors is always considered to be - the batch dimension and the last to be the non-batch input related derivatives. - So the shape of the resulting tensors are $(N, n, d_i)$, - where $N$ is the number of samples in the batch. - - Args: - batch: The batch of data for which to compute mixed gradients. - - Returns: - A dictionary where keys are gradient identifiers and values are the - mixed gradients computed per sample. - """ - @abstractmethod def jacobian_prod( self, From 5b34406b64d8b244e1359e6193253b956ae67910 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Wed, 5 Jun 2024 18:24:42 +0200 Subject: [PATCH 43/43] Fix fixture usage in test --- tests/influence/torch/test_batch_operation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/influence/torch/test_batch_operation.py b/tests/influence/torch/test_batch_operation.py index cf3ea58d6..b04f9b19b 100644 --- a/tests/influence/torch/test_batch_operation.py +++ b/tests/influence/torch/test_batch_operation.py @@ -118,7 +118,7 @@ def compute_gn_mat_analytical(self): [(astuple(tp.model_params), 1e-3) for tp in test_parameters], indirect=["model_data"], ) - def test_analytical_comparison(self, tol): + def test_analytical_comparison(self, model_data, tol): gn_autograd = self.gn_op.apply(TorchBatch(self.x, self.y), self.vec) gn_autograd_dict = self.gn_op.apply_to_dict( TorchBatch(self.x, self.y), align_structure(self.params, self.vec) @@ -136,7 +136,7 @@ def test_analytical_comparison(self, tol): [(astuple(tp.model_params), 1e-3) for tp in test_parameters], indirect=["model_data"], ) - def test_flattening_commutation(self, tol): + def test_flattening_commutation(self, model_data, tol): batch_size = 10 rand_mat_dict = { k: torch.randn(batch_size, *t.shape) for k, t in self.params.items()