Skip to content

Commit

Permalink
Merge pull request #235 from dos-group/fix_#234
Browse files Browse the repository at this point in the history
Fix #234
  • Loading branch information
marvin-steinke authored Sep 9, 2024
2 parents 63fc53b + ea8611d commit 5bdcb81
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
8 changes: 2 additions & 6 deletions tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def hist_signal_single(self) -> vs.HistoricalSignal:
@pytest.fixture
def hist_signal_forecast(self) -> vs.HistoricalSignal:
index = pd.date_range("2023-01-01T00:00:00", "2023-01-01T01:00:00", freq="20T")
actual = pd.DataFrame(
{"a": [1, 5, 3, 2], "b": [0, 1, 2, 3], "c": [4, 3, 2, 7]}, index=index
)
actual = pd.DataFrame({"a": [1, 5, 3, 2], "b": [0, 1, 2, 3]}, index=index)

forecast_data = [
["2023-01-01T00:00:00", "2023-01-01T00:10:00", 2, 2.5],
Expand Down Expand Up @@ -294,12 +292,10 @@ def test_forecast_with_frequency(
# Complicated because np.nan == np.nan is False
assert forecast.keys() == expected.keys()
assert all(
np.isnan(expected[k]) if np.isnan(forecast[k])
else forecast[k] == expected[k]
np.isnan(expected[k]) if np.isnan(forecast[k]) else forecast[k] == expected[k]
for k in forecast.keys()
)


def test_forecast_fails_if_column_not_specified(self, hist_signal):
with pytest.raises(ValueError):
hist_signal.forecast(
Expand Down
16 changes: 15 additions & 1 deletion vessim/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ def __init__(
fill_method: Literal["ffill", "bfill"] = "ffill",
column: Optional[str] = None,
):
if isinstance(actual, pd.DataFrame) and forecast is not None:
if isinstance(forecast, pd.DataFrame):
if not actual.columns.equals(forecast.columns):
raise ValueError("Column names in actual and forecast do not match.")
else:
raise ValueError("Forecast has to be a DataFrame if actual is a DataFrame.")

self.default_column = column
self._fill_method = fill_method
# Unpack index of actual dataframe
Expand Down Expand Up @@ -356,7 +363,14 @@ def _resample_to_frequency(
else:
# Actual value is used for interpolation
times = np.insert(times, 0, start_time)
data = np.insert(data, 0, self.now(start_time, column))
# https://github.com/dos-group/vessim/issues/234
# Use the length of the actual data to determine the column:
# self._actual is a dict[str, tuple[np.ndarray, np.ndarray]]
# -> every key is a column name
# -> if len(self._actual) == 1, _actual is based on pd.Series and column is None
data = np.insert(
data, 0, self.now(start_time, None if len(self._actual) == 1 else column)
)
if resample_method == "ffill":
new_data = data[np.searchsorted(times, new_times, side="right") - 1]
elif resample_method == "nearest":
Expand Down

0 comments on commit 5bdcb81

Please sign in to comment.