Skip to content

Commit

Permalink
Feat: checkpoint optimize function to restart after crash (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
deependujha authored Jul 5, 2024
1 parent 1616aeb commit 6ebb7b9
Show file tree
Hide file tree
Showing 10 changed files with 501 additions and 27 deletions.
273 changes: 262 additions & 11 deletions src/litdata/processing/data_processor.py

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def optimize(
reader: Optional[BaseReader] = None,
batch_size: Optional[int] = None,
mode: Optional[Literal["append", "overwrite"]] = None,
use_checkpoint: bool = False,
) -> None:
"""This function converts a dataset into chunks possibly in a distributed way.
Expand All @@ -336,6 +337,8 @@ def optimize(
batch_size: Group the inputs into batches of batch_size length.
mode: The mode to use when writing the data. Accepts either ``append`` or ``overwrite`` or None.
Defaults to None.
use_checkpoint: Whether to create checkpoints while processing the data, which can be used to resume the
processing from the last checkpoint if the process is interrupted. (`Default: False`)
"""
if mode is not None and mode not in ["append", "overwrite"]:
Expand Down Expand Up @@ -377,7 +380,7 @@ def optimize(
" HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`."
)

_assert_dir_has_index_file(_output_dir, mode=mode)
_assert_dir_has_index_file(_output_dir, mode=mode, use_checkpoint=use_checkpoint)

if not isinstance(inputs, StreamingDataLoader):
resolved_dir = _resolve_dir(input_dir or _get_input_dir(inputs))
Expand Down Expand Up @@ -412,6 +415,7 @@ def optimize(
reorder_files=reorder_files,
reader=reader,
state_dict=state_dict,
use_checkpoint=use_checkpoint,
)

with optimize_dns_context(True):
Expand Down
34 changes: 34 additions & 0 deletions src/litdata/processing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,37 @@ def extract_rank_and_index_from_filename(chunk_filename: str) -> Tuple[int, int]
index = int(chunk_filename[1].split(".")[0])

return rank, index


def remove_uuid_from_filename(filepath: str) -> str:
"""Remove the unique id from the filepath. Expects the filepath to be in the format
`checkpoint-<rank>-<uuid>.json`.
e.g.: `checkpoint-0-9fe2c4e93f654fdbb24c02b15259716c.json`
-> `checkpoint-0.json`
"""

if not filepath.__contains__(".checkpoints"):
return filepath

# uuid is of 32 characters, '.json' is 5 characters and '-' is 1 character
return filepath[:-38] + ".json"


def download_directory_from_S3(bucket_name: str, remote_directory_name: str, local_directory_name: str) -> str:
s3_resource = boto3.resource("s3")
bucket = s3_resource.Bucket(bucket_name)

saved_file_dir = "."

for obj in bucket.objects.filter(Prefix=remote_directory_name):
local_filename = os.path.join(local_directory_name, obj.key)

if not os.path.exists(os.path.dirname(local_filename)):
os.makedirs(os.path.dirname(local_filename))
with open(local_filename, "wb") as f:
s3_resource.meta.client.download_fileobj(bucket_name, obj.key, f)
saved_file_dir = os.path.dirname(local_filename)

return saved_file_dir
4 changes: 4 additions & 0 deletions src/litdata/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,7 @@ def get_chunk_intervals(self) -> List[Interval]:

def _get_chunk_index_from_index(self, index: int) -> Tuple[int, int]:
return self._reader._get_chunk_index_from_index(index)

def save_checkpoint(self, checkpoint_dir: str = ".checkpoints") -> Optional[str]:
"""Save the current state of the writer to a checkpoint."""
return self._writer.save_checkpoint(checkpoint_dir=checkpoint_dir)
25 changes: 19 additions & 6 deletions src/litdata/streaming/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import datetime
import os
import re
import shutil
import sys
from contextlib import suppress
from dataclasses import dataclass
from pathlib import Path
from time import sleep
Expand Down Expand Up @@ -246,7 +248,9 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool
)


def _assert_dir_has_index_file(output_dir: Dir, mode: Optional[Literal["append", "overwrite"]] = None) -> None:
def _assert_dir_has_index_file(
output_dir: Dir, mode: Optional[Literal["append", "overwrite"]] = None, use_checkpoint: bool = False
) -> None:
if mode is not None and mode not in ["append", "overwrite"]:
raise ValueError(f"The provided `mode` should be either `append` or `overwrite`. Found {mode}.")

