Skip to content

Commit

Permalink
XspectraCrystalWorkChain: Enable Symmetry Data Inputs
Browse files Browse the repository at this point in the history
Adds an input namespace for the `XspectraCrystalWorkChain` which
allows the user to define the spacegroup and equivalent sites data
for the incoming structure, thus instructing the WorkChain to generate
structures and run calculations for only the sites specified.

Changes:
* Adds the `symmetry_data` input namespace to `XspectraCrystalWorkChain`,
  which the `WorkChain` will use to generate structures and set the list
of polarisation vectors to calculate.
* Adds input validation steps for the symmetry data to check for
  required information and for entries which may cause a crash, though
does not check for issues beyond this in order to maximise flexibility
of use.
* Fixes an oversight in `get_xspectra_structures` where the `supercell`
  entry was not returned to the outputs when external symmetry data were
provided by the user.
  • Loading branch information
PNOGillespie committed May 15, 2024
1 parent 210c40b commit 0f06292
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-st
new_supercell = get_supercell_result['new_supercell']
output_params['supercell_factors'] = multiples

result['supercell'] = new_supercell
output_params['supercell_num_sites'] = len(new_supercell.sites)
output_params['supercell_cell_matrix'] = new_supercell.cell
output_params['supercell_cell_lengths'] = new_supercell.cell_lengths
Expand Down
71 changes: 71 additions & 0 deletions src/aiida_quantumespresso/workflows/xspectra/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,19 @@ def define(cls, spec):
help=('Input namespace to provide core wavefunction inputs for each element. Must follow the format: '
'``core_wfc_data__{symbol} = {node}``')
)
spec.input_namespace(
'symmetry_data',
valid_type=(orm.Dict, orm.Int),
dynamic=True,
required=False,
help=(
'Input namespace to define equivalent sites and spacegroup number for the system. If defined, will '
+ 'skip symmetry analysis and structure standardization. Use *only* if symmetry data are known'
+ 'for certain. Requires ``spacegroup_number`` (Int) and ``equivalent_sites_data`` (Dict) to be'
+ 'defined separately. All keys in `equivalent_sites_data` must be formatted as "site_<site_index>".'
+ 'See docstring of `get_xspectra_structures` for more information about inputs.'
)
)
spec.inputs.validator = cls.validate_inputs
spec.outline(
cls.setup,
Expand Down Expand Up @@ -369,6 +382,7 @@ def get_builder_from_protocol( # pylint: disable=too-many-statements
return builder


# pylint: disable=too-many-statements
@staticmethod
def validate_inputs(inputs, _):
"""Validate the inputs before launching the WorkChain."""
Expand Down Expand Up @@ -429,7 +443,58 @@ def validate_inputs(inputs, _):
'any wavefunction data.'
)

if 'symmetry_data' in inputs:
spacegroup_number = inputs['symmetry_data']['spacegroup_number'].value
equivalent_sites_data = inputs['symmetry_data']['equivalent_sites_data'].get_dict()
if spacegroup_number <= 0 or spacegroup_number >= 231:
raise ValidationError(
f'Input spacegroup number ({spacegroup_number}) outside of valid range (1-230).'
)

input_elements = []
required_keys = sorted(['symbol', 'multiplicity', 'kind_name', 'site_index'])
invalid_entries = []
# We check three things here: (1) are there any site indices which are outside of the possible
# range of site indices (2) do we have all the required keys for each entry,
# and (3) is there a mismatch between `absorbing_elements_list` and the elements specified
# in the entries of `equivalent_sites_data`. These checks are intended only to avoid a crash.
# We assume otherwise that the user knows what they're doing and has set everything else
# to their preferences correctly.
for site_label, value in equivalent_sites_data.items():
required_keys_found = []

if value['site_index'] < 0:
raise ValidationError(
f'The site index for {site_label} ({value["site_index"]}) is below the range of '
+ f'sites within the structure (0-{len(structure.sites) -1}).'
)
if value['site_index'] >= len(structure.sites):
raise ValidationError(
f'The site index for {site_label} ({value["site_index"]}) is above the range of '
+ f'sites within the structure (0-{len(structure.sites) -1}).'
)
for key in value:
if key in required_keys:
required_keys_found.append(key)
if sorted(required_keys_found) != required_keys:
invalid_entries.append(site_label)
elif value['symbol'] not in input_elements:
input_elements.append(value['symbol'])

if len(invalid_entries) != 0:
raise ValidationError(
f'The required keys ({required_keys}) were not found in the following entries: {invalid_entries}'
)

sorted_input_elements = sorted(input_elements)
if sorted_input_elements != absorbing_elements_list:
raise ValidationError(
f'Elements defined for sites in `equivalent_sites_data` ({sorted_input_elements}) do not match the'
+ f'list of absorbing elements ({absorbing_elements_list})'
)


# pylint: enable=too-many-statements
def setup(self):
"""Set required context variables."""
if 'core_wfc_data' in self.inputs.keys():
Expand Down Expand Up @@ -489,6 +554,12 @@ def get_xspectra_structures(self):
if 'spglib_settings' in self.inputs:
inputs['spglib_settings'] = self.inputs.spglib_settings

if 'symmetry_data' in self.inputs:
inputs['parse_symmetry'] = orm.Bool(False)
input_sym_data = self.inputs.symmetry_data
inputs['equivalent_sites_data'] = input_sym_data['equivalent_sites_data']
inputs['spacegroup_number'] = input_sym_data['spacegroup_number']

if 'relax' in self.inputs:
result = get_xspectra_structures(self.ctx.optimized_structure, **inputs)
else:
Expand Down

0 comments on commit 0f06292

Please sign in to comment.