Skip to content

Commit

Permalink
Feat: adds support for reading mosaic mds written dataset (#210)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Deependu Jha <[email protected]>
  • Loading branch information
3 people authored Jul 8, 2024
1 parent 70dc4a4 commit 14a47f2
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 27 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ ld.map(

**Key benefits:**

Paralellize processing: Reduce processing time by transforming data across multiple machines simultaneously.
Parallelize processing: Reduce processing time by transforming data across multiple machines simultaneously.
✅ Scale to large data: Increase the size of datasets you can efficiently handle.
✅ Flexible usecases: Resize images, create embeddings, scrape the internet, etc...
✅ Run local or cloud: Run on your own machines or auto-scale to 1000s of cloud GPUs with Lightning Studios.
Expand Down Expand Up @@ -638,7 +638,7 @@ Time to optimize 1.2 million ImageNet images (Faster is better):

----

# Paralellize transforms and data optimization on cloud machines
# Parallelize transforms and data optimization on cloud machines
<div align="center">
<img alt="Lightning" src="https://pl-flash-data.s3.amazonaws.com/data-prep.jpg" width="700px">
</div>
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
coverage ==7.5.3
mosaicml-streaming==0.7.6
pytest ==8.2.*
pytest-cov ==5.0.0
pytest-timeout ==2.3.1
Expand Down
4 changes: 2 additions & 2 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from litdata.streaming.resolver import _resolve_dir
from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads
from litdata.utilities.broadcast import broadcast_object
from litdata.utilities.dataset_utilities import load_index_file
from litdata.utilities.packing import _pack_greedily

if _TQDM_AVAILABLE:
Expand Down Expand Up @@ -788,8 +789,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul
self._upload_index(output_dir, cache_dir, num_nodes, node_rank)

if num_nodes == node_rank + 1:
with open(os.path.join(cache_dir, _INDEX_FILENAME)) as f:
config = json.load(f)
config = load_index_file(cache_dir)

size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]])
num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]])
Expand Down
22 changes: 11 additions & 11 deletions src/litdata/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -22,6 +21,7 @@
from litdata.streaming.sampler import ChunkedIndex
from litdata.streaming.serializers import Serializer
from litdata.utilities._pytree import tree_unflatten, treespec_loads
from litdata.utilities.dataset_utilities import load_index_file


class ChunksConfig:
Expand Down Expand Up @@ -53,18 +53,18 @@ def __init__(
self._remote_dir = remote_dir
self._item_loader = item_loader or PyTreeLoader()

with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f:
data = json.load(f)
_original_chunks = data["chunks"]
self._config = data["config"]
self._validate_item_loader()
# load data from `index.json` file
data = load_index_file(self._cache_dir)
_original_chunks = data["chunks"]
self._config = data["config"]
self._validate_item_loader()

assert _original_chunks is not None
assert _original_chunks is not None

if subsampled_files is None:
self._chunks = _original_chunks
else:
self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks)
if subsampled_files is None:
self._chunks = _original_chunks
else:
self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks)

self._config["data_spec"] = treespec_loads(self._config["data_spec"])

