From c353cc2e4104352ef9b5490adb53a60da47f293d Mon Sep 17 00:00:00 2001 From: Lorenzo <79980269+bastonero@users.noreply.github.com> Date: Mon, 12 Feb 2024 11:22:11 +0100 Subject: [PATCH] `PhBaseWorkChain`: fix `set_qpoints` step The `set_qpoints` step in the outline of the `PhBaseWorkChain` contained several errors incorrectly assuming that the inputs of the `PhCalculation` are found in the `self.ctx.inputs.ph` namespace of the context. These should actually be placed in the `self.ctx.inputs`, which is where the `BaseRestartWorkChain` expects to find the inputs of the process class it wraps. Here we correctly assign the inputs in the context. `Additionally, the `set_qpoints` step would assume that the `qpoints_force_parity` input of the PhBaseWorkChain is always present. However, this is not a required input, and hence we take this in consideration in the `set_qpoints` logic. --- .../workflows/ph/base.py | 9 ++-- tests/conftest.py | 35 +++++++++++++--- tests/workflows/ph/test_base.py | 42 +++++++++++++++---- 3 files changed, 66 insertions(+), 20 deletions(-) diff --git a/src/aiida_quantumespresso/workflows/ph/base.py b/src/aiida_quantumespresso/workflows/ph/base.py index 158d709fc..f26e771b3 100644 --- a/src/aiida_quantumespresso/workflows/ph/base.py +++ b/src/aiida_quantumespresso/workflows/ph/base.py @@ -177,27 +177,26 @@ def set_qpoints(self): the case of the latter, the `KpointsData` will be constructed for the input `StructureData` from the parent_folder using the `create_kpoints_from_distance` calculation function. """ - try: qpoints = self.inputs.qpoints except AttributeError: try: - structure = self.ctx.inputs.ph.parent_folder.creator.output.output_structure + structure = self.ctx.inputs.parent_folder.creator.output.output_structure except AttributeError: - structure = self.ctx.inputs.ph.parent_folder.creator.inputs.structure + structure = self.ctx.inputs.parent_folder.creator.inputs.structure inputs = { 'structure': structure, 'distance': self.inputs.qpoints_distance, - 'force_parity': self.inputs.qpoints_force_parity, + 'force_parity': self.inputs.get('qpoints_force_parity', orm.Bool(False)), 'metadata': { 'call_link_label': 'create_qpoints_from_distance' } } qpoints = create_kpoints_from_distance(**inputs) - self.ctx.inputs.ph['qpoints'] = qpoints + self.ctx.inputs['qpoints'] = qpoints def set_max_seconds(self, max_wallclock_seconds: None): """Set the `max_seconds` to a fraction of `max_wallclock_seconds` option to prevent out-of-walltime problems. diff --git a/tests/conftest.py b/tests/conftest.py index 6f99ea467..c208e09eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -591,18 +591,41 @@ def _generate_inputs_q2r(): @pytest.fixture -def generate_inputs_ph(fixture_sandbox, fixture_localhost, fixture_code, generate_remote_data, generate_kpoints_mesh): +def generate_inputs_ph( + generate_calc_job_node, generate_structure, fixture_localhost, fixture_code, generate_kpoints_mesh +): """Generate default inputs for a `PhCalculation.""" - def _generate_inputs_ph(): - """Generate default inputs for a `PhCalculation.""" - from aiida.orm import Dict + def _generate_inputs_ph(with_output_structure=False): + """Generate default inputs for a `PhCalculation. + + :param with_output_structure: whether the PwCalculation has a StructureData in its outputs. + This is needed to test some PhBaseWorkChain logics. + """ + from aiida.common import LinkType + from aiida.orm import Dict, RemoteData from aiida_quantumespresso.utils.resources import get_default_options + pw_node = generate_calc_job_node( + entry_point_name='quantumespresso.pw', inputs={ + 'parameters': Dict(), + 'structure': generate_structure() + } + ) + remote_folder = RemoteData(computer=fixture_localhost, remote_path='/tmp') + remote_folder.base.links.add_incoming(pw_node, link_type=LinkType.CREATE, link_label='remote_folder') + remote_folder.store() + parent_folder = pw_node.outputs.remote_folder + + if with_output_structure: + structure = generate_structure() + structure.base.links.add_incoming(pw_node, link_type=LinkType.CREATE, link_label='output_structure') + structure.store() + inputs = { 'code': fixture_code('quantumespresso.ph'), - 'parent_folder': generate_remote_data(fixture_localhost, fixture_sandbox.abspath, 'quantumespresso.pw'), + 'parent_folder': parent_folder, 'qpoints': generate_kpoints_mesh(2), 'parameters': Dict({'INPUTPH': {}}), 'metadata': { @@ -806,7 +829,7 @@ def _generate_workchain_ph(exit_code=None, inputs=None, return_inputs=False): if inputs is None: ph_inputs = generate_inputs_ph() - qpoints = ph_inputs.get('qpoints') + qpoints = ph_inputs.pop('qpoints') inputs = {'ph': ph_inputs, 'qpoints': qpoints} if return_inputs: diff --git a/tests/workflows/ph/test_base.py b/tests/workflows/ph/test_base.py index e978f62b1..5db0479e9 100644 --- a/tests/workflows/ph/test_base.py +++ b/tests/workflows/ph/test_base.py @@ -10,15 +10,6 @@ from aiida_quantumespresso.workflows.ph.base import PhBaseWorkChain -@pytest.mark.usefixtures('aiida_profile') -def test_invalid_inputs(generate_workchain_ph, generate_inputs_ph): - """Test `PhBaseWorkChain` validation methods.""" - inputs = {'ph': generate_inputs_ph()} - message = r'Neither `qpoints` nor `qpoints_distance` were specified.' - with pytest.raises(ValueError, match=message): - generate_workchain_ph(inputs=inputs) - - @pytest.fixture def generate_ph_calc_job_node(generate_calc_job_node, fixture_localhost): """Generate a ``CalcJobNode`` that would have been created by a ``PhCalculation``.""" @@ -43,6 +34,15 @@ def _generate_ph_calc_job_node(): return _generate_ph_calc_job_node +@pytest.mark.usefixtures('aiida_profile') +def test_invalid_inputs(generate_workchain_ph, generate_inputs_ph): + """Test `PhBaseWorkChain` validation methods.""" + inputs = {'ph': generate_inputs_ph()} + message = r'Neither `qpoints` nor `qpoints_distance` were specified.' + with pytest.raises(ValueError, match=message): + generate_workchain_ph(inputs=inputs) + + def test_setup(generate_workchain_ph): """Test `PhBaseWorkChain.setup`.""" process = generate_workchain_ph() @@ -52,6 +52,30 @@ def test_setup(generate_workchain_ph): assert isinstance(process.ctx.inputs, AttributeDict) +@pytest.mark.parametrize( + ('with_output_structure', 'with_qpoints_distance'), + ((False, False), (False, True), (True, True)), +) +def test_set_qpoints(generate_workchain_ph, generate_inputs_ph, with_output_structure, with_qpoints_distance): + """Test `PhBaseWorkChain.set_qpoints`.""" + inputs = {'ph': generate_inputs_ph(with_output_structure=with_output_structure)} + inputs['qpoints'] = inputs['ph'].pop('qpoints') + + if with_qpoints_distance: + inputs.pop('qpoints') + inputs['qpoints_distance'] = orm.Float(0.5) + + process = generate_workchain_ph(inputs=inputs) + process.setup() + process.set_qpoints() + + assert 'qpoints' in process.ctx.inputs + assert isinstance(process.ctx.inputs['qpoints'], orm.KpointsData) + + if not with_qpoints_distance: + assert process.ctx.inputs['qpoints'] == inputs['qpoints'] + + def test_handle_unrecoverable_failure(generate_workchain_ph): """Test `PhBaseWorkChain.handle_unrecoverable_failure`.""" process = generate_workchain_ph(exit_code=PhCalculation.exit_codes.ERROR_NO_RETRIEVED_FOLDER)