diff --git a/docs/API.rst b/docs/API.rst index 405d5af..f1da729 100644 --- a/docs/API.rst +++ b/docs/API.rst @@ -13,6 +13,7 @@ API polars_xdt.from_local_datetime polars_xdt.is_workday polars_xdt.month_name + polars_xdt.month_delta polars_xdt.offset_by polars_xdt.to_local_datetime polars_xdt.to_julian_date diff --git a/polars_xdt/__init__.py b/polars_xdt/__init__.py index f5ea8ff..00036fa 100644 --- a/polars_xdt/__init__.py +++ b/polars_xdt/__init__.py @@ -9,6 +9,7 @@ format_localized, from_local_datetime, is_workday, + month_delta, month_name, offset_by, to_julian_date, @@ -31,6 +32,7 @@ "to_julian_date", "to_local_datetime", "workday_count", + "month_delta", "arg_previous_greater", "ewma_by_time", "__version__", diff --git a/polars_xdt/functions.py b/polars_xdt/functions.py index 3b9f2aa..f6e462d 100644 --- a/polars_xdt/functions.py +++ b/polars_xdt/functions.py @@ -733,6 +733,75 @@ def workday_count( ) +def month_delta( + start_dates: IntoExpr, + end_dates: IntoExpr, +) -> pl.Expr: + """ + Calculate the number of months between two Series. + + Parameters + ---------- + start_dates + A Series object containing the start dates. + end_dates + A Series object containing the end dates. + + Returns + ------- + polars.Expr + + Examples + -------- + >>> from datetime import date + >>> import polars as pl + >>> import polars_xdt as xdt + >>> df = pl.DataFrame( + ... { + ... "start_date": [ + ... date(2024, 3, 1), + ... date(2024, 3, 31), + ... date(2022, 2, 28), + ... date(2023, 1, 31), + ... date(2019, 12, 31), + ... ], + ... "end_date": [ + ... date(2023, 2, 28), + ... date(2023, 2, 28), + ... date(2023, 2, 28), + ... date(2023, 1, 31), + ... date(2023, 1, 1), + ... ], + ... }, + ... ) + >>> df.with_columns( + ... xdt.month_delta("start_date", "end_date").alias("month_delta") + ... ) + shape: (5, 3) + ┌────────────┬────────────┬─────────────┐ + │ start_date ┆ end_date ┆ month_delta │ + │ --- ┆ --- ┆ --- │ + │ date ┆ date ┆ i32 │ + ╞════════════╪════════════╪═════════════╡ + │ 2024-03-01 ┆ 2023-02-28 ┆ -12 │ + │ 2024-03-31 ┆ 2023-02-28 ┆ -13 │ + │ 2022-02-28 ┆ 2023-02-28 ┆ 12 │ + │ 2023-01-31 ┆ 2023-01-31 ┆ 0 │ + │ 2019-12-31 ┆ 2023-01-01 ┆ 36 │ + └────────────┴────────────┴─────────────┘ + + """ + start_dates = parse_into_expr(start_dates) + end_dates = parse_into_expr(end_dates) + + return start_dates.register_plugin( + lib=lib, + symbol="month_delta", + is_elementwise=True, + args=[end_dates], + ) + + def arg_previous_greater(expr: IntoExpr) -> pl.Expr: """ Find the row count of the previous value greater than the current one. diff --git a/polars_xdt/ranges.py b/polars_xdt/ranges.py index 1a3d852..1d820c2 100644 --- a/polars_xdt/ranges.py +++ b/polars_xdt/ranges.py @@ -25,7 +25,8 @@ def date_range( eager: Literal[False] = ..., weekend: Sequence[str] = ..., holidays: Sequence[date] | None = ..., -) -> pl.Expr: ... +) -> pl.Expr: + ... @overload @@ -40,7 +41,8 @@ def date_range( eager: Literal[True], weekend: Sequence[str] = ..., holidays: Sequence[date] | None = ..., -) -> pl.Series: ... +) -> pl.Series: + ... @overload @@ -55,7 +57,8 @@ def date_range( eager: bool = ..., weekend: Sequence[str] = ..., holidays: Sequence[date] | None = ..., -) -> pl.Series | pl.Expr: ... +) -> pl.Series | pl.Expr: + ... def date_range( # noqa: PLR0913 diff --git a/pyproject.toml b/pyproject.toml index 0284020..5fe715a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ line-length = 80 [[tool.mypy.overrides]] module = [ - "pandas" + "pandas", + "dateutil.*", ] ignore_missing_imports = true diff --git a/requirements.txt b/requirements.txt index 3802661..0b7e3bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ maturin +python-dateutil polars hypothesis numpy diff --git a/src/expressions.rs b/src/expressions.rs index 09608db..2635ac1 100644 --- a/src/expressions.rs +++ b/src/expressions.rs @@ -4,6 +4,7 @@ use crate::business_days::*; use crate::ewma_by_time::*; use crate::format_localized::*; use crate::is_workday::*; +use crate::month_delta::*; use crate::sub::*; use crate::timezone::*; use crate::to_julian::*; @@ -87,6 +88,13 @@ fn workday_count(inputs: &[Series], kwargs: BusinessDayKwargs) -> PolarsResult PolarsResult { + let start_dates = &inputs[0]; + let end_dates = &inputs[1]; + impl_month_delta(start_dates, end_dates) +} + #[polars_expr(output_type=Boolean)] fn is_workday(inputs: &[Series], kwargs: BusinessDayKwargs) -> PolarsResult { let dates = &inputs[0]; diff --git a/src/lib.rs b/src/lib.rs index 2f43f9b..77531f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ mod ewma_by_time; mod expressions; mod format_localized; mod is_workday; +mod month_delta; mod sub; mod timezone; mod to_julian; diff --git a/src/month_delta.rs b/src/month_delta.rs new file mode 100644 index 0000000..3284f39 --- /dev/null +++ b/src/month_delta.rs @@ -0,0 +1,133 @@ +use chrono::Datelike; +use chrono::NaiveDate; +use polars::prelude::*; + +// Copied from https://docs.pola.rs/docs/rust/dev/src/polars_time/windows/duration.rs.html#398 +// `add_month` is a private function. +fn add_month(ts: NaiveDate, n_months: i64) -> NaiveDate { + // Have to define, because it is hidden + const DAYS_PER_MONTH: [[i64; 12]; 2] = [ + //J F M A M J J A S O N D + [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31], // non-leap year + [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31], // leap year + ]; + let months = n_months; + + // Retrieve the current date and increment the values + // based on the number of months + + let mut year = ts.year(); + let mut month = ts.month() as i32; + let mut day = ts.day(); + year += (months / 12) as i32; + month += (months % 12) as i32; + + // if the month overflowed or underflowed, adjust the year + // accordingly. Because we add the modulo for the months + // the year will only adjust by one + if month > 12 { + year += 1; + month -= 12; + } else if month <= 0 { + year -= 1; + month += 12; + } + + // Adding this not to import copy pasta again + let leap_year = year % 400 == 0 || (year % 4 == 0 && year % 100 != 0); + // Normalize the day if we are past the end of the month. + let last_day_of_month = DAYS_PER_MONTH[leap_year as usize][(month - 1) as usize] as u32; + + if day > last_day_of_month { + day = last_day_of_month + } + + NaiveDate::from_ymd_opt(year, month as u32, day).unwrap() +} + +/// Calculates the difference in months between two dates. +/// +/// The difference is expressed as the number of whole months between the two dates. +/// If `right` is before `left`, the return value will be negative. +/// +/// # Arguments +/// +/// * `left`: `NaiveDate` - The start date. +/// * `right`: `NaiveDate` - The end date. +/// +/// # Returns +/// +/// * `i32` - The number of whole months between `left` and `right`. +/// +/// # Examples +/// +/// ``` +/// let start_date = NaiveDate::from_ymd(2023, 1, 1); +/// let end_date = NaiveDate::from_ymd(2023, 4, 1); +/// assert_eq!(get_m_diff(start_date, end_date), 3); +/// ``` +fn get_m_diff(mut left: NaiveDate, right: NaiveDate) -> i32 { + let mut n = 0; + while left < right { + left = add_month(left, 1); + if left <= right { + n += 1; + } + } + n +} + +/// Implements the month delta operation for Polars series containing dates. +/// +/// This function calculates the difference in months between two series of dates. +/// The operation is pairwise: it computes the month difference for each pair +/// of start and end dates in the input series. +/// +/// # Arguments +/// +/// * `start_dates`: `&Series` - A series of start dates. +/// * `end_dates`: `&Series` - A series of end dates. +/// +/// # Returns +/// +/// * `PolarsResult` - A new series containing the month differences as `i32` values. +/// +/// # Errors +/// +/// Returns an error if the input series are not of the `Date` type. +/// +/// # Examples +/// +/// ``` +/// use polars::prelude::*; +/// let date1 = NaiveDate::from_ymd(2023, 1, 1); // January 1, 2023 +/// let date2 = NaiveDate::from_ymd(2023, 3, 1); // March 1, 2023 +/// let date3 = NaiveDate::from_ymd(2023, 4, 1); // April 1, 2023 +/// let date4 = NaiveDate::from_ymd(2023, 6, 1); // June 1, 2023 +/// let start_dates = Series::new("start_dates", &[date1, date2]); +/// let end_dates = Series::new("end_dates", &[date3, date4]); +/// let month_deltas = impl_month_delta(&start_dates, &end_dates).unwrap(); +/// ``` +pub(crate) fn impl_month_delta(start_dates: &Series, end_dates: &Series) -> PolarsResult { + if (start_dates.dtype() != &DataType::Date) || (end_dates.dtype() != &DataType::Date) { + polars_bail!(InvalidOperation: "polars_xdt.month_delta only works on Date type. Please cast to Date first."); + } + let start_dates = start_dates.date()?; + let end_dates = end_dates.date()?; + + let month_diff: Int32Chunked = start_dates + .as_date_iter() + .zip(end_dates.as_date_iter()) + .map(|(s_arr, e_arr)| { + s_arr.zip(e_arr).map(|(start_date, end_date)| { + if start_date > end_date { + -get_m_diff(end_date, start_date) + } else { + get_m_diff(start_date, end_date) + } + }) + }) + .collect(); + + Ok(month_diff.into_series()) +} diff --git a/tests/test_month_delta.py b/tests/test_month_delta.py new file mode 100644 index 0000000..0647137 --- /dev/null +++ b/tests/test_month_delta.py @@ -0,0 +1,114 @@ +import polars as pl +import polars_xdt as xdt +from datetime import date +from dateutil.relativedelta import relativedelta + +from hypothesis import given, strategies as st, assume + + +def test_month_delta(): + df = pl.DataFrame( + { + "start_date": [ + date(2024, 1, 1), + date(2024, 1, 1), + date(2023, 9, 1), + date(2023, 1, 4), + date(2022, 6, 4), + date(2023, 1, 1), + date(2023, 1, 1), + date(2022, 2, 1), + date(2022, 2, 1), + date(2024, 3, 1), + date(2024, 3, 31), + date(2022, 2, 28), + date(2023, 1, 31), + date(2019, 12, 31), + date(2024, 1, 31), + date(1970, 1, 2), + ], + "end_date": [ + date(2024, 1, 4), + date(2024, 1, 31), + date(2023, 11, 1), + date(2022, 1, 4), + date(2022, 1, 4), + date(2022, 12, 31), + date(2021, 12, 31), + date(2022, 3, 1), + date(2023, 3, 1), + date(2023, 2, 28), + date(2023, 2, 28), + date(2023, 1, 31), + date(2022, 2, 28), + date(2023, 1, 1), + date(2024, 4, 30), + date(1971, 1, 1), + ], + }, + ) + + assert_month_diff = [ + 0, # 2024-01-01 to 2024-01-04 + 0, # 2024-01-01 to 2024-01-31 + 2, # 2023-09-01 to 2023-11-01 + -12, # 2023-01-04 to 2022-01-04 + -5, # 2022-06-04 to 2022-01-04 + 0, # 2023-01-01 to 2022-12-31 + -12, # 2023-01-01 to 2021-12-31 + 1, # 2022-02-01 to 2022-03-01 + 13, # 2022-02-01 to 2023-03-01 + -12, # 2024-03-01 to 2023-02-28 + -13, # 2024-03-31 to 2023-02-28 + 11, # 2022-02-28 to 2023-01-31 + -11, # 2023-01-31 to 2022-02-28 + 36, # 2019-12-31 to 2023-01-01 + 3, # 2024-01-31 to 2024-04-30 + 11, # 1970-01-02 to 1971-01-01 + ] + df = df.with_columns( + # For easier visual debugging purposes + pl.Series(name="assert_month_delta", values=assert_month_diff), + month_delta=xdt.month_delta("start_date", "end_date"), + ) + # pl.Config.set_tbl_rows(50) + # print(df) + month_diff_list = df.get_column("month_delta").to_list() + assert assert_month_diff == month_diff_list, ( + "The month difference list did not match the expected values.\n" + "Please check the function: 'month_diff.rs' for discrepancies." + ) + + +@given( + start_date=st.dates( + min_value=date(1960, 1, 1), max_value=date(2024, 12, 31) + ), + end_date=st.dates(min_value=date(1960, 1, 1), max_value=date(2024, 12, 31)), +) +def test_month_delta_hypothesis(start_date: date, end_date: date) -> None: + df = pl.DataFrame( + { + "start_date": [start_date], + "end_date": [end_date], + } + ) + result = df.select(result=xdt.month_delta("start_date", "end_date"))[ + "result" + ].item() + + expected = 0 + if start_date <= end_date: + while True: + start_date = start_date + relativedelta(months=1) + if start_date > end_date: + break + expected += 1 + else: + while True: + end_date = end_date + relativedelta(months=1) + if end_date > start_date: + break + expected -= 1 + + assert result == expected