diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 0183f3ae..b42162ca 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -39,6 +39,7 @@ from litdata.streaming.shuffle import FullShuffle, NoShuffle from litdata.utilities import dataset_utilities as dataset_utilities_module from litdata.utilities.env import _DistributedEnv, _WorkerEnv +from litdata.utilities.shuffle import _associate_chunks_and_internals_to_workers from torch.utils.data import DataLoader @@ -985,14 +986,15 @@ def test_replay_sampling(): def test_replay_chunks_sampling(): chunks_replica = range(10) intervals_replica = [(i, i, i + 5, i + 5) for i in range(0, 50, 5)] - workers_chunks, workers_intervals = _associate_chunks_to_workers( - _WorkerEnv(2, 0), chunks_replica, intervals_replica + workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers( + _DistributedEnv(2, 0, 1), chunks_replica, intervals_replica ) - assert workers_chunks == {0: [0, 2, 4, 6, 8], 1: [1, 3, 5, 7, 9]} - assert workers_intervals == { - 0: [(0, 0, 5, 5), (10, 10, 15, 15), (20, 20, 25, 25), (30, 30, 35, 35), (40, 40, 45, 45)], - 1: [(5, 5, 10, 10), (15, 15, 20, 20), (25, 25, 30, 30), (35, 35, 40, 40), (45, 45, 50, 50)], - } + assert workers_chunks == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] + assert workers_intervals == [ + [[0, 0, 5, 5], [5, 5, 10, 10], [10, 10, 15, 15], [15, 15, 20, 20], [20, 20, 25, 25]], + [[25, 25, 30, 30], [30, 30, 35, 35], [35, 35, 40, 40], [40, 40, 45, 45], [45, 45, 50, 50]] + ] + workers_intervals = {i: workers_intervals[i] for i in range(len(workers_intervals))} assert _replay_chunks_sampling(workers_intervals, {0: 16, 1: 11}) == ({0: 3, 1: 2}, {0: 1, 1: 1}) assert _replay_chunks_sampling(workers_intervals, {0: 14, 1: 13}) == ({0: 2, 1: 2}, {0: 4, 1: 3}) assert _replay_chunks_sampling(workers_intervals, {0: 15, 1: 12}) == ({0: 3, 1: 2}, {0: 0, 1: 2})