Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jul 17, 2024
1 parent f8fe74c commit 0eec395
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,20 +505,22 @@ def test_dataset_for_text_tokens_multiple_workers(tmpdir):
assert len(dataloader) == 10

expected = [
[0, 10],
[40, 50],
[20, 30],
[60, 70],
[80, 90],
[120, 130],
[100, 110],
[140, 150],
[160, 170],
[0, 10],
[100, 110],
[20, 30],
[120, 130],
[40, 50],
[140, 150],
[60, 70],
[160, 170],
[80, 90],
[180, 190],
]

for result, batch in zip(expected, dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == result
result = []
for batch in dataloader:
result.append(batch[:, 0].tolist())
assert result == expected


def test_dataset_for_text_tokens_distributed_num_workers(tmpdir):
Expand Down Expand Up @@ -623,7 +625,7 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk

expected = [[0, 10], [40, 50], [20, 30], [60, 70]]
returned = []
for batch_idx, batch in enumerate(dataloader):
for batch in dataloader:
returned.append(batch[:, 0].tolist())
assert returned == expected

Expand All @@ -638,7 +640,7 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk

expected = [[80, 90], [120, 130], [100, 110], [140, 150]]
returned = []
for batch_idx, batch in enumerate(dataloader):
for batch in dataloader:
returned.append(batch[:, 0].tolist())
assert returned == expected

Expand Down

0 comments on commit 0eec395

Please sign in to comment.