Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: adds support for reading mosaic mds written dataset #210

Merged
merged 28 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
298e1a5
chore: update test.txt with mosaicml-streaming dependency
bhimrazy Jul 6, 2024
3f15c36
feat: add load_index_file function with supports for mds config
bhimrazy Jul 6, 2024
eb05ee7
chore: replaces indexfile loading with reusable fn
bhimrazy Jul 6, 2024
66da1e5
feat: updates config to load indexfile
bhimrazy Jul 6, 2024
615c894
feat: adds fn to deserialize mds written bytes data
bhimrazy Jul 6, 2024
0e690f6
feat: adds tests to test the functionality to read mds writer dataset
bhimrazy Jul 6, 2024
1584532
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 6, 2024
13428ff
fix: import path for `load_index_file` fn
bhimrazy Jul 6, 2024
3fe0b56
fixes:type
bhimrazy Jul 6, 2024
e3773b0
fix: `load_index_file` input dir
bhimrazy Jul 6, 2024
c8c88b9
fix: return type
bhimrazy Jul 6, 2024
8248410
feat: adds default missing return case
bhimrazy Jul 6, 2024
2dedd0b
Merge branch 'main' into feat/adds-mosaic-mds-support
bhimrazy Jul 6, 2024
281329f
Merge branch 'main' into feat/adds-mosaic-mds-support
bhimrazy Jul 6, 2024
d0aa2ab
Merge branch 'main' into feat/adds-mosaic-mds-support
bhimrazy Jul 7, 2024
5a38800
Update README.md: fix typo in parallelize
bhimrazy Jul 7, 2024
bfe93a8
chore: updates test for mds dataset
bhimrazy Jul 7, 2024
f9dfe25
refactor: Improve index file loading and adapt MDS shards to chunks f…
bhimrazy Jul 7, 2024
1f76ffd
chore: Add unit test for adapting MDS shards to chunks format
bhimrazy Jul 7, 2024
f26dc32
refactor: Skip test_dataset_with_mosaic_mds_data on Windows
bhimrazy Jul 7, 2024
930eb1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2024
54c32d5
refactor: Improve index file loading and adapt MDS shards to chunks f…
bhimrazy Jul 7, 2024
23e6fdf
test streamingDataset features for mosaic mds
deependujha Jul 8, 2024
ce6db09
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2024
b25eeff
fix pre-commit-ci errors
deependujha Jul 8, 2024
30a6ce8
fix pre-commit-ci list comprehension with yield
deependujha Jul 8, 2024
575615c
fix failing tests bcoz of generators
deependujha Jul 8, 2024
6b09b80
fix: docs for the fn
bhimrazy Jul 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ ld.map(

**Key benefits:**

Paralellize processing: Reduce processing time by transforming data across multiple machines simultaneously.
Parallelize processing: Reduce processing time by transforming data across multiple machines simultaneously.
✅ Scale to large data: Increase the size of datasets you can efficiently handle.
✅ Flexible usecases: Resize images, create embeddings, scrape the internet, etc...
✅ Run local or cloud: Run on your own machines or auto-scale to 1000s of cloud GPUs with Lightning Studios.
Expand Down Expand Up @@ -638,7 +638,7 @@ Time to optimize 1.2 million ImageNet images (Faster is better):

----

# Paralellize transforms and data optimization on cloud machines
# Parallelize transforms and data optimization on cloud machines
<div align="center">
<img alt="Lightning" src="https://pl-flash-data.s3.amazonaws.com/data-prep.jpg" width="700px">
</div>
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
coverage ==7.5.3
mosaicml-streaming==0.7.6
pytest ==8.2.*
pytest-cov ==5.0.0
pytest-timeout ==2.3.1
Expand Down
4 changes: 2 additions & 2 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from litdata.streaming.resolver import _resolve_dir
from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads
from litdata.utilities.broadcast import broadcast_object
from litdata.utilities.dataset_utilities import load_index_file
from litdata.utilities.packing import _pack_greedily

if _TQDM_AVAILABLE:
Expand Down Expand Up @@ -788,8 +789,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul
self._upload_index(output_dir, cache_dir, num_nodes, node_rank)

if num_nodes == node_rank + 1:
with open(os.path.join(cache_dir, _INDEX_FILENAME)) as f:
config = json.load(f)
config = load_index_file(cache_dir)

size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]])
num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]])
Expand Down
22 changes: 11 additions & 11 deletions src/litdata/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -22,6 +21,7 @@
from litdata.streaming.sampler import ChunkedIndex
from litdata.streaming.serializers import Serializer
from litdata.utilities._pytree import tree_unflatten, treespec_loads
from litdata.utilities.dataset_utilities import load_index_file


class ChunksConfig:
Expand Down Expand Up @@ -53,18 +53,18 @@ def __init__(
self._remote_dir = remote_dir
self._item_loader = item_loader or PyTreeLoader()

with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f:
data = json.load(f)
_original_chunks = data["chunks"]
self._config = data["config"]
self._validate_item_loader()
# load data from `index.json` file
data = load_index_file(self._cache_dir)
_original_chunks = data["chunks"]
self._config = data["config"]
self._validate_item_loader()

assert _original_chunks is not None
assert _original_chunks is not None

if subsampled_files is None:
self._chunks = _original_chunks
else:
self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks)
if subsampled_files is None:
self._chunks = _original_chunks
else:
self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks)

