From 298e1a533f7b89bb468ba784d9f49500765c0367 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sat, 6 Jul 2024 15:57:42 +0545 Subject: [PATCH 01/25] chore: update test.txt with mosaicml-streaming dependency --- requirements/test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/test.txt b/requirements/test.txt index d8240755..71de5307 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,4 +1,5 @@ coverage ==7.5.3 +mosaicml-streaming==0.7.6 pytest ==8.2.* pytest-cov ==5.0.0 pytest-timeout ==2.3.1 From 3f15c363647d6db3b0d42c7748240c8e5dc6de37 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sat, 6 Jul 2024 16:02:14 +0545 Subject: [PATCH 02/25] feat: add load_index_file function with supports for mds config --- src/litdata/utilities/dataset_utilities.py | 58 ++++++++++++++++++++-- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 65978e66..f64b725c 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -51,9 +51,8 @@ def subsample_streaming_dataset( if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)): # load chunks from `index.json` file - with open(os.path.join(input_dir.path, _INDEX_FILENAME)) as f: - data = json.load(f) - original_chunks = data["chunks"] + data = load_index_file(input_dir.path) + original_chunks = data["chunks"] else: raise ValueError( f"The provided dataset `{input_dir.path}` doesn't contain any {_INDEX_FILENAME} file." @@ -115,3 +114,56 @@ def generate_roi(chunks: List[Dict[str, Any]], item_loader: Optional[BaseItemLoa roi.append((0, end)) return roi + + +def load_index_file(input_dir: str) -> Dict[str, Any]: + """Load index file from the input_dir.""" + + index_filepath = os.path.join(input_dir, _INDEX_FILENAME) + try: + # load index.json file + with open(index_filepath, "r") as f: + data = json.load(f) + if "chunks" not in data: + raise KeyError(f"'chunks' not found in the index file at {index_filepath}.") + except KeyError as e: + # Verify the presence of MDS shards + # For more details, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming + if "shards" in data: + # adapt mosiac index to litdata index + chunks = [] + shards = data["shards"] + for shard in shards: + chunks.append( + { + "chunk_bytes": shard["zip_data"]["bytes"], + "chunk_size": shard["samples"], + "column_sizes": shard["column_sizes"], + "dim": None, + "filename": shard["zip_data"]["basename"], + } + ) + data["chunks"] = chunks + # TODO: create a robust data_spec + data_spec = [ + 1, + { + "type": "builtins.dict", + "context": json.dumps(shards[0]["column_names"]), + "children_spec": [ + {"type": None, "context": None, "children_spec": []} for _ in shards[0]["column_names"] + ], + }, + ] + data["config"] = { + "chunk_bytes": sum([shard["zip_data"]["bytes"] for shard in shards]), + "chunk_size": sum([shard["samples"] for shard in shards]), + "compression": shards[0]["compression"], + "data_format": shards[0]["column_encodings"], + "format": shards[0]["format"], + "data_spec": json.dumps(data_spec), + } + return data + raise e + except FileNotFoundError: + raise FileNotFoundError(f"Index file not found at {index_filepath}.") From eb05ee748a2eb1cd094b007097a0f7e9711959cc Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sat, 6 Jul 2024 16:03:48 +0545 Subject: [PATCH 03/25] chore: replaces indexfile loading with reusable fn --- src/litdata/processing/data_processor.py | 4 ++-- src/litdata/utilities/train_test_split.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index d44662ae..eff73c6f 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -52,6 +52,7 @@ from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads from litdata.utilities.broadcast import broadcast_object from litdata.utilities.packing import _pack_greedily +from litdata.utilities.streaming import load_index_file if _TQDM_AVAILABLE: from tqdm.auto import tqdm as _tqdm @@ -788,8 +789,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul self._upload_index(output_dir, cache_dir, num_nodes, node_rank) if num_nodes == node_rank + 1: - with open(os.path.join(cache_dir, _INDEX_FILENAME)) as f: - config = json.load(f) + config = load_index_file(self._cache_dir) size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]]) num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]]) diff --git a/src/litdata/utilities/train_test_split.py b/src/litdata/utilities/train_test_split.py index 7f8fbef9..9bd415dc 100644 --- a/src/litdata/utilities/train_test_split.py +++ b/src/litdata/utilities/train_test_split.py @@ -1,4 +1,3 @@ -import json import logging import os from copy import deepcopy @@ -8,6 +7,7 @@ from litdata import StreamingDataset from litdata.constants import _INDEX_FILENAME +from litdata.utilities.streaming import load_index_file from litdata.utilities.subsample import shuffle_lists_together, subsample_filenames_and_roi @@ -55,14 +55,14 @@ def train_test_split( if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)): # load chunks from `index.json` file - with open(os.path.join(input_dir.path, _INDEX_FILENAME)) as f: - data = json.load(f) - original_chunks = data["chunks"] - subsampled_chunks = [ - _org_chunk - for _org_chunk in original_chunks - if _org_chunk["filename"] in dummy_subsampled_chunk_filename - ] + data = load_index_file(input_dir.path) + + original_chunks = data["chunks"] + subsampled_chunks = [ + _org_chunk + for _org_chunk in original_chunks + if _org_chunk["filename"] in dummy_subsampled_chunk_filename + ] else: raise ValueError("Couldn't load original chunk file.") From 66da1e5aba6fdb298ee9a0babdc48c611c672ec4 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sat, 6 Jul 2024 16:04:24 +0545 Subject: [PATCH 04/25] feat: updates config to load indexfile --- src/litdata/streaming/config.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index c5966552..b9744f0a 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os from typing import Any, Dict, List, Optional, Tuple @@ -22,6 +21,7 @@ from litdata.streaming.sampler import ChunkedIndex from litdata.streaming.serializers import Serializer from litdata.utilities._pytree import tree_unflatten, treespec_loads +from litdata.utilities.dataset_utilities import load_index_file class ChunksConfig: @@ -53,18 +53,18 @@ def __init__( self._remote_dir = remote_dir self._item_loader = item_loader or PyTreeLoader() - with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f: - data = json.load(f) - _original_chunks = data["chunks"] - self._config = data["config"] - self._validate_item_loader() + # load data from `index.json` file + data = load_index_file(self._cache_dir) + _original_chunks = data["chunks"] + self._config = data["config"] + self._validate_item_loader() - assert _original_chunks is not None + assert _original_chunks is not None - if subsampled_files is None: - self._chunks = _original_chunks - else: - self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks) + if subsampled_files is None: + self._chunks = _original_chunks + else: + self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks) self._config["data_spec"] = treespec_loads(self._config["data_spec"]) From 615c894041b3974dee7fb94ebdd42127f95fccd0 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sat, 6 Jul 2024 16:05:03 +0545 Subject: [PATCH 05/25] feat: adds fn to deserialize mds written bytes data --- src/litdata/streaming/item_loader.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index d302214d..2213aaca 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -142,8 +142,32 @@ def load_item_from_chunk( fp.seek(begin) data = fp.read(end - begin) + # check for mosaic mds format + if "format" in self._config and self._config["format"] == "mds": + return self.mds_deserialize(data, chunk_index) return self.deserialize(data) + def mds_deserialize(self, raw_item_data: bytes, chunk_index: int) -> "PyTree": + """Deserialize the mds raw bytes into their python equivalent.""" + idx = 0 + sizes = [] + column_sizes = self._chunks[chunk_index]["column_sizes"] + # adapted from: MDSReader.deserialize : https://github.com/mosaicml/streaming/blob/main/streaming/base/format/mds/reader.py + for size in column_sizes: + if size: + sizes.append(size) + else: + (size,) = np.frombuffer(raw_item_data[idx : idx + 4], np.uint32) + sizes.append(size) + idx += 4 + data = [] + for size, data_format in zip(sizes, self._data_format): + serializer = self._serializers[data_format] + data_bytes = raw_item_data[idx : idx + size] + data.append(serializer.deserialize(data_bytes)) + idx += size + return tree_unflatten(data, self._config["data_spec"]) + def deserialize(self, raw_item_data: bytes) -> "PyTree": """Deserialize the raw bytes into their python equivalent.""" idx = self._shift_idx From 0e690f652cdb4ab3c8b45cae8e6684e8dcbb95eb Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sat, 6 Jul 2024 16:05:43 +0545 Subject: [PATCH 06/25] feat: adds tests to test the functionality to read mds writer dataset --- tests/conftest.py | 22 +++++++++++++++++++++ tests/streaming/test_dataset.py | 24 +++++++++++++++++++++++ tests/utilities/test_dataset_utilities.py | 12 ++++++++++++ 3 files changed, 58 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 4133e793..538d0bcb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,3 +8,25 @@ def teardown_process_group(): # noqa: PT004 yield if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.destroy_process_group() + + +@pytest.fixture() +def mosaic_mds_index_data(): + return { + "shards": [ + { + "column_encodings": ["int", "jpeg"], + "column_names": ["class", "image"], + "column_sizes": [8, None], + "compression": "zstd", + "format": "mds", + "hashes": [], + "raw_data": {"basename": "shard.00000.mds", "bytes": 125824, "hashes": {}}, + "samples": 100, + "size_limit": 67108864, + "version": 2, + "zip_data": {"basename": "shard.00000.mds.zstd", "bytes": 63407, "hashes": {}}, + } + ], + "version": 2, + } diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 193b3202..98293c4b 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -995,3 +995,27 @@ def test_subsample_streaming_dataset_with_token_loader(tmpdir, monkeypatch): ) assert len(dataset2) == int(len(dataset1) * 0.4) + + +def test_dataset_with_mosaic_mds_data(tmpdir): + from PIL import Image + from streaming import MDSWriter + # example taken from: https://github.com/mosaicml/streaming + + # A dictionary mapping input fields to their data types + columns = {"image": "jpeg", "class": "int"} + # Shard compression, if any + compression = "zstd" + # Save the samples as shards using MDSWriter + with MDSWriter(out=str(tmpdir), columns=columns, compression=compression) as out: + for i in range(100): + sample = { + "image": Image.fromarray(np.random.randint(0, 256, (32, 32, 3), np.uint8)), + "class": i, + } + out.write(sample) + dataset = StreamingDataset(input_dir=str(tmpdir)) + assert len(dataset) == 100 + for i in range(100): + sample = dataset[i] + assert sample["class"] == i diff --git a/tests/utilities/test_dataset_utilities.py b/tests/utilities/test_dataset_utilities.py index 40b3cfce..3fa2149c 100644 --- a/tests/utilities/test_dataset_utilities.py +++ b/tests/utilities/test_dataset_utilities.py @@ -1,10 +1,13 @@ +import json import os from unittest import mock +from litdata.constants import _INDEX_FILENAME from litdata.utilities.dataset_utilities import ( _should_replace_path, _try_create_cache_dir, generate_roi, + load_index_file, ) @@ -44,3 +47,12 @@ def test_generate_roi(): my_roi = generate_roi(my_chunks) assert my_roi == [(0, 30), (0, 50), (0, 20), (0, 10)] + + +def test_load_index_file(tmpdir, mosaic_mds_index_data): + with open(os.path.join(tmpdir, _INDEX_FILENAME), "w") as f: + f.write(json.dumps(mosaic_mds_index_data)) + index_data = load_index_file(tmpdir) + assert "chunks" in index_data + assert "config" in index_data + assert len(mosaic_mds_index_data["shards"]) == len(index_data["chunks"]) From 158453261ebe047706c25e0d4a291d2c1b0fab27 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 Jul 2024 10:23:08 +0000 Subject: [PATCH 07/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/utilities/dataset_utilities.py | 2 +- src/litdata/utilities/train_test_split.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index f64b725c..0105c0a0 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -122,7 +122,7 @@ def load_index_file(input_dir: str) -> Dict[str, Any]: index_filepath = os.path.join(input_dir, _INDEX_FILENAME) try: # load index.json file - with open(index_filepath, "r") as f: + with open(index_filepath) as f: data = json.load(f) if "chunks" not in data: raise KeyError(f"'chunks' not found in the index file at {index_filepath}.") diff --git a/src/litdata/utilities/train_test_split.py b/src/litdata/utilities/train_test_split.py index 9bd415dc..b888729a 100644 --- a/src/litdata/utilities/train_test_split.py +++ b/src/litdata/utilities/train_test_split.py @@ -59,9 +59,7 @@ def train_test_split( original_chunks = data["chunks"] subsampled_chunks = [ - _org_chunk - for _org_chunk in original_chunks - if _org_chunk["filename"] in dummy_subsampled_chunk_filename + _org_chunk for _org_chunk in original_chunks if _org_chunk["filename"] in dummy_subsampled_chunk_filename ] else: raise ValueError("Couldn't load original chunk file.") From 13428fff8e7c80d567ed571a84449fca18afd625 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sat, 6 Jul 2024 16:13:06 +0545 Subject: [PATCH 08/25] fix: import path for `load_index_file` fn --- src/litdata/processing/data_processor.py | 2 +- src/litdata/utilities/train_test_split.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index eff73c6f..07bf9295 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -51,8 +51,8 @@ from litdata.streaming.resolver import _resolve_dir from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads from litdata.utilities.broadcast import broadcast_object +from litdata.utilities.dataset_utilities import load_index_file from litdata.utilities.packing import _pack_greedily -from litdata.utilities.streaming import load_index_file if _TQDM_AVAILABLE: from tqdm.auto import tqdm as _tqdm diff --git a/src/litdata/utilities/train_test_split.py b/src/litdata/utilities/train_test_split.py index b888729a..d31fb076 100644 --- a/src/litdata/utilities/train_test_split.py +++ b/src/litdata/utilities/train_test_split.py @@ -7,7 +7,7 @@ from litdata import StreamingDataset from litdata.constants import _INDEX_FILENAME -from litdata.utilities.streaming import load_index_file +from litdata.utilities.dataset_utilities import load_index_file from litdata.utilities.subsample import shuffle_lists_together, subsample_filenames_and_roi From 3fe0b56fa43cc4b2d9106f3b9f4972ce2fef42de Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sat, 6 Jul 2024 16:21:05 +0545 Subject: [PATCH 09/25] fixes:type --- src/litdata/utilities/dataset_utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 0105c0a0..f7d39bc6 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -116,7 +116,7 @@ def generate_roi(chunks: List[Dict[str, Any]], item_loader: Optional[BaseItemLoa return roi -def load_index_file(input_dir: str) -> Dict[str, Any]: +def load_index_file(input_dir: str) -> Dict[str, Any] or None: """Load index file from the input_dir.""" index_filepath = os.path.join(input_dir, _INDEX_FILENAME) From e3773b0046a584338dd2196a322f1b275400d355 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sat, 6 Jul 2024 16:30:04 +0545 Subject: [PATCH 10/25] fix: `load_index_file` input dir --- src/litdata/processing/data_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 07bf9295..d2220a00 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -789,7 +789,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul self._upload_index(output_dir, cache_dir, num_nodes, node_rank) if num_nodes == node_rank + 1: - config = load_index_file(self._cache_dir) + config = load_index_file(cache_dir) size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]]) num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]]) From c8c88b943adaa7700089ef7a024232e3fde71df4 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sat, 6 Jul 2024 16:34:07 +0545 Subject: [PATCH 11/25] fix: return type --- src/litdata/utilities/dataset_utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index f7d39bc6..0105c0a0 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -116,7 +116,7 @@ def generate_roi(chunks: List[Dict[str, Any]], item_loader: Optional[BaseItemLoa return roi -def load_index_file(input_dir: str) -> Dict[str, Any] or None: +def load_index_file(input_dir: str) -> Dict[str, Any]: """Load index file from the input_dir.""" index_filepath = os.path.join(input_dir, _INDEX_FILENAME) From 82484105265bcc06bf996a1a336e8ade28fd9fa3 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sat, 6 Jul 2024 16:53:35 +0545 Subject: [PATCH 12/25] feat: adds default missing return case --- src/litdata/utilities/dataset_utilities.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 0105c0a0..139a3b1b 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -126,6 +126,7 @@ def load_index_file(input_dir: str) -> Dict[str, Any]: data = json.load(f) if "chunks" not in data: raise KeyError(f"'chunks' not found in the index file at {index_filepath}.") + return data except KeyError as e: # Verify the presence of MDS shards # For more details, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming From 5a388006dedda155a615e33fa5f5be8bceb5aef1 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 7 Jul 2024 22:45:27 +0545 Subject: [PATCH 13/25] Update README.md: fix typo in parallelize --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 0b6e6575..003eac48 100644 --- a/README.md +++ b/README.md @@ -174,7 +174,7 @@ ld.map( **Key benefits:** -✅ Paralellize processing: Reduce processing time by transforming data across multiple machines simultaneously. +✅ Parallelize processing: Reduce processing time by transforming data across multiple machines simultaneously. ✅ Scale to large data: Increase the size of datasets you can efficiently handle. ✅ Flexible usecases: Resize images, create embeddings, scrape the internet, etc... ✅ Run local or cloud: Run on your own machines or auto-scale to 1000s of cloud GPUs with Lightning Studios. @@ -638,7 +638,7 @@ Time to optimize 1.2 million ImageNet images (Faster is better): ---- -# Paralellize transforms and data optimization on cloud machines +# Parallelize transforms and data optimization on cloud machines
Lightning
From bfe93a8e4969739e31455fcda5201d3dc6e0b755 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 7 Jul 2024 22:47:21 +0545 Subject: [PATCH 14/25] chore: updates test for mds dataset --- tests/streaming/test_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 98293c4b..2d93a516 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -1008,14 +1008,14 @@ def test_dataset_with_mosaic_mds_data(tmpdir): compression = "zstd" # Save the samples as shards using MDSWriter with MDSWriter(out=str(tmpdir), columns=columns, compression=compression) as out: - for i in range(100): + for i in range(10): sample = { "image": Image.fromarray(np.random.randint(0, 256, (32, 32, 3), np.uint8)), "class": i, } out.write(sample) dataset = StreamingDataset(input_dir=str(tmpdir)) - assert len(dataset) == 100 - for i in range(100): + assert len(dataset) == 10 + for i in range(10): sample = dataset[i] assert sample["class"] == i From f9dfe2562ec85fc2dc134046c14513027d1289f6 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 7 Jul 2024 23:11:31 +0545 Subject: [PATCH 15/25] refactor: Improve index file loading and adapt MDS shards to chunks format --- src/litdata/utilities/dataset_utilities.py | 106 ++++++++++++--------- 1 file changed, 63 insertions(+), 43 deletions(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 139a3b1b..45a1bd6c 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -117,54 +117,74 @@ def generate_roi(chunks: List[Dict[str, Any]], item_loader: Optional[BaseItemLoa def load_index_file(input_dir: str) -> Dict[str, Any]: - """Load index file from the input_dir.""" + """Load index file from the specified input directory. + This function supports loading both chunk-based and mds shard-based index files. + For shard-based files, it adapts the format to be compatible with chunk-based processing. + + Args: + input_dir (str): The directory containing the index file. + + Returns: + Dict[str, Any]: The loaded and possibly adapted index data. + + Raises: + ValueError: If the index file format is invalid. + FileNotFoundError: If the index file does not exist in the input directory. + """ index_filepath = os.path.join(input_dir, _INDEX_FILENAME) try: - # load index.json file with open(index_filepath) as f: data = json.load(f) - if "chunks" not in data: - raise KeyError(f"'chunks' not found in the index file at {index_filepath}.") - return data - except KeyError as e: - # Verify the presence of MDS shards - # For more details, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming - if "shards" in data: - # adapt mosiac index to litdata index - chunks = [] - shards = data["shards"] - for shard in shards: - chunks.append( - { - "chunk_bytes": shard["zip_data"]["bytes"], - "chunk_size": shard["samples"], - "column_sizes": shard["column_sizes"], - "dim": None, - "filename": shard["zip_data"]["basename"], - } - ) - data["chunks"] = chunks - # TODO: create a robust data_spec - data_spec = [ - 1, - { - "type": "builtins.dict", - "context": json.dumps(shards[0]["column_names"]), - "children_spec": [ - {"type": None, "context": None, "children_spec": []} for _ in shards[0]["column_names"] - ], - }, - ] - data["config"] = { - "chunk_bytes": sum([shard["zip_data"]["bytes"] for shard in shards]), - "chunk_size": sum([shard["samples"] for shard in shards]), - "compression": shards[0]["compression"], - "data_format": shards[0]["column_encodings"], - "format": shards[0]["format"], - "data_spec": json.dumps(data_spec), - } + + if "chunks" in data: return data - raise e + elif "shards" in data: + return adapt_mds_shards_to_chunks(data) + else: + raise ValueError(f"Invalid index file format at {index_filepath}.") except FileNotFoundError: raise FileNotFoundError(f"Index file not found at {index_filepath}.") + + +def adapt_mds_shards_to_chunks(data: Dict[str, Any]) -> Dict[str, Any]: + """Adapt mds shard-based index data to chunk-based format for compatibility. + For more details about MDS, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming + + Args: + data (Dict[str, Any]): The original index data containing shards. + + Returns: + Dict[str, Any]: Adapted index data with chunks format. + """ + chunks = [] + shards = data["shards"] + for shard in shards: + chunks.append( + { + "chunk_bytes": shard["zip_data"]["bytes"], + "chunk_size": shard["samples"], + "column_sizes": shard["column_sizes"], + "dim": None, + "filename": shard["zip_data"]["basename"], + } + ) + data["chunks"] = chunks + + data_spec = [ + 1, + { + "type": "builtins.dict", + "context": json.dumps(shards[0]["column_names"]), + "children_spec": [{"type": None, "context": None, "children_spec": []} for _ in shards[0]["column_names"]], + }, + ] + data["config"] = { + "chunk_bytes": sum(shard["zip_data"]["bytes"] for shard in shards), + "chunk_size": sum(shard["samples"] for shard in shards), + "compression": shards[0]["compression"], + "data_format": shards[0]["column_encodings"], + "format": shards[0]["format"], + "data_spec": json.dumps(data_spec), + } + return data From 1f76ffdc0215d2fd7cbf831348cd687cd29afdf0 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 7 Jul 2024 23:11:37 +0545 Subject: [PATCH 16/25] chore: Add unit test for adapting MDS shards to chunks format --- tests/utilities/test_dataset_utilities.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/utilities/test_dataset_utilities.py b/tests/utilities/test_dataset_utilities.py index 3fa2149c..bb952fe9 100644 --- a/tests/utilities/test_dataset_utilities.py +++ b/tests/utilities/test_dataset_utilities.py @@ -6,6 +6,7 @@ from litdata.utilities.dataset_utilities import ( _should_replace_path, _try_create_cache_dir, + adapt_mds_shards_to_chunks, generate_roi, load_index_file, ) @@ -56,3 +57,10 @@ def test_load_index_file(tmpdir, mosaic_mds_index_data): assert "chunks" in index_data assert "config" in index_data assert len(mosaic_mds_index_data["shards"]) == len(index_data["chunks"]) + + +def test_adapt_mds_shards_to_chunks(mosaic_mds_index_data): + adapted_data = adapt_mds_shards_to_chunks(mosaic_mds_index_data) + assert "chunks" in adapted_data + assert "config" in adapted_data + assert len(mosaic_mds_index_data["shards"]) == len(adapted_data["chunks"]) From f26dc32bf2c8ac53d03fdbc36919eebc13685b7a Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 7 Jul 2024 23:19:36 +0545 Subject: [PATCH 17/25] refactor: Skip test_dataset_with_mosaic_mds_data on Windows --- tests/streaming/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 2d93a516..e37cc111 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -996,7 +996,7 @@ def test_subsample_streaming_dataset_with_token_loader(tmpdir, monkeypatch): assert len(dataset2) == int(len(dataset1) * 0.4) - +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows") def test_dataset_with_mosaic_mds_data(tmpdir): from PIL import Image from streaming import MDSWriter From 930eb1d10b3ff5731623e67a298de0037a5651d1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 7 Jul 2024 17:34:58 +0000 Subject: [PATCH 18/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/utilities/dataset_utilities.py | 6 +++--- tests/streaming/test_dataset.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 45a1bd6c..3abe1b41 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -131,6 +131,7 @@ def load_index_file(input_dir: str) -> Dict[str, Any]: Raises: ValueError: If the index file format is invalid. FileNotFoundError: If the index file does not exist in the input directory. + """ index_filepath = os.path.join(input_dir, _INDEX_FILENAME) try: @@ -139,10 +140,9 @@ def load_index_file(input_dir: str) -> Dict[str, Any]: if "chunks" in data: return data - elif "shards" in data: + if "shards" in data: return adapt_mds_shards_to_chunks(data) - else: - raise ValueError(f"Invalid index file format at {index_filepath}.") + raise ValueError(f"Invalid index file format at {index_filepath}.") except FileNotFoundError: raise FileNotFoundError(f"Index file not found at {index_filepath}.") diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index e37cc111..2384dd81 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -996,6 +996,7 @@ def test_subsample_streaming_dataset_with_token_loader(tmpdir, monkeypatch): assert len(dataset2) == int(len(dataset1) * 0.4) + @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows") def test_dataset_with_mosaic_mds_data(tmpdir): from PIL import Image From 54c32d526a9826fb4acabafbea6a65aa45eb3fe2 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 7 Jul 2024 23:31:03 +0545 Subject: [PATCH 19/25] refactor: Improve index file loading and adapt MDS shards to chunks format --- src/litdata/utilities/dataset_utilities.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 3abe1b41..c63677df 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -138,11 +138,11 @@ def load_index_file(input_dir: str) -> Dict[str, Any]: with open(index_filepath) as f: data = json.load(f) - if "chunks" in data: - return data - if "shards" in data: + if "chunks" not in data and "shards" in data: + # load mds shard-based index file and adapt to chunks format return adapt_mds_shards_to_chunks(data) - raise ValueError(f"Invalid index file format at {index_filepath}.") + + return data except FileNotFoundError: raise FileNotFoundError(f"Index file not found at {index_filepath}.") From 23e6fdf4fb033dd01c00609629954e1c7d91190b Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 8 Jul 2024 02:58:53 +0000 Subject: [PATCH 20/25] test streamingDataset features for mosaic mds --- tests/streaming/test_dataset.py | 42 +++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 2384dd81..daccf74f 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -21,6 +21,7 @@ import numpy as np import pytest import torch +from litdata import train_test_split from litdata.constants import _ZSTD_AVAILABLE from litdata.processing import functions from litdata.streaming import Cache @@ -1020,3 +1021,44 @@ def test_dataset_with_mosaic_mds_data(tmpdir): for i in range(10): sample = dataset[i] assert sample["class"] == i + + assert [sample["class"] for sample in dataset[:]] == list(range(10)) # test slicing + + # -------------- train_test_split -------------- + + train_ds, test_ds, val_ds = train_test_split(dataset, splits=[0.7, 0.2, 0.1]) + + assert len(train_ds) == 7 + assert len(test_ds) == 2 + assert len(val_ds) == 1 + + # -------------- subsample -------------- + + dataset = StreamingDataset(input_dir=str(tmpdir), subsample=0.4) + assert len(dataset) == 4 + assert [sample["class"] for sample in dataset[:]] == [0, 1, 2, 3] + + # -------------- works with dataloader -------------- + + dataset = StreamingDataset(input_dir=str(tmpdir)) + dataloader = DataLoader(dataset, batch_size=4, drop_last=True) + i = 0 + for batch in dataloader: + assert len(batch["class"]) == 4 + assert len(batch["image"]) == 4 + assert [_class for _class in batch["class"]] == [4*i, 4*i+1, 4*i+2, 4*i+3] + i += 1 + + dataloader = DataLoader(dataset, batch_size=4, drop_last=False) + i = 0 + for batch in dataloader: + if (i == 2): + # last batch is smaller than batch_size + assert len(batch["class"]) == 2 + assert len(batch["image"]) == 2 + assert [_class for _class in batch["class"]] == [4*i, 4*i+1] + break + assert len(batch["class"]) == 4 + assert len(batch["image"]) == 4 + assert [_class for _class in batch["class"]] == [4*i, 4*i+1, 4*i+2, 4*i+3] + i += 1 From ce6db095ed1cb75042457d828266874016f15074 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Jul 2024 03:01:56 +0000 Subject: [PATCH 21/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/streaming/test_dataset.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index daccf74f..7f758349 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -1021,8 +1021,8 @@ def test_dataset_with_mosaic_mds_data(tmpdir): for i in range(10): sample = dataset[i] assert sample["class"] == i - - assert [sample["class"] for sample in dataset[:]] == list(range(10)) # test slicing + + assert [sample["class"] for sample in dataset[:]] == list(range(10)) # test slicing # -------------- train_test_split -------------- @@ -1046,19 +1046,19 @@ def test_dataset_with_mosaic_mds_data(tmpdir): for batch in dataloader: assert len(batch["class"]) == 4 assert len(batch["image"]) == 4 - assert [_class for _class in batch["class"]] == [4*i, 4*i+1, 4*i+2, 4*i+3] + assert [_class for _class in batch["class"]] == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] i += 1 - + dataloader = DataLoader(dataset, batch_size=4, drop_last=False) i = 0 for batch in dataloader: - if (i == 2): + if i == 2: # last batch is smaller than batch_size assert len(batch["class"]) == 2 assert len(batch["image"]) == 2 - assert [_class for _class in batch["class"]] == [4*i, 4*i+1] + assert [_class for _class in batch["class"]] == [4 * i, 4 * i + 1] break assert len(batch["class"]) == 4 assert len(batch["image"]) == 4 - assert [_class for _class in batch["class"]] == [4*i, 4*i+1, 4*i+2, 4*i+3] + assert [_class for _class in batch["class"]] == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] i += 1 From b25eeff45d705c0a4068189806e50734da6c5cfb Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 8 Jul 2024 03:05:22 +0000 Subject: [PATCH 22/25] fix pre-commit-ci errors --- tests/streaming/test_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 7f758349..a8534d5d 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -1046,7 +1046,7 @@ def test_dataset_with_mosaic_mds_data(tmpdir): for batch in dataloader: assert len(batch["class"]) == 4 assert len(batch["image"]) == 4 - assert [_class for _class in batch["class"]] == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] + assert list(_class for _class in batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] i += 1 dataloader = DataLoader(dataset, batch_size=4, drop_last=False) @@ -1056,9 +1056,9 @@ def test_dataset_with_mosaic_mds_data(tmpdir): # last batch is smaller than batch_size assert len(batch["class"]) == 2 assert len(batch["image"]) == 2 - assert [_class for _class in batch["class"]] == [4 * i, 4 * i + 1] + assert list(_class for _class in batch["class"]) == [4 * i, 4 * i + 1] break assert len(batch["class"]) == 4 assert len(batch["image"]) == 4 - assert [_class for _class in batch["class"]] == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] + assert list(_class for _class in batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] i += 1 From 30a6ce89e69a6676fd4ceebb9145ced7bf6cf316 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 8 Jul 2024 03:06:30 +0000 Subject: [PATCH 23/25] fix pre-commit-ci list comprehension with yield --- tests/streaming/test_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index a8534d5d..acdd0a5d 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -1046,7 +1046,7 @@ def test_dataset_with_mosaic_mds_data(tmpdir): for batch in dataloader: assert len(batch["class"]) == 4 assert len(batch["image"]) == 4 - assert list(_class for _class in batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] + assert (_class for _class in batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] i += 1 dataloader = DataLoader(dataset, batch_size=4, drop_last=False) @@ -1056,9 +1056,9 @@ def test_dataset_with_mosaic_mds_data(tmpdir): # last batch is smaller than batch_size assert len(batch["class"]) == 2 assert len(batch["image"]) == 2 - assert list(_class for _class in batch["class"]) == [4 * i, 4 * i + 1] + assert (_class for _class in batch["class"]) == [4 * i, 4 * i + 1] break assert len(batch["class"]) == 4 assert len(batch["image"]) == 4 - assert list(_class for _class in batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] + assert (_class for _class in batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] i += 1 From 575615c3184e5dfa5894df99f5a9629d89f2120e Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 8 Jul 2024 03:08:54 +0000 Subject: [PATCH 24/25] fix failing tests bcoz of generators --- tests/streaming/test_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index acdd0a5d..91de1a95 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -1046,7 +1046,7 @@ def test_dataset_with_mosaic_mds_data(tmpdir): for batch in dataloader: assert len(batch["class"]) == 4 assert len(batch["image"]) == 4 - assert (_class for _class in batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] + assert list(batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] i += 1 dataloader = DataLoader(dataset, batch_size=4, drop_last=False) @@ -1056,9 +1056,9 @@ def test_dataset_with_mosaic_mds_data(tmpdir): # last batch is smaller than batch_size assert len(batch["class"]) == 2 assert len(batch["image"]) == 2 - assert (_class for _class in batch["class"]) == [4 * i, 4 * i + 1] + assert list(batch["class"]) == [4 * i, 4 * i + 1] break assert len(batch["class"]) == 4 assert len(batch["image"]) == 4 - assert (_class for _class in batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] + assert list(batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] i += 1 From 6b09b80d912643525a78ea0f2244401af534845d Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Mon, 8 Jul 2024 11:10:23 +0545 Subject: [PATCH 25/25] fix: docs for the fn --- src/litdata/utilities/dataset_utilities.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index c63677df..7e0c4838 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -129,7 +129,6 @@ def load_index_file(input_dir: str) -> Dict[str, Any]: Dict[str, Any]: The loaded and possibly adapted index data. Raises: - ValueError: If the index file format is invalid. FileNotFoundError: If the index file does not exist in the input directory. """