diff --git a/src/aiida_quantumespresso/workflows/pdos.py b/src/aiida_quantumespresso/workflows/pdos.py index 888379f63..0e570ce46 100644 --- a/src/aiida_quantumespresso/workflows/pdos.py +++ b/src/aiida_quantumespresso/workflows/pdos.py @@ -97,7 +97,6 @@ def validate_inputs(value, _): - Check that either the `scf` or `nscf.pw.parent_folder` inputs is provided. - Check that the `Emin`, `Emax` and `DeltaE` inputs are the same for the `dos` and `projwfc` namespaces. - - Check that `Emin` and `Emax` are provided in case `align_to_fermi` is set to `True`. """ # Check that either the `scf` input or `nscf.pw.parent_folder` is provided. import warnings @@ -113,10 +112,13 @@ def validate_inputs(value, _): if value['dos']['parameters']['DOS'].get(par, None) != value['projwfc']['parameters']['PROJWFC'].get(par, None): return f'The `{par}`` parameter has to be equal for the `dos` and `projwfc` inputs.' - if value.get('align_to_fermi', False): + if value.get('energy_range_vs_fermi', False): for par in ['Emin', 'Emax']: - if value['dos']['parameters']['DOS'].get(par, None) is None: - return f'The `{par}`` parameter must be set in case `align_to_fermi` is set to `True`.' + if value['dos']['parameters']['DOS'].get(par, None): + warnings.warn( + f'The `{par}` parameter and `energy_range_vs_fermi` were specified.' + 'The value in `energy_range_vs_fermi` will be used.' + ) def validate_scf(value, _): @@ -157,6 +159,17 @@ def validate_projwfc(value, _): jsonschema.validate(value['parameters'].get_dict()['PROJWFC'], get_parameter_schema()) +def validate_energy_range_vs_fermi(value, _): + """Validate specified energy_range_vs_fermi. + + - List needs to consist of two float values. + """ + if len(value) != 2: + return f'`energy_range_vs_fermi` should be a `List` of length two, but got: {value}' + if not all(isinstance(val, (float, int)) for val in value): + return f'`energy_range_vs_fermi` should be a `List` of floats, but got: {value}' + + def clean_calcjob_remote(node): """Clean the remote directory of a ``CalcJobNode``.""" cleaned = False @@ -217,14 +230,15 @@ def define(cls, spec): help='Terminate workchain steps before submitting calculations (test purposes only).' ) spec.input( - 'align_to_fermi', - valid_type=orm.Bool, + 'energy_range_vs_fermi', + valid_type=orm.List, + required=False, serializer=to_aiida_type, - default=lambda: orm.Bool(False), + validator=validate_energy_range_vs_fermi, help=( - 'If true, Emin=>Emin-Efermi & Emax=>Emax-Efermi, where Efermi is taken from the `nscf` calculation. ' - 'Note that it only makes sense to align `Emax` and `Emin` to the fermi level in case they are actually ' - 'provided by in the `dos` and `projwfc` inputs, since otherwise the ' + 'Energy range with respect to the Fermi level that should be covered in DOS and PROJWFC calculation.' + 'If not specified but Emin and Emax are specified in the input parameters, these values will be used.' + 'Otherwise, the default values are extracted from the NSCF calculation.' ) ) spec.expose_inputs( @@ -375,6 +389,9 @@ def setup(self): """Initialize context variables that are used during the logical flow of the workchain.""" self.ctx.serial_clean = 'serial_clean' in self.inputs and self.inputs.serial_clean.value self.ctx.dry_run = 'dry_run' in self.inputs and self.inputs.dry_run.value + self.ctx.energy_range_vs_fermi = ( + self.inputs.energy_range_vs_fermi if 'energy_range_vs_fermi' in self.inputs else None + ) def serial_clean(self): """Return whether dos and projwfc calculations should be run in serial. @@ -466,9 +483,12 @@ def _generate_dos_inputs(self): dos_inputs.parent_folder = self.ctx.nscf_parent_folder dos_parameters = self.inputs.dos.parameters.get_dict() - if dos_parameters.pop('align_to_fermi', False): - dos_parameters['DOS']['Emin'] = dos_parameters['Emin'] + self.ctx.nscf_fermi - dos_parameters['DOS']['Emax'] = dos_parameters['Emax'] + self.ctx.nscf_fermi + if self.ctx.energy_range_vs_fermi: + dos_parameters['DOS']['Emin'] = self.ctx.energy_range_vs_fermi[0] + self.ctx.nscf_fermi + dos_parameters['DOS']['Emax'] = self.ctx.energy_range_vs_fermi[1] + self.ctx.nscf_fermi + else: + dos_parameters['DOS'].setdefault('Emin', self.ctx.nscf_emin) + dos_parameters['DOS'].setdefault('Emax', self.ctx.nscf_emax) dos_inputs.parameters = orm.Dict(dos_parameters) dos_inputs['metadata']['call_link_label'] = 'dos' @@ -480,9 +500,12 @@ def _generate_projwfc_inputs(self): projwfc_inputs.parent_folder = self.ctx.nscf_parent_folder projwfc_parameters = self.inputs.projwfc.parameters.get_dict() - if projwfc_parameters.pop('align_to_fermi', False): - projwfc_parameters['PROJWFC']['Emin'] = projwfc_parameters['Emin'] + self.ctx.nscf_fermi - projwfc_parameters['PROJWFC']['Emax'] = projwfc_parameters['Emax'] + self.ctx.nscf_fermi + if self.ctx.energy_range_vs_fermi: + projwfc_parameters['PROJWFC']['Emin'] = self.ctx.energy_range_vs_fermi[0] + self.ctx.nscf_fermi + projwfc_parameters['PROJWFC']['Emax'] = self.ctx.energy_range_vs_fermi[1] + self.ctx.nscf_fermi + else: + projwfc_parameters['PROJWFC'].setdefault('Emin', self.ctx.nscf_emin) + projwfc_parameters['PROJWFC'].setdefault('Emax', self.ctx.nscf_emax) projwfc_inputs.parameters = orm.Dict(projwfc_parameters) projwfc_inputs['metadata']['call_link_label'] = 'projwfc' diff --git a/tests/conftest.py b/tests/conftest.py index e0b08b57e..1de5d2c94 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -806,8 +806,8 @@ def _generate_workchain_ph(exit_code=None, inputs=None, return_inputs=False): def generate_workchain_pdos(generate_workchain, generate_inputs_pw, fixture_code): """Generate an instance of a `PdosWorkChain`.""" - def _generate_workchain_pdos(): - from aiida.orm import Bool, Dict + def _generate_workchain_pdos(emin=None, emax=None, energy_range_vs_fermi=None): + from aiida.orm import Bool, Dict, List from aiida_quantumespresso.utils.resources import get_default_options @@ -829,12 +829,15 @@ def _generate_workchain_pdos(): dos_params = { 'DOS': { - 'Emin': -10, - 'Emax': 10, 'DeltaE': 0.01, } } - projwfc_params = {'PROJWFC': {'Emin': -10, 'Emax': 10, 'DeltaE': 0.01, 'ngauss': 0, 'degauss': 0.01}} + projwfc_params = {'PROJWFC': {'DeltaE': 0.01, 'ngauss': 0, 'degauss': 0.01}} + + if emin and emax: + dos_params['DOS'].update({'Emin': emin, 'Emax': emax}) + projwfc_params['PROJWFC'].update({'Emin': emin, 'Emax': emax}) + dos = { 'code': fixture_code('quantumespresso.dos'), 'parameters': Dict(dos_params), @@ -855,9 +858,10 @@ def _generate_workchain_pdos(): 'nscf': nscf, 'dos': dos, 'projwfc': projwfc, - 'align_to_fermi': Bool(True), 'dry_run': Bool(True) } + if energy_range_vs_fermi: + inputs.update({'energy_range_vs_fermi': List(energy_range_vs_fermi)}) return generate_workchain(entry_point, inputs) diff --git a/tests/workflows/test_pdos.py b/tests/workflows/test_pdos.py index beb8b3ddc..0fdfd6b7a 100644 --- a/tests/workflows/test_pdos.py +++ b/tests/workflows/test_pdos.py @@ -8,6 +8,7 @@ from aiida.engine.utils import instantiate_process from aiida.manage.manager import get_manager from plumpy import ProcessState +import pytest from aiida_quantumespresso.calculations.helpers import pw_input_helper @@ -26,6 +27,24 @@ def instantiate_process_cls(process_cls, inputs): return instantiate_process(runner, process_cls, **inputs) +def check_pdos_energy_range(dos_inputs, projwfc_inputs, expected_p_dos_inputs): + """Check the energy range of the pdos calculation.""" + # check generated inputs + dos_params = dos_inputs.parameters.get_dict() + projwfc_params = projwfc_inputs.parameters.get_dict() + + assert dos_params['DOS']['Emin'] == expected_p_dos_inputs[0] + assert dos_params['DOS']['Emax'] == expected_p_dos_inputs[1] + assert projwfc_params['PROJWFC']['Emin'] == expected_p_dos_inputs[0] + assert projwfc_params['PROJWFC']['Emax'] == expected_p_dos_inputs[1] + + +@pytest.mark.parametrize( + 'energy_range_inputs,expected_p_dos_inputs', [((-10, 10, None), (-10, 10)), + ((None, None, [-10, 10]), (-3.0970404109571996, 16.9029595890428)), + ((None, None, None), (-5.64024889, 8.91047649))] +) +@pytest.mark.usefixtures('aiida_profile_clean') def test_default( generate_workchain_pdos, generate_workchain_pw, @@ -35,10 +54,12 @@ def test_default( generate_calc_job_node, fixture_sandbox, generate_bands_data, + energy_range_inputs, + expected_p_dos_inputs, ): """Test instantiating the WorkChain, then mock its process, by calling methods in the ``spec.outline``.""" + wkchain = generate_workchain_pdos(*energy_range_inputs) - wkchain = generate_workchain_pdos() assert wkchain.setup() is None assert wkchain.serial_clean() is False @@ -76,12 +97,10 @@ def test_default( remote.store() remote.base.links.add_incoming(mock_wknode, link_type=LinkType.RETURN, link_label='remote_folder') - result = orm.Dict({'fermi_energy': 6.9029595890428}) - result.store() + result = orm.Dict({'fermi_energy': 6.9029595890428}).store() result.base.links.add_incoming(mock_wknode, link_type=LinkType.RETURN, link_label='output_parameters') - bands_data = generate_bands_data() - bands_data.store() + bands_data = generate_bands_data().store() bands_data.base.links.add_incoming(mock_wknode, link_type=LinkType.RETURN, link_label='output_band') wkchain.ctx.workchain_nscf = mock_wknode @@ -90,6 +109,9 @@ def test_default( # mock run dos and projwfc, and check that their inputs are acceptable dos_inputs, projwfc_inputs = wkchain.run_pdos_parallel() + + check_pdos_energy_range(dos_inputs, projwfc_inputs, expected_p_dos_inputs) + generate_calc_job(fixture_sandbox, 'quantumespresso.dos', dos_inputs) generate_calc_job(fixture_sandbox, 'quantumespresso.projwfc', projwfc_inputs)