diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 0624e739..f453f538 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -29,7 +29,6 @@ if TYPE_CHECKING: from PIL.JpegImagePlugin import JpegImageFile - _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") _AV_AVAILABLE = RequirementCache("av") @@ -70,6 +69,8 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: @classmethod def deserialize(cls, data: bytes) -> Any: + if not _PIL_AVAILABLE: + raise ModuleNotFoundError("PIL is required. Run `pip install pillow`") from PIL import Image idx = 3 * 4 @@ -94,6 +95,9 @@ class JPEGSerializer(Serializer): """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: + if not _PIL_AVAILABLE: + raise ModuleNotFoundError("PIL is required. Run `pip install pillow`") + from PIL import Image from PIL.GifImagePlugin import GifImageFile from PIL.JpegImagePlugin import JpegImageFile @@ -122,7 +126,7 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: buff.seek(0) return buff.read(), None - raise TypeError(f"The provided item should be of type {JpegImageFile}. Found {item}.") + raise TypeError(f"The provided item should be of type `JpegImageFile`. Found {item}.") def deserialize(self, data: bytes) -> Union["JpegImageFile", torch.Tensor]: if _TORCH_VISION_AVAILABLE: @@ -184,7 +188,7 @@ def deserialize(self, data: bytes) -> torch.Tensor: shape = [] for shape_idx in range(shape_size): shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) - idx_start = 8 + 4 * (shape_idx + 1) + idx_start = 8 + 4 * shape_size idx_end = len(data) if idx_end > idx_start: tensor = torch.frombuffer(data[idx_start:idx_end], dtype=dtype) @@ -250,7 +254,7 @@ def deserialize(self, data: bytes) -> np.ndarray: shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) # deserialize the numpy array bytes - tensor = np.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype) + tensor = np.frombuffer(data[8 + 4 * shape_size : len(data)], dtype=dtype) if tensor.shape == shape: return tensor return np.reshape(tensor, shape)