Skip to content

Commit

Permalink
Merge pull request #70 from akmalsoliev/fix_month_delta
Browse files Browse the repository at this point in the history
FIX: `month_delta` with optimisations
  • Loading branch information
MarcoGorelli authored Mar 21, 2024
2 parents 2929f3a + 62cc30d commit 6e86dc5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 102 deletions.
31 changes: 16 additions & 15 deletions src/month_delta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,22 @@ fn add_month(ts: NaiveDate, n_months: i64) -> NaiveDate {
/// let end_date = NaiveDate::from_ymd(2023, 4, 1);
/// assert_eq!(get_m_diff(start_date, end_date), 3);
/// ```
fn get_m_diff(mut left: NaiveDate, right: NaiveDate) -> i32 {
fn get_m_diff(left: NaiveDate, right: NaiveDate) -> i32 {
let mut n = 0;
if left.year() + 2 < right.year() {
n = (right.year() - left.year() - 1) * 12;
left = add_month(left, n.into());
}
while left < right {
left = add_month(left, 1);
if left <= right {
if right >= left {
if right.year() + 1 > left.year() {
n = (right.year() - left.year() - 1) * 12;
}
while add_month(left, (n + 1).into()) <= right {
n += 1;
}
} else {
if left.year() + 1 > right.year() {
n = -(left.year() - right.year() - 1) * 12;
}
while add_month(left, (n - 1).into()) >= right {
n -= 1;
}
}
n
}
Expand Down Expand Up @@ -123,13 +128,9 @@ pub(crate) fn impl_month_delta(start_dates: &Series, end_dates: &Series) -> Pola
.as_date_iter()
.zip(end_dates.as_date_iter())
.map(|(s_arr, e_arr)| {
s_arr.zip(e_arr).map(|(start_date, end_date)| {
if start_date > end_date {
-get_m_diff(end_date, start_date)
} else {
get_m_diff(start_date, end_date)
}
})
s_arr
.zip(e_arr)
.map(|(start_date, end_date)| get_m_diff(start_date, end_date))
})
.collect();

Expand Down
105 changes: 18 additions & 87 deletions tests/test_month_delta.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,29 @@
import polars as pl
import polars_xdt as xdt
from datetime import date
from dateutil.relativedelta import relativedelta

from hypothesis import given, strategies as st, assume

import polars as pl
from dateutil.relativedelta import relativedelta
from hypothesis import example, given, settings
from hypothesis import strategies as st

def test_month_delta():
df = pl.DataFrame(
{
"start_date": [
date(2024, 1, 1),
date(2024, 1, 1),
date(2023, 9, 1),
date(2023, 1, 4),
date(2022, 6, 4),
date(2023, 1, 1),
date(2023, 1, 1),
date(2022, 2, 1),
date(2022, 2, 1),
date(2024, 3, 1),
date(2024, 3, 31),
date(2022, 2, 28),
date(2023, 1, 31),
date(2019, 12, 31),
date(2024, 1, 31),
date(1970, 1, 2),
],
"end_date": [
date(2024, 1, 4),
date(2024, 1, 31),
date(2023, 11, 1),
date(2022, 1, 4),
date(2022, 1, 4),
date(2022, 12, 31),
date(2021, 12, 31),
date(2022, 3, 1),
date(2023, 3, 1),
date(2023, 2, 28),
date(2023, 2, 28),
date(2023, 1, 31),
date(2022, 2, 28),
date(2023, 1, 1),
date(2024, 4, 30),
date(1971, 1, 1),
],
},
)

assert_month_diff = [
0, # 2024-01-01 to 2024-01-04
0, # 2024-01-01 to 2024-01-31
2, # 2023-09-01 to 2023-11-01
-12, # 2023-01-04 to 2022-01-04
-5, # 2022-06-04 to 2022-01-04
0, # 2023-01-01 to 2022-12-31
-12, # 2023-01-01 to 2021-12-31
1, # 2022-02-01 to 2022-03-01
13, # 2022-02-01 to 2023-03-01
-12, # 2024-03-01 to 2023-02-28
-13, # 2024-03-31 to 2023-02-28
11, # 2022-02-28 to 2023-01-31
-11, # 2023-01-31 to 2022-02-28
36, # 2019-12-31 to 2023-01-01
3, # 2024-01-31 to 2024-04-30
11, # 1970-01-02 to 1971-01-01
]
df = df.with_columns(
# For easier visual debugging purposes
pl.Series(name="assert_month_delta", values=assert_month_diff),
month_delta=xdt.month_delta("start_date", "end_date"),
)
# pl.Config.set_tbl_rows(50)
# print(df)
month_diff_list = df.get_column("month_delta").to_list()
assert assert_month_diff == month_diff_list, (
"The month difference list did not match the expected values.\n"
"Please check the function: 'month_diff.rs' for discrepancies."
)
import polars_xdt as xdt


@given(
start_date=st.dates(
min_value=date(1960, 1, 1), max_value=date(2024, 12, 31)
min_value=date(1924, 1, 1), max_value=date(2024, 12, 31)
),
end_date=st.dates(min_value=date(1960, 1, 1), max_value=date(2024, 12, 31)),
)
@example(start_date=date(2022, 2, 28), end_date=date(2024, 2, 29)) # Leap year
@example(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) # Same month
@example(start_date=date(1973, 1, 1), end_date=date(1973, 1, 1)) # Same date
@example(start_date=date(2019, 12, 31), end_date=date(2020, 1, 1)) # Border
@example(start_date=date(2018, 12, 1), end_date=date(2020, 1, 1)) # End of year
@example(start_date=date(2022, 12, 1), end_date=date(2020, 1, 1)) # Negative
@example(
start_date=date(2000, 3, 29), end_date=date(2003, 1, 28)
) # Failed test
@settings(max_examples=1000)
def test_month_delta_hypothesis(start_date: date, end_date: date) -> None:
df = pl.DataFrame(
{
Expand All @@ -99,16 +37,9 @@ def test_month_delta_hypothesis(start_date: date, end_date: date) -> None:

expected = 0
if start_date <= end_date:
while True:
start_date = start_date + relativedelta(months=1)
if start_date > end_date:
break
while start_date + relativedelta(months=expected + 1) <= end_date:
expected += 1
else:
while True:
end_date = end_date + relativedelta(months=1)
if end_date > start_date:
break
while start_date + relativedelta(months=expected - 1) >= end_date:
expected -= 1

assert result == expected

0 comments on commit 6e86dc5

Please sign in to comment.