Skip to content

Commit

Permalink
use public plugins functions [skip-ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 5, 2024
1 parent 111b51f commit 4309674
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 42 deletions.
78 changes: 37 additions & 41 deletions polars_xdt/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING, Literal, Sequence

import polars as pl
from polars.utils.udfs import _get_shared_lib_location
from polars.plugins import register_plugin

from polars_xdt.utils import parse_into_expr

Expand All @@ -21,8 +21,6 @@
RollStrategy: TypeAlias = Literal["raise", "forward", "backward"]


lib = _get_shared_lib_location(__file__)

mapping = {"Mon": 1, "Tue": 2, "Wed": 3, "Thu": 4, "Fri": 5, "Sat": 6, "Sun": 7}
reverse_mapping = {value: key for key, value in mapping.items()}

Expand Down Expand Up @@ -148,7 +146,6 @@ def offset_by(
└────────────┴──────┴──────────────┘
"""
expr = parse_into_expr(expr)
if (
isinstance(by, str)
and (match := re.search(r"(\d+bd)", by)) is not None
Expand All @@ -174,11 +171,12 @@ def offset_by(
)
weekmask = get_weekmask(weekend)

result = expr.register_plugin(
lib=lib,
result = register_plugin(
expr,
n,
plugin_location=__file__,
symbol="advance_n_days",
is_elementwise=True,
args=[n],
kwargs={
"holidays": holidays_int,
"weekmask": weekmask,
Expand Down Expand Up @@ -240,19 +238,18 @@ def is_workday(
└────────────┴────────────┘
"""
expr = parse_into_expr(expr)
weekmask = get_weekmask(weekend)
if not holidays:
holidays_int = []
else:
holidays_int = sorted(
{(holiday - date(1970, 1, 1)).days for holiday in holidays},
)
return expr.register_plugin(
lib=lib,
return register_plugin(
expr,
plugin_location=__file__,
symbol="is_workday",
is_elementwise=True,
args=[],
kwargs={
"weekmask": weekmask,
"holidays": holidays_int,
Expand Down Expand Up @@ -327,13 +324,14 @@ def from_local_datetime(
└─────────────────────┴──────────────────┴─────────────────────────┘
"""
expr = parse_into_expr(expr)
from_tz = parse_into_expr(from_tz, str_as_lit=True)
return expr.register_plugin(
lib=lib,
if isinstance(from_tz, str):
from_tz = pl.lit(from_tz)
return register_plugin(
expr,
from_tz,
plugin_location=__file__,
symbol="from_local_datetime",
is_elementwise=True,
args=[from_tz],
kwargs={
"to_tz": to_tz,
"ambiguous": ambiguous,
Expand Down Expand Up @@ -394,13 +392,14 @@ def to_local_datetime(
└─────────────────────────┴──────────────────┴─────────────────────┘
"""
expr = parse_into_expr(expr)
time_zone = parse_into_expr(time_zone, str_as_lit=True)
return expr.register_plugin(
lib=lib,
if isinstance(time_zone, str):
time_zone = pl.lit(time_zone)
return register_plugin(
expr,
time_zone,
plugin_location=__file__,
symbol="to_local_datetime",
is_elementwise=True,
args=[time_zone],
)


Expand Down Expand Up @@ -453,12 +452,11 @@ def format_localized(
└─────────────────────┴──────────────────────────┘
"""
expr = parse_into_expr(expr)
return expr.register_plugin(
lib=lib,
return register_plugin(
expr,
plugin_location=__file__,
symbol="format_localized",
is_elementwise=True,
args=[],
kwargs={"format": format, "locale": locale},
)

Expand Down Expand Up @@ -492,12 +490,11 @@ def to_julian_date(expr: str | pl.Expr) -> pl.Expr:
└─────────────────────┴────────────────────┘
"""
expr = parse_into_expr(expr)
return expr.register_plugin(
lib=lib,
return register_plugin(
expr,
plugin_location=__file__,
symbol="to_julian_date",
is_elementwise=True,
args=[],
)


Expand Down Expand Up @@ -712,20 +709,19 @@ def workday_count(
└────────────┴────────────┴─────────────────┘
"""
start_dates = parse_into_expr(start_dates)
end_dates = parse_into_expr(end_dates)
weekmask = get_weekmask(weekend)
if not holidays:
holidays_int = []
else:
holidays_int = sorted(
{(holiday - date(1970, 1, 1)).days for holiday in holidays},
)
return start_dates.register_plugin(
lib=lib,
return register_plugin(
start_dates,
end_dates,
plugin_location=__file__,
symbol="workday_count",
is_elementwise=True,
args=[end_dates],
kwargs={
"weekmask": weekmask,
"holidays": holidays_int,
Expand Down Expand Up @@ -821,9 +817,9 @@ def arg_previous_greater(expr: IntoExpr) -> pl.Expr:
└────────────┴───────┴───────┴────────┘
"""
expr = parse_into_expr(expr)
return expr.register_plugin(
lib=lib,
return register_plugin(
expr,
plugin_location=__file__,
symbol="arg_previous_greater",
is_elementwise=False,
)
Expand Down Expand Up @@ -915,14 +911,14 @@ def ewma_by_time(
└────────┴────────────┴──────────┘
"""
times = parse_into_expr(times)
half_life_us = (
int(half_life.total_seconds()) * 1_000_000 + half_life.microseconds
)
return times.register_plugin(
lib=lib,
return register_plugin(
times,
values,
plugin_location=__file__,
symbol="ewma_by_time",
is_elementwise=False,
args=[values],
kwargs={"half_life": half_life_us, "adjust": adjust},
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["maturin>=1.0,<2.0", "polars>=0.20.6"]
requires = ["maturin>=1.0,<2.0", "polars>=0.20.14"]
build-backend = "maturin"

[project]
Expand Down

0 comments on commit 4309674

Please sign in to comment.