Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PDOSWorkChain - align energy range to fermi level #764

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
55 changes: 39 additions & 16 deletions src/aiida_quantumespresso/workflows/pdos.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def validate_inputs(value, _):

- Check that either the `scf` or `nscf.pw.parent_folder` inputs is provided.
- Check that the `Emin`, `Emax` and `DeltaE` inputs are the same for the `dos` and `projwfc` namespaces.
- Check that `Emin` and `Emax` are provided in case `align_to_fermi` is set to `True`.
"""
# Check that either the `scf` input or `nscf.pw.parent_folder` is provided.
import warnings
Expand All @@ -113,10 +112,13 @@ def validate_inputs(value, _):
if value['dos']['parameters']['DOS'].get(par, None) != value['projwfc']['parameters']['PROJWFC'].get(par, None):
return f'The `{par}`` parameter has to be equal for the `dos` and `projwfc` inputs.'

if value.get('align_to_fermi', False):
if value.get('energy_range_vs_fermi', False):
for par in ['Emin', 'Emax']:
if value['dos']['parameters']['DOS'].get(par, None) is None:
return f'The `{par}`` parameter must be set in case `align_to_fermi` is set to `True`.'
if value['dos']['parameters']['DOS'].get(par, None):
warnings.warn(
f'The `{par}` parameter and `energy_range_vs_fermi` were specified.'
'The value in `energy_range_vs_fermi` will be used.'
)


def validate_scf(value, _):
Expand Down Expand Up @@ -157,6 +159,17 @@ def validate_projwfc(value, _):
jsonschema.validate(value['parameters'].get_dict()['PROJWFC'], get_parameter_schema())


def validate_energy_range_vs_fermi(value, _):
"""Validate specified energy_range_vs_fermi.

- List needs to consist of two float values.
"""
if len(value) != 2:
return f'`energy_range_vs_fermi` should be a `List` of length two, but got: {value}'
if not all(isinstance(val, (float, int)) for val in value):
return f'`energy_range_vs_fermi` should be a `List` of floats, but got: {value}'


def clean_calcjob_remote(node):
"""Clean the remote directory of a ``CalcJobNode``."""
cleaned = False
Expand Down Expand Up @@ -217,14 +230,15 @@ def define(cls, spec):
help='Terminate workchain steps before submitting calculations (test purposes only).'
)
spec.input(
'align_to_fermi',
valid_type=orm.Bool,
'energy_range_vs_fermi',
valid_type=orm.List,
required=False,
serializer=to_aiida_type,
default=lambda: orm.Bool(False),
validator=validate_energy_range_vs_fermi,
help=(
'If true, Emin=>Emin-Efermi & Emax=>Emax-Efermi, where Efermi is taken from the `nscf` calculation. '
'Note that it only makes sense to align `Emax` and `Emin` to the fermi level in case they are actually '
'provided by in the `dos` and `projwfc` inputs, since otherwise the '
'Energy range with respect to the Fermi level that should be covered in DOS and PROJWFC calculation.'
'If not specified but Emin and Emax are specified in the input parameters, these values will be used.'
'Otherwise, the default values are extracted from the NSCF calculation.'
)
)
spec.expose_inputs(
Expand Down Expand Up @@ -375,6 +389,9 @@ def setup(self):
"""Initialize context variables that are used during the logical flow of the workchain."""
self.ctx.serial_clean = 'serial_clean' in self.inputs and self.inputs.serial_clean.value
self.ctx.dry_run = 'dry_run' in self.inputs and self.inputs.dry_run.value
self.ctx.energy_range_vs_fermi = (
self.inputs.energy_range_vs_fermi if 'energy_range_vs_fermi' in self.inputs else None
)

def serial_clean(self):
"""Return whether dos and projwfc calculations should be run in serial.
Expand Down Expand Up @@ -466,9 +483,12 @@ def _generate_dos_inputs(self):
dos_inputs.parent_folder = self.ctx.nscf_parent_folder
dos_parameters = self.inputs.dos.parameters.get_dict()

if dos_parameters.pop('align_to_fermi', False):
dos_parameters['DOS']['Emin'] = dos_parameters['Emin'] + self.ctx.nscf_fermi
dos_parameters['DOS']['Emax'] = dos_parameters['Emax'] + self.ctx.nscf_fermi
if self.ctx.energy_range_vs_fermi:
dos_parameters['DOS']['Emin'] = self.ctx.energy_range_vs_fermi[0] + self.ctx.nscf_fermi
dos_parameters['DOS']['Emax'] = self.ctx.energy_range_vs_fermi[1] + self.ctx.nscf_fermi
else:
dos_parameters['DOS'].setdefault('Emin', self.ctx.nscf_emin)
dos_parameters['DOS'].setdefault('Emax', self.ctx.nscf_emax)

dos_inputs.parameters = orm.Dict(dos_parameters)
dos_inputs['metadata']['call_link_label'] = 'dos'
Expand All @@ -480,9 +500,12 @@ def _generate_projwfc_inputs(self):
projwfc_inputs.parent_folder = self.ctx.nscf_parent_folder
projwfc_parameters = self.inputs.projwfc.parameters.get_dict()

if projwfc_parameters.pop('align_to_fermi', False):
projwfc_parameters['PROJWFC']['Emin'] = projwfc_parameters['Emin'] + self.ctx.nscf_fermi
projwfc_parameters['PROJWFC']['Emax'] = projwfc_parameters['Emax'] + self.ctx.nscf_fermi
if self.ctx.energy_range_vs_fermi:
projwfc_parameters['PROJWFC']['Emin'] = self.ctx.energy_range_vs_fermi[0] + self.ctx.nscf_fermi
projwfc_parameters['PROJWFC']['Emax'] = self.ctx.energy_range_vs_fermi[1] + self.ctx.nscf_fermi
else:
projwfc_parameters['PROJWFC'].setdefault('Emin', self.ctx.nscf_emin)
projwfc_parameters['PROJWFC'].setdefault('Emax', self.ctx.nscf_emax)

projwfc_inputs.parameters = orm.Dict(projwfc_parameters)
projwfc_inputs['metadata']['call_link_label'] = 'projwfc'
Expand Down
16 changes: 10 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,8 +806,8 @@ def _generate_workchain_ph(exit_code=None, inputs=None, return_inputs=False):
def generate_workchain_pdos(generate_workchain, generate_inputs_pw, fixture_code):
"""Generate an instance of a `PdosWorkChain`."""

def _generate_workchain_pdos():
from aiida.orm import Bool, Dict
def _generate_workchain_pdos(emin=None, emax=None, energy_range_vs_fermi=None):
from aiida.orm import Bool, Dict, List

from aiida_quantumespresso.utils.resources import get_default_options

Expand All @@ -829,12 +829,15 @@ def _generate_workchain_pdos():

dos_params = {
'DOS': {
'Emin': -10,
'Emax': 10,
'DeltaE': 0.01,
}
}
projwfc_params = {'PROJWFC': {'Emin': -10, 'Emax': 10, 'DeltaE': 0.01, 'ngauss': 0, 'degauss': 0.01}}
projwfc_params = {'PROJWFC': {'DeltaE': 0.01, 'ngauss': 0, 'degauss': 0.01}}

if emin and emax:
dos_params['DOS'].update({'Emin': emin, 'Emax': emax})
projwfc_params['PROJWFC'].update({'Emin': emin, 'Emax': emax})

dos = {
'code': fixture_code('quantumespresso.dos'),
'parameters': Dict(dos_params),
Expand All @@ -855,9 +858,10 @@ def _generate_workchain_pdos():
'nscf': nscf,
'dos': dos,
'projwfc': projwfc,
'align_to_fermi': Bool(True),
'dry_run': Bool(True)
}
if energy_range_vs_fermi:
inputs.update({'energy_range_vs_fermi': List(energy_range_vs_fermi)})

return generate_workchain(entry_point, inputs)

Expand Down
32 changes: 27 additions & 5 deletions tests/workflows/test_pdos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aiida.engine.utils import instantiate_process
from aiida.manage.manager import get_manager
from plumpy import ProcessState
import pytest

from aiida_quantumespresso.calculations.helpers import pw_input_helper

Expand All @@ -26,6 +27,24 @@ def instantiate_process_cls(process_cls, inputs):
return instantiate_process(runner, process_cls, **inputs)


def check_pdos_energy_range(dos_inputs, projwfc_inputs, expected_p_dos_inputs):
"""Check the energy range of the pdos calculation."""
# check generated inputs
dos_params = dos_inputs.parameters.get_dict()
projwfc_params = projwfc_inputs.parameters.get_dict()

assert dos_params['DOS']['Emin'] == expected_p_dos_inputs[0]
assert dos_params['DOS']['Emax'] == expected_p_dos_inputs[1]
assert projwfc_params['PROJWFC']['Emin'] == expected_p_dos_inputs[0]
assert projwfc_params['PROJWFC']['Emax'] == expected_p_dos_inputs[1]


@pytest.mark.parametrize(
'energy_range_inputs,expected_p_dos_inputs', [((-10, 10, None), (-10, 10)),
((None, None, [-10, 10]), (-3.0970404109571996, 16.9029595890428)),
((None, None, None), (-5.64024889, 8.91047649))]
)
@pytest.mark.usefixtures('aiida_profile_clean')
def test_default(
generate_workchain_pdos,
generate_workchain_pw,
Expand All @@ -35,10 +54,12 @@ def test_default(
generate_calc_job_node,
fixture_sandbox,
generate_bands_data,
energy_range_inputs,
expected_p_dos_inputs,
):
"""Test instantiating the WorkChain, then mock its process, by calling methods in the ``spec.outline``."""
wkchain = generate_workchain_pdos(*energy_range_inputs)

wkchain = generate_workchain_pdos()
assert wkchain.setup() is None
assert wkchain.serial_clean() is False

Expand Down Expand Up @@ -76,12 +97,10 @@ def test_default(
remote.store()
remote.base.links.add_incoming(mock_wknode, link_type=LinkType.RETURN, link_label='remote_folder')

result = orm.Dict({'fermi_energy': 6.9029595890428})
result.store()
result = orm.Dict({'fermi_energy': 6.9029595890428}).store()
result.base.links.add_incoming(mock_wknode, link_type=LinkType.RETURN, link_label='output_parameters')

bands_data = generate_bands_data()
bands_data.store()
bands_data = generate_bands_data().store()
bands_data.base.links.add_incoming(mock_wknode, link_type=LinkType.RETURN, link_label='output_band')

wkchain.ctx.workchain_nscf = mock_wknode
Expand All @@ -90,6 +109,9 @@ def test_default(

# mock run dos and projwfc, and check that their inputs are acceptable
dos_inputs, projwfc_inputs = wkchain.run_pdos_parallel()

check_pdos_energy_range(dos_inputs, projwfc_inputs, expected_p_dos_inputs)

generate_calc_job(fixture_sandbox, 'quantumespresso.dos', dos_inputs)
generate_calc_job(fixture_sandbox, 'quantumespresso.projwfc', projwfc_inputs)

Expand Down