Skip to content

Commit

Permalink
Added call to setup function of serializer class to set data format (#96
Browse files Browse the repository at this point in the history
)
  • Loading branch information
vgurev authored Apr 11, 2024
1 parent c87662a commit 58f7aeb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
29 changes: 18 additions & 11 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer])
self._config = config
self._chunks = chunks
self._serializers = serializers
self._data_format = self._config["data_format"]
self._shift_idx = len(self._data_format) * 4

# setup the serializers on restart
for data_format in self._data_format:
serializer = self._serializers[self._data_format_to_key(data_format)]
serializer.setup(data_format)

@functools.lru_cache(maxsize=128)
def _data_format_to_key(self, data_format: str) -> str:
if ":" in data_format:
serialier, serializer_sub_type = data_format.split(":")
if serializer_sub_type in self._serializers:
return serializer_sub_type
return serialier
return data_format

def state_dict(self) -> Dict:
return {}
Expand Down Expand Up @@ -109,21 +125,12 @@ def load_item_from_chunk(

return self.deserialize(data)

@functools.lru_cache(maxsize=128)
def _data_format_to_key(self, data_format: str) -> str:
if ":" in data_format:
serialier, serializer_sub_type = data_format.split(":")
if serializer_sub_type in self._serializers:
return serializer_sub_type
return serialier
return data_format

def deserialize(self, raw_item_data: bytes) -> "PyTree":
"""Deserialize the raw bytes into their python equivalent."""
idx = len(self._config["data_format"]) * 4
idx = self._shift_idx
sizes = np.frombuffer(raw_item_data[:idx], np.uint32)
data = []
for size, data_format in zip(sizes, self._config["data_format"]):
for size, data_format in zip(sizes, self._data_format):
serializer = self._serializers[self._data_format_to_key(data_format)]
data_bytes = raw_item_data[idx : idx + size]
data.append(serializer.deserialize(data_bytes))
Expand Down
12 changes: 12 additions & 0 deletions tests/streaming/test_item_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from unittest.mock import MagicMock

from litdata.streaming.item_loader import PyTreeLoader


def test_serializer_setup():
config_mock = MagicMock()
config_mock.__getitem__.return_value = ["fake:12"]
serializer_mock = MagicMock()
item_loader = PyTreeLoader()
item_loader.setup(config_mock, [], {"fake": serializer_mock})
serializer_mock.setup._mock_mock_calls[0].args[0] == "fake:12"

0 comments on commit 58f7aeb

Please sign in to comment.