diff --git a/tests/value/test_semivalues.py b/tests/value/test_semivalues.py index 9d1e1dcce..d4e62ed51 100644 --- a/tests/value/test_semivalues.py +++ b/tests/value/test_semivalues.py @@ -17,16 +17,15 @@ UniformSampler, ) from pydvl.value.semivalues import ( - MSRMarginal, + DefaultMarginal, SVCoefficient, - _marginal, banzhaf_coefficient, beta_coefficient, compute_generic_semivalues, - msr_banzhaf, + compute_msr_banzhaf_semivalues, shapley_coefficient, ) -from pydvl.value.stopping import HistoryDeviation, MaxUpdates, RankStability +from pydvl.value.stopping import HistoryDeviation, MaxChecks, MaxUpdates from . import check_values from .utils import timed @@ -54,10 +53,12 @@ def test_marginal_batch_size(test_game, sampler, coefficient, batch_size, seed): marginals_single = [] for sample in samples: marginals_single.extend( - _marginal(test_game.u, coefficient=coefficient, samples=[sample]) + DefaultMarginal()(test_game.u, coefficient=coefficient, samples=[sample]) ) - marginals_batch = _marginal(test_game.u, coefficient=coefficient, samples=samples) + marginals_batch = DefaultMarginal()( + test_game.u, coefficient=coefficient, samples=samples + ) assert len(marginals_single) == len(marginals_batch) assert set(marginals_single) == set(marginals_batch) @@ -68,20 +69,11 @@ def test_msr_banzhaf( num_samples: int, analytic_banzhaf, parallel_config, n_jobs, seed: Seed ): u, exact_values = analytic_banzhaf - sampler = MSRSampler() - marginal = MSRMarginal() - values = compute_generic_semivalues( - sampler(u.data.indices, seed=seed), - u=u, - coefficient=coefficient, - marginal=marginal, - criterion=RankStability(rtol=0.1) | MaxUpdates(100), - skip_converged=False, - n_jobs=n_jobs, - config=parallel_config, - progress=True, + values = compute_msr_banzhaf_semivalues( + u=u, done=MaxChecks(1000), config=parallel_config, n_jobs=n_jobs, seed=seed ) - check_values(values, exact_values, rtol=0.1) + # Need to use atol because msr banzhaf is quite noisy. + check_values(values, exact_values, atol=0.1) @pytest.mark.parametrize("n", [10, 100])