diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 57d3e85e..35ef7a30 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -27,9 +27,6 @@ from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING -if TYPE_CHECKING: - from PIL.JpegImagePlugin import JpegImageFile - _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") _AV_AVAILABLE = RequirementCache("av") @@ -70,6 +67,8 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: @classmethod def deserialize(cls, data: bytes) -> Any: + if not _PIL_AVAILABLE: + raise ModuleNotFoundError("PIL is required. Run `pip install pillow`") from PIL import Image idx = 3 * 4 @@ -94,6 +93,9 @@ class JPEGSerializer(Serializer): """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: + if not _PIL_AVAILABLE: + raise ModuleNotFoundError("PIL is required. Run `pip install pillow`") + from PIL import Image from PIL.GifImagePlugin import GifImageFile from PIL.JpegImagePlugin import JpegImageFile @@ -121,7 +123,7 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: buff.seek(0) return buff.read(), None - raise TypeError(f"The provided item should be of type {JpegImageFile}. Found {item}.") + raise TypeError(f"The provided item should be of type `JpegImageFile`. Found {item}.") def deserialize(self, data: bytes) -> Union["JpegImageFile", torch.Tensor]: if _TORCH_VISION_AVAILABLE: @@ -183,7 +185,7 @@ def deserialize(self, data: bytes) -> torch.Tensor: shape = [] for shape_idx in range(shape_size): shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) - idx_start = 8 + 4 * (shape_idx + 1) + idx_start = 8 + 4 * shape_size idx_end = len(data) if idx_end > idx_start: tensor = torch.frombuffer(data[idx_start:idx_end], dtype=dtype) @@ -249,7 +251,7 @@ def deserialize(self, data: bytes) -> np.ndarray: shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) # deserialize the numpy array bytes - tensor = np.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype) + tensor = np.frombuffer(data[8 + 4 * shape_size : len(data)], dtype=dtype) if tensor.shape == shape: return tensor return np.reshape(tensor, shape)