Expand All @@ -273,10 +277,17 @@ def _assert_dir_has_index_file(output_dir: Dir, mode: Optional[Literal["append",

# delete index.json file and chunks
if os.path.exists(os.path.join(output_dir.path, "index.json")):
# only possible if mode = "overwrite"
os.remove(os.path.join(output_dir.path, "index.json"))
for file in os.listdir(output_dir.path):
if file.endswith(".bin"):
os.remove(os.path.join(output_dir.path, file))

if mode == "overwrite" or (mode is None and not use_checkpoint):
for file in os.listdir(output_dir.path):
if file.endswith(".bin"):
os.remove(os.path.join(output_dir.path, file))

# delete checkpoints
with suppress(FileNotFoundError):
shutil.rmtree(os.path.join(output_dir.path, ".checkpoints"))

return

Expand Down Expand Up @@ -316,8 +327,10 @@ def _assert_dir_has_index_file(output_dir: Dir, mode: Optional[Literal["append",
# Delete all the files (including the index file in overwrite mode)
bucket_name = obj.netloc
s3 = boto3.resource("s3")
for obj in s3.Bucket(bucket_name).objects.filter(Prefix=prefix):
s3.Object(bucket_name, obj.key).delete()

if mode == "overwrite" or (mode is None and not use_checkpoint):
for obj in s3.Bucket(bucket_name).objects.filter(Prefix=prefix):
s3.Object(bucket_name, obj.key).delete()


def _get_lightning_cloud_url() -> str:
Expand Down
28 changes: 26 additions & 2 deletions src/litdata/streaming/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import json
import os
import uuid
import warnings
from dataclasses import dataclass
from time import sleep
Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(
"""
self._cache_dir = cache_dir

os.makedirs(self._cache_dir, exist_ok=True)
if (isinstance(self._cache_dir, str) and not os.path.exists(self._cache_dir)) or self._cache_dir is None:
raise FileNotFoundError(f"The provided cache directory `{self._cache_dir}` doesn't exist.")

Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(

self._per_sample_num_bytes = 0
self._per_sample_num_items = 0
self.last_checkpoint_chunk_info: List[Dict[str, Any]] = []

@property
def filled(self) -> bool:
Expand Down Expand Up @@ -458,7 +460,7 @@ def _merge_no_wait(self, node_rank: Optional[int] = None, existing_index: Option
elif config != data["config"]:
raise Exception(
"The config isn't consistent between chunks. This shouldn't have happened."
f"Found {config} {data['config']}."
f"Found {config}; {data['config']}."
)

chunks_info.extend(data["chunks"])
Expand Down Expand Up @@ -494,3 +496,25 @@ def _pretty_serialized_items(self) -> Dict[int, Item]:
data=b"",
)
return out

def save_checkpoint(self, checkpoint_dir: str = ".checkpoints") -> Optional[str]:
"""Save the current state of the writer to a checkpoint."""
checkpoint_dir = os.path.join(self._cache_dir, checkpoint_dir)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)

if self._chunks_info == self.last_checkpoint_chunk_info:
# to avoid saving the same checkpoint twice
return None

unique_id = uuid.uuid4().hex
done_till_index = sum(chnk_info["chunk_size"] for chnk_info in self._chunks_info)

checkpoint_filepath = os.path.join(checkpoint_dir, f"checkpoint-{self.rank}-{unique_id}.json")

checkPoint = {"chunks": self._chunks_info, "config": self.get_config(), "done_till_index": done_till_index}

with open(checkpoint_filepath, "w") as f:
json.dump(checkPoint, f)

return checkpoint_filepath
86 changes: 85 additions & 1 deletion tests/processing/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,17 @@ def different_compress(index):
return index, index**2, index**3


@pytest.mark.skipif(sys.platform == "win32" and sys.platform == "darwin", reason="too slow")
def fn(i: int):
if i in [1, 2, 4]:
raise ValueError("An error occurred")
return i, i**2


def another_fn(i: int):
return i, i**2


@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow")
def test_optimize_append_overwrite(tmpdir):
output_dir = str(tmpdir / "output_dir")

Expand Down Expand Up @@ -157,6 +167,80 @@ def test_optimize_append_overwrite(tmpdir):
assert ds[:] == [(i, i**2, i**3) for i in range(0, 5)]


@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow")
def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
output_dir = str(tmpdir / "output_dir")

with pytest.raises(RuntimeError, match="We found the following error"):
optimize(
fn=fn,
inputs=list(range(4)),
output_dir=output_dir,
chunk_size=1,
num_workers=2,
use_checkpoint=True,
)

# check that the checkpoints are created
assert os.path.exists(os.path.join(output_dir, ".checkpoints"))
assert os.path.exists(os.path.join(output_dir, ".checkpoints", "config.json"))

optimize(
fn=another_fn,
inputs=list(range(4)),
output_dir=output_dir,
chunk_size=1,
num_workers=2,
use_checkpoint=True,
)

ds = StreamingDataset(output_dir)

assert len(ds) == 4
assert ds[:] == [(i, i**2) for i in range(4)]
# checkpoints should be deleted
assert not os.path.exists(os.path.join(output_dir, ".checkpoints"))

# --------- now test for append mode ---------

with pytest.raises(RuntimeError, match="We found the following error"):
optimize(
fn=fn,
inputs=list(range(4, 8)),
output_dir=output_dir,
chunk_size=1,
num_workers=2,
use_checkpoint=True,
mode="append",
)

# check that the checkpoints are created
assert os.path.exists(os.path.join(output_dir, ".checkpoints"))
assert os.path.exists(os.path.join(output_dir, ".checkpoints", "config.json"))
print("-" * 80)
# print all the files in the checkpoints folder
for f in os.listdir(os.path.join(output_dir, ".checkpoints")):
print(f)
print("-" * 80)

optimize(
fn=another_fn,
inputs=list(range(4, 8)),
output_dir=output_dir,
chunk_size=1,
num_workers=2,
use_checkpoint=True,
mode="append",
)

ds = StreamingDataset(output_dir)

assert len(ds) == 8
assert ds[:] == [(i, i**2) for i in range(8)]
# checkpoints should be deleted
assert not os.path.exists(os.path.join(output_dir, ".checkpoints"))


def test_merge_datasets(tmpdir):
folder_1 = os.path.join(tmpdir, "folder_1")
folder_2 = os.path.join(tmpdir, "folder_2")
Expand Down
26 changes: 26 additions & 0 deletions tests/processing/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
extract_rank_and_index_from_filename,
optimize_dns_context,
read_index_file_content,
remove_uuid_from_filename,
)
from litdata.streaming.resolver import _resolve_dir

Expand Down Expand Up @@ -84,3 +85,28 @@ def test_read_index_file_content(tmpdir):
json.dump(dummy_dict, f)

assert read_index_file_content(_resolve_dir(str(output_dir))) == dummy_dict


def test_remove_uuid_from_filename():
filepaths = [
"checkpoint-0-9fe2c4e93f654fdbb24c02b15259716c.json",
"checkpoint-1-9fe2c4e93f654fdbb24c02b15259716c.json",
"checkpoint-2-9fe2c4e93f654fdbb24c02b15259716c.json",
"checkpoint-101-9fe2c4e93f654fdbb24c02b15259716c.json",
"checkpoint-12-9fe2c4e93f654fdbb24c02b15259716c.json",
"checkpoint-267-9fe2c4e93f654fdbb24c02b15259716c.json",
]

expected = [
"checkpoint-0.json",
"checkpoint-1.json",
"checkpoint-2.json",
"checkpoint-101.json",
"checkpoint-12.json",
"checkpoint-267.json",
]

for idx, filepath in enumerate(filepaths):
filepath = ".checkpoints/" + filepath
result = remove_uuid_from_filename(filepath)
assert result == ".checkpoints/" + expected[idx]
20 changes: 20 additions & 0 deletions tests/streaming/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,23 @@ def test_cache_for_text_tokens(tmpdir):

with pytest.raises(ValueError, match="TokensLoader"):
len(Cache(str(tmpdir), chunk_size=block_size * 11))


def test_cache_checkpoint(tmpdir):
cache_dir = os.path.join(tmpdir, "cache_checkpoint")
os.makedirs(cache_dir)

cache = Cache(cache_dir, chunk_bytes=90)

# you encode data
for i in range(100):
cache[i] = i

# I am done, write the index ...
cache.done()
cache.merge()
cache.save_checkpoint()

for file in os.listdir(os.path.join(cache_dir, ".checkpoints")):
assert file.__contains__("checkpoint-0")
assert file.endswith(".json")
26 changes: 20 additions & 6 deletions tests/streaming/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ def seed_everything(random_seed):


def test_binary_writer_with_ints_and_chunk_bytes(tmpdir):
with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."):
BinaryWriter("dontexists", {})

match = (
"The provided compression something_else isn't available"
if _ZSTD_AVAILABLE
Expand Down Expand Up @@ -81,9 +78,6 @@ def test_binary_writer_with_ints_and_chunk_bytes(tmpdir):
def test_binary_writer_with_ints_and_chunk_size(tmpdir):
seed_everything(42)

with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."):
BinaryWriter("dontexists", {})

match = (
"The provided compression something_else isn't available"
if _ZSTD_AVAILABLE
Expand Down Expand Up @@ -252,3 +246,23 @@ def test_writer_unordered_indexes(tmpdir):
assert data["chunks"][0]["chunk_size"] == 5
assert data["chunks"][1]["chunk_size"] == 5
assert data["chunks"][2]["chunk_size"] == 2


def test_writer_save_checkpoint(tmpdir):
cache_dir = os.path.join(tmpdir, "chunks")
os.makedirs(cache_dir, exist_ok=True)

binary_writer = BinaryWriter(cache_dir, chunk_size=5)

arr = [2, 3, 1, 4, 6, 5, 7, 8, 11, 9, 10, 12]

for i in arr:
binary_writer[i] = i - 1

binary_writer.done()
binary_writer.merge()
binary_writer.save_checkpoint()

for file in os.listdir(os.path.join(cache_dir, ".checkpoints")):
assert file.__contains__("checkpoint-0")
assert file.endswith(".json")

0 comments on commit 6ebb7b9

Please sign in to comment.