diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 71d6fa90..f24ec159 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -161,13 +161,28 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( counter = 0 while True: + # PART 1: Consume as many batches all at once for every worker and their respective current chunk + if num_of_samples_to_carry_to_next_chunk is None: + sizes = [size for size in workers_interval_sizes_for_this_rank if len(size)] + min_interval_size = min(size[0] for size in sizes) + # -1 here because we need the logic in PART 2 to .pop() the list for the last batch + num_batches = (min_interval_size // batch_size) - 1 + num_batches = max(num_batches, 0) + for i in range(len(workers_interval_sizes_for_this_rank)): + if workers_interval_sizes_for_this_rank[i]: + workers_interval_sizes_for_this_rank[i][0] -= num_batches * batch_size + worker_tracker_idx += num_batches * len(sizes) + counter += num_batches * batch_size * len(sizes) + interval_size_of_current_worker = workers_interval_sizes_for_this_rank[worker_tracker_idx % num_workers] if len(interval_size_of_current_worker) == 0: worker_tracker_idx += 1 continue + # PART 2: We have leftover samples to consume + # We consume them one by one because we're at the end of a chunk and may have to handle + # a remainder from the previous iteration num_samples_left_for_this_worker_chunk = interval_size_of_current_worker[0] - # To consume a batch, we want to subtract `batch_size` from the size we have left, # unless we had a remainder (< batch size) from the previous iteration/chunk remover = ( diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index 4bf4fd9c..70192da2 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -222,6 +222,15 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): ) assert chunks_to_disable == {1: [1]} + # world size = 1, 2 workers sharing one chunk, different sizes with remainders to next chunk + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=2, + batch_size=25, + workers_chunks=[[0, 1], [1, 2]], + workers_intervals=[[(0, 0, 70, 100), (0, 0, 55, 100)], [(0, 0, 105, 50), (0, 0, 55, 100)]], + ) + assert chunks_to_disable == {1: [0]} + # world size = 1, 4 workers sharing one chunk chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( num_workers=4, @@ -241,7 +250,7 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): assert chunks_to_disable == {0: [0, 1, 3]} # world size 2, 2 workers per rank, varying batch size - for batch_size in range(1, 6): + for batch_size in range(1, 7): chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( num_workers=2, batch_size=batch_size,