diff --git a/invenio_records/dictutils.py b/invenio_records/dictutils.py index f78cf379..ed7897d2 100644 --- a/invenio_records/dictutils.py +++ b/invenio_records/dictutils.py @@ -8,8 +8,6 @@ """Dictionary utilities.""" -from copy import deepcopy - def clear_none(d): """Clear None values and empty dicts from a dict.""" @@ -147,3 +145,37 @@ def dict_merge(dest, source): dict_merge(dest[key], source[key]) else: dest[key] = source[key] + + +def filter_dict_keys(src, keys): + """Filter a dictionary based on a list of key paths.""" + # Split the keys into top-level and nested keys + top_level_keys = [key for key in keys if "." not in key] + nested_keys = [key for key in keys if "." in key] + + # Filter the top-level keys + result = {key: src[key] for key in top_level_keys if key in src} + + # Handle nested keys + for key in nested_keys: + parts = key.split(".") + current_dict = src + for part in parts[:-1]: + if part in current_dict: + current_dict = current_dict[part] + else: + break # Skip this key if the path does not exist + # Update the filtered dictionary with the nested key if it exists + if parts[-2] in result and parts[-1] in current_dict: + if parts[-2] not in result: + result[parts[-2]] = {} + result[parts[-2]][parts[-1]] = current_dict[parts[-1]] + + # Handle specific case for top-level keys that are dictionaries but not explicitly mentioned + for key in src: + if key not in result and isinstance(src[key], dict): + subkeys = [k.split(".", 1)[1] for k in keys if k.startswith(f"{key}.")] + if subkeys: + result[key] = filter_dict_keys(src[key], subkeys) + + return result diff --git a/tests/test_dictutils.py b/tests/test_dictutils.py index 7e7c8ab6..c276e902 100644 --- a/tests/test_dictutils.py +++ b/tests/test_dictutils.py @@ -12,7 +12,12 @@ import pytest -from invenio_records.dictutils import clear_none, dict_lookup, dict_merge +from invenio_records.dictutils import ( + clear_none, + dict_lookup, + dict_merge, + filter_dict_keys, +) def test_clear_none(): @@ -78,3 +83,17 @@ def test_dict_merge(): "foo2": "bar2", "metadata": {"field1": 3, "field2": "test"}, } + + +def test_filter_dict_keys(): + """Test filter dict keys.""" + source = { + "foo1": {"bar1": 1, "bar2": 2}, + "foo2": 1, + "foo3": {"bar1": {"foo4": 0}, "bar2": 1, "bar3": 2}, + "foo4": {}, + } + + assert filter_dict_keys( + source, ["foo1.bar1", "foo2", "foo3.bar1.foo4", "foo3.bar3"] + ) == {"foo1": {"bar1": 1}, "foo2": 1, "foo3": {"bar1": {"foo4": 0}, "bar3": 2}} diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 2f1d61ba..be2df9d4 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -41,9 +41,9 @@ def local_ref_resolver_store_factory(): @pytest.fixture(scope="module") def app_config(app_config): - app_config[ - "RECORDS_REFRESOLVER_CLS" - ] = "invenio_records.resolver.InvenioRefResolver" + app_config["RECORDS_REFRESOLVER_CLS"] = ( + "invenio_records.resolver.InvenioRefResolver" + ) app_config["RECORDS_REFRESOLVER_STORE"] = local_ref_resolver_store_factory() return app_config