From de7914ea43453b3d860f0b1418c16fb662658105 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 6 Aug 2024 17:22:10 +0000 Subject: [PATCH 1/4] speed up skip --- src/litdata/utilities/shuffle.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 71d6fa90..92def8f9 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -161,6 +161,18 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( counter = 0 while True: + + # TODO: Add comment + if num_of_samples_to_carry_to_next_chunk is None: + sizes = [size for size in workers_interval_sizes_for_this_rank if len(size)] + min_interval_size = min(size[0] for size in sizes) + num_batches = max(0, (min_interval_size // batch_size) - 1) + for i in range(len(workers_interval_sizes_for_this_rank)): + if workers_interval_sizes_for_this_rank[i]: + workers_interval_sizes_for_this_rank[i][0] -= num_batches * batch_size + worker_tracker_idx += num_batches * len(sizes) + counter += num_batches * batch_size * len(sizes) + interval_size_of_current_worker = workers_interval_sizes_for_this_rank[worker_tracker_idx % num_workers] if len(interval_size_of_current_worker) == 0: worker_tracker_idx += 1 From 8b60be20fd2455e75f31006b4ea048bb34624341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 6 Aug 2024 19:09:03 +0000 Subject: [PATCH 2/4] add comments --- src/litdata/utilities/shuffle.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 92def8f9..fb2dbc47 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -161,12 +161,14 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( counter = 0 while True: - - # TODO: Add comment + + # PART 1: Consume as many batches all at once for every worker and their respective current chunk if num_of_samples_to_carry_to_next_chunk is None: sizes = [size for size in workers_interval_sizes_for_this_rank if len(size)] min_interval_size = min(size[0] for size in sizes) - num_batches = max(0, (min_interval_size // batch_size) - 1) + # -1 here because we need the logic in PART 2 to .pop() the list for the last batch + num_batches = (min_interval_size // batch_size) - 1 + num_batches = max(num_batches, 0) for i in range(len(workers_interval_sizes_for_this_rank)): if workers_interval_sizes_for_this_rank[i]: workers_interval_sizes_for_this_rank[i][0] -= num_batches * batch_size @@ -178,8 +180,10 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( worker_tracker_idx += 1 continue + # PART 2: We have leftover samples to consume + # We consume them one by one because we're at the end of a chunk and may have to handle + # a remainder from the previous iteration num_samples_left_for_this_worker_chunk = interval_size_of_current_worker[0] - # To consume a batch, we want to subtract `batch_size` from the size we have left, # unless we had a remainder (< batch size) from the previous iteration/chunk remover = ( From 6db1f00dda4431beb61d9f00082381c1357bf46e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 7 Aug 2024 00:31:13 +0000 Subject: [PATCH 3/4] add test --- tests/utilities/test_shuffle.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index 4bf4fd9c..70192da2 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -222,6 +222,15 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): ) assert chunks_to_disable == {1: [1]} + # world size = 1, 2 workers sharing one chunk, different sizes with remainders to next chunk + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=2, + batch_size=25, + workers_chunks=[[0, 1], [1, 2]], + workers_intervals=[[(0, 0, 70, 100), (0, 0, 55, 100)], [(0, 0, 105, 50), (0, 0, 55, 100)]], + ) + assert chunks_to_disable == {1: [0]} + # world size = 1, 4 workers sharing one chunk chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( num_workers=4, @@ -241,7 +250,7 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): assert chunks_to_disable == {0: [0, 1, 3]} # world size 2, 2 workers per rank, varying batch size - for batch_size in range(1, 6): + for batch_size in range(1, 7): chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( num_workers=2, batch_size=batch_size, From 1ebd857a4638aaeb50f99848bd1bf3e995fd1aa7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Aug 2024 00:31:32 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/utilities/shuffle.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index fb2dbc47..f24ec159 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -161,7 +161,6 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( counter = 0 while True: - # PART 1: Consume as many batches all at once for every worker and their respective current chunk if num_of_samples_to_carry_to_next_chunk is None: sizes = [size for size in workers_interval_sizes_for_this_rank if len(size)]