From 5db7d99691da5b7de3c181074f52abc42c3866bc Mon Sep 17 00:00:00 2001 From: David Hoese Date: Tue, 3 Dec 2024 13:51:08 -0600 Subject: [PATCH 01/22] Cleanup dependency tree tests --- satpy/tests/test_dependency_tree.py | 360 +++++++++++++--------------- 1 file changed, 171 insertions(+), 189 deletions(-) diff --git a/satpy/tests/test_dependency_tree.py b/satpy/tests/test_dependency_tree.py index 40433c0032..dd18338c3a 100644 --- a/satpy/tests/test_dependency_tree.py +++ b/satpy/tests/test_dependency_tree.py @@ -16,139 +16,16 @@ """Unit tests for the dependency tree class and dependencies.""" import os -import unittest + +import pytest from satpy.dependency_tree import DependencyTree from satpy.tests.utils import make_cid, make_dataid -class TestDependencyTree(unittest.TestCase): - """Test the dependency tree. - - This is what we are working with:: - - None (No Data) - +DataID(name='comp19') - + +DataID(name='ds5', resolution=250, modifiers=('res_change',)) - + + +DataID(name='ds5', resolution=250, modifiers=()) - + + +__EMPTY_LEAF_SENTINEL__ (No Data) - + +DataID(name='comp13') - + + +DataID(name='ds5', resolution=250, modifiers=('res_change',)) - + + + +DataID(name='ds5', resolution=250, modifiers=()) - + + + +__EMPTY_LEAF_SENTINEL__ (No Data) - + +DataID(name='ds2', resolution=250, calibration=, modifiers=()) - - """ - - def setUp(self): - """Set up the test tree.""" - self.dependency_tree = DependencyTree(None, None, None) - - composite_1 = make_cid(name="comp19") - dependency_1 = make_dataid(name="ds5", resolution=250, modifiers=("res_change",)) - dependency_1_1 = make_dataid(name="ds5", resolution=250, modifiers=tuple()) - node_composite_1 = self.dependency_tree.add_leaf(composite_1) - node_dependency_1 = self.dependency_tree.add_leaf(dependency_1, node_composite_1) - self.dependency_tree.add_leaf(dependency_1_1, node_dependency_1) - # ToDo: do we really want then empty node to be at the same level as the unmodified data? - node_dependency_1.add_child(self.dependency_tree.empty_node) - - dependency_2 = make_cid(name="comp13") - dependency_2_1 = dependency_1 - node_dependency_2 = self.dependency_tree.add_leaf(dependency_2, node_composite_1) - self.dependency_tree.add_leaf(dependency_2_1, node_dependency_2) - # We don't need to add the unmodified dependency a second time. - - dependency_3 = make_dataid(name="ds2", resolution=250, calibration="reflectance", modifiers=tuple()) - self.dependency_tree.add_leaf(dependency_3, node_composite_1) - - @staticmethod - def _nodes_equal(node_list1, node_list2): - names1 = [node.name for node in node_list1] - names2 = [node.name for node in node_list2] - return sorted(names1) == sorted(names2) - - def test_copy_preserves_all_nodes(self): - """Test that dependency tree copy preserves all nodes.""" - new_dependency_tree = self.dependency_tree.copy() - assert self.dependency_tree.empty_node is new_dependency_tree.empty_node - assert self._nodes_equal(self.dependency_tree.leaves(), - new_dependency_tree.leaves()) - assert self._nodes_equal(self.dependency_tree.trunk(), - new_dependency_tree.trunk()) - - # make sure that we can get access to sub-nodes - c13_id = make_cid(name="comp13") - assert self._nodes_equal(self.dependency_tree.trunk(limit_nodes_to=[c13_id]), - new_dependency_tree.trunk(limit_nodes_to=[c13_id])) - - def test_copy_preserves_unique_empty_node(self): - """Test that dependency tree copy preserves the uniqueness of the empty node.""" - new_dependency_tree = self.dependency_tree.copy() - assert self.dependency_tree.empty_node is new_dependency_tree.empty_node - - assert self.dependency_tree._root.children[0].children[0].children[1] is self.dependency_tree.empty_node - assert new_dependency_tree._root.children[0].children[0].children[1] is self.dependency_tree.empty_node - - def test_new_dependency_tree_preserves_unique_empty_node(self): - """Test that dependency tree instantiation preserves the uniqueness of the empty node.""" - new_dependency_tree = DependencyTree(None, None, None) - assert self.dependency_tree.empty_node is new_dependency_tree.empty_node - - -class TestMissingDependencies(unittest.TestCase): - """Test the MissingDependencies exception.""" - - def test_new_missing_dependencies(self): - """Test new MissingDependencies.""" - from satpy.node import MissingDependencies - error = MissingDependencies("bla") - assert error.missing_dependencies == "bla" - - def test_new_missing_dependencies_with_message(self): - """Test new MissingDependencies with a message.""" - from satpy.node import MissingDependencies - error = MissingDependencies("bla", "This is a message") - assert "This is a message" in str(error) - - -class TestMultipleResolutionSameChannelDependency(unittest.TestCase): - """Test that MODIS situations where the same channel is available at multiple resolution works.""" - - def test_modis_overview_1000m(self): - """Test a modis overview dependency calculation with resolution fixed to 1000m.""" - from satpy import DataQuery - from satpy._config import PACKAGE_CONFIG_PATH - from satpy.composites import GenericCompositor - from satpy.dataset import DatasetDict - from satpy.modifiers.geometry import SunZenithCorrector - from satpy.readers.yaml_reader import FileYAMLReader - - config_file = os.path.join(PACKAGE_CONFIG_PATH, "readers", "modis_l1b.yaml") - self.reader_instance = FileYAMLReader.from_config_files(config_file) - - overview = {"_satpy_id": make_dataid(name="overview"), - "name": "overview", - "optional_prerequisites": [], - "prerequisites": [DataQuery(name="1", modifiers=("sunz_corrected",)), - DataQuery(name="2", modifiers=("sunz_corrected",)), - DataQuery(name="31")], - "standard_name": "overview"} - compositors = {"modis": DatasetDict()} - compositors["modis"]["overview"] = GenericCompositor(**overview) - - modifiers = {"modis": {"sunz_corrected": (SunZenithCorrector, - {"optional_prerequisites": ["solar_zenith_angle"], - "name": "sunz_corrected", - "prerequisites": []})}} - dep_tree = DependencyTree({"modis_l1b": self.reader_instance}, compositors, modifiers) - dep_tree.populate_with_keys({"overview"}, DataQuery(resolution=1000)) - for key in dep_tree._all_nodes.keys(): - assert key.get("resolution", 1000) == 1000 - - -class TestMultipleSensors(unittest.TestCase): - """Test cases where multiple sensors are available. +@pytest.fixture +def dep_tree1(): + """Fake dependency tree with two composites and one regular dataset. This is what we are working with:: @@ -164,64 +41,169 @@ class TestMultipleSensors(unittest.TestCase): + +DataID(name='ds2', resolution=250, calibration=, modifiers=()) """ - - def setUp(self): - """Set up the test tree.""" - from satpy.composites import CompositeBase - from satpy.dataset.data_dict import DatasetDict - from satpy.modifiers import ModifierBase - - class _FakeCompositor(CompositeBase): - def __init__(self, ret_val, *args, **kwargs): - self.ret_val = ret_val - super().__init__(*args, **kwargs) - - def __call__(self, *args, **kwargs): - return self.ret_val - - class _FakeModifier(ModifierBase): - def __init__(self, ret_val, *args, **kwargs): - self.ret_val = ret_val - super().__init__(*args, **kwargs) - - def __call__(self, *args, **kwargs): - return self.ret_val - - comp1_sensor1 = _FakeCompositor(1, "comp1") - comp1_sensor2 = _FakeCompositor(2, "comp1") - # create the dictionary one element at a time to force "incorrect" order - # (sensor2 comes before sensor1, but results should be alphabetical order) - compositors = {} - compositors["sensor2"] = s2_comps = DatasetDict() - compositors["sensor1"] = s1_comps = DatasetDict() - c1_s2_id = make_cid(name="comp1", resolution=1000) - c1_s1_id = make_cid(name="comp1", resolution=500) - s2_comps[c1_s2_id] = comp1_sensor2 - s1_comps[c1_s1_id] = comp1_sensor1 - - modifiers = {} - modifiers["sensor2"] = s2_mods = {} - modifiers["sensor1"] = s1_mods = {} - s2_mods["mod1"] = (_FakeModifier, {"ret_val": 2}) - s1_mods["mod1"] = (_FakeModifier, {"ret_val": 1}) - - self.dependency_tree = DependencyTree({}, compositors, modifiers) - # manually add a leaf so we don't have to mock a reader - ds5 = make_dataid(name="ds5", resolution=250, modifiers=tuple()) - self.dependency_tree.add_leaf(ds5) - - def test_compositor_loaded_sensor_order(self): - """Test that a compositor is loaded from the first alphabetical sensor.""" - self.dependency_tree.populate_with_keys({"comp1"}) - comp_nodes = self.dependency_tree.trunk() - assert len(comp_nodes) == 1 - assert comp_nodes[0].name["resolution"] == 500 - - def test_modifier_loaded_sensor_order(self): - """Test that a modifier is loaded from the first alphabetical sensor.""" - from satpy import DataQuery - dq = DataQuery(name="ds5", modifiers=("mod1",)) - self.dependency_tree.populate_with_keys({dq}) - comp_nodes = self.dependency_tree.trunk() - assert len(comp_nodes) == 1 - assert comp_nodes[0].data[0].ret_val == 1 + dependency_tree = DependencyTree(None, None, None) + + composite_1 = make_cid(name="comp19") + dependency_1 = make_dataid(name="ds5", resolution=250, modifiers=("res_change",)) + dependency_1_1 = make_dataid(name="ds5", resolution=250, modifiers=tuple()) + node_composite_1 = dependency_tree.add_leaf(composite_1) + node_dependency_1 = dependency_tree.add_leaf(dependency_1, node_composite_1) + dependency_tree.add_leaf(dependency_1_1, node_dependency_1) + # ToDo: do we really want the empty node to be at the same level as the unmodified data? + node_dependency_1.add_child(dependency_tree.empty_node) + + dependency_2 = make_cid(name="comp13") + dependency_2_1 = dependency_1 + node_dependency_2 = dependency_tree.add_leaf(dependency_2, node_composite_1) + dependency_tree.add_leaf(dependency_2_1, node_dependency_2) + # We don't need to add the unmodified dependency a second time. + + dependency_3 = make_dataid(name="ds2", resolution=250, calibration="reflectance", modifiers=tuple()) + dependency_tree.add_leaf(dependency_3, node_composite_1) + return dependency_tree + + +@pytest.fixture +def dep_tree2(): + """Fake dependency tree with multiple sensors available.""" + from satpy.composites import CompositeBase + from satpy.dataset.data_dict import DatasetDict + from satpy.modifiers import ModifierBase + + class _FakeCompositor(CompositeBase): + def __init__(self, ret_val, *args, **kwargs): + self.ret_val = ret_val + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + return self.ret_val + + class _FakeModifier(ModifierBase): + def __init__(self, ret_val, *args, **kwargs): + self.ret_val = ret_val + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + return self.ret_val + + comp1_sensor1 = _FakeCompositor(1, "comp1") + comp1_sensor2 = _FakeCompositor(2, "comp1") + # create the dictionary one element at a time to force "incorrect" order + # (sensor2 comes before sensor1, but results should be alphabetical order) + compositors = {} + compositors["sensor2"] = s2_comps = DatasetDict() + compositors["sensor1"] = s1_comps = DatasetDict() + c1_s2_id = make_cid(name="comp1", resolution=1000) + c1_s1_id = make_cid(name="comp1", resolution=500) + s2_comps[c1_s2_id] = comp1_sensor2 + s1_comps[c1_s1_id] = comp1_sensor1 + + modifiers = {} + modifiers["sensor2"] = s2_mods = {} + modifiers["sensor1"] = s1_mods = {} + s2_mods["mod1"] = (_FakeModifier, {"ret_val": 2}) + s1_mods["mod1"] = (_FakeModifier, {"ret_val": 1}) + + dependency_tree = DependencyTree({}, compositors, modifiers) + # manually add a leaf so we don't have to mock a reader + ds5 = make_dataid(name="ds5", resolution=250, modifiers=tuple()) + dependency_tree.add_leaf(ds5) + return dependency_tree + + +def _nodes_equal(node_list1, node_list2): + names1 = [node.name for node in node_list1] + names2 = [node.name for node in node_list2] + return sorted(names1) == sorted(names2) + + +def test_copy_preserves_all_nodes(dep_tree1): + """Test that dependency tree copy preserves all nodes.""" + new_dependency_tree = dep_tree1.copy() + assert dep_tree1.empty_node is new_dependency_tree.empty_node + assert _nodes_equal(dep_tree1.leaves(), new_dependency_tree.leaves()) + assert _nodes_equal(dep_tree1.trunk(), new_dependency_tree.trunk()) + + # make sure that we can get access to sub-nodes + c13_id = make_cid(name="comp13") + assert _nodes_equal(dep_tree1.trunk(limit_nodes_to=[c13_id]), + new_dependency_tree.trunk(limit_nodes_to=[c13_id])) + + +def test_copy_preserves_unique_empty_node(dep_tree1): + """Test that dependency tree copy preserves the uniqueness of the empty node.""" + new_dependency_tree = dep_tree1.copy() + assert dep_tree1.empty_node is new_dependency_tree.empty_node + + assert dep_tree1._root.children[0].children[0].children[1] is dep_tree1.empty_node + assert new_dependency_tree._root.children[0].children[0].children[1] is dep_tree1.empty_node + + +def test_new_dependency_tree_preserves_unique_empty_node(dep_tree1): + """Test that dependency tree instantiation preserves the uniqueness of the empty node.""" + new_dependency_tree = DependencyTree(None, None, None) + assert dep_tree1.empty_node is new_dependency_tree.empty_node + + +def test_new_missing_dependencies(): + """Test new MissingDependencies.""" + from satpy.node import MissingDependencies + error = MissingDependencies("bla") + assert error.missing_dependencies == "bla" + + +def test_new_missing_dependencies_with_message(): + """Test new MissingDependencies with a message.""" + from satpy.node import MissingDependencies + error = MissingDependencies("bla", "This is a message") + assert "This is a message" in str(error) + + +def test_modis_overview_1000m(): + """Test a modis overview dependency calculation with resolution fixed to 1000m.""" + from satpy import DataQuery + from satpy._config import PACKAGE_CONFIG_PATH + from satpy.composites import GenericCompositor + from satpy.dataset import DatasetDict + from satpy.modifiers.geometry import SunZenithCorrector + from satpy.readers.yaml_reader import FileYAMLReader + + config_file = os.path.join(PACKAGE_CONFIG_PATH, "readers", "modis_l1b.yaml") + reader_instance = FileYAMLReader.from_config_files(config_file) + + overview = {"_satpy_id": make_dataid(name="overview"), + "name": "overview", + "optional_prerequisites": [], + "prerequisites": [DataQuery(name="1", modifiers=("sunz_corrected",)), + DataQuery(name="2", modifiers=("sunz_corrected",)), + DataQuery(name="31")], + "standard_name": "overview"} + compositors = {"modis": DatasetDict()} + compositors["modis"]["overview"] = GenericCompositor(**overview) + + modifiers = {"modis": {"sunz_corrected": (SunZenithCorrector, + {"optional_prerequisites": ["solar_zenith_angle"], + "name": "sunz_corrected", + "prerequisites": []})}} + dep_tree = DependencyTree({"modis_l1b": reader_instance}, compositors, modifiers) + dep_tree.populate_with_keys({"overview"}, DataQuery(resolution=1000)) + for key in dep_tree._all_nodes.keys(): + assert key.get("resolution", 1000) == 1000 + + +def test_compositor_loaded_sensor_order(dep_tree2): + """Test that a compositor is loaded from the first alphabetical sensor.""" + dep_tree2.populate_with_keys({"comp1"}) + comp_nodes = dep_tree2.trunk() + assert len(comp_nodes) == 1 + assert comp_nodes[0].name["resolution"] == 500 + + +def test_modifier_loaded_sensor_order(dep_tree2): + """Test that a modifier is loaded from the first alphabetical sensor.""" + from satpy import DataQuery + dq = DataQuery(name="ds5", modifiers=("mod1",)) + dep_tree2.populate_with_keys({dq}) + comp_nodes = dep_tree2.trunk() + assert len(comp_nodes) == 1 + assert comp_nodes[0].data[0].ret_val == 1 From 66b714ae650c396cf18a2a2329c3b27c667f12bb Mon Sep 17 00:00:00 2001 From: David Hoese Date: Tue, 3 Dec 2024 14:07:31 -0600 Subject: [PATCH 02/22] Update dep tree tests to be more realistic --- satpy/tests/test_dependency_tree.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/satpy/tests/test_dependency_tree.py b/satpy/tests/test_dependency_tree.py index dd18338c3a..29afbc7661 100644 --- a/satpy/tests/test_dependency_tree.py +++ b/satpy/tests/test_dependency_tree.py @@ -30,15 +30,15 @@ def dep_tree1(): This is what we are working with:: None (No Data) - +DataID(name='comp19') - + +DataID(name='ds5', resolution=250, modifiers=('res_change',)) - + + +DataID(name='ds5', resolution=250, modifiers=()) + +DataID(name='comp19') (No Data) + + +DataID(name='ds5', resolution=250, modifiers=('res_change',)) (No Data) + + + +DataID(name='ds5', resolution=250, modifiers=()) (No Data) + + +DataID(name='comp13') (No Data) + + + +DataID(name='ds5', resolution=250, modifiers=('res_change',)) (No Data) + + + + +DataID(name='ds5', resolution=250, modifiers=()) (No Data) + + +DataID(name='ds2', resolution=250, calibration=<1>, modifiers=()) (No Data) + + +DataID(name='no_deps_comp') (No Data) + + +__EMPTY_LEAF_SENTINEL__ (No Data) - + +DataID(name='comp13') - + + +DataID(name='ds5', resolution=250, modifiers=('res_change',)) - + + + +DataID(name='ds5', resolution=250, modifiers=()) - + + + +__EMPTY_LEAF_SENTINEL__ (No Data) - + +DataID(name='ds2', resolution=250, calibration=, modifiers=()) """ dependency_tree = DependencyTree(None, None, None) @@ -49,8 +49,6 @@ def dep_tree1(): node_composite_1 = dependency_tree.add_leaf(composite_1) node_dependency_1 = dependency_tree.add_leaf(dependency_1, node_composite_1) dependency_tree.add_leaf(dependency_1_1, node_dependency_1) - # ToDo: do we really want the empty node to be at the same level as the unmodified data? - node_dependency_1.add_child(dependency_tree.empty_node) dependency_2 = make_cid(name="comp13") dependency_2_1 = dependency_1 @@ -60,6 +58,10 @@ def dep_tree1(): dependency_3 = make_dataid(name="ds2", resolution=250, calibration="reflectance", modifiers=tuple()) dependency_tree.add_leaf(dependency_3, node_composite_1) + + dependency_4 = make_cid(name="no_deps_comp") + node_dependency_4 = dependency_tree.add_leaf(dependency_4, node_composite_1) + node_dependency_4.add_child(dependency_tree.empty_node) return dependency_tree @@ -135,8 +137,8 @@ def test_copy_preserves_unique_empty_node(dep_tree1): new_dependency_tree = dep_tree1.copy() assert dep_tree1.empty_node is new_dependency_tree.empty_node - assert dep_tree1._root.children[0].children[0].children[1] is dep_tree1.empty_node - assert new_dependency_tree._root.children[0].children[0].children[1] is dep_tree1.empty_node + assert dep_tree1._root.children[0].children[3].children[0] is dep_tree1.empty_node + assert new_dependency_tree._root.children[0].children[3].children[0] is dep_tree1.empty_node def test_new_dependency_tree_preserves_unique_empty_node(dep_tree1): From cab16b9d2d011c5513ba72f60e4c62339ed53aca Mon Sep 17 00:00:00 2001 From: David Hoese Date: Tue, 3 Dec 2024 14:16:52 -0600 Subject: [PATCH 03/22] Convert combine metadata tests to pytest --- satpy/tests/test_dataset.py | 115 +++++++++++++++--------------------- 1 file changed, 49 insertions(+), 66 deletions(-) diff --git a/satpy/tests/test_dataset.py b/satpy/tests/test_dataset.py index 6ca3b25d72..e1d9c98bb3 100644 --- a/satpy/tests/test_dataset.py +++ b/satpy/tests/test_dataset.py @@ -23,11 +23,12 @@ import pytest from satpy.dataset.dataid import DataID, DataQuery, ModifierTuple, WavelengthRange, minimal_default_keys_config +from satpy.dataset.metadata import combine_metadata from satpy.readers.pmw_channels_definitions import FrequencyDoubleSideBand, FrequencyQuadrupleSideBand, FrequencyRange from satpy.tests.utils import make_cid, make_dataid, make_dsq -class TestDataID(unittest.TestCase): +class TestDataID: """Test DataID object creation and other methods.""" def test_basic_init(self): @@ -97,97 +98,89 @@ def test_create_less_modified_query(self): assert not d2.create_less_modified_query()["modifiers"] -class TestCombineMetadata(unittest.TestCase): +class TestCombineMetadata: """Test how metadata is combined.""" - def setUp(self): - """Set up the test case.""" + def test_average_datetimes(self): + """Test the average_datetimes helper function.""" + from satpy.dataset.metadata import average_datetimes + dts = ( + dt.datetime(2018, 2, 1, 11, 58, 0), + dt.datetime(2018, 2, 1, 11, 59, 0), + dt.datetime(2018, 2, 1, 12, 0, 0), + dt.datetime(2018, 2, 1, 12, 1, 0), + dt.datetime(2018, 2, 1, 12, 2, 0), + ) + ret = average_datetimes(dts) + assert dts[2] == ret + + def test_combine_start_times(self): + """Test the combine_metadata with start times.""" # The times need to be in ascending order (oldest first) - self.start_time_dts = ( + start_time_dts = ( {"start_time": dt.datetime(2018, 2, 1, 11, 58, 0)}, {"start_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, {"start_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, {"start_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, {"start_time": dt.datetime(2018, 2, 1, 12, 2, 0)}, ) - self.end_time_dts = ( + ret = combine_metadata(*start_time_dts) + assert ret["start_time"] == start_time_dts[0]["start_time"] + + def test_combine_end_times(self): + """Test the combine_metadata with end times.""" + # The times need to be in ascending order (oldest first) + end_time_dts = ( {"end_time": dt.datetime(2018, 2, 1, 11, 58, 0)}, {"end_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, {"end_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, {"end_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, {"end_time": dt.datetime(2018, 2, 1, 12, 2, 0)}, ) - self.other_time_dts = ( - {"other_time": dt.datetime(2018, 2, 1, 11, 58, 0)}, - {"other_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, - {"other_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, - {"other_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, - {"other_time": dt.datetime(2018, 2, 1, 12, 2, 0)}, - ) - self.start_time_dts_with_none = ( + ret = combine_metadata(*end_time_dts) + assert ret["end_time"] == end_time_dts[-1]["end_time"] + + def test_combine_start_times_with_none(self): + """Test the combine_metadata with start times when there's a None included.""" + start_time_dts_with_none = ( {"start_time": None}, {"start_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, {"start_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, {"start_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, {"start_time": dt.datetime(2018, 2, 1, 12, 2, 0)}, ) - self.end_time_dts_with_none = ( + ret = combine_metadata(*start_time_dts_with_none) + assert ret["start_time"] == start_time_dts_with_none[1]["start_time"] + + def test_combine_end_times_with_none(self): + """Test the combine_metadata with end times when there's a None included.""" + end_time_dts_with_none = ( {"end_time": dt.datetime(2018, 2, 1, 11, 58, 0)}, {"end_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, {"end_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, {"end_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, {"end_time": None}, ) - - def test_average_datetimes(self): - """Test the average_datetimes helper function.""" - from satpy.dataset.metadata import average_datetimes - dts = ( - dt.datetime(2018, 2, 1, 11, 58, 0), - dt.datetime(2018, 2, 1, 11, 59, 0), - dt.datetime(2018, 2, 1, 12, 0, 0), - dt.datetime(2018, 2, 1, 12, 1, 0), - dt.datetime(2018, 2, 1, 12, 2, 0), - ) - ret = average_datetimes(dts) - assert dts[2] == ret - - def test_combine_start_times(self): - """Test the combine_metadata with start times.""" - from satpy.dataset.metadata import combine_metadata - ret = combine_metadata(*self.start_time_dts) - assert ret["start_time"] == self.start_time_dts[0]["start_time"] - - def test_combine_end_times(self): - """Test the combine_metadata with end times.""" - from satpy.dataset.metadata import combine_metadata - ret = combine_metadata(*self.end_time_dts) - assert ret["end_time"] == self.end_time_dts[-1]["end_time"] - - def test_combine_start_times_with_none(self): - """Test the combine_metadata with start times when there's a None included.""" - from satpy.dataset.metadata import combine_metadata - ret = combine_metadata(*self.start_time_dts_with_none) - assert ret["start_time"] == self.start_time_dts_with_none[1]["start_time"] - - def test_combine_end_times_with_none(self): - """Test the combine_metadata with end times when there's a None included.""" - from satpy.dataset.metadata import combine_metadata - ret = combine_metadata(*self.end_time_dts_with_none) - assert ret["end_time"] == self.end_time_dts_with_none[-2]["end_time"] + ret = combine_metadata(*end_time_dts_with_none) + assert ret["end_time"] == end_time_dts_with_none[-2]["end_time"] def test_combine_other_times(self): """Test the combine_metadata with other time values than start or end times.""" - from satpy.dataset.metadata import combine_metadata - ret = combine_metadata(*self.other_time_dts) - assert ret["other_time"] == self.other_time_dts[2]["other_time"] + other_time_dts = ( + {"other_time": dt.datetime(2018, 2, 1, 11, 58, 0)}, + {"other_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, + {"other_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, + {"other_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, + {"other_time": dt.datetime(2018, 2, 1, 12, 2, 0)}, + ) + ret = combine_metadata(*other_time_dts) + assert ret["other_time"] == other_time_dts[2]["other_time"] def test_combine_arrays(self): """Test the combine_metadata with arrays.""" from numpy import arange, ones from xarray import DataArray - from satpy.dataset.metadata import combine_metadata dts = [ {"quality": (arange(25) % 2).reshape(5, 5).astype("?")}, {"quality": (arange(1, 26) % 3).reshape(5, 5).astype("?")}, @@ -221,7 +214,6 @@ def test_combine_arrays(self): def test_combine_lists_identical(self): """Test combine metadata with identical lists.""" - from satpy.dataset.metadata import combine_metadata metadatas = [ {"prerequisites": [1, 2, 3, 4]}, {"prerequisites": [1, 2, 3, 4]}, @@ -231,7 +223,6 @@ def test_combine_lists_identical(self): def test_combine_lists_same_size_diff_values(self): """Test combine metadata with lists with different values.""" - from satpy.dataset.metadata import combine_metadata metadatas = [ {"prerequisites": [1, 2, 3, 4]}, {"prerequisites": [1, 2, 3, 5]}, @@ -241,7 +232,6 @@ def test_combine_lists_same_size_diff_values(self): def test_combine_lists_different_size(self): """Test combine metadata with different size lists.""" - from satpy.dataset.metadata import combine_metadata metadatas = [ {"prerequisites": [1, 2, 3, 4]}, {"prerequisites": []}, @@ -258,25 +248,21 @@ def test_combine_lists_different_size(self): def test_combine_identical_numpy_scalars(self): """Test combining identical fill values.""" - from satpy.dataset.metadata import combine_metadata test_metadata = [{"_FillValue": np.uint16(42)}, {"_FillValue": np.uint16(42)}] assert combine_metadata(*test_metadata) == {"_FillValue": 42} def test_combine_empty_metadata(self): """Test combining empty metadata.""" - from satpy.dataset.metadata import combine_metadata test_metadata = [{}, {}] assert combine_metadata(*test_metadata) == {} def test_combine_nans(self): """Test combining nan fill values.""" - from satpy.dataset.metadata import combine_metadata test_metadata = [{"_FillValue": np.nan}, {"_FillValue": np.nan}] assert combine_metadata(*test_metadata) == {"_FillValue": np.nan} def test_combine_numpy_arrays(self): """Test combining values that are numpy arrays.""" - from satpy.dataset.metadata import combine_metadata test_metadata = [{"valid_range": np.array([0., 0.00032], dtype=np.float32)}, {"valid_range": np.array([0., 0.00032], dtype=np.float32)}, {"valid_range": np.array([0., 0.00032], dtype=np.float32)}] @@ -287,7 +273,6 @@ def test_combine_dask_arrays(self): """Test combining values that are dask arrays.""" import dask.array as da - from satpy.dataset.metadata import combine_metadata test_metadata = [{"valid_range": da.from_array(np.array([0., 0.00032], dtype=np.float32))}, {"valid_range": da.from_array(np.array([0., 0.00032], dtype=np.float32))}] result = combine_metadata(*test_metadata) @@ -327,7 +312,6 @@ def test_combine_real_world_mda(self): "sensor": {"viirs"}, "raw_metadata": {"foo": {"bar": np.array([1, 2, 3])}}} - from satpy.dataset.metadata import combine_metadata result = combine_metadata(*mda_objects) assert np.allclose(result.pop("_FillValue"), expected.pop("_FillValue"), equal_nan=True) assert np.allclose(result.pop("valid_range"), expected.pop("valid_range")) @@ -357,7 +341,6 @@ def test_combine_one_metadata_object(self): "platform_name": "NOAA-20", "sensor": {"viirs"}} - from satpy.dataset.metadata import combine_metadata result = combine_metadata(*mda_objects) assert np.allclose(result.pop("_FillValue"), expected.pop("_FillValue"), equal_nan=True) assert np.allclose(result.pop("valid_range"), expected.pop("valid_range")) From 1e8826955b7dd44b63045e11c71d4bfe92eea8c8 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Sat, 7 Dec 2024 10:10:22 -0600 Subject: [PATCH 04/22] Fix inconsistency with DataID "resolution" transitive property --- satpy/dataset/dataid.py | 71 +++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/satpy/dataset/dataid.py b/satpy/dataset/dataid.py index d8301bc453..e2913aac7e 100644 --- a/satpy/dataset/dataid.py +++ b/satpy/dataset/dataid.py @@ -239,48 +239,43 @@ def __hash__(self): #: Default ID keys DataArrays. -default_id_keys_config = {"name": { - "required": True, - }, - "wavelength": { - "type": WavelengthRange, - }, - "resolution": { - "transitive": False, - }, - "calibration": { - "enum": [ - "reflectance", - "brightness_temperature", - "radiance", - "radiance_wavenumber", - "counts" - ], - "transitive": True, - }, - "modifiers": { - "default": ModifierTuple(), - "type": ModifierTuple, - }, - } +default_id_keys_config = { + "name": { + "required": True, + }, + "wavelength": { + "type": WavelengthRange, + }, + "resolution": { + "transitive": False, + }, + "calibration": { + "enum": [ + "reflectance", + "brightness_temperature", + "radiance", + "radiance_wavenumber", + "counts", + ], + "transitive": True, + }, + "modifiers": { + "default": ModifierTuple(), + "type": ModifierTuple, + }, +} #: Default ID keys for coordinate DataArrays. -default_co_keys_config = {"name": { - "required": True, - }, - "resolution": { - "transitive": True, - } - } +default_co_keys_config = { + "name": default_id_keys_config["name"], + "resolution": default_id_keys_config["resolution"], +} #: Minimal ID keys for DataArrays, for example composites. -minimal_default_keys_config = {"name": { - "required": True, - }, - "resolution": { - "transitive": True, - } - } +minimal_default_keys_config = { + "name": default_id_keys_config["name"], + "resolution": default_id_keys_config["resolution"], +} class DataID(dict): From dd46a4e94cd4999322c2f39925074ea123b44758 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Mon, 9 Dec 2024 10:10:33 -0600 Subject: [PATCH 05/22] Convert dataquery tests to pytest --- satpy/tests/test_dataset.py | 47 +++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/satpy/tests/test_dataset.py b/satpy/tests/test_dataset.py index e1d9c98bb3..634fc2f13a 100644 --- a/satpy/tests/test_dataset.py +++ b/satpy/tests/test_dataset.py @@ -17,7 +17,6 @@ """Test objects and functions in the dataset module.""" import datetime as dt -import unittest import numpy as np import pytest @@ -584,32 +583,30 @@ def test_create_less_modified_query(self): assert not d2.create_less_modified_query()["modifiers"] -class TestIDQueryInteractions(unittest.TestCase): +class TestIDQueryInteractions: """Test the interactions between DataIDs and DataQuerys.""" - def setUp(self) -> None: - """Set up the test case.""" - self.default_id_keys_config = { - "name": { - "required": True, - }, - "wavelength": { - "type": WavelengthRange, - }, - "resolution": None, - "calibration": { - "enum": [ - "reflectance", - "brightness_temperature", - "radiance", - "counts" - ] - }, - "modifiers": { - "default": ModifierTuple(), - "type": ModifierTuple, - }, - } + default_id_keys_config = { + "name": { + "required": True, + }, + "wavelength": { + "type": WavelengthRange, + }, + "resolution": None, + "calibration": { + "enum": [ + "reflectance", + "brightness_temperature", + "radiance", + "counts" + ] + }, + "modifiers": { + "default": ModifierTuple(), + "type": ModifierTuple, + }, + } def test_hash_equality(self): """Test hash equality.""" From 251524b3e1d206e5ef6a96f92d4759f49db7e4b6 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Mon, 9 Dec 2024 10:45:44 -0600 Subject: [PATCH 06/22] Fix DataQuery equality to require query keys to match --- satpy/dataset/dataid.py | 14 ++++++++------ satpy/tests/test_dataset.py | 12 +++++++++++- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/satpy/dataset/dataid.py b/satpy/dataset/dataid.py index e2913aac7e..a08900ae34 100644 --- a/satpy/dataset/dataid.py +++ b/satpy/dataset/dataid.py @@ -491,7 +491,8 @@ def _generalize_value_for_comparison(val): class DataQuery: """The data query object. - A DataQuery can be used in Satpy to query for a Dataset. This way + A DataQuery can be used in Satpy to query a dict using ``DataID`` objects + as keys. This way a fully qualified DataID can be found even if some DataID elements are unknown. In this case a `*` signifies something that is unknown or not applicable to the requested Dataset. @@ -513,17 +514,18 @@ def __eq__(self, other): A DataQuery is considered equal to another DataQuery or DataID if they have common keys that have equal values. """ - sdict = self._asdict() + sdict = self._to_trimmed_dict() try: odict = other._asdict() except AttributeError: return False common_keys = False for key, val in sdict.items(): - if key in odict: - common_keys = True - if odict[key] != val and val is not None: - return False + if key not in odict: + return False + common_keys = True + if odict[key] != val: + return False return common_keys def __hash__(self): diff --git a/satpy/tests/test_dataset.py b/satpy/tests/test_dataset.py index 634fc2f13a..c5f7876ebb 100644 --- a/satpy/tests/test_dataset.py +++ b/satpy/tests/test_dataset.py @@ -639,7 +639,17 @@ def test_id_filtering(self): did = make_cid(name="static_image") assert len(dq.filter_dataids([did])) == 0 - def test_inequality(self): + def test_equality_no_modifiers(self): + """Test that a query finds unmodified ID when not specified.""" + data_id = DataID(self.default_id_keys_config, name="1", resolution=500) + assert data_id["modifiers"] == tuple() + assert DataQuery(name="1", resolution=500) == data_id + + def test_inequality_missing_keys(self): + """Check inequality against a DataID missing a query parameter.""" + assert DataQuery(name="1", resolution=500) != DataID(self.default_id_keys_config, name="1") + + def test_inequality_diff_required_keys(self): """Check (in)equality.""" assert DataQuery(wavelength=10) != DataID(self.default_id_keys_config, name="VIS006") From ef07c5540b2d597266dc8e2e50c1f2de8ecd99cb Mon Sep 17 00:00:00 2001 From: David Hoese Date: Mon, 9 Dec 2024 11:00:09 -0600 Subject: [PATCH 07/22] Fix inconsistency with DataID "resolution" transitive property --- satpy/dataset/dataid.py | 17 +++++++++++------ satpy/tests/test_dataset.py | 21 +++++++++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/satpy/dataset/dataid.py b/satpy/dataset/dataid.py index a08900ae34..2bbd5368d8 100644 --- a/satpy/dataset/dataid.py +++ b/satpy/dataset/dataid.py @@ -532,12 +532,17 @@ def __hash__(self): """Hash.""" fields = [] values = [] - for field, value in sorted(self._dict.items()): - if value != "*": - fields.append(field) - if isinstance(value, (list, set)): - value = tuple(value) - values.append(value) + for field, value in sorted(self._to_trimmed_dict().items()): + if value == "*": + continue + fields.append(field) + if isinstance(value, list): + # list or tuple is ordered (ex. modifiers) + value = tuple(value) + elif isinstance(value, set): + # a set is unordered, but must be sorted for consistent hashing + value = tuple(sorted(value)) + values.append(value) return hash(tuple(zip(fields, values))) def get(self, key, default=None): diff --git a/satpy/tests/test_dataset.py b/satpy/tests/test_dataset.py index c5f7876ebb..ca39bd50d9 100644 --- a/satpy/tests/test_dataset.py +++ b/satpy/tests/test_dataset.py @@ -614,6 +614,27 @@ def test_hash_equality(self): did = DataID(self.default_id_keys_config, name="cheese_shops") assert hash(dq) == hash(did) + def test_hash_wildcard_equality(self): + """Test hashes are equal with or without wildcards.""" + assert DataQuery(name="1", resolution="*") == DataQuery(name="1") + + @pytest.mark.parametrize( + "modifiers", + [ + ("a", "b", "c"), + ["a", "b", "c"], + ], + ) + def test_hash_list_equality(self, modifiers): + """Test hashes are equal regardless of list type.""" + assert hash(DataQuery(name="1", modifiers=("a", "b", "c"))) == hash(DataQuery(name="1", modifiers=modifiers)) + + def test_hash_set_equality(self): + """Test hashes are equal regardless of set type.""" + the_set = {"c", "b", "a"} + the_tuple = ("a", "b", "c") + assert hash(DataQuery(name="1", some_set=the_set)) == hash(DataQuery(name="1", some_set=the_tuple)) + def test_id_filtering(self): """Check did filtering.""" dq = DataQuery(modifiers=tuple(), name="cheese_shops") From 438de2c7a3a7fd280bc97ed3421c6546670e7052 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Mon, 9 Dec 2024 12:56:50 -0600 Subject: [PATCH 08/22] Split data query tests to be more explicit --- satpy/tests/test_dataset.py | 39 +++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/satpy/tests/test_dataset.py b/satpy/tests/test_dataset.py index ca39bd50d9..68c8abb32e 100644 --- a/satpy/tests/test_dataset.py +++ b/satpy/tests/test_dataset.py @@ -635,7 +635,7 @@ def test_hash_set_equality(self): the_tuple = ("a", "b", "c") assert hash(DataQuery(name="1", some_set=the_set)) == hash(DataQuery(name="1", some_set=the_tuple)) - def test_id_filtering(self): + def test_id_filtering_name(self): """Check did filtering.""" dq = DataQuery(modifiers=tuple(), name="cheese_shops") did = DataID(self.default_id_keys_config, name="cheese_shops") @@ -644,18 +644,37 @@ def test_id_filtering(self): assert len(res) == 1 assert res[0] == did - dataid_container = [DataID(self.default_id_keys_config, - name="ds1", - resolution=250, - calibration="reflectance", - modifiers=tuple())] - dq = DataQuery(wavelength=0.22, modifiers=tuple()) - assert len(dq.filter_dataids(dataid_container)) == 0 - dataid_container = [DataID(minimal_default_keys_config, - name="natural_color")] + @pytest.mark.parametrize( + ("id_kwargs", "query_kwargs", "exp_match"), + [ + ({}, {}, 0), + ({"wavelength": (0.1, 0.2, 0.3)}, {}, 1), + ], + ) + def test_id_filtering_wavelength(self, id_kwargs, query_kwargs, exp_match): + """Test that a query on wavelength doesn't match ID without.""" + dataid_container = [ + DataID(self.default_id_keys_config, + name="ds1", + resolution=250, + calibration="reflectance", + modifiers=tuple(), + **id_kwargs, + ), + ] + dq = DataQuery(wavelength=0.22, modifiers=tuple(), **query_kwargs) + assert len(dq.filter_dataids(dataid_container)) == exp_match + + def test_id_filtering_composite_resolution(self): + """Test that a query for a composite with resolution still finds the composite.""" + dataid_container = [ + DataID(minimal_default_keys_config, name="natural_color"), + ] dq = DataQuery(name="natural_color", resolution=250) assert len(dq.filter_dataids(dataid_container)) == 1 + def test_id_filtering_wavelength_unrelated(self): + """Test that no name query doesn't match name-only ID.""" dq = make_dsq(wavelength=0.22, modifiers=("mod1",)) did = make_cid(name="static_image") assert len(dq.filter_dataids([did])) == 0 From ce01aa6df42ee7fc1780e0b9ed49aa7325acd361 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Mon, 9 Dec 2024 13:07:40 -0600 Subject: [PATCH 09/22] Cleanup satpy internals documentation regarding DataQuery equality --- doc/source/dev_guide/satpy_internals.rst | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/doc/source/dev_guide/satpy_internals.rst b/doc/source/dev_guide/satpy_internals.rst index 566aaf5f67..0766b8466e 100644 --- a/doc/source/dev_guide/satpy_internals.rst +++ b/doc/source/dev_guide/satpy_internals.rst @@ -137,8 +137,9 @@ DataID and DataQuery interactions Different DataIDs and DataQuerys can have different metadata items defined. As such we define equality between different instances of these classes, and across the classes as equality between the sorted key/value pairs shared between the instances. -If a DataQuery has one or more values set to `'*'`, the corresponding key/value pair will be omitted from the comparison. -Instances sharing no keys will no be equal. +If a DataQuery has one or more values set to ``'*'``, the corresponding key/value pair will be omitted from the comparison. +If any key in a DataQuery (not equal to ``"*"``) is missing from a DataID they +are considered not equal. Breaking changes from DatasetIDs @@ -151,8 +152,9 @@ Breaking changes from DatasetIDs Creating DataID for tests ========================= -Sometimes, it is useful to create `DataID` instances for testing purposes. For these cases, the `satpy.tests.utils` module -now has a `make_dsid` function that can be used just for this:: +Sometimes, it is useful to create ``DataID`` instances for testing purposes. +For these cases, the ``satpy.tests.utils`` module now has a ``make_dataid`` +function that can be used just for this:: from satpy.tests.utils import make_dataid did = make_dataid(name='camembert', modifiers=('runny',)) From ed2867a81ce2cf5519731ef373406539c2f34a20 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Fri, 13 Dec 2024 15:30:45 -0600 Subject: [PATCH 10/22] Change DataQuery equality to require all query keys to be equal Includes changes to loading compositors to a DataID with all query parameters --- satpy/dataset/data_dict.py | 20 +++-- satpy/dataset/dataid.py | 124 ++++++++++++++++----------- satpy/dependency_tree.py | 66 ++++++++++++-- satpy/node.py | 6 +- satpy/tests/scene_tests/test_load.py | 1 - satpy/tests/test_dataset.py | 91 +++++++++++++++++++- 6 files changed, 236 insertions(+), 72 deletions(-) diff --git a/satpy/dataset/data_dict.py b/satpy/dataset/data_dict.py index 783ddc4487..846a65e8b0 100644 --- a/satpy/dataset/data_dict.py +++ b/satpy/dataset/data_dict.py @@ -143,16 +143,18 @@ def get_key(self, match_key, num_results=1, best=True, **dfilter): # noqa: D417 """Get multiple fully-specified keys that match the provided query. Args: - key (DataID): DataID of query parameters to use for - searching. Any parameter that is `None` - is considered a wild card and any match is - accepted. Can also be a string representing the - dataset name or a number representing the dataset - wavelength. + match_key (DataID): DataID of query parameters to use for + searching. Any parameter that is `None` + is considered a wild card and any match is + accepted. Can also be a string representing the + dataset name or a number representing the dataset + wavelength. num_results (int): Number of results to return. If `0` return all, - if `1` return only that element, otherwise - return a list of matching keys. - **dfilter (dict): See `get_key` function for more information. + if `1` return only that element, otherwise + return a list of matching keys. + best (bool): Sort results to get "best" result first + (default: True). See `get_best_dataset_key` for details. + **dfilter (dict): See :func:`get_key` function for more information. """ return get_key(match_key, self.keys(), num_results=num_results, diff --git a/satpy/dataset/dataid.py b/satpy/dataset/dataid.py index 2bbd5368d8..a1e580b671 100644 --- a/satpy/dataset/dataid.py +++ b/satpy/dataset/dataid.py @@ -21,7 +21,7 @@ from contextlib import suppress from copy import copy, deepcopy from enum import Enum, IntEnum -from typing import NoReturn +from typing import Any, NoReturn import numpy as np @@ -492,10 +492,13 @@ class DataQuery: """The data query object. A DataQuery can be used in Satpy to query a dict using ``DataID`` objects - as keys. This way - a fully qualified DataID can be found even if some DataID - elements are unknown. In this case a `*` signifies something that is - unknown or not applicable to the requested Dataset. + as keys. In a plain Python builtin ``dict`` object a fully matching + ``DataQuery`` can be used to access the value of the matching ``DataID``. + Using Satpy's special :class:``~satpy.dataid.data_dict.DatasetDict`` a + ``DataQuery`` will match the closest matching ``DataID``. In this case a + ``"*"`` in the query signifies something that is unknown or not applicable + to the requested Dataset. See the ``DatasetDict`` class for more information + including retrieving all items matching a ``DataQuery``. """ def __init__(self, **kwargs): @@ -511,30 +514,74 @@ def __getitem__(self, key): def __eq__(self, other): """Compare the DataQuerys. - A DataQuery is considered equal to another DataQuery or DataID - if they have common keys that have equal values. + A DataQuery is considered equal to another DataQuery if all keys + are shared between them and are equal. A DataQuery is considered + equal to a DataID if all elements in the query are equal to those + elements in the DataID. The DataID is still considered equal if it + contains additional elements. Any DataQuery elements with the value + ``"*"`` are ignored. + """ - sdict = self._to_trimmed_dict() + sdict = self._asdict() try: odict = other._asdict() except AttributeError: return False - common_keys = False - for key, val in sdict.items(): - if key not in odict: - return False - common_keys = True - if odict[key] != val: + + if not sdict and not odict: + return True + + # if other is a DataID then must match this query exactly + keys_to_match = set(sdict.keys()) + o_is_id = hasattr(other, "id_keys") + if not o_is_id: + # if another DataQuery, then compare both sets of keys + keys_to_match |= set(odict.keys()) + if not keys_to_match: + return False + + for key in keys_to_match: + if not self._compare_key_equality(sdict, odict, key, o_is_id): return False - return common_keys + return True + + @staticmethod + def _compare_key_equality(sdict: dict, odict: dict, key: str, o_is_id: bool) -> bool: + if key not in sdict: + return False + sval = sdict[key] + if sval == "*": + return True + + if key not in odict: + return False + oval = odict[key] + if oval == "*": + # Gotcha: if a DataID contains a "*" this could cause + # unexpected matches. A DataID is not expected to use "*" + return True + + if isinstance(sval, list) or isinstance(oval, list): + # multiple options to match + if not isinstance(sval, list): + # query to query comparison, make a list to iterate over + sval = [sval] + if o_is_id: + return oval in sval + + # we're matching against a DataQuery who could have its own list + if not isinstance(oval, list): + oval = [oval] + s_in_o = any(_sval in oval for _sval in sval) + o_in_s = any(_oval in sval for _oval in oval) + return s_in_o or o_in_s + return oval == sval def __hash__(self): """Hash.""" fields = [] values = [] for field, value in sorted(self._to_trimmed_dict().items()): - if value == "*": - continue fields.append(field) if isinstance(value, list): # list or tuple is ordered (ex. modifiers) @@ -579,31 +626,9 @@ def __repr__(self): def filter_dataids(self, dataid_container): """Filter DataIDs based on this query.""" - keys = list(filter(self._match_dataid, dataid_container)) - + keys = list(filter(self.__eq__, dataid_container)) return keys - def _match_dataid(self, dataid): - """Match the dataid with the current query.""" - if self._shares_required_keys(dataid): - keys_to_check = set(dataid.keys()) & set(self._fields) - else: - keys_to_check = set(dataid._id_keys.keys()) & set(self._fields) - if not keys_to_check: - return False - return all(self._match_query_value(key, dataid.get(key)) for key in keys_to_check) - - def _shares_required_keys(self, dataid): - """Check if dataid shares required keys with the current query.""" - for key, val in dataid._id_keys.items(): - try: - if val.get("required", False): - if key in self._fields: - return True - except AttributeError: - continue - return False - def _match_query_value(self, key, id_val): val = self._dict[key] if val == "*": @@ -734,21 +759,20 @@ def create_filtered_query(dataset_key, filter_query): return DataQuery.from_dict(ds_dict) -def _update_dict_with_filter_query(ds_dict, filter_query): +def _update_dict_with_filter_query(ds_dict: dict[str, Any], filter_query: dict[str, Any]) -> None: if filter_query is not None: for key, value in filter_query.items(): if value != "*": ds_dict.setdefault(key, value) -def _create_id_dict_from_any_key(dataset_key): - try: +def _create_id_dict_from_any_key(dataset_key: DataQuery | DataID | str | numbers.Number) -> dict[str, Any]: + if hasattr(dataset_key, "to_dict"): ds_dict = dataset_key.to_dict() - except AttributeError: - if isinstance(dataset_key, str): - ds_dict = {"name": dataset_key} - elif isinstance(dataset_key, numbers.Number): - ds_dict = {"wavelength": dataset_key} - else: - raise TypeError("Don't know how to interpret a dataset_key of type {}".format(type(dataset_key))) + elif isinstance(dataset_key, str): + ds_dict = {"name": dataset_key} + elif isinstance(dataset_key, numbers.Number): + ds_dict = {"wavelength": dataset_key} + else: + raise TypeError("Don't know how to interpret a dataset_key of type {}".format(type(dataset_key))) return ds_dict diff --git a/satpy/dependency_tree.py b/satpy/dependency_tree.py index 7c2b65a6c5..f6af1f1418 100644 --- a/satpy/dependency_tree.py +++ b/satpy/dependency_tree.py @@ -19,13 +19,16 @@ from __future__ import annotations +import warnings from typing import Container, Iterable, Optional import numpy as np +from holoviews.core.options import Compositor -from satpy import DataID, DatasetDict +from satpy import DataID, DataQuery, DatasetDict from satpy.dataset import ModifierTuple, create_filtered_query from satpy.dataset.data_dict import TooManyResults, get_key +from satpy.dataset.dataid import default_id_keys_config from satpy.node import EMPTY_LEAF_NAME, LOG, CompositorNode, MissingDependencies, Node, ReaderNode @@ -245,8 +248,9 @@ def populate_with_keys(self, dataset_keys: set, query=None): unknown_datasets = list() known_nodes = list() for key in dataset_keys.copy(): + dsq = create_filtered_query(key, query) + try: - dsq = create_filtered_query(key, query) node = self._create_subtree_for_key(dsq, query) except MissingDependencies as unknown: unknown_datasets.append(unknown.missing_dependencies) @@ -405,7 +409,7 @@ def _get_subtree_for_existing_name(self, dsq): LOG.trace("Composite already loaded:\n\tRequested: {}\n\tFound: {}".format(dsq, node.name)) return node except KeyError: - # composite hasn't been loaded yet, let's load it below + # composite hasn't been loaded yet, let's load it next LOG.trace("Composite hasn't been loaded yet, will load: {}".format(dsq)) raise MissingDependencies({dsq}) @@ -424,6 +428,7 @@ def _find_compositor(self, dataset_key, query): # one or more modifications if it has modifiers see if we can find # the unmodified version first + orig_query = dataset_key if dataset_key.is_modified(): implicit_dependency_node = self._create_implicit_dependency_subtree(dataset_key, query) dataset_key = self._promote_query_to_modified_dataid(dataset_key, implicit_dependency_node.name) @@ -438,10 +443,21 @@ def _find_compositor(self, dataset_key, query): except KeyError: raise KeyError("Can't find anything called {}".format(str(dataset_key))) - root = CompositorNode(compositor) + new_id_dict = compositor.id.to_dict() + new_id = None + # TODO: dataset_key could include ID parameters from composite YAML, is this different from load kwargs? + if compositor.id.to_dict() != dataset_key._asdict(): + id_keys = default_id_keys_config # minimal_default_keys_config + for query_key, query_val in dataset_key.to_dict().items(): + # XXX: What if the query_val is a list? + if new_id_dict.get(query_key) is None and query_key in id_keys and query_val != "*": + new_id_dict[query_key] = query_val + new_id = DataID(id_keys, **new_id_dict) + + root = CompositorNode(compositor, new_id=new_id) composite_id = root.name - prerequisite_filter = composite_id.create_filter_query_without_required_fields(dataset_key) + prerequisite_filter = composite_id.create_filter_query_without_required_fields(orig_query) # Get the prerequisites LOG.trace("Looking for composite prerequisites for: {}".format(dataset_key)) @@ -488,7 +504,7 @@ def _promote_query_to_modified_dataid(self, query, dep_key): orig_dict[key] = dep_val return dep_key.from_dict(orig_dict) - def get_compositor(self, key): + def get_compositor(self, key: DataQuery): """Get a compositor.""" for sensor_name in sorted(self.compositors): try: @@ -496,6 +512,44 @@ def get_compositor(self, key): except KeyError: continue + if key.get("name", default="*") != "*" and len(key.to_dict()) == 1: + # the query key is just the name and still couldn't be found + raise KeyError("Could not find compositor '{}'".format(key)) + + # Get the generic version of the compositor (by name only) + # then save our new version under the new name + + return self._get_compositor_by_name(key) + + def _get_compositor_by_name(self, key: DataQuery) -> Compositor | None: + name_query = DataQuery(name=key["name"]) + for sensor_name in sorted(self.compositors): + sensor_data_dict = self.compositors[sensor_name] + try: + # get all IDs that have the minimum "distance" for our composite name + all_comp_ids = sensor_data_dict.get_key(name_query, num_results=0) + # Filter to those that don't disagree with the original query + matching_comp_ids = [] + for comp_id in all_comp_ids: + for query_key, query_val in key.to_dict().items(): + # TODO: Handle query_vals that are lists + if comp_id.get(query_key, query_val) != query_val: + break + else: + # all query keys match + matching_comp_ids.append(comp_id) + if len(matching_comp_ids) > 1: + warnings.warn("Multiple compositors matching {name_query} to create {key} variant. " + "Going to use the name-only 'base' compositor definition.") + matching_comp_ids = matching_comp_ids[:1] + except KeyError: + continue + + if len(matching_comp_ids) != 1: + raise KeyError("Can't find compositor {key['name']} by name only.") + comp_id = matching_comp_ids[0] + # should use the "short-circuit" path and find the exact name-only compositor by DataID + return sensor_data_dict[comp_id] raise KeyError("Could not find compositor '{}'".format(key)) def get_modifier(self, comp_id): diff --git a/satpy/node.py b/satpy/node.py index 191ec0bbcf..855af3fe05 100644 --- a/satpy/node.py +++ b/satpy/node.py @@ -158,9 +158,11 @@ def trunk(self, unique=True, limit_children_to=None): class CompositorNode(Node): """Implementation of a compositor-specific node.""" - def __init__(self, compositor): + def __init__(self, compositor, new_id=None): """Set up the node.""" - super().__init__(compositor.id, data=(compositor, [], [])) + if new_id is None: + new_id = compositor.id + super().__init__(new_id, data=(compositor, [], [])) def add_required_nodes(self, children): """Add nodes to the required field.""" diff --git a/satpy/tests/scene_tests/test_load.py b/satpy/tests/scene_tests/test_load.py index 889d9e2cbe..7adf237c57 100644 --- a/satpy/tests/scene_tests/test_load.py +++ b/satpy/tests/scene_tests/test_load.py @@ -14,7 +14,6 @@ # You should have received a copy of the GNU General Public License along with # satpy. If not, see . """Unit tests for loading-related functionality in scene.py.""" - from unittest import mock import pytest diff --git a/satpy/tests/test_dataset.py b/satpy/tests/test_dataset.py index 68c8abb32e..f4b838590e 100644 --- a/satpy/tests/test_dataset.py +++ b/satpy/tests/test_dataset.py @@ -20,6 +20,7 @@ import numpy as np import pytest +import xarray as xr from satpy.dataset.dataid import DataID, DataQuery, ModifierTuple, WavelengthRange, minimal_default_keys_config from satpy.dataset.metadata import combine_metadata @@ -616,7 +617,7 @@ def test_hash_equality(self): def test_hash_wildcard_equality(self): """Test hashes are equal with or without wildcards.""" - assert DataQuery(name="1", resolution="*") == DataQuery(name="1") + assert hash(DataQuery(name="1", resolution="*")) == hash(DataQuery(name="1")) @pytest.mark.parametrize( "modifiers", @@ -649,6 +650,7 @@ def test_id_filtering_name(self): [ ({}, {}, 0), ({"wavelength": (0.1, 0.2, 0.3)}, {}, 1), + ({}, {"name": "ds1"}, 0), ], ) def test_id_filtering_wavelength(self, id_kwargs, query_kwargs, exp_match): @@ -671,10 +673,10 @@ def test_id_filtering_composite_resolution(self): DataID(minimal_default_keys_config, name="natural_color"), ] dq = DataQuery(name="natural_color", resolution=250) - assert len(dq.filter_dataids(dataid_container)) == 1 + assert len(dq.filter_dataids(dataid_container)) == 0 - def test_id_filtering_wavelength_unrelated(self): - """Test that no name query doesn't match name-only ID.""" + def test_id_filtering_unrelated(self): + """Test that a query doesn't match an ID with no matching keys.""" dq = make_dsq(wavelength=0.22, modifiers=("mod1",)) did = make_cid(name="static_image") assert len(dq.filter_dataids([did])) == 0 @@ -685,6 +687,66 @@ def test_equality_no_modifiers(self): assert data_id["modifiers"] == tuple() assert DataQuery(name="1", resolution=500) == data_id + @pytest.mark.parametrize( + ("dq1", "dq2"), + [ + (DataQuery(name="1"), DataQuery(name="1")), + (DataQuery(name="1", resolution="*"), DataQuery(name="1")), + (DataQuery(name="1", resolution="*"), DataQuery(name="1", resolution=500)), + (DataQuery(name="1", resolution=500), DataQuery(name="1", resolution="*")), # opposite order + (DataQuery(), DataQuery()), + (DataQuery(name="1", resolution=[250, 500]), DataQuery(name="1", resolution=[500, 750])), # opposite order + (DataQuery(name="1", resolution=500), DataQuery(name="1", resolution=[500, 750])), # opposite order + (DataQuery(name="1", resolution=[250, 500]), DataQuery(name="1", resolution=500)), # opposite order + ], + ) + def test_equality_queries(self, dq1, dq2): + """Test various query to query comparisons.""" + assert dq1 == dq2 + + @pytest.mark.parametrize( + ("dq1", "dq2"), + [ + (DataQuery(name="1", resolution=[250, 500]), DataQuery(name="1", resolution=[750, 1000])), # opposite order + (DataQuery(name="1"), DataQuery(name="1", resolution=750)), # opposite order + ], + ) + def test_inequality_queries(self, dq1, dq2): + """Test various query to query inequality cases.""" + assert dq1 != dq2 + + @pytest.mark.parametrize( + ("dq", "id_dict"), + [ + (DataQuery(name="1"), dict(name="1")), + (DataQuery(name="1", resolution="*"), dict(name="1")), + (DataQuery(name="1", resolution="*"), dict(name="1", resolution=500)), + # DataID shouldn't use * but we still test it: + (DataQuery(name="1", resolution=500), dict(name="1", resolution="*")), + (DataQuery(), dict()), # probably not useful, but it is a case + (DataQuery(name="1", resolution=[250, 500]), dict(name="1", resolution=500)), + ], + ) + def test_equality_ids(self, dq, id_dict): + """Test various query to DataID equality cases.""" + assert dq == DataID(self.default_id_keys_config, **id_dict) + + @pytest.mark.parametrize( + ("dq", "id_dict"), + [ + (DataQuery(name="1", resolution=500), dict(name="1", resolution=None)), + (DataQuery(), dict(name="1", resolution=1000)), + (DataQuery(name="1", resolution=[250, 500]), dict(name="1", resolution=1000)), + ], + ) + def test_inequality_ids(self, dq, id_dict): + """Test various query to DataID inequality cases.""" + assert dq != DataID(self.default_id_keys_config, **id_dict) + + def test_inequality_unknown_type(self): + """Test equality against non ID/query type.""" + assert DataQuery(name="1") != "1" + def test_inequality_missing_keys(self): """Check inequality against a DataID missing a query parameter.""" assert DataQuery(name="1", resolution=500) != DataID(self.default_id_keys_config, name="1") @@ -1029,3 +1091,24 @@ def test_wavelength_range_cf_roundtrip(): assert WavelengthRange.from_cf(wr.to_cf()) == wr assert WavelengthRange.from_cf([str(item) for item in wr]) == wr + + +def test_dataset_dict_contains_inexact_match(): + """Test that DatasetDict does not match inexact keys to existing keys. + + Specifically, check that additional DataID properties aren't ignored + when querying the DatasetDict. + + See https://github.com/pytroll/satpy/issues/2331. + """ + from satpy.dataset.data_dict import DatasetDict + + dd = DatasetDict() + name = "1" + item = xr.DataArray(()) + dd[name] = item + exact_id = DataQuery(name=name) + assert exact_id in dd + + inexact_id = DataQuery(name=name, resolution=2000) + assert inexact_id not in dd From d223f101388020f22473ee2aaeae644585042eda Mon Sep 17 00:00:00 2001 From: David Hoese Date: Mon, 16 Dec 2024 10:52:40 -0600 Subject: [PATCH 11/22] Add test querying for a wavelength on no-wavelength DataIDs Closes #1806 Closes #1807 --- satpy/tests/test_dataset.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/satpy/tests/test_dataset.py b/satpy/tests/test_dataset.py index f4b838590e..eb693039e1 100644 --- a/satpy/tests/test_dataset.py +++ b/satpy/tests/test_dataset.py @@ -755,6 +755,19 @@ def test_inequality_diff_required_keys(self): """Check (in)equality.""" assert DataQuery(wavelength=10) != DataID(self.default_id_keys_config, name="VIS006") + def test_id_filtering_no_id_wavelength(self): + """Test that a DataID with no wavelength doesn't match a query for a wavelength.""" + did_keys = { + "name": {"required": True}, + "level": {}, + "modifiers": {"default": [], "type": ModifierTuple} + } + did1 = DataID(did_keys, name="test1") + did2 = DataID(did_keys, name="test2") + dq = DataQuery(wavelength=1.8, modifiers=()) + matched_ids = dq.filter_dataids([did1, did2]) + assert len(matched_ids) == 0 + def test_sort_dataids(self): """Check dataid sorting.""" dq = DataQuery(name="cheese_shops", wavelength=2, modifiers="*") From aba5dc558110f1b7eed9f281d1976df811df953f Mon Sep 17 00:00:00 2001 From: David Hoese Date: Mon, 16 Dec 2024 12:39:44 -0600 Subject: [PATCH 12/22] Merge combine times test cases --- satpy/tests/test_dataset.py | 121 ++++++++++++++++++------------------ 1 file changed, 60 insertions(+), 61 deletions(-) diff --git a/satpy/tests/test_dataset.py b/satpy/tests/test_dataset.py index eb693039e1..f6fb1c6ebb 100644 --- a/satpy/tests/test_dataset.py +++ b/satpy/tests/test_dataset.py @@ -114,67 +114,66 @@ def test_average_datetimes(self): ret = average_datetimes(dts) assert dts[2] == ret - def test_combine_start_times(self): - """Test the combine_metadata with start times.""" - # The times need to be in ascending order (oldest first) - start_time_dts = ( - {"start_time": dt.datetime(2018, 2, 1, 11, 58, 0)}, - {"start_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, - {"start_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, - {"start_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, - {"start_time": dt.datetime(2018, 2, 1, 12, 2, 0)}, - ) - ret = combine_metadata(*start_time_dts) - assert ret["start_time"] == start_time_dts[0]["start_time"] - - def test_combine_end_times(self): - """Test the combine_metadata with end times.""" - # The times need to be in ascending order (oldest first) - end_time_dts = ( - {"end_time": dt.datetime(2018, 2, 1, 11, 58, 0)}, - {"end_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, - {"end_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, - {"end_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, - {"end_time": dt.datetime(2018, 2, 1, 12, 2, 0)}, - ) - ret = combine_metadata(*end_time_dts) - assert ret["end_time"] == end_time_dts[-1]["end_time"] - - def test_combine_start_times_with_none(self): - """Test the combine_metadata with start times when there's a None included.""" - start_time_dts_with_none = ( - {"start_time": None}, - {"start_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, - {"start_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, - {"start_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, - {"start_time": dt.datetime(2018, 2, 1, 12, 2, 0)}, - ) - ret = combine_metadata(*start_time_dts_with_none) - assert ret["start_time"] == start_time_dts_with_none[1]["start_time"] - - def test_combine_end_times_with_none(self): - """Test the combine_metadata with end times when there's a None included.""" - end_time_dts_with_none = ( - {"end_time": dt.datetime(2018, 2, 1, 11, 58, 0)}, - {"end_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, - {"end_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, - {"end_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, - {"end_time": None}, - ) - ret = combine_metadata(*end_time_dts_with_none) - assert ret["end_time"] == end_time_dts_with_none[-2]["end_time"] - - def test_combine_other_times(self): - """Test the combine_metadata with other time values than start or end times.""" - other_time_dts = ( - {"other_time": dt.datetime(2018, 2, 1, 11, 58, 0)}, - {"other_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, - {"other_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, - {"other_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, - {"other_time": dt.datetime(2018, 2, 1, 12, 2, 0)}, - ) - ret = combine_metadata(*other_time_dts) - assert ret["other_time"] == other_time_dts[2]["other_time"] + @pytest.mark.parametrize( + ("meta_dicts", "key", "result_idx"), + [ + ( + # The times need to be in ascending order (oldest first) + ({"start_time": dt.datetime(2018, 2, 1, 11, 58, 0)}, + {"start_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, + {"start_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, + {"start_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, + {"start_time": dt.datetime(2018, 2, 1, 12, 2, 0)}, + ), + "start_time", + 0, + ), + ( + ({"end_time": dt.datetime(2018, 2, 1, 11, 58, 0)}, + {"end_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, + {"end_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, + {"end_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, + {"end_time": dt.datetime(2018, 2, 1, 12, 2, 0)}, + ), + "end_time", + -1, + ), + ( + ({"start_time": None}, + {"start_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, + {"start_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, + {"start_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, + {"start_time": dt.datetime(2018, 2, 1, 12, 2, 0)}, + ), + "start_time", + 1, + ), + ( + ({"end_time": dt.datetime(2018, 2, 1, 11, 58, 0)}, + {"end_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, + {"end_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, + {"end_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, + {"end_time": None}, + ), + "end_time", + -2, + ), + ( + ({"other_time": dt.datetime(2018, 2, 1, 11, 58, 0)}, + {"other_time": dt.datetime(2018, 2, 1, 11, 59, 0)}, + {"other_time": dt.datetime(2018, 2, 1, 12, 0, 0)}, + {"other_time": dt.datetime(2018, 2, 1, 12, 1, 0)}, + {"other_time": dt.datetime(2018, 2, 1, 12, 2, 0)}, + ), + "other_time", + 2, + ), + ], + ) + def test_combine_times(self, meta_dicts, key, result_idx): + """Test the combine_metadata with times.""" + ret = combine_metadata(*meta_dicts) + assert ret[key] == meta_dicts[result_idx][key] def test_combine_arrays(self): """Test the combine_metadata with arrays.""" From 24a1068889b97efbe7fa1bb75eb21b3b55e69af6 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Mon, 16 Dec 2024 13:43:11 -0600 Subject: [PATCH 13/22] Refactor ID key types to separate module --- satpy/composites/__init__.py | 2 +- satpy/composites/config_loader.py | 2 +- satpy/dataset/__init__.py | 3 +- satpy/dataset/anc_vars.py | 3 +- satpy/dataset/data_dict.py | 3 +- satpy/dataset/dataid.py | 256 +----------------- satpy/dataset/id_keys.py | 259 +++++++++++++++++++ satpy/dependency_tree.py | 2 +- satpy/etc/readers/msi_safe_l2a.yaml | 2 +- satpy/etc/readers/sgli_l1b.yaml | 2 +- satpy/etc/readers/slstr_l1b.yaml | 2 +- satpy/modifiers/_crefl_utils.py | 2 +- satpy/readers/satpy_cf_nc.py | 2 +- satpy/readers/yaml_reader.py | 2 +- satpy/tests/modifier_tests/test_parallax.py | 2 +- satpy/tests/multiscene_tests/test_utils.py | 2 +- satpy/tests/reader_tests/test_msi_safe.py | 3 +- satpy/tests/reader_tests/test_satpy_cf_nc.py | 11 +- satpy/tests/reader_tests/test_slstr_l1b.py | 3 +- satpy/tests/scene_tests/test_data_access.py | 2 +- satpy/tests/scene_tests/test_load.py | 2 +- satpy/tests/scene_tests/test_resampling.py | 2 +- satpy/tests/test_dataset.py | 56 +--- satpy/tests/test_readers.py | 3 +- satpy/tests/test_yaml_reader.py | 3 +- satpy/tests/utils.py | 2 +- 26 files changed, 311 insertions(+), 322 deletions(-) create mode 100644 satpy/dataset/id_keys.py diff --git a/satpy/composites/__init__.py b/satpy/composites/__init__.py index d7518be91d..7bd6ea3d0e 100644 --- a/satpy/composites/__init__.py +++ b/satpy/composites/__init__.py @@ -29,7 +29,7 @@ import satpy from satpy.aux_download import DataDownloadMixin from satpy.dataset import DataID, combine_metadata -from satpy.dataset.dataid import minimal_default_keys_config +from satpy.dataset.id_keys import minimal_default_keys_config from satpy.utils import unify_chunks from satpy.writers import get_enhanced_image diff --git a/satpy/composites/config_loader.py b/satpy/composites/config_loader.py index bffbee8a13..f1e8d9b821 100644 --- a/satpy/composites/config_loader.py +++ b/satpy/composites/config_loader.py @@ -30,7 +30,7 @@ import satpy from satpy import DataID, DataQuery from satpy._config import config_search_paths, get_entry_points_config_dirs, glob_config -from satpy.dataset.dataid import minimal_default_keys_config +from satpy.dataset.id_keys import minimal_default_keys_config from satpy.utils import recursive_dict_update logger = logging.getLogger(__name__) diff --git a/satpy/dataset/__init__.py b/satpy/dataset/__init__.py index 33978048b2..1493c652b2 100644 --- a/satpy/dataset/__init__.py +++ b/satpy/dataset/__init__.py @@ -19,5 +19,6 @@ from .anc_vars import dataset_walker, replace_anc # noqa from .data_dict import DatasetDict, get_key # noqa -from .dataid import DataID, DataQuery, ModifierTuple, WavelengthRange, create_filtered_query # noqa +from .dataid import DataID, DataQuery, create_filtered_query # noqa +from .id_keys import ModifierTuple, WavelengthRange # noqa from .metadata import combine_metadata # noqa diff --git a/satpy/dataset/anc_vars.py b/satpy/dataset/anc_vars.py index 90b2d7bd3c..4a092137d9 100644 --- a/satpy/dataset/anc_vars.py +++ b/satpy/dataset/anc_vars.py @@ -17,7 +17,8 @@ # satpy. If not, see . """Utilities for dealing with ancillary variables.""" -from .dataid import DataID, default_id_keys_config +from .dataid import DataID +from .id_keys import default_id_keys_config def dataset_walker(datasets): diff --git a/satpy/dataset/data_dict.py b/satpy/dataset/data_dict.py index 846a65e8b0..964df8d7e0 100644 --- a/satpy/dataset/data_dict.py +++ b/satpy/dataset/data_dict.py @@ -19,7 +19,8 @@ import numpy as np -from .dataid import DataID, create_filtered_query, minimal_default_keys_config +from .dataid import DataID, create_filtered_query +from .id_keys import minimal_default_keys_config class TooManyResults(KeyError): diff --git a/satpy/dataset/dataid.py b/satpy/dataset/dataid.py index a1e580b671..1f17568ad1 100644 --- a/satpy/dataset/dataid.py +++ b/satpy/dataset/dataid.py @@ -17,265 +17,15 @@ import logging import numbers -from collections import namedtuple -from contextlib import suppress from copy import copy, deepcopy -from enum import Enum, IntEnum +from enum import Enum from typing import Any, NoReturn import numpy as np -logger = logging.getLogger(__name__) - - -def get_keys_from_config(common_id_keys, config): - """Gather keys for a new DataID from the ones available in configured dataset.""" - id_keys = {} - for key, val in common_id_keys.items(): - if key in config: - id_keys[key] = val - elif val is not None and (val.get("required") is True or val.get("default") is not None): - id_keys[key] = val - if not id_keys: - raise ValueError("Metadata does not contain enough information to create a DataID.") - return id_keys - - -class ValueList(IntEnum): - """A static value list. - - This class is meant to be used for dynamically created Enums. Due to this - it should not be used as a normal Enum class or there may be some - unexpected behavior. For example, this class contains custom pickling and - unpickling handling that may break in subclasses. - - """ - - @classmethod - def convert(cls, value): - """Convert value to an instance of this class.""" - try: - return cls[value] - except KeyError: - raise ValueError("{} invalid value for {}".format(value, cls)) - - @classmethod - def _unpickle(cls, enum_name, enum_members, enum_member): - """Create dynamic class that was previously pickled. - - See :meth:`__reduce_ex__` for implementation details. - - """ - enum_cls = cls(enum_name, enum_members) - return enum_cls[enum_member] - - def __reduce_ex__(self, proto): - """Reduce the object for pickling.""" - return (ValueList._unpickle, - (self.__class__.__name__, list(self.__class__.__members__.keys()), self.name)) - - def __eq__(self, other): - """Check equality.""" - return self.name == other - - def __ne__(self, other): - """Check non-equality.""" - return self.name != other - - def __hash__(self): - """Hash the object.""" - return hash(self.name) - - def __repr__(self): - """Represent the values.""" - return "<" + str(self) + ">" - - -wlklass = namedtuple("WavelengthRange", "min central max unit", defaults=("µm",)) # type: ignore - - -class WavelengthRange(wlklass): - """A named tuple for wavelength ranges. - - The elements of the range are min, central and max values, and optionally a unit - (defaults to µm). No clever unit conversion is done here, it's just used for checking - that two ranges are comparable. - """ - - def __eq__(self, other): - """Return if two wavelengths are equal. - - Args: - other (tuple or scalar): (min wl, nominal wl, max wl) or scalar wl +from satpy.dataset.id_keys import ModifierTuple, ValueList, minimal_default_keys_config - Return: - True if other is a scalar and min <= other <= max, or if other is - a tuple equal to self, False otherwise. - - """ - if other is None: - return False - if isinstance(other, numbers.Number): - return other in self - if isinstance(other, (tuple, list)) and len(other) == 3: - return self[:3] == other - return super().__eq__(other) - - def __ne__(self, other): - """Return the opposite of `__eq__`.""" - return not self == other - - def __lt__(self, other): - """Compare to another wavelength.""" - if other is None: - return False - return super().__lt__(other) - - def __gt__(self, other): - """Compare to another wavelength.""" - if other is None: - return True - return super().__gt__(other) - - def __hash__(self): - """Hash this tuple.""" - return tuple.__hash__(self) - - def __str__(self): - """Format for print out.""" - return "{0.central} {0.unit} ({0.min}-{0.max} {0.unit})".format(self) - - def __contains__(self, other): - """Check if this range contains *other*.""" - if other is None: - return False - if isinstance(other, numbers.Number): - return self.min <= other <= self.max - with suppress(AttributeError): - if self.unit != other.unit: - raise NotImplementedError("Can't compare wavelength ranges with different units.") - return self.min <= other.min and self.max >= other.max - return False - - def distance(self, value): - """Get the distance from value.""" - if self == value: - try: - return abs(value.central - self.central) - except AttributeError: - if isinstance(value, (tuple, list)): - return abs(value[1] - self.central) - return abs(value - self.central) - else: - return np.inf - - @classmethod - def convert(cls, wl): - """Convert `wl` to this type if possible.""" - if isinstance(wl, (tuple, list)): - return cls(*wl) - return wl - - def to_cf(self): - """Serialize for cf export.""" - return str(self) - - @classmethod - def from_cf(cls, blob): - """Return a WavelengthRange from a cf blob.""" - try: - obj = cls._read_cf_from_string_export(blob) - except TypeError: - obj = cls._read_cf_from_string_list(blob) - return obj - - @classmethod - def _read_cf_from_string_export(cls, blob): - """Read blob as a string created by `to_cf`.""" - pattern = "{central:f} {unit:s} ({min:f}-{max:f} {unit2:s})" - from trollsift import Parser - parser = Parser(pattern) - res_dict = parser.parse(blob) - res_dict.pop("unit2") - obj = cls(**res_dict) - return obj - - @classmethod - def _read_cf_from_string_list(cls, blob): - """Read blob as a list of strings (legacy formatting).""" - min_wl, central_wl, max_wl, unit = blob - obj = cls(float(min_wl), float(central_wl), float(max_wl), unit) - return obj - - -class ModifierTuple(tuple): - """A tuple holder for modifiers.""" - - @classmethod - def convert(cls, modifiers): - """Convert `modifiers` to this type if possible.""" - if modifiers is None: - return None - if not isinstance(modifiers, (cls, tuple, list)): - raise TypeError("'DataID' modifiers must be a tuple or None, " - "not {}".format(type(modifiers))) - return cls(modifiers) - - def __eq__(self, other): - """Check equality.""" - if isinstance(other, list): - other = tuple(other) - return super().__eq__(other) - - def __ne__(self, other): - """Check non-equality.""" - if isinstance(other, list): - other = tuple(other) - return super().__ne__(other) - - def __hash__(self): - """Hash this tuple.""" - return tuple.__hash__(self) - - -#: Default ID keys DataArrays. -default_id_keys_config = { - "name": { - "required": True, - }, - "wavelength": { - "type": WavelengthRange, - }, - "resolution": { - "transitive": False, - }, - "calibration": { - "enum": [ - "reflectance", - "brightness_temperature", - "radiance", - "radiance_wavenumber", - "counts", - ], - "transitive": True, - }, - "modifiers": { - "default": ModifierTuple(), - "type": ModifierTuple, - }, -} - -#: Default ID keys for coordinate DataArrays. -default_co_keys_config = { - "name": default_id_keys_config["name"], - "resolution": default_id_keys_config["resolution"], -} - -#: Minimal ID keys for DataArrays, for example composites. -minimal_default_keys_config = { - "name": default_id_keys_config["name"], - "resolution": default_id_keys_config["resolution"], -} +logger = logging.getLogger(__name__) class DataID(dict): diff --git a/satpy/dataset/id_keys.py b/satpy/dataset/id_keys.py new file mode 100644 index 0000000000..46c16ea620 --- /dev/null +++ b/satpy/dataset/id_keys.py @@ -0,0 +1,259 @@ +"""Default ID key sets and types for DataID keys.""" +import numbers +from collections import namedtuple +from contextlib import suppress +from enum import IntEnum + +import numpy as np + + +def get_keys_from_config(common_id_keys, config): + """Gather keys for a new DataID from the ones available in configured dataset.""" + id_keys = {} + for key, val in common_id_keys.items(): + if key in config: + id_keys[key] = val + elif val is not None and (val.get("required") is True or val.get("default") is not None): + id_keys[key] = val + if not id_keys: + raise ValueError("Metadata does not contain enough information to create a DataID.") + return id_keys + + +class ValueList(IntEnum): + """A static value list. + + This class is meant to be used for dynamically created Enums. Due to this + it should not be used as a normal Enum class or there may be some + unexpected behavior. For example, this class contains custom pickling and + unpickling handling that may break in subclasses. + + """ + + @classmethod + def convert(cls, value): + """Convert value to an instance of this class.""" + try: + return cls[value] + except KeyError: + raise ValueError("{} invalid value for {}".format(value, cls)) + + @classmethod + def _unpickle(cls, enum_name, enum_members, enum_member): + """Create dynamic class that was previously pickled. + + See :meth:`__reduce_ex__` for implementation details. + + """ + enum_cls = cls(enum_name, enum_members) + return enum_cls[enum_member] + + def __reduce_ex__(self, proto): + """Reduce the object for pickling.""" + return (ValueList._unpickle, + (self.__class__.__name__, list(self.__class__.__members__.keys()), self.name)) + + def __eq__(self, other): + """Check equality.""" + return self.name == other + + def __ne__(self, other): + """Check non-equality.""" + return self.name != other + + def __hash__(self): + """Hash the object.""" + return hash(self.name) + + def __repr__(self): + """Represent the values.""" + return "<" + str(self) + ">" + + +wlklass = namedtuple("WavelengthRange", "min central max unit", defaults=("µm",)) # type: ignore + + +class WavelengthRange(wlklass): + """A named tuple for wavelength ranges. + + The elements of the range are min, central and max values, and optionally a unit + (defaults to µm). No clever unit conversion is done here, it's just used for checking + that two ranges are comparable. + """ + + def __eq__(self, other): + """Return if two wavelengths are equal. + + Args: + other (tuple or scalar): (min wl, nominal wl, max wl) or scalar wl + + Return: + True if other is a scalar and min <= other <= max, or if other is + a tuple equal to self, False otherwise. + + """ + if other is None: + return False + if isinstance(other, numbers.Number): + return other in self + if isinstance(other, (tuple, list)) and len(other) == 3: + return self[:3] == other + return super().__eq__(other) + + def __ne__(self, other): + """Return the opposite of `__eq__`.""" + return not self == other + + def __lt__(self, other): + """Compare to another wavelength.""" + if other is None: + return False + return super().__lt__(other) + + def __gt__(self, other): + """Compare to another wavelength.""" + if other is None: + return True + return super().__gt__(other) + + def __hash__(self): + """Hash this tuple.""" + return tuple.__hash__(self) + + def __str__(self): + """Format for print out.""" + return "{0.central} {0.unit} ({0.min}-{0.max} {0.unit})".format(self) + + def __contains__(self, other): + """Check if this range contains *other*.""" + if other is None: + return False + if isinstance(other, numbers.Number): + return self.min <= other <= self.max + with suppress(AttributeError): + if self.unit != other.unit: + raise NotImplementedError("Can't compare wavelength ranges with different units.") + return self.min <= other.min and self.max >= other.max + return False + + def distance(self, value): + """Get the distance from value.""" + if self == value: + try: + return abs(value.central - self.central) + except AttributeError: + if isinstance(value, (tuple, list)): + return abs(value[1] - self.central) + return abs(value - self.central) + else: + return np.inf + + @classmethod + def convert(cls, wl): + """Convert `wl` to this type if possible.""" + if isinstance(wl, (tuple, list)): + return cls(*wl) + return wl + + def to_cf(self): + """Serialize for cf export.""" + return str(self) + + @classmethod + def from_cf(cls, blob): + """Return a WavelengthRange from a cf blob.""" + try: + obj = cls._read_cf_from_string_export(blob) + except TypeError: + obj = cls._read_cf_from_string_list(blob) + return obj + + @classmethod + def _read_cf_from_string_export(cls, blob): + """Read blob as a string created by `to_cf`.""" + pattern = "{central:f} {unit:s} ({min:f}-{max:f} {unit2:s})" + from trollsift import Parser + parser = Parser(pattern) + res_dict = parser.parse(blob) + res_dict.pop("unit2") + obj = cls(**res_dict) + return obj + + @classmethod + def _read_cf_from_string_list(cls, blob): + """Read blob as a list of strings (legacy formatting).""" + min_wl, central_wl, max_wl, unit = blob + obj = cls(float(min_wl), float(central_wl), float(max_wl), unit) + return obj + + +class ModifierTuple(tuple): + """A tuple holder for modifiers.""" + + @classmethod + def convert(cls, modifiers): + """Convert `modifiers` to this type if possible.""" + if modifiers is None: + return None + if not isinstance(modifiers, (cls, tuple, list)): + raise TypeError("'DataID' modifiers must be a tuple or None, " + "not {}".format(type(modifiers))) + return cls(modifiers) + + def __eq__(self, other): + """Check equality.""" + if isinstance(other, list): + other = tuple(other) + return super().__eq__(other) + + def __ne__(self, other): + """Check non-equality.""" + if isinstance(other, list): + other = tuple(other) + return super().__ne__(other) + + def __hash__(self): + """Hash this tuple.""" + return tuple.__hash__(self) + + +#: Default ID keys DataArrays. +default_id_keys_config = { + "name": { + "required": True, + }, + "wavelength": { + "type": WavelengthRange, + }, + "resolution": { + "transitive": False, + }, + "calibration": { + "enum": [ + "reflectance", + "brightness_temperature", + "radiance", + "radiance_wavenumber", + "counts", + ], + "transitive": True, + }, + "modifiers": { + "default": ModifierTuple(), + "type": ModifierTuple, + }, +} + + +#: Default ID keys for coordinate DataArrays. +default_co_keys_config = { + "name": default_id_keys_config["name"], + "resolution": default_id_keys_config["resolution"], +} + + +#: Minimal ID keys for DataArrays, for example composites. +minimal_default_keys_config = { + "name": default_id_keys_config["name"], + "resolution": default_id_keys_config["resolution"], +} diff --git a/satpy/dependency_tree.py b/satpy/dependency_tree.py index f6af1f1418..7e0ba442b4 100644 --- a/satpy/dependency_tree.py +++ b/satpy/dependency_tree.py @@ -28,7 +28,7 @@ from satpy import DataID, DataQuery, DatasetDict from satpy.dataset import ModifierTuple, create_filtered_query from satpy.dataset.data_dict import TooManyResults, get_key -from satpy.dataset.dataid import default_id_keys_config +from satpy.dataset.id_keys import default_id_keys_config from satpy.node import EMPTY_LEAF_NAME, LOG, CompositorNode, MissingDependencies, Node, ReaderNode diff --git a/satpy/etc/readers/msi_safe_l2a.yaml b/satpy/etc/readers/msi_safe_l2a.yaml index f4c6e4221a..cadb303f6a 100644 --- a/satpy/etc/readers/msi_safe_l2a.yaml +++ b/satpy/etc/readers/msi_safe_l2a.yaml @@ -12,7 +12,7 @@ reader: name: required: true wavelength: - type: !!python/name:satpy.dataset.dataid.WavelengthRange + type: !!python/name:satpy.dataset.id_keys.WavelengthRange resolution: transitive: false calibration: diff --git a/satpy/etc/readers/sgli_l1b.yaml b/satpy/etc/readers/sgli_l1b.yaml index 4cb86890c4..295ba55f72 100644 --- a/satpy/etc/readers/sgli_l1b.yaml +++ b/satpy/etc/readers/sgli_l1b.yaml @@ -14,7 +14,7 @@ reader: name: required: true wavelength: - type: !!python/name:satpy.dataset.dataid.WavelengthRange + type: !!python/name:satpy.dataset.id_keys.WavelengthRange polarization: transitive: true resolution: diff --git a/satpy/etc/readers/slstr_l1b.yaml b/satpy/etc/readers/slstr_l1b.yaml index 85c875cca2..b3b3369ba0 100644 --- a/satpy/etc/readers/slstr_l1b.yaml +++ b/satpy/etc/readers/slstr_l1b.yaml @@ -13,7 +13,7 @@ reader: name: required: true wavelength: - type: !!python/name:satpy.dataset.dataid.WavelengthRange + type: !!python/name:satpy.dataset.id_keys.WavelengthRange resolution: transitive: false calibration: diff --git a/satpy/modifiers/_crefl_utils.py b/satpy/modifiers/_crefl_utils.py index b8a1d52a4b..5d547e6b0a 100644 --- a/satpy/modifiers/_crefl_utils.py +++ b/satpy/modifiers/_crefl_utils.py @@ -69,7 +69,7 @@ import numpy as np import xarray as xr -from satpy.dataset.dataid import WavelengthRange +from satpy.dataset import WavelengthRange LOG = logging.getLogger(__name__) diff --git a/satpy/readers/satpy_cf_nc.py b/satpy/readers/satpy_cf_nc.py index 9f742272a1..9928361301 100644 --- a/satpy/readers/satpy_cf_nc.py +++ b/satpy/readers/satpy_cf_nc.py @@ -183,7 +183,7 @@ from pyresample import AreaDefinition import satpy.cf.decoding -from satpy.dataset.dataid import WavelengthRange +from satpy.dataset import WavelengthRange from satpy.readers.file_handlers import BaseFileHandler from satpy.utils import get_legacy_chunk_size diff --git a/satpy/readers/yaml_reader.py b/satpy/readers/yaml_reader.py index 5bbaba4a6c..35ccca5863 100644 --- a/satpy/readers/yaml_reader.py +++ b/satpy/readers/yaml_reader.py @@ -44,7 +44,7 @@ from satpy._compat import cache from satpy.aux_download import DataDownloadMixin from satpy.dataset import DataID, DataQuery, get_key -from satpy.dataset.dataid import default_co_keys_config, default_id_keys_config, get_keys_from_config +from satpy.dataset.id_keys import default_co_keys_config, default_id_keys_config, get_keys_from_config from satpy.resample import add_crs_xy_coords, get_area_def from satpy.utils import recursive_dict_update diff --git a/satpy/tests/modifier_tests/test_parallax.py b/satpy/tests/modifier_tests/test_parallax.py index 63ddbd8caf..70714058ef 100644 --- a/satpy/tests/modifier_tests/test_parallax.py +++ b/satpy/tests/modifier_tests/test_parallax.py @@ -731,7 +731,7 @@ def conf_file(self, yaml_code, tmp_path): def fake_scene(self, yaml_code): """Produce fake scene and prepare fake composite config.""" from satpy import Scene - from satpy.dataset.dataid import WavelengthRange + from satpy.dataset import WavelengthRange from satpy.tests.utils import make_dataid area = _get_fake_areas((0, 0), [5], 1)[0] diff --git a/satpy/tests/multiscene_tests/test_utils.py b/satpy/tests/multiscene_tests/test_utils.py index 310d68c215..6829d6f624 100644 --- a/satpy/tests/multiscene_tests/test_utils.py +++ b/satpy/tests/multiscene_tests/test_utils.py @@ -26,7 +26,7 @@ import xarray as xr from pyresample.geometry import AreaDefinition -from satpy.dataset.dataid import ModifierTuple, WavelengthRange +from satpy.dataset import ModifierTuple, WavelengthRange DEFAULT_SHAPE = (5, 10) diff --git a/satpy/tests/reader_tests/test_msi_safe.py b/satpy/tests/reader_tests/test_msi_safe.py index 1f2e603ee2..487be47b59 100644 --- a/satpy/tests/reader_tests/test_msi_safe.py +++ b/satpy/tests/reader_tests/test_msi_safe.py @@ -1456,7 +1456,8 @@ def jp2_builder(process_level, band_name, mask_saturated=True, test_l1b=False): def make_alt_dataid(**items): """Make a DataID with modified keys.""" - from satpy.dataset.dataid import DataID, ModifierTuple, WavelengthRange + from satpy.dataset import ModifierTuple, WavelengthRange + from satpy.dataset.dataid import DataID modified_id_keys_config = { "name": { "required": True, diff --git a/satpy/tests/reader_tests/test_satpy_cf_nc.py b/satpy/tests/reader_tests/test_satpy_cf_nc.py index 56acccdeb9..94b16e3faa 100644 --- a/satpy/tests/reader_tests/test_satpy_cf_nc.py +++ b/satpy/tests/reader_tests/test_satpy_cf_nc.py @@ -27,7 +27,7 @@ from pyresample import AreaDefinition, SwathDefinition from satpy import Scene -from satpy.dataset.dataid import WavelengthRange +from satpy.dataset import WavelengthRange from satpy.readers.satpy_cf_nc import SatpyCFFileHandler # NOTE: @@ -477,7 +477,8 @@ def test_write_and_read_from_two_files(self, nc_filename, nc_filename_i): def test_dataid_attrs_equal_matching_dataset(self, cf_scene, nc_filename): """Check that get_dataset returns valid dataset when keys matches.""" - from satpy.dataset.dataid import DataID, default_id_keys_config + from satpy.dataset.dataid import DataID + from satpy.dataset.id_keys import default_id_keys_config _create_test_netcdf(nc_filename, resolution=742) reader = SatpyCFFileHandler(nc_filename, {}, {"filetype": "info"}) ds_id = DataID(default_id_keys_config, name="solar_zenith_angle", resolution=742, modifiers=()) @@ -486,7 +487,8 @@ def test_dataid_attrs_equal_matching_dataset(self, cf_scene, nc_filename): def test_dataid_attrs_equal_not_matching_dataset(self, cf_scene, nc_filename): """Check that get_dataset returns None when key(s) are not matching.""" - from satpy.dataset.dataid import DataID, default_id_keys_config + from satpy.dataset.dataid import DataID + from satpy.dataset.id_keys import default_id_keys_config _create_test_netcdf(nc_filename, resolution=742) reader = SatpyCFFileHandler(nc_filename, {}, {"filetype": "info"}) not_existing_resolution = 9999999 @@ -496,7 +498,8 @@ def test_dataid_attrs_equal_not_matching_dataset(self, cf_scene, nc_filename): def test_dataid_attrs_equal_contains_not_matching_key(self, cf_scene, nc_filename): """Check that get_dataset returns valid dataset when dataid have key(s) not existing in data.""" - from satpy.dataset.dataid import DataID, default_id_keys_config + from satpy.dataset.dataid import DataID + from satpy.dataset.id_keys import default_id_keys_config _create_test_netcdf(nc_filename, resolution=742) reader = SatpyCFFileHandler(nc_filename, {}, {"filetype": "info"}) ds_id = DataID(default_id_keys_config, name="solar_zenith_angle", resolution=742, diff --git a/satpy/tests/reader_tests/test_slstr_l1b.py b/satpy/tests/reader_tests/test_slstr_l1b.py index becc1455b2..aa115529f0 100644 --- a/satpy/tests/reader_tests/test_slstr_l1b.py +++ b/satpy/tests/reader_tests/test_slstr_l1b.py @@ -26,7 +26,8 @@ import pytest import xarray as xr -from satpy.dataset.dataid import DataID, ModifierTuple, WavelengthRange +from satpy.dataset import ModifierTuple, WavelengthRange +from satpy.dataset.dataid import DataID from satpy.readers.slstr_l1b import NCSLSTR1B, NCSLSTRAngles, NCSLSTRFlag, NCSLSTRGeo local_id_keys_config = {"name": { diff --git a/satpy/tests/scene_tests/test_data_access.py b/satpy/tests/scene_tests/test_data_access.py index 66129ad8bb..b32a8a1f85 100644 --- a/satpy/tests/scene_tests/test_data_access.py +++ b/satpy/tests/scene_tests/test_data_access.py @@ -22,7 +22,7 @@ from dask import array as da from satpy import Scene -from satpy.dataset.dataid import default_id_keys_config +from satpy.dataset.id_keys import default_id_keys_config from satpy.tests.utils import FAKE_FILEHANDLER_END, FAKE_FILEHANDLER_START, make_cid, make_dataid # NOTE: diff --git a/satpy/tests/scene_tests/test_load.py b/satpy/tests/scene_tests/test_load.py index 7adf237c57..4df8b69573 100644 --- a/satpy/tests/scene_tests/test_load.py +++ b/satpy/tests/scene_tests/test_load.py @@ -587,7 +587,7 @@ def test_modified_with_wl_dep(self): nodes are unique and that DataIDs. """ - from satpy.dataset.dataid import WavelengthRange + from satpy.dataset import WavelengthRange # Check dependency tree nodes # initialize the dep tree without loading the data diff --git a/satpy/tests/scene_tests/test_resampling.py b/satpy/tests/scene_tests/test_resampling.py index d59019e3f7..d1da9c62b4 100644 --- a/satpy/tests/scene_tests/test_resampling.py +++ b/satpy/tests/scene_tests/test_resampling.py @@ -22,7 +22,7 @@ from dask import array as da from satpy import Scene -from satpy.dataset.dataid import default_id_keys_config +from satpy.dataset.id_keys import default_id_keys_config from satpy.tests.utils import make_cid, make_dataid # NOTE: diff --git a/satpy/tests/test_dataset.py b/satpy/tests/test_dataset.py index f6fb1c6ebb..909d80e8b9 100644 --- a/satpy/tests/test_dataset.py +++ b/satpy/tests/test_dataset.py @@ -18,11 +18,16 @@ import datetime as dt +import dask.array as da import numpy as np import pytest import xarray as xr -from satpy.dataset.dataid import DataID, DataQuery, ModifierTuple, WavelengthRange, minimal_default_keys_config +from satpy.dataset.data_dict import DatasetDict +from satpy.dataset.dataid import DataID, DataQuery +from satpy.dataset.id_keys import ModifierTuple, ValueList, WavelengthRange +from satpy.dataset.id_keys import default_id_keys_config as dikc +from satpy.dataset.id_keys import minimal_default_keys_config as mdkc from satpy.dataset.metadata import combine_metadata from satpy.readers.pmw_channels_definitions import FrequencyDoubleSideBand, FrequencyQuadrupleSideBand, FrequencyRange from satpy.tests.utils import make_cid, make_dataid, make_dsq @@ -33,10 +38,6 @@ class TestDataID: def test_basic_init(self): """Test basic ways of creating a DataID.""" - from satpy.dataset.dataid import DataID - from satpy.dataset.dataid import default_id_keys_config as dikc - from satpy.dataset.dataid import minimal_default_keys_config as mdkc - did = DataID(dikc, name="a") assert did["name"] == "a" assert did["modifiers"] == tuple() @@ -54,15 +55,11 @@ def test_basic_init(self): def test_init_bad_modifiers(self): """Test that modifiers are a tuple.""" - from satpy.dataset.dataid import DataID - from satpy.dataset.dataid import default_id_keys_config as dikc with pytest.raises(TypeError): DataID(dikc, name="a", modifiers="str") def test_compare_no_wl(self): """Compare fully qualified wavelength ID to no wavelength ID.""" - from satpy.dataset.dataid import DataID - from satpy.dataset.dataid import default_id_keys_config as dikc d1 = DataID(dikc, name="a", wavelength=(0.1, 0.2, 0.3)) d2 = DataID(dikc, name="a", wavelength=None) @@ -72,15 +69,11 @@ def test_compare_no_wl(self): def test_bad_calibration(self): """Test that asking for a bad calibration fails.""" - from satpy.dataset.dataid import DataID - from satpy.dataset.dataid import default_id_keys_config as dikc with pytest.raises(ValueError, match="_bad_ invalid value for "): DataID(dikc, name="C05", calibration="_bad_") def test_is_modified(self): """Test that modifications are detected properly.""" - from satpy.dataset.dataid import DataID - from satpy.dataset.dataid import default_id_keys_config as dikc d1 = DataID(dikc, name="a", wavelength=(0.1, 0.2, 0.3), modifiers=("hej",)) d2 = DataID(dikc, name="a", wavelength=(0.1, 0.2, 0.3), modifiers=tuple()) @@ -89,8 +82,6 @@ def test_is_modified(self): def test_create_less_modified_query(self): """Test that modifications are popped correctly.""" - from satpy.dataset.dataid import DataID - from satpy.dataset.dataid import default_id_keys_config as dikc d1 = DataID(dikc, name="a", wavelength=(0.1, 0.2, 0.3), modifiers=("hej",)) d2 = DataID(dikc, name="a", wavelength=(0.1, 0.2, 0.3), modifiers=tuple()) @@ -177,16 +168,13 @@ def test_combine_times(self, meta_dicts, key, result_idx): def test_combine_arrays(self): """Test the combine_metadata with arrays.""" - from numpy import arange, ones - from xarray import DataArray - dts = [ - {"quality": (arange(25) % 2).reshape(5, 5).astype("?")}, - {"quality": (arange(1, 26) % 3).reshape(5, 5).astype("?")}, - {"quality": ones((5, 5,), "?")}, + {"quality": (np.arange(25) % 2).reshape(5, 5).astype("?")}, + {"quality": (np.arange(1, 26) % 3).reshape(5, 5).astype("?")}, + {"quality": np.ones((5, 5,), "?")}, ] assert "quality" not in combine_metadata(*dts) - dts2 = [{"quality": DataArray(d["quality"])} for d in dts] + dts2 = [{"quality": xr.DataArray(d["quality"])} for d in dts] assert "quality" not in combine_metadata(*dts2) # the ancillary_variables attribute is actually a list of data arrays dts3 = [{"quality": [d["quality"]]} for d in dts] @@ -204,9 +192,9 @@ def test_combine_arrays(self): assert "quality" in combine_metadata(*dts5) # check with other types dts6 = [ - DataArray(arange(5), attrs=dts[0]), - DataArray(arange(5), attrs=dts[0]), - DataArray(arange(5), attrs=dts[1]), + xr.DataArray(np.arange(5), attrs=dts[0]), + xr.DataArray(np.arange(5), attrs=dts[0]), + xr.DataArray(np.arange(5), attrs=dts[1]), object() ] assert "quality" not in combine_metadata(*dts6) @@ -270,8 +258,6 @@ def test_combine_numpy_arrays(self): def test_combine_dask_arrays(self): """Test combining values that are dask arrays.""" - import dask.array as da - test_metadata = [{"valid_range": da.from_array(np.array([0., 0.00032], dtype=np.float32))}, {"valid_range": da.from_array(np.array([0., 0.00032], dtype=np.float32))}] result = combine_metadata(*test_metadata) @@ -406,8 +392,6 @@ def test_combine_dicts_different(test_mda): def test_dataid(): """Test the DataID object.""" - from satpy.dataset.dataid import DataID, ModifierTuple, ValueList, WavelengthRange - # Check that enum is translated to type. did = make_dataid() assert issubclass(did._id_keys["calibration"]["type"], ValueList) @@ -468,7 +452,6 @@ def test_dataid(): def test_dataid_equal_if_enums_different(): """Check that dataids with different enums but same items are equal.""" - from satpy.dataset.dataid import DataID, ModifierTuple, WavelengthRange id_keys_config1 = {"name": None, "wavelength": { "type": WavelengthRange, @@ -513,9 +496,6 @@ def test_dataid_copy(): """Test copying a DataID.""" from copy import deepcopy - from satpy.dataset.dataid import DataID - from satpy.dataset.dataid import default_id_keys_config as dikc - did = DataID(dikc, name="a", resolution=1000) did2 = deepcopy(did) assert did2 == did @@ -526,7 +506,6 @@ def test_dataid_pickle(): """Test dataid pickling roundtrip.""" import pickle - from satpy.tests.utils import make_dataid did = make_dataid(name="hi", wavelength=(10, 11, 12), resolution=1000, calibration="radiance") assert did == pickle.loads(pickle.dumps(did)) @@ -541,7 +520,6 @@ def test_dataid_elements_picklable(): """ import pickle - from satpy.tests.utils import make_dataid did = make_dataid(name="hi", wavelength=(10, 11, 12), resolution=1000, calibration="radiance") for value in did.values(): pickled_value = pickle.loads(pickle.dumps(value)) @@ -553,8 +531,6 @@ class TestDataQuery: def test_dataquery(self): """Test DataQuery objects.""" - from satpy.dataset import DataQuery - DataQuery(name="cheese_shops") # Check repr @@ -566,7 +542,6 @@ def test_dataquery(self): def test_is_modified(self): """Test that modifications are detected properly.""" - from satpy.dataset import DataQuery d1 = DataQuery(name="a", wavelength=0.2, modifiers=("hej",)) d2 = DataQuery(name="a", wavelength=0.2, modifiers=tuple()) @@ -575,7 +550,6 @@ def test_is_modified(self): def test_create_less_modified_query(self): """Test that modifications are popped correctly.""" - from satpy.dataset import DataQuery d1 = DataQuery(name="a", wavelength=0.2, modifiers=("hej",)) d2 = DataQuery(name="a", wavelength=0.2, modifiers=tuple()) @@ -669,7 +643,7 @@ def test_id_filtering_wavelength(self, id_kwargs, query_kwargs, exp_match): def test_id_filtering_composite_resolution(self): """Test that a query for a composite with resolution still finds the composite.""" dataid_container = [ - DataID(minimal_default_keys_config, name="natural_color"), + DataID(mdkc, name="natural_color"), ] dq = DataQuery(name="natural_color", resolution=250) assert len(dq.filter_dataids(dataid_container)) == 0 @@ -1113,8 +1087,6 @@ def test_dataset_dict_contains_inexact_match(): See https://github.com/pytroll/satpy/issues/2331. """ - from satpy.dataset.data_dict import DatasetDict - dd = DatasetDict() name = "1" item = xr.DataArray(()) diff --git a/satpy/tests/test_readers.py b/satpy/tests/test_readers.py index f11d181833..769ff80c9b 100644 --- a/satpy/tests/test_readers.py +++ b/satpy/tests/test_readers.py @@ -34,8 +34,9 @@ import xarray as xr from pytest_lazy_fixtures import lf as lazy_fixture +from satpy.dataset import ModifierTuple, WavelengthRange from satpy.dataset.data_dict import get_key -from satpy.dataset.dataid import DataID, ModifierTuple, WavelengthRange +from satpy.dataset.dataid import DataID from satpy.readers import FSFile, find_files_and_readers, open_file_or_filename # NOTE: diff --git a/satpy/tests/test_yaml_reader.py b/satpy/tests/test_yaml_reader.py index 9e9bc7fa00..10d03f83ce 100644 --- a/satpy/tests/test_yaml_reader.py +++ b/satpy/tests/test_yaml_reader.py @@ -31,8 +31,7 @@ import satpy.readers.yaml_reader as yr from satpy._compat import cache -from satpy.dataset import DataQuery -from satpy.dataset.dataid import ModifierTuple +from satpy.dataset import DataQuery, ModifierTuple from satpy.readers.file_handlers import BaseFileHandler from satpy.readers.pmw_channels_definitions import FrequencyDoubleSideBand, FrequencyRange from satpy.tests.utils import make_dataid diff --git a/satpy/tests/utils.py b/satpy/tests/utils.py index ea8b01c0df..1fb8291085 100644 --- a/satpy/tests/utils.py +++ b/satpy/tests/utils.py @@ -31,7 +31,7 @@ from satpy import Scene from satpy.composites import GenericCompositor, IncompatibleAreas from satpy.dataset import DataID, DataQuery -from satpy.dataset.dataid import default_id_keys_config, minimal_default_keys_config +from satpy.dataset.id_keys import default_id_keys_config, minimal_default_keys_config from satpy.modifiers import ModifierBase from satpy.readers.file_handlers import BaseFileHandler From 2085b500de20066d08b6004682164f630145193d Mon Sep 17 00:00:00 2001 From: David Hoese Date: Mon, 16 Dec 2024 14:33:03 -0600 Subject: [PATCH 14/22] Extract ID update with query function --- satpy/dataset/dataid.py | 20 +++++++++++++++++++- satpy/dependency_tree.py | 14 ++------------ 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/satpy/dataset/dataid.py b/satpy/dataset/dataid.py index 1f17568ad1..b90623efed 100644 --- a/satpy/dataset/dataid.py +++ b/satpy/dataset/dataid.py @@ -23,7 +23,7 @@ import numpy as np -from satpy.dataset.id_keys import ModifierTuple, ValueList, minimal_default_keys_config +from satpy.dataset.id_keys import ModifierTuple, ValueList, default_id_keys_config, minimal_default_keys_config logger = logging.getLogger(__name__) @@ -526,3 +526,21 @@ def _create_id_dict_from_any_key(dataset_key: DataQuery | DataID | str | numbers else: raise TypeError("Don't know how to interpret a dataset_key of type {}".format(type(dataset_key))) return ds_dict + + +def update_id_with_query(orig_id: DataID, query: DataQuery) -> DataID: + """Update a DataID with additional info from a query used to find it.""" + query_dict = query.to_dict() + if not query_dict: + return orig_id + + new_id_dict = orig_id.to_dict() + orig_id_keys = orig_id.id_keys + for query_key, query_val in query_dict.items(): + # XXX: What if the query_val is a list? + if new_id_dict.get(query_key) is None: + new_id_dict[query_key] = query_val + # don't replace ID key information if we don't have to + id_keys = orig_id_keys if all(key in orig_id_keys for key in new_id_dict) else default_id_keys_config + new_id = DataID(id_keys, **new_id_dict) + return new_id diff --git a/satpy/dependency_tree.py b/satpy/dependency_tree.py index 7e0ba442b4..66624ed6be 100644 --- a/satpy/dependency_tree.py +++ b/satpy/dependency_tree.py @@ -28,7 +28,7 @@ from satpy import DataID, DataQuery, DatasetDict from satpy.dataset import ModifierTuple, create_filtered_query from satpy.dataset.data_dict import TooManyResults, get_key -from satpy.dataset.id_keys import default_id_keys_config +from satpy.dataset.dataid import update_id_with_query from satpy.node import EMPTY_LEAF_NAME, LOG, CompositorNode, MissingDependencies, Node, ReaderNode @@ -443,17 +443,7 @@ def _find_compositor(self, dataset_key, query): except KeyError: raise KeyError("Can't find anything called {}".format(str(dataset_key))) - new_id_dict = compositor.id.to_dict() - new_id = None - # TODO: dataset_key could include ID parameters from composite YAML, is this different from load kwargs? - if compositor.id.to_dict() != dataset_key._asdict(): - id_keys = default_id_keys_config # minimal_default_keys_config - for query_key, query_val in dataset_key.to_dict().items(): - # XXX: What if the query_val is a list? - if new_id_dict.get(query_key) is None and query_key in id_keys and query_val != "*": - new_id_dict[query_key] = query_val - new_id = DataID(id_keys, **new_id_dict) - + new_id = update_id_with_query(compositor.id, dataset_key) root = CompositorNode(compositor, new_id=new_id) composite_id = root.name From 6afc26d111d95c18105e4871ba7c26de0db3d8c9 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Mon, 16 Dec 2024 14:33:54 -0600 Subject: [PATCH 15/22] Remove unused _match_query_value method --- satpy/dataset/dataid.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/satpy/dataset/dataid.py b/satpy/dataset/dataid.py index b90623efed..b2281c2eaa 100644 --- a/satpy/dataset/dataid.py +++ b/satpy/dataset/dataid.py @@ -379,16 +379,6 @@ def filter_dataids(self, dataid_container): keys = list(filter(self.__eq__, dataid_container)) return keys - def _match_query_value(self, key, id_val): - val = self._dict[key] - if val == "*": - return True - if isinstance(id_val, tuple) and isinstance(val, (tuple, list)): - return tuple(val) == id_val - if not isinstance(val, list): - val = [val] - return id_val in val - def sort_dataids_with_preference(self, all_ids, preference): """Sort `all_ids` given a sorting `preference` (DataQuery or None).""" try: From 633802f1142bb3c8b0ecf2c21297b835e7d1883b Mon Sep 17 00:00:00 2001 From: David Hoese Date: Tue, 17 Dec 2024 12:41:29 -0600 Subject: [PATCH 16/22] Add shared_key option to DataQuery equality checks --- satpy/dataset/dataid.py | 26 +++++++++++++++++++++++--- satpy/dependency_tree.py | 27 +++++++++------------------ 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/satpy/dataset/dataid.py b/satpy/dataset/dataid.py index b2281c2eaa..588f01edcf 100644 --- a/satpy/dataset/dataid.py +++ b/satpy/dataset/dataid.py @@ -14,11 +14,13 @@ # You should have received a copy of the GNU General Public License along with # satpy. If not, see . """Dataset identifying objects.""" +from __future__ import annotations import logging import numbers from copy import copy, deepcopy from enum import Enum +from functools import partial from typing import Any, NoReturn import numpy as np @@ -261,7 +263,7 @@ def __getitem__(self, key): """Get an item.""" return self._dict[key] - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Compare the DataQuerys. A DataQuery is considered equal to another DataQuery if all keys @@ -271,6 +273,20 @@ def __eq__(self, other): contains additional elements. Any DataQuery elements with the value ``"*"`` are ignored. + """ + return self.equal(other, shared_keys=False) + + def equal(self, other: Any, shared_keys: bool = False) -> bool: + """Compare this DataQuery to another DataQuery or a DataID. + + Args: + other: Other DataQuery or DataID to compare against. + shared_keys: Limit keys being compared to those shared + by both objects. If False (default), then all of the + current query's keys are used when compared against + a DataID. If compared against another DataQuery then + all keys are compared between the two queries. + """ sdict = self._asdict() try: @@ -287,6 +303,9 @@ def __eq__(self, other): if not o_is_id: # if another DataQuery, then compare both sets of keys keys_to_match |= set(odict.keys()) + if shared_keys: + # only compare with the keys that both objects share + keys_to_match &= set(odict.keys()) if not keys_to_match: return False @@ -374,9 +393,10 @@ def __repr__(self): items = ("{}={}".format(key, repr(val)) for key, val in zip(self._fields, self._values)) return self.__class__.__name__ + "(" + ", ".join(items) + ")" - def filter_dataids(self, dataid_container): + def filter_dataids(self, dataid_container, shared_keys: bool = False): """Filter DataIDs based on this query.""" - keys = list(filter(self.__eq__, dataid_container)) + func = partial(self.equal, shared_keys=shared_keys) + keys = list(filter(func, dataid_container)) return keys def sort_dataids_with_preference(self, all_ids, preference): diff --git a/satpy/dependency_tree.py b/satpy/dependency_tree.py index 66624ed6be..58638f834a 100644 --- a/satpy/dependency_tree.py +++ b/satpy/dependency_tree.py @@ -515,26 +515,17 @@ def _get_compositor_by_name(self, key: DataQuery) -> Compositor | None: name_query = DataQuery(name=key["name"]) for sensor_name in sorted(self.compositors): sensor_data_dict = self.compositors[sensor_name] - try: - # get all IDs that have the minimum "distance" for our composite name - all_comp_ids = sensor_data_dict.get_key(name_query, num_results=0) - # Filter to those that don't disagree with the original query - matching_comp_ids = [] - for comp_id in all_comp_ids: - for query_key, query_val in key.to_dict().items(): - # TODO: Handle query_vals that are lists - if comp_id.get(query_key, query_val) != query_val: - break - else: - # all query keys match - matching_comp_ids.append(comp_id) - if len(matching_comp_ids) > 1: - warnings.warn("Multiple compositors matching {name_query} to create {key} variant. " - "Going to use the name-only 'base' compositor definition.") - matching_comp_ids = matching_comp_ids[:1] - except KeyError: + # get all IDs that have the minimum "distance" for our composite name + all_comp_ids = sensor_data_dict.get_key(name_query, num_results=0) + if len(all_comp_ids) == 0: continue + # Filter to those that don't disagree with the original query + matching_comp_ids = key.filter_dataids(all_comp_ids, shared_keys=True) + if len(matching_comp_ids) > 1: + warnings.warn("Multiple compositors matching {name_query} to create {key} variant. " + "Going to use the name-only 'base' compositor definition.") + matching_comp_ids = matching_comp_ids[:1] if len(matching_comp_ids) != 1: raise KeyError("Can't find compositor {key['name']} by name only.") comp_id = matching_comp_ids[0] From 0f77b4e6195e436ee7558cb9006b7129ebcd19d5 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Tue, 17 Dec 2024 13:12:50 -0600 Subject: [PATCH 17/22] Refactor get_keys_from_config for simpler conditionals --- satpy/dataset/id_keys.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/satpy/dataset/id_keys.py b/satpy/dataset/id_keys.py index 46c16ea620..49699013d3 100644 --- a/satpy/dataset/id_keys.py +++ b/satpy/dataset/id_keys.py @@ -7,13 +7,13 @@ import numpy as np -def get_keys_from_config(common_id_keys, config): +def get_keys_from_config(common_id_keys: dict, config: dict) -> dict: """Gather keys for a new DataID from the ones available in configured dataset.""" id_keys = {} for key, val in common_id_keys.items(): - if key in config: - id_keys[key] = val - elif val is not None and (val.get("required") is True or val.get("default") is not None): + has_key = key in config + is_required_or_default = val is not None and (val.get("required") is True or val.get("default") is not None) + if has_key or is_required_or_default: id_keys[key] = val if not id_keys: raise ValueError("Metadata does not contain enough information to create a DataID.") From 67cb550b1ecece7bc10271ad133a63840c02970a Mon Sep 17 00:00:00 2001 From: David Hoese Date: Tue, 17 Dec 2024 13:24:42 -0600 Subject: [PATCH 18/22] Attempt to refactor DataQuery equality --- satpy/dataset/dataid.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/satpy/dataset/dataid.py b/satpy/dataset/dataid.py index 588f01edcf..42911a0eb8 100644 --- a/satpy/dataset/dataid.py +++ b/satpy/dataset/dataid.py @@ -298,14 +298,8 @@ def equal(self, other: Any, shared_keys: bool = False) -> bool: return True # if other is a DataID then must match this query exactly - keys_to_match = set(sdict.keys()) o_is_id = hasattr(other, "id_keys") - if not o_is_id: - # if another DataQuery, then compare both sets of keys - keys_to_match |= set(odict.keys()) - if shared_keys: - # only compare with the keys that both objects share - keys_to_match &= set(odict.keys()) + keys_to_match = self._keys_to_compare(sdict, odict, o_is_id, shared_keys) if not keys_to_match: return False @@ -314,6 +308,17 @@ def equal(self, other: Any, shared_keys: bool = False) -> bool: return False return True + @staticmethod + def _keys_to_compare(sdict: dict, odict: dict, o_is_id: bool, shared_keys: bool) -> set: + keys_to_match = set(sdict.keys()) + if not o_is_id: + # if another DataQuery, then compare both sets of keys + keys_to_match |= set(odict.keys()) + if shared_keys: + # only compare with the keys that both objects share + keys_to_match &= set(odict.keys()) + return keys_to_match + @staticmethod def _compare_key_equality(sdict: dict, odict: dict, key: str, o_is_id: bool) -> bool: if key not in sdict: From 52469d493150f0a919c9471b701e5b928127a269 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Tue, 17 Dec 2024 13:36:32 -0600 Subject: [PATCH 19/22] Refactor DataQuery equality checks --- satpy/dataset/dataid.py | 94 +++++++++++++++++++++-------------------- 1 file changed, 49 insertions(+), 45 deletions(-) diff --git a/satpy/dataset/dataid.py b/satpy/dataset/dataid.py index 42911a0eb8..a9be87fffa 100644 --- a/satpy/dataset/dataid.py +++ b/satpy/dataset/dataid.py @@ -299,58 +299,15 @@ def equal(self, other: Any, shared_keys: bool = False) -> bool: # if other is a DataID then must match this query exactly o_is_id = hasattr(other, "id_keys") - keys_to_match = self._keys_to_compare(sdict, odict, o_is_id, shared_keys) + keys_to_match = _keys_to_compare(sdict, odict, o_is_id, shared_keys) if not keys_to_match: return False for key in keys_to_match: - if not self._compare_key_equality(sdict, odict, key, o_is_id): + if not _compare_key_equality(sdict, odict, key, o_is_id): return False return True - @staticmethod - def _keys_to_compare(sdict: dict, odict: dict, o_is_id: bool, shared_keys: bool) -> set: - keys_to_match = set(sdict.keys()) - if not o_is_id: - # if another DataQuery, then compare both sets of keys - keys_to_match |= set(odict.keys()) - if shared_keys: - # only compare with the keys that both objects share - keys_to_match &= set(odict.keys()) - return keys_to_match - - @staticmethod - def _compare_key_equality(sdict: dict, odict: dict, key: str, o_is_id: bool) -> bool: - if key not in sdict: - return False - sval = sdict[key] - if sval == "*": - return True - - if key not in odict: - return False - oval = odict[key] - if oval == "*": - # Gotcha: if a DataID contains a "*" this could cause - # unexpected matches. A DataID is not expected to use "*" - return True - - if isinstance(sval, list) or isinstance(oval, list): - # multiple options to match - if not isinstance(sval, list): - # query to query comparison, make a list to iterate over - sval = [sval] - if o_is_id: - return oval in sval - - # we're matching against a DataQuery who could have its own list - if not isinstance(oval, list): - oval = [oval] - s_in_o = any(_sval in oval for _sval in sval) - o_in_s = any(_oval in sval for _oval in oval) - return s_in_o or o_in_s - return oval == sval - def __hash__(self): """Hash.""" fields = [] @@ -559,3 +516,50 @@ def update_id_with_query(orig_id: DataID, query: DataQuery) -> DataID: id_keys = orig_id_keys if all(key in orig_id_keys for key in new_id_dict) else default_id_keys_config new_id = DataID(id_keys, **new_id_dict) return new_id + + +def _keys_to_compare(sdict: dict, odict: dict, o_is_id: bool, shared_keys: bool) -> set: + keys_to_match = set(sdict.keys()) + if not o_is_id: + # if another DataQuery, then compare both sets of keys + keys_to_match |= set(odict.keys()) + if shared_keys: + # only compare with the keys that both objects share + keys_to_match &= set(odict.keys()) + return keys_to_match + + +def _compare_key_equality(sdict: dict, odict: dict, key: str, o_is_id: bool) -> bool: + if key not in sdict: + return False + sval = sdict[key] + if sval == "*": + return True + + if key not in odict: + return False + oval = odict[key] + if oval == "*": + # Gotcha: if a DataID contains a "*" this could cause + # unexpected matches. A DataID is not expected to use "*" + return True + + return _compare_values(sval, oval, o_is_id) + + +def _compare_values(sval: Any, oval: Any, o_is_id: bool) -> bool: + if isinstance(sval, list) or isinstance(oval, list): + # multiple options to match + if not isinstance(sval, list): + # query to query comparison, make a list to iterate over + sval = [sval] + if o_is_id: + return oval in sval + + # we're matching against a DataQuery who could have its own list + if not isinstance(oval, list): + oval = [oval] + s_in_o = any(_sval in oval for _sval in sval) + o_in_s = any(_oval in sval for _oval in oval) + return s_in_o or o_in_s + return oval == sval From 68880a41f21a8e79dc189f092039f13967aab3f6 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Tue, 17 Dec 2024 14:08:25 -0600 Subject: [PATCH 20/22] Another try at making CodeScene happy with get_keys_from_config --- satpy/dataset/id_keys.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/satpy/dataset/id_keys.py b/satpy/dataset/id_keys.py index 49699013d3..504b4a0a49 100644 --- a/satpy/dataset/id_keys.py +++ b/satpy/dataset/id_keys.py @@ -11,9 +11,11 @@ def get_keys_from_config(common_id_keys: dict, config: dict) -> dict: """Gather keys for a new DataID from the ones available in configured dataset.""" id_keys = {} for key, val in common_id_keys.items(): - has_key = key in config - is_required_or_default = val is not None and (val.get("required") is True or val.get("default") is not None) - if has_key or is_required_or_default: + if key in config: + id_keys[key] = val + if val is None: + continue + if val.get("required") is True or val.get("default") is not None: id_keys[key] = val if not id_keys: raise ValueError("Metadata does not contain enough information to create a DataID.") From de36fb80f2c27e7238bc3af90dcb77e825290a05 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Tue, 17 Dec 2024 14:08:40 -0600 Subject: [PATCH 21/22] Remove duplicated test code in test_satpy_cf_nc.py --- satpy/tests/reader_tests/test_satpy_cf_nc.py | 41 +++++++------------- 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/satpy/tests/reader_tests/test_satpy_cf_nc.py b/satpy/tests/reader_tests/test_satpy_cf_nc.py index 94b16e3faa..bac29b0f54 100644 --- a/satpy/tests/reader_tests/test_satpy_cf_nc.py +++ b/satpy/tests/reader_tests/test_satpy_cf_nc.py @@ -475,34 +475,23 @@ def test_write_and_read_from_two_files(self, nc_filename, nc_filename_i): scn_.load(["solar_zenith_angle"], resolution=371) assert scn_["solar_zenith_angle"].attrs["resolution"] == 371 - def test_dataid_attrs_equal_matching_dataset(self, cf_scene, nc_filename): - """Check that get_dataset returns valid dataset when keys matches.""" + @pytest.mark.parametrize( + ("data_id_kwargs", "exp_match"), + [ + ({"name": "solar_zenith_angle", "resolution": 742, "modifiers": ()}, True), + ({"name": "solar_zenith_angle", "resolution": 9999999, "modifiers": ()}, False), + ({"name": "solar_zenith_angle", "resolution": 742, "modifiers": (), "calibration": "counts"}, True), + ], + ) + def test_dataid_attrs_equal_matching_dataset(self, cf_scene, nc_filename, data_id_kwargs, exp_match): + """Check that get_dataset returns valid dataset when keys match.""" from satpy.dataset.dataid import DataID from satpy.dataset.id_keys import default_id_keys_config _create_test_netcdf(nc_filename, resolution=742) reader = SatpyCFFileHandler(nc_filename, {}, {"filetype": "info"}) - ds_id = DataID(default_id_keys_config, name="solar_zenith_angle", resolution=742, modifiers=()) + ds_id = DataID(default_id_keys_config, **data_id_kwargs) res = reader.get_dataset(ds_id, {}) - assert res.attrs["resolution"] == 742 - - def test_dataid_attrs_equal_not_matching_dataset(self, cf_scene, nc_filename): - """Check that get_dataset returns None when key(s) are not matching.""" - from satpy.dataset.dataid import DataID - from satpy.dataset.id_keys import default_id_keys_config - _create_test_netcdf(nc_filename, resolution=742) - reader = SatpyCFFileHandler(nc_filename, {}, {"filetype": "info"}) - not_existing_resolution = 9999999 - ds_id = DataID(default_id_keys_config, name="solar_zenith_angle", resolution=not_existing_resolution, - modifiers=()) - assert reader.get_dataset(ds_id, {}) is None - - def test_dataid_attrs_equal_contains_not_matching_key(self, cf_scene, nc_filename): - """Check that get_dataset returns valid dataset when dataid have key(s) not existing in data.""" - from satpy.dataset.dataid import DataID - from satpy.dataset.id_keys import default_id_keys_config - _create_test_netcdf(nc_filename, resolution=742) - reader = SatpyCFFileHandler(nc_filename, {}, {"filetype": "info"}) - ds_id = DataID(default_id_keys_config, name="solar_zenith_angle", resolution=742, - modifiers=(), calibration="counts") - res = reader.get_dataset(ds_id, {}) - assert res.attrs["resolution"] == 742 + if not exp_match: + assert res is None + else: + assert res.attrs["resolution"] == 742 From c45ed8d831284405826d53bc919b5f09ccfa2d00 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Tue, 17 Dec 2024 15:34:17 -0600 Subject: [PATCH 22/22] Remove accidental holoviews import in dependency tree --- satpy/dependency_tree.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/satpy/dependency_tree.py b/satpy/dependency_tree.py index 58638f834a..249a14dada 100644 --- a/satpy/dependency_tree.py +++ b/satpy/dependency_tree.py @@ -20,10 +20,9 @@ from __future__ import annotations import warnings -from typing import Container, Iterable, Optional +from typing import TYPE_CHECKING, Container, Iterable, Optional import numpy as np -from holoviews.core.options import Compositor from satpy import DataID, DataQuery, DatasetDict from satpy.dataset import ModifierTuple, create_filtered_query @@ -31,6 +30,9 @@ from satpy.dataset.dataid import update_id_with_query from satpy.node import EMPTY_LEAF_NAME, LOG, CompositorNode, MissingDependencies, Node, ReaderNode +if TYPE_CHECKING: + from satpy.composites import CompositeBase + class Tree: """A tree implementation.""" @@ -511,7 +513,7 @@ def get_compositor(self, key: DataQuery): return self._get_compositor_by_name(key) - def _get_compositor_by_name(self, key: DataQuery) -> Compositor | None: + def _get_compositor_by_name(self, key: DataQuery) -> CompositeBase | None: name_query = DataQuery(name=key["name"]) for sensor_name in sorted(self.compositors): sensor_data_dict = self.compositors[sensor_name]