From 58f7aeb5836383ff839b4d966462c568aa6e7435 Mon Sep 17 00:00:00 2001 From: VG Date: Thu, 11 Apr 2024 04:34:47 -0400 Subject: [PATCH] Added call to setup function of serializer class to set data format (#96) --- src/litdata/streaming/item_loader.py | 29 +++++++++++++++++----------- tests/streaming/test_item_loader.py | 12 ++++++++++++ 2 files changed, 30 insertions(+), 11 deletions(-) create mode 100644 tests/streaming/test_item_loader.py diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 6a2bbde5..04b9b2b8 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -37,6 +37,22 @@ def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer]) self._config = config self._chunks = chunks self._serializers = serializers + self._data_format = self._config["data_format"] + self._shift_idx = len(self._data_format) * 4 + + # setup the serializers on restart + for data_format in self._data_format: + serializer = self._serializers[self._data_format_to_key(data_format)] + serializer.setup(data_format) + + @functools.lru_cache(maxsize=128) + def _data_format_to_key(self, data_format: str) -> str: + if ":" in data_format: + serialier, serializer_sub_type = data_format.split(":") + if serializer_sub_type in self._serializers: + return serializer_sub_type + return serialier + return data_format def state_dict(self) -> Dict: return {} @@ -109,21 +125,12 @@ def load_item_from_chunk( return self.deserialize(data) - @functools.lru_cache(maxsize=128) - def _data_format_to_key(self, data_format: str) -> str: - if ":" in data_format: - serialier, serializer_sub_type = data_format.split(":") - if serializer_sub_type in self._serializers: - return serializer_sub_type - return serialier - return data_format - def deserialize(self, raw_item_data: bytes) -> "PyTree": """Deserialize the raw bytes into their python equivalent.""" - idx = len(self._config["data_format"]) * 4 + idx = self._shift_idx sizes = np.frombuffer(raw_item_data[:idx], np.uint32) data = [] - for size, data_format in zip(sizes, self._config["data_format"]): + for size, data_format in zip(sizes, self._data_format): serializer = self._serializers[self._data_format_to_key(data_format)] data_bytes = raw_item_data[idx : idx + size] data.append(serializer.deserialize(data_bytes)) diff --git a/tests/streaming/test_item_loader.py b/tests/streaming/test_item_loader.py new file mode 100644 index 00000000..e17fe705 --- /dev/null +++ b/tests/streaming/test_item_loader.py @@ -0,0 +1,12 @@ +from unittest.mock import MagicMock + +from litdata.streaming.item_loader import PyTreeLoader + + +def test_serializer_setup(): + config_mock = MagicMock() + config_mock.__getitem__.return_value = ["fake:12"] + serializer_mock = MagicMock() + item_loader = PyTreeLoader() + item_loader.setup(config_mock, [], {"fake": serializer_mock}) + serializer_mock.setup._mock_mock_calls[0].args[0] == "fake:12"