Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 3, 2024
1 parent 29ebc9e commit 0b61bc7
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 28 deletions.
33 changes: 22 additions & 11 deletions polars_xdt/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,8 +836,8 @@ def ewma_by_time(
halflife: timedelta,
adjust: bool = True,
) -> pl.Expr:
"""
Calculated time-based exponentially weighted moving average.
r"""
Calculate time-based exponentially weighted moving average.
Given observations :math:`x_1, x_2, \ldots, x_n` at times
:math:`t_1, t_2, \ldots, t_n`, the **unadjusted** EWMA is calculated as
Expand All @@ -846,11 +846,11 @@ def ewma_by_time(
y_0 &= x_0
\\alpha_i &= exp(-\\lambda(t_i - t_{i-1}))
\alpha_i &= exp(-\lambda(t_i - t_{i-1}))
y_i &= \alpha_i x_i + (1 - \alpha_i) y_{i-1}; \quad i > 0
y_i &= \\alpha_i x_i + (1 - \\alpha_i) y_{i-1}; \\quad i > 0
where :math:`\\lambda` equals :math:`\\ln(2) / \\text{halflife}`.
where :math:`\lambda` equals :math:`\ln(2) / \text{halflife}`.
Parameters
----------
Expand All @@ -863,23 +863,33 @@ def ewma_by_time(
adjust
Whether to adjust the result to account for the bias towards the
initial value. Defaults to True.
Returns
-------
pl.Expr
Float64
Examples
--------
>>> import polars as pl
>>> import polars_xdt as xdt
>>> from datetime import date, timedelta
>>> df = pl.DataFrame(
... {
... 'values': [0, 1, 2, None, 4],
... 'times': [date(2020, 1, 1), date(2020, 1, 3), date(2020, 1, 10), date(2020, 1, 15), date(2020, 1, 17)]})
... "values": [0, 1, 2, None, 4],
... "times": [
... date(2020, 1, 1),
... date(2020, 1, 3),
... date(2020, 1, 10),
... date(2020, 1, 15),
... date(2020, 1, 17),
... ],
... }
... )
>>> df.with_columns(
... ewma = xdt.ewma_by_time("values", times="times", halflife=timedelta(days=4)),
... ewma=xdt.ewma_by_time(
... "values", times="times", halflife=timedelta(days=4)
... ),
... )
shape: (5, 3)
┌────────┬────────────┬──────────┐
Expand All @@ -893,6 +903,7 @@ def ewma_by_time(
│ null ┆ 2020-01-15 ┆ null │
│ 4 ┆ 2020-01-17 ┆ 3.233686 │
└────────┴────────────┴──────────┘
"""
times = parse_into_expr(times)
halflife_us = (
Expand Down
14 changes: 6 additions & 8 deletions src/ewma_by_time.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use polars::prelude::*;
use polars_arrow::array::PrimitiveArray;
use pyo3_polars::export::polars_core::export::num::{Float, Pow};
use pyo3_polars::export::polars_core::export::num::Pow;

pub(crate) fn impl_ewma_by_time_float(
times: &Int64Chunked,
values: &Float64Chunked,
halflife: i64,
adjust: bool,
time_unit: TimeUnit,
) -> Float64Chunked
{
) -> Float64Chunked {
let mut out = Vec::with_capacity(times.len());
if values.len() == 0 {
if values.is_empty() {
return Float64Chunked::full_null("", times.len());
}

Expand All @@ -36,7 +35,7 @@ pub(crate) fn impl_ewma_by_time_float(
let result: f64;
if adjust {
alpha *= Pow::pow(0.5, delta_time as f64 / halflife as f64);
result = (value + alpha * prev_result) / ((1. + alpha));
result = (value + alpha * prev_result) / (1. + alpha);
alpha += 1.;
} else {
// equivalent to:
Expand All @@ -56,7 +55,6 @@ pub(crate) fn impl_ewma_by_time_float(
Float64Chunked::from(arr)
}


pub(crate) fn impl_ewma_by_time(
times: &Int64Chunked,
values: &Series,
Expand All @@ -72,13 +70,13 @@ pub(crate) fn impl_ewma_by_time(
DataType::Int64 | DataType::Int32 => {
let values = values.cast(&DataType::Float64).unwrap();
let values = values.f64().unwrap();
impl_ewma_by_time_float(times, &values, halflife, adjust, time_unit).into_series()
impl_ewma_by_time_float(times, values, halflife, adjust, time_unit).into_series()
}
DataType::Float32 => {
// todo: preserve Float32 in this case
let values = values.cast(&DataType::Float64).unwrap();
let values = values.f64().unwrap();
impl_ewma_by_time_float(times, &values, halflife, adjust, time_unit).into_series()
impl_ewma_by_time_float(times, values, halflife, adjust, time_unit).into_series()
}
dt => panic!("Expected values to be signed numeric, got {:?}", dt),
}
Expand Down
20 changes: 17 additions & 3 deletions src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,27 @@ fn ewma_by_time(inputs: &[Series], kwargs: EwmTimeKwargs) -> PolarsResult<Series
match &inputs[0].dtype() {
DataType::Datetime(_, _) => {
let time = &inputs[0].datetime().unwrap();
Ok(impl_ewma_by_time(&time.0, values, kwargs.halflife, kwargs.adjust, time.time_unit()).into_series())
Ok(impl_ewma_by_time(
&time.0,
values,
kwargs.halflife,
kwargs.adjust,
time.time_unit(),
)
.into_series())
}
DataType::Date => {
let binding = &inputs[0].cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?;
let time = binding.datetime().unwrap();
Ok(impl_ewma_by_time(&time.0, values, kwargs.halflife, kwargs.adjust, time.time_unit()).into_series())
Ok(impl_ewma_by_time(
&time.0,
values,
kwargs.halflife,
kwargs.adjust,
time.time_unit(),
)
.into_series())
}
_ => polars_bail!(InvalidOperation: "First argument should be a date or datetime type.")
_ => polars_bail!(InvalidOperation: "First argument should be a date or datetime type."),
}
}
30 changes: 24 additions & 6 deletions tests/test_ewma_by_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,34 @@
import polars_xdt as xdt
from datetime import datetime, timedelta


def test_ewma_by_time():
df = pl.DataFrame(
{
'values': [0., 1, 2, None, 4],
'times': [datetime(2020, 1, 1), datetime(2020, 1, 3), datetime(2020, 1, 10), datetime(2020, 1, 15), datetime(2020, 1, 17)]})
"values": [0.0, 1, 2, None, 4],
"times": [
datetime(2020, 1, 1),
datetime(2020, 1, 3),
datetime(2020, 1, 10),
datetime(2020, 1, 15),
datetime(2020, 1, 17),
],
}
)
result = df.select(
ewma = xdt.ewma_by_time("values", times="times", halflife=timedelta(days=4)),
ewma=xdt.ewma_by_time(
"values", times="times", halflife=timedelta(days=4)
),
)
expected = pl.DataFrame(
{
'ewma': [0.0, 0.585786437626905, 1.52388878049859, None, 3.2336858398518338]
})
assert_frame_equal(result, expected)
"ewma": [
0.0,
0.585786437626905,
1.52388878049859,
None,
3.2336858398518338,
]
}
)
assert_frame_equal(result, expected)

0 comments on commit 0b61bc7

Please sign in to comment.