From b13bbd5c348025a5fe7a35ea5300ef727c50dfe9 Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Fri, 1 Sep 2023 22:16:34 +0200 Subject: [PATCH] Refactor executor pattern. --- src/pydvl/value/shapley/classwise.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/src/pydvl/value/shapley/classwise.py b/src/pydvl/value/shapley/classwise.py index 9b1c5f5c9..8f7339f78 100644 --- a/src/pydvl/value/shapley/classwise.py +++ b/src/pydvl/value/shapley/classwise.py @@ -72,10 +72,7 @@ def compute_classwise_shapley_values( parallel_backend = init_parallel_backend(config) u_ref = parallel_backend.put(u) - # This represents the number of jobs that are running n_jobs = effective_n_jobs(n_jobs, config) - # This determines the total number of submitted jobs - # including the ones that are running n_submitted_jobs = 2 * n_jobs pbar = tqdm(disable=not progress, position=0, total=100, unit="%") @@ -87,19 +84,7 @@ def compute_classwise_shapley_values( terminate_exec = False with init_executor(max_workers=n_jobs, config=config) as executor: futures = set() - # Initial batch of computations - for _ in range(n_submitted_jobs): - future = executor.submit( - _classwise_shapley_one_step, - u_ref, - truncation=truncation, - n_resample_complement_sets=n_resample_complement_sets, - use_default_scorer_value=use_default_scorer_value, - min_elements_per_label=min_elements_per_label, - ) - futures.add(future) - while futures: - # Wait for the next futures to complete. + while True: completed_futures, futures = wait( futures, timeout=60, return_when=FIRST_COMPLETED ) @@ -114,12 +99,9 @@ def compute_classwise_shapley_values( if terminate_exec: break - # Submit more computations - # The goal is to always have `n_jobs` - # computations running for _ in range(n_submitted_jobs - len(futures)): future = executor.submit( - _classwise_shapley_one_step, + _permutation_montecarlo_classwise_shapley, u_ref, truncation=truncation, n_resample_complement_sets=n_resample_complement_sets, @@ -135,7 +117,7 @@ def compute_classwise_shapley_values( return result -def _classwise_shapley_one_step( +def _permutation_montecarlo_classwise_shapley( u: Utility, *, truncation: TruncationPolicy, @@ -170,7 +152,7 @@ def _classwise_shapley_one_step( for label in unique_labels: u.scorer.label = label - result += _permutation_montecarlo_classwise_shapley( + result += _permutation_montecarlo_classwise_shapley_for_label( u, label, done=MaxChecks(n_resample_complement_sets), @@ -387,7 +369,7 @@ def estimate_in_cls_and_out_of_cls_score( return in_cls_score, out_of_cls_score -def _permutation_montecarlo_classwise_shapley( +def _permutation_montecarlo_classwise_shapley_for_label( u: Utility, label: int, *,