Skip to content

Commit

Permalink
polars serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
dominictarro committed Jun 14, 2023
1 parent 8fbd591 commit 6ec750c
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 0 deletions.
121 changes: 121 additions & 0 deletions src/prefecto/serializers/polars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
Serializer classes for Prefect.
"""
from typing import Any, Literal

try:
import polars as pl
except ImportError:
raise ImportError(
"Polars is required for the polars serializer.\nInstall"
" with `pip install polars`."
)

from .core import ExtendedSerializer, Method, method


@method
class Parquet(Method):
"""Method for reading and writing Parquet files."""

discriminator: str = "polars.parquet"
__read__ = pl.read_parquet
__write__ = pl.DataFrame.write_parquet


@method
class CSV(Method):
"""Method for reading and writing CSV files."""

discriminator: str = "polars.csv"
default_write_kwargs: dict[str, Any] = {}
__read__ = pl.read_csv
__write__ = pl.DataFrame.write_csv


@method
class JSON(Method):
"""Method for reading and writing JSON files."""

discriminator: str = "polars.json"
__read__ = pl.read_json
__write__ = pl.DataFrame.write_json


@method
class NDJSON(Method):
"""Method for reading and writing NDJSON files."""

discriminator: str = "polars.ndjson"
__read__ = pl.read_ndjson
__write__ = pl.DataFrame.write_ndjson


@method
class TSV(Method):
"""Method for reading and writing TSV files."""

discriminator: str = "polars.tsv"
default_read_kwargs: dict[str, Any] = {"separator": "\t"}
default_write_kwargs: dict[str, Any] = {"separator": "\t"}
__read__ = pl.read_csv
__write__ = pl.DataFrame.write_csv


@method
class Excel(Method):
"""Method for reading and writing Excel files."""

discriminator: str = "polars.excel"
__read__ = pl.read_excel
__write__ = pl.DataFrame.write_excel


@ExtendedSerializer.register
class PolarsSerializer(ExtendedSerializer):
"""Serializer for `polars.DataFrame` objects.
Parameters
----------
method : str
The method to use for reading and writing. Must be a registered
`Method`. Defaults to "polars.tsv".
read_kwargs : dict[str, Any], optional
Keyword arguments for the read method. Overrides default arguments for
the method.
write_kwargs : dict[str, Any], optional
Keyword arguments for the write method. Overrides default arguments
for the method.
Examples
--------
Simple read and write.
>>> import polars as pl
>>> from prefecto.serializers.polars import PolarsSerializer
>>> df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> blob = PolarsSerializer().dumps(df)
>>> blob
b'a\\tb\\n1\\t4\\n2\\t5\\n3\\t6\\n'
>>> df2 = PolarsSerializer().loads(blob)
>>> df2.frame_equal(df)
True
Using a different method.
>>> blob = PolarsSerializer(method="polars.csv").dumps(df)
>>> blob
b'a,b\\n1,4\\n2,5\\n3,6\\n'
>>> df2 = PolarsSerializer(method="polars.csv").loads(blob)
>>> df2.frame_equal(df)
True
Using custom read and write kwargs.
>>> blob = PolarsSerializer(write_kwargs={"index": True}).dumps(df)
>>> blob
b'index\\ta\\tb\\n0\\t1\\t4\\n1\\t2\\t5\\n2\\t3\\t6\\n'
>>> df2 = PolarsSerializer(read_kwargs={"index_col": 0}).loads(blob)
>>> df2.frame_equal(df)
True
"""

type: Literal["polars"] = "polars"
method = "polars.parquet"
72 changes: 72 additions & 0 deletions tests/serializers/test_polars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Tests for the `serializers.polars` module.
"""
import polars as pl
import pytest

from prefecto.serializers import polars as pls


@pytest.fixture
def df():
"""Returns a simple DataFrame."""
return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})


def test_parquet(df: pl.DataFrame):
"""Tests the parquet method."""
s = pls.PolarsSerializer(method=pls.Parquet.discriminator)
assert s.get_method() == pls.Parquet

blob = s.dumps(df)
df2 = s.loads(blob)
assert df.frame_equal(df2)


def test_csv(df: pl.DataFrame):
"""Tests the csv method."""
s = pls.PolarsSerializer(method=pls.CSV.discriminator)
assert s.get_method() == pls.CSV

blob = s.dumps(df)
df2 = s.loads(blob)
assert df.frame_equal(df2)


def test_json(df: pl.DataFrame):
"""Tests the json method."""
s = pls.PolarsSerializer(method=pls.JSON.discriminator)
assert s.get_method() == pls.JSON

blob = s.dumps(df)
df2 = s.loads(blob)
assert df.frame_equal(df2)


def test_ndjson(df: pl.DataFrame):
"""Tests the jsonl method."""
s = pls.PolarsSerializer(method=pls.NDJSON.discriminator)
assert s.get_method() == pls.NDJSON

blob = s.dumps(df)
df2 = s.loads(blob)
assert df.frame_equal(df2)


def test_tsv(df: pl.DataFrame):
"""Tests the tsv method."""
s = pls.PolarsSerializer(method=pls.TSV.discriminator)
assert s.get_method() == pls.TSV

blob = s.dumps(df)
df2 = s.loads(blob)
assert df.frame_equal(df2)


def test_excel(df: pl.DataFrame):
"""Tests the excel method."""
s = pls.PolarsSerializer(method=pls.Excel.discriminator)
assert s.get_method() == pls.Excel
blob = s.dumps(df)
df2 = s.loads(blob)
assert df.frame_equal(df2)

0 comments on commit 6ec750c

Please sign in to comment.