Skip to content

Commit

Permalink
Merge pull request #72 from wbeardall/main
Browse files Browse the repository at this point in the history
Add robustness to leading `null` values in series in 'ewma_by_time'.
  • Loading branch information
MarcoGorelli authored Mar 25, 2024
2 parents 14ed739 + 5f287e1 commit f9e083d
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 20 deletions.
26 changes: 20 additions & 6 deletions src/ewma_by_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub(crate) fn impl_ewma_by_time_float(
times: &Int64Chunked,
values: &Float64Chunked,
half_life: i64,
time_unit: TimeUnit,
time_unit: TimeUnit
) -> Float64Chunked {
let mut out = Vec::with_capacity(times.len());
if values.is_empty() {
Expand All @@ -18,13 +18,27 @@ pub(crate) fn impl_ewma_by_time_float(
TimeUnit::Nanoseconds => half_life * 1_000,
};

let mut prev_time: i64 = times.get(0).unwrap();
let mut prev_result = values.get(0).unwrap();
out.push(Some(prev_result));
let mut skip_rows: usize = 0;
let mut prev_time: i64 = 0;
let mut prev_result: f64 = 0.;
for (idx, (value, time)) in values.iter().zip(times.iter()).enumerate() {
match (time, value) {
(Some(time), Some(value)) => {
prev_time = time;
prev_result = value;
out.push(Some(prev_result));
skip_rows = idx + 1;
break;
},
_ => {
out.push(None);
}
};
}
values
.iter()
.zip(times.iter())
.skip(1)
.skip(skip_rows)
.for_each(|(value, time)| {
match (time, value) {
(Some(time), Some(value)) => {
Expand All @@ -36,7 +50,7 @@ pub(crate) fn impl_ewma_by_time_float(
prev_time = time;
prev_result = result;
out.push(Some(result));
}
},
_ => out.push(None),
}
});
Expand Down
2 changes: 1 addition & 1 deletion src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ fn arg_previous_greater(inputs: &[Series]) -> PolarsResult<Series> {

#[derive(Deserialize)]
struct EwmTimeKwargs {
half_life: i64,
half_life: i64
}

#[polars_expr(output_type=Float64)]
Expand Down
40 changes: 27 additions & 13 deletions tests/test_ewma_by_time.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,48 @@
import polars as pl
from polars.testing import assert_frame_equal
import polars_xdt as xdt
import pytest
from datetime import datetime, timedelta

import os

def test_ewma_by_time():
os.environ["POLARS_VERBOSE"] = "1"

@pytest.mark.parametrize("start_null", [True, False])
def test_ewma_by_time(start_null):
if start_null:
values = [None]
times = [datetime(2020, 1, 1)]
else:
values = []
times = []

df = pl.DataFrame(
{
"values": [0.0, 1, 2, None, 4],
"times": [
datetime(2020, 1, 1),
datetime(2020, 1, 3),
datetime(2020, 1, 10),
datetime(2020, 1, 15),
datetime(2020, 1, 17),
"values": values + [0.0, 1., 2., None, 4.],
"times": times + [
datetime(2020, 1, 2),
datetime(2020, 1, 4),
datetime(2020, 1, 11),
datetime(2020, 1, 16),
datetime(2020, 1, 18),
],
}
)
result = df.select(
xdt.ewma_by_time("values", times="times", half_life=timedelta(days=4)),
xdt.ewma_by_time("values", times="times", half_life=timedelta(days=2)),
)

expected = pl.DataFrame(
{
"values": [
"values": values + [
0.0,
0.2928932188134524,
1.4924741174358913,
0.5,
1.8674174785275222,
None,
3.2545080948503213,
3.811504554703363,
]
}
)

assert_frame_equal(result, expected)

0 comments on commit f9e083d

Please sign in to comment.