Skip to content

Commit

Permalink
Fix the NoHeaderTensorSerializer for 1D tensors (other than tokens) (#…
Browse files Browse the repository at this point in the history
…124)

* Fix the NoHeaderTensorSerializer for 1D tensors (other than tokens)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Implement an end to end test for the deserialization of a NoHeaderTensor with the PyTreeLoader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* update

* update

---------

Co-authored-by: Enrico Stauss <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: tchaton <[email protected]>
  • Loading branch information
4 people authored May 9, 2024
1 parent efa4ae0 commit 7fe1c26
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
8 changes: 5 additions & 3 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import functools
import os
from abc import ABC, abstractmethod
from copy import deepcopy
from time import sleep
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -33,14 +34,15 @@ class BaseItemLoader(ABC):
def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer]) -> None:
self._config = config
self._chunks = chunks
self._serializers = serializers
self._serializers = {**serializers}
self._data_format = self._config["data_format"]
self._shift_idx = len(self._data_format) * 4

# setup the serializers on restart
for data_format in self._data_format:
serializer = self._serializers[self._data_format_to_key(data_format)]
serializer = deepcopy(self._serializers[self._data_format_to_key(data_format)])
serializer.setup(data_format)
self._serializers[data_format] = serializer

@functools.lru_cache(maxsize=128)
def _data_format_to_key(self, data_format: str) -> str:
Expand Down Expand Up @@ -128,7 +130,7 @@ def deserialize(self, raw_item_data: bytes) -> "PyTree":
sizes = np.frombuffer(raw_item_data[:idx], np.uint32)
data = []
for size, data_format in zip(sizes, self._data_format):
serializer = self._serializers[self._data_format_to_key(data_format)]
serializer = self._serializers[data_format]
data_bytes = raw_item_data[idx : idx + size]
data.append(serializer.deserialize(data_bytes))
idx += size
Expand Down
30 changes: 29 additions & 1 deletion tests/streaming/test_item_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from unittest.mock import MagicMock

import torch
from litdata.constants import _TORCH_DTYPES_MAPPING
from litdata.streaming import Cache
from litdata.streaming.dataset import StreamingDataset
from litdata.streaming.item_loader import PyTreeLoader


Expand All @@ -9,4 +13,28 @@ def test_serializer_setup():
serializer_mock = MagicMock()
item_loader = PyTreeLoader()
item_loader.setup(config_mock, [], {"fake": serializer_mock})
serializer_mock.setup._mock_mock_calls[0].args[0] == "fake:12"
assert len(item_loader._serializers) == 2
assert item_loader._serializers["fake:12"]


def test_pytreeloader_with_no_header_tensor_serializer(tmpdir):
cache = Cache(str(tmpdir), chunk_size=10)
assert isinstance(cache._reader._item_loader, PyTreeLoader)
dtype_index_float = 1
dtype_index_long = 18
for i in range(10):
cache[i] = {
"float": i * torch.ones(10).to(_TORCH_DTYPES_MAPPING[dtype_index_float]),
"long": i * torch.ones(10).to(_TORCH_DTYPES_MAPPING[dtype_index_long]),
}

data_format = [f"no_header_tensor:{dtype_index_float}", f"no_header_tensor:{dtype_index_long}"]
assert cache._writer.get_config()["data_format"] == data_format
cache.done()
cache.merge()

dataset = StreamingDataset(input_dir=str(tmpdir))
for i in range(len(dataset)):
item = dataset[i]
assert torch.allclose(i * torch.ones(10).to(_TORCH_DTYPES_MAPPING[dtype_index_float]), item["float"])
assert torch.allclose(i * torch.ones(10).to(_TORCH_DTYPES_MAPPING[dtype_index_long]), item["long"])

0 comments on commit 7fe1c26

Please sign in to comment.