Skip to content

Commit

Permalink
WIP: MSR method for Banzhaf
Browse files Browse the repository at this point in the history
  • Loading branch information
mdbenito authored and jakobkruse1 committed Mar 18, 2024
1 parent 531c4d3 commit 837cbeb
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/pydvl/value/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,19 @@ def weight(cls, n: int, subset_len: int) -> float:
return float(2 ** (n - 1)) if n > 0 else 1.0


class MSRSampler(PowersetSampler[T]):
def __iter__(self) -> Iterator[SampleType]:
if len(self) == 0:
return
while True:
subset = random_subset(self.indices)
yield -1, subset
self._n_samples += 1

def weight(self, subset: NDArray[T]) -> float:
return float(2 ** (self._n - 1)) if self._n > 0 else 1.0


class AntitheticSampler(StochasticSamplerMixin, PowersetSampler[IndexT]):
"""An iterator to perform uniform random sampling of subsets, and their
complements.
Expand Down
79 changes: 79 additions & 0 deletions src/pydvl/value/semivalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
from pydvl.utils.types import IndexT, Seed
from pydvl.value import ValuationResult
from pydvl.value.sampler import (
MSRSampler,
PermutationSampler,
PowersetSampler,
SampleT,
Expand Down Expand Up @@ -140,6 +141,84 @@ def __call__(self, n: int, k: int) -> float:
...


def _msr_banzhaf(
sampler: MSRSampler,
u: Utility,
done: StoppingCriterion,
*,
job_id: int = 1,
progress: bool = False,
) -> ValuationResult:
"""Inner loop of the MSR-Banzhaf approximation.
See :func:`msr_banzhaf` for details.
:param sampler: The subset sampler to use for utility computations.
:param u: Utility object with model, data, and scoring function.
:param done: Stopping criterion.
:param progress: Whether to display progress bars for each job.
:param job_id: id to use for reporting progress.
:return: Object with the results.
"""

if not isinstance(sampler, MSRSampler):
raise TypeError("MSR-Banzhaf requires a MSRSampler.")

result = ValuationResult.zeros(
algorithm=f"msr-banzhaf",
indices=sampler.indices,
data_names=[u.data.data_names[i] for i in sampler.indices],
)

samples = takewhile(lambda _: not done(result), sampler)
pbar = tqdm(disable=not progress, position=job_id, total=100, unit="%")
for _, s in samples:
pbar.n = 100 * done.completion()
pbar.refresh()
_u = u(s)
for idx in sampler.indices:
# FIXME: this is the wrong normalization

marginal = _u * (1 if idx in s else -1)
result.update(idx, marginal)

return result


def msr_banzhaf(
u: Utility,
done: StoppingCriterion,
*,
n_jobs: int = 1,
config: ParallelConfig = ParallelConfig(),
progress: bool = False,
) -> ValuationResult:
r"""Maximum-Sample-Reuse Monte Carlo approximation to Banzhaf index.
Following :footcite:t:`wang_data_2022`, the MSR-Banzhaf Monte Carlo
approximation uses each sample $S$ of the whole dataset $D$ for each index.
This is made possible by the formulation:
TODO...
:param u: Utility object with model, data, and scoring function.
:param done: Stopping criterion.
:param n_jobs: Number of parallel jobs to run.
:param config: Configuration for parallel jobs.
:param progress: Whether to display progress bars for each job.
:return: Object with the results.
"""
map_reduce_job: MapReduceJob[PowersetSampler, ValuationResult] = MapReduceJob(
MSRSampler(u.data.indices),
map_func=_msr_banzhaf,
reduce_func=lambda results: reduce(operator.add, results),
map_kwargs=dict(u=u, done=done, progress=progress),
config=config,
n_jobs=n_jobs,
)
return map_reduce_job()


MarginalT = Tuple[IndexT, float]


Expand Down
33 changes: 33 additions & 0 deletions src/pydvl/value/stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@

import numpy as np
from numpy.typing import NDArray
from scipy.stats import spearmanr

from pydvl.utils import Status
from pydvl.value import ValuationResult
Expand Down Expand Up @@ -626,3 +627,35 @@ def reset(self):

def __str__(self):
return f"HistoryDeviation(n_steps={self.n_steps}, rtol={self.rtol})"


class RankStability(StoppingCriterion):
r"""A check for stability of Spearman correlation between checks.
When the change in rank correlation between two successive iterations is
below a given threshold, the computation is terminated.
This criterion is used in :footcite:t:`wang_data_2022`.
"""

def __init__(self, rtol: float, modify_result: bool = True):
super().__init__(modify_result=modify_result)
if rtol <= 0 or rtol >= 1:
raise ValueError("rtol must be in (0, 1)")
self.rtol = rtol
self._memory = None # type: ignore
self._corr = 0.0

def _check(self, r: ValuationResult) -> Status:
if self._memory is None:
self._memory = r.values.copy()
self._converged = np.full(len(r), False)
return Status.Pending

corr = spearmanr(self._memory, r.values)[0]
self._memory = r.values.copy()
if np.isclose(corr, self._corr, rtol=self.rtol):
self._converged = np.full(len(r), True)
return Status.Converged
self._corr = corr
return Status.Pending
8 changes: 8 additions & 0 deletions tests/value/test_semivalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
banzhaf_coefficient,
beta_coefficient,
compute_generic_semivalues,
msr_banzhaf,
shapley_coefficient,
)
from pydvl.value.stopping import HistoryDeviation, MaxUpdates
Expand Down Expand Up @@ -61,6 +62,13 @@ def test_marginal_batch_size(test_game, sampler, coefficient, batch_size, seed):
assert set(marginals_single) == set(marginals_batch)


@pytest.mark.parametrize("num_samples", [5])
def test_msr_banzhaf(num_samples: int, analytic_banzhaf):
u, exact_values = analytic_banzhaf
values = msr_banzhaf(u, AbsoluteStandardError(0.02, 1.0) | MaxUpdates(300))
check_values(values, exact_values, rtol=0.15)


@pytest.mark.parametrize("n", [10, 100])
@pytest.mark.parametrize(
"coefficient",
Expand Down

0 comments on commit 837cbeb

Please sign in to comment.