Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jul 10, 2024
1 parent ea41a03 commit 14db351
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,29 +799,35 @@ def _simple_preprocess(_):
yield torch.randint(0, 100, size=(10,), dtype=torch.int64)


def _get_simulated_s3_dataloader(tmpdir):
def _get_simulated_s3_dataloader(cache_dir, data_dir):
dataset = EmulateS3StreamingDataset(
input_dir=Dir(str(tmpdir / "s3cache"), str(tmpdir / "optimized")),
input_dir=Dir(cache_dir, data_dir),
item_loader=TokensLoader(block_size=10),
)
return StreamingDataLoader(dataset, batch_size=2, num_workers=1)


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
@mock.patch.dict(os.environ, {}, clear=True)
def test_dataset_resume_on_future_chunks(tmpdir):
"""This test is constructed to test resuming from a chunk past the first chunk, when
subsequent chunks don't have the same size."""
s3_cache_dir = str(tmpdir / "s3cache")
data_dir = str(tmpdir / "optimized")

optimize(
fn=_simple_preprocess,
inputs=list(range(8)),
output_dir=str(tmpdir / "optimized"),
chunk_size=190,
num_workers=4,
)
assert len(os.listdir(tmpdir / "optimized")) == 9 # 8 chunks + 1 index file
assert len(os.listdir(tmpdir / "optimized")) > 1

os.mkdir(tmpdir / "s3cache")
shutil.rmtree("/cache/chunks", ignore_errors=True) # TODO
os.mkdir(s3_cache_dir)
shutil.rmtree("/cache/chunks", ignore_errors=True)

train_dataloader = _get_simulated_s3_dataloader(tmpdir)
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir)
batches_to_fetch = 16
batch_to_resume_from = None
for i, batch in enumerate(train_dataloader):
Expand All @@ -831,11 +837,12 @@ def test_dataset_resume_on_future_chunks(tmpdir):
batch_to_resume_from = batch
break

shutil.rmtree(tmpdir / "s3cache")
os.mkdir(tmpdir / "s3cache")
shutil.rmtree(s3_cache_dir)
os.mkdir(s3_cache_dir)
shutil.rmtree("/cache/chunks", ignore_errors=True)
train_dataloader = _get_simulated_s3_dataloader(tmpdir)
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir)
train_dataloader.load_state_dict(dataloader_state)
# The next batch after resuming must match what we should have gotten next in the initial loop
assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from)


Expand Down

0 comments on commit 14db351

Please sign in to comment.