Skip to content

Commit

Permalink
Added possibility of having static forecasts
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianp14 committed Oct 16, 2023
1 parent 5a224e6 commit c0d25d9
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 21 deletions.
45 changes: 33 additions & 12 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def time_series_forecast(self) -> TimeSeriesApi:
forecast.set_index(["request_time", "forecast_time"], inplace=True)
return TimeSeriesApi(actual, forecast)

@pytest.fixture
def time_series_global_forecast(self) -> TimeSeriesApi:
index = pd.date_range("2023-01-01T00:00:00", "2023-01-01T03:00:00", freq="1H")
actual = pd.DataFrame([3, 2, 4, 0], index=index)
forecast = pd.DataFrame([4, 2, 4, 1], index=index)
return TimeSeriesApi(actual, forecast)

@pytest.mark.parametrize(
"dt, expected",
[
Expand Down Expand Up @@ -67,7 +74,7 @@ def test_next_update(self, time_series, dt, expected):
def test_zones(self, time_series):
assert time_series.zones() == ["a", "b"]

def test_trace_actual_at_single_zone(self, time_series_single):
def test_trace_actual_single_zone(self, time_series_single):
assert time_series_single.actual("2023-01-01T00:45:00") == 3

@pytest.mark.parametrize(
Expand All @@ -84,23 +91,23 @@ def test_trace_actual_at_single_zone(self, time_series_single):
def test_actual_at(self, time_series, dt, zone, expected):
assert time_series.actual(dt, zone) == expected

def test_actual_at_fails_if_zone_not_specified(self, time_series):
def test_actual_fails_if_zone_not_specified(self, time_series):
with pytest.raises(ValueError):
time_series.actual(pd.to_datetime("2023-01-01T00:00:00"))

def test_actual_at_fails_if_zone_does_not_exist(self, time_series):
def test_actual_fails_if_zone_does_not_exist(self, time_series):
with pytest.raises(ValueError):
time_series.actual(pd.to_datetime("2023-01-01T00:00:00"), "c")

def test_actual_at_fails_if_now_too_early(self, time_series):
def test_actual_fails_if_now_too_early(self, time_series):
with pytest.raises(ValueError):
time_series.actual(pd.to_datetime("2022-12-30T23:59:59"), "a")

def test_actual_at_fails_if_now_too_late(self, time_series_single):
def test_actual_fails_if_now_too_late(self, time_series_single):
with pytest.raises(ValueError):
time_series_single.actual(pd.to_datetime("2023-01-01T01:00:01"))

def test_forecast_at_single_zone(self, time_series_single):
def test_forecast_single_zone(self, time_series_single):
assert time_series_single.forecast(
start_time="2023-01-01T00:00:00",
end_time="2023-01-01T01:00:00",
Expand All @@ -114,6 +121,20 @@ def test_forecast_at_single_zone(self, time_series_single):
)
)

def test_forecast_global(self, time_series_global_forecast):
assert time_series_global_forecast.forecast(
start_time="2023-01-01T00:00:00",
end_time="2023-01-01T02:00:00",
).equals(
pd.Series(
[2, 4],
index=[
pd.to_datetime("2023-01-01T01:00:00"),
pd.to_datetime("2023-01-01T02:00:00"),
],
)
)

@pytest.mark.parametrize(
"start, end, zone, expected",
[
Expand Down Expand Up @@ -232,7 +253,7 @@ def test_forecast_at(self, time_series_forecast, start, end, zone, expected):
),
],
)
def test_forecast_at_with_frequency(
def test_forecast_with_frequency(
self, time_series_forecast, start, end, zone, frequency, method, expected
):
assert time_series_forecast.forecast(
Expand All @@ -243,30 +264,30 @@ def test_forecast_at_with_frequency(
resample_method=method,
).equals(expected)

def test_forecast_at_fails_if_zone_not_specified(self, time_series):
def test_forecast_fails_if_zone_not_specified(self, time_series):
with pytest.raises(ValueError):
time_series.forecast(
pd.to_datetime("2023-01-01T00:00:00"),
pd.to_datetime("2023-01-01T01:00:00"),
)

def test_forecast_at_fails_if_zone_does_not_exist(self, time_series):
def test_forecast_fails_if_zone_does_not_exist(self, time_series):
with pytest.raises(ValueError):
time_series.forecast(
pd.to_datetime("2023-01-01T00:00:00"),
pd.to_datetime("2023-01-01T01:00:00"),
zone="c",
)

def test_forecast_at_fails_if_start_too_early(self, time_series):
def test_forecast_fails_if_start_too_early(self, time_series):
with pytest.raises(ValueError):
time_series.forecast(
pd.to_datetime("2022-12-31T23:59:59"),
pd.to_datetime("2023-01-01T01:00:00"),
zone="a",
)

def test_forecast_at_fails_with_invalid_frequency(self, time_series):
def test_forecast_fails_with_invalid_frequency(self, time_series):
with pytest.raises(ValueError):
time_series.forecast(
time_series.forecast(
Expand All @@ -277,7 +298,7 @@ def test_forecast_at_fails_with_invalid_frequency(self, time_series):
)
)

def test_forecast_at_fails_if_not_enough_data_for_frequency(self, time_series):
def test_forecast_fails_if_not_enough_data_for_frequency(self, time_series):
with pytest.raises(ValueError):
time_series.forecast(
pd.to_datetime("2023-01-01T00:00:00"),
Expand Down
22 changes: 13 additions & 9 deletions vessim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class TimeSeriesApi:
If you wish a different behavior, you have to change your actual data
beforehand (e.g by resampling into a different frequency).
forecast: An optional time-series dataset representing forecasted values.
The data should contain two indices. One is the "Request Timestamp", marking
the time when the forecast was made. One is the "Forecast Timestamp",
indicating the time the forecast is made for.
The data should contain two datetime-like indices. One is the
"Request Timestamp", marking the time when the forecast was made. One is the
"Forecast Timestamp", indicating the time the forecast is made for.
- If data does not include a "Request Timestamp", it is treated as a global
forecast that does not change over time.
- If `forecast` is not provided, predictions are derived from the
Expand Down Expand Up @@ -77,12 +77,13 @@ def __init__(

if isinstance(forecast, (pd.Series, pd.DataFrame)):
# Convert all indices (either one or two columns) to datetime
new_levels = [
forecast.index.get_level_values(i).map(pd.to_datetime)
for i in range(forecast.index.nlevels)
]
new_index = pd.MultiIndex.from_arrays(new_levels, names=forecast.index.names)
forecast.index = new_index
if isinstance(forecast.index, pd.MultiIndex):
for level in range(forecast.index.nlevels):
forecast.index = forecast.index.set_levels(
pd.to_datetime(forecast.index.levels[level]), level=level
)
else:
forecast.index = pd.to_datetime(forecast.index)

forecast.sort_index(inplace=True)
if isinstance(forecast, pd.Series):
Expand Down Expand Up @@ -202,6 +203,9 @@ def _get_forecast_data_source(self, start_time: DatetimeLike):
# Get all data points beginning at the nearest existing timestamp
# lower than start time from actual data
data_src = self._actual.loc[self._actual.index.asof(start_time) :]
elif self._forecast.index.nlevels == 1:
# Forecast data does not include request_timestamp
data_src = self._forecast.loc[self._forecast.index.asof(start_time) :]
else:
# Get the nearest existing timestamp lower than start time from forecasts
first_index = self._forecast.index.get_level_values(0)
Expand Down

0 comments on commit c0d25d9

Please sign in to comment.