-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add transfer calculation for density restart
- Loading branch information
1 parent
da94161
commit 8fd52c2
Showing
2 changed files
with
346 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,345 @@ | ||
# -*- coding: utf-8 -*- | ||
############################################################################################################## | ||
"""Workchain to compute a band structure for a given structure using Quantum ESPRESSO pw.x.""" | ||
import warnings | ||
|
||
from aiida import orm | ||
from aiida.engine import WorkChain, ToContext, if_, calcfunction | ||
from aiida.plugins import WorkflowFactory, CalculationFactory | ||
|
||
from aiida_quantumespresso.workflows.protocols.utils import ProtocolMixin | ||
|
||
TransferCalcjob = CalculationFactory('core.transfer') | ||
PwBaseWorkChain = WorkflowFactory('quantumespresso.pw.base') | ||
|
||
# pylint: disable=f-string-without-interpolation | ||
|
||
|
||
############################################################################################################## | ||
class TransferDensityWorkChain(ProtocolMixin, WorkChain): | ||
"""Workchain to transfer the charge density (and other restart data) from a Quantum ESPRESSO calculation. | ||
If a `RemoteData` node is provided, this has to point to the folder where a PW calculation was last ran. | ||
This workchain will then take the density (and other restart data files) and copy it to a local node. | ||
If a `FolderData` node is provided, you will also need to specify either a `remote_computer` or all the | ||
inputs for a `PwBaseWorkChain` run. With the first, this workchain will just transfer the density (and | ||
other restart data files) to the remote machine and end. If the inputs for a `PwBaseWorkChain` are | ||
provided instead, the workchain will also perform an NSCF calculation to re-generate the wavefunctions. | ||
""" | ||
|
||
@classmethod | ||
def define(cls, spec): | ||
"""Define the process specification.""" | ||
super().define(spec) | ||
|
||
spec.input( | ||
'data_source', | ||
valid_type=(orm.RemoteData, orm.FolderData), | ||
help='Node containing the density file and restart data (or path to it in a remote).' | ||
) | ||
|
||
spec.input( | ||
'remote_computer', | ||
valid_type=orm.Computer, | ||
non_db=True, | ||
required=False, | ||
help='Computer to which to transfer the information.' | ||
) | ||
|
||
spec.expose_inputs( | ||
PwBaseWorkChain, | ||
namespace='nscf', | ||
exclude=('clean_workdir', 'pw.parent_folder'), | ||
namespace_options={ | ||
'required': False, | ||
'populate_defaults': False, | ||
'help': 'Inputs for the `PwBaseWorkChain` for the NSCF calculation.', | ||
} | ||
) | ||
|
||
spec.inputs.validator = validate_inputs | ||
|
||
spec.outline( | ||
cls.run_transfer, | ||
cls.inspect_transfer, | ||
if_(cls.should_run_nscf)( | ||
cls.run_nscf, | ||
cls.inspect_nscf, | ||
), | ||
cls.results, | ||
) | ||
|
||
spec.output( | ||
'output_data', | ||
valid_type=(orm.RemoteData, orm.FolderData), | ||
help='The output node with the final data.', | ||
) | ||
|
||
spec.exit_code(401, 'ERROR_SUB_PROCESS_FAILED_TRANSFER', message='The TransferCalcjob sub process failed.') | ||
spec.exit_code(402, 'ERROR_SUB_PROCESS_FAILED_NSCF', message='The ncf PwBasexWorkChain sub process failed.') | ||
|
||
@classmethod | ||
def get_protocol_filepath(cls): | ||
"""Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" | ||
raise NotImplementedError(f'`get_protocol_filepath` method not yet implemented in `TransferDensityWorkChain`') | ||
|
||
@classmethod | ||
def get_builder_from_protocol( | ||
cls, data_source, remote_computer=None, code=None, structure=None, protocol=None, overrides=None, **kwargs | ||
): | ||
"""Return a builder prepopulated with inputs selected according to the chosen protocol. | ||
:param data_source: the source of the density (and the rest of the restart data). | ||
:param remote_computer: the ``Computer`` to which the data will be transfered. This is necessary | ||
if the ``data_source`` is a local ``FolderData`` and no code is going to be provided for | ||
running an NSCF calculation. | ||
:param code: the ``Code`` instance configured for the ``quantumespresso.pw`` plugin, required to | ||
run the NSF calculation. | ||
:param structure: the ``StructureData`` instance required to run the NSF calculation. | ||
:param protocol: protocol to use, if not specified, the default will be used. | ||
:param overrides: optional dictionary of inputs to override the defaults of the protocol. | ||
:param kwargs: additional keyword arguments that will be passed to the ``get_builder_from_protocol`` | ||
of all the sub processes that are called by this workchain. | ||
:return: a process builder instance with all inputs defined ready for launch. | ||
""" | ||
# inputs = cls.get_protocol_inputs(protocol, overrides) | ||
|
||
builder = cls.get_builder() | ||
builder.data_source = data_source | ||
|
||
if isinstance(data_source, orm.RemoteData): | ||
if (remote_computer is not None) or (code is not None) or (structure is not None): | ||
warnings.warn( | ||
f'\nWhen providing a data source of type `RemoteData`, the other 3 inputs ' | ||
f'(remote_computer, code and structure) will be ignored:' | ||
f'\n - remote_computer: {remote_computer}' | ||
f'\n - structure: {structure}' | ||
f'\n - code: {code}' | ||
) | ||
builder.pop('nscf') | ||
|
||
if isinstance(data_source, orm.FolderData): | ||
|
||
# If the code and structure are given: prepare NSCF and check computer compatibility | ||
# If the code and structure are abscent: check that a computer is provided | ||
# If only one was given: problematic | ||
|
||
if (code is not None) and (structure is not None): | ||
|
||
nscf_args = (code, structure, protocol) | ||
nscf_kwargs = kwargs | ||
|
||
nscf_kwargs['overrides'] = {} | ||
if overrides is not None: | ||
nscf_kwargs['overrides'] = overrides.get('nscf', None) | ||
|
||
# This is for easily setting defaults at each level of: | ||
# [overrides].nscf.pw.parameters.CONTROL.calculation | ||
last_layer = nscf_kwargs['overrides'] | ||
last_layer = last_layer.setdefault('pw', {}) | ||
last_layer = last_layer.setdefault('parameters', {}) | ||
last_layer = last_layer.setdefault('CONTROL', {}) | ||
|
||
if last_layer.setdefault('calculation', 'nscf') != 'nscf': | ||
bad_value = last_layer['calculation'] | ||
raise ValueError( | ||
f'The internal PwBaseWorkChain is for running an NSCF calculation, this should not\n' | ||
f'be overriden.\n' | ||
f'(Found overrides.nscf.pw.parametersCONTROL.calculation=`{bad_value}`)' | ||
) | ||
|
||
nscf = PwBaseWorkChain.get_builder_from_protocol(*nscf_args, **nscf_kwargs) | ||
nscf['pw'].pop('parent_folder', None) | ||
nscf.pop('clean_workdir', None) | ||
builder.nscf = nscf | ||
|
||
builder.remote_computer = code.computer | ||
if remote_computer is not None: | ||
warnings.warn( | ||
f'\nWhen providing a code for running the NSCF, any remote_computer given ' | ||
f'will be ignored:' | ||
f'\n - remote_computer: `{remote_computer}`' | ||
f'\n - code provided: `{code}`' | ||
) | ||
|
||
elif (code is None) and (structure is None): | ||
|
||
if remote_computer is None: | ||
raise ValueError( | ||
f'If the `data_source` is a `FolderData` node, a `remote_computer` must also be\n' | ||
f'specified, or at least inferred from a `code` provided for running the NSCF.\n' | ||
f'(Currently remote_computer=`{remote_computer}` and code=`{code}`)' | ||
) | ||
builder.remote_computer = remote_computer | ||
builder.pop('nscf') | ||
|
||
else: | ||
|
||
raise ValueError( | ||
f'To run the NSCF both the code and structure must be specified.\n' | ||
f'(Currently code=`{code}` and structure=`{structure}`)' | ||
) | ||
|
||
return builder | ||
|
||
def should_run_nscf(self): | ||
"""If the 'nscf' input namespace was specified, we reconstruct the wave functions.""" | ||
return 'nscf' in self.inputs | ||
|
||
def run_transfer(self): | ||
"""Run the TransferCalcjob.""" | ||
source_folder = self.inputs.data_source | ||
inputs = { | ||
'instructions': generate_instructions(source_folder)['instructions'], | ||
'source_nodes': { | ||
'source_node': source_folder | ||
}, | ||
'metadata': {} | ||
} | ||
|
||
if isinstance(source_folder, orm.FolderData): | ||
inputs['metadata']['call_link_label'] = 'transfer_put' | ||
if 'nscf' in self.inputs: | ||
inputs['metadata']['computer'] = self.inputs.nscf.pw.code.computer | ||
else: | ||
inputs['metadata']['computer'] = self.inputs.remote_computer | ||
|
||
elif isinstance(source_folder, orm.RemoteData): | ||
inputs['metadata']['call_link_label'] = 'transfer_get' | ||
inputs['metadata']['computer'] = source_folder.computer | ||
|
||
running = self.submit(TransferCalcjob, **inputs) | ||
self.report(f'launching TransferCalcjob<{running.pk}>') | ||
return ToContext(transfer_calcjob=running) | ||
|
||
def inspect_transfer(self): | ||
"""Verify that the TransferCalcjob finished successfully.""" | ||
source0_node = self.inputs.data_source | ||
calcjob_node = self.ctx.transfer_calcjob | ||
|
||
if not calcjob_node.is_finished_ok: | ||
self.report(f'TransferCalcjob failed with exit status {calcjob_node.exit_status}') | ||
return self.exit_codes.ERROR_SUB_PROCESS_FAILED_TRANSFER | ||
|
||
if isinstance(source0_node, orm.FolderData): | ||
self.ctx.last_output = calcjob_node.outputs.remote_folder | ||
|
||
elif isinstance(source0_node, orm.RemoteData): | ||
self.ctx.last_output = calcjob_node.outputs.retrieved | ||
|
||
def run_nscf(self): | ||
"""Run the PwBaseWorkChain in nscf mode on the restart folder.""" | ||
inputs = self.exposed_inputs(PwBaseWorkChain, namespace='nscf') | ||
inputs['metadata']['call_link_label'] = 'nscf' | ||
inputs['pw']['parent_folder'] = self.ctx.last_output | ||
|
||
running = self.submit(PwBaseWorkChain, **inputs) | ||
self.report(f'launching PwBaseWorkChain<{running.pk}> in nscf mode') | ||
return ToContext(workchain_nscf=running) | ||
|
||
def inspect_nscf(self): | ||
"""Verify that the PwBaseWorkChain for the scf run finished successfully.""" | ||
workchain = self.ctx.workchain_nscf | ||
|
||
if not workchain.is_finished_ok: | ||
self.report(f'scf PwBaseWorkChain failed with exit status {workchain.exit_status}') | ||
return self.exit_codes.ERROR_SUB_PROCESS_FAILED_NSCF | ||
|
||
self.ctx.last_output = workchain.outputs.remote_folder | ||
|
||
def results(self): | ||
"""Attach the desired output nodes directly as outputs of the workchain.""" | ||
self.report('workchain succesfully completed') | ||
self.out('output_data', self.ctx.last_output) | ||
|
||
|
||
############################################################################################################## | ||
def validate_inputs(inputs, _): | ||
"""Validate the inputs of the entire input namespace.""" | ||
source_folder = inputs['data_source'] | ||
|
||
# FolderData: files must be there and code/computer compatibility | ||
if isinstance(source_folder, orm.FolderData): | ||
|
||
error_message = '' | ||
if 'data-file-schema.xml' not in source_folder.list_object_names(): | ||
error_message += f'Missing `data-file-schema.xml` on node {source_folder.pk}\n' | ||
if 'charge-density.dat' not in source_folder.list_object_names(): | ||
error_message += f'Missing `charge-density.dat` on node {source_folder.pk}\n' | ||
if len(error_message) > 0: | ||
return error_message | ||
|
||
if ('nscf' in inputs) and ('remote_computer' in inputs): | ||
computer_remote = inputs['remote_computer'] | ||
computer_pwcode = inputs['nscf']['pw']['code'].computer | ||
if computer_remote.pk != computer_pwcode.pk: | ||
return ( | ||
f'\nSome of the inputs provided are associated to different computers:' | ||
f'\n - remote_computer: {computer_remote}' | ||
f'\n - nscf.pw.code: {computer_pwcode}' | ||
) | ||
|
||
elif ('nscf' not in inputs) and ('remote_computer' not in inputs): | ||
return 'The source is a FolderData and no code or remote_computer was provided.' | ||
|
||
# RemoteData: RemoteData and remote_computer compatibility and warn if NSCF was provided | ||
if isinstance(source_folder, orm.RemoteData): | ||
|
||
if 'remote_computer' in inputs: | ||
computer_remote = inputs['remote_computer'] | ||
computer_source = source_folder.computer | ||
if computer_remote.pk != computer_source.pk: | ||
return ( | ||
f'\nSome of the inputs provided are associated to different computers:' | ||
f'\n - remote_computer: {computer_remote}' | ||
f'\n - source_folder: {computer_source}' | ||
) | ||
|
||
if 'nscf' in inputs: | ||
warnings.warn( | ||
f'\nThe `source_folder` (PK={source_folder.pk}) is a RemoteData node, so the data will be' | ||
f'\nretrieved and the NSCF input will be ignored' | ||
) | ||
|
||
|
||
def validate_nscf(value, _): | ||
"""Validate the inputs of the nscf input namespace.""" | ||
parameters = value['pw']['parameters'].get_dict() | ||
if parameters.get('CONTROL', {}).get('calculation', 'scf') != 'nscf': | ||
return '`CONTOL.calculation` in `nscf.pw.parameters` is not set to `nscf`.' | ||
|
||
|
||
############################################################################################################## | ||
@calcfunction | ||
def generate_instructions(source_folder): | ||
"""Generate the instruction node to be used for copying the files.""" | ||
|
||
# Paths in the QE run folder | ||
schema_qepath = 'out/aiida.save/data-file-schema.xml' | ||
charge_qepath = 'out/aiida.save/charge-density.dat' | ||
pawtxt_qepath = 'out/aiida.save/paw.txt' | ||
|
||
# Paths in the local node | ||
schema_dbpath = 'data-file-schema.xml' | ||
charge_dbpath = 'charge-density.dat' | ||
pawtxt_dbpath = 'paw.txt' | ||
|
||
# Transfer from local to remote | ||
if isinstance(source_folder, orm.FolderData): | ||
instructions = {'retrieve_files': False, 'local_files': []} | ||
instructions['local_files'].append(('source_node', schema_dbpath, schema_qepath)) | ||
instructions['local_files'].append(('source_node', charge_dbpath, charge_qepath)) | ||
|
||
if 'paw.txt' in source_folder.list_object_names(): | ||
instructions['local_files'].append(('source_node', pawtxt_dbpath, pawtxt_qepath)) | ||
|
||
# Transfer from remote to local | ||
elif isinstance(source_folder, orm.RemoteData): | ||
instructions = {'retrieve_files': True, 'symlink_files': []} | ||
instructions['symlink_files'].append(('source_node', schema_qepath, schema_dbpath)) | ||
instructions['symlink_files'].append(('source_node', charge_qepath, charge_dbpath)) | ||
instructions['symlink_files'].append(('source_node', pawtxt_qepath, pawtxt_dbpath)) | ||
|
||
return {'instructions': orm.Dict(dict=instructions)} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters