diff --git a/src/month_delta.rs b/src/month_delta.rs index 02d33f6..c6be7f2 100644 --- a/src/month_delta.rs +++ b/src/month_delta.rs @@ -66,17 +66,22 @@ fn add_month(ts: NaiveDate, n_months: i64) -> NaiveDate { /// let end_date = NaiveDate::from_ymd(2023, 4, 1); /// assert_eq!(get_m_diff(start_date, end_date), 3); /// ``` -fn get_m_diff(mut left: NaiveDate, right: NaiveDate) -> i32 { +fn get_m_diff(left: NaiveDate, right: NaiveDate) -> i32 { let mut n = 0; - if left.year() + 2 < right.year() { - n = (right.year() - left.year() - 1) * 12; - left = add_month(left, n.into()); - } - while left < right { - left = add_month(left, 1); - if left <= right { + if right >= left { + if right.year() + 1 > left.year() { + n = (right.year() - left.year() - 1) * 12; + } + while add_month(left, (n + 1).into()) <= right { n += 1; } + } else { + if left.year() + 1 > right.year() { + n = -(left.year() - right.year() - 1) * 12; + } + while add_month(left, (n - 1).into()) >= right { + n -= 1; + } } n } @@ -123,13 +128,9 @@ pub(crate) fn impl_month_delta(start_dates: &Series, end_dates: &Series) -> Pola .as_date_iter() .zip(end_dates.as_date_iter()) .map(|(s_arr, e_arr)| { - s_arr.zip(e_arr).map(|(start_date, end_date)| { - if start_date > end_date { - -get_m_diff(end_date, start_date) - } else { - get_m_diff(start_date, end_date) - } - }) + s_arr + .zip(e_arr) + .map(|(start_date, end_date)| get_m_diff(start_date, end_date)) }) .collect(); diff --git a/tests/test_month_delta.py b/tests/test_month_delta.py index 0647137..0de669c 100644 --- a/tests/test_month_delta.py +++ b/tests/test_month_delta.py @@ -1,91 +1,29 @@ -import polars as pl -import polars_xdt as xdt from datetime import date -from dateutil.relativedelta import relativedelta - -from hypothesis import given, strategies as st, assume +import polars as pl +from dateutil.relativedelta import relativedelta +from hypothesis import example, given, settings +from hypothesis import strategies as st -def test_month_delta(): - df = pl.DataFrame( - { - "start_date": [ - date(2024, 1, 1), - date(2024, 1, 1), - date(2023, 9, 1), - date(2023, 1, 4), - date(2022, 6, 4), - date(2023, 1, 1), - date(2023, 1, 1), - date(2022, 2, 1), - date(2022, 2, 1), - date(2024, 3, 1), - date(2024, 3, 31), - date(2022, 2, 28), - date(2023, 1, 31), - date(2019, 12, 31), - date(2024, 1, 31), - date(1970, 1, 2), - ], - "end_date": [ - date(2024, 1, 4), - date(2024, 1, 31), - date(2023, 11, 1), - date(2022, 1, 4), - date(2022, 1, 4), - date(2022, 12, 31), - date(2021, 12, 31), - date(2022, 3, 1), - date(2023, 3, 1), - date(2023, 2, 28), - date(2023, 2, 28), - date(2023, 1, 31), - date(2022, 2, 28), - date(2023, 1, 1), - date(2024, 4, 30), - date(1971, 1, 1), - ], - }, - ) - - assert_month_diff = [ - 0, # 2024-01-01 to 2024-01-04 - 0, # 2024-01-01 to 2024-01-31 - 2, # 2023-09-01 to 2023-11-01 - -12, # 2023-01-04 to 2022-01-04 - -5, # 2022-06-04 to 2022-01-04 - 0, # 2023-01-01 to 2022-12-31 - -12, # 2023-01-01 to 2021-12-31 - 1, # 2022-02-01 to 2022-03-01 - 13, # 2022-02-01 to 2023-03-01 - -12, # 2024-03-01 to 2023-02-28 - -13, # 2024-03-31 to 2023-02-28 - 11, # 2022-02-28 to 2023-01-31 - -11, # 2023-01-31 to 2022-02-28 - 36, # 2019-12-31 to 2023-01-01 - 3, # 2024-01-31 to 2024-04-30 - 11, # 1970-01-02 to 1971-01-01 - ] - df = df.with_columns( - # For easier visual debugging purposes - pl.Series(name="assert_month_delta", values=assert_month_diff), - month_delta=xdt.month_delta("start_date", "end_date"), - ) - # pl.Config.set_tbl_rows(50) - # print(df) - month_diff_list = df.get_column("month_delta").to_list() - assert assert_month_diff == month_diff_list, ( - "The month difference list did not match the expected values.\n" - "Please check the function: 'month_diff.rs' for discrepancies." - ) +import polars_xdt as xdt @given( start_date=st.dates( - min_value=date(1960, 1, 1), max_value=date(2024, 12, 31) + min_value=date(1924, 1, 1), max_value=date(2024, 12, 31) ), end_date=st.dates(min_value=date(1960, 1, 1), max_value=date(2024, 12, 31)), ) +@example(start_date=date(2022, 2, 28), end_date=date(2024, 2, 29)) # Leap year +@example(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) # Same month +@example(start_date=date(1973, 1, 1), end_date=date(1973, 1, 1)) # Same date +@example(start_date=date(2019, 12, 31), end_date=date(2020, 1, 1)) # Border +@example(start_date=date(2018, 12, 1), end_date=date(2020, 1, 1)) # End of year +@example(start_date=date(2022, 12, 1), end_date=date(2020, 1, 1)) # Negative +@example( + start_date=date(2000, 3, 29), end_date=date(2003, 1, 28) +) # Failed test +@settings(max_examples=1000) def test_month_delta_hypothesis(start_date: date, end_date: date) -> None: df = pl.DataFrame( { @@ -99,16 +37,9 @@ def test_month_delta_hypothesis(start_date: date, end_date: date) -> None: expected = 0 if start_date <= end_date: - while True: - start_date = start_date + relativedelta(months=1) - if start_date > end_date: - break + while start_date + relativedelta(months=expected + 1) <= end_date: expected += 1 else: - while True: - end_date = end_date + relativedelta(months=1) - if end_date > start_date: - break + while start_date + relativedelta(months=expected - 1) >= end_date: expected -= 1 - assert result == expected