Skip to content

Commit

Permalink
[NEW] month_diff function
Browse files Browse the repository at this point in the history
The month_diff function computes the integer representation of the month difference between two dates.
The resulting values can be both positive and negative, signifying whether the comparative shift from the target date is to the past or the future, respectively.
  • Loading branch information
akmalsoliev committed Jan 20, 2024
1 parent 8a4c1c1 commit 7062510
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 0 deletions.
14 changes: 14 additions & 0 deletions polars_xdt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,20 @@ def sub(
)
return cast(XDTExpr, result)

def month_diff(
self,
end_dates: str | pl.Expr,
) -> XDTExpr:
if isinstance(end_dates, str):
end_dates = pl.col(end_dates)
result = self._expr.register_plugin(
lib=lib,
symbol="month_diff",
is_elementwise=True,
args=[end_dates],
)
return cast(XDTExpr, result)

def is_workday(
self,
*,
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod sub;
mod timezone;
mod to_julian;
mod utc_offsets;
mod month_diff;

use pyo3::types::PyModule;
use pyo3::{pymodule, PyResult, Python};
Expand Down
32 changes: 32 additions & 0 deletions src/month_diff.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use chrono::Datelike;
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;

#[polars_expr(output_type=Int32)]
pub fn month_diff(
dates: &[Series],
) -> PolarsResult<Series> {
let start_dates = dates[0].date()?;
let end_dates = dates[1].date()?;

if (start_dates.dtype() != &DataType::Date) || (end_dates.dtype() != &DataType::Date) {
polars_bail!(InvalidOperation: "polars_xdt.workday_count only works on Date type. Please cast to Date first.");
}


let month_diff: Int32Chunked = start_dates.as_date_iter().zip(end_dates.as_date_iter()).map(
|(s_arr, e_arr)| {
s_arr.zip(e_arr).map(
|(left, right)| {
{
let month_diff = right.month() as i32 - left.month() as i32;
let year_diff = (right.year() - left.year()) * 12;
month_diff + year_diff
}
}
)
}
).collect();

Ok(month_diff.into_series())
}
41 changes: 41 additions & 0 deletions tests/test_month_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import polars as pl
import polars_xdt as xdt
from datetime import date

def test_month_diff():
df = pl.DataFrame(
{
"start_date": [
date(2023, 4, 3),
date(2023, 9, 1),
date(2024, 1, 4),
date(2023, 1, 4)
],

"end_date": [
date(2024, 5, 3),
date(2023, 11, 1),
date(2024, 1, 4),
date(2022, 1, 4),
]
},
)

# Expected output of diff of two dates: [13, 2, 0, -12]

df = (
df.with_columns(
xdt.col("start_date").xdt.month_diff("end_date").alias("month_diff")
)
)

month_diff_list = (
df
.get_column("month_diff")
.to_list()
)

assert [13, 2, 0, -12] == month_diff_list, (
"The month difference list did not match the expected values.\n"
"Please check the function: 'month_diff.rs' for discrepancies."
)

0 comments on commit 7062510

Please sign in to comment.