-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Towards #1009: Seasonal decomposition: Part 1/2: backend, unit test p…
…lots (#1016) Towards #1009 Seasonal decomposition - Add real crypto data: 10000 five-min epochs for Binance BTC/USDT & ETH/USDT. In `./pdr_backend/lake/test/merged_ohlcv_df_BTC-ETH_2024-02-01_to_2024-03-08.csv` - add `aimodel/seasonal.py` which has: - `pdr_seasonal_decompose()` - Wraps statsmodels' seasonal_decompose() with predictoor-specific inputs - `SeasonalPlotdata` - Simple class to manage many inputs going to plot_seasonal." - `plot_seasonal(d: SeasonalPlotdata)` - Plot seasonal decomposition of the feed, via 4 figures - add `aimodel/test/test_seasonal.py` which has: - via the methods added: pull in real BTC data, decompose it, and plot it
- Loading branch information
Showing
4 changed files
with
10,590 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
from datetime import datetime | ||
from typing import List | ||
|
||
from enforce_typing import enforce_types | ||
import plotly.graph_objects as go | ||
from plotly.subplots import make_subplots | ||
from statsmodels.tsa.seasonal import DecomposeResult, seasonal_decompose | ||
|
||
from pdr_backend.cli.arg_timeframe import ArgTimeframe | ||
from pdr_backend.util.time_types import UnixTimeMs | ||
|
||
|
||
@enforce_types | ||
def pdr_seasonal_decompose(timeframe: ArgTimeframe, y_values) -> DecomposeResult: | ||
""" | ||
@description | ||
Wraps statsmodels' seasonal_decompose() with predictoor-specific inputs | ||
@arguments | ||
timeframe -- time interval between x-values. ArgTimeframe('5m') | ||
y_values -- array-like -- [sample_i] : y_value_float | ||
""" | ||
# preconditions | ||
assert len(y_values.shape) == 1, "y_values must be 1d array" | ||
|
||
# https://stackoverflow.com/questions/60017052/decompose-for-time-series-valueerror-you-must-specify-a-period-or-x-must-be | ||
s = timeframe.timeframe_str | ||
if s == "5m": | ||
period = 288 # 288 5min epochs per day | ||
elif s == "1h": | ||
period = 24 # 24 1h epochs per day | ||
else: | ||
raise ValueError(s) | ||
|
||
result = seasonal_decompose(y_values, period=period) | ||
return result | ||
|
||
|
||
class SeasonalPlotdata: | ||
"""Simple class to manage many inputs going to plot_seasonal.""" | ||
|
||
@enforce_types | ||
def __init__( | ||
self, | ||
start_time: UnixTimeMs, | ||
timeframe: ArgTimeframe, | ||
decompose_result: DecomposeResult, | ||
): | ||
""" | ||
@arguments | ||
start_time -- x-value #0 | ||
timeframe -- time interval between x-values. ArgTimeframe('5m') | ||
decompose_result -- has attributes (all array-like) | ||
observed - The data series that has been decomposed = y_values | ||
seasonal - The seasonal component of the data series. | ||
trend - The trend component of the data series. | ||
resid - The residual component of the data series. | ||
(weights - The weights used to reduce outlier influence.) | ||
""" | ||
self.start_time = start_time | ||
self.timeframe = timeframe | ||
self.decompose_result = decompose_result | ||
|
||
@property | ||
def dr(self) -> DecomposeResult: | ||
"""@description -- alias for decompose_result""" | ||
return self.decompose_result | ||
|
||
@property | ||
def N(self) -> int: | ||
"""@return -- number of samples""" | ||
return self.dr.observed.shape[0] | ||
|
||
@property | ||
def x_ut(self) -> List[UnixTimeMs]: | ||
"""@return -- x-values in unix time (ms)""" | ||
s = self.timeframe.timeframe_str | ||
if s == "5m": | ||
ms_per_5m = 300000 | ||
uts = [self.start_time + i * ms_per_5m for i in range(self.N)] | ||
elif s == "1h": | ||
ms_per_1h = 3600000 | ||
uts = [self.start_time + i * ms_per_1h for i in range(self.N)] | ||
else: | ||
raise ValueError(s) | ||
return [UnixTimeMs(ut) for ut in uts] | ||
|
||
@property | ||
def x_dt(self) -> List[datetime]: | ||
"""@return - x-values in datetime object""" | ||
return [ut.to_dt() for ut in self.x_ut] | ||
|
||
|
||
@enforce_types | ||
def plot_seasonal(seasonal_plotdata: SeasonalPlotdata): | ||
""" | ||
@description | ||
Plot seasonal decomposition of the feed, via 4 figures | ||
1. observed feed | ||
2. trend | ||
3. seasonal | ||
4. residual | ||
""" | ||
d = seasonal_plotdata | ||
x = d.x_dt | ||
|
||
fig = make_subplots(rows=4, cols=1, shared_xaxes=True, vertical_spacing=0.01) | ||
|
||
# subplot 1: observed | ||
fig.add_trace( | ||
go.Scatter( | ||
x=x, | ||
y=d.dr.observed, | ||
mode="lines", | ||
line={"color": "black", "width": 1}, | ||
), | ||
row=1, | ||
col=1, | ||
) | ||
fig.update_yaxes(title_text="Observed", row=1, col=1) | ||
|
||
# subplot 2: trend | ||
fig.add_trace( | ||
go.Scatter( | ||
x=x, | ||
y=d.dr.trend, | ||
mode="lines", | ||
line={"color": "blue", "width": 1}, | ||
), | ||
row=2, | ||
col=1, | ||
) | ||
fig.update_yaxes(title_text="Trend", row=2, col=1) | ||
|
||
# subplot 3: seasonal | ||
fig.add_trace( | ||
go.Scatter( | ||
x=x, | ||
y=d.dr.seasonal, | ||
mode="lines", | ||
line={"color": "green", "width": 1}, | ||
), | ||
row=3, | ||
col=1, | ||
) | ||
fig.update_yaxes(title_text="Seasonal", row=3, col=1) | ||
|
||
# subplot 4: residual | ||
fig.add_trace( | ||
go.Scatter( | ||
x=x, | ||
y=d.dr.resid, | ||
mode="lines", | ||
line={"color": "red", "width": 1}, | ||
), | ||
row=4, | ||
col=1, | ||
) | ||
fig.update_yaxes(title_text="Residual", row=4, col=1) | ||
fig.update_xaxes(title_text="Time", row=4, col=1) | ||
|
||
# global | ||
minor = {"ticks": "inside", "showgrid": True} | ||
for row in [1, 2, 3, 4]: | ||
fig.update_yaxes(minor=minor, row=row, col=1) | ||
fig.update_xaxes(minor=minor, row=row, col=1) | ||
fig.update_layout(title_text="Seasonal decomposition", showlegend=False) | ||
fig.update_yaxes(nticks=8) | ||
fig.update_xaxes(nticks=15) | ||
|
||
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from enforce_typing import enforce_types | ||
import pandas as pd | ||
|
||
from pdr_backend.aimodel.seasonal import ( | ||
pdr_seasonal_decompose, | ||
plot_seasonal, | ||
SeasonalPlotdata, | ||
) | ||
from pdr_backend.cli.arg_timeframe import ArgTimeframe | ||
from pdr_backend.util.time_types import UnixTimeMs | ||
|
||
|
||
DATA_FILE = ( | ||
"./pdr_backend/lake/test/merged_ohlcv_df_BTC-ETH_2024-02-01_to_2024-03-08.csv" | ||
) | ||
BTC_COL = "binance:BTC/USDT:close" | ||
|
||
SHOW_PLOT = False # only turn on for manual testing | ||
|
||
|
||
@enforce_types | ||
def test_seasonal_SHOW_PLOT(): | ||
"""SHOW_PLOT should only be set to True temporarily in local testing.""" | ||
assert not SHOW_PLOT | ||
|
||
|
||
@enforce_types | ||
def test_seasonal(): | ||
df = pd.read_csv(DATA_FILE) # all data start_time = UnixTimeMs(df["timestamp"][0]) | ||
st = UnixTimeMs(df["timestamp"][0]) | ||
t = ArgTimeframe("5m") | ||
y = df[BTC_COL].array | ||
|
||
dr = pdr_seasonal_decompose(t, y) | ||
assert ( | ||
y.shape | ||
== dr.observed.shape | ||
== dr.seasonal.shape | ||
== dr.trend.shape | ||
== dr.resid.shape | ||
) | ||
|
||
plotdata = SeasonalPlotdata(st, t, dr) | ||
|
||
fig = plot_seasonal(plotdata) | ||
|
||
if SHOW_PLOT: | ||
fig.show() |
Oops, something went wrong.