Skip to content

Commit

Permalink
fix batch size update unintentional change
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobkruse1 committed Mar 27, 2024
1 parent bfe5637 commit 4542263
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions src/pydvl/value/semivalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def compute_values(self) -> np.ndarray:

def __call__(
self, future_result: List[Tuple[List[IndexT], float]]
) -> Tuple[MarginalT, ...]:
) -> List[List[MarginalT, ...]]:
"""Computation of marginal utility using Maximum Sample Reuse.
This processor requires the Marginal Function to be set to RawUtility.
Expand All @@ -259,8 +259,8 @@ def __call__(
A collection of marginals. Each marginal is a tuple with index and its marginal
utility.
"""
marginals: List[MarginalT] = []
for s, evaluation in future_result:
marginals: List[List[MarginalT]] = []
for batch_id, (s, evaluation) in enumerate(future_result):
previous_values = self.compute_values()
self.total_evaluations += 1
self.point_in_subset[s] += 1
Expand All @@ -273,9 +273,12 @@ def __call__(
self.total_evaluations * new_values
- (self.total_evaluations - 1) * previous_values
)
marginals.append([])
for data_index in range(self.n):
marginals.append((data_index, float(marginal_vals[data_index])))
return tuple(marginals)
marginals[batch_id].append(
(data_index, float(marginal_vals[data_index]))
)
return marginals


# @deprecated(
Expand Down Expand Up @@ -373,11 +376,18 @@ def compute_generic_semivalues(

completed, pending = wait(pending, timeout=1, return_when=FIRST_COMPLETED)
for future in completed:
processed_future = future_processor(future.result())
for idx, marginal_val in processed_future:
result.update(idx, marginal_val)
if done(result):
return result
processed_future = future_processor(
future.result()
) # List of tuples or
for batch_future in processed_future:
if isinstance(batch_future, list): # Case when batch size is > 1
for idx, marginal_val in batch_future:
result.update(idx, marginal_val)
else: # Batch size 1
idx, marginal_val = batch_future
result.update(idx, marginal_val)
if done(result):
return result

# Ensure that we always have n_submitted_jobs running
try:
Expand Down

0 comments on commit 4542263

Please sign in to comment.