From 4542263a1c0045f1f31083895a0f7c78f297aeed Mon Sep 17 00:00:00 2001 From: Jakob Kruse Date: Wed, 27 Mar 2024 16:59:19 +0100 Subject: [PATCH] fix batch size update unintentional change --- src/pydvl/value/semivalues.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index a966cfa12..70b83af6f 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -241,7 +241,7 @@ def compute_values(self) -> np.ndarray: def __call__( self, future_result: List[Tuple[List[IndexT], float]] - ) -> Tuple[MarginalT, ...]: + ) -> List[List[MarginalT, ...]]: """Computation of marginal utility using Maximum Sample Reuse. This processor requires the Marginal Function to be set to RawUtility. @@ -259,8 +259,8 @@ def __call__( A collection of marginals. Each marginal is a tuple with index and its marginal utility. """ - marginals: List[MarginalT] = [] - for s, evaluation in future_result: + marginals: List[List[MarginalT]] = [] + for batch_id, (s, evaluation) in enumerate(future_result): previous_values = self.compute_values() self.total_evaluations += 1 self.point_in_subset[s] += 1 @@ -273,9 +273,12 @@ def __call__( self.total_evaluations * new_values - (self.total_evaluations - 1) * previous_values ) + marginals.append([]) for data_index in range(self.n): - marginals.append((data_index, float(marginal_vals[data_index]))) - return tuple(marginals) + marginals[batch_id].append( + (data_index, float(marginal_vals[data_index])) + ) + return marginals # @deprecated( @@ -373,11 +376,18 @@ def compute_generic_semivalues( completed, pending = wait(pending, timeout=1, return_when=FIRST_COMPLETED) for future in completed: - processed_future = future_processor(future.result()) - for idx, marginal_val in processed_future: - result.update(idx, marginal_val) - if done(result): - return result + processed_future = future_processor( + future.result() + ) # List of tuples or + for batch_future in processed_future: + if isinstance(batch_future, list): # Case when batch size is > 1 + for idx, marginal_val in batch_future: + result.update(idx, marginal_val) + else: # Batch size 1 + idx, marginal_val = batch_future + result.update(idx, marginal_val) + if done(result): + return result # Ensure that we always have n_submitted_jobs running try: