Skip to content

Commit

Permalink
Moves simple is_decision_node method into utilities and reuses it in …
Browse files Browse the repository at this point in the history
…another utility. Adds test.
  • Loading branch information
MoseleyS committed Nov 1, 2023
1 parent 6c8a616 commit d3173e5
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 21 deletions.
21 changes: 2 additions & 19 deletions improver/categorical/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
day_night_map,
expand_nested_lists,
get_parameter_names,
is_decision_node,
is_variable,
update_daynight,
update_tree_thresholds,
Expand Down Expand Up @@ -164,24 +165,6 @@ def __repr__(self) -> str:
"""Represent the configured plugin instance as a string."""
return "<ApplyDecisionTree start_node={}>".format(self.start_node)

@staticmethod
def _is_decision_node(key: str, query: Dict[str, Dict[str, Union[str, List]]]) -> bool:
"""
Determine whether a given node is a decision node.
The meta node has a key of "meta", leaf nodes have a query key of "leaf", everything
else is a decision node.
Args:
key:
Decision name ("meta" indicates a non-decision node)
query:
Dict where key "leaf" indicates a non-decision node
Returns:
True if query represents a decision node
"""
return key != "meta" and "leaf" not in query.keys()

def prepare_input_cubes(
self, cubes: CubeList
) -> Tuple[CubeList, Optional[List[str]]]:
Expand Down Expand Up @@ -215,7 +198,7 @@ def prepare_input_cubes(
optional_node_data_missing = []
missing_data = []
for key, query in self.queries.items():
if not self._is_decision_node(key, query):
if not is_decision_node(key, query):
continue
diagnostics = get_parameter_names(
expand_nested_lists(query, "diagnostic_fields")
Expand Down
22 changes: 20 additions & 2 deletions improver/categorical/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,24 @@ def update_daynight(cube: Cube, day_night: Dict) -> Cube:
return cube_day_night


def is_decision_node(key: str, query: Dict[str, Any]) -> bool:
"""
Determine whether a given node is a decision node.
The meta node has a key of "meta", leaf nodes have a query key of "leaf", everything
else is a decision node.
Args:
key:
Decision name ("meta" indicates a non-decision node)
query:
Dict where key "leaf" indicates a non-decision node
Returns:
True if query represents a decision node
"""
return key != "meta" and "leaf" not in query.keys()


def interrogate_decision_tree(decision_tree: Dict[str, Dict[str, Any]]) -> str:
"""
Obtain a list of necessary inputs from the decision tree as it is currently
Expand All @@ -223,8 +241,8 @@ def interrogate_decision_tree(decision_tree: Dict[str, Dict[str, Any]]) -> str:
"""
# Diagnostic names and threshold values.
requirements = {}
for query in decision_tree.values():
if "diagnostic_fields" not in query.keys():
for key, query in decision_tree.items():
if not is_decision_node(key, query):
continue
diagnostics = get_parameter_names(
expand_nested_lists(query, "diagnostic_fields")
Expand Down
11 changes: 11 additions & 0 deletions improver_tests/categorical/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
expand_nested_lists,
get_parameter_names,
interrogate_decision_tree,
is_decision_node,
update_daynight,
update_tree_thresholds,
)
Expand Down Expand Up @@ -718,5 +719,15 @@ def test_day_night_map():
assert expected == result


@pytest.mark.parametrize(
"name, node, expected",
(("anything", {}, True), ("meta", {}, False), ("a_leaf", {"leaf": 0}, False)),
)
def test_is_decision_node(name, node, expected):
"""Tests that we can correctly distinguish between decision nodes and other nodes"""
result = is_decision_node(name, node)
assert result == expected


if __name__ == "__main__":
unittest.main()

0 comments on commit d3173e5

Please sign in to comment.