diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 41c58179..77d5c2bf 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -202,7 +202,7 @@ def deserialize(self, data: bytes) -> torch.Tensor: return torch.reshape(tensor, shape) def can_serialize(self, item: torch.Tensor) -> bool: - return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) > 1 + return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) != 1 class NoHeaderTensorSerializer(Serializer): diff --git a/tests/streaming/test_serializer.py b/tests/streaming/test_serializer.py index c3998e90..6228e52f 100644 --- a/tests/streaming/test_serializer.py +++ b/tests/streaming/test_serializer.py @@ -256,6 +256,14 @@ def test_deserialize_empty_tensor(): assert torch.equal(t, new_t) +def test_deserialize_scalar_tensor(): + serializer = TensorSerializer() + t = torch.tensor(0) + data, _ = serializer.serialize(t) + new_t = serializer.deserialize(data) + assert torch.equal(t, new_t) + + def test_deserialize_empty_no_header_tensor(): serializer = NoHeaderTensorSerializer() t = torch.ones((0,)).int() @@ -271,6 +279,15 @@ def test_deserialize_empty_no_header_tensor(): assert torch.equal(t, new_t) +def test_can_serialize_tensor(): + serializer = TensorSerializer() + # Check that the TensorSerializer can serialize scalar valued tensors as well as higher order (>1) Tensors + assert serializer.can_serialize(torch.tensor(0)) + assert serializer.can_serialize(torch.tensor([[0, 0]])) + # Check that it does not serialize Tensors of order 1, those are treated by the dedicated NoHeaderTensorSerializer + assert not serializer.can_serialize(torch.tensor([0, 0])) + + @pytest.mark.skipif(not _TIFFFILE_AVAILABLE, reason="Requires: ['tifffile']") def test_tiff_serializer(): serializer = TIFFSerializer()