diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 92def8f9..fb2dbc47 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -161,12 +161,14 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( counter = 0 while True: - - # TODO: Add comment + + # PART 1: Consume as many batches all at once for every worker and their respective current chunk if num_of_samples_to_carry_to_next_chunk is None: sizes = [size for size in workers_interval_sizes_for_this_rank if len(size)] min_interval_size = min(size[0] for size in sizes) - num_batches = max(0, (min_interval_size // batch_size) - 1) + # -1 here because we need the logic in PART 2 to .pop() the list for the last batch + num_batches = (min_interval_size // batch_size) - 1 + num_batches = max(num_batches, 0) for i in range(len(workers_interval_sizes_for_this_rank)): if workers_interval_sizes_for_this_rank[i]: workers_interval_sizes_for_this_rank[i][0] -= num_batches * batch_size @@ -178,8 +180,10 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( worker_tracker_idx += 1 continue + # PART 2: We have leftover samples to consume + # We consume them one by one because we're at the end of a chunk and may have to handle + # a remainder from the previous iteration num_samples_left_for_this_worker_chunk = interval_size_of_current_worker[0] - # To consume a batch, we want to subtract `batch_size` from the size we have left, # unless we had a remainder (< batch size) from the previous iteration/chunk remover = (