From ec3bd32ccbba6243a58afef6535dd9398c36b201 Mon Sep 17 00:00:00 2001 From: Bryn Pickering <17178478+brynpickering@users.noreply.github.com> Date: Tue, 22 Oct 2024 11:49:45 +0100 Subject: [PATCH] Add helper function test to cover conditional --- src/calliope/backend/helper_functions.py | 5 +--- tests/test_backend_expression_parser.py | 29 +++++++++++++++++++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/calliope/backend/helper_functions.py b/src/calliope/backend/helper_functions.py index 073ef8e3..dc024796 100644 --- a/src/calliope/backend/helper_functions.py +++ b/src/calliope/backend/helper_functions.py @@ -809,10 +809,7 @@ def as_array(self, array: xr.DataArray, where_array: xr.DataArray) -> xr.DataArr * cap_node_groups (cap_node_groups) object 24B 'group_1' 'group_2' 'group_3' ``` """ - if ( - self._backend_interface is not None - and where_array.name in self._backend_interface._dataset - ): + if self._backend_interface is not None: where_array = self._input_data[where_array.name] return array.where(where_array.fillna(False).astype(bool)) diff --git a/tests/test_backend_expression_parser.py b/tests/test_backend_expression_parser.py index ded1feb9..b8baf7ec 100644 --- a/tests/test_backend_expression_parser.py +++ b/tests/test_backend_expression_parser.py @@ -39,7 +39,16 @@ def as_array(self, x, y): @pytest.fixture def valid_component_names(): - return ["foo", "with_inf", "only_techs", "no_dims", "multi_dim_var", "no_dim_var"] + return [ + "foo", + "with_inf", + "only_techs", + "no_dims", + "multi_dim_var", + "no_dim_var", + "all_true", + "only_techs_as_bool", + ] @pytest.fixture @@ -567,6 +576,24 @@ def test_function_one_arg_allowed_invalid_string( assert check_error_or_warning(excinfo, "Expected") +class TestEquationParserHelper: + @pytest.mark.parametrize( + ("where", "expected_notnull"), + [ + ("all_true", [[True, True, True, True], [True, True, True, True]]), + ("only_techs_as_bool", [False, True, True, True]), + ], + ) + def test_helper_function_where( + self, helper_function, eval_kwargs, where, expected_notnull + ): + """Test that `where` helper function works as expected when passed a backend interface object.""" + string_ = f"where(no_dims, {where})" + parsed_ = helper_function.parse_string(string_, parse_all=True) + evaluated_ = parsed_[0].eval(**eval_kwargs) + np.testing.assert_array_equal(evaluated_.notnull(), expected_notnull) + + class TestEquationParserArithmetic: numbers = [2, 100, 0.02, "1e2", "2e-2", "inf"]