From b496cbda178e0e4939f6920a996e56a047a0081e Mon Sep 17 00:00:00 2001 From: Jirka B Date: Wed, 18 Sep 2024 14:03:28 +0200 Subject: [PATCH 1/3] fix import & asignement issue --- src/litdata/streaming/serializers.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 57d3e85e..35ef7a30 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -27,9 +27,6 @@ from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING -if TYPE_CHECKING: - from PIL.JpegImagePlugin import JpegImageFile - _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") _AV_AVAILABLE = RequirementCache("av") @@ -70,6 +67,8 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: @classmethod def deserialize(cls, data: bytes) -> Any: + if not _PIL_AVAILABLE: + raise ModuleNotFoundError("PIL is required. Run `pip install pillow`") from PIL import Image idx = 3 * 4 @@ -94,6 +93,9 @@ class JPEGSerializer(Serializer): """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: + if not _PIL_AVAILABLE: + raise ModuleNotFoundError("PIL is required. Run `pip install pillow`") + from PIL import Image from PIL.GifImagePlugin import GifImageFile from PIL.JpegImagePlugin import JpegImageFile @@ -121,7 +123,7 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: buff.seek(0) return buff.read(), None - raise TypeError(f"The provided item should be of type {JpegImageFile}. Found {item}.") + raise TypeError(f"The provided item should be of type `JpegImageFile`. Found {item}.") def deserialize(self, data: bytes) -> Union["JpegImageFile", torch.Tensor]: if _TORCH_VISION_AVAILABLE: @@ -183,7 +185,7 @@ def deserialize(self, data: bytes) -> torch.Tensor: shape = [] for shape_idx in range(shape_size): shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) - idx_start = 8 + 4 * (shape_idx + 1) + idx_start = 8 + 4 * shape_size idx_end = len(data) if idx_end > idx_start: tensor = torch.frombuffer(data[idx_start:idx_end], dtype=dtype) @@ -249,7 +251,7 @@ def deserialize(self, data: bytes) -> np.ndarray: shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) # deserialize the numpy array bytes - tensor = np.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype) + tensor = np.frombuffer(data[8 + 4 * shape_size : len(data)], dtype=dtype) if tensor.shape == shape: return tensor return np.reshape(tensor, shape) From 9d6d6d276e89cb6f7c35f37df9fabfa49a33c818 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Sep 2024 12:04:47 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 35ef7a30..d837c9bc 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -19,7 +19,7 @@ from collections import OrderedDict from contextlib import suppress from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch From 8009f3f73f73508fcd3ece6c74f0714bd7ca7a33 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 18 Sep 2024 14:10:17 +0200 Subject: [PATCH 3/3] Apply suggestions from code review --- src/litdata/streaming/serializers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index d837c9bc..444c56ad 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -19,7 +19,7 @@ from collections import OrderedDict from contextlib import suppress from copy import deepcopy -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import numpy as np import torch @@ -27,6 +27,8 @@ from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING +if TYPE_CHECKING: + from PIL.JpegImagePlugin import JpegImageFile _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") _AV_AVAILABLE = RequirementCache("av")