Skip to content

Commit

Permalink
chore: Add tests for CombinedStreamingDataset in test_dataloader.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bhimrazy committed Sep 3, 2024
1 parent 3d31c24 commit 0ba50c5
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tests/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,44 @@ def test_resume_dataloader_with_new_dataset(tmpdir):
dataloader.load_state_dict(dataloader_state)
for _ in dataloader:
assert dataloader.current_epoch == 2, "Current epoch should be 2"


@pytest.mark.timeout(120)
def test_combined_dataset_dataloader_states(tmpdir):
datasets = [str(tmpdir.join(f"dataset_{i}")) for i in range(2)]
for dataset in datasets:
cache = Cache(input_dir=dataset, chunk_bytes="64MB")
for i in range(50):
cache[i] = i
cache.done()
cache.merge()

dataset_1 = StreamingDataset(datasets[0], shuffle=True)
dataset_2 = StreamingDataset(datasets[1], shuffle=True)
combined_dataset = CombinedStreamingDataset(datasets=[dataset_1, dataset_2])

# Test dataloader without explicit num workers
dataloader = StreamingDataLoader(combined_dataset, batch_size=4)
assert not dataloader.restore
dataloader.load_state_dict(dataloader.state_dict())
assert dataloader.restore
batch = next(iter(dataloader))
assert len(batch) == 4, "Batch size should be 4"
assert len(dataloader) == 25, "Dataloader length should be 25 (50+50 items / batch size 4)"

# Test dataloader with num workers
dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=2)
assert len(dataloader) == 25, "Dataloader length should be 25 (50+50 items / batch size 4)"

# Verify dataloader state after partial iteration
for batch_idx, batch in enumerate(dataloader):
assert dataloader.current_epoch == 1, "Current epoch should be 1"
if batch_idx == 10:
break

dataloader.load_state_dict(dataloader.state_dict())
assert dataloader.restore

# Verify remaining batches in the first epoch

# TODO: Add more conditions to check the state of the dataloader

0 comments on commit 0ba50c5

Please sign in to comment.