Skip to content

Commit

Permalink
Fix import error in serialization checks (#241)
Browse files Browse the repository at this point in the history
* Fix import error

* add test
  • Loading branch information
awaelchli authored Jul 18, 2024
1 parent 36c1f10 commit 0380af8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/litdata/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,13 @@ def deserialize(cls, data: bytes) -> Any:
return Image.frombytes(mode, size, raw) # pyright: ignore

def can_serialize(self, item: Any) -> bool:
if not _PIL_AVAILABLE:
return False

from PIL import Image
from PIL.JpegImagePlugin import JpegImageFile

return bool(_PIL_AVAILABLE) and isinstance(item, Image.Image) and not isinstance(item, JpegImageFile)
return isinstance(item, Image.Image) and not isinstance(item, JpegImageFile)


class JPEGSerializer(Serializer):
Expand Down Expand Up @@ -137,9 +140,12 @@ def deserialize(self, data: bytes) -> Union["JpegImageFile", torch.Tensor]:
return img

def can_serialize(self, item: Any) -> bool:
if not _PIL_AVAILABLE:
return False

from PIL.JpegImagePlugin import JpegImageFile

return bool(_PIL_AVAILABLE) and isinstance(item, JpegImageFile)
return isinstance(item, JpegImageFile)


class BytesSerializer(Serializer):
Expand Down
13 changes: 13 additions & 0 deletions tests/streaming/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import random
import sys
from time import time
from unittest import mock

import numpy as np
import pytest
Expand Down Expand Up @@ -100,6 +101,12 @@ def test_pil_serializer(mode):
assert np.array_equal(np_data, np_dec_data)


def test_pil_serializer_available():
serializer = PILSerializer()
with mock.patch("litdata.streaming.serializers._PIL_AVAILABLE", False):
assert not serializer.can_serialize(None)


@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
def test_jpeg_serializer():
serializer = JPEGSerializer()
Expand All @@ -121,6 +128,12 @@ def test_jpeg_serializer():
assert deserialized_img.shape == torch.Size([3, 28, 28])


def test_jpeg_serializer_available():
serializer = JPEGSerializer()
with mock.patch("litdata.streaming.serializers._PIL_AVAILABLE", False):
assert not serializer.can_serialize(None)


@pytest.mark.flaky(reruns=3)
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
def test_tensor_serializer():
Expand Down

0 comments on commit 0380af8

Please sign in to comment.