Skip to content

Commit

Permalink
add tests for resuming dataloader from a checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
deependujha committed Jul 10, 2024
1 parent d259cd9 commit e3b565b
Showing 1 changed file with 97 additions and 1 deletion.
98 changes: 97 additions & 1 deletion tests/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
import json
import os

import pytest
import torch
from litdata.constants import _VIZ_TRACKER_AVAILABLE
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader
from litdata.processing.functions import optimize
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset
from litdata.streaming import dataloader as streaming_dataloader_module
from torch import tensor


def compress(index):
return (index, index ** 2)


def another_compress(index):
return (index, index * 2)

class TestStatefulDataset:
def __init__(self, size, step):
self.size = size
Expand Down Expand Up @@ -180,3 +189,90 @@ def test_custom_collate_multiworker():

# Try calling the state_dict. No error should follow
_state_dict = dataloader.state_dict()


def test_resume_single_dataset_dataloader_from_checkpoint(tmpdir):
# create an optimize dataset
output_dir = os.path.join(tmpdir, "output_dir")
optimize(
fn=compress,
inputs=list(range(10)),
num_workers=2,
output_dir=output_dir,
chunk_size=3,
)

ds = StreamingDataset(output_dir)
dataloader = StreamingDataLoader(ds, batch_size=2, num_workers=2, pin_memory=True)

for i, batch in enumerate(dataloader):
if i == 2:
curr_state_dict = dataloader.state_dict()
with open(os.path.join(tmpdir, "state_dict.json"), "w") as f:
json.dump(curr_state_dict, f)
break

# load the state dict
with open(os.path.join(tmpdir, "state_dict.json"), "r") as f:
state_dict = json.load(f)

# create a new dataloader
ds = StreamingDataset(output_dir)
dataloader = StreamingDataLoader(ds, batch_size=2, num_workers=2, pin_memory=True)
dataloader.load_state_dict(state_dict)

count = 0
for _batch in enumerate(dataloader):
count += 1

# 3 batches in first iteration, 2 batches in second iteration (restart from checkpoint)
assert count == 2


def test_resume_combined_dataset_dataloader_from_checkpoint(tmpdir):
# create two optimize datasets
output_dir_1 = os.path.join(tmpdir, "output_dir_1")
output_dir_2 = os.path.join(tmpdir, "output_dir_2")
optimize(
fn=compress,
inputs=list(range(10)),
num_workers=2,
output_dir=output_dir_1,
chunk_size=3,
)
optimize(
fn=another_compress,
inputs=list(range(10,20)),
num_workers=2,
output_dir=output_dir_2,
chunk_size=3,
)

ds = CombinedStreamingDataset(
[StreamingDataset(output_dir_1), StreamingDataset(output_dir_2)], seed=42)

dataloader = StreamingDataLoader(ds, batch_size=2, num_workers=2, pin_memory=True)

for i, batch in enumerate(dataloader):
if i == 2:
curr_state_dict = dataloader.state_dict()
with open(os.path.join(tmpdir, "state_dict.json"), "w") as f:
json.dump(curr_state_dict, f)
break

# load the state dict
with open(os.path.join(tmpdir, "state_dict.json"), "r") as f:
state_dict = json.load(f)

# create a new dataloader
ds = CombinedStreamingDataset(
[StreamingDataset(output_dir_1), StreamingDataset(output_dir_2)], seed=42)
dataloader = StreamingDataLoader(ds, batch_size=2, num_workers=2, pin_memory=True)
dataloader.load_state_dict(state_dict)

count = 0
for _batch in enumerate(dataloader):
count += 1

# 3 batches in first iteration, 7 batches in second iteration (restart from checkpoint)
assert count == 7

0 comments on commit e3b565b

Please sign in to comment.