Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 10, 2024
1 parent e3b565b commit b0251a3
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions tests/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@


def compress(index):
return (index, index ** 2)
return (index, index**2)


def another_compress(index):
return (index, index * 2)


class TestStatefulDataset:
def __init__(self, size, step):
self.size = size
Expand Down Expand Up @@ -213,7 +214,7 @@ def test_resume_single_dataset_dataloader_from_checkpoint(tmpdir):
break

# load the state dict
with open(os.path.join(tmpdir, "state_dict.json"), "r") as f:
with open(os.path.join(tmpdir, "state_dict.json")) as f:
state_dict = json.load(f)

# create a new dataloader
Expand Down Expand Up @@ -242,14 +243,13 @@ def test_resume_combined_dataset_dataloader_from_checkpoint(tmpdir):
)
optimize(
fn=another_compress,
inputs=list(range(10,20)),
inputs=list(range(10, 20)),
num_workers=2,
output_dir=output_dir_2,
chunk_size=3,
)

ds = CombinedStreamingDataset(
[StreamingDataset(output_dir_1), StreamingDataset(output_dir_2)], seed=42)
ds = CombinedStreamingDataset([StreamingDataset(output_dir_1), StreamingDataset(output_dir_2)], seed=42)

dataloader = StreamingDataLoader(ds, batch_size=2, num_workers=2, pin_memory=True)

Expand All @@ -261,12 +261,11 @@ def test_resume_combined_dataset_dataloader_from_checkpoint(tmpdir):
break

# load the state dict
with open(os.path.join(tmpdir, "state_dict.json"), "r") as f:
with open(os.path.join(tmpdir, "state_dict.json")) as f:
state_dict = json.load(f)

# create a new dataloader
ds = CombinedStreamingDataset(
[StreamingDataset(output_dir_1), StreamingDataset(output_dir_2)], seed=42)
ds = CombinedStreamingDataset([StreamingDataset(output_dir_1), StreamingDataset(output_dir_2)], seed=42)
dataloader = StreamingDataLoader(ds, batch_size=2, num_workers=2, pin_memory=True)
dataloader.load_state_dict(state_dict)

Expand Down

0 comments on commit b0251a3

Please sign in to comment.