Expand Down
24 changes: 24 additions & 0 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,32 @@ def load_item_from_chunk(
fp.seek(begin)
data = fp.read(end - begin)

# check for mosaic mds format
if "format" in self._config and self._config["format"] == "mds":
return self.mds_deserialize(data, chunk_index)
return self.deserialize(data)

def mds_deserialize(self, raw_item_data: bytes, chunk_index: int) -> "PyTree":
"""Deserialize the mds raw bytes into their python equivalent."""
idx = 0
sizes = []
column_sizes = self._chunks[chunk_index]["column_sizes"]
# adapted from: MDSReader.deserialize : https://github.com/mosaicml/streaming/blob/main/streaming/base/format/mds/reader.py
for size in column_sizes:
if size:
sizes.append(size)
else:
(size,) = np.frombuffer(raw_item_data[idx : idx + 4], np.uint32)
sizes.append(size)
idx += 4
data = []
for size, data_format in zip(sizes, self._data_format):
serializer = self._serializers[data_format]
data_bytes = raw_item_data[idx : idx + size]
data.append(serializer.deserialize(data_bytes))
idx += size
return tree_unflatten(data, self._config["data_spec"])

def deserialize(self, raw_item_data: bytes) -> "PyTree":
"""Deserialize the raw bytes into their python equivalent."""
idx = self._shift_idx
Expand Down
78 changes: 75 additions & 3 deletions src/litdata/utilities/dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ def subsample_streaming_dataset(

if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)):
# load chunks from `index.json` file
with open(os.path.join(input_dir.path, _INDEX_FILENAME)) as f:
data = json.load(f)
original_chunks = data["chunks"]
data = load_index_file(input_dir.path)
original_chunks = data["chunks"]
else:
raise ValueError(
f"The provided dataset `{input_dir.path}` doesn't contain any {_INDEX_FILENAME} file."
Expand Down Expand Up @@ -115,3 +114,76 @@ def generate_roi(chunks: List[Dict[str, Any]], item_loader: Optional[BaseItemLoa
roi.append((0, end))

return roi


def load_index_file(input_dir: str) -> Dict[str, Any]:
"""Load index file from the specified input directory.
This function supports loading both chunk-based and mds shard-based index files.
For shard-based files, it adapts the format to be compatible with chunk-based processing.
Args:
input_dir (str): The directory containing the index file.
Returns:
Dict[str, Any]: The loaded and possibly adapted index data.
Raises:
FileNotFoundError: If the index file does not exist in the input directory.
"""
index_filepath = os.path.join(input_dir, _INDEX_FILENAME)
try:
with open(index_filepath) as f:
data = json.load(f)

if "chunks" not in data and "shards" in data:
# load mds shard-based index file and adapt to chunks format
return adapt_mds_shards_to_chunks(data)

return data
except FileNotFoundError:
raise FileNotFoundError(f"Index file not found at {index_filepath}.")


def adapt_mds_shards_to_chunks(data: Dict[str, Any]) -> Dict[str, Any]:
"""Adapt mds shard-based index data to chunk-based format for compatibility.
For more details about MDS, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming
Args:
data (Dict[str, Any]): The original index data containing shards.
Returns:
Dict[str, Any]: Adapted index data with chunks format.
"""
chunks = []
shards = data["shards"]
for shard in shards:
chunks.append(
{
"chunk_bytes": shard["zip_data"]["bytes"],
"chunk_size": shard["samples"],
"column_sizes": shard["column_sizes"],
"dim": None,
"filename": shard["zip_data"]["basename"],
}
)
data["chunks"] = chunks

data_spec = [
1,
{
"type": "builtins.dict",
"context": json.dumps(shards[0]["column_names"]),
"children_spec": [{"type": None, "context": None, "children_spec": []} for _ in shards[0]["column_names"]],
},
]
data["config"] = {
"chunk_bytes": sum(shard["zip_data"]["bytes"] for shard in shards),
"chunk_size": sum(shard["samples"] for shard in shards),
"compression": shards[0]["compression"],
"data_format": shards[0]["column_encodings"],
"format": shards[0]["format"],
"data_spec": json.dumps(data_spec),
}
return data
16 changes: 7 additions & 9 deletions src/litdata/utilities/train_test_split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
import os
from copy import deepcopy
Expand All @@ -8,6 +7,7 @@

from litdata import StreamingDataset
from litdata.constants import _INDEX_FILENAME
from litdata.utilities.dataset_utilities import load_index_file
from litdata.utilities.subsample import shuffle_lists_together, subsample_filenames_and_roi


Expand Down Expand Up @@ -55,14 +55,12 @@ def train_test_split(

if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)):
# load chunks from `index.json` file
with open(os.path.join(input_dir.path, _INDEX_FILENAME)) as f:
data = json.load(f)
original_chunks = data["chunks"]
subsampled_chunks = [
_org_chunk
for _org_chunk in original_chunks
if _org_chunk["filename"] in dummy_subsampled_chunk_filename
]
data = load_index_file(input_dir.path)

original_chunks = data["chunks"]
subsampled_chunks = [
_org_chunk for _org_chunk in original_chunks if _org_chunk["filename"] in dummy_subsampled_chunk_filename
]
else:
raise ValueError("Couldn't load original chunk file.")

Expand Down
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,25 @@ def teardown_process_group(): # noqa: PT004
yield
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


@pytest.fixture()
def mosaic_mds_index_data():
return {
"shards": [
{
"column_encodings": ["int", "jpeg"],
"column_names": ["class", "image"],
"column_sizes": [8, None],
"compression": "zstd",
"format": "mds",
"hashes": [],
"raw_data": {"basename": "shard.00000.mds", "bytes": 125824, "hashes": {}},
"samples": 100,
"size_limit": 67108864,
"version": 2,
"zip_data": {"basename": "shard.00000.mds.zstd", "bytes": 63407, "hashes": {}},
}
],
"version": 2,
}
67 changes: 67 additions & 0 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import pytest
import torch
from litdata import train_test_split
from litdata.constants import _ZSTD_AVAILABLE
from litdata.processing import functions
from litdata.streaming import Cache
Expand Down Expand Up @@ -995,3 +996,69 @@ def test_subsample_streaming_dataset_with_token_loader(tmpdir, monkeypatch):
)

assert len(dataset2) == int(len(dataset1) * 0.4)


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows")
def test_dataset_with_mosaic_mds_data(tmpdir):
from PIL import Image
from streaming import MDSWriter
# example taken from: https://github.com/mosaicml/streaming

# A dictionary mapping input fields to their data types
columns = {"image": "jpeg", "class": "int"}
# Shard compression, if any
compression = "zstd"
# Save the samples as shards using MDSWriter
with MDSWriter(out=str(tmpdir), columns=columns, compression=compression) as out:
for i in range(10):
sample = {
"image": Image.fromarray(np.random.randint(0, 256, (32, 32, 3), np.uint8)),
"class": i,
}
out.write(sample)
dataset = StreamingDataset(input_dir=str(tmpdir))
assert len(dataset) == 10
for i in range(10):
sample = dataset[i]
assert sample["class"] == i

assert [sample["class"] for sample in dataset[:]] == list(range(10)) # test slicing

# -------------- train_test_split --------------

train_ds, test_ds, val_ds = train_test_split(dataset, splits=[0.7, 0.2, 0.1])

assert len(train_ds) == 7
assert len(test_ds) == 2
assert len(val_ds) == 1

# -------------- subsample --------------

dataset = StreamingDataset(input_dir=str(tmpdir), subsample=0.4)
assert len(dataset) == 4
assert [sample["class"] for sample in dataset[:]] == [0, 1, 2, 3]

# -------------- works with dataloader --------------

dataset = StreamingDataset(input_dir=str(tmpdir))
dataloader = DataLoader(dataset, batch_size=4, drop_last=True)
i = 0
for batch in dataloader:
assert len(batch["class"]) == 4
assert len(batch["image"]) == 4
assert list(batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
i += 1

dataloader = DataLoader(dataset, batch_size=4, drop_last=False)
i = 0
for batch in dataloader:
if i == 2:
# last batch is smaller than batch_size
assert len(batch["class"]) == 2
assert len(batch["image"]) == 2
assert list(batch["class"]) == [4 * i, 4 * i + 1]
break
assert len(batch["class"]) == 4
assert len(batch["image"]) == 4
assert list(batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
i += 1
20 changes: 20 additions & 0 deletions tests/utilities/test_dataset_utilities.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import json
import os
from unittest import mock

from litdata.constants import _INDEX_FILENAME
from litdata.utilities.dataset_utilities import (
_should_replace_path,
_try_create_cache_dir,
adapt_mds_shards_to_chunks,
generate_roi,
load_index_file,
)


Expand Down Expand Up @@ -44,3 +48,19 @@ def test_generate_roi():
my_roi = generate_roi(my_chunks)

assert my_roi == [(0, 30), (0, 50), (0, 20), (0, 10)]


def test_load_index_file(tmpdir, mosaic_mds_index_data):
with open(os.path.join(tmpdir, _INDEX_FILENAME), "w") as f:
f.write(json.dumps(mosaic_mds_index_data))
index_data = load_index_file(tmpdir)
assert "chunks" in index_data
assert "config" in index_data
assert len(mosaic_mds_index_data["shards"]) == len(index_data["chunks"])


def test_adapt_mds_shards_to_chunks(mosaic_mds_index_data):
adapted_data = adapt_mds_shards_to_chunks(mosaic_mds_index_data)
assert "chunks" in adapted_data
assert "config" in adapted_data
assert len(mosaic_mds_index_data["shards"]) == len(adapted_data["chunks"])

0 comments on commit 14a47f2

Please sign in to comment.