self._config["data_spec"] = treespec_loads(self._config["data_spec"])

Expand Down
24 changes: 24 additions & 0 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,32 @@ def load_item_from_chunk(
fp.seek(begin)
data = fp.read(end - begin)

# check for mosaic mds format
if "format" in self._config and self._config["format"] == "mds":
return self.mds_deserialize(data, chunk_index)
return self.deserialize(data)

def mds_deserialize(self, raw_item_data: bytes, chunk_index: int) -> "PyTree":
"""Deserialize the mds raw bytes into their python equivalent."""
idx = 0
sizes = []
column_sizes = self._chunks[chunk_index]["column_sizes"]
# adapted from: MDSReader.deserialize : https://github.com/mosaicml/streaming/blob/main/streaming/base/format/mds/reader.py
for size in column_sizes:
if size:
sizes.append(size)
else:
(size,) = np.frombuffer(raw_item_data[idx : idx + 4], np.uint32)
sizes.append(size)
idx += 4
data = []
for size, data_format in zip(sizes, self._data_format):
serializer = self._serializers[data_format]
data_bytes = raw_item_data[idx : idx + size]
data.append(serializer.deserialize(data_bytes))
idx += size
return tree_unflatten(data, self._config["data_spec"])

def deserialize(self, raw_item_data: bytes) -> "PyTree":
"""Deserialize the raw bytes into their python equivalent."""
idx = self._shift_idx
Expand Down
78 changes: 75 additions & 3 deletions src/litdata/utilities/dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ def subsample_streaming_dataset(

if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)):
# load chunks from `index.json` file
with open(os.path.join(input_dir.path, _INDEX_FILENAME)) as f:
data = json.load(f)
original_chunks = data["chunks"]
data = load_index_file(input_dir.path)
original_chunks = data["chunks"]
else:
raise ValueError(
f"The provided dataset `{input_dir.path}` doesn't contain any {_INDEX_FILENAME} file."
Expand Down Expand Up @@ -115,3 +114,76 @@ def generate_roi(chunks: List[Dict[str, Any]], item_loader: Optional[BaseItemLoa
roi.append((0, end))

return roi


def load_index_file(input_dir: str) -> Dict[str, Any]:
"""Load index file from the specified input directory.

This function supports loading both chunk-based and mds shard-based index files.
For shard-based files, it adapts the format to be compatible with chunk-based processing.

Args:
input_dir (str): The directory containing the index file.

Returns:
Dict[str, Any]: The loaded and possibly adapted index data.

Raises:
FileNotFoundError: If the index file does not exist in the input directory.

"""
index_filepath = os.path.join(input_dir, _INDEX_FILENAME)
try:
with open(index_filepath) as f:
data = json.load(f)

if "chunks" not in data and "shards" in data:
# load mds shard-based index file and adapt to chunks format
return adapt_mds_shards_to_chunks(data)

return data
except FileNotFoundError:
raise FileNotFoundError(f"Index file not found at {index_filepath}.")


def adapt_mds_shards_to_chunks(data: Dict[str, Any]) -> Dict[str, Any]:
"""Adapt mds shard-based index data to chunk-based format for compatibility.
For more details about MDS, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming

Args:
data (Dict[str, Any]): The original index data containing shards.

Returns:
Dict[str, Any]: Adapted index data with chunks format.
"""
chunks = []
shards = data["shards"]
for shard in shards:
chunks.append(
{
"chunk_bytes": shard["zip_data"]["bytes"],
"chunk_size": shard["samples"],
"column_sizes": shard["column_sizes"],
"dim": None,
"filename": shard["zip_data"]["basename"],
}
)
data["chunks"] = chunks

data_spec = [
1,
{
"type": "builtins.dict",
"context": json.dumps(shards[0]["column_names"]),
"children_spec": [{"type": None, "context": None, "children_spec": []} for _ in shards[0]["column_names"]],
},
]
data["config"] = {
"chunk_bytes": sum(shard["zip_data"]["bytes"] for shard in shards),
"chunk_size": sum(shard["samples"] for shard in shards),
"compression": shards[0]["compression"],
"data_format": shards[0]["column_encodings"],
"format": shards[0]["format"],
"data_spec": json.dumps(data_spec),
}
return data
16 changes: 7 additions & 9 deletions src/litdata/utilities/train_test_split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
import os
from copy import deepcopy
Expand All @@ -8,6 +7,7 @@

from litdata import StreamingDataset
from litdata.constants import _INDEX_FILENAME
from litdata.utilities.dataset_utilities import load_index_file
from litdata.utilities.subsample import shuffle_lists_together, subsample_filenames_and_roi


Expand Down Expand Up @@ -55,14 +55,12 @@ def train_test_split(

if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)):
# load chunks from `index.json` file
with open(os.path.join(input_dir.path, _INDEX_FILENAME)) as f:
data = json.load(f)
original_chunks = data["chunks"]
subsampled_chunks = [
_org_chunk
for _org_chunk in original_chunks
if _org_chunk["filename"] in dummy_subsampled_chunk_filename
]
data = load_index_file(input_dir.path)

original_chunks = data["chunks"]
subsampled_chunks = [
_org_chunk for _org_chunk in original_chunks if _org_chunk["filename"] in dummy_subsampled_chunk_filename
]
else:
raise ValueError("Couldn't load original chunk file.")

Expand Down
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,25 @@ def teardown_process_group(): # noqa: PT004
yield
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


@pytest.fixture()
def mosaic_mds_index_data():
return {
"shards": [
{
"column_encodings": ["int", "jpeg"],
"column_names": ["class", "image"],
"column_sizes": [8, None],
"compression": "zstd",
"format": "mds",
"hashes": [],
"raw_data": {"basename": "shard.00000.mds", "bytes": 125824, "hashes": {}},
"samples": 100,
"size_limit": 67108864,
"version": 2,
"zip_data": {"basename": "shard.00000.mds.zstd", "bytes": 63407, "hashes": {}},
}
],
"version": 2,
}
67 changes: 67 additions & 0 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import pytest
import torch
from litdata import train_test_split
from litdata.constants import _ZSTD_AVAILABLE
from litdata.processing import functions
from litdata.streaming import Cache
Expand Down Expand Up @@ -995,3 +996,69 @@ def test_subsample_streaming_dataset_with_token_loader(tmpdir, monkeypatch):
)

assert len(dataset2) == int(len(dataset1) * 0.4)


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows")
def test_dataset_with_mosaic_mds_data(tmpdir):
from PIL import Image
from streaming import MDSWriter
# example taken from: https://github.com/mosaicml/streaming

# A dictionary mapping input fields to their data types
columns = {"image": "jpeg", "class": "int"}
# Shard compression, if any
compression = "zstd"
# Save the samples as shards using MDSWriter
with MDSWriter(out=str(tmpdir), columns=columns, compression=compression) as out:
for i in range(10):
sample = {
"image": Image.fromarray(np.random.randint(0, 256, (32, 32, 3), np.uint8)),
"class": i,
}
out.write(sample)
dataset = StreamingDataset(input_dir=str(tmpdir))
assert len(dataset) == 10
for i in range(10):
sample = dataset[i]
assert sample["class"] == i

assert [sample["class"] for sample in dataset[:]] == list(range(10)) # test slicing

# -------------- train_test_split --------------

train_ds, test_ds, val_ds = train_test_split(dataset, splits=[0.7, 0.2, 0.1])

assert len(train_ds) == 7
assert len(test_ds) == 2
assert len(val_ds) == 1

# -------------- subsample --------------

dataset = StreamingDataset(input_dir=str(tmpdir), subsample=0.4)
assert len(dataset) == 4
assert [sample["class"] for sample in dataset[:]] == [0, 1, 2, 3]

# -------------- works with dataloader --------------

dataset = StreamingDataset(input_dir=str(tmpdir))
dataloader = DataLoader(dataset, batch_size=4, drop_last=True)
i = 0
for batch in dataloader:
assert len(batch["class"]) == 4
assert len(batch["image"]) == 4
assert list(batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
i += 1

dataloader = DataLoader(dataset, batch_size=4, drop_last=False)
i = 0
for batch in dataloader:
if i == 2:
# last batch is smaller than batch_size
assert len(batch["class"]) == 2
assert len(batch["image"]) == 2
assert list(batch["class"]) == [4 * i, 4 * i + 1]
break
assert len(batch["class"]) == 4
assert len(batch["image"]) == 4
assert list(batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
i += 1
20 changes: 20 additions & 0 deletions tests/utilities/test_dataset_utilities.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import json
import os
from unittest import mock

from litdata.constants import _INDEX_FILENAME
from litdata.utilities.dataset_utilities import (
_should_replace_path,
_try_create_cache_dir,
adapt_mds_shards_to_chunks,
generate_roi,
load_index_file,
)


Expand Down Expand Up @@ -44,3 +48,19 @@ def test_generate_roi():
my_roi = generate_roi(my_chunks)

assert my_roi == [(0, 30), (0, 50), (0, 20), (0, 10)]


def test_load_index_file(tmpdir, mosaic_mds_index_data):
with open(os.path.join(tmpdir, _INDEX_FILENAME), "w") as f:
f.write(json.dumps(mosaic_mds_index_data))
index_data = load_index_file(tmpdir)
assert "chunks" in index_data
assert "config" in index_data
assert len(mosaic_mds_index_data["shards"]) == len(index_data["chunks"])


def test_adapt_mds_shards_to_chunks(mosaic_mds_index_data):
adapted_data = adapt_mds_shards_to_chunks(mosaic_mds_index_data)
assert "chunks" in adapted_data
assert "config" in adapted_data
assert len(mosaic_mds_index_data["shards"]) == len(adapted_data["chunks"])
Loading