Skip to content

Commit

Permalink
Merge pull request #232 from dos-group/hist_signal_fix
Browse files Browse the repository at this point in the history
Fixed bug in HistoricalSignal
  • Loading branch information
marvin-steinke authored Aug 30, 2024
2 parents 9cbe52b + ddee5fc commit 6f7779b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 14 deletions.
54 changes: 43 additions & 11 deletions tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,24 +185,27 @@ def test_forecast(self, hist_signal_forecast, start, end, column, expected):
),
(
"2023-01-01T01:00:00",
"2023-01-01T03:00:00",
"2023-01-01T04:00:00",
"a",
timedelta(minutes=45),
"ffill",
{
np.datetime64("2023-01-01T01:45:00.000000000"): 2.0, # type: ignore
np.datetime64("2023-01-01T02:30:00.000000000"): 3.0, # type: ignore
np.datetime64("2023-01-01T03:15:00.000000000"): 1.5, # type: ignore
np.datetime64("2023-01-01T04:00:00.000000000"): 1.5, # type: ignore
},
),
(
"2023-01-01T00:00:00",
"2023-01-01T03:00:00",
"2023-01-01T05:00:00",
"a",
timedelta(hours=1, minutes=30),
"linear",
{
np.datetime64("2023-01-01T01:30:00.000000000"): 2.0, # type: ignore
np.datetime64("2023-01-01T03:00:00.000000000"): 2.0, # type: ignore
np.datetime64("2023-01-01T04:30:00.000000000"): 2.0, # type: ignore
},
),
(
Expand Down Expand Up @@ -252,21 +255,50 @@ def test_forecast(self, hist_signal_forecast, start, end, column, expected):
np.datetime64("2023-01-01T00:55:00.000000000"): 2.5, # type: ignore
},
),
(
"2023-01-01T01:00:00",
"2023-01-01T04:00:00",
"b",
"1H",
"bfill",
{
np.datetime64("2023-01-01T02:00:00.000000000"): 2.5, # type: ignore
np.datetime64("2023-01-01T03:00:00.000000000"): 1.5, # type: ignore
np.datetime64("2023-01-01T04:00:00.000000000"): np.nan, # type: ignore
},
),
(
"2023-01-01T01:00:00",
"2023-01-01T04:00:00",
"b",
"1H",
"nearest",
{
np.datetime64("2023-01-01T02:00:00.000000000"): 2.5, # type: ignore
np.datetime64("2023-01-01T03:00:00.000000000"): 1.5, # type: ignore
np.datetime64("2023-01-01T04:00:00.000000000"): 1.5, # type: ignore
},
),
],
)
def test_forecast_with_frequency(
self, hist_signal_forecast, start, end, column, frequency, method, expected
):
assert (
hist_signal_forecast.forecast(
start,
end,
column=column,
frequency=frequency,
resample_method=method,
)
== expected
forecast = hist_signal_forecast.forecast(
start,
end,
column=column,
frequency=frequency,
resample_method=method,
)
# Complicated because np.nan == np.nan is False
assert forecast.keys() == expected.keys()
assert all(
np.isnan(expected[k]) if np.isnan(forecast[k])
else forecast[k] == expected[k]
for k in forecast.keys()
)


def test_forecast_fails_if_column_not_specified(self, hist_signal):
with pytest.raises(ValueError):
Expand Down
14 changes: 11 additions & 3 deletions vessim/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,17 @@ def _resample_to_frequency(
)

new_times_indices = np.searchsorted(times, new_times, side="left")
if not np.array_equal(new_times, times[new_times_indices]) and resample_method != "bfill":
if np.all(new_times_indices < times.size) and np.array_equal(
new_times, times[new_times_indices]
):
# No resampling necessary
new_data = data[new_times_indices]
elif resample_method == "bfill":
# Perform backward-fill whereas values outside range are filled with NaN
new_data = np.full(new_times_indices.shape, np.nan)
valid_mask = new_times_indices < len(data)
new_data[valid_mask] = data[new_times_indices[valid_mask]]
else:
# Actual value is used for interpolation
times = np.insert(times, 0, start_time)
data = np.insert(data, 0, self.now(start_time, column))
Expand All @@ -361,8 +371,6 @@ def _resample_to_frequency(
raise ValueError(f"Unknown resample_method '{resample_method}'.")
else:
raise ValueError(f"Not enough data at frequency '{freq}' without resampling.")
else:
new_data = data[new_times_indices]

return dict(zip(new_times, new_data))

Expand Down

0 comments on commit 6f7779b

Please sign in to comment.