From cfb60dcbb2dc103dad4d685f9f354f290a44a4d7 Mon Sep 17 00:00:00 2001 From: Enrico Stauss Date: Tue, 3 Dec 2024 13:14:11 +0100 Subject: [PATCH] Fix the serialization of scalar valued tensors --- src/litdata/streaming/serializers.py | 2 +- tests/streaming/test_serializer.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 41c58179..77d5c2bf 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -202,7 +202,7 @@ def deserialize(self, data: bytes) -> torch.Tensor: return torch.reshape(tensor, shape) def can_serialize(self, item: torch.Tensor) -> bool: - return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) > 1 + return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) != 1 class NoHeaderTensorSerializer(Serializer): diff --git a/tests/streaming/test_serializer.py b/tests/streaming/test_serializer.py index c3998e90..7afd6b47 100644 --- a/tests/streaming/test_serializer.py +++ b/tests/streaming/test_serializer.py @@ -23,6 +23,7 @@ import tifffile import torch from lightning_utilities.core.imports import RequirementCache + from litdata.streaming.serializers import ( _AV_AVAILABLE, _NUMPY_DTYPES_MAPPING, @@ -256,6 +257,14 @@ def test_deserialize_empty_tensor(): assert torch.equal(t, new_t) +def test_deserialize_scalar_tensor(): + serializer = TensorSerializer() + t = torch.tensor(0) + data, _ = serializer.serialize(t) + new_t = serializer.deserialize(data) + assert torch.equal(t, new_t) + + def test_deserialize_empty_no_header_tensor(): serializer = NoHeaderTensorSerializer() t = torch.ones((0,)).int() @@ -271,6 +280,15 @@ def test_deserialize_empty_no_header_tensor(): assert torch.equal(t, new_t) +def test_can_serialize_tensor(): + serializer = TensorSerializer() + # Check that the TensorSerializer can serialize scalar valued tensors as well as higher order (>1) Tensors + assert serializer.can_serialize(torch.tensor(0)) + assert serializer.can_serialize(torch.tensor([[0, 0]])) + # Check that it does not serialize Tensors of order 1, those are treated by the dedicated NoHeaderTensorSerializer + assert not serializer.can_serialize(torch.tensor([0, 0])) + + @pytest.mark.skipif(not _TIFFFILE_AVAILABLE, reason="Requires: ['tifffile']") def test_tiff_serializer(): serializer = TIFFSerializer()