Skip to content

Commit

Permalink
update replay test
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jul 17, 2024
1 parent e621185 commit 0641666
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from litdata.streaming.shuffle import FullShuffle, NoShuffle
from litdata.utilities import dataset_utilities as dataset_utilities_module
from litdata.utilities.env import _DistributedEnv, _WorkerEnv
from litdata.utilities.shuffle import _associate_chunks_and_internals_to_workers
from torch.utils.data import DataLoader


Expand Down Expand Up @@ -985,14 +986,15 @@ def test_replay_sampling():
def test_replay_chunks_sampling():
chunks_replica = range(10)
intervals_replica = [(i, i, i + 5, i + 5) for i in range(0, 50, 5)]
workers_chunks, workers_intervals = _associate_chunks_to_workers(
_WorkerEnv(2, 0), chunks_replica, intervals_replica
workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers(
_DistributedEnv(2, 0, 1), chunks_replica, intervals_replica
)
assert workers_chunks == {0: [0, 2, 4, 6, 8], 1: [1, 3, 5, 7, 9]}
assert workers_intervals == {
0: [(0, 0, 5, 5), (10, 10, 15, 15), (20, 20, 25, 25), (30, 30, 35, 35), (40, 40, 45, 45)],
1: [(5, 5, 10, 10), (15, 15, 20, 20), (25, 25, 30, 30), (35, 35, 40, 40), (45, 45, 50, 50)],
}
assert workers_chunks == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
assert workers_intervals == [
[[0, 0, 5, 5], [5, 5, 10, 10], [10, 10, 15, 15], [15, 15, 20, 20], [20, 20, 25, 25]],
[[25, 25, 30, 30], [30, 30, 35, 35], [35, 35, 40, 40], [40, 40, 45, 45], [45, 45, 50, 50]]
]
workers_intervals = {i: workers_intervals[i] for i in range(len(workers_intervals))}
assert _replay_chunks_sampling(workers_intervals, {0: 16, 1: 11}) == ({0: 3, 1: 2}, {0: 1, 1: 1})
assert _replay_chunks_sampling(workers_intervals, {0: 14, 1: 13}) == ({0: 2, 1: 2}, {0: 4, 1: 3})
assert _replay_chunks_sampling(workers_intervals, {0: 15, 1: 12}) == ({0: 3, 1: 2}, {0: 0, 1: 2})
Expand Down

0 comments on commit 0641666

Please sign in to comment.