Skip to content

Commit

Permalink
fix import & asignement issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Sep 18, 2024
1 parent 2f78ec1 commit b496cbd
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/litdata/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@

from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING

if TYPE_CHECKING:
from PIL.JpegImagePlugin import JpegImageFile

_PIL_AVAILABLE = RequirementCache("PIL")
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
_AV_AVAILABLE = RequirementCache("av")
Expand Down Expand Up @@ -70,6 +67,8 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]:

@classmethod
def deserialize(cls, data: bytes) -> Any:
if not _PIL_AVAILABLE:
raise ModuleNotFoundError("PIL is required. Run `pip install pillow`")
from PIL import Image

idx = 3 * 4
Expand All @@ -94,6 +93,9 @@ class JPEGSerializer(Serializer):
"""The JPEGSerializer serialize and deserialize JPEG image to and from bytes."""

def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]:
if not _PIL_AVAILABLE:
raise ModuleNotFoundError("PIL is required. Run `pip install pillow`")

from PIL import Image
from PIL.GifImagePlugin import GifImageFile
from PIL.JpegImagePlugin import JpegImageFile
Expand Down Expand Up @@ -121,7 +123,7 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]:
buff.seek(0)
return buff.read(), None

raise TypeError(f"The provided item should be of type {JpegImageFile}. Found {item}.")
raise TypeError(f"The provided item should be of type `JpegImageFile`. Found {item}.")

def deserialize(self, data: bytes) -> Union["JpegImageFile", torch.Tensor]:
if _TORCH_VISION_AVAILABLE:
Expand Down Expand Up @@ -183,7 +185,7 @@ def deserialize(self, data: bytes) -> torch.Tensor:
shape = []
for shape_idx in range(shape_size):
shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item())
idx_start = 8 + 4 * (shape_idx + 1)
idx_start = 8 + 4 * shape_size
idx_end = len(data)
if idx_end > idx_start:
tensor = torch.frombuffer(data[idx_start:idx_end], dtype=dtype)
Expand Down Expand Up @@ -249,7 +251,7 @@ def deserialize(self, data: bytes) -> np.ndarray:
shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item())

# deserialize the numpy array bytes
tensor = np.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype)
tensor = np.frombuffer(data[8 + 4 * shape_size : len(data)], dtype=dtype)
if tensor.shape == shape:
return tensor
return np.reshape(tensor, shape)
Expand Down

0 comments on commit b496cbd

Please sign in to comment.