diff --git a/src/aiida_quantumespresso/workflows/functions/get_xspectra_structures.py b/src/aiida_quantumespresso/workflows/functions/get_xspectra_structures.py index 369bfa45..7dbcbaf0 100644 --- a/src/aiida_quantumespresso/workflows/functions/get_xspectra_structures.py +++ b/src/aiida_quantumespresso/workflows/functions/get_xspectra_structures.py @@ -360,6 +360,7 @@ def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-st new_supercell = get_supercell_result['new_supercell'] output_params['supercell_factors'] = multiples + result['supercell'] = new_supercell output_params['supercell_num_sites'] = len(new_supercell.sites) output_params['supercell_cell_matrix'] = new_supercell.cell output_params['supercell_cell_lengths'] = new_supercell.cell_lengths diff --git a/src/aiida_quantumespresso/workflows/xspectra/crystal.py b/src/aiida_quantumespresso/workflows/xspectra/crystal.py index 1d56a50f..76c866fc 100644 --- a/src/aiida_quantumespresso/workflows/xspectra/crystal.py +++ b/src/aiida_quantumespresso/workflows/xspectra/crystal.py @@ -4,7 +4,7 @@ Uses QuantumESPRESSO pw.x and xspectra.x. """ from aiida import orm -from aiida.common import AttributeDict, ValidationError +from aiida.common import AttributeDict from aiida.engine import ToContext, WorkChain, if_ from aiida.orm import UpfData as aiida_core_upf from aiida.plugins import CalculationFactory, DataFactory, WorkflowFactory @@ -173,6 +173,19 @@ def define(cls, spec): help=('Input namespace to provide core wavefunction inputs for each element. Must follow the format: ' '``core_wfc_data__{symbol} = {node}``') ) + spec.input_namespace( + 'symmetry_data', + valid_type=(orm.Dict, orm.Int), + dynamic=True, + required=False, + help=( + 'Input namespace to define equivalent sites and spacegroup number for the system. If defined, will ' + 'skip symmetry analysis and structure standardization. Use *only* if symmetry data are known ' + 'for certain. Requires ``spacegroup_number`` (Int) and ``equivalent_sites_data`` (Dict) to be ' + 'defined separately. All keys in `equivalent_sites_data` must be formatted as "site_". ' + 'See docstring of `get_xspectra_structures` for more information about inputs.' + ) + ) spec.inputs.validator = cls.validate_inputs spec.outline( cls.setup, @@ -370,7 +383,7 @@ def get_builder_from_protocol( # pylint: disable=too-many-statements @staticmethod - def validate_inputs(inputs, _): + def validate_inputs(inputs, _): # pylint: disable=too-many-return-statements """Validate the inputs before launching the WorkChain.""" structure = inputs['structure'] kinds_present = [kind.name for kind in structure.kinds] @@ -382,54 +395,92 @@ def validate_inputs(inputs, _): if element not in elements_present: extra_elements.append(element) if len(extra_elements) > 0: - raise ValidationError( + return ( f'Some elements in ``elements_list`` {extra_elements} do not exist in the' f' structure provided {elements_present}.' ) abs_atom_marker = inputs['abs_atom_marker'].value if abs_atom_marker in kinds_present: - raise ValidationError( + return ( f'The marker given for the absorbing atom ("{abs_atom_marker}") matches an existing Kind in the ' f'input structure ({kinds_present}).' ) if not inputs['core']['get_powder_spectrum'].value: - raise ValidationError( + return ( 'The ``get_powder_spectrum`` input for the XspectraCoreWorkChain namespace must be ``True``.' ) if 'upf2plotcore_code' not in inputs and 'core_wfc_data' not in inputs: - raise ValidationError( + return ( 'Neither a ``Code`` node for upf2plotcore.sh or a set of ``core_wfc_data`` were provided.' ) if 'core_wfc_data' in inputs: core_wfc_data_list = sorted(inputs['core_wfc_data'].keys()) if core_wfc_data_list != absorbing_elements_list: - raise ValidationError( + return ( f'The ``core_wfc_data`` provided ({core_wfc_data_list}) does not match the list of' f' absorbing elements ({absorbing_elements_list})' ) - else: - empty_core_wfc_data = [] - for key, value in inputs['core_wfc_data'].items(): - header_line = value.get_content()[:40] - try: - num_core_states = int(header_line.split(' ')[5]) - except Exception as exc: - raise ValidationError( - 'The core wavefunction data file is not of the correct format' - ) from exc - if num_core_states == 0: - empty_core_wfc_data.append(key) - if len(empty_core_wfc_data) > 0: - raise ValidationError( - f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain ' - 'any wavefunction data.' - ) + empty_core_wfc_data = [] + for key, value in inputs['core_wfc_data'].items(): + header_line = value.get_content()[:40] + try: + num_core_states = int(header_line.split(' ')[5]) + except: # pylint: disable=bare-except + return ( + 'The core wavefunction data file is not of the correct format' + ) # pylint: enable=bare-except + if num_core_states == 0: + empty_core_wfc_data.append(key) + if len(empty_core_wfc_data) > 0: + return ( + f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain ' + 'any wavefunction data.' + ) + + if 'symmetry_data' in inputs: + spacegroup_number = inputs['symmetry_data']['spacegroup_number'].value + equivalent_sites_data = inputs['symmetry_data']['equivalent_sites_data'].get_dict() + if spacegroup_number <= 0 or spacegroup_number >= 231: + return ( + f'Input spacegroup number ({spacegroup_number}) outside of valid range (1-230).' + ) + input_elements = [] + required_keys = sorted(['symbol', 'multiplicity', 'kind_name', 'site_index']) + invalid_entries = [] + # We check three things here: (1) are there any site indices which are outside of the possible + # range of site indices (2) do we have all the required keys for each entry, + # and (3) is there a mismatch between `absorbing_elements_list` and the elements specified + # in the entries of `equivalent_sites_data`. These checks are intended only to avoid a crash. + # We assume otherwise that the user knows what they're doing and has set everything else + # to their preferences correctly. + for site_label, value in equivalent_sites_data.items(): + if not set(required_keys).issubset(set(value.keys())) : + invalid_entries.append(site_label) + elif value['symbol'] not in input_elements: + input_elements.append(value['symbol']) + if value['site_index'] < 0 or value['site_index'] >= len(structure.sites): + return ( + f'The site index for {site_label} ({value["site_index"]}) is outside the range of ' + + f'sites within the structure (0-{len(structure.sites) -1}).' + ) + + if len(invalid_entries) != 0: + return ( + f'The required keys ({required_keys}) were not found in the following entries: {invalid_entries}' + ) + + sorted_input_elements = sorted(input_elements) + if sorted_input_elements != absorbing_elements_list: + return (f'Elements defined for sites in `equivalent_sites_data` ({sorted_input_elements}) ' + f'do not match the list of absorbing elements ({absorbing_elements_list})') + + # pylint: enable=too-many-return-statements def setup(self): """Set required context variables.""" if 'core_wfc_data' in self.inputs.keys(): @@ -489,6 +540,12 @@ def get_xspectra_structures(self): if 'spglib_settings' in self.inputs: inputs['spglib_settings'] = self.inputs.spglib_settings + if 'symmetry_data' in self.inputs: + inputs['parse_symmetry'] = orm.Bool(False) + input_sym_data = self.inputs.symmetry_data + inputs['equivalent_sites_data'] = input_sym_data['equivalent_sites_data'] + inputs['spacegroup_number'] = input_sym_data['spacegroup_number'] + if 'relax' in self.inputs: result = get_xspectra_structures(self.ctx.optimized_structure, **inputs) else: