diff --git a/src/aiida_quantumespresso/workflows/pdos.py b/src/aiida_quantumespresso/workflows/pdos.py index 311d8237..09ec4501 100644 --- a/src/aiida_quantumespresso/workflows/pdos.py +++ b/src/aiida_quantumespresso/workflows/pdos.py @@ -118,6 +118,9 @@ def validate_inputs(value, _): if value['dos']['parameters']['DOS'].get(par, None) is None: return f'The `{par}`` parameter must be set in case `align_to_fermi` is set to `True`.' + if 'nbands_factor' in value and 'nbnd' in value['nscf']['pw']['parameters'].base.attributes.get('SYSTEM', {}): + return PdosWorkChain.exit_codes.ERROR_INVALID_INPUT_NUMBER_OF_BANDS.message + def validate_scf(value, _): """Validate the scf parameters.""" @@ -227,6 +230,9 @@ def define(cls, spec): 'provided by in the `dos` and `projwfc` inputs, since otherwise the ' ) ) + spec.input('nbands_factor', valid_type=orm.Float, required=False, + help='The number of bands for the NSCF calculation is that used for the SCF multiplied by this factor.') + spec.expose_inputs( PwBaseWorkChain, namespace='scf', @@ -301,6 +307,8 @@ def define(cls, spec): message='the PROJWFC sub process failed') spec.exit_code(404, 'ERROR_SUB_PROCESS_FAILED_BOTH', message='both the DOS and PROJWFC sub process failed') + spec.exit_code(405, 'ERROR_INVALID_INPUT_NUMBER_OF_BANDS', + message='Cannot specify both `nbands_factor` and `nscf.pw.parameters.SYSTEM.nbnd`.') spec.expose_outputs(PwBaseWorkChain, namespace='nscf') spec.expose_outputs(DosCalculation, namespace='dos') @@ -426,11 +434,23 @@ def run_nscf(self): """ inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, 'nscf')) + if 'scf' in self.inputs: inputs.pw.parent_folder = self.ctx.scf_parent_folder + + if 'nbands_factor' in self.inputs: + inputs.pw.parameters = inputs.pw.parameters.get_dict() + factor = self.inputs.nbands_factor.value + parameters = self.ctx.workchain_scf.outputs.output_parameters.get_dict() + nbands = int(parameters['number_of_bands']) + nelectron = int(parameters['number_of_electrons']) + nbnd = max(int(0.5 * nelectron * factor), int(0.5 * nelectron) + 4, nbands) + inputs.pw.parameters['SYSTEM']['nbnd'] = nbnd + inputs.pw.structure = self.inputs.structure inputs.metadata.call_link_label = 'nscf' + inputs = prepare_process_inputs(PwBaseWorkChain, inputs) if self.ctx.dry_run: