-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8fbd591
commit 6ec750c
Showing
2 changed files
with
193 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
""" | ||
Serializer classes for Prefect. | ||
""" | ||
from typing import Any, Literal | ||
|
||
try: | ||
import polars as pl | ||
except ImportError: | ||
raise ImportError( | ||
"Polars is required for the polars serializer.\nInstall" | ||
" with `pip install polars`." | ||
) | ||
|
||
from .core import ExtendedSerializer, Method, method | ||
|
||
|
||
@method | ||
class Parquet(Method): | ||
"""Method for reading and writing Parquet files.""" | ||
|
||
discriminator: str = "polars.parquet" | ||
__read__ = pl.read_parquet | ||
__write__ = pl.DataFrame.write_parquet | ||
|
||
|
||
@method | ||
class CSV(Method): | ||
"""Method for reading and writing CSV files.""" | ||
|
||
discriminator: str = "polars.csv" | ||
default_write_kwargs: dict[str, Any] = {} | ||
__read__ = pl.read_csv | ||
__write__ = pl.DataFrame.write_csv | ||
|
||
|
||
@method | ||
class JSON(Method): | ||
"""Method for reading and writing JSON files.""" | ||
|
||
discriminator: str = "polars.json" | ||
__read__ = pl.read_json | ||
__write__ = pl.DataFrame.write_json | ||
|
||
|
||
@method | ||
class NDJSON(Method): | ||
"""Method for reading and writing NDJSON files.""" | ||
|
||
discriminator: str = "polars.ndjson" | ||
__read__ = pl.read_ndjson | ||
__write__ = pl.DataFrame.write_ndjson | ||
|
||
|
||
@method | ||
class TSV(Method): | ||
"""Method for reading and writing TSV files.""" | ||
|
||
discriminator: str = "polars.tsv" | ||
default_read_kwargs: dict[str, Any] = {"separator": "\t"} | ||
default_write_kwargs: dict[str, Any] = {"separator": "\t"} | ||
__read__ = pl.read_csv | ||
__write__ = pl.DataFrame.write_csv | ||
|
||
|
||
@method | ||
class Excel(Method): | ||
"""Method for reading and writing Excel files.""" | ||
|
||
discriminator: str = "polars.excel" | ||
__read__ = pl.read_excel | ||
__write__ = pl.DataFrame.write_excel | ||
|
||
|
||
@ExtendedSerializer.register | ||
class PolarsSerializer(ExtendedSerializer): | ||
"""Serializer for `polars.DataFrame` objects. | ||
Parameters | ||
---------- | ||
method : str | ||
The method to use for reading and writing. Must be a registered | ||
`Method`. Defaults to "polars.tsv". | ||
read_kwargs : dict[str, Any], optional | ||
Keyword arguments for the read method. Overrides default arguments for | ||
the method. | ||
write_kwargs : dict[str, Any], optional | ||
Keyword arguments for the write method. Overrides default arguments | ||
for the method. | ||
Examples | ||
-------- | ||
Simple read and write. | ||
>>> import polars as pl | ||
>>> from prefecto.serializers.polars import PolarsSerializer | ||
>>> df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) | ||
>>> blob = PolarsSerializer().dumps(df) | ||
>>> blob | ||
b'a\\tb\\n1\\t4\\n2\\t5\\n3\\t6\\n' | ||
>>> df2 = PolarsSerializer().loads(blob) | ||
>>> df2.frame_equal(df) | ||
True | ||
Using a different method. | ||
>>> blob = PolarsSerializer(method="polars.csv").dumps(df) | ||
>>> blob | ||
b'a,b\\n1,4\\n2,5\\n3,6\\n' | ||
>>> df2 = PolarsSerializer(method="polars.csv").loads(blob) | ||
>>> df2.frame_equal(df) | ||
True | ||
Using custom read and write kwargs. | ||
>>> blob = PolarsSerializer(write_kwargs={"index": True}).dumps(df) | ||
>>> blob | ||
b'index\\ta\\tb\\n0\\t1\\t4\\n1\\t2\\t5\\n2\\t3\\t6\\n' | ||
>>> df2 = PolarsSerializer(read_kwargs={"index_col": 0}).loads(blob) | ||
>>> df2.frame_equal(df) | ||
True | ||
""" | ||
|
||
type: Literal["polars"] = "polars" | ||
method = "polars.parquet" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
""" | ||
Tests for the `serializers.polars` module. | ||
""" | ||
import polars as pl | ||
import pytest | ||
|
||
from prefecto.serializers import polars as pls | ||
|
||
|
||
@pytest.fixture | ||
def df(): | ||
"""Returns a simple DataFrame.""" | ||
return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) | ||
|
||
|
||
def test_parquet(df: pl.DataFrame): | ||
"""Tests the parquet method.""" | ||
s = pls.PolarsSerializer(method=pls.Parquet.discriminator) | ||
assert s.get_method() == pls.Parquet | ||
|
||
blob = s.dumps(df) | ||
df2 = s.loads(blob) | ||
assert df.frame_equal(df2) | ||
|
||
|
||
def test_csv(df: pl.DataFrame): | ||
"""Tests the csv method.""" | ||
s = pls.PolarsSerializer(method=pls.CSV.discriminator) | ||
assert s.get_method() == pls.CSV | ||
|
||
blob = s.dumps(df) | ||
df2 = s.loads(blob) | ||
assert df.frame_equal(df2) | ||
|
||
|
||
def test_json(df: pl.DataFrame): | ||
"""Tests the json method.""" | ||
s = pls.PolarsSerializer(method=pls.JSON.discriminator) | ||
assert s.get_method() == pls.JSON | ||
|
||
blob = s.dumps(df) | ||
df2 = s.loads(blob) | ||
assert df.frame_equal(df2) | ||
|
||
|
||
def test_ndjson(df: pl.DataFrame): | ||
"""Tests the jsonl method.""" | ||
s = pls.PolarsSerializer(method=pls.NDJSON.discriminator) | ||
assert s.get_method() == pls.NDJSON | ||
|
||
blob = s.dumps(df) | ||
df2 = s.loads(blob) | ||
assert df.frame_equal(df2) | ||
|
||
|
||
def test_tsv(df: pl.DataFrame): | ||
"""Tests the tsv method.""" | ||
s = pls.PolarsSerializer(method=pls.TSV.discriminator) | ||
assert s.get_method() == pls.TSV | ||
|
||
blob = s.dumps(df) | ||
df2 = s.loads(blob) | ||
assert df.frame_equal(df2) | ||
|
||
|
||
def test_excel(df: pl.DataFrame): | ||
"""Tests the excel method.""" | ||
s = pls.PolarsSerializer(method=pls.Excel.discriminator) | ||
assert s.get_method() == pls.Excel | ||
blob = s.dumps(df) | ||
df2 = s.loads(blob) | ||
assert df.frame_equal(df2) |