From 14a47f21575abd045115ad3c43a766d08b102f02 Mon Sep 17 00:00:00 2001 From: Bhimraj Yadav Date: Mon, 8 Jul 2024 12:59:22 +0545 Subject: [PATCH] Feat: adds support for reading mosaic mds written dataset (#210) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deependu Jha --- README.md | 4 +- requirements/test.txt | 1 + src/litdata/processing/data_processor.py | 4 +- src/litdata/streaming/config.py | 22 +++--- src/litdata/streaming/item_loader.py | 24 +++++++ src/litdata/utilities/dataset_utilities.py | 78 +++++++++++++++++++++- src/litdata/utilities/train_test_split.py | 16 ++--- tests/conftest.py | 22 ++++++ tests/streaming/test_dataset.py | 67 +++++++++++++++++++ tests/utilities/test_dataset_utilities.py | 20 ++++++ 10 files changed, 231 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 0b6e6575..003eac48 100644 --- a/README.md +++ b/README.md @@ -174,7 +174,7 @@ ld.map( **Key benefits:** -✅ Paralellize processing: Reduce processing time by transforming data across multiple machines simultaneously. +✅ Parallelize processing: Reduce processing time by transforming data across multiple machines simultaneously. ✅ Scale to large data: Increase the size of datasets you can efficiently handle. ✅ Flexible usecases: Resize images, create embeddings, scrape the internet, etc... ✅ Run local or cloud: Run on your own machines or auto-scale to 1000s of cloud GPUs with Lightning Studios. @@ -638,7 +638,7 @@ Time to optimize 1.2 million ImageNet images (Faster is better): ---- -# Paralellize transforms and data optimization on cloud machines +# Parallelize transforms and data optimization on cloud machines
Lightning
diff --git a/requirements/test.txt b/requirements/test.txt index d8240755..71de5307 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,4 +1,5 @@ coverage ==7.5.3 +mosaicml-streaming==0.7.6 pytest ==8.2.* pytest-cov ==5.0.0 pytest-timeout ==2.3.1 diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index d44662ae..d2220a00 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -51,6 +51,7 @@ from litdata.streaming.resolver import _resolve_dir from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads from litdata.utilities.broadcast import broadcast_object +from litdata.utilities.dataset_utilities import load_index_file from litdata.utilities.packing import _pack_greedily if _TQDM_AVAILABLE: @@ -788,8 +789,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul self._upload_index(output_dir, cache_dir, num_nodes, node_rank) if num_nodes == node_rank + 1: - with open(os.path.join(cache_dir, _INDEX_FILENAME)) as f: - config = json.load(f) + config = load_index_file(cache_dir) size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]]) num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]]) diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index c5966552..b9744f0a 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os from typing import Any, Dict, List, Optional, Tuple @@ -22,6 +21,7 @@ from litdata.streaming.sampler import ChunkedIndex from litdata.streaming.serializers import Serializer from litdata.utilities._pytree import tree_unflatten, treespec_loads +from litdata.utilities.dataset_utilities import load_index_file class ChunksConfig: @@ -53,18 +53,18 @@ def __init__( self._remote_dir = remote_dir self._item_loader = item_loader or PyTreeLoader() - with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f: - data = json.load(f) - _original_chunks = data["chunks"] - self._config = data["config"] - self._validate_item_loader() + # load data from `index.json` file + data = load_index_file(self._cache_dir) + _original_chunks = data["chunks"] + self._config = data["config"] + self._validate_item_loader() - assert _original_chunks is not None + assert _original_chunks is not None - if subsampled_files is None: - self._chunks = _original_chunks - else: - self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks) + if subsampled_files is None: + self._chunks = _original_chunks + else: + self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks) self._config["data_spec"] = treespec_loads(self._config["data_spec"]) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index d302214d..2213aaca 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -142,8 +142,32 @@ def load_item_from_chunk( fp.seek(begin) data = fp.read(end - begin) + # check for mosaic mds format + if "format" in self._config and self._config["format"] == "mds": + return self.mds_deserialize(data, chunk_index) return self.deserialize(data) + def mds_deserialize(self, raw_item_data: bytes, chunk_index: int) -> "PyTree": + """Deserialize the mds raw bytes into their python equivalent.""" + idx = 0 + sizes = [] + column_sizes = self._chunks[chunk_index]["column_sizes"] + # adapted from: MDSReader.deserialize : https://github.com/mosaicml/streaming/blob/main/streaming/base/format/mds/reader.py + for size in column_sizes: + if size: + sizes.append(size) + else: + (size,) = np.frombuffer(raw_item_data[idx : idx + 4], np.uint32) + sizes.append(size) + idx += 4 + data = [] + for size, data_format in zip(sizes, self._data_format): + serializer = self._serializers[data_format] + data_bytes = raw_item_data[idx : idx + size] + data.append(serializer.deserialize(data_bytes)) + idx += size + return tree_unflatten(data, self._config["data_spec"]) + def deserialize(self, raw_item_data: bytes) -> "PyTree": """Deserialize the raw bytes into their python equivalent.""" idx = self._shift_idx diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 65978e66..7e0c4838 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -51,9 +51,8 @@ def subsample_streaming_dataset( if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)): # load chunks from `index.json` file - with open(os.path.join(input_dir.path, _INDEX_FILENAME)) as f: - data = json.load(f) - original_chunks = data["chunks"] + data = load_index_file(input_dir.path) + original_chunks = data["chunks"] else: raise ValueError( f"The provided dataset `{input_dir.path}` doesn't contain any {_INDEX_FILENAME} file." @@ -115,3 +114,76 @@ def generate_roi(chunks: List[Dict[str, Any]], item_loader: Optional[BaseItemLoa roi.append((0, end)) return roi + + +def load_index_file(input_dir: str) -> Dict[str, Any]: + """Load index file from the specified input directory. + + This function supports loading both chunk-based and mds shard-based index files. + For shard-based files, it adapts the format to be compatible with chunk-based processing. + + Args: + input_dir (str): The directory containing the index file. + + Returns: + Dict[str, Any]: The loaded and possibly adapted index data. + + Raises: + FileNotFoundError: If the index file does not exist in the input directory. + + """ + index_filepath = os.path.join(input_dir, _INDEX_FILENAME) + try: + with open(index_filepath) as f: + data = json.load(f) + + if "chunks" not in data and "shards" in data: + # load mds shard-based index file and adapt to chunks format + return adapt_mds_shards_to_chunks(data) + + return data + except FileNotFoundError: + raise FileNotFoundError(f"Index file not found at {index_filepath}.") + + +def adapt_mds_shards_to_chunks(data: Dict[str, Any]) -> Dict[str, Any]: + """Adapt mds shard-based index data to chunk-based format for compatibility. + For more details about MDS, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming + + Args: + data (Dict[str, Any]): The original index data containing shards. + + Returns: + Dict[str, Any]: Adapted index data with chunks format. + """ + chunks = [] + shards = data["shards"] + for shard in shards: + chunks.append( + { + "chunk_bytes": shard["zip_data"]["bytes"], + "chunk_size": shard["samples"], + "column_sizes": shard["column_sizes"], + "dim": None, + "filename": shard["zip_data"]["basename"], + } + ) + data["chunks"] = chunks + + data_spec = [ + 1, + { + "type": "builtins.dict", + "context": json.dumps(shards[0]["column_names"]), + "children_spec": [{"type": None, "context": None, "children_spec": []} for _ in shards[0]["column_names"]], + }, + ] + data["config"] = { + "chunk_bytes": sum(shard["zip_data"]["bytes"] for shard in shards), + "chunk_size": sum(shard["samples"] for shard in shards), + "compression": shards[0]["compression"], + "data_format": shards[0]["column_encodings"], + "format": shards[0]["format"], + "data_spec": json.dumps(data_spec), + } + return data diff --git a/src/litdata/utilities/train_test_split.py b/src/litdata/utilities/train_test_split.py index 7f8fbef9..d31fb076 100644 --- a/src/litdata/utilities/train_test_split.py +++ b/src/litdata/utilities/train_test_split.py @@ -1,4 +1,3 @@ -import json import logging import os from copy import deepcopy @@ -8,6 +7,7 @@ from litdata import StreamingDataset from litdata.constants import _INDEX_FILENAME +from litdata.utilities.dataset_utilities import load_index_file from litdata.utilities.subsample import shuffle_lists_together, subsample_filenames_and_roi @@ -55,14 +55,12 @@ def train_test_split( if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)): # load chunks from `index.json` file - with open(os.path.join(input_dir.path, _INDEX_FILENAME)) as f: - data = json.load(f) - original_chunks = data["chunks"] - subsampled_chunks = [ - _org_chunk - for _org_chunk in original_chunks - if _org_chunk["filename"] in dummy_subsampled_chunk_filename - ] + data = load_index_file(input_dir.path) + + original_chunks = data["chunks"] + subsampled_chunks = [ + _org_chunk for _org_chunk in original_chunks if _org_chunk["filename"] in dummy_subsampled_chunk_filename + ] else: raise ValueError("Couldn't load original chunk file.") diff --git a/tests/conftest.py b/tests/conftest.py index 4133e793..538d0bcb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,3 +8,25 @@ def teardown_process_group(): # noqa: PT004 yield if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.destroy_process_group() + + +@pytest.fixture() +def mosaic_mds_index_data(): + return { + "shards": [ + { + "column_encodings": ["int", "jpeg"], + "column_names": ["class", "image"], + "column_sizes": [8, None], + "compression": "zstd", + "format": "mds", + "hashes": [], + "raw_data": {"basename": "shard.00000.mds", "bytes": 125824, "hashes": {}}, + "samples": 100, + "size_limit": 67108864, + "version": 2, + "zip_data": {"basename": "shard.00000.mds.zstd", "bytes": 63407, "hashes": {}}, + } + ], + "version": 2, + } diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 193b3202..91de1a95 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -21,6 +21,7 @@ import numpy as np import pytest import torch +from litdata import train_test_split from litdata.constants import _ZSTD_AVAILABLE from litdata.processing import functions from litdata.streaming import Cache @@ -995,3 +996,69 @@ def test_subsample_streaming_dataset_with_token_loader(tmpdir, monkeypatch): ) assert len(dataset2) == int(len(dataset1) * 0.4) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows") +def test_dataset_with_mosaic_mds_data(tmpdir): + from PIL import Image + from streaming import MDSWriter + # example taken from: https://github.com/mosaicml/streaming + + # A dictionary mapping input fields to their data types + columns = {"image": "jpeg", "class": "int"} + # Shard compression, if any + compression = "zstd" + # Save the samples as shards using MDSWriter + with MDSWriter(out=str(tmpdir), columns=columns, compression=compression) as out: + for i in range(10): + sample = { + "image": Image.fromarray(np.random.randint(0, 256, (32, 32, 3), np.uint8)), + "class": i, + } + out.write(sample) + dataset = StreamingDataset(input_dir=str(tmpdir)) + assert len(dataset) == 10 + for i in range(10): + sample = dataset[i] + assert sample["class"] == i + + assert [sample["class"] for sample in dataset[:]] == list(range(10)) # test slicing + + # -------------- train_test_split -------------- + + train_ds, test_ds, val_ds = train_test_split(dataset, splits=[0.7, 0.2, 0.1]) + + assert len(train_ds) == 7 + assert len(test_ds) == 2 + assert len(val_ds) == 1 + + # -------------- subsample -------------- + + dataset = StreamingDataset(input_dir=str(tmpdir), subsample=0.4) + assert len(dataset) == 4 + assert [sample["class"] for sample in dataset[:]] == [0, 1, 2, 3] + + # -------------- works with dataloader -------------- + + dataset = StreamingDataset(input_dir=str(tmpdir)) + dataloader = DataLoader(dataset, batch_size=4, drop_last=True) + i = 0 + for batch in dataloader: + assert len(batch["class"]) == 4 + assert len(batch["image"]) == 4 + assert list(batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] + i += 1 + + dataloader = DataLoader(dataset, batch_size=4, drop_last=False) + i = 0 + for batch in dataloader: + if i == 2: + # last batch is smaller than batch_size + assert len(batch["class"]) == 2 + assert len(batch["image"]) == 2 + assert list(batch["class"]) == [4 * i, 4 * i + 1] + break + assert len(batch["class"]) == 4 + assert len(batch["image"]) == 4 + assert list(batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] + i += 1 diff --git a/tests/utilities/test_dataset_utilities.py b/tests/utilities/test_dataset_utilities.py index 40b3cfce..bb952fe9 100644 --- a/tests/utilities/test_dataset_utilities.py +++ b/tests/utilities/test_dataset_utilities.py @@ -1,10 +1,14 @@ +import json import os from unittest import mock +from litdata.constants import _INDEX_FILENAME from litdata.utilities.dataset_utilities import ( _should_replace_path, _try_create_cache_dir, + adapt_mds_shards_to_chunks, generate_roi, + load_index_file, ) @@ -44,3 +48,19 @@ def test_generate_roi(): my_roi = generate_roi(my_chunks) assert my_roi == [(0, 30), (0, 50), (0, 20), (0, 10)] + + +def test_load_index_file(tmpdir, mosaic_mds_index_data): + with open(os.path.join(tmpdir, _INDEX_FILENAME), "w") as f: + f.write(json.dumps(mosaic_mds_index_data)) + index_data = load_index_file(tmpdir) + assert "chunks" in index_data + assert "config" in index_data + assert len(mosaic_mds_index_data["shards"]) == len(index_data["chunks"]) + + +def test_adapt_mds_shards_to_chunks(mosaic_mds_index_data): + adapted_data = adapt_mds_shards_to_chunks(mosaic_mds_index_data) + assert "chunks" in adapted_data + assert "config" in adapted_data + assert len(mosaic_mds_index_data["shards"]) == len(adapted_data["chunks"])