Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: PhBaseWorkChain.set_qpoints doesn't work as expected #1005

Merged
merged 3 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/aiida_quantumespresso/workflows/ph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,27 +177,26 @@ def set_qpoints(self):
the case of the latter, the `KpointsData` will be constructed for the input `StructureData`
from the parent_folder using the `create_kpoints_from_distance` calculation function.
"""

try:
qpoints = self.inputs.qpoints
except AttributeError:

try:
structure = self.ctx.inputs.ph.parent_folder.creator.output.output_structure
structure = self.ctx.inputs.parent_folder.creator.output.output_structure
except AttributeError:
structure = self.ctx.inputs.ph.parent_folder.creator.inputs.structure
structure = self.ctx.inputs.parent_folder.creator.inputs.structure

inputs = {
'structure': structure,
'distance': self.inputs.qpoints_distance,
'force_parity': self.inputs.qpoints_force_parity,
'force_parity': self.inputs.get('qpoints_force_parity', orm.Bool(False)),
'metadata': {
'call_link_label': 'create_qpoints_from_distance'
}
}
qpoints = create_kpoints_from_distance(**inputs)

self.ctx.inputs.ph['qpoints'] = qpoints
self.ctx.inputs['qpoints'] = qpoints

def set_max_seconds(self, max_wallclock_seconds: None):
"""Set the `max_seconds` to a fraction of `max_wallclock_seconds` option to prevent out-of-walltime problems.
Expand Down
35 changes: 29 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,18 +591,41 @@ def _generate_inputs_q2r():


@pytest.fixture
def generate_inputs_ph(fixture_sandbox, fixture_localhost, fixture_code, generate_remote_data, generate_kpoints_mesh):
def generate_inputs_ph(
generate_calc_job_node, generate_structure, fixture_localhost, fixture_code, generate_kpoints_mesh
):
"""Generate default inputs for a `PhCalculation."""

def _generate_inputs_ph():
"""Generate default inputs for a `PhCalculation."""
from aiida.orm import Dict
def _generate_inputs_ph(with_output_structure=False):
bastonero marked this conversation as resolved.
Show resolved Hide resolved
"""Generate default inputs for a `PhCalculation.

:param with_output_structure: whether the PwCalculation has a StructureData in its outputs.
This is needed to test some PhBaseWorkChain logics.
"""
from aiida.common import LinkType
from aiida.orm import Dict, RemoteData

from aiida_quantumespresso.utils.resources import get_default_options

pw_node = generate_calc_job_node(
entry_point_name='quantumespresso.pw', inputs={
'parameters': Dict(),
'structure': generate_structure()
}
)
remote_folder = RemoteData(computer=fixture_localhost, remote_path='/tmp')
remote_folder.base.links.add_incoming(pw_node, link_type=LinkType.CREATE, link_label='remote_folder')
remote_folder.store()
parent_folder = pw_node.outputs.remote_folder

if with_output_structure:
structure = generate_structure()
structure.base.links.add_incoming(pw_node, link_type=LinkType.CREATE, link_label='output_structure')
structure.store()

inputs = {
'code': fixture_code('quantumespresso.ph'),
'parent_folder': generate_remote_data(fixture_localhost, fixture_sandbox.abspath, 'quantumespresso.pw'),
'parent_folder': parent_folder,
'qpoints': generate_kpoints_mesh(2),
'parameters': Dict({'INPUTPH': {}}),
'metadata': {
Expand Down Expand Up @@ -806,7 +829,7 @@ def _generate_workchain_ph(exit_code=None, inputs=None, return_inputs=False):

if inputs is None:
ph_inputs = generate_inputs_ph()
qpoints = ph_inputs.get('qpoints')
qpoints = ph_inputs.pop('qpoints')
inputs = {'ph': ph_inputs, 'qpoints': qpoints}

if return_inputs:
Expand Down
42 changes: 33 additions & 9 deletions tests/workflows/ph/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,6 @@
from aiida_quantumespresso.workflows.ph.base import PhBaseWorkChain


@pytest.mark.usefixtures('aiida_profile')
def test_invalid_inputs(generate_workchain_ph, generate_inputs_ph):
"""Test `PhBaseWorkChain` validation methods."""
inputs = {'ph': generate_inputs_ph()}
message = r'Neither `qpoints` nor `qpoints_distance` were specified.'
with pytest.raises(ValueError, match=message):
generate_workchain_ph(inputs=inputs)


@pytest.fixture
def generate_ph_calc_job_node(generate_calc_job_node, fixture_localhost):
"""Generate a ``CalcJobNode`` that would have been created by a ``PhCalculation``."""
Expand All @@ -43,6 +34,15 @@ def _generate_ph_calc_job_node():
return _generate_ph_calc_job_node


@pytest.mark.usefixtures('aiida_profile')
def test_invalid_inputs(generate_workchain_ph, generate_inputs_ph):
"""Test `PhBaseWorkChain` validation methods."""
inputs = {'ph': generate_inputs_ph()}
message = r'Neither `qpoints` nor `qpoints_distance` were specified.'
with pytest.raises(ValueError, match=message):
generate_workchain_ph(inputs=inputs)


def test_setup(generate_workchain_ph):
"""Test `PhBaseWorkChain.setup`."""
process = generate_workchain_ph()
Expand All @@ -52,6 +52,30 @@ def test_setup(generate_workchain_ph):
assert isinstance(process.ctx.inputs, AttributeDict)


@pytest.mark.parametrize(
('with_output_structure', 'with_qpoints_distance'),
((False, False), (False, True), (True, True)),
)
def test_set_qpoints(generate_workchain_ph, generate_inputs_ph, with_output_structure, with_qpoints_distance):
"""Test `PhBaseWorkChain.set_qpoints`."""
inputs = {'ph': generate_inputs_ph(with_output_structure=with_output_structure)}
inputs['qpoints'] = inputs['ph'].pop('qpoints')

if with_qpoints_distance:
inputs.pop('qpoints')
inputs['qpoints_distance'] = orm.Float(0.5)

process = generate_workchain_ph(inputs=inputs)
process.setup()
process.set_qpoints()

assert 'qpoints' in process.ctx.inputs
assert isinstance(process.ctx.inputs['qpoints'], orm.KpointsData)

if not with_qpoints_distance:
assert process.ctx.inputs['qpoints'] == inputs['qpoints']


def test_handle_unrecoverable_failure(generate_workchain_ph):
"""Test `PhBaseWorkChain.handle_unrecoverable_failure`."""
process = generate_workchain_ph(exit_code=PhCalculation.exit_codes.ERROR_NO_RETRIEVED_FOLDER)
Expand Down
Loading