Skip to content

Commit

Permalink
Feat/add support for numpy datatypes in tokensloader (#401)
Browse files Browse the repository at this point in the history
* wip: adds support for numpy datatypes in tokensloader

* updated itemloader

* update

* removed numpy array to tensor conversion in tokens loader

* update

* fix: add type ignore for dtype mapping in TokensLoader

* feat: add test for numpy array in toekns loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 5, 2024
1 parent 73f767f commit a247307
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
23 changes: 18 additions & 5 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
import numpy as np
import torch

from litdata.constants import (
_TORCH_DTYPES_MAPPING,
)
from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
from litdata.streaming.serializers import Serializer
from litdata.utilities._pytree import PyTree, tree_unflatten
from litdata.utilities.encryption import Encryption, EncryptionLevel
Expand Down Expand Up @@ -281,7 +279,17 @@ def setup(
region_of_interest: Optional[List[Tuple[int, int]]] = None,
) -> None:
super().setup(config, chunks, serializers, region_of_interest)
self._dtype = _TORCH_DTYPES_MAPPING[int(config["data_format"][0].split(":")[1])]

serializer_name, dtype_index = self._data_format[0].split(":")
if serializer_name not in ["no_header_numpy", "no_header_tensor"]:
raise ValueError("The provided data format isn't supported.")

self._serializer_name = serializer_name
self._dtype = (
_TORCH_DTYPES_MAPPING[int(dtype_index)] # type: ignore
if serializer_name == "no_header_tensor"
else _NUMPY_DTYPES_MAPPING[int(dtype_index)]
)
if all(chunk["dim"] is None for chunk in self._chunks):
raise ValueError("The provided chunks isn't properly setup.")

Expand Down Expand Up @@ -350,7 +358,12 @@ def load_item_from_chunk(

buffer: bytes = self._buffers[chunk_index]
offset = self._dtype.itemsize * (index - begin) * self._block_size
return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)

if self._serializer_name == "no_header_tensor":
data = torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
else:
data = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) # type: ignore
return data

def delete(self, chunk_index: int, chunk_filepath: str) -> None:
if os.path.exists(chunk_filepath):
Expand Down
32 changes: 30 additions & 2 deletions tests/streaming/test_item_loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from unittest.mock import MagicMock

import numpy as np
import torch
from litdata.constants import _TORCH_DTYPES_MAPPING
from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
from litdata.streaming import Cache
from litdata.streaming.dataset import StreamingDataset
from litdata.streaming.item_loader import PyTreeLoader
from litdata.streaming.item_loader import PyTreeLoader, TokensLoader


def test_serializer_setup():
Expand Down Expand Up @@ -38,3 +39,30 @@ def test_pytreeloader_with_no_header_tensor_serializer(tmpdir):
item = dataset[i]
assert torch.allclose(i * torch.ones(10).to(_TORCH_DTYPES_MAPPING[dtype_index_float]), item["float"])
assert torch.allclose(i * torch.ones(10).to(_TORCH_DTYPES_MAPPING[dtype_index_long]), item["long"])


def test_tokensloader_with_no_header_numpy_serializer(tmpdir):
cache = Cache(str(tmpdir), chunk_size=512, item_loader=TokensLoader())
assert isinstance(cache._reader._item_loader, TokensLoader)

dtype_index_int32 = 3
dtype = _NUMPY_DTYPES_MAPPING[dtype_index_int32]

for i in range(10):
data = np.random.randint(0, 100, size=(256), dtype=dtype)
cache._add_item(i, data)

data_format = [f"no_header_numpy:{dtype_index_int32}"]
assert cache._writer.get_config()["data_format"] == data_format
cache.done()
cache.merge()

dataset = StreamingDataset(
input_dir=str(tmpdir),
drop_last=True,
item_loader=TokensLoader(block_size=256),
)

for data in dataset:
assert data.shape == (256,)
assert data.dtype == dtype

0 comments on commit a247307

Please sign in to comment.