Skip to content

Commit

Permalink
Merge pull request #520 from aai-institute/feature/msr-banzhaf
Browse files Browse the repository at this point in the history
MSR method for Banzhaf
  • Loading branch information
mdbenito authored Apr 12, 2024
2 parents be14b2b + 426b867 commit 67a5b06
Show file tree
Hide file tree
Showing 22 changed files with 2,189 additions and 73 deletions.
19 changes: 10 additions & 9 deletions .notebook_test_durations
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
{
"notebooks/data_oob.ipynb::": 14.608769827000287,
"notebooks/influence_imagenet.ipynb::": 13.570316236000508,
"notebooks/influence_sentiment_analysis.ipynb::": 20.546479973001624,
"notebooks/influence_synthetic.ipynb::": 5.9324631089984905,
"notebooks/influence_wine.ipynb::": 16.114133220999065,
"notebooks/least_core_basic.ipynb::": 14.312467472000208,
"notebooks/shapley_basic_spotify.ipynb::": 15.608795123000164,
"notebooks/shapley_knn_flowers.ipynb::": 3.9430189769991557,
"notebooks/shapley_utility_learning.ipynb::": 26.96671833400069
"notebooks/data_oob.ipynb::": 14.514983271001256,
"notebooks/influence_imagenet.ipynb::": 15.937124550999215,
"notebooks/influence_sentiment_analysis.ipynb::": 26.479645616000198,
"notebooks/influence_synthetic.ipynb::": 6.61773010700017,
"notebooks/influence_wine.ipynb::": 16.312171267998565,
"notebooks/least_core_basic.ipynb::": 14.375480750999486,
"notebooks/msr_banzhaf_digits.ipynb::": 106.6507187110019,
"notebooks/shapley_basic_spotify.ipynb::": 15.657225806997303,
"notebooks/shapley_knn_flowers.ipynb::": 3.9943819290019746,
"notebooks/shapley_utility_learning.ipynb::": 25.939783253001224
}
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

- New method: `NystroemSketchInfluence`
[PR #504](https://github.com/aai-institute/pyDVL/pull/504)
- New method `MSR Banzhaf` with accompanying notebook, and new stopping
criterion `RankCorrelation`
[PR #520](https://github.com/aai-institute/pyDVL/pull/520)
- New preconditioned block variant of conjugate gradient
[PR #507](https://github.com/aai-institute/pyDVL/pull/507)
- Improvements to documentation: fixes, links, text, example gallery, LFS and
Expand Down
2 changes: 1 addition & 1 deletion docs/assets/pydvl.bib
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ @book{trefethen_numerical_1997
langid = {english}
}

@inproceedings{wang_data_2022,
@inproceedings{wang_data_2023,
title = {Data {{Banzhaf}}: {{A Robust Data Valuation Framework}} for {{Machine Learning}}},
shorttitle = {Data {{Banzhaf}}},
booktitle = {Proceedings of {{The}} 26th {{International Conference}} on {{Artificial Intelligence}} and {{Statistics}}},
Expand Down
3 changes: 3 additions & 0 deletions docs/examples/img/msr_banzhaf_digits.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions docs/examples/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ alias:

[![](img/data_oob.png)](data_oob/)

- [__Faster Banzhaf values__](msr_banzhaf_digits/)

---

Using Banzhaf values to estimate the value of data points in MNIST, and
evaluating convergence speed of MSR.

[![](img/msr_banzhaf_digits.png)](msr_banzhaf_digits/)

</div>


Expand Down
9 changes: 9 additions & 0 deletions docs/getting-started/glossary.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ performance when that point is removed from the training set.
* [Implementation][pydvl.value.loo.loo.compute_loo]
* [Documentation][leave-one-out-values]

### Maximum Sample Reuse

MSR is a sampling method for data valuation that updates the value of every
data point in one sample. This method can achieve much faster convergence.
Introduced by [@wang_data_2023]

* [Implementation][pydvl.value.sampler.MSRSampler]


### Monte Carlo Least Core

MCLC is a variation of the Least Core that uses a reduced amount of
Expand Down
37 changes: 34 additions & 3 deletions docs/value/semi-values.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ the set $D_{-i}^{(k)}$ contains all subsets of $D$ of size $k$ that do not
include sample $x_i$, $S_{+i}$ is the set $S$ with $x_i$ added, and $u$ is the
utility function.

Two instances of this are **Banzhaf indices** [@wang_data_2022],
Two instances of this are **Banzhaf indices** [@wang_data_2023],
and **Beta Shapley** [@kwon_beta_2022], with better numerical and
rank stability in certain situations.

Expand Down Expand Up @@ -84,7 +84,7 @@ any choice of weight function $w$, one can always construct a utility with
higher variance where $w$ is greater. Therefore, in a worst-case sense, the best
one can do is to pick a constant weight.

The authors of [@wang_data_2022] show that Banzhaf indices are more robust to
The authors of [@wang_data_2023] show that Banzhaf indices are more robust to
variance in the utility function than Shapley and Beta Shapley values. They are
available in pyDVL through
[compute_banzhaf_semivalues][pydvl.value.semivalues.compute_banzhaf_semivalues]:
Expand All @@ -98,6 +98,37 @@ values = compute_banzhaf_semivalues(
)
```

### Banzhaf semi-values with MSR sampling
Wang et. al. propose a more sample-efficient method for computing Banzhaf
semivalues in their paper *Data Banzhaf: A Robust Data Valuation Framework
for Machine Learning* [@wang_data_2023]. This method updates all semivalues
per evaluation of the utility (i.e. per model trained) based on whether a
specific data point was included in the data subset or not. The expression
for computing the semivalues is

$$\hat{\phi}_{MSR}(i) = \frac{1}{|\mathbf{S}_{\ni i}|} \sum_{S \in
\mathbf{S}_{\ni i}} U(S) - \frac{1}{|\mathbf{S}_{\not{\ni} i}|}
\sum_{S \in \mathbf{S}_{\not{\ni} i}} U(S)$$

where $\mathbf{S}_{\ni i}$ are the subsets that contain the index $i$ and
$\mathbf{S}_{\not{\ni} i}$ are the subsets not containing the index $i$.

The function implementing this method is
[compute_msr_banzhaf_semivalues][pydvl.value.semivalues.compute_msr_banzhaf_semivalues].

```python
from pydvl.value import compute_msr_banzhaf_semivalues, RankCorrelation, Utility

utility = Utility(model, data)
values = compute_msr_banzhaf_semivalues(
u=utility, done=RankCorrelation(rtol=0.001),
)
```
For further details on how to use this method and a comparison of the sample
efficiency, we suggest to take a look at the example notebook
[msr_banzhaf_spotify](../../examples/msr_banzhaf_spotify).


## General semi-values

As explained above, both Beta Shapley and Banzhaf indices are special cases of
Expand Down Expand Up @@ -130,7 +161,7 @@ values = compute_generic_semivalues(
Allowing any coefficient can help when experimenting with models which are more
sensitive to changes in training set size. However, Data Banzhaf indices are
proven to be the most robust to variance in the utility function, in the sense
of rank stability, across a range of models and datasets [@wang_data_2022].
of rank stability, across a range of models and datasets [@wang_data_2023].

!!! warning "Careful with permutation sampling"
This generic implementation of semi-values allowing for any combination of
Expand Down
1 change: 1 addition & 0 deletions docs_includes/abbreviations.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*[MLRC]: Machine Learning Reproducibility Challenge
*[ML]: Machine Learning
*[MSE]: Mean Squared Error
*[MSR]: Maximum Sample Reuse
*[NLRA]: Nyström Low-Rank Approximation
*[OOB]: Out-of-Bag
*[PCA]: Principal Component Analysis
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ nav:
- Data utility learning: examples/shapley_utility_learning.ipynb
- Least Core: examples/least_core_basic.ipynb
- Data OOB: examples/data_oob.ipynb
- Banzhaf Semivalues: examples/msr_banzhaf_digits.ipynb
- Influence Function:
- For CNNs: examples/influence_imagenet.ipynb
- For mislabeled data: examples/influence_synthetic.ipynb
Expand Down
18 changes: 9 additions & 9 deletions notebooks/least_core_basic.ipynb

Large diffs are not rendered by default.

1,532 changes: 1,532 additions & 0 deletions notebooks/msr_banzhaf_digits.ipynb

Large diffs are not rendered by default.

114 changes: 114 additions & 0 deletions notebooks/support/banzhaf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Optional

import numpy as np
from numpy.typing import NDArray
from sklearn.datasets import load_digits
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

from pydvl.utils.types import SupervisedModel

try:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
except ImportError as e:
raise RuntimeError("PyTorch is required to run the Banzhaf MSR notebook") from e


def load_digits_dataset(
test_size: float, val_size: float = 0.0, random_state: Optional[int] = None
):
"""Loads the sklearn handwritten digits dataset. More info can be found at
https://scikit-learn.org/stable/datasets/toy_dataset.html#optical-recognition-of-handwritten-digits-dataset.
:param test_size: fraction of points used for test dataset
:param val_size: fraction of points used for training dataset
:param random_state: fix random seed. If None, no random seed is set.
:return: A tuple of three elements with the first three being input and
target values in the form of matrices of shape (N,8,8) the first
and (N,) the second.
"""

digits_bunch = load_digits(as_frame=True)
x, x_test, y, y_test = train_test_split(
digits_bunch.data.values / 16.0,
digits_bunch.target.values,
train_size=1 - test_size,
random_state=random_state,
)
if val_size > 0:
x_train, x_val, y_train, y_val = train_test_split(
x, y, train_size=(1 - val_size) / (1 - test_size), random_state=random_state
)
else:
x_train, y_train = x, y
x_val, y_val = None, None

return ((x_train, y_train), (x_val, y_val), (x_test, y_test))


class TorchCNNModel(SupervisedModel):
def __init__(
self,
lr: float = 0.001,
epochs: int = 40,
batch_size: int = 32,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
self.lr = lr
self.batch_size = batch_size
self.epochs = epochs
self.device = device
self.model = nn.Sequential(
nn.Conv2d(
out_channels=8, in_channels=1, kernel_size=(3, 3), padding="same"
),
nn.Conv2d(
out_channels=4, in_channels=8, kernel_size=(3, 3), padding="same"
),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(in_features=64, out_features=32),
nn.Linear(in_features=32, out_features=10),
nn.Softmax(dim=1),
)
self.loss = nn.CrossEntropyLoss()
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
self.model.to(device)

def fit(self, x: NDArray, y: NDArray) -> None:
torch_dataset = TensorDataset(
torch.tensor(
np.reshape(x, (x.shape[0], 1, 8, 8)),
dtype=torch.float,
device=self.device,
),
torch.tensor(y, device=self.device),
)
torch_dataloader = DataLoader(torch_dataset, batch_size=self.batch_size)
for epoch in range(self.epochs):
for features, labels in torch_dataloader:
pred = self.model(features)
loss = self.loss(pred, labels)
loss.backward()
self.optimizer.step()

def predict(self, x: NDArray) -> NDArray:
pred = self.model(
torch.tensor(
np.reshape(x, (x.shape[0], 1, 8, 8)),
dtype=torch.float,
device=self.device,
)
)
pred = torch.argmax(pred, dim=1)
return pred.cpu().numpy()

def score(self, x: NDArray, y: NDArray) -> float:
pred = self.predict(x)
acc = accuracy_score(pred, y)
return acc

def get_params(self, deep: bool = False):
return {"lr": self.lr, "epochs": self.epochs}
5 changes: 3 additions & 2 deletions src/pydvl/reporting/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def plot_shapley(
title: Optional[str] = None,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
prefix: Optional[str] = "data_value",
) -> plt.Axes:
r"""Plots the shapley values, as returned from
[compute_shapley_values][pydvl.value.shapley.common.compute_shapley_values],
Expand All @@ -260,9 +261,9 @@ def plot_shapley(
if ax is None:
_, ax = plt.subplots()

yerr = norm.ppf(1 - level / 2) * df["data_value_stderr"]
yerr = norm.ppf(1 - level / 2) * df[f"{prefix}_stderr"]

ax.errorbar(x=df.index, y=df["data_value"], yerr=yerr, fmt="o", capsize=6)
ax.errorbar(x=df.index, y=df[prefix], yerr=yerr, fmt="o", capsize=6)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
Expand Down
10 changes: 7 additions & 3 deletions src/pydvl/value/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
samples random values uniformly.
"""

from __future__ import annotations

import collections.abc
Expand Down Expand Up @@ -676,12 +677,15 @@ def to_dataframe(
column = column or self._algorithm
df = pd.DataFrame(
self._values[self._sort_positions],
index=self._names[self._sort_positions]
if use_names
else self._indices[self._sort_positions],
index=(
self._names[self._sort_positions]
if use_names
else self._indices[self._sort_positions]
),
columns=[column],
)
df[column + "_stderr"] = self.stderr[self._sort_positions]
df[column + "_updates"] = self.counts[self._sort_positions]
return df

@classmethod
Expand Down
31 changes: 29 additions & 2 deletions src/pydvl/value/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
Frank, and Geoffrey Holmes. [Sampling Permutations for Shapley Value
Estimation](http://jmlr.org/papers/v23/21-0439.html). Journal of Machine
Learning Research 23, no. 43 (2022): 1–46.
[^2]: <a name="wang_data_2023"></a>Wang, J.T. and Jia, R., 2023.
[Data Banzhaf: A Robust Data Valuation Framework for Machine Learning](https://proceedings.mlr.press/v206/wang23e.html).
In: Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, pp. 6388-6421.
"""

Expand Down Expand Up @@ -72,13 +75,16 @@

__all__ = [
"AntitheticSampler",
"DeterministicUniformSampler",
"DeterministicPermutationSampler",
"DeterministicUniformSampler",
"MSRSampler",
"PermutationSampler",
"PowersetSampler",
"RandomHierarchicalSampler",
"UniformSampler",
"SampleT",
"StochasticSampler",
"StochasticSamplerMixin",
"UniformSampler",
]

SampleT = Tuple[IndexT, NDArray[IndexT]]
Expand Down Expand Up @@ -312,6 +318,26 @@ def weight(cls, n: int, subset_len: int) -> float:
return float(2 ** (n - 1)) if n > 0 else 1.0


class MSRSampler(StochasticSamplerMixin, PowersetSampler[IndexT]):
"""An iterator to perform sampling of random subsets.
This sampler does not return any index, it only returns subsets of the data.
This sampler is used in (Wang et. al.)<sup><a href="wang_data_2023">2</a></sup>.
"""

def __iter__(self) -> Iterator[SampleT]:
if len(self) == 0:
return
while True:
subset = random_subset(self.indices, seed=self._rng)
yield None, subset
self._n_samples += 1

@classmethod
def weight(cls, n: int, subset_len: int) -> float:
return 1.0


class AntitheticSampler(StochasticSamplerMixin, PowersetSampler[IndexT]):
"""An iterator to perform uniform random sampling of subsets, and their
complements.
Expand Down Expand Up @@ -450,4 +476,5 @@ def weight(cls, n: int, subset_len: int) -> float:
PermutationSampler[IndexT],
AntitheticSampler[IndexT],
RandomHierarchicalSampler[IndexT],
MSRSampler[IndexT],
]
Loading

0 comments on commit 67a5b06

Please sign in to comment.