diff --git a/src/prefecto/serializers/pandas.py b/src/prefecto/serializers/pandas.py index 0e730df..8282bb2 100644 --- a/src/prefecto/serializers/pandas.py +++ b/src/prefecto/serializers/pandas.py @@ -1,7 +1,7 @@ """ Pandas IO `Method`s and `ExtendedSerializer`. """ -from typing import Any +from typing import Any, Literal try: import pandas as pd @@ -87,10 +87,12 @@ class Excel(Method): """Method for reading and writing Excel files.""" discriminator: str = "pandas.excel" + default_write_kwargs: dict[str, Any] = {"index": False} __read__ = pd.read_excel __write__ = pd.DataFrame.to_excel +@ExtendedSerializer.register class PandasSerializer(ExtendedSerializer): """Serializer for `pandas.DataFrame` objects. @@ -136,4 +138,5 @@ class PandasSerializer(ExtendedSerializer): True """ + type: Literal["pandas"] = "pandas" method = "pandas.tsv" diff --git a/tests/serializers/test_core.py b/tests/serializers/test_core.py new file mode 100644 index 0000000..f42c4ed --- /dev/null +++ b/tests/serializers/test_core.py @@ -0,0 +1,105 @@ +""" +Tests for the `serializers.core` module. +""" +import io +from typing import Any + +import pytest +from _pytest.monkeypatch import MonkeyPatch + +from prefecto.serializers import core + + +def dummyread() -> str: + """Returns "abc".""" + return "abc" + + +def dummywrite() -> bytes: + """Returns b"abc".""" + return b"abc" + + +@pytest.fixture +def method_registry(monkeypatch: MonkeyPatch): + """Fixture to clear the method registry.""" + try: + __registry__ = core.__registry__ + monkeypatch.setattr(core, "__registry__", __registry__.copy()) + yield + finally: + monkeypatch.setattr(core, "__registry__", __registry__) + + +def test_method(method_registry): + """Tests the method decorator.""" + + @core.method + class Method(core.Method): + """Dummy method.""" + + discriminator: str = "test.method" + default_read_kwargs: dict[str, Any] = {} + default_write_kwargs: dict[str, Any] = {} + __read__ = dummyread + __write__ = dummywrite + + assert core.get_method("test.method") == Method + assert Method.read() == "abc" + assert Method.write() == b"abc" + + +def test_method_missing_attr(method_registry): + """Tests the method decorator with missing attributes.""" + with pytest.raises(AttributeError): + + @core.method + class _(core.Method): + """Dummy method.""" + + __read__ = dummyread + __write__ = dummywrite + + with pytest.raises(AttributeError): + + @core.method + class _(core.Method): + """Dummy method.""" + + discriminator: str = "test.method.a" + __write__ = dummywrite + + with pytest.raises(AttributeError): + + @core.method + class _(core.Method): + """Dummy method.""" + + discriminator: str = "test.method.b" + default_read_kwargs: dict[str, Any] = {} + __read__ = dummyread + + +def test_extended_serializer_basic(method_registry): + """Tests the serializer class with simple inputs.""" + + @core.method + class _(core.Method): + """Dummy method.""" + + discriminator: str = "test" + default_read_kwargs: dict[str, Any] = {} + default_write_kwargs: dict[str, Any] = {} + + def __read__(buff: io.BytesIO) -> str: + """Reads the string from the buffer.""" + return buff.read().decode() + + def __write__(value: str, buff: io.BytesIO) -> None: + """Writes the string to the buffer.""" + buff.write(value.encode()) + + s = core.ExtendedSerializer(method="test") + string = "abc" + assert s.dumps(string) == b"abc" + assert s.loads(b"abc") == "abc" diff --git a/tests/serializers/test_pandas.py b/tests/serializers/test_pandas.py new file mode 100644 index 0000000..3a0143d --- /dev/null +++ b/tests/serializers/test_pandas.py @@ -0,0 +1,92 @@ +""" +Tests for the `serializers.pandas` module. +""" +import pandas as pd +import pytest + +from prefecto.serializers import pandas as pds + + +@pytest.fixture +def df(): + """Returns a simple DataFrame.""" + return pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +def test_parquet(df: pd.DataFrame): + """Tests the parquet method.""" + s = pds.PandasSerializer(method=pds.Parquet.discriminator) + assert s.get_method() == pds.Parquet + + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.equals(df2) + + +def test_csv(df: pd.DataFrame): + """Tests the csv method.""" + s = pds.PandasSerializer(method=pds.CSV.discriminator) + assert s.get_method() == pds.CSV + + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.equals(df2) + + +def test_json(df: pd.DataFrame): + """Tests the json method.""" + s = pds.PandasSerializer(method=pds.JSON.discriminator) + assert s.get_method() == pds.JSON + + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.equals(df2) + + +def test_jsonl(df: pd.DataFrame): + """Tests the jsonl method.""" + s = pds.PandasSerializer(method=pds.JSONL.discriminator) + assert s.get_method() == pds.JSONL + + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.equals(df2) + + +def test_feather(df: pd.DataFrame): + """Tests the feather method.""" + s = pds.PandasSerializer(method=pds.Feather.discriminator) + assert s.get_method() == pds.Feather + + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.equals(df2) + + +def test_pickle(df: pd.DataFrame): + """Tests the pickle method.""" + s = pds.PandasSerializer(method=pds.Pickle.discriminator) + assert s.get_method() == pds.Pickle + + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.equals(df2) + + +def test_tsv(df: pd.DataFrame): + """Tests the tsv method.""" + s = pds.PandasSerializer(method=pds.TSV.discriminator) + assert s.get_method() == pds.TSV + + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.equals(df2) + + +def test_excel(df: pd.DataFrame): + """Tests the excel method.""" + s = pds.PandasSerializer(method=pds.Excel.discriminator) + assert s.get_method() == pds.Excel + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.equals(df2) diff --git a/tests/test_serializers.py b/tests/test_serializers.py deleted file mode 100644 index be8be5e..0000000 --- a/tests/test_serializers.py +++ /dev/null @@ -1,105 +0,0 @@ -"""Tests for the serializers library. -""" -import io -from typing import Any - -import pytest -from _pytest.monkeypatch import MonkeyPatch - -from prefecto.serializers import core - - -class TestCore: - """Tests for the core module.""" - - @staticmethod - def dummyread() -> str: - """Returns "abc".""" - return "abc" - - @staticmethod - def dummywrite() -> bytes: - """Returns b"abc".""" - return b"abc" - - @pytest.fixture - def method_registry(self, monkeypatch: MonkeyPatch): - """Fixture to clear the method registry.""" - try: - __registry__ = core.__registry__ - monkeypatch.setattr(core, "__registry__", __registry__.copy()) - yield - finally: - monkeypatch.setattr(core, "__registry__", __registry__) - - def test_method(self, method_registry): - """Tests the method decorator.""" - - @core.method - class Method(core.Method): - """Dummy method.""" - - discriminator: str = "test.method" - default_read_kwargs: dict[str, Any] = {} - default_write_kwargs: dict[str, Any] = {} - __read__ = self.dummyread - __write__ = self.dummywrite - - assert core.__registry__ == {"test.method": Method} - assert core.get_method("test.method") == Method - assert Method.read() == "abc" - assert Method.write() == b"abc" - - def test_method_missing_attr(self, method_registry): - """Tests the method decorator with missing attributes.""" - with pytest.raises(AttributeError): - - @core.method - class _(core.Method): - """Dummy method.""" - - __read__ = self.dummyread - __write__ = self.dummywrite - - with pytest.raises(AttributeError): - - @core.method - class _(core.Method): - """Dummy method.""" - - discriminator: str = "test.method.a" - __write__ = self.dummywrite - - with pytest.raises(AttributeError): - - @core.method - class _(core.Method): - """Dummy method.""" - - discriminator: str = "test.method.b" - default_read_kwargs: dict[str, Any] = {} - __read__ = self.dummyread - - def test_extended_serializer_basic(self, method_registry): - """Tests the serializer class with simple inputs.""" - - @core.method - class _(core.Method): - """Dummy method.""" - - discriminator: str = "test" - default_read_kwargs: dict[str, Any] = {} - default_write_kwargs: dict[str, Any] = {} - - def __read__(buff: io.BytesIO) -> str: - """Reads the string from the buffer.""" - return buff.read().decode() - - def __write__(value: str, buff: io.BytesIO) -> None: - """Writes the string to the buffer.""" - buff.write(value.encode()) - - s = core.ExtendedSerializer(method="test") - string = "abc" - assert s.dumps(string) == b"abc" - assert s.loads(b"abc") == "abc"