Skip to content

Commit

Permalink
Refactor executor pattern.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Sep 1, 2023
1 parent 0bd79d0 commit b13bbd5
Showing 1 changed file with 5 additions and 23 deletions.
28 changes: 5 additions & 23 deletions src/pydvl/value/shapley/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,7 @@ def compute_classwise_shapley_values(

parallel_backend = init_parallel_backend(config)
u_ref = parallel_backend.put(u)
# This represents the number of jobs that are running
n_jobs = effective_n_jobs(n_jobs, config)
# This determines the total number of submitted jobs
# including the ones that are running
n_submitted_jobs = 2 * n_jobs

pbar = tqdm(disable=not progress, position=0, total=100, unit="%")
Expand All @@ -87,19 +84,7 @@ def compute_classwise_shapley_values(
terminate_exec = False
with init_executor(max_workers=n_jobs, config=config) as executor:
futures = set()
# Initial batch of computations
for _ in range(n_submitted_jobs):
future = executor.submit(
_classwise_shapley_one_step,
u_ref,
truncation=truncation,
n_resample_complement_sets=n_resample_complement_sets,
use_default_scorer_value=use_default_scorer_value,
min_elements_per_label=min_elements_per_label,
)
futures.add(future)
while futures:
# Wait for the next futures to complete.
while True:
completed_futures, futures = wait(
futures, timeout=60, return_when=FIRST_COMPLETED
)
Expand All @@ -114,12 +99,9 @@ def compute_classwise_shapley_values(
if terminate_exec:
break

# Submit more computations
# The goal is to always have `n_jobs`
# computations running
for _ in range(n_submitted_jobs - len(futures)):
future = executor.submit(
_classwise_shapley_one_step,
_permutation_montecarlo_classwise_shapley,
u_ref,
truncation=truncation,
n_resample_complement_sets=n_resample_complement_sets,
Expand All @@ -135,7 +117,7 @@ def compute_classwise_shapley_values(
return result


def _classwise_shapley_one_step(
def _permutation_montecarlo_classwise_shapley(
u: Utility,
*,
truncation: TruncationPolicy,
Expand Down Expand Up @@ -170,7 +152,7 @@ def _classwise_shapley_one_step(

for label in unique_labels:
u.scorer.label = label
result += _permutation_montecarlo_classwise_shapley(
result += _permutation_montecarlo_classwise_shapley_for_label(
u,
label,
done=MaxChecks(n_resample_complement_sets),
Expand Down Expand Up @@ -387,7 +369,7 @@ def estimate_in_cls_and_out_of_cls_score(
return in_cls_score, out_of_cls_score


def _permutation_montecarlo_classwise_shapley(
def _permutation_montecarlo_classwise_shapley_for_label(
u: Utility,
label: int,
*,
Expand Down

0 comments on commit b13bbd5

Please sign in to comment.