From 6ec750cb82d3edd8db419f5acf80ec97840a3dd4 Mon Sep 17 00:00:00 2001 From: Dominic Tarro Date: Wed, 14 Jun 2023 00:34:33 -0400 Subject: [PATCH] polars serializer --- src/prefecto/serializers/polars.py | 121 +++++++++++++++++++++++++++++ tests/serializers/test_polars.py | 72 +++++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 src/prefecto/serializers/polars.py create mode 100644 tests/serializers/test_polars.py diff --git a/src/prefecto/serializers/polars.py b/src/prefecto/serializers/polars.py new file mode 100644 index 0000000..e24798d --- /dev/null +++ b/src/prefecto/serializers/polars.py @@ -0,0 +1,121 @@ +""" +Serializer classes for Prefect. +""" +from typing import Any, Literal + +try: + import polars as pl +except ImportError: + raise ImportError( + "Polars is required for the polars serializer.\nInstall" + " with `pip install polars`." + ) + +from .core import ExtendedSerializer, Method, method + + +@method +class Parquet(Method): + """Method for reading and writing Parquet files.""" + + discriminator: str = "polars.parquet" + __read__ = pl.read_parquet + __write__ = pl.DataFrame.write_parquet + + +@method +class CSV(Method): + """Method for reading and writing CSV files.""" + + discriminator: str = "polars.csv" + default_write_kwargs: dict[str, Any] = {} + __read__ = pl.read_csv + __write__ = pl.DataFrame.write_csv + + +@method +class JSON(Method): + """Method for reading and writing JSON files.""" + + discriminator: str = "polars.json" + __read__ = pl.read_json + __write__ = pl.DataFrame.write_json + + +@method +class NDJSON(Method): + """Method for reading and writing NDJSON files.""" + + discriminator: str = "polars.ndjson" + __read__ = pl.read_ndjson + __write__ = pl.DataFrame.write_ndjson + + +@method +class TSV(Method): + """Method for reading and writing TSV files.""" + + discriminator: str = "polars.tsv" + default_read_kwargs: dict[str, Any] = {"separator": "\t"} + default_write_kwargs: dict[str, Any] = {"separator": "\t"} + __read__ = pl.read_csv + __write__ = pl.DataFrame.write_csv + + +@method +class Excel(Method): + """Method for reading and writing Excel files.""" + + discriminator: str = "polars.excel" + __read__ = pl.read_excel + __write__ = pl.DataFrame.write_excel + + +@ExtendedSerializer.register +class PolarsSerializer(ExtendedSerializer): + """Serializer for `polars.DataFrame` objects. + + Parameters + ---------- + method : str + The method to use for reading and writing. Must be a registered + `Method`. Defaults to "polars.tsv". + read_kwargs : dict[str, Any], optional + Keyword arguments for the read method. Overrides default arguments for + the method. + write_kwargs : dict[str, Any], optional + Keyword arguments for the write method. Overrides default arguments + for the method. + + Examples + -------- + Simple read and write. + >>> import polars as pl + >>> from prefecto.serializers.polars import PolarsSerializer + >>> df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + >>> blob = PolarsSerializer().dumps(df) + >>> blob + b'a\\tb\\n1\\t4\\n2\\t5\\n3\\t6\\n' + >>> df2 = PolarsSerializer().loads(blob) + >>> df2.frame_equal(df) + True + + Using a different method. + >>> blob = PolarsSerializer(method="polars.csv").dumps(df) + >>> blob + b'a,b\\n1,4\\n2,5\\n3,6\\n' + >>> df2 = PolarsSerializer(method="polars.csv").loads(blob) + >>> df2.frame_equal(df) + True + + Using custom read and write kwargs. + >>> blob = PolarsSerializer(write_kwargs={"index": True}).dumps(df) + >>> blob + b'index\\ta\\tb\\n0\\t1\\t4\\n1\\t2\\t5\\n2\\t3\\t6\\n' + >>> df2 = PolarsSerializer(read_kwargs={"index_col": 0}).loads(blob) + >>> df2.frame_equal(df) + True + """ + + type: Literal["polars"] = "polars" + method = "polars.parquet" diff --git a/tests/serializers/test_polars.py b/tests/serializers/test_polars.py new file mode 100644 index 0000000..86266d6 --- /dev/null +++ b/tests/serializers/test_polars.py @@ -0,0 +1,72 @@ +""" +Tests for the `serializers.polars` module. +""" +import polars as pl +import pytest + +from prefecto.serializers import polars as pls + + +@pytest.fixture +def df(): + """Returns a simple DataFrame.""" + return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +def test_parquet(df: pl.DataFrame): + """Tests the parquet method.""" + s = pls.PolarsSerializer(method=pls.Parquet.discriminator) + assert s.get_method() == pls.Parquet + + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.frame_equal(df2) + + +def test_csv(df: pl.DataFrame): + """Tests the csv method.""" + s = pls.PolarsSerializer(method=pls.CSV.discriminator) + assert s.get_method() == pls.CSV + + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.frame_equal(df2) + + +def test_json(df: pl.DataFrame): + """Tests the json method.""" + s = pls.PolarsSerializer(method=pls.JSON.discriminator) + assert s.get_method() == pls.JSON + + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.frame_equal(df2) + + +def test_ndjson(df: pl.DataFrame): + """Tests the jsonl method.""" + s = pls.PolarsSerializer(method=pls.NDJSON.discriminator) + assert s.get_method() == pls.NDJSON + + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.frame_equal(df2) + + +def test_tsv(df: pl.DataFrame): + """Tests the tsv method.""" + s = pls.PolarsSerializer(method=pls.TSV.discriminator) + assert s.get_method() == pls.TSV + + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.frame_equal(df2) + + +def test_excel(df: pl.DataFrame): + """Tests the excel method.""" + s = pls.PolarsSerializer(method=pls.Excel.discriminator) + assert s.get_method() == pls.Excel + blob = s.dumps(df) + df2 = s.loads(blob) + assert df.frame_equal(df2)