diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index c23df3fd..cf91420f 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -505,20 +505,22 @@ def test_dataset_for_text_tokens_multiple_workers(tmpdir): assert len(dataloader) == 10 expected = [ - [0, 10], - [40, 50], - [20, 30], - [60, 70], - [80, 90], - [120, 130], - [100, 110], - [140, 150], - [160, 170], + [0, 10], + [100, 110], + [20, 30], + [120, 130], + [40, 50], + [140, 150], + [60, 70], + [160, 170], + [80, 90], [180, 190], ] - for result, batch in zip(expected, dataloader): - assert [batch[0][0].item(), batch[1][0].item()] == result + result = [] + for batch in dataloader: + result.append(batch[:, 0].tolist()) + assert result == expected def test_dataset_for_text_tokens_distributed_num_workers(tmpdir): @@ -623,7 +625,7 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk expected = [[0, 10], [40, 50], [20, 30], [60, 70]] returned = [] - for batch_idx, batch in enumerate(dataloader): + for batch in dataloader: returned.append(batch[:, 0].tolist()) assert returned == expected @@ -638,7 +640,7 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk expected = [[80, 90], [120, 130], [100, 110], [140, 150]] returned = [] - for batch_idx, batch in enumerate(dataloader): + for batch in dataloader: returned.append(batch[:, 0].tolist()) assert returned == expected