Skip to content

Commit

Permalink
test msr banzhaf methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobkruse1 committed Mar 27, 2024
1 parent a34f3aa commit bfe5637
Showing 1 changed file with 11 additions and 19 deletions.
30 changes: 11 additions & 19 deletions tests/value/test_semivalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@
UniformSampler,
)
from pydvl.value.semivalues import (
MSRMarginal,
DefaultMarginal,
SVCoefficient,
_marginal,
banzhaf_coefficient,
beta_coefficient,
compute_generic_semivalues,
msr_banzhaf,
compute_msr_banzhaf_semivalues,
shapley_coefficient,
)
from pydvl.value.stopping import HistoryDeviation, MaxUpdates, RankStability
from pydvl.value.stopping import HistoryDeviation, MaxChecks, MaxUpdates

from . import check_values
from .utils import timed
Expand Down Expand Up @@ -54,10 +53,12 @@ def test_marginal_batch_size(test_game, sampler, coefficient, batch_size, seed):
marginals_single = []
for sample in samples:
marginals_single.extend(
_marginal(test_game.u, coefficient=coefficient, samples=[sample])
DefaultMarginal()(test_game.u, coefficient=coefficient, samples=[sample])
)

marginals_batch = _marginal(test_game.u, coefficient=coefficient, samples=samples)
marginals_batch = DefaultMarginal()(
test_game.u, coefficient=coefficient, samples=samples
)

assert len(marginals_single) == len(marginals_batch)
assert set(marginals_single) == set(marginals_batch)
Expand All @@ -68,20 +69,11 @@ def test_msr_banzhaf(
num_samples: int, analytic_banzhaf, parallel_config, n_jobs, seed: Seed
):
u, exact_values = analytic_banzhaf
sampler = MSRSampler()
marginal = MSRMarginal()
values = compute_generic_semivalues(
sampler(u.data.indices, seed=seed),
u=u,
coefficient=coefficient,
marginal=marginal,
criterion=RankStability(rtol=0.1) | MaxUpdates(100),
skip_converged=False,
n_jobs=n_jobs,
config=parallel_config,
progress=True,
values = compute_msr_banzhaf_semivalues(
u=u, done=MaxChecks(1000), config=parallel_config, n_jobs=n_jobs, seed=seed
)
check_values(values, exact_values, rtol=0.1)
# Need to use atol because msr banzhaf is quite noisy.
check_values(values, exact_values, atol=0.1)


@pytest.mark.parametrize("n", [10, 100])
Expand Down

0 comments on commit bfe5637

Please sign in to comment.