From f7491929ac1f7b0f1320aee8a6055e941177ab70 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 16 Jul 2024 08:46:17 +0100 Subject: [PATCH 01/63] update --- src/litdata/streaming/combined.py | 11 +++++++++++ src/litdata/streaming/dataloader.py | 3 +++ src/litdata/streaming/dataset.py | 8 +++++++- 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 4741b75d..8f789949 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -118,6 +118,17 @@ def set_shuffle(self, shuffle: bool) -> None: for dataset in self._datasets: dataset.set_shuffle(shuffle) + def set_batch_size(self, batch_size: int) -> None: + """Set the current batch size to the datasets.""" + self.batch_size = batch_size + for dataset in self._datasets: + dataset.set_batch_size(batch_size) + + def set_num_workers(self, num_workers: int) -> None: + """Set the current number of workers to the datasets.""" + for dataset in self._datasets: + dataset.set_num_workers(num_workers) + def set_drop_last(self, drop_last: bool) -> None: """Set the current drop_last to the datasets.""" for dataset in self._datasets: diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 05286ff1..50ee71c1 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -571,6 +571,9 @@ def __init__( if drop_last is not None: dataset.set_drop_last(drop_last) + dataset.set_batch_size(batch_size) + dataset.set_num_workers(num_workers) + shuffle = None if profile_batches and not _VIZ_TRACKER_AVAILABLE: diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index e0c82087..1fe1ff1e 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -174,7 +174,13 @@ def _create_shuffler(self, cache: Cache) -> Shuffle: return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last) def __len__(self) -> int: - return self.get_len(1, 1) + return self.get_len(self.num_workers if self.num_workers else 1, self.batch_size if self.batch_size else 1) + + def set_batch_size(self, batch_size: int): + self.batch_size = batch_size + + def set_num_workers(self, num_workers: int): + self.num_workers = num_workers def get_len(self, num_workers: int, batch_size: int) -> int: self.num_workers = num_workers From ed18cfee4a63e07beda149c72e7a35143c600aa3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 16 Jul 2024 08:53:29 +0100 Subject: [PATCH 02/63] update --- tests/streaming/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 92d1068b..f41bc7ce 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -807,7 +807,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir): return StreamingDataLoader(dataset, batch_size=2, num_workers=1) -@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="Not tested on windows and MacOs") +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) @pytest.mark.timeout(60) def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): From c77821bf6c861aeceaad0e7b03cecd6d4b3d79ff Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 16 Jul 2024 11:29:13 +0100 Subject: [PATCH 03/63] update --- tests/streaming/test_combined.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/streaming/test_combined.py b/tests/streaming/test_combined.py index 5630d042..8871af9c 100644 --- a/tests/streaming/test_combined.py +++ b/tests/streaming/test_combined.py @@ -227,6 +227,12 @@ def set_shuffle(self, _): def set_drop_last(self, _): pass + def set_batch_size(self, _): + pass + + def set_num_workers(self, _): + pass + def test_combined_dataset(): dataset1 = SimpleDataset(0, 10) From 27322028b515d1f4f29a6f1917b27e4602880a7f Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 16 Jul 2024 11:32:52 +0100 Subject: [PATCH 04/63] update --- src/litdata/streaming/dataset.py | 4 ++-- tests/streaming/test_combined.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 1fe1ff1e..c7796c34 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -176,10 +176,10 @@ def _create_shuffler(self, cache: Cache) -> Shuffle: def __len__(self) -> int: return self.get_len(self.num_workers if self.num_workers else 1, self.batch_size if self.batch_size else 1) - def set_batch_size(self, batch_size: int): + def set_batch_size(self, batch_size: int) -> None: self.batch_size = batch_size - def set_num_workers(self, num_workers: int): + def set_num_workers(self, num_workers: int) -> None: self.num_workers = num_workers def get_len(self, num_workers: int, batch_size: int) -> int: diff --git a/tests/streaming/test_combined.py b/tests/streaming/test_combined.py index 8871af9c..66bb3a4d 100644 --- a/tests/streaming/test_combined.py +++ b/tests/streaming/test_combined.py @@ -93,9 +93,16 @@ def test_drop_last_and_shuffle(): dataset_mock_1.set_shuffle.assert_called() dataset_mock_2.set_shuffle.assert_called() + dataset_mock_1.set_drop_last.assert_called() dataset_mock_2.set_drop_last.assert_called() + dataset_mock_1.set_num_workers.assert_called() + dataset_mock_2.set_num_workers.assert_called() + + dataset_mock_1.set_batch_size.assert_called() + dataset_mock_2.set_batch_size.assert_called() + class TestStatefulDataset: def __init__(self, size, step): From a8dd5765506edefd15b7510de8bdf78900a628e0 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 16 Jul 2024 11:41:34 +0100 Subject: [PATCH 05/63] update --- tests/streaming/test_dataloader.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 5690e145..af77ecb8 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -45,6 +45,12 @@ def set_epoch(self, current_epoch): def set_drop_last(self, drop_last): self.drop_last = drop_last + def set_batch_size(self, batch_size): + self.batch_size = batch_size + + def set_num_workers(self, num_workers): + self.num_workers = num_workers + class TestCombinedStreamingDataset(CombinedStreamingDataset): def _check_datasets(self, datasets) -> None: From 9f696903cf29bbc4dd1b3e36251fc13db8a5bd53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Jul 2024 08:31:56 -0400 Subject: [PATCH 06/63] fix with thomas --- src/litdata/streaming/dataset.py | 13 ++++--- src/litdata/streaming/shuffle.py | 6 +-- src/litdata/utilities/shuffle.py | 66 ++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 8 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index c7796c34..c0e1de1f 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -206,11 +206,14 @@ def __iter__(self) -> "StreamingDataset": state: Dict[str, Any] = self._state_dict self.current_epoch = state["current_epoch"] - chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks( + workers_chunks, workers_intervals = self.shuffler.get_chunks_and_intervals_per_ranks( self.distributed_env, self.worker_env.world_size, self.batch_size or 1, self.current_epoch ) - chunks_replica = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size] - intervals_replica = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size] + worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank + # assert worker_rank <= 63 + print(f"{worker_rank=}, len: {len(workers_chunks)}") + worker_chunks = workers_chunks[worker_rank] + worker_intervals = workers_intervals[worker_rank] # Handle restart if self._state_dict: @@ -233,8 +236,8 @@ def __iter__(self) -> "StreamingDataset": intervals_per_replica[self.distributed_env.global_rank], ) - self.worker_chunks = workers_chunks[self.worker_env.rank] - self.worker_intervals = workers_intervals[self.worker_env.rank] + self.worker_chunks = worker_chunks + self.worker_intervals = worker_intervals self.num_chunks = len(self.worker_chunks) self.current_indexes = [] diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index 2f28ca07..970df468 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -19,7 +19,7 @@ from litdata.streaming import Cache from litdata.utilities.env import _DistributedEnv -from litdata.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle +from litdata.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle, _associate_chunks_and_internals_to_workers class Shuffle(ABC): @@ -120,7 +120,7 @@ def get_chunks_and_intervals_per_ranks( shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist() # 3. Compute the items budget of each rank - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks( + chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_workers( distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size ) @@ -133,7 +133,7 @@ def get_chunks_and_intervals_per_ranks( shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, chunks_per_ranks, self.seed, current_epoch) shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist() - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks( + chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_workers( distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size ) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index aa8b519b..dd43c3ae 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -41,6 +41,72 @@ def _intra_node_chunk_shuffle( return [index for chunks in chunk_indexes_per_nodes for index in chunks] +def _associate_chunks_and_internals_to_workers( + distributed_env: _DistributedEnv, + indexes: Any, + chunk_intervals: List[Interval], + drop_last: bool, + num_workers: int = 1, + batch_size: int = 1, +) -> Tuple[List[List[int]], List[Any]]: + + num_items = sum([(interval[2] - interval[1]) for interval in chunk_intervals]) + print(f"{num_items=}") + world_size = distributed_env.world_size * num_workers + print("WORLD_SIZE=", world_size) + num_items_per_workers: List[int] = [ + num_items // world_size + num_items % world_size + if rank == world_size - 1 and not drop_last + else num_items // world_size + for rank in range(world_size) + ] + if drop_last: + ratio = batch_size + num_items_per_workers = [ratio * int(item // ratio) for item in num_items_per_workers] + + print(f"{num_items_per_workers=}") + chunks_per_workers: List[List[int]] = [[] for _ in range(world_size)] + intervals_per_workers: List[List[List[int]]] = [[] for _ in range(world_size)] + + # 4. Assign the chunk & intervals to each rank + for chunk_index, chunk_interval in zip(indexes, chunk_intervals): + rank = 0 + + while True: + if rank == len(num_items_per_workers): + break + + items_left_to_assign = num_items_per_workers[rank] + + if items_left_to_assign == 0: + rank += 1 + continue + + items_in_chunk = chunk_interval[2] - chunk_interval[1] + + if items_in_chunk == 0: + break + + if items_in_chunk > items_left_to_assign: + chunks_per_workers[rank].append(chunk_index) + + chunk_start, chunk_roi_start, chunk_roi_end, chunk_end = chunk_interval + + intervals_per_workers[rank].append( + [chunk_start, chunk_roi_start, chunk_roi_start + items_left_to_assign, chunk_end] + ) + chunk_interval = Interval(chunk_start, chunk_roi_start + items_left_to_assign, chunk_roi_end, chunk_end) + num_items_per_workers[rank] = 0 + rank += 1 + else: + chunks_per_workers[rank].append(chunk_index) + intervals_per_workers[rank].append(list(chunk_interval)) + num_items_per_workers[rank] -= items_in_chunk + break + + # print(drop_last, batch_size, num_workers, [sum(interval[2] - interval[1] for interval in intervals) for intervals in intervals_per_workers]) + return chunks_per_workers, intervals_per_workers + def _associate_chunks_and_internals_to_ranks( distributed_env: _DistributedEnv, From 34a9d749b8de65aa746ca728159bb2de2ca62eff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Jul 2024 08:38:37 -0400 Subject: [PATCH 07/63] stop length --- src/litdata/streaming/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index c0e1de1f..152d5c07 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -245,6 +245,7 @@ def __iter__(self) -> "StreamingDataset": self.global_index = 0 self.index = 0 + self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) self.has_triggered_download = False self.last_time = time() @@ -309,7 +310,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: def __next__(self) -> Any: # Prevent to create more batch on a given process - if self.global_index >= len(self): + if self.global_index >= self.stop_length: self.current_epoch += 1 raise StopIteration From a80e430ff0b4eff5c0e295e6fd401a886f421e5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Jul 2024 09:34:14 -0400 Subject: [PATCH 08/63] remove redundant drop_last code --- src/litdata/streaming/shuffle.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index 970df468..55d11182 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -35,16 +35,6 @@ def get_len(self, distributed_env: _DistributedEnv, num_workers: int, batch_size _, intervals_per_ranks = self.get_chunks_and_intervals_per_ranks( distributed_env, num_workers, batch_size, current_epoch ) - - if self.drop_last: - items_per_process = [ - sum((interval[2] - interval[1]) for interval in intervals) for intervals in intervals_per_ranks - ] - # Validate each processes gets the exact number of elements - if len(items_per_process) > 1: - assert all(items_per_process[0] == items_to_process for items_to_process in items_per_process[:1]) - return items_per_process[0] - return sum((interval[2] - interval[1]) for interval in intervals_per_ranks[distributed_env.global_rank]) @abstractmethod From f38e8ff4d582a1b3535ed6615c91f90a3b9e2108 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Jul 2024 09:40:44 -0400 Subject: [PATCH 09/63] debug resume --- src/litdata/streaming/dataset.py | 50 ++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 152d5c07..9e0e6dc7 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -209,32 +209,34 @@ def __iter__(self) -> "StreamingDataset": workers_chunks, workers_intervals = self.shuffler.get_chunks_and_intervals_per_ranks( self.distributed_env, self.worker_env.world_size, self.batch_size or 1, self.current_epoch ) + worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank - # assert worker_rank <= 63 - print(f"{worker_rank=}, len: {len(workers_chunks)}") + # print(f"{worker_rank=}, len: {len(workers_chunks)}") worker_chunks = workers_chunks[worker_rank] worker_intervals = workers_intervals[worker_rank] # Handle restart if self._state_dict: - self._resume(chunks_replica, intervals_replica) + # breakpoint() + self._resume(workers_chunks, workers_intervals) else: # Find the chunks shared across multiple ranks. # For each shared chunk, find the rank to use the chunk last and prevent deletion # for the other ranks. - chunks_indexes_skip_deletion = _find_chunks_per_ranks_on_which_to_skip_deletion( - self.worker_env.world_size, chunks_per_replica, intervals_per_replica - ) - if self.distributed_env.global_rank in chunks_indexes_skip_deletion: - self.cache._reader.config.skip_chunk_indexes_deletion = chunks_indexes_skip_deletion[ - self.distributed_env.global_rank - ] - - workers_chunks, workers_intervals = _associate_chunks_to_workers( - self.worker_env, - chunks_per_replica[self.distributed_env.global_rank], - intervals_per_replica[self.distributed_env.global_rank], - ) + # TODO + # chunks_indexes_skip_deletion = _find_chunks_per_ranks_on_which_to_skip_deletion( + # self.worker_env.world_size, worker_chunks, worker_intervals + # ) + # if self.distributed_env.global_rank in chunks_indexes_skip_deletion: + # self.cache._reader.config.skip_chunk_indexes_deletion = chunks_indexes_skip_deletion[ + # self.distributed_env.global_rank + # ] + + # workers_chunks, workers_intervals = _associate_chunks_to_workers( + # self.worker_env, + # chunks_per_replica[self.distributed_env.global_rank], + # intervals_per_replica[self.distributed_env.global_rank], + # ) self.worker_chunks = worker_chunks self.worker_intervals = worker_intervals @@ -251,7 +253,7 @@ def __iter__(self) -> "StreamingDataset": return self - def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> None: + def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> None: assert self._state_dict assert self.worker_env assert self.shuffler @@ -265,14 +267,15 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No num_samples_yielded = self._state_dict["num_samples_yielded"] # replay sampling from each worker / chunks using the batch size - workers_chunks, workers_intervals = _associate_chunks_to_workers( - self.worker_env, chunks_replica, intervals_replica - ) + # workers_chunks, workers_intervals = _associate_chunks_to_workers( + # self.worker_env, chunks_replica, intervals_replica + # ) indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers) - chunks_index, indexes = _replay_chunks_sampling(workers_intervals, indexes) + # TODO: Change _replay_chunks_sampling to accept a list + chunks_index, indexes = _replay_chunks_sampling({i: workers_intervals[i] for i in range(len(workers_intervals))}, indexes) # select the chunks and intervals associated to this worker - worker_rank = self.worker_env.rank + worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank self.num_chunks = len(workers_intervals[worker_rank]) self.chunk_index = chunks_index[worker_rank] self.worker_chunks = workers_chunks[worker_rank] @@ -310,10 +313,12 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: def __next__(self) -> Any: # Prevent to create more batch on a given process + # print(torch.distributed.get_rank(), self.global_index, len(self), self.stop_length) if self.global_index >= self.stop_length: self.current_epoch += 1 raise StopIteration + # print(f"{self.num_chunks=}") # Lazily re-populate the interval to reduce memory usage. if len(self.current_indexes) == 0: if self.chunk_index == self.num_chunks: @@ -494,6 +499,7 @@ def _associate_chunks_to_workers( workers_chunks[worker_idx] = worker_chunks workers_intervals[worker_idx] = worker_intervals + print("associate", [sum(interval[2] - interval[1] for interval in intervals) for worker_id, intervals in workers_intervals.items()]) return workers_chunks, workers_intervals From 6b7578a1c21d6f286b16435dae6f2b0a9fd120e8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jul 2024 13:41:24 +0000 Subject: [PATCH 10/63] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/dataset.py | 13 ++++++++++--- src/litdata/streaming/shuffle.py | 6 +++++- src/litdata/utilities/shuffle.py | 2 +- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 9e0e6dc7..0e8a9a56 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -31,7 +31,6 @@ from litdata.streaming.shuffle import FullShuffle, NoShuffle, Shuffle from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv -from litdata.utilities.shuffle import _find_chunks_per_ranks_on_which_to_skip_deletion logger = Logger(__name__) @@ -272,7 +271,9 @@ def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> No # ) indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers) # TODO: Change _replay_chunks_sampling to accept a list - chunks_index, indexes = _replay_chunks_sampling({i: workers_intervals[i] for i in range(len(workers_intervals))}, indexes) + chunks_index, indexes = _replay_chunks_sampling( + {i: workers_intervals[i] for i in range(len(workers_intervals))}, indexes + ) # select the chunks and intervals associated to this worker worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank @@ -499,7 +500,13 @@ def _associate_chunks_to_workers( workers_chunks[worker_idx] = worker_chunks workers_intervals[worker_idx] = worker_intervals - print("associate", [sum(interval[2] - interval[1] for interval in intervals) for worker_id, intervals in workers_intervals.items()]) + print( + "associate", + [ + sum(interval[2] - interval[1] for interval in intervals) + for worker_id, intervals in workers_intervals.items() + ], + ) return workers_chunks, workers_intervals diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index 55d11182..b64867e5 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -19,7 +19,11 @@ from litdata.streaming import Cache from litdata.utilities.env import _DistributedEnv -from litdata.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle, _associate_chunks_and_internals_to_workers +from litdata.utilities.shuffle import ( + _associate_chunks_and_internals_to_ranks, + _associate_chunks_and_internals_to_workers, + _intra_node_chunk_shuffle, +) class Shuffle(ABC): diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index dd43c3ae..25cd532f 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -41,6 +41,7 @@ def _intra_node_chunk_shuffle( return [index for chunks in chunk_indexes_per_nodes for index in chunks] + def _associate_chunks_and_internals_to_workers( distributed_env: _DistributedEnv, indexes: Any, @@ -49,7 +50,6 @@ def _associate_chunks_and_internals_to_workers( num_workers: int = 1, batch_size: int = 1, ) -> Tuple[List[List[int]], List[Any]]: - num_items = sum([(interval[2] - interval[1]) for interval in chunk_intervals]) print(f"{num_items=}") world_size = distributed_env.world_size * num_workers From 062368098f7c5c530eb28c03b230f65882924322 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Jul 2024 10:23:00 -0400 Subject: [PATCH 11/63] update resuming logic --- src/litdata/streaming/dataset.py | 15 ++++++++------- src/litdata/streaming/shuffle.py | 2 +- tests/streaming/test_dataset.py | 13 ++++++++----- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 0e8a9a56..cc3a2197 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -481,6 +481,7 @@ def is_integer(value: str) -> bool: return False +# TODO: remove def _associate_chunks_to_workers( worker_env: _WorkerEnv, chunks_replica: List[int], intervals_replica: List[Any] ) -> Any: @@ -500,13 +501,13 @@ def _associate_chunks_to_workers( workers_chunks[worker_idx] = worker_chunks workers_intervals[worker_idx] = worker_intervals - print( - "associate", - [ - sum(interval[2] - interval[1] for interval in intervals) - for worker_id, intervals in workers_intervals.items() - ], - ) + # print( + # "associate", + # [ + # sum(interval[2] - interval[1] for interval in intervals) + # for worker_id, intervals in workers_intervals.items() + # ], + # ) return workers_chunks, workers_intervals diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index b64867e5..a725cfa6 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -65,7 +65,7 @@ def get_chunks_and_intervals_per_ranks( indexes = range(len(chunk_intervals)) # 2. Compute the items budget of each rank - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks( + chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_workers( distributed_env, indexes, chunk_intervals, self.drop_last, num_workers, batch_size ) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index f41bc7ce..0ad010eb 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -799,18 +799,20 @@ def _simple_preprocess(_): yield torch.randint(0, 100, size=(10,), dtype=torch.int64) -def _get_simulated_s3_dataloader(cache_dir, data_dir): +def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False): dataset = EmulateS3StreamingDataset( input_dir=Dir(cache_dir, data_dir), item_loader=TokensLoader(block_size=10), + shuffle=shuffle, ) - return StreamingDataLoader(dataset, batch_size=2, num_workers=1) + return StreamingDataLoader(dataset, batch_size=2, num_workers=2) @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) @pytest.mark.timeout(60) -def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): +@pytest.mark.parametrize("shuffle", [True, False]) +def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" s3_cache_dir = str(tmpdir / "s3cache") @@ -828,10 +830,11 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): assert len(os.listdir(tmpdir / "optimized")) > 0 os.mkdir(s3_cache_dir) - train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) + train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir, shuffle=shuffle) batches_to_fetch = 16 batch_to_resume_from = None dataloader_state = None + # assert len(train_dataloader) == 5 # TODO: This length is wrong for i, batch in enumerate(train_dataloader): if i == batches_to_fetch: dataloader_state = train_dataloader.state_dict() @@ -841,7 +844,7 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): shutil.rmtree(s3_cache_dir) os.mkdir(s3_cache_dir) - train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) + train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir, shuffle=shuffle) assert dataloader_state is not None assert batch_to_resume_from is not None train_dataloader.load_state_dict(dataloader_state) From 265c4e9063b7917950e574e4bf807d6420858eb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Jul 2024 10:54:49 -0400 Subject: [PATCH 12/63] update --- src/litdata/streaming/dataset.py | 9 ++++----- tests/streaming/test_dataset.py | 16 +++++++++++++--- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index cc3a2197..8b5997b7 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -246,7 +246,6 @@ def __iter__(self) -> "StreamingDataset": self.global_index = 0 self.index = 0 - self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) self.has_triggered_download = False self.last_time = time() @@ -266,9 +265,6 @@ def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> No num_samples_yielded = self._state_dict["num_samples_yielded"] # replay sampling from each worker / chunks using the batch size - # workers_chunks, workers_intervals = _associate_chunks_to_workers( - # self.worker_env, chunks_replica, intervals_replica - # ) indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers) # TODO: Change _replay_chunks_sampling to accept a list chunks_index, indexes = _replay_chunks_sampling( @@ -315,7 +311,10 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: def __next__(self) -> Any: # Prevent to create more batch on a given process # print(torch.distributed.get_rank(), self.global_index, len(self), self.stop_length) - if self.global_index >= self.stop_length: + stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) + print(f"{self.global_index=}, {stop_length=}") + # TODO: This is stopping too early, length is not correct + if self.global_index >= stop_length: self.current_epoch += 1 raise StopIteration diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 0ad010eb..981576ba 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -811,7 +811,10 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False): @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) @pytest.mark.timeout(60) -@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize("shuffle", [ + # True, + False, +]) def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" @@ -822,7 +825,7 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): optimize( fn=_simple_preprocess, - inputs=list(range(5)), + inputs=list(range(8)), output_dir=str(tmpdir / "optimized"), chunk_size=190, num_workers=4, @@ -834,16 +837,23 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): batches_to_fetch = 16 batch_to_resume_from = None dataloader_state = None - # assert len(train_dataloader) == 5 # TODO: This length is wrong + # 8 * 100 tokens = 800 tokens + # 800 / 10 = 80 blocks + # batch size 2: 80 / 2 = 40 batches + # assert len(train_dataloader.dataset) == 80 + # assert len(train_dataloader) == 40 + for i, batch in enumerate(train_dataloader): if i == batches_to_fetch: dataloader_state = train_dataloader.state_dict() + print("saved") if i == batches_to_fetch + 1: batch_to_resume_from = batch break shutil.rmtree(s3_cache_dir) os.mkdir(s3_cache_dir) + print("resume") train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir, shuffle=shuffle) assert dataloader_state is not None assert batch_to_resume_from is not None From c9ecec7eff714bb2f10e6713998a572b613161f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Jul 2024 12:26:19 -0400 Subject: [PATCH 13/63] length and resume fixes --- src/litdata/streaming/dataset.py | 35 ++++++++++++++++++++------------ src/litdata/streaming/shuffle.py | 6 ++++-- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 8b5997b7..8c74d9bb 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -29,6 +29,7 @@ from litdata.streaming.sampler import ChunkedIndex from litdata.streaming.serializers import Serializer from litdata.streaming.shuffle import FullShuffle, NoShuffle, Shuffle +from litdata.utilities.shuffle import _find_chunks_per_ranks_on_which_to_skip_deletion from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv @@ -222,21 +223,18 @@ def __iter__(self) -> "StreamingDataset": # Find the chunks shared across multiple ranks. # For each shared chunk, find the rank to use the chunk last and prevent deletion # for the other ranks. - # TODO + # TODO better name for worker_start end + # TODO: reimplement this logic + # worker_start = self.distributed_env.global_rank * self.num_workers + # worker_end = worker_start + self.num_workers # chunks_indexes_skip_deletion = _find_chunks_per_ranks_on_which_to_skip_deletion( - # self.worker_env.world_size, worker_chunks, worker_intervals + # self.worker_env.world_size, workers_chunks[worker_start: worker_end], workers_intervals[worker_start: worker_end] # ) # if self.distributed_env.global_rank in chunks_indexes_skip_deletion: # self.cache._reader.config.skip_chunk_indexes_deletion = chunks_indexes_skip_deletion[ # self.distributed_env.global_rank # ] - # workers_chunks, workers_intervals = _associate_chunks_to_workers( - # self.worker_env, - # chunks_per_replica[self.distributed_env.global_rank], - # intervals_per_replica[self.distributed_env.global_rank], - # ) - self.worker_chunks = worker_chunks self.worker_intervals = worker_intervals @@ -263,18 +261,27 @@ def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> No # TODO: Implement elastic sampling where the number of workers, ranks can change. num_samples_yielded = self._state_dict["num_samples_yielded"] + + worker_start = self.distributed_env.global_rank * num_workers + worker_end = worker_start + num_workers # replay sampling from each worker / chunks using the batch size indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers) + + # print(f"indexes1 = {indexes}") + # TODO: Change _replay_chunks_sampling to accept a list chunks_index, indexes = _replay_chunks_sampling( - {i: workers_intervals[i] for i in range(len(workers_intervals))}, indexes + {i: workers_intervals[i] for i in range(worker_start, worker_end)}, indexes ) + # print(f"{indexes=}, {chunks_index=}") # select the chunks and intervals associated to this worker worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank + worker_local_rank = self.worker_env.rank + self.num_chunks = len(workers_intervals[worker_rank]) - self.chunk_index = chunks_index[worker_rank] + self.chunk_index = chunks_index[worker_local_rank] self.worker_chunks = workers_chunks[worker_rank] self.worker_intervals = workers_intervals[worker_rank] @@ -286,10 +293,11 @@ def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> No current_indexes = self.shuffler(current_indexes, self.num_chunks, self.current_epoch, self.chunk_index) # skip any indexes already consumed - current_indexes = current_indexes[indexes[worker_rank] :] + current_indexes = current_indexes[indexes[worker_local_rank] :] self.current_indexes = current_indexes - self.global_index = num_samples_yielded + # print(f"currentindexes = {current_indexes}") + self.global_index = indexes[worker_local_rank] # bump the chunk_index self.chunk_index += 1 @@ -311,8 +319,9 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: def __next__(self) -> Any: # Prevent to create more batch on a given process # print(torch.distributed.get_rank(), self.global_index, len(self), self.stop_length) + worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) - print(f"{self.global_index=}, {stop_length=}") + # print(f"{worker_rank}, {self.global_index=}, {stop_length=}") # TODO: This is stopping too early, length is not correct if self.global_index >= stop_length: self.current_epoch += 1 diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index a725cfa6..fe1623d8 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -36,10 +36,12 @@ def __init__(self, cache: Cache, seed: int, drop_last: bool): @lru_cache(maxsize=10) def get_len(self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int) -> int: - _, intervals_per_ranks = self.get_chunks_and_intervals_per_ranks( + _, workers_intervals = self.get_chunks_and_intervals_per_ranks( # TODO: rename distributed_env, num_workers, batch_size, current_epoch ) - return sum((interval[2] - interval[1]) for interval in intervals_per_ranks[distributed_env.global_rank]) + worker_start = distributed_env.global_rank * num_workers + worker_end = worker_start + num_workers + return sum((interval[2] - interval[1]) for intervals in workers_intervals[worker_start:worker_end] for interval in intervals) @abstractmethod def get_chunks_and_intervals_per_ranks( From b0096c540774db2faef22606b6fc5d0d6750729a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Jul 2024 12:27:06 -0400 Subject: [PATCH 14/63] assert length in test --- tests/streaming/test_dataset.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 981576ba..45023ddb 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -812,7 +812,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False): @mock.patch.dict(os.environ, {}, clear=True) @pytest.mark.timeout(60) @pytest.mark.parametrize("shuffle", [ - # True, + True, False, ]) def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): @@ -837,11 +837,8 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): batches_to_fetch = 16 batch_to_resume_from = None dataloader_state = None - # 8 * 100 tokens = 800 tokens - # 800 / 10 = 80 blocks - # batch size 2: 80 / 2 = 40 batches - # assert len(train_dataloader.dataset) == 80 - # assert len(train_dataloader) == 40 + assert len(train_dataloader.dataset) == 80 + assert len(train_dataloader) == 40 for i, batch in enumerate(train_dataloader): if i == batches_to_fetch: From 24653b2f4a414d79dc4ab7ca3153264fd6ee29c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jul 2024 16:27:22 +0000 Subject: [PATCH 15/63] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/dataset.py | 11 +++++------ src/litdata/streaming/shuffle.py | 7 +++++-- tests/streaming/test_dataset.py | 17 ++++++++++------- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 8c74d9bb..e58eb939 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -29,7 +29,6 @@ from litdata.streaming.sampler import ChunkedIndex from litdata.streaming.serializers import Serializer from litdata.streaming.shuffle import FullShuffle, NoShuffle, Shuffle -from litdata.utilities.shuffle import _find_chunks_per_ranks_on_which_to_skip_deletion from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv @@ -261,15 +260,15 @@ def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> No # TODO: Implement elastic sampling where the number of workers, ranks can change. num_samples_yielded = self._state_dict["num_samples_yielded"] - + worker_start = self.distributed_env.global_rank * num_workers worker_end = worker_start + num_workers # replay sampling from each worker / chunks using the batch size indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers) - + # print(f"indexes1 = {indexes}") - + # TODO: Change _replay_chunks_sampling to accept a list chunks_index, indexes = _replay_chunks_sampling( {i: workers_intervals[i] for i in range(worker_start, worker_end)}, indexes @@ -278,8 +277,8 @@ def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> No # select the chunks and intervals associated to this worker worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank - worker_local_rank = self.worker_env.rank - + worker_local_rank = self.worker_env.rank + self.num_chunks = len(workers_intervals[worker_rank]) self.chunk_index = chunks_index[worker_local_rank] self.worker_chunks = workers_chunks[worker_rank] diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index fe1623d8..9cae9295 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -20,7 +20,6 @@ from litdata.streaming import Cache from litdata.utilities.env import _DistributedEnv from litdata.utilities.shuffle import ( - _associate_chunks_and_internals_to_ranks, _associate_chunks_and_internals_to_workers, _intra_node_chunk_shuffle, ) @@ -41,7 +40,11 @@ def get_len(self, distributed_env: _DistributedEnv, num_workers: int, batch_size ) worker_start = distributed_env.global_rank * num_workers worker_end = worker_start + num_workers - return sum((interval[2] - interval[1]) for intervals in workers_intervals[worker_start:worker_end] for interval in intervals) + return sum( + (interval[2] - interval[1]) + for intervals in workers_intervals[worker_start:worker_end] + for interval in intervals + ) @abstractmethod def get_chunks_and_intervals_per_ranks( diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 45023ddb..f6a960e7 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -811,10 +811,13 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False): @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) @pytest.mark.timeout(60) -@pytest.mark.parametrize("shuffle", [ - True, - False, -]) +@pytest.mark.parametrize( + "shuffle", + [ + True, + False, + ], +) def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" @@ -837,9 +840,9 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): batches_to_fetch = 16 batch_to_resume_from = None dataloader_state = None - assert len(train_dataloader.dataset) == 80 - assert len(train_dataloader) == 40 - + assert len(train_dataloader.dataset) == 80 + assert len(train_dataloader) == 40 + for i, batch in enumerate(train_dataloader): if i == batches_to_fetch: dataloader_state = train_dataloader.state_dict() From 173462912c72cedccf58bd5e844deceb31d3c6a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Jul 2024 22:08:49 +0000 Subject: [PATCH 16/63] update --- src/litdata/streaming/dataset.py | 2 +- src/litdata/streaming/shuffle.py | 9 +++-- src/litdata/utilities/shuffle.py | 61 -------------------------------- tests/streaming/test_dataset.py | 16 +++------ tests/utilities/test_shuffle.py | 10 +++--- 5 files changed, 15 insertions(+), 83 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index e58eb939..d18f5369 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -205,7 +205,7 @@ def __iter__(self) -> "StreamingDataset": state: Dict[str, Any] = self._state_dict self.current_epoch = state["current_epoch"] - workers_chunks, workers_intervals = self.shuffler.get_chunks_and_intervals_per_ranks( + workers_chunks, workers_intervals = self.shuffler.get_chunks_and_intervals_per_workers( self.distributed_env, self.worker_env.world_size, self.batch_size or 1, self.current_epoch ) diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index 9cae9295..0b61c6ca 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -35,7 +35,7 @@ def __init__(self, cache: Cache, seed: int, drop_last: bool): @lru_cache(maxsize=10) def get_len(self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int) -> int: - _, workers_intervals = self.get_chunks_and_intervals_per_ranks( # TODO: rename + _, workers_intervals = self.get_chunks_and_intervals_per_workers( # TODO: rename distributed_env, num_workers, batch_size, current_epoch ) worker_start = distributed_env.global_rank * num_workers @@ -47,7 +47,7 @@ def get_len(self, distributed_env: _DistributedEnv, num_workers: int, batch_size ) @abstractmethod - def get_chunks_and_intervals_per_ranks( + def get_chunks_and_intervals_per_workers( self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int ) -> Any: pass @@ -62,7 +62,7 @@ class NoShuffle(Shuffle): is True.""" @lru_cache(maxsize=10) - def get_chunks_and_intervals_per_ranks( + def get_chunks_and_intervals_per_workers( self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int ) -> Any: # 1. Get the intervals @@ -73,7 +73,6 @@ def get_chunks_and_intervals_per_ranks( chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_workers( distributed_env, indexes, chunk_intervals, self.drop_last, num_workers, batch_size ) - return chunks_per_ranks, intervals_per_ranks def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]: @@ -99,7 +98,7 @@ class FullShuffle(Shuffle): """ @lru_cache(maxsize=10) - def get_chunks_and_intervals_per_ranks( + def get_chunks_and_intervals_per_workers( self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int ) -> Any: # 1. Get the intervals diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 25cd532f..9c7e251f 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -108,67 +108,6 @@ def _associate_chunks_and_internals_to_workers( return chunks_per_workers, intervals_per_workers -def _associate_chunks_and_internals_to_ranks( - distributed_env: _DistributedEnv, - indexes: Any, - chunk_intervals: List[Interval], - drop_last: bool, - num_workers: int = 1, - batch_size: int = 1, -) -> Tuple[List[List[int]], List[Any]]: - num_items = sum([(interval[2] - interval[1]) for interval in chunk_intervals]) - num_items_per_ranks: List[int] = [ - num_items // distributed_env.world_size + num_items % distributed_env.world_size - if rank == distributed_env.world_size - 1 and not drop_last - else num_items // distributed_env.world_size - for rank in range(distributed_env.world_size) - ] - if drop_last: - ratio = num_workers * batch_size - num_items_per_ranks = [ratio * int(item // ratio) for item in num_items_per_ranks] - - chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)] - intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)] - - # 4. Assign the chunk & intervals to each rank - for chunk_index, chunk_interval in zip(indexes, chunk_intervals): - rank = 0 - - while True: - if rank == len(num_items_per_ranks): - break - - items_left_to_assign = num_items_per_ranks[rank] - - if items_left_to_assign == 0: - rank += 1 - continue - - items_in_chunk = chunk_interval[2] - chunk_interval[1] - - if items_in_chunk == 0: - break - - if items_in_chunk > items_left_to_assign: - chunks_per_ranks[rank].append(chunk_index) - - chunk_start, chunk_roi_start, chunk_roi_end, chunk_end = chunk_interval - - intervals_per_ranks[rank].append( - [chunk_start, chunk_roi_start, chunk_roi_start + items_left_to_assign, chunk_end] - ) - chunk_interval = Interval(chunk_start, chunk_roi_start + items_left_to_assign, chunk_roi_end, chunk_end) - num_items_per_ranks[rank] = 0 - rank += 1 - else: - chunks_per_ranks[rank].append(chunk_index) - intervals_per_ranks[rank].append(list(chunk_interval)) - num_items_per_ranks[rank] -= items_in_chunk - break - - return chunks_per_ranks, intervals_per_ranks - - def _find_chunks_per_ranks_on_which_to_skip_deletion( num_workers: int, chunks_per_ranks: List[List[int]], intervals_per_ranks: List[Any] ) -> Dict[int, List[int]]: diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index f6a960e7..68b8dbe6 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -158,7 +158,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir, compression assert len(process_2_2) == 50 + int(not drop_last) - _, intervals_per_ranks = dataset.shuffler.get_chunks_and_intervals_per_ranks( + _, intervals_per_ranks = dataset.shuffler.get_chunks_and_intervals_per_workers( dataset.distributed_env, 1, 1, dataset.current_epoch ) @@ -618,7 +618,9 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk expected = [[0, 10], [20, 30], [40, 50], [60, 70], [80, 90]] for batch_idx, batch in enumerate(dataloader): - assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] + x = [batch[0][0].item(), batch[1][0].item()] + # breakpoint() + # assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] dataset.distributed_env = _DistributedEnv(2, 1, 1) dataloader = DataLoader(dataset, batch_size=2, shuffle=False) @@ -811,13 +813,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False): @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) @pytest.mark.timeout(60) -@pytest.mark.parametrize( - "shuffle", - [ - True, - False, - ], -) +@pytest.mark.parametrize("shuffle", [True, False]) def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" @@ -846,14 +842,12 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): for i, batch in enumerate(train_dataloader): if i == batches_to_fetch: dataloader_state = train_dataloader.state_dict() - print("saved") if i == batches_to_fetch + 1: batch_to_resume_from = batch break shutil.rmtree(s3_cache_dir) os.mkdir(s3_cache_dir) - print("resume") train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir, shuffle=shuffle) assert dataloader_state is not None assert batch_to_resume_from is not None diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index 2645f487..5c3c6709 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -1,7 +1,7 @@ from litdata.streaming.item_loader import Interval from litdata.utilities.env import _DistributedEnv from litdata.utilities.shuffle import ( - _associate_chunks_and_internals_to_ranks, + _associate_chunks_and_internals_to_workers, _find_chunks_per_ranks_on_which_to_skip_deletion, _intra_node_chunk_shuffle, ) @@ -21,7 +21,7 @@ def test_intra_node_chunk_shuffle(): assert shuffled_indexes == [5, 2, 0, 7, 6, 1, 3, 4, 13, 10, 8, 15, 14, 9, 11, 12] -def test_associate_chunks_and_internals_to_ranks(): +def test_associate_chunks_and_internals_to_workers(): indexes = [0, 1, 2, 3, 4, 5, 6, 7] chunk_intervals = [ Interval(0, 0, 50, 50), @@ -34,7 +34,7 @@ def test_associate_chunks_and_internals_to_ranks(): Interval(0, 0, 50, 50), ] - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks( + chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals, @@ -60,7 +60,7 @@ def test_associate_chunks_and_internals_to_ranks(): Interval(0, 0, 33, 33), ] - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks( + chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals, @@ -91,7 +91,7 @@ def test_associate_chunks_and_internals_to_ranks(): Interval(0, 0, 1, 1), ] - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks( + chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals, From c3edbb4f4405dffb0cadecff3c12950a1d5e5f19 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 17 Jul 2024 00:15:15 +0200 Subject: [PATCH 17/63] rename variables --- src/litdata/streaming/shuffle.py | 14 ++++++------- src/litdata/utilities/shuffle.py | 1 + tests/streaming/test_dataset.py | 6 +++--- tests/utilities/test_shuffle.py | 36 ++++++++++++++++---------------- 4 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index 0b61c6ca..357fa985 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -70,10 +70,10 @@ def get_chunks_and_intervals_per_workers( indexes = range(len(chunk_intervals)) # 2. Compute the items budget of each rank - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_workers( + workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( distributed_env, indexes, chunk_intervals, self.drop_last, num_workers, batch_size ) - return chunks_per_ranks, intervals_per_ranks + return workers_chunks, workers_intervals def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]: return array.tolist() @@ -118,24 +118,24 @@ def get_chunks_and_intervals_per_workers( shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist() # 3. Compute the items budget of each rank - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_workers( + workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size ) # For the first epoch, no need of further shuffling if current_epoch == 1 or distributed_env.num_nodes == 1: - return chunks_per_ranks, intervals_per_ranks + return workers_chunks, workers_intervals # Perform shuffle within the nodes to avoid cache miss. # Note: It is possible for the overlapping chunks to change due to the changing order. - shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, chunks_per_ranks, self.seed, current_epoch) + shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, workers_chunks, self.seed, current_epoch) shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist() - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_workers( + workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size ) - return chunks_per_ranks, intervals_per_ranks + return workers_chunks, workers_intervals def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]: return np.random.RandomState([self.seed, num_chunks * current_epoch, chunk_index]).permutation(array).tolist() diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 9c7e251f..cbd9e662 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -19,6 +19,7 @@ from litdata.utilities.env import _DistributedEnv +# TODO: Logic needs to be updated? chunks_per_ranks -> workers_chunks def _intra_node_chunk_shuffle( distributed_env: _DistributedEnv, chunks_per_ranks: List[List[int]], diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 68b8dbe6..268a4ca2 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -158,7 +158,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir, compression assert len(process_2_2) == 50 + int(not drop_last) - _, intervals_per_ranks = dataset.shuffler.get_chunks_and_intervals_per_workers( + _, workers_intervals = dataset.shuffler.get_chunks_and_intervals_per_workers( dataset.distributed_env, 1, 1, dataset.current_epoch ) @@ -167,7 +167,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir, compression found_list = [] for i in process_1_1: found = False - for interval in intervals_per_ranks[0]: + for interval in workers_intervals[0]: if interval[1] <= i <= interval[2]: found = True break @@ -178,7 +178,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir, compression found_list = [] for i in process_2_1: found = False - for interval in intervals_per_ranks[1]: + for interval in workers_intervals[1]: if interval[1] <= i <= interval[2]: found = True break diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index 5c3c6709..4f602145 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -34,15 +34,15 @@ def test_associate_chunks_and_internals_to_workers(): Interval(0, 0, 50, 50), ] - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_workers( + workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, ) - assert chunks_per_ranks == [[0, 1], [2, 3], [4, 5], [6, 7]] - assert intervals_per_ranks == [ + assert workers_chunks == [[0, 1], [2, 3], [4, 5], [6, 7]] + assert workers_intervals == [ [[0, 0, 50, 50], [0, 0, 50, 50]], [[0, 0, 50, 50], [0, 0, 50, 50]], [[0, 0, 50, 50], [0, 0, 50, 50]], @@ -60,20 +60,20 @@ def test_associate_chunks_and_internals_to_workers(): Interval(0, 0, 33, 33), ] - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_workers( + workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, ) - assert chunks_per_ranks == [[0, 1], [1, 2], [2, 3, 4, 5], [5, 6, 7]] - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[0]]) == 105 - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[1]]) == 105 - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[2]]) == 105 - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[3]]) == 105 + assert workers_chunks == [[0, 1], [1, 2], [2, 3, 4, 5], [5, 6, 7]] + assert sum([interval[2] - interval[1] for interval in workers_intervals[0]]) == 105 + assert sum([interval[2] - interval[1] for interval in workers_intervals[1]]) == 105 + assert sum([interval[2] - interval[1] for interval in workers_intervals[2]]) == 105 + assert sum([interval[2] - interval[1] for interval in workers_intervals[3]]) == 105 - assert intervals_per_ranks == [ + assert workers_intervals == [ [[0, 0, 50, 50], [0, 0, 55, 150]], [[0, 55, 150, 150], [0, 0, 10, 50]], [[0, 10, 50, 50], [0, 0, 12, 12], [0, 0, 50, 50], [0, 0, 3, 27]], @@ -91,24 +91,24 @@ def test_associate_chunks_and_internals_to_workers(): Interval(0, 0, 1, 1), ] - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_workers( + workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, ) - assert chunks_per_ranks == [[0, 1], [1], [1, 2, 3, 4, 5], [5, 6, 7]] - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[0]]) == 64 - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[1]]) == 64 - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[2]]) == 64 - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[3]]) == 64 - assert intervals_per_ranks == [ + assert workers_chunks == [[0, 1], [1], [1, 2, 3, 4, 5], [5, 6, 7]] + assert sum([interval[2] - interval[1] for interval in workers_intervals[0]]) == 64 + assert sum([interval[2] - interval[1] for interval in workers_intervals[1]]) == 64 + assert sum([interval[2] - interval[1] for interval in workers_intervals[2]]) == 64 + assert sum([interval[2] - interval[1] for interval in workers_intervals[3]]) == 64 + assert workers_intervals == [ [[0, 0, 5, 5], [0, 0, 59, 150]], [[0, 59, 123, 150]], [[0, 123, 150, 150], [0, 0, 7, 7], [0, 0, 12, 12], [0, 0, 4, 4], [0, 0, 14, 27]], [[0, 14, 27, 27], [0, 0, 50, 50], [0, 0, 1, 1]], ] - disable_deletion_ranks = _find_chunks_per_ranks_on_which_to_skip_deletion(1, chunks_per_ranks, intervals_per_ranks) + disable_deletion_ranks = _find_chunks_per_ranks_on_which_to_skip_deletion(1, workers_chunks, workers_intervals) assert disable_deletion_ranks == {1: [1], 2: [1], 3: [5]} From 99ca28030e9b44dd4234746ca477d4ef4ad8a7d5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 17 Jul 2024 00:27:30 +0200 Subject: [PATCH 18/63] clean up dataset.py --- src/litdata/streaming/dataset.py | 63 +++++--------------------------- tests/streaming/test_dataset.py | 1 - 2 files changed, 9 insertions(+), 55 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index d18f5369..090adba6 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -210,20 +210,20 @@ def __iter__(self) -> "StreamingDataset": ) worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank - # print(f"{worker_rank=}, len: {len(workers_chunks)}") - worker_chunks = workers_chunks[worker_rank] - worker_intervals = workers_intervals[worker_rank] + self.worker_chunks = workers_chunks[worker_rank] + self.worker_intervals = workers_intervals[worker_rank] + + # The max number of samples to return from `__next__` (in worker) + self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) # Handle restart if self._state_dict: - # breakpoint() self._resume(workers_chunks, workers_intervals) else: + # TODO: Reimplement this logic # Find the chunks shared across multiple ranks. # For each shared chunk, find the rank to use the chunk last and prevent deletion # for the other ranks. - # TODO better name for worker_start end - # TODO: reimplement this logic # worker_start = self.distributed_env.global_rank * self.num_workers # worker_end = worker_start + self.num_workers # chunks_indexes_skip_deletion = _find_chunks_per_ranks_on_which_to_skip_deletion( @@ -234,9 +234,6 @@ def __iter__(self) -> "StreamingDataset": # self.distributed_env.global_rank # ] - self.worker_chunks = worker_chunks - self.worker_intervals = worker_intervals - self.num_chunks = len(self.worker_chunks) self.current_indexes = [] self.chunk_index = 0 @@ -266,14 +263,10 @@ def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> No # replay sampling from each worker / chunks using the batch size indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers) - - # print(f"indexes1 = {indexes}") - - # TODO: Change _replay_chunks_sampling to accept a list chunks_index, indexes = _replay_chunks_sampling( - {i: workers_intervals[i] for i in range(worker_start, worker_end)}, indexes + workers_intervals={i: workers_intervals[i] for i in range(worker_start, worker_end)}, + indexes=indexes, ) - # print(f"{indexes=}, {chunks_index=}") # select the chunks and intervals associated to this worker worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank @@ -295,7 +288,6 @@ def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> No current_indexes = current_indexes[indexes[worker_local_rank] :] self.current_indexes = current_indexes - # print(f"currentindexes = {current_indexes}") self.global_index = indexes[worker_local_rank] # bump the chunk_index @@ -317,16 +309,10 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: def __next__(self) -> Any: # Prevent to create more batch on a given process - # print(torch.distributed.get_rank(), self.global_index, len(self), self.stop_length) - worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank - stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) - # print(f"{worker_rank}, {self.global_index=}, {stop_length=}") - # TODO: This is stopping too early, length is not correct - if self.global_index >= stop_length: + if self.global_index >= self.stop_length: self.current_epoch += 1 raise StopIteration - # print(f"{self.num_chunks=}") # Lazily re-populate the interval to reduce memory usage. if len(self.current_indexes) == 0: if self.chunk_index == self.num_chunks: @@ -487,37 +473,6 @@ def is_integer(value: str) -> bool: except Exception: return False - -# TODO: remove -def _associate_chunks_to_workers( - worker_env: _WorkerEnv, chunks_replica: List[int], intervals_replica: List[Any] -) -> Any: - workers_chunks = {} - workers_intervals = {} - - for worker_idx in range(worker_env.world_size): - worker_chunks = [] - worker_intervals = [] - for i, (chunk_index, chunk_interval) in enumerate(zip(chunks_replica, intervals_replica)): - if i % worker_env.world_size != worker_idx: - continue - - worker_chunks.append(chunk_index) - worker_intervals.append(chunk_interval) - - workers_chunks[worker_idx] = worker_chunks - workers_intervals[worker_idx] = worker_intervals - - # print( - # "associate", - # [ - # sum(interval[2] - interval[1] for interval in intervals) - # for worker_id, intervals in workers_intervals.items() - # ], - # ) - return workers_chunks, workers_intervals - - def _replay_sampling(num_samples_yielded: int, batch_size: int, num_workers: int) -> Dict[int, int]: """This function replays the sampling from the dataloader.""" divisible_num_batches_yielded = num_samples_yielded // (num_workers * batch_size) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 268a4ca2..9486e9a9 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -32,7 +32,6 @@ _INDEX_FILENAME, Dir, StreamingDataset, - _associate_chunks_to_workers, _replay_chunks_sampling, _replay_sampling, ) From 4130f50e07151329c3fe2186db10a1f90a3b2b89 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jul 2024 22:27:42 +0000 Subject: [PATCH 19/63] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 090adba6..62c4228c 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -473,6 +473,7 @@ def is_integer(value: str) -> bool: except Exception: return False + def _replay_sampling(num_samples_yielded: int, batch_size: int, num_workers: int) -> Dict[int, int]: """This function replays the sampling from the dataloader.""" divisible_num_batches_yielded = num_samples_yielded // (num_workers * batch_size) From 0053486c8ec9a09377a4971e569a1272d831b4f2 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 17 Jul 2024 00:31:31 +0200 Subject: [PATCH 20/63] clean up shuffle --- src/litdata/streaming/shuffle.py | 2 +- src/litdata/utilities/shuffle.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index 357fa985..fb5ed25b 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -35,7 +35,7 @@ def __init__(self, cache: Cache, seed: int, drop_last: bool): @lru_cache(maxsize=10) def get_len(self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int) -> int: - _, workers_intervals = self.get_chunks_and_intervals_per_workers( # TODO: rename + _, workers_intervals = self.get_chunks_and_intervals_per_workers( distributed_env, num_workers, batch_size, current_epoch ) worker_start = distributed_env.global_rank * num_workers diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index cbd9e662..4684838a 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -52,9 +52,7 @@ def _associate_chunks_and_internals_to_workers( batch_size: int = 1, ) -> Tuple[List[List[int]], List[Any]]: num_items = sum([(interval[2] - interval[1]) for interval in chunk_intervals]) - print(f"{num_items=}") world_size = distributed_env.world_size * num_workers - print("WORLD_SIZE=", world_size) num_items_per_workers: List[int] = [ num_items // world_size + num_items % world_size if rank == world_size - 1 and not drop_last @@ -65,7 +63,6 @@ def _associate_chunks_and_internals_to_workers( ratio = batch_size num_items_per_workers = [ratio * int(item // ratio) for item in num_items_per_workers] - print(f"{num_items_per_workers=}") chunks_per_workers: List[List[int]] = [[] for _ in range(world_size)] intervals_per_workers: List[List[List[int]]] = [[] for _ in range(world_size)] @@ -105,7 +102,6 @@ def _associate_chunks_and_internals_to_workers( num_items_per_workers[rank] -= items_in_chunk break - # print(drop_last, batch_size, num_workers, [sum(interval[2] - interval[1] for interval in intervals) for intervals in intervals_per_workers]) return chunks_per_workers, intervals_per_workers From 3e94cd66dbfe0bccec921b1169b4e8eeb534fed0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Jul 2024 12:40:07 +0000 Subject: [PATCH 21/63] Fix set_drop_last and test --- src/litdata/streaming/dataloader.py | 2 +- tests/streaming/test_dataset.py | 44 ++++++++++++++++++----------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 50ee71c1..4ad656db 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -555,7 +555,7 @@ def __init__( profile_dir: Optional[str] = None, prefetch_factor: Optional[int] = None, shuffle: Optional[bool] = None, - drop_last: Optional[bool] = False, + drop_last: Optional[bool] = None, collate_fn: Optional[Callable] = None, **kwargs: Any, ) -> None: # pyright: ignore diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 9486e9a9..3986a526 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -600,36 +600,46 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) L = len(dataset) - assert len(dataset) == L + assert L == 20 for i in range(L): sequence = dataset[i] assert sequence[0].item() == i * block_size assert sequence[-1].item() == (i + 1) * block_size - 1 + monkeypatch.setenv("WORLD_SIZE", "2") + monkeypatch.setenv("GLOBAL_RANK", "0") + monkeypatch.setenv("NNODES", "1") dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) + dataloader = StreamingDataLoader(dataset, batch_size=2, shuffle=False, num_workers=2) + assert dataset.drop_last # in distributed setting, this is forced automatically - dataset.distributed_env = _DistributedEnv(2, 0, 1) - dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2) - - assert len(dataloader) == 5 - - expected = [[0, 10], [20, 30], [40, 50], [60, 70], [80, 90]] + # L = 20, world size 2, num workers 2 + # L / (2 * 2) = 5 items per worker + # drop last -> 4 items per worker + # batch size = 2 -> 2 batches per worker -> len(dataloader) = 4 + assert len(dataloader) == 4 + expected = [[0, 10], [40, 50], [20, 30], [60, 70]] + returned = [] for batch_idx, batch in enumerate(dataloader): - x = [batch[0][0].item(), batch[1][0].item()] - # breakpoint() - # assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] + returned.append(batch[:, 0].tolist()) + assert returned == expected - dataset.distributed_env = _DistributedEnv(2, 1, 1) - dataloader = DataLoader(dataset, batch_size=2, shuffle=False) - - assert len(dataloader) == 5 - - expected = [[100, 110], [120, 130], [140, 150], [160, 170], [180, 190]] + monkeypatch.setenv("WORLD_SIZE", "2") + monkeypatch.setenv("GLOBAL_RANK", "1") + monkeypatch.setenv("NNODES", "1") + dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) + dataloader = StreamingDataLoader(dataset, batch_size=2, shuffle=False, num_workers=2) + assert dataset.drop_last # in distributed setting, this is forced automatically + assert len(dataloader) == 4 + + expected = [[80, 90], [120, 130], [100, 110], [140, 150]] + returned = [] for batch_idx, batch in enumerate(dataloader): - assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] + returned.append(batch[:, 0].tolist()) + assert returned == expected @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") From a3b9457f67e2b44ff37c4ba9aab8324d8c6af62b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jul 2024 12:40:36 +0000 Subject: [PATCH 22/63] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/streaming/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 3986a526..a32d6601 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -634,7 +634,7 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk assert dataset.drop_last # in distributed setting, this is forced automatically assert len(dataloader) == 4 - + expected = [[80, 90], [120, 130], [100, 110], [140, 150]] returned = [] for batch_idx, batch in enumerate(dataloader): From f52a501440091e826feef03b983a426074d8f72a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Jul 2024 14:15:05 +0000 Subject: [PATCH 23/63] fix epoch reshuffling test --- tests/streaming/test_dataset.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index a32d6601..0183f3ae 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -684,7 +684,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache: @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") -def test_resumable_dataset_two_workers(tmpdir): +def test_dataset_reshuffling_every_epoch(tmpdir): seed_everything(42) data_dir = os.path.join(tmpdir, "data") @@ -786,7 +786,6 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir): input_dir=Dir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=True ) - dataset.current_epoch = 1 dataloader = StreamingDataLoader(dataset, num_workers=2, batch_size=2, prefetch_factor=1, persistent_workers=True) batches_epoch_1 = [] @@ -800,9 +799,7 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir): batches_epoch_2.append(batch) assert len(os.listdir(cache_dir)) == 51 - - for batch_1, batch_2 in zip(batches_epoch_1, batches_epoch_2): - assert not torch.equal(batch_1, batch_2) + assert not all(torch.equal(b1, b2) for b1, b2 in zip(batches_epoch_1, batches_epoch_2)) def _simple_preprocess(_): From e621185f6476451be5f5e84ed87b328b3b1747c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Jul 2024 14:36:55 +0000 Subject: [PATCH 24/63] update combined test --- tests/streaming/test_combined.py | 90 +++----------------------------- 1 file changed, 8 insertions(+), 82 deletions(-) diff --git a/tests/streaming/test_combined.py b/tests/streaming/test_combined.py index 66bb3a4d..7079f564 100644 --- a/tests/streaming/test_combined.py +++ b/tests/streaming/test_combined.py @@ -568,43 +568,6 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "latest_worker_idx": 1, "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 0]}, }, - { - "dataset": { - "0": { - "num_samples_yielded": 8, - "num_workers": 3, - "batch_size": 2, - "current_epoch": 1, - "input_dir_path": ANY, - "input_dir_url": ANY, - "item_loader": None, - "drop_last": False, - "seed": 42, - "world_size": 1, - "shuffle": True, - "subsampled_files": ANY, - "region_of_interest": ANY, - }, - "1": { - "num_samples_yielded": 3, - "num_workers": 3, - "batch_size": 2, - "current_epoch": 1, - "input_dir_path": ANY, - "input_dir_url": ANY, - "item_loader": None, - "drop_last": False, - "seed": 42, - "world_size": 1, - "shuffle": True, - "subsampled_files": ANY, - "region_of_interest": ANY, - }, - }, - "current_epoch": 0, - "latest_worker_idx": 2, - "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 1]}, - }, { "dataset": { "0": { @@ -639,8 +602,8 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, }, "current_epoch": 0, - "latest_worker_idx": 0, - "num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [2, 1]}, + "latest_worker_idx": 2, + "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]}, }, { "dataset": { @@ -676,8 +639,8 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, }, "current_epoch": 0, - "latest_worker_idx": 1, - "num_samples_yielded": {0: [4, 1], 1: [4, 1], 2: [2, 1]}, + "latest_worker_idx": 0, + "num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [3, 1]}, }, ] @@ -867,43 +830,6 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "latest_worker_idx": 1, "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 0]}, }, - { - "dataset": { - "0": { - "num_samples_yielded": 8, - "num_workers": 3, - "batch_size": 2, - "current_epoch": 2, - "input_dir_path": ANY, - "input_dir_url": ANY, - "item_loader": None, - "drop_last": False, - "seed": 42, - "world_size": 1, - "shuffle": True, - "subsampled_files": ANY, - "region_of_interest": ANY, - }, - "1": { - "num_samples_yielded": 3, - "num_workers": 3, - "batch_size": 2, - "current_epoch": 2, - "input_dir_path": ANY, - "input_dir_url": ANY, - "item_loader": None, - "drop_last": False, - "seed": 42, - "world_size": 1, - "shuffle": True, - "subsampled_files": ANY, - "region_of_interest": ANY, - }, - }, - "current_epoch": 1, - "latest_worker_idx": 2, - "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 1]}, - }, { "dataset": { "0": { @@ -938,8 +864,8 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, }, "current_epoch": 1, - "latest_worker_idx": 0, - "num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [2, 1]}, + "latest_worker_idx": 2, + "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]}, }, { "dataset": { @@ -975,8 +901,8 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, }, "current_epoch": 1, - "latest_worker_idx": 1, - "num_samples_yielded": {0: [4, 1], 1: [4, 1], 2: [2, 1]}, + "latest_worker_idx": 0, + "num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [3, 1]}, }, ] From 064166693233ab6d55bbd7fcae88869f57d53690 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Jul 2024 14:52:34 +0000 Subject: [PATCH 25/63] update replay test --- tests/streaming/test_dataset.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 0183f3ae..b42162ca 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -39,6 +39,7 @@ from litdata.streaming.shuffle import FullShuffle, NoShuffle from litdata.utilities import dataset_utilities as dataset_utilities_module from litdata.utilities.env import _DistributedEnv, _WorkerEnv +from litdata.utilities.shuffle import _associate_chunks_and_internals_to_workers from torch.utils.data import DataLoader @@ -985,14 +986,15 @@ def test_replay_sampling(): def test_replay_chunks_sampling(): chunks_replica = range(10) intervals_replica = [(i, i, i + 5, i + 5) for i in range(0, 50, 5)] - workers_chunks, workers_intervals = _associate_chunks_to_workers( - _WorkerEnv(2, 0), chunks_replica, intervals_replica + workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( + _DistributedEnv(2, 0, 1), chunks_replica, intervals_replica ) - assert workers_chunks == {0: [0, 2, 4, 6, 8], 1: [1, 3, 5, 7, 9]} - assert workers_intervals == { - 0: [(0, 0, 5, 5), (10, 10, 15, 15), (20, 20, 25, 25), (30, 30, 35, 35), (40, 40, 45, 45)], - 1: [(5, 5, 10, 10), (15, 15, 20, 20), (25, 25, 30, 30), (35, 35, 40, 40), (45, 45, 50, 50)], - } + assert workers_chunks == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] + assert workers_intervals == [ + [[0, 0, 5, 5], [5, 5, 10, 10], [10, 10, 15, 15], [15, 15, 20, 20], [20, 20, 25, 25]], + [[25, 25, 30, 30], [30, 30, 35, 35], [35, 35, 40, 40], [40, 40, 45, 45], [45, 45, 50, 50]] + ] + workers_intervals = {i: workers_intervals[i] for i in range(len(workers_intervals))} assert _replay_chunks_sampling(workers_intervals, {0: 16, 1: 11}) == ({0: 3, 1: 2}, {0: 1, 1: 1}) assert _replay_chunks_sampling(workers_intervals, {0: 14, 1: 13}) == ({0: 2, 1: 2}, {0: 4, 1: 3}) assert _replay_chunks_sampling(workers_intervals, {0: 15, 1: 12}) == ({0: 3, 1: 2}, {0: 0, 1: 2}) From 9b4b807be864e67bffb6ff6108044601d35046db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Jul 2024 14:52:57 +0000 Subject: [PATCH 26/63] set default --- src/litdata/utilities/shuffle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 4684838a..cdeb50d5 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -47,7 +47,7 @@ def _associate_chunks_and_internals_to_workers( distributed_env: _DistributedEnv, indexes: Any, chunk_intervals: List[Interval], - drop_last: bool, + drop_last: bool = False, num_workers: int = 1, batch_size: int = 1, ) -> Tuple[List[List[int]], List[Any]]: From 4aa15da8b17e4e743b1cb3d4e595c73c26297266 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jul 2024 14:55:17 +0000 Subject: [PATCH 27/63] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/streaming/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index b42162ca..c23df3fd 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -992,7 +992,7 @@ def test_replay_chunks_sampling(): assert workers_chunks == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] assert workers_intervals == [ [[0, 0, 5, 5], [5, 5, 10, 10], [10, 10, 15, 15], [15, 15, 20, 20], [20, 20, 25, 25]], - [[25, 25, 30, 30], [30, 30, 35, 35], [35, 35, 40, 40], [40, 40, 45, 45], [45, 45, 50, 50]] + [[25, 25, 30, 30], [30, 30, 35, 35], [35, 35, 40, 40], [40, 40, 45, 45], [45, 45, 50, 50]], ] workers_intervals = {i: workers_intervals[i] for i in range(len(workers_intervals))} assert _replay_chunks_sampling(workers_intervals, {0: 16, 1: 11}) == ({0: 3, 1: 2}, {0: 1, 1: 1}) From f8fe74cd67371e7cee94454842232093938b11d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Jul 2024 15:07:48 +0000 Subject: [PATCH 28/63] clean up --- src/litdata/utilities/shuffle.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index cdeb50d5..df848998 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -19,7 +19,6 @@ from litdata.utilities.env import _DistributedEnv -# TODO: Logic needs to be updated? chunks_per_ranks -> workers_chunks def _intra_node_chunk_shuffle( distributed_env: _DistributedEnv, chunks_per_ranks: List[List[int]], @@ -60,8 +59,7 @@ def _associate_chunks_and_internals_to_workers( for rank in range(world_size) ] if drop_last: - ratio = batch_size - num_items_per_workers = [ratio * int(item // ratio) for item in num_items_per_workers] + num_items_per_workers = [batch_size * int(item // batch_size) for item in num_items_per_workers] chunks_per_workers: List[List[int]] = [[] for _ in range(world_size)] intervals_per_workers: List[List[List[int]]] = [[] for _ in range(world_size)] From 0eec395db86c574bb6e3975930148d33f5a8c758 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Jul 2024 15:12:47 +0000 Subject: [PATCH 29/63] update test --- tests/streaming/test_dataset.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index c23df3fd..cf91420f 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -505,20 +505,22 @@ def test_dataset_for_text_tokens_multiple_workers(tmpdir): assert len(dataloader) == 10 expected = [ - [0, 10], - [40, 50], - [20, 30], - [60, 70], - [80, 90], - [120, 130], - [100, 110], - [140, 150], - [160, 170], + [0, 10], + [100, 110], + [20, 30], + [120, 130], + [40, 50], + [140, 150], + [60, 70], + [160, 170], + [80, 90], [180, 190], ] - for result, batch in zip(expected, dataloader): - assert [batch[0][0].item(), batch[1][0].item()] == result + result = [] + for batch in dataloader: + result.append(batch[:, 0].tolist()) + assert result == expected def test_dataset_for_text_tokens_distributed_num_workers(tmpdir): @@ -623,7 +625,7 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk expected = [[0, 10], [40, 50], [20, 30], [60, 70]] returned = [] - for batch_idx, batch in enumerate(dataloader): + for batch in dataloader: returned.append(batch[:, 0].tolist()) assert returned == expected @@ -638,7 +640,7 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk expected = [[80, 90], [120, 130], [100, 110], [140, 150]] returned = [] - for batch_idx, batch in enumerate(dataloader): + for batch in dataloader: returned.append(batch[:, 0].tolist()) assert returned == expected From 4d9befdc5d9e7282d6ffdd71537491d6dfd649d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Jul 2024 16:16:06 +0000 Subject: [PATCH 30/63] disable profiler test for now --- tests/streaming/test_dataloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index af77ecb8..f0a5e138 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -94,6 +94,7 @@ def test_streaming_dataloader(): } +@pytest.mark.skip(reason="Profiling patches torch which leads to undesired test interactions") @pytest.mark.skipif(not _VIZ_TRACKER_AVAILABLE, reason="viz tracker required") @pytest.mark.parametrize("profile", [2, True]) def test_dataloader_profiling(profile, tmpdir, monkeypatch): From fa4237b49fff7377a85d956a6f462936870dc249 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Jul 2024 16:33:06 +0000 Subject: [PATCH 31/63] update --- tests/streaming/test_dataset.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index cf91420f..dee381f0 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -845,8 +845,6 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): batches_to_fetch = 16 batch_to_resume_from = None dataloader_state = None - assert len(train_dataloader.dataset) == 80 - assert len(train_dataloader) == 40 for i, batch in enumerate(train_dataloader): if i == batches_to_fetch: From 92d52d31cdf7007925fc7601d4d138914f627dfd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jul 2024 16:35:01 +0000 Subject: [PATCH 32/63] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/streaming/test_dataset.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index dee381f0..3f07e237 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -505,15 +505,15 @@ def test_dataset_for_text_tokens_multiple_workers(tmpdir): assert len(dataloader) == 10 expected = [ - [0, 10], - [100, 110], - [20, 30], - [120, 130], - [40, 50], - [140, 150], - [60, 70], - [160, 170], - [80, 90], + [0, 10], + [100, 110], + [20, 30], + [120, 130], + [40, 50], + [140, 150], + [60, 70], + [160, 170], + [80, 90], [180, 190], ] From ace58fbde53e5a990fffc0163318d0927ed24158 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Jul 2024 20:27:40 +0000 Subject: [PATCH 33/63] fix type --- src/litdata/streaming/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 62c4228c..69cf4f93 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -245,7 +245,7 @@ def __iter__(self) -> "StreamingDataset": return self - def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> None: + def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any]) -> None: assert self._state_dict assert self.worker_env assert self.shuffler From 53c2cf4aa6d92302f11ed5d5130f2c8b35e6498d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 09:46:55 +0000 Subject: [PATCH 34/63] fix with thomas --- src/litdata/streaming/dataset.py | 35 ++++--- src/litdata/utilities/shuffle.py | 164 +++++++++++++++++++++---------- tests/utilities/test_shuffle.py | 84 +++++++++++++++- 3 files changed, 217 insertions(+), 66 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 69cf4f93..805da033 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -31,6 +31,7 @@ from litdata.streaming.shuffle import FullShuffle, NoShuffle, Shuffle from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv +from litdata.utilities.shuffle import _find_chunks_per_workers_on_which_to_skip_deletion, _map_node_worker_rank_to_chunk_indexes_to_not_delete logger = Logger(__name__) @@ -220,19 +221,27 @@ def __iter__(self) -> "StreamingDataset": if self._state_dict: self._resume(workers_chunks, workers_intervals) else: - # TODO: Reimplement this logic - # Find the chunks shared across multiple ranks. - # For each shared chunk, find the rank to use the chunk last and prevent deletion - # for the other ranks. - # worker_start = self.distributed_env.global_rank * self.num_workers - # worker_end = worker_start + self.num_workers - # chunks_indexes_skip_deletion = _find_chunks_per_ranks_on_which_to_skip_deletion( - # self.worker_env.world_size, workers_chunks[worker_start: worker_end], workers_intervals[worker_start: worker_end] - # ) - # if self.distributed_env.global_rank in chunks_indexes_skip_deletion: - # self.cache._reader.config.skip_chunk_indexes_deletion = chunks_indexes_skip_deletion[ - # self.distributed_env.global_rank - # ] + # Find the chunks shared across all workers of the current node. + # For each shared chunk, find the rank and worker to use the chunk last and prevent + # premature deletion for the other workers. + node_size = self.distributed_env.world_size // self.distributed_env.num_nodes + first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size + num_workers_per_node = node_size * self.num_workers + worker_start = first_rank_this_node * num_workers_per_node + worker_end = worker_start + num_workers_per_node + local_rank = self.distributed_env.global_rank % node_size + + chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion( + self.num_workers, + self.batch_size, + workers_chunks[worker_start: worker_end], workers_intervals[worker_start: worker_end], + ) + worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete(chunks_indexes_skip_deletion) + + worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank + + if worker_rank_local_node in worker_node_rank_to_chunk_indexes: + self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[worker_rank_local_node] self.num_chunks = len(self.worker_chunks) self.current_indexes = [] diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index df848998..71068a60 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -13,6 +13,7 @@ from typing import Any, Dict, List, Tuple +import copy import numpy as np from litdata.streaming.item_loader import Interval @@ -103,64 +104,127 @@ def _associate_chunks_and_internals_to_workers( return chunks_per_workers, intervals_per_workers -def _find_chunks_per_ranks_on_which_to_skip_deletion( - num_workers: int, chunks_per_ranks: List[List[int]], intervals_per_ranks: List[Any] +def _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers: int, + batch_size: int, + workers_chunks: List[List[int]], + workers_intervals: List[List[int]], ) -> Dict[int, List[int]]: - # TODO: Add support for the real batch size - batch_size = 1 - shared_chunks = {} - for rank, chunks in enumerate(chunks_per_ranks): - for c in chunks: - if c not in shared_chunks: - shared_chunks[c] = [rank] - else: - shared_chunks[c].append(rank) - - shared_chunks = {c: ranks for c, ranks in shared_chunks.items() if len(ranks) > 1} - disable_deletion_ranks = {} + # {1: [2, 3, 4, 5]} + # [2, 3] belongs to rank 0 + # [4, 5] belongs to rank 1 + shared_chunks = _get_shared_chunks(workers_chunks) + + # workers_index_sharing_chunks + # {1: (0, [2, 3], (1, [4, 5]))} + shared_chunks_aggregated_by_rank = _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) + + # breakpoint() + + max_trackers = {} + to_disable = {} + for chunk_index, map_local_rank_to_worker_ids in shared_chunks_aggregated_by_rank.items(): + for local_rank, workers_index_sharing_chunks_for_this_rank in map_local_rank_to_worker_ids.items(): + + # get all the worker chunks and intervals for this distributed rank + workers_slice = slice(local_rank * num_workers, (local_rank + 1) * num_workers) + workers_chunks_for_this_rank = copy.deepcopy(workers_chunks[workers_slice]) + workers_intervals_for_this_rank = copy.deepcopy( # TODO: rename + [[interval[2] - interval[1] for interval in worker_intervals] for worker_intervals in workers_intervals[workers_slice]] + ) + + num_shared_workers_for_this_rank = len(workers_index_sharing_chunks_for_this_rank) + worker_tracker_idx = 0 + num_of_samples_to_carry_to_next_chunk = None + counter = 0 - for shared_chunk, ranks in shared_chunks.items(): - counters = [] - for rank in ranks: - chunks = chunks_per_ranks[rank] - intervals = [interval[2] - interval[1] for interval in intervals_per_ranks[rank]] + while True: + chunks_of_currently_loaded_worker = workers_chunks_for_this_rank[worker_tracker_idx % num_workers] + intervals_of_currently_loaded_worker = workers_intervals_for_this_rank[worker_tracker_idx % num_workers] + if len(intervals_of_currently_loaded_worker) == 0: + worker_tracker_idx += 1 + continue + + num_samples_left_for_this_worker_chunk = intervals_of_currently_loaded_worker[0] + + remover = batch_size if num_of_samples_to_carry_to_next_chunk is None else num_of_samples_to_carry_to_next_chunk + + if num_samples_left_for_this_worker_chunk > remover: + + # We have consumed a batch, going to the next worker + workers_intervals_for_this_rank[worker_tracker_idx % num_workers][0] -= remover + counter += remover + num_of_samples_to_carry_to_next_chunk = None + else: + # We have consumed a batch, going to the next worker + current_worker_chunk_index = workers_chunks_for_this_rank[worker_tracker_idx % num_workers].pop(0) + workers_intervals_for_this_rank[worker_tracker_idx % num_workers].pop(0) + counter += remover + + if current_worker_chunk_index == chunk_index: + num_shared_workers_for_this_rank -= 1 + # breakpoint() + + # We consumed entirely the chunk of the worker we were tracking, let's break + # TODO: Maybe, we can prevent loading over and over for each worker + if num_shared_workers_for_this_rank == 0 and current_worker_chunk_index == chunk_index: + if chunk_index not in max_trackers: + max_trackers[chunk_index] = (local_rank * num_workers + worker_tracker_idx % num_workers, counter) + else: + if max_trackers[chunk_index][1] < counter: + max_trackers[chunk_index] = (local_rank * num_workers + worker_tracker_idx % num_workers, counter) + + break - workers_chunks: Any = [[] for _ in range(num_workers)] - workers_intervals: Any = [[] for _ in range(num_workers)] - for interval_idx, (c, i) in enumerate(zip(chunks, intervals)): - workers_chunks[interval_idx % num_workers].append(c) - workers_intervals[interval_idx % num_workers].append(i) + if num_samples_left_for_this_worker_chunk != batch_size: + num_of_samples_to_carry_to_next_chunk = batch_size - num_samples_left_for_this_worker_chunk - counter = 0 - worker_idx = 0 # reset the worker_idx - while True: - current_chunks = workers_chunks[worker_idx] - current_intervals = workers_intervals[worker_idx] + if remover != batch_size: + num_of_samples_to_carry_to_next_chunk = None - if len(current_intervals) == 0: - break + if num_of_samples_to_carry_to_next_chunk is None: + worker_tracker_idx += 1 - if current_intervals[0] > batch_size: - current_intervals[0] -= batch_size - counter += batch_size - worker_idx = (worker_idx + 1) % num_workers - else: - counter += current_intervals[0] - current_intervals.pop(0) - current_chunk = current_chunks.pop(0) - worker_idx = (worker_idx + 1) % num_workers + # else: + # # I don't know if this is possible + # break + - if current_chunk == shared_chunk: - break + for chunk_index, worker_ids in shared_chunks.items(): + last_worker_idx = max_trackers[chunk_index][0] + to_disable[chunk_index] = [worker_idx for worker_idx in worker_ids if worker_idx != last_worker_idx] + return to_disable - counters.append(counter) - max_counter = np.argmax(counters) - disable_ranks = [rank for rank in ranks if rank != ranks[max_counter]] - for rank in disable_ranks: - if rank not in disable_deletion_ranks: - disable_deletion_ranks[rank] = [shared_chunk] +def _get_shared_chunks(workers_chunks: List[List[int]]) -> Dict[int, List[int]]: + shared_chunks = {} + for worker, chunks in enumerate(workers_chunks): + for chunk in chunks: + if chunk not in shared_chunks: + shared_chunks[chunk] = [worker] else: - disable_deletion_ranks[rank].append(shared_chunk) - return disable_deletion_ranks + shared_chunks[chunk].append(worker) + return {chunk: workers for chunk, workers in shared_chunks.items() if len(workers) > 1} + + +def _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) -> Dict[int, List[int]]: + aggregated_shared_chunks_per_rank = {} + for chunk_index, workers_ids in shared_chunks.items(): + aggregated_shared_chunks_per_rank[chunk_index] = {} + for worker_idx in workers_ids: + if (worker_idx // num_workers) not in aggregated_shared_chunks_per_rank[chunk_index]: + aggregated_shared_chunks_per_rank[chunk_index][worker_idx // num_workers] = [] + aggregated_shared_chunks_per_rank[chunk_index][worker_idx // num_workers].append(worker_idx) + return aggregated_shared_chunks_per_rank + + +def _map_node_worker_rank_to_chunk_indexes_to_not_delete(to_disable): + map_node_worker_rank_to_chunk_indexes = {} + for chunk_index, worker_ids in to_disable.items(): + for worker_idx in worker_ids: + if worker_idx not in map_node_worker_rank_to_chunk_indexes: + map_node_worker_rank_to_chunk_indexes[worker_idx] = [] + map_node_worker_rank_to_chunk_indexes[worker_idx].append(chunk_index) + return map_node_worker_rank_to_chunk_indexes + \ No newline at end of file diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index 4f602145..35bd89c1 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -2,8 +2,9 @@ from litdata.utilities.env import _DistributedEnv from litdata.utilities.shuffle import ( _associate_chunks_and_internals_to_workers, - _find_chunks_per_ranks_on_which_to_skip_deletion, + _find_chunks_per_workers_on_which_to_skip_deletion, _intra_node_chunk_shuffle, + _get_shared_chunks, ) @@ -110,5 +111,82 @@ def test_associate_chunks_and_internals_to_workers(): [[0, 14, 27, 27], [0, 0, 50, 50], [0, 0, 1, 1]], ] - disable_deletion_ranks = _find_chunks_per_ranks_on_which_to_skip_deletion(1, workers_chunks, workers_intervals) - assert disable_deletion_ranks == {1: [1], 2: [1], 3: [5]} + +def test_get_shared_chunks(): + assert _get_shared_chunks([]) == {} + assert _get_shared_chunks([[0]]) == {} + assert _get_shared_chunks([[0], [1]]) == {} + assert _get_shared_chunks([[0], [0, 1]]) == {0: [0, 1]} # chunk 0 is shared by worker 0 and 1 + assert _get_shared_chunks([[0, 1], [1]]) == {1: [0, 1]} # chunk 1 is shared by worker 0 and 1 + assert _get_shared_chunks([[2], [0, 1], [2, 3]]) == {2: [0, 2]} + assert _get_shared_chunks([[2], [0, 1], [2, 3], [1, 4], [1]]) == {1: [1, 3, 4], 2: [0, 2]} + + +def test_find_chunks_per_workers_on_which_to_skip_deletion(): + # world size = 1, no shared chunks + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=2, + batch_size=1, + workers_chunks=[[0], [1]], + workers_intervals=[[(0, 0, 50, 50)], [(0, 0, 50, 50)]], + ) + assert chunks_to_disable == {} + + # world size = 1, batch size 5, no shared chunks + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=2, + batch_size=5, + workers_chunks=[[0], [1]], + workers_intervals=[[(0, 0, 50, 50)], [(0, 0, 50, 50)]], + ) + assert chunks_to_disable == {} + + # world size = 1, batch size 5, shared chunks + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=2, + batch_size=5, + workers_chunks=[[0, 1], [1, 2]], + workers_intervals=[[(0, 0, 50, 50), (0, 0, 25, 50)], [(0, 25, 50, 50), (0, 0, 50, 50)]], + ) + assert chunks_to_disable == {1: [1]} + + # world size = 1, batch size 5, shared chunks + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=4, + batch_size=5, + workers_chunks=[[0], [0], [0], [0]], + workers_intervals=[[(0, 0, 50, 50)], [(0, 50, 100, 50)], [(0, 100, 150, 50)], [(0, 150, 200, 50)]], + ) + assert chunks_to_disable == {0: [0, 1, 2]} + + # world size = 1, batch size 5, shared chunks + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=4, + batch_size=5, + workers_chunks=[[0], [0], [0], [0]], + workers_intervals=[[(0, 0, 50, 50)], [(0, 50, 95, 50)], [(0, 95, 150, 50)], [(0, 150, 200, 50)]], + ) + assert chunks_to_disable == {0: [0, 1, 3]} + + + # world size = 1, batch size 5, shared chunks + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=4, + batch_size=5, + workers_chunks=[[0], [0], [0], [0]], + workers_intervals=[[(0, 0, 50, 50)], [(0, 50, 95, 50)], [(0, 95, 150, 50)], [(0, 150, 200, 50)]], + ) + assert chunks_to_disable == {0: [0, 1, 3]} + + for batch_size in range(1, 6): + # world size = 1, batch size 5, shared chunks + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=2, + batch_size=batch_size, + workers_chunks=[[0], [0], [0], [0]], + workers_intervals=[ + [(0, 0, 50, 50)], [(0, 50, 95, 50)], # local_rank 0 + [(0, 95, 145, 50)], [(0, 145, 205, 50)] # local_rank 1 + ], + ) + assert chunks_to_disable == {0: [0, 1, 2]} \ No newline at end of file From 322a69717663630014a6e70b217115f025aee15f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Jul 2024 09:47:18 +0000 Subject: [PATCH 35/63] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/dataset.py | 24 +++++++++----- src/litdata/utilities/shuffle.py | 54 ++++++++++++++++++-------------- tests/utilities/test_shuffle.py | 17 +++++----- 3 files changed, 56 insertions(+), 39 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 805da033..f5a5b22e 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -31,7 +31,10 @@ from litdata.streaming.shuffle import FullShuffle, NoShuffle, Shuffle from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv -from litdata.utilities.shuffle import _find_chunks_per_workers_on_which_to_skip_deletion, _map_node_worker_rank_to_chunk_indexes_to_not_delete +from litdata.utilities.shuffle import ( + _find_chunks_per_workers_on_which_to_skip_deletion, + _map_node_worker_rank_to_chunk_indexes_to_not_delete, +) logger = Logger(__name__) @@ -222,7 +225,7 @@ def __iter__(self) -> "StreamingDataset": self._resume(workers_chunks, workers_intervals) else: # Find the chunks shared across all workers of the current node. - # For each shared chunk, find the rank and worker to use the chunk last and prevent + # For each shared chunk, find the rank and worker to use the chunk last and prevent # premature deletion for the other workers. node_size = self.distributed_env.world_size // self.distributed_env.num_nodes first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size @@ -230,18 +233,23 @@ def __iter__(self) -> "StreamingDataset": worker_start = first_rank_this_node * num_workers_per_node worker_end = worker_start + num_workers_per_node local_rank = self.distributed_env.global_rank % node_size - + chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion( - self.num_workers, - self.batch_size, - workers_chunks[worker_start: worker_end], workers_intervals[worker_start: worker_end], + self.num_workers, + self.batch_size, + workers_chunks[worker_start:worker_end], + workers_intervals[worker_start:worker_end], + ) + worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete( + chunks_indexes_skip_deletion ) - worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete(chunks_indexes_skip_deletion) worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank if worker_rank_local_node in worker_node_rank_to_chunk_indexes: - self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[worker_rank_local_node] + self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[ + worker_rank_local_node + ] self.num_chunks = len(self.worker_chunks) self.current_indexes = [] diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 71068a60..b0ec91eb 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -11,9 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from typing import Any, Dict, List, Tuple -import copy import numpy as np from litdata.streaming.item_loader import Interval @@ -105,35 +105,36 @@ def _associate_chunks_and_internals_to_workers( def _find_chunks_per_workers_on_which_to_skip_deletion( - num_workers: int, - batch_size: int, - workers_chunks: List[List[int]], + num_workers: int, + batch_size: int, + workers_chunks: List[List[int]], workers_intervals: List[List[int]], ) -> Dict[int, List[int]]: - # {1: [2, 3, 4, 5]} - # [2, 3] belongs to rank 0 + # [2, 3] belongs to rank 0 # [4, 5] belongs to rank 1 shared_chunks = _get_shared_chunks(workers_chunks) - - # workers_index_sharing_chunks + + # workers_index_sharing_chunks # {1: (0, [2, 3], (1, [4, 5]))} shared_chunks_aggregated_by_rank = _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) - + # breakpoint() - + max_trackers = {} to_disable = {} for chunk_index, map_local_rank_to_worker_ids in shared_chunks_aggregated_by_rank.items(): for local_rank, workers_index_sharing_chunks_for_this_rank in map_local_rank_to_worker_ids.items(): - # get all the worker chunks and intervals for this distributed rank workers_slice = slice(local_rank * num_workers, (local_rank + 1) * num_workers) workers_chunks_for_this_rank = copy.deepcopy(workers_chunks[workers_slice]) workers_intervals_for_this_rank = copy.deepcopy( # TODO: rename - [[interval[2] - interval[1] for interval in worker_intervals] for worker_intervals in workers_intervals[workers_slice]] + [ + [interval[2] - interval[1] for interval in worker_intervals] + for worker_intervals in workers_intervals[workers_slice] + ] ) - + num_shared_workers_for_this_rank = len(workers_index_sharing_chunks_for_this_rank) worker_tracker_idx = 0 num_of_samples_to_carry_to_next_chunk = None @@ -145,13 +146,16 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( if len(intervals_of_currently_loaded_worker) == 0: worker_tracker_idx += 1 continue - + num_samples_left_for_this_worker_chunk = intervals_of_currently_loaded_worker[0] - - remover = batch_size if num_of_samples_to_carry_to_next_chunk is None else num_of_samples_to_carry_to_next_chunk - + + remover = ( + batch_size + if num_of_samples_to_carry_to_next_chunk is None + else num_of_samples_to_carry_to_next_chunk + ) + if num_samples_left_for_this_worker_chunk > remover: - # We have consumed a batch, going to the next worker workers_intervals_for_this_rank[worker_tracker_idx % num_workers][0] -= remover counter += remover @@ -170,11 +174,17 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( # TODO: Maybe, we can prevent loading over and over for each worker if num_shared_workers_for_this_rank == 0 and current_worker_chunk_index == chunk_index: if chunk_index not in max_trackers: - max_trackers[chunk_index] = (local_rank * num_workers + worker_tracker_idx % num_workers, counter) + max_trackers[chunk_index] = ( + local_rank * num_workers + worker_tracker_idx % num_workers, + counter, + ) else: if max_trackers[chunk_index][1] < counter: - max_trackers[chunk_index] = (local_rank * num_workers + worker_tracker_idx % num_workers, counter) - + max_trackers[chunk_index] = ( + local_rank * num_workers + worker_tracker_idx % num_workers, + counter, + ) + break if num_samples_left_for_this_worker_chunk != batch_size: @@ -189,7 +199,6 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( # else: # # I don't know if this is possible # break - for chunk_index, worker_ids in shared_chunks.items(): last_worker_idx = max_trackers[chunk_index][0] @@ -227,4 +236,3 @@ def _map_node_worker_rank_to_chunk_indexes_to_not_delete(to_disable): map_node_worker_rank_to_chunk_indexes[worker_idx] = [] map_node_worker_rank_to_chunk_indexes[worker_idx].append(chunk_index) return map_node_worker_rank_to_chunk_indexes - \ No newline at end of file diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index 35bd89c1..bb42af14 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -3,8 +3,8 @@ from litdata.utilities.shuffle import ( _associate_chunks_and_internals_to_workers, _find_chunks_per_workers_on_which_to_skip_deletion, - _intra_node_chunk_shuffle, _get_shared_chunks, + _intra_node_chunk_shuffle, ) @@ -149,7 +149,7 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): workers_intervals=[[(0, 0, 50, 50), (0, 0, 25, 50)], [(0, 25, 50, 50), (0, 0, 50, 50)]], ) assert chunks_to_disable == {1: [1]} - + # world size = 1, batch size 5, shared chunks chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( num_workers=4, @@ -158,7 +158,7 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): workers_intervals=[[(0, 0, 50, 50)], [(0, 50, 100, 50)], [(0, 100, 150, 50)], [(0, 150, 200, 50)]], ) assert chunks_to_disable == {0: [0, 1, 2]} - + # world size = 1, batch size 5, shared chunks chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( num_workers=4, @@ -168,7 +168,6 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): ) assert chunks_to_disable == {0: [0, 1, 3]} - # world size = 1, batch size 5, shared chunks chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( num_workers=4, @@ -177,7 +176,7 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): workers_intervals=[[(0, 0, 50, 50)], [(0, 50, 95, 50)], [(0, 95, 150, 50)], [(0, 150, 200, 50)]], ) assert chunks_to_disable == {0: [0, 1, 3]} - + for batch_size in range(1, 6): # world size = 1, batch size 5, shared chunks chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( @@ -185,8 +184,10 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): batch_size=batch_size, workers_chunks=[[0], [0], [0], [0]], workers_intervals=[ - [(0, 0, 50, 50)], [(0, 50, 95, 50)], # local_rank 0 - [(0, 95, 145, 50)], [(0, 145, 205, 50)] # local_rank 1 + [(0, 0, 50, 50)], + [(0, 50, 95, 50)], # local_rank 0 + [(0, 95, 145, 50)], + [(0, 145, 205, 50)], # local_rank 1 ], ) - assert chunks_to_disable == {0: [0, 1, 2]} \ No newline at end of file + assert chunks_to_disable == {0: [0, 1, 2]} From 28117a78f1471dee8591870c8a74b4b276d6d223 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 09:54:17 +0000 Subject: [PATCH 36/63] num_workers_or_1 --- src/litdata/streaming/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index f5a5b22e..d51423a7 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -229,13 +229,13 @@ def __iter__(self) -> "StreamingDataset": # premature deletion for the other workers. node_size = self.distributed_env.world_size // self.distributed_env.num_nodes first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size - num_workers_per_node = node_size * self.num_workers + num_workers_per_node = node_size * (self.num_workers or 1) worker_start = first_rank_this_node * num_workers_per_node worker_end = worker_start + num_workers_per_node local_rank = self.distributed_env.global_rank % node_size chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion( - self.num_workers, + (self.num_workers or 1), self.batch_size, workers_chunks[worker_start:worker_end], workers_intervals[worker_start:worker_end], @@ -244,7 +244,7 @@ def __iter__(self) -> "StreamingDataset": chunks_indexes_skip_deletion ) - worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank + worker_rank_local_node = local_rank * (self.num_workers or 1) + self.worker_env.rank if worker_rank_local_node in worker_node_rank_to_chunk_indexes: self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[ From 641ecae0574d5135439550bf8203ccb137434d16 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Jul 2024 11:29:53 +0100 Subject: [PATCH 37/63] update --- src/litdata/processing/data_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index f7d2ce89..6f5d1eb9 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -254,7 +254,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints os.makedirs(os.path.dirname(output_filepath), exist_ok=True) - shutil.move(local_filepath, output_filepath) + os.symlink(local_filepath, output_filepath) else: raise ValueError(f"The provided {output_dir.path} isn't supported.") From 73f376f9e059cc1923e2b2bf5ace1fa2749cb96a Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Jul 2024 11:40:07 +0100 Subject: [PATCH 38/63] update --- src/litdata/processing/data_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 6f5d1eb9..f37b3541 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -254,7 +254,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints os.makedirs(os.path.dirname(output_filepath), exist_ok=True) - os.symlink(local_filepath, output_filepath) + shutil.copy(local_filepath, output_filepath) else: raise ValueError(f"The provided {output_dir.path} isn't supported.") From d7d6dfa303e797af6a7c847fe782938c8cb8e910 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 11:20:27 +0000 Subject: [PATCH 39/63] extend test and delete duplicated test --- src/litdata/utilities/shuffle.py | 11 -------- tests/utilities/test_shuffle.py | 45 ++++++++++++++++++-------------- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index b0ec91eb..31c9d9e8 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -114,13 +114,8 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( # [2, 3] belongs to rank 0 # [4, 5] belongs to rank 1 shared_chunks = _get_shared_chunks(workers_chunks) - - # workers_index_sharing_chunks - # {1: (0, [2, 3], (1, [4, 5]))} shared_chunks_aggregated_by_rank = _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) - # breakpoint() - max_trackers = {} to_disable = {} for chunk_index, map_local_rank_to_worker_ids in shared_chunks_aggregated_by_rank.items(): @@ -168,7 +163,6 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( if current_worker_chunk_index == chunk_index: num_shared_workers_for_this_rank -= 1 - # breakpoint() # We consumed entirely the chunk of the worker we were tracking, let's break # TODO: Maybe, we can prevent loading over and over for each worker @@ -184,7 +178,6 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( local_rank * num_workers + worker_tracker_idx % num_workers, counter, ) - break if num_samples_left_for_this_worker_chunk != batch_size: @@ -196,10 +189,6 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( if num_of_samples_to_carry_to_next_chunk is None: worker_tracker_idx += 1 - # else: - # # I don't know if this is possible - # break - for chunk_index, worker_ids in shared_chunks.items(): last_worker_idx = max_trackers[chunk_index][0] to_disable[chunk_index] = [worker_idx for worker_idx in worker_ids if worker_idx != last_worker_idx] diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index bb42af14..9e67c592 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -123,16 +123,16 @@ def test_get_shared_chunks(): def test_find_chunks_per_workers_on_which_to_skip_deletion(): - # world size = 1, no shared chunks + # world size = 1, single worker chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( - num_workers=2, + num_workers=1, batch_size=1, - workers_chunks=[[0], [1]], - workers_intervals=[[(0, 0, 50, 50)], [(0, 0, 50, 50)]], + workers_chunks=[[0]], + workers_intervals=[[(0, 0, 50, 50)]], ) assert chunks_to_disable == {} - # world size = 1, batch size 5, no shared chunks + # world size = 1, multiple workers, no shared chunks chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( num_workers=2, batch_size=5, @@ -141,7 +141,7 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): ) assert chunks_to_disable == {} - # world size = 1, batch size 5, shared chunks + # world size = 1, 2 workers sharing one chunk chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( num_workers=2, batch_size=5, @@ -150,7 +150,7 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): ) assert chunks_to_disable == {1: [1]} - # world size = 1, batch size 5, shared chunks + # world size = 1, 4 workers sharing one chunk chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( num_workers=4, batch_size=5, @@ -159,16 +159,7 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): ) assert chunks_to_disable == {0: [0, 1, 2]} - # world size = 1, batch size 5, shared chunks - chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( - num_workers=4, - batch_size=5, - workers_chunks=[[0], [0], [0], [0]], - workers_intervals=[[(0, 0, 50, 50)], [(0, 50, 95, 50)], [(0, 95, 150, 50)], [(0, 150, 200, 50)]], - ) - assert chunks_to_disable == {0: [0, 1, 3]} - - # world size = 1, batch size 5, shared chunks + # world size = 1, 4 workers sharing one chunk, different size chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( num_workers=4, batch_size=5, @@ -177,17 +168,31 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): ) assert chunks_to_disable == {0: [0, 1, 3]} + # world size 2, 2 workers per rank, varying batch size for batch_size in range(1, 6): - # world size = 1, batch size 5, shared chunks chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( num_workers=2, batch_size=batch_size, workers_chunks=[[0], [0], [0], [0]], workers_intervals=[ [(0, 0, 50, 50)], - [(0, 50, 95, 50)], # local_rank 0 + [(0, 50, 95, 50)], [(0, 95, 145, 50)], - [(0, 145, 205, 50)], # local_rank 1 + [(0, 145, 205, 50)], # last to access chunk 0 ], ) assert chunks_to_disable == {0: [0, 1, 2]} + + # world size 2, 2 workers per rank, sharing multiple chunks + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=2, + batch_size=5, + workers_chunks=[[0, 1], [3, 4], [1, 2], [4, 5]], + workers_intervals=[ + [(0, 0, 50, 50), (0, 0, 50, 50)], + [(0, 0, 50, 50), (0, 0, 50, 50)], + [(0, 50, 100, 100), (0, 0, 50, 50)], + [(0, 50, 100, 100), (0, 0, 50, 50)], + ], + ) + assert chunks_to_disable == {1: [2], 4: [3]} From af74a3bc9ef79e8b7d8d0c6f8cedfdfa30f9aa4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 11:30:40 +0000 Subject: [PATCH 40/63] simplify `num_workers or 1` logic --- src/litdata/streaming/dataset.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index d51423a7..17ae2bbe 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -120,7 +120,9 @@ def __init__( self.shuffler: Optional[Shuffle] = None self.serializers = serializers self._state_dict: Optional[Dict[str, Any]] = None - self.num_workers: Optional[int] = None + # Has slightly different meaning in the context of the dataset + # We consider `num_workers = 0` from `torch.utils.DataLoader` still as 1 worker (the main process) + self.num_workers: int = 1 self.batch_size: Optional[int] = None def set_shuffle(self, shuffle: bool) -> None: @@ -177,13 +179,13 @@ def _create_shuffler(self, cache: Cache) -> Shuffle: return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last) def __len__(self) -> int: - return self.get_len(self.num_workers if self.num_workers else 1, self.batch_size if self.batch_size else 1) + return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1) def set_batch_size(self, batch_size: int) -> None: self.batch_size = batch_size def set_num_workers(self, num_workers: int) -> None: - self.num_workers = num_workers + self.num_workers = num_workers or 1 def get_len(self, num_workers: int, batch_size: int) -> int: self.num_workers = num_workers @@ -229,13 +231,13 @@ def __iter__(self) -> "StreamingDataset": # premature deletion for the other workers. node_size = self.distributed_env.world_size // self.distributed_env.num_nodes first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size - num_workers_per_node = node_size * (self.num_workers or 1) + num_workers_per_node = node_size * self.num_workers worker_start = first_rank_this_node * num_workers_per_node worker_end = worker_start + num_workers_per_node local_rank = self.distributed_env.global_rank % node_size chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion( - (self.num_workers or 1), + self.num_workers, self.batch_size, workers_chunks[worker_start:worker_end], workers_intervals[worker_start:worker_end], @@ -244,8 +246,7 @@ def __iter__(self) -> "StreamingDataset": chunks_indexes_skip_deletion ) - worker_rank_local_node = local_rank * (self.num_workers or 1) + self.worker_env.rank - + worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank if worker_rank_local_node in worker_node_rank_to_chunk_indexes: self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[ worker_rank_local_node From 5446fd66bff9e3e9ead5e855318ec4c6a3970164 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 11:53:05 +0000 Subject: [PATCH 41/63] mypy --- src/litdata/utilities/shuffle.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 31c9d9e8..6c0fdea6 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -117,7 +117,6 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( shared_chunks_aggregated_by_rank = _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) max_trackers = {} - to_disable = {} for chunk_index, map_local_rank_to_worker_ids in shared_chunks_aggregated_by_rank.items(): for local_rank, workers_index_sharing_chunks_for_this_rank in map_local_rank_to_worker_ids.items(): # get all the worker chunks and intervals for this distributed rank @@ -189,6 +188,7 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( if num_of_samples_to_carry_to_next_chunk is None: worker_tracker_idx += 1 + to_disable = {} for chunk_index, worker_ids in shared_chunks.items(): last_worker_idx = max_trackers[chunk_index][0] to_disable[chunk_index] = [worker_idx for worker_idx in worker_ids if worker_idx != last_worker_idx] @@ -206,7 +206,7 @@ def _get_shared_chunks(workers_chunks: List[List[int]]) -> Dict[int, List[int]]: return {chunk: workers for chunk, workers in shared_chunks.items() if len(workers) > 1} -def _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) -> Dict[int, List[int]]: +def _aggregate_shared_chunks_per_rank(shared_chunks: Dict[int, List[int]], num_workers: int) -> Dict[int, List[int]]: aggregated_shared_chunks_per_rank = {} for chunk_index, workers_ids in shared_chunks.items(): aggregated_shared_chunks_per_rank[chunk_index] = {} @@ -217,7 +217,7 @@ def _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) -> Dict[int, L return aggregated_shared_chunks_per_rank -def _map_node_worker_rank_to_chunk_indexes_to_not_delete(to_disable): +def _map_node_worker_rank_to_chunk_indexes_to_not_delete(to_disable: Dict[int, List[int]]) -> Dict[int, List[int]]: map_node_worker_rank_to_chunk_indexes = {} for chunk_index, worker_ids in to_disable.items(): for worker_idx in worker_ids: From 44572cfee52210b04e69d6a19e43ced95a4e434a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 11:54:59 +0000 Subject: [PATCH 42/63] todo rename --- src/litdata/utilities/shuffle.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 6c0fdea6..50dd388a 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -122,7 +122,7 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( # get all the worker chunks and intervals for this distributed rank workers_slice = slice(local_rank * num_workers, (local_rank + 1) * num_workers) workers_chunks_for_this_rank = copy.deepcopy(workers_chunks[workers_slice]) - workers_intervals_for_this_rank = copy.deepcopy( # TODO: rename + workers_interval_sizes_for_this_rank = copy.deepcopy( [ [interval[2] - interval[1] for interval in worker_intervals] for worker_intervals in workers_intervals[workers_slice] @@ -136,12 +136,12 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( while True: chunks_of_currently_loaded_worker = workers_chunks_for_this_rank[worker_tracker_idx % num_workers] - intervals_of_currently_loaded_worker = workers_intervals_for_this_rank[worker_tracker_idx % num_workers] - if len(intervals_of_currently_loaded_worker) == 0: + interval_size_of_current_worker = workers_interval_sizes_for_this_rank[worker_tracker_idx % num_workers] + if len(interval_size_of_current_worker) == 0: worker_tracker_idx += 1 continue - num_samples_left_for_this_worker_chunk = intervals_of_currently_loaded_worker[0] + num_samples_left_for_this_worker_chunk = interval_size_of_current_worker[0] remover = ( batch_size @@ -151,13 +151,13 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( if num_samples_left_for_this_worker_chunk > remover: # We have consumed a batch, going to the next worker - workers_intervals_for_this_rank[worker_tracker_idx % num_workers][0] -= remover + workers_interval_sizes_for_this_rank[worker_tracker_idx % num_workers][0] -= remover counter += remover num_of_samples_to_carry_to_next_chunk = None else: # We have consumed a batch, going to the next worker current_worker_chunk_index = workers_chunks_for_this_rank[worker_tracker_idx % num_workers].pop(0) - workers_intervals_for_this_rank[worker_tracker_idx % num_workers].pop(0) + workers_interval_sizes_for_this_rank[worker_tracker_idx % num_workers].pop(0) counter += remover if current_worker_chunk_index == chunk_index: From 533ba42c214b8131c81c3385aaf6db0842fb8554 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 12:03:28 +0000 Subject: [PATCH 43/63] mypy --- src/litdata/streaming/dataset.py | 4 ++-- src/litdata/utilities/shuffle.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 17ae2bbe..333720a9 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -123,7 +123,7 @@ def __init__( # Has slightly different meaning in the context of the dataset # We consider `num_workers = 0` from `torch.utils.DataLoader` still as 1 worker (the main process) self.num_workers: int = 1 - self.batch_size: Optional[int] = None + self.batch_size: int = 1 def set_shuffle(self, shuffle: bool) -> None: self.shuffle = shuffle @@ -212,7 +212,7 @@ def __iter__(self) -> "StreamingDataset": self.current_epoch = state["current_epoch"] workers_chunks, workers_intervals = self.shuffler.get_chunks_and_intervals_per_workers( - self.distributed_env, self.worker_env.world_size, self.batch_size or 1, self.current_epoch + self.distributed_env, self.worker_env.world_size, self.batch_size, self.current_epoch ) worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 50dd388a..4b48a714 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -206,8 +206,8 @@ def _get_shared_chunks(workers_chunks: List[List[int]]) -> Dict[int, List[int]]: return {chunk: workers for chunk, workers in shared_chunks.items() if len(workers) > 1} -def _aggregate_shared_chunks_per_rank(shared_chunks: Dict[int, List[int]], num_workers: int) -> Dict[int, List[int]]: - aggregated_shared_chunks_per_rank = {} +def _aggregate_shared_chunks_per_rank(shared_chunks: Dict[int, List[int]], num_workers: int) -> Dict[int, Dict[int, List[int]]]: + aggregated_shared_chunks_per_rank: Dict[int, Dict[int, List[int]]] = {} for chunk_index, workers_ids in shared_chunks.items(): aggregated_shared_chunks_per_rank[chunk_index] = {} for worker_idx in workers_ids: @@ -218,7 +218,7 @@ def _aggregate_shared_chunks_per_rank(shared_chunks: Dict[int, List[int]], num_w def _map_node_worker_rank_to_chunk_indexes_to_not_delete(to_disable: Dict[int, List[int]]) -> Dict[int, List[int]]: - map_node_worker_rank_to_chunk_indexes = {} + map_node_worker_rank_to_chunk_indexes: Dict[int, List[int]] = {} for chunk_index, worker_ids in to_disable.items(): for worker_idx in worker_ids: if worker_idx not in map_node_worker_rank_to_chunk_indexes: From dba782eb493ae06475d3792021c269af1b1e11ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 12:11:35 +0000 Subject: [PATCH 44/63] Fix typeerror --- src/litdata/streaming/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 333720a9..63a43766 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -476,8 +476,8 @@ def reset(self) -> None: "random_state": None, "shuffler": None, "_state_dict": None, - "num_workers": None, - "batch_size": None, + "num_workers": 1, + "batch_size": 1, } for prop, value in default_properties.items(): From 1ea878d07c02761d489f2679cc7ca45b82510ac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 12:21:22 +0000 Subject: [PATCH 45/63] debug --- tests/streaming/test_dataset.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 3f07e237..0868442b 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -838,7 +838,17 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): chunk_size=190, num_workers=4, ) - assert len(os.listdir(tmpdir / "optimized")) > 0 + assert set(os.listdir(tmpdir / "optimized")) == { + "chunk-0-0.bin", + "chunk-0-1.bin", + "chunk-1-0.bin", + "chunk-1-1.bin", + "chunk-2-0.bin", + "chunk-2-1.bin", + "chunk-3-0.bin", + "chunk-3-1.bin", + "index.json", + } os.mkdir(s3_cache_dir) train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir, shuffle=shuffle) From 59d870f35e8f03a20995575854d29df15e9c3864 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Jul 2024 12:21:56 +0000 Subject: [PATCH 46/63] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/utilities/shuffle.py | 4 +++- tests/streaming/test_dataset.py | 16 ++++++++-------- tests/utilities/test_shuffle.py | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 4b48a714..3d3ed6d2 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -206,7 +206,9 @@ def _get_shared_chunks(workers_chunks: List[List[int]]) -> Dict[int, List[int]]: return {chunk: workers for chunk, workers in shared_chunks.items() if len(workers) > 1} -def _aggregate_shared_chunks_per_rank(shared_chunks: Dict[int, List[int]], num_workers: int) -> Dict[int, Dict[int, List[int]]]: +def _aggregate_shared_chunks_per_rank( + shared_chunks: Dict[int, List[int]], num_workers: int +) -> Dict[int, Dict[int, List[int]]]: aggregated_shared_chunks_per_rank: Dict[int, Dict[int, List[int]]] = {} for chunk_index, workers_ids in shared_chunks.items(): aggregated_shared_chunks_per_rank[chunk_index] = {} diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 0868442b..53cfff74 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -839,14 +839,14 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): num_workers=4, ) assert set(os.listdir(tmpdir / "optimized")) == { - "chunk-0-0.bin", - "chunk-0-1.bin", - "chunk-1-0.bin", - "chunk-1-1.bin", - "chunk-2-0.bin", - "chunk-2-1.bin", - "chunk-3-0.bin", - "chunk-3-1.bin", + "chunk-0-0.bin", + "chunk-0-1.bin", + "chunk-1-0.bin", + "chunk-1-1.bin", + "chunk-2-0.bin", + "chunk-2-1.bin", + "chunk-3-0.bin", + "chunk-3-1.bin", "index.json", } diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index 9e67c592..a604cdd0 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -182,7 +182,7 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): ], ) assert chunks_to_disable == {0: [0, 1, 2]} - + # world size 2, 2 workers per rank, sharing multiple chunks chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( num_workers=2, From 65903112684040b01d5848db524b9ce2fa058c0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 12:40:13 +0000 Subject: [PATCH 47/63] mypy --- src/litdata/utilities/shuffle.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 3d3ed6d2..eb3819c9 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -108,7 +108,7 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( num_workers: int, batch_size: int, workers_chunks: List[List[int]], - workers_intervals: List[List[int]], + workers_intervals: List[List[Interval]], ) -> Dict[int, List[int]]: # {1: [2, 3, 4, 5]} # [2, 3] belongs to rank 0 @@ -135,7 +135,6 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( counter = 0 while True: - chunks_of_currently_loaded_worker = workers_chunks_for_this_rank[worker_tracker_idx % num_workers] interval_size_of_current_worker = workers_interval_sizes_for_this_rank[worker_tracker_idx % num_workers] if len(interval_size_of_current_worker) == 0: worker_tracker_idx += 1 From a6cf041016a6a3cbaa07ab1a1db564c2f52663df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 12:43:46 +0000 Subject: [PATCH 48/63] debug --- tests/streaming/test_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 53cfff74..d0f25f5e 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -837,6 +837,7 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): output_dir=str(tmpdir / "optimized"), chunk_size=190, num_workers=4, + num_uploaders=1, ) assert set(os.listdir(tmpdir / "optimized")) == { "chunk-0-0.bin", From dbdeb8a77d44838efd2cdffce6699f80a20203e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 13:49:44 +0000 Subject: [PATCH 49/63] debug --- src/litdata/processing/data_processor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index ba100794..02bbca42 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -255,12 +255,14 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints os.makedirs(os.path.dirname(output_filepath), exist_ok=True) + print(f"copying {local_filepath} to {output_filepath}") shutil.copy(local_filepath, output_filepath) else: raise ValueError(f"The provided {output_dir.path} isn't supported.") # Inform the remover to delete the file if remove_queue and os.path.exists(local_filepath): + print(f"putting {local_filepath} on the remove queue") remove_queue.put([local_filepath]) From e0720d7aa404079cbf125f1ccfd2bb907487af03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 13:59:42 +0000 Subject: [PATCH 50/63] debug --- tests/streaming/test_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index d0f25f5e..1ebf1d9e 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -839,6 +839,7 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): num_workers=4, num_uploaders=1, ) + sleep(5) # wait for copier/remover threads to complete assert set(os.listdir(tmpdir / "optimized")) == { "chunk-0-0.bin", "chunk-0-1.bin", From c20a0ec69e11225d05fa9f2d14a038ac108dc04a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 14:11:27 +0000 Subject: [PATCH 51/63] debug --- .github/workflows/ci-testing.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index b6f0d707..d75fed0c 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -76,6 +76,10 @@ jobs: pip install -e . -r requirements/test.txt -U -q --find-links $TORCH_URL pip list + # FIXME: REMOVE + - name: Setup tmate session + uses: mxschmitt/action-tmate@v3 + - name: Tests run: coverage run --source litdata -m pytest tests -v From 57559988c038d1197498b3e5493562d810c65710 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 14:40:08 +0000 Subject: [PATCH 52/63] debug --- .github/workflows/ci-testing.yml | 4 ++-- tests/conftest.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index d75fed0c..b02f3423 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -77,8 +77,8 @@ jobs: pip list # FIXME: REMOVE - - name: Setup tmate session - uses: mxschmitt/action-tmate@v3 + # - name: Setup tmate session + # uses: mxschmitt/action-tmate@v3 - name: Tests run: coverage run --source litdata -m pytest tests -v diff --git a/tests/conftest.py b/tests/conftest.py index 17e47ed8..c575f29c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import sys +import threading from types import ModuleType from unittest.mock import Mock @@ -65,3 +66,26 @@ def lightning_sdk_mock(monkeypatch): lightning_sdk = ModuleType("lightning_sdk") monkeypatch.setitem(sys.modules, "lightning_sdk", lightning_sdk) return lightning_sdk + + +@pytest.fixture(autouse=True) +def thread_police(): + """Attempts to stop left-over threads to avoid test interactions. + + Adapted from PyTorch Lightning. + """ + active_threads_before = set(threading.enumerate()) + yield + active_threads_after = set(threading.enumerate()) + + for thread in active_threads_after - active_threads_before: + stop = getattr(thread, "stop", None) or getattr(thread, "exit", None) + if thread.daemon and callable(stop): + # A daemon thread would anyway be stopped at the end of a program + # We do it preemptively here to reduce the risk of interactions with other tests that run after + stop() + assert not thread.is_alive() + elif thread.name == "QueueFeederThread": + thread.join(timeout=20) + else: + raise AssertionError(f"Test left zombie thread: {thread}") From a9c688fc37aae6d18c51baf0af8a737a3da3401b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Jul 2024 14:41:11 +0000 Subject: [PATCH 53/63] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conftest.py b/tests/conftest.py index c575f29c..7bdf2106 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,6 +73,7 @@ def thread_police(): """Attempts to stop left-over threads to avoid test interactions. Adapted from PyTorch Lightning. + """ active_threads_before = set(threading.enumerate()) yield From 18afb5e3ca1f6fd4dde7ea15a0701b277841a4c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 14:49:53 +0000 Subject: [PATCH 54/63] debug --- tests/conftest.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 7bdf2106..58522008 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -88,5 +88,7 @@ def thread_police(): assert not thread.is_alive() elif thread.name == "QueueFeederThread": thread.join(timeout=20) + elif thread.name == "PrepareChunksThread": + thread.force_stop() else: raise AssertionError(f"Test left zombie thread: {thread}") From 6b04c22e0328add402b34d23486a73db36fc4ca3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 14:57:50 +0000 Subject: [PATCH 55/63] debug --- tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 58522008..a5fb4d21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -69,7 +69,7 @@ def lightning_sdk_mock(monkeypatch): @pytest.fixture(autouse=True) -def thread_police(): +def _thread_police(): """Attempts to stop left-over threads to avoid test interactions. Adapted from PyTorch Lightning. @@ -88,7 +88,7 @@ def thread_police(): assert not thread.is_alive() elif thread.name == "QueueFeederThread": thread.join(timeout=20) - elif thread.name == "PrepareChunksThread": + elif "PrepareChunksThread" in thread.name: thread.force_stop() else: raise AssertionError(f"Test left zombie thread: {thread}") From 46daa792644ae525f70c889684d94d2c80064933 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 15:04:42 +0000 Subject: [PATCH 56/63] debug --- tests/conftest.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a5fb4d21..cc9bfdeb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import pytest import torch.distributed +from litdata.streaming.reader import PrepareChunksThread @pytest.fixture(autouse=True) @@ -80,6 +81,10 @@ def _thread_police(): active_threads_after = set(threading.enumerate()) for thread in active_threads_after - active_threads_before: + if isinstance(thread, PrepareChunksThread): + thread.force_stop() + continue + stop = getattr(thread, "stop", None) or getattr(thread, "exit", None) if thread.daemon and callable(stop): # A daemon thread would anyway be stopped at the end of a program @@ -88,7 +93,5 @@ def _thread_police(): assert not thread.is_alive() elif thread.name == "QueueFeederThread": thread.join(timeout=20) - elif "PrepareChunksThread" in thread.name: - thread.force_stop() else: raise AssertionError(f"Test left zombie thread: {thread}") From aafab96b818f254ca776fc550d249b985849318f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Jul 2024 15:06:06 +0000 Subject: [PATCH 57/63] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index cc9bfdeb..d12ff2f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -84,7 +84,7 @@ def _thread_police(): if isinstance(thread, PrepareChunksThread): thread.force_stop() continue - + stop = getattr(thread, "stop", None) or getattr(thread, "exit", None) if thread.daemon and callable(stop): # A daemon thread would anyway be stopped at the end of a program From d27fb340c4b81d00f2a5ab2c9ca074e5f39fa368 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 15:24:08 +0000 Subject: [PATCH 58/63] debug --- tests/streaming/test_dataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 1ebf1d9e..5a971069 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -827,20 +827,22 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" s3_cache_dir = str(tmpdir / "s3cache") + optimize_data_cache_dir = str(tmpdir / "optimize_data_cache") optimize_cache_dir = str(tmpdir / "optimize_cache") data_dir = str(tmpdir / "optimized") + monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", optimize_data_cache_dir) monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", optimize_cache_dir) optimize( fn=_simple_preprocess, inputs=list(range(8)), - output_dir=str(tmpdir / "optimized"), + output_dir=data_dir, chunk_size=190, num_workers=4, num_uploaders=1, ) sleep(5) # wait for copier/remover threads to complete - assert set(os.listdir(tmpdir / "optimized")) == { + assert set(os.listdir(data_dir)) == { "chunk-0-0.bin", "chunk-0-1.bin", "chunk-1-0.bin", From 06bf41413968060e295b18f7b8adb6f68970f9f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 15:33:06 +0000 Subject: [PATCH 59/63] debug --- tests/streaming/test_dataset.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 5a971069..fba51389 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -838,21 +838,21 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): inputs=list(range(8)), output_dir=data_dir, chunk_size=190, - num_workers=4, + num_workers=1, num_uploaders=1, ) sleep(5) # wait for copier/remover threads to complete - assert set(os.listdir(data_dir)) == { - "chunk-0-0.bin", - "chunk-0-1.bin", - "chunk-1-0.bin", - "chunk-1-1.bin", - "chunk-2-0.bin", - "chunk-2-1.bin", - "chunk-3-0.bin", - "chunk-3-1.bin", - "index.json", - } + # assert set(os.listdir(data_dir)) == { + # "chunk-0-0.bin", + # "chunk-0-1.bin", + # "chunk-1-0.bin", + # "chunk-1-1.bin", + # "chunk-2-0.bin", + # "chunk-2-1.bin", + # "chunk-3-0.bin", + # "chunk-3-1.bin", + # "index.json", + # } os.mkdir(s3_cache_dir) train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir, shuffle=shuffle) From bc64b7728d704b1de8786d404ccf88618b931410 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 15:44:19 +0000 Subject: [PATCH 60/63] debug --- .github/workflows/ci-testing.yml | 4 ---- src/litdata/processing/data_processor.py | 2 -- tests/conftest.py | 30 ------------------------ tests/streaming/test_dataset.py | 2 +- 4 files changed, 1 insertion(+), 37 deletions(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 1acab9fb..080e70b5 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -78,10 +78,6 @@ jobs: pip install -e . -r requirements/test.txt -U -q --find-links $TORCH_URL pip list - # FIXME: REMOVE - # - name: Setup tmate session - # uses: mxschmitt/action-tmate@v3 - - name: Tests run: coverage run --source litdata -m pytest tests -v diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 02bbca42..ba100794 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -255,14 +255,12 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints os.makedirs(os.path.dirname(output_filepath), exist_ok=True) - print(f"copying {local_filepath} to {output_filepath}") shutil.copy(local_filepath, output_filepath) else: raise ValueError(f"The provided {output_dir.path} isn't supported.") # Inform the remover to delete the file if remove_queue and os.path.exists(local_filepath): - print(f"putting {local_filepath} on the remove queue") remove_queue.put([local_filepath]) diff --git a/tests/conftest.py b/tests/conftest.py index d12ff2f6..17e47ed8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,9 @@ import sys -import threading from types import ModuleType from unittest.mock import Mock import pytest import torch.distributed -from litdata.streaming.reader import PrepareChunksThread @pytest.fixture(autouse=True) @@ -67,31 +65,3 @@ def lightning_sdk_mock(monkeypatch): lightning_sdk = ModuleType("lightning_sdk") monkeypatch.setitem(sys.modules, "lightning_sdk", lightning_sdk) return lightning_sdk - - -@pytest.fixture(autouse=True) -def _thread_police(): - """Attempts to stop left-over threads to avoid test interactions. - - Adapted from PyTorch Lightning. - - """ - active_threads_before = set(threading.enumerate()) - yield - active_threads_after = set(threading.enumerate()) - - for thread in active_threads_after - active_threads_before: - if isinstance(thread, PrepareChunksThread): - thread.force_stop() - continue - - stop = getattr(thread, "stop", None) or getattr(thread, "exit", None) - if thread.daemon and callable(stop): - # A daemon thread would anyway be stopped at the end of a program - # We do it preemptively here to reduce the risk of interactions with other tests that run after - stop() - assert not thread.is_alive() - elif thread.name == "QueueFeederThread": - thread.join(timeout=20) - else: - raise AssertionError(f"Test left zombie thread: {thread}") diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index fba51389..2299b3eb 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -838,7 +838,7 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): inputs=list(range(8)), output_dir=data_dir, chunk_size=190, - num_workers=1, + num_workers=1, # TODO: Want 4 here, but optimize() has deletion race condition num_uploaders=1, ) sleep(5) # wait for copier/remover threads to complete From ac17f3e124768c4030fcfb6ccaff93de7bc68e3c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 19 Jul 2024 12:28:38 -0400 Subject: [PATCH 61/63] Update src/litdata/utilities/shuffle.py Co-authored-by: thomas chaton --- src/litdata/utilities/shuffle.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index eb3819c9..0ea50727 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -110,9 +110,6 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( workers_chunks: List[List[int]], workers_intervals: List[List[Interval]], ) -> Dict[int, List[int]]: - # {1: [2, 3, 4, 5]} - # [2, 3] belongs to rank 0 - # [4, 5] belongs to rank 1 shared_chunks = _get_shared_chunks(workers_chunks) shared_chunks_aggregated_by_rank = _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) From 66017e8de3f2482098e8c50ac83fe160af3d1185 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 19 Jul 2024 19:36:41 +0200 Subject: [PATCH 62/63] comments and test --- src/litdata/utilities/shuffle.py | 39 ++++++++++++++++++++++++++++---- tests/utilities/test_shuffle.py | 28 +++++++++++++++++++++++ 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 0ea50727..0526906d 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -110,13 +110,24 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( workers_chunks: List[List[int]], workers_intervals: List[List[Interval]], ) -> Dict[int, List[int]]: + """Returns a dictionary mapping a chunk index to a list of workers that should not delete that chunk. + + If a worker is included in this list, it should not delete the chunk after fully reading it, because another worker + will still have items left to read and therefore needs the chunk to be present. This mapping is used in the dataset + to only let the worker delete a chunk when that worker is the last to read from it. + + """ + + # Shared chunks across all workers and ranks shared_chunks = _get_shared_chunks(workers_chunks) + + # Shared chunks grouped together by rank shared_chunks_aggregated_by_rank = _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) max_trackers = {} for chunk_index, map_local_rank_to_worker_ids in shared_chunks_aggregated_by_rank.items(): for local_rank, workers_index_sharing_chunks_for_this_rank in map_local_rank_to_worker_ids.items(): - # get all the worker chunks and intervals for this distributed rank + # Get all the worker chunks and intervals for this distributed rank workers_slice = slice(local_rank * num_workers, (local_rank + 1) * num_workers) workers_chunks_for_this_rank = copy.deepcopy(workers_chunks[workers_slice]) workers_interval_sizes_for_this_rank = copy.deepcopy( @@ -139,6 +150,8 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( num_samples_left_for_this_worker_chunk = interval_size_of_current_worker[0] + # To consume a batch, we want to subtract `batch_size` from the size we have left, + # unless we had a remainder (< batch size) from the previous iteration/chunk remover = ( batch_size if num_of_samples_to_carry_to_next_chunk is None @@ -146,12 +159,13 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( ) if num_samples_left_for_this_worker_chunk > remover: - # We have consumed a batch, going to the next worker + # There are samples left to consume, so we subtract the batch size (or a remainder) workers_interval_sizes_for_this_rank[worker_tracker_idx % num_workers][0] -= remover counter += remover num_of_samples_to_carry_to_next_chunk = None else: - # We have consumed a batch, going to the next worker + # There are fewer samples left in this chunk than we would like to consume for a full batch + # So we take what's left from the chunk and move to the next chunk to complete the batch current_worker_chunk_index = workers_chunks_for_this_rank[worker_tracker_idx % num_workers].pop(0) workers_interval_sizes_for_this_rank[worker_tracker_idx % num_workers].pop(0) counter += remover @@ -159,9 +173,11 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( if current_worker_chunk_index == chunk_index: num_shared_workers_for_this_rank -= 1 - # We consumed entirely the chunk of the worker we were tracking, let's break # TODO: Maybe, we can prevent loading over and over for each worker if num_shared_workers_for_this_rank == 0 and current_worker_chunk_index == chunk_index: + # We consumed entirely the chunk of the worker we were tracking + # Keep track of how many samples this worker consumed for this chunk and which worker + # has consumed the most samples for this chunk if chunk_index not in max_trackers: max_trackers[chunk_index] = ( local_rank * num_workers + worker_tracker_idx % num_workers, @@ -176,12 +192,18 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( break if num_samples_left_for_this_worker_chunk != batch_size: + # If a batch was not assembled completely because we're at the end of a chunk, + # we need to complete the assembly from samples in the next chunk and carry + # over that remainder to the next loop iteration num_of_samples_to_carry_to_next_chunk = batch_size - num_samples_left_for_this_worker_chunk if remover != batch_size: + # We've handled the remainder, reset it. Next iteration will start a fresh batch. num_of_samples_to_carry_to_next_chunk = None if num_of_samples_to_carry_to_next_chunk is None: + # Only go to the next worker if we assembled a full batch. If we have a remainder, + # we need to go to the next chunk with the same worker and complete the batch. worker_tracker_idx += 1 to_disable = {} @@ -192,6 +214,7 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( def _get_shared_chunks(workers_chunks: List[List[int]]) -> Dict[int, List[int]]: + """Returns a dictionary mapping a chunk index to a list of workers that share that same chunk.""" shared_chunks = {} for worker, chunks in enumerate(workers_chunks): for chunk in chunks: @@ -199,12 +222,18 @@ def _get_shared_chunks(workers_chunks: List[List[int]]) -> Dict[int, List[int]]: shared_chunks[chunk] = [worker] else: shared_chunks[chunk].append(worker) + # Remove chunk indexes that are only read by a single worker (and thus not shared) return {chunk: workers for chunk, workers in shared_chunks.items() if len(workers) > 1} def _aggregate_shared_chunks_per_rank( shared_chunks: Dict[int, List[int]], num_workers: int ) -> Dict[int, Dict[int, List[int]]]: + """Groups together shared chunks by rank. + + The output is a dictionary mapping a chunk index to a dictionary that maps a rank to a list of workers. + + """ aggregated_shared_chunks_per_rank: Dict[int, Dict[int, List[int]]] = {} for chunk_index, workers_ids in shared_chunks.items(): aggregated_shared_chunks_per_rank[chunk_index] = {} @@ -216,6 +245,8 @@ def _aggregate_shared_chunks_per_rank( def _map_node_worker_rank_to_chunk_indexes_to_not_delete(to_disable: Dict[int, List[int]]) -> Dict[int, List[int]]: + """Takes a dictionary mapping a chunk index to a list of workers and inverts the map such that it returns a + dictionary mapping a worker to a list of chunk indexes (that should not be deleted by that worker).""" map_node_worker_rank_to_chunk_indexes: Dict[int, List[int]] = {} for chunk_index, worker_ids in to_disable.items(): for worker_idx in worker_ids: diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index a604cdd0..2d26e7eb 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -1,10 +1,12 @@ from litdata.streaming.item_loader import Interval from litdata.utilities.env import _DistributedEnv from litdata.utilities.shuffle import ( + _aggregate_shared_chunks_per_rank, _associate_chunks_and_internals_to_workers, _find_chunks_per_workers_on_which_to_skip_deletion, _get_shared_chunks, _intra_node_chunk_shuffle, + _map_node_worker_rank_to_chunk_indexes_to_not_delete, ) @@ -196,3 +198,29 @@ def test_find_chunks_per_workers_on_which_to_skip_deletion(): ], ) assert chunks_to_disable == {1: [2], 4: [3]} + + +def test_aggregate_shared_chunks_per_rank(): + # world size = 1, num workers per rank = 1 + num_workers = 1 + shared_chunks = {0: [0], 1: [0], 2: [0]} # 3 chunks shared by 1 worker + expected = {0: {0: [0]}, 1: {0: [0]}, 2: {0: [0]}} + assert _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) == expected + + # world size = 1, num workers per rank = 2 + num_workers = 2 + shared_chunks = {0: [0, 1], 1: [0, 1], 2: [0]} # 3 chunks shared by 2 workers + expected = {0: {0: [0, 1]}, 1: {0: [0, 1]}, 2: {0: [0]}} + assert _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) == expected + + # world size = 4, num workers per rank = 2 + num_workers = 2 + shared_chunks = {0: [0, 2], 1: [1, 3], 2: [2, 3]} # 3 chunks distributed among 2 * 2 workers + expected = {0: {0: [0], 1: [2]}, 1: {0: [1], 1: [3]}, 2: {1: [2, 3]}} + assert _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) == expected + + +def test_map_node_worker_rank_to_chunk_indexes_to_not_delete(): + chunks_to_workers = {10: [2, 3, 4], 20: [1, 2, 3], 30: [3, 4], 40: [4, 5, 6]} + workers_to_chunks = _map_node_worker_rank_to_chunk_indexes_to_not_delete(chunks_to_workers) + assert workers_to_chunks == {1: [20], 2: [10, 20], 3: [10, 20, 30], 4: [10, 30, 40], 5: [40], 6: [40]} From e2e9ff8e29bc2aa8a91d1a04f794a97fdf02d189 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Jul 2024 17:42:29 +0000 Subject: [PATCH 63/63] internals -> intervals --- src/litdata/streaming/shuffle.py | 8 ++++---- src/litdata/utilities/shuffle.py | 2 +- tests/streaming/test_dataset.py | 4 ++-- tests/utilities/test_shuffle.py | 10 +++++----- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index fb5ed25b..a61e2f24 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -20,7 +20,7 @@ from litdata.streaming import Cache from litdata.utilities.env import _DistributedEnv from litdata.utilities.shuffle import ( - _associate_chunks_and_internals_to_workers, + _associate_chunks_and_intervals_to_workers, _intra_node_chunk_shuffle, ) @@ -70,7 +70,7 @@ def get_chunks_and_intervals_per_workers( indexes = range(len(chunk_intervals)) # 2. Compute the items budget of each rank - workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( distributed_env, indexes, chunk_intervals, self.drop_last, num_workers, batch_size ) return workers_chunks, workers_intervals @@ -118,7 +118,7 @@ def get_chunks_and_intervals_per_workers( shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist() # 3. Compute the items budget of each rank - workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size ) @@ -131,7 +131,7 @@ def get_chunks_and_intervals_per_workers( shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, workers_chunks, self.seed, current_epoch) shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist() - workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size ) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 0526906d..a8ab3930 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -43,7 +43,7 @@ def _intra_node_chunk_shuffle( return [index for chunks in chunk_indexes_per_nodes for index in chunks] -def _associate_chunks_and_internals_to_workers( +def _associate_chunks_and_intervals_to_workers( distributed_env: _DistributedEnv, indexes: Any, chunk_intervals: List[Interval], diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 2299b3eb..7b64f863 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -39,7 +39,7 @@ from litdata.streaming.shuffle import FullShuffle, NoShuffle from litdata.utilities import dataset_utilities as dataset_utilities_module from litdata.utilities.env import _DistributedEnv, _WorkerEnv -from litdata.utilities.shuffle import _associate_chunks_and_internals_to_workers +from litdata.utilities.shuffle import _associate_chunks_and_intervals_to_workers from torch.utils.data import DataLoader @@ -1000,7 +1000,7 @@ def test_replay_sampling(): def test_replay_chunks_sampling(): chunks_replica = range(10) intervals_replica = [(i, i, i + 5, i + 5) for i in range(0, 50, 5)] - workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( _DistributedEnv(2, 0, 1), chunks_replica, intervals_replica ) assert workers_chunks == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index 2d26e7eb..15532a21 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -2,7 +2,7 @@ from litdata.utilities.env import _DistributedEnv from litdata.utilities.shuffle import ( _aggregate_shared_chunks_per_rank, - _associate_chunks_and_internals_to_workers, + _associate_chunks_and_intervals_to_workers, _find_chunks_per_workers_on_which_to_skip_deletion, _get_shared_chunks, _intra_node_chunk_shuffle, @@ -24,7 +24,7 @@ def test_intra_node_chunk_shuffle(): assert shuffled_indexes == [5, 2, 0, 7, 6, 1, 3, 4, 13, 10, 8, 15, 14, 9, 11, 12] -def test_associate_chunks_and_internals_to_workers(): +def test_associate_chunks_and_intervals_to_workers(): indexes = [0, 1, 2, 3, 4, 5, 6, 7] chunk_intervals = [ Interval(0, 0, 50, 50), @@ -37,7 +37,7 @@ def test_associate_chunks_and_internals_to_workers(): Interval(0, 0, 50, 50), ] - workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals, @@ -63,7 +63,7 @@ def test_associate_chunks_and_internals_to_workers(): Interval(0, 0, 33, 33), ] - workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals, @@ -94,7 +94,7 @@ def test_associate_chunks_and_internals_to_workers(): Interval(0, 0, 1, 1), ] - workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals,