Skip to content

Commit

Permalink
Support AiiDA 2.6
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhollas committed Dec 9, 2024
1 parent 1eaca43 commit fe97735
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 29 deletions.
4 changes: 3 additions & 1 deletion .aiida-test-cache-config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ archive_cache:
# The test archives have version 0.8
- environment_variables_double_quotes # This option was introduced in aiida-core 2.0
- submit_script_filename # This option was introduced in aiida-core 1.2.1 (archive version 0.9)
- metadata_inputs # Added in aiida-core 2.3.0
# metadata_inputs is now ignored automatically since AiiDA v2.6.0
# commit 4626b11f85cd0d95a17d8f5766a90b88ddddd689
#- metadata_inputs # Added in aiida-core 2.3.0
32 changes: 17 additions & 15 deletions aiida_test_cache/archive_cache/_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def mock_objects_to_hash_code(self):
"""
self = get_node_from_hash_objects_caller(self)
# computer names are changed by aiida-core if imported and do not have same uuid.
return [self.base.attributes.get(key='input_plugin')]
return {'input_plugin': self.base.attributes.get(key='input_plugin')}

def mock_objects_to_hash_calcjob(self):
"""
Expand All @@ -263,17 +263,19 @@ def mock_objects_to_hash_calcjob(self):
self._hash_ignored_attributes = tuple(self._hash_ignored_attributes) + \
calcjob_ignored_attributes

objects = [{
key: val
for key, val in self.base.attributes.items()
if key not in self._hash_ignored_attributes and key not in self._updatable_attributes
},
{
entry.link_label: entry.node.base.caching.get_hash()
for entry in self.base.links.get_incoming(
link_type=(LinkType.INPUT_CALC, LinkType.INPUT_WORK)
) if entry.link_label not in hash_ignored_inputs
}]
objects = {
"attributes": {
key: val
for key, val in self.base.attributes.items() if
key not in self._hash_ignored_attributes and key not in self._updatable_attributes
},
"inputs": {
entry.link_label: entry.node.base.caching.compute_hash()
for entry in
self.base.links.get_incoming(link_type=(LinkType.INPUT_CALC, LinkType.INPUT_WORK))
if entry.link_label not in hash_ignored_inputs
},
}
return objects

def mock_objects_to_hash(self):
Expand All @@ -286,13 +288,13 @@ def mock_objects_to_hash(self):
self._hash_ignored_attributes = tuple(self._hash_ignored_attributes) + \
node_ignored_attributes.get(class_name, ('version',))

objects = [
{
objects = {
"attributes": {
key: val
for key, val in self.base.attributes.items() if
key not in self._hash_ignored_attributes and key not in self._updatable_attributes
},
]
}
return objects

monkeypatch_hash_objects(monkeypatch, AbstractCode, mock_objects_to_hash_code)
Expand Down
18 changes: 9 additions & 9 deletions aiida_test_cache/archive_cache/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def monkeypatch_hash_objects(
monkeypatch: pytest.MonkeyPatch, node_class: type[Node], hash_objects_func: ty.Callable
) -> None:
"""
Monkeypatch the _get_objects_to_hash method in aiida-core for the given node class
Monkeypatch the get_objects_to_hash method in aiida-core for the given node class
:param monkeypatch: monkeypatch fixture of pytest
:param node_class: Node class to monkeypatch
:param hash_objects_func: function, which should be called instead of the
`_get_objects_to_hash` method
`get_objects_to_hash` method
.. note::
Expand All @@ -47,16 +47,16 @@ def monkeypatch_hash_objects(
"""
try:
monkeypatch.setattr(node_class, "_get_objects_to_hash", hash_objects_func)
monkeypatch.setattr(node_class, "get_objects_to_hash", hash_objects_func)
except AttributeError:
node_caching_class = node_class._CLS_NODE_CACHING

class MockNodeCaching(node_caching_class): #type: ignore
"""
NodeCaching subclass with stripped down _get_objects_to_hash method
NodeCaching subclass with stripped down get_objects_to_hash method
"""

def _get_objects_to_hash(self):
def get_objects_to_hash(self):
return hash_objects_func(self)

monkeypatch.setattr(node_class, "_CLS_NODE_CACHING", MockNodeCaching)
Expand All @@ -65,12 +65,12 @@ def _get_objects_to_hash(self):
def get_node_from_hash_objects_caller(caller: ty.Any) -> Node:
"""
Get the actual node instance from the class calling the
_get_objects_to_hash method
get_objects_to_hash method
:param caller: object holding _get_objects_to_hash
:param caller: object holding get_objects_to_hash
"""
#Case for AiiDA 2.0: The class holding the _get_objects_to_hash method
#is the NodeCaching class not the actual node
# Case for AiiDA 2.0: The class holding the get_objects_to_hash method
# is the NodeCaching class not the actual node
return caller._node #type: ignore[no-any-return]


Expand Down
8 changes: 4 additions & 4 deletions tests/archive_cache/test_archive_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _check_diff_workchain(res, node, should_have_used_cache=True):
cache_src = diffjob.base.caching.get_cache_source()

calc_hash = diffjob.base.caching.get_hash()
assert calc_hash == EXPECTED_HASH, f'Hash mismatch. hashed objects: {diffjob.base.caching._get_objects_to_hash()}'
assert calc_hash == EXPECTED_HASH, f'Hash mismatch. hashed objects: {diffjob.base.caching.get_objects_to_hash()}'

#Make sure that the cache was used if it should have been
if should_have_used_cache:
Expand Down Expand Up @@ -127,16 +127,16 @@ def test_load_node_archive(aiida_profile_clean, absolute_archive_path):


def test_mock_hash_codes(aiida_profile_clean, mock_code_factory, liberal_hash):
"""test if mock of _get_objects_to_hash works for Code and Calcs"""
"""test if mock of get_objects_to_hash() works for Code and Calcs"""

mock_code = mock_code_factory(
label='diff',
data_dir_abspath=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'calc_data'),
entry_point=CALC_ENTRY_POINT,
ignore_paths=('_aiidasubmit.sh', 'file*')
)
objs = mock_code.base.caching._get_objects_to_hash()
assert objs == [mock_code.base.attributes.get(key='input_plugin')]
objs = mock_code.base.caching.get_objects_to_hash()
assert objs == {'input_plugin': mock_code.base.attributes.get(key='input_plugin')}


@pytest.mark.parametrize(
Expand Down

0 comments on commit fe97735

Please sign in to comment.