Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add arg_prev_greater #61

Merged
merged 4 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "polars_xdt"
version = "0.12.11"
version = "0.13.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
2 changes: 2 additions & 0 deletions polars_xdt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import polars_xdt.namespace # noqa: F401
from polars_xdt.functions import (
arg_previous_greater,
ceil,
day_name,
format_localized,
Expand Down Expand Up @@ -29,5 +30,6 @@
"to_julian_date",
"to_local_datetime",
"workday_count",
"arg_previous_greater",
"__version__",
]
96 changes: 96 additions & 0 deletions polars_xdt/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,3 +731,99 @@ def workday_count(
"holidays": holidays_int,
},
)


def arg_previous_greater(expr: IntoExpr) -> pl.Expr:
"""
Find the row count of the previous value greater than the current one.

Parameters
----------
expr
Expression.

Returns
-------
Expr
UInt64 or UInt32 type, depending on the platform.

Examples
--------
>>> import polars as pl
>>> import polars_xdt as xdt
>>> df = pl.DataFrame({"value": [1, 9, 6, 7, 3]})
>>> df.with_columns(result=xdt.arg_previous_greater("value"))
shape: (5, 2)
┌───────┬────────┐
│ value ┆ result │
│ --- ┆ --- │
│ i64 ┆ u32 │
╞═══════╪════════╡
│ 1 ┆ null │
│ 9 ┆ 1 │
│ 6 ┆ 1 │
│ 7 ┆ 1 │
│ 3 ┆ 3 │
└───────┴────────┘

This can be useful when working with time series. For example,
if you a dataset like this:

>>> df = pl.DataFrame(
... {
... "date": [
... "2024-02-01",
... "2024-02-02",
... "2024-02-03",
... "2024-02-04",
... "2024-02-05",
... "2024-02-06",
... "2024-02-07",
... "2024-02-08",
... "2024-02-09",
... "2024-02-10",
... ],
... "group": ["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"],
... "value": [1, 9, None, 7, 3, 2, 4, 5, 1, 9],
... }
... )
>>> df = df.with_columns(pl.col("date").str.to_date())

and want find out, for each day and each item, how many days it's
been since `'value'` was higher than it currently is, you could do

>>> df.with_columns(
... result=(
... (
... pl.col("date")
... - pl.col("date")
... .gather(xdt.arg_previous_greater("value"))
... .over("group")
... ).dt.total_days()
... ),
... )
shape: (10, 4)
┌────────────┬───────┬───────┬────────┐
│ date ┆ group ┆ value ┆ result │
│ --- ┆ --- ┆ --- ┆ --- │
│ date ┆ str ┆ i64 ┆ i64 │
╞════════════╪═══════╪═══════╪════════╡
│ 2024-02-01 ┆ A ┆ 1 ┆ null │
│ 2024-02-02 ┆ A ┆ 9 ┆ 0 │
│ 2024-02-03 ┆ A ┆ null ┆ null │
│ 2024-02-04 ┆ A ┆ 7 ┆ 2 │
│ 2024-02-05 ┆ A ┆ 3 ┆ 1 │
│ 2024-02-06 ┆ B ┆ 2 ┆ null │
│ 2024-02-07 ┆ B ┆ 4 ┆ 0 │
│ 2024-02-08 ┆ B ┆ 5 ┆ 0 │
│ 2024-02-09 ┆ B ┆ 1 ┆ 1 │
│ 2024-02-10 ┆ B ┆ 9 ┆ 0 │
└────────────┴───────┴───────┴────────┘

"""
expr = parse_into_expr(expr)
return expr.register_plugin(
lib=lib,
symbol="arg_previous_greater",
is_elementwise=False,
)
38 changes: 38 additions & 0 deletions src/arg_previous_greater.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use polars::prelude::*;

pub(crate) fn impl_arg_previous_greater<T>(ca: &ChunkedArray<T>) -> IdxCa
where
T: PolarsNumericType,
{
let mut idx: Vec<Option<i32>> = Vec::with_capacity(ca.len());
let out: IdxCa = ca
.into_iter()
.enumerate()
.map(|(i, opt_val)| {
if opt_val.is_none() {
idx.push(None);
return None;
}
let i_curr = i;
let mut i = Some((i as i32) - 1); // look at previous element
while i >= Some(0) && ca.get(i.unwrap() as usize).is_none() {
// find previous non-null value
i = Some(i.unwrap() - 1)
}
if i < Some(0) {
idx.push(None);
return None;
}
while i.is_some() && opt_val >= ca.get(i.unwrap() as usize) {
i = idx[i.unwrap() as usize];
}
if i.is_none() {
idx.push(None);
return Some(i_curr as IdxSize);
}
idx.push(i);
i.map(|x| x as IdxSize)
})
.collect();
out
}
20 changes: 20 additions & 0 deletions src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![allow(clippy::unit_arg, clippy::unused_unit)]
use crate::arg_previous_greater::*;
use crate::business_days::*;
use crate::format_localized::*;
use crate::is_workday::*;
Expand Down Expand Up @@ -146,3 +147,22 @@ fn dst_offset(inputs: &[Series]) -> PolarsResult<Series> {
_ => polars_bail!(InvalidOperation: "base_utc_offset only works on Datetime type."),
}
}

fn list_idx_dtype(input_fields: &[Field]) -> PolarsResult<Field> {
let field = Field::new(input_fields[0].name(), DataType::List(Box::new(IDX_DTYPE)));
Ok(field.clone())
}

#[polars_expr(output_type_func=list_idx_dtype)]
fn arg_previous_greater(inputs: &[Series]) -> PolarsResult<Series> {
let ser = &inputs[0];
match ser.dtype() {
DataType::Int64 => Ok(impl_arg_previous_greater(ser.i64().unwrap()).into_series()),
DataType::Int32 => Ok(impl_arg_previous_greater(ser.i32().unwrap()).into_series()),
DataType::UInt64 => Ok(impl_arg_previous_greater(ser.u64().unwrap()).into_series()),
DataType::UInt32 => Ok(impl_arg_previous_greater(ser.u32().unwrap()).into_series()),
DataType::Float64 => Ok(impl_arg_previous_greater(ser.f64().unwrap()).into_series()),
DataType::Float32 => Ok(impl_arg_previous_greater(ser.f32().unwrap()).into_series()),
dt => polars_bail!(ComputeError:"Expected numeric data type, got: {}", dt),
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod arg_previous_greater;
mod business_days;
mod expressions;
mod format_localized;
Expand Down
Loading