-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add transfer calculation for density restart
- Loading branch information
1 parent
1b50690
commit 81a91a1
Showing
3 changed files
with
315 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# -*- coding: utf-8 -*- | ||
############################################################################################################## | ||
"""Return process builders ready for transferring Quantum ESPRESSO density restart files.""" | ||
|
||
import warnings | ||
|
||
from aiida.orm import FolderData, RemoteData, Dict | ||
from aiida.engine import calcfunction | ||
from aiida.plugins import CalculationFactory | ||
|
||
|
||
def get_transfer_builder(data_source, computer=None, track=False): | ||
"""Create a `ProcessBuilder` for `TransferCalcjob`from a data_source. | ||
The data_source can be of either `RemoteData` or `FolderData`: | ||
- `RemoteData`: generate a set of instructions so that the density restart data will be taken from the | ||
remote computer specified by the node and into the local aiida DB. | ||
- `FolderData`: generate a set of instructions so that the density restart data will be taken from the | ||
local aiida DB and into the provided computer (which has to be given as an extra parameter). | ||
:param data_source: the node instance from which to take the density data | ||
:param computer: if `data_source` is a `FolderData` node, the remote computer to which to transfer the data must | ||
be specified here | ||
:param track: boolean, if True, the generation of the instructions will be done through a calcfunction from the | ||
data_source as input (and thus be tracked as such in the provenance) | ||
:return: a `ProcessBuilder` instance configured for launching a `TransferCalcjob` | ||
""" | ||
builder = CalculationFactory('core.transfer').get_builder() | ||
builder.source_nodes = {'source_node': data_source} | ||
|
||
if isinstance(data_source, FolderData): | ||
if computer is None: | ||
raise ValueError('No computer was provided for setting up a transfer to a remote.') | ||
builder.metadata['computer'] = computer | ||
|
||
elif isinstance(data_source, RemoteData): | ||
if computer is not None: | ||
warnings.warn( | ||
f'Computer `{computer}` provided will be ignored ' | ||
f'(using `{data_source.computer}` from the RemoteData input `{data_source}`)' | ||
) | ||
builder.metadata['computer'] = data_source.computer | ||
|
||
if track: | ||
builder.instructions = generate_instructions(data_source)['instructions'] | ||
else: | ||
builder.instructions = generate_instructions_untracked(data_source) | ||
|
||
return builder | ||
|
||
|
||
############################################################################################################## | ||
def generate_instructions_untracked(source_folder): | ||
"""Generate the instruction node to be used for copying the files.""" | ||
|
||
# Paths in the QE run folder | ||
schema_qepath = 'out/aiida.save/data-file-schema.xml' | ||
charge_qepath = 'out/aiida.save/charge-density.dat' | ||
pawtxt_qepath = 'out/aiida.save/paw.txt' | ||
|
||
# Paths in the local node | ||
schema_dbpath = 'data-file-schema.xml' | ||
charge_dbpath = 'charge-density.dat' | ||
pawtxt_dbpath = 'paw.txt' | ||
|
||
# Transfer from local to remote | ||
if isinstance(source_folder, FolderData): | ||
instructions = {'retrieve_files': False, 'local_files': []} | ||
instructions['local_files'].append(('source_node', schema_dbpath, schema_qepath)) | ||
instructions['local_files'].append(('source_node', charge_dbpath, charge_qepath)) | ||
|
||
if 'paw.txt' in source_folder.list_object_names(): | ||
instructions['local_files'].append(('source_node', pawtxt_dbpath, pawtxt_qepath)) | ||
|
||
# Transfer from remote to local | ||
elif isinstance(source_folder, RemoteData): | ||
instructions = {'retrieve_files': True, 'symlink_files': []} | ||
instructions['symlink_files'].append(('source_node', schema_qepath, schema_dbpath)) | ||
instructions['symlink_files'].append(('source_node', charge_qepath, charge_dbpath)) | ||
instructions['symlink_files'].append(('source_node', pawtxt_qepath, pawtxt_dbpath)) | ||
|
||
return Dict(dict=instructions) | ||
|
||
|
||
@calcfunction | ||
def generate_instructions(source_folder): | ||
"""Auxiliary function to keep provenance track of the generation of the instructions.""" | ||
output_node = generate_instructions_untracked(source_folder) | ||
return {'instructions': output_node} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
# -*- coding: utf-8 -*- | ||
############################################################################################################## | ||
"""Workchain to generate a restart remote folder for a Quantum ESPRESSO calculation.""" | ||
|
||
from aiida.orm import RemoteData, FolderData | ||
from aiida.engine import WorkChain, ToContext | ||
from aiida.plugins import WorkflowFactory, CalculationFactory | ||
|
||
from aiida_quantumespresso.workflows.protocols.utils import ProtocolMixin | ||
from aiida_quantumespresso.utils.transfer import get_transfer_builder | ||
|
||
TransferCalcjob = CalculationFactory('core.transfer') | ||
PwBaseWorkChain = WorkflowFactory('quantumespresso.pw.base') | ||
|
||
# pylint: disable=f-string-without-interpolation | ||
|
||
|
||
############################################################################################################## | ||
class RestartSetupWorkChain(ProtocolMixin, WorkChain): | ||
"""Workchain to generate a restart remote folder for a Quantum ESPRESSO calculation. | ||
It consists of two steps: | ||
1. TransferCalcjob: takes the content of a ``FolderData`` node and copies it into the remote computer | ||
into a ``RemoteData`` folder with the correct folder structure. The original ``FolderData`` needs | ||
to have the necessary files in the right internal path (check the ``get_transfer_builder`` utility | ||
function and/or use it to retrieve the densities to have this already taken care of) | ||
2. PwBaseWorkChain: it runs an NSCF calculation using the previously created ``RemoteData`` as its | ||
``parent_folder``. This will re-generate all the wavefunctions in the running directory, which | ||
are necessary for launching any other kind of QE calculation in restart mode. | ||
""" | ||
|
||
@classmethod | ||
def define(cls, spec): | ||
"""Define the process specification.""" | ||
super().define(spec) | ||
|
||
spec.expose_inputs( | ||
TransferCalcjob, | ||
namespace='transfer', | ||
namespace_options={ | ||
'validator': validate_transfer, | ||
'populate_defaults': False, | ||
'help': 'Inputs for the `TransferCalcjob` to put the data on the cluster.', | ||
} | ||
) | ||
|
||
spec.expose_inputs( | ||
PwBaseWorkChain, | ||
namespace='nscf', | ||
exclude=('clean_workdir', 'pw.parent_folder'), | ||
namespace_options={ | ||
'validator': validate_nscf, | ||
'populate_defaults': False, | ||
'help': 'Inputs for the `PwBaseWorkChain` for the NSCF calculation.', | ||
} | ||
) | ||
|
||
spec.inputs.validator = validate_inputs | ||
|
||
spec.outline( | ||
cls.run_transfer, | ||
cls.inspect_transfer, | ||
cls.run_nscf, | ||
cls.inspect_nscf, | ||
cls.results, | ||
) | ||
|
||
spec.output( | ||
'remote_data', | ||
valid_type=RemoteData, | ||
help='The output node with the folder to be used as parent folder for other calculations.', | ||
) | ||
|
||
spec.exit_code(401, 'ERROR_SUB_PROCESS_FAILED_TRANSFER', message='The TransferCalcjob sub process failed.') | ||
spec.exit_code(402, 'ERROR_SUB_PROCESS_FAILED_NSCF', message='The ncf PwBasexWorkChain sub process failed.') | ||
|
||
@classmethod | ||
def get_protocol_filepath(cls): | ||
"""Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" | ||
raise NotImplementedError(f'`get_protocol_filepath` method not yet implemented in `RestartSetupWorkChain`') | ||
|
||
@classmethod | ||
def get_builder_from_protocol(cls, folder_data, structure, code, protocol=None, overrides=None, **kwargs): | ||
"""Return a builder prepopulated with inputs selected according to the chosen protocol. | ||
:param data_source: the ``FolderData`` node containing the density (and the rest of the restart data). | ||
:param structure: the ``StructureData`` instance required to run the NSF calculation. | ||
:param code: the ``Code`` instance configured for the ``quantumespresso.pw`` plugin, required to | ||
run the NSF calculation. | ||
:param protocol: protocol to use, if not specified, the default will be used. | ||
:param overrides: optional dictionary of inputs to override the defaults of the protocol. | ||
:param kwargs: additional keyword arguments that will be passed to the ``get_builder_from_protocol`` | ||
of all the sub processes that are called by this workchain. | ||
:return: a process builder instance with all inputs defined ready for launch. | ||
""" | ||
# inputs = cls.get_protocol_inputs(protocol, overrides) | ||
|
||
builder = cls.get_builder() | ||
|
||
track = kwargs.get('track', False) | ||
transfer = get_transfer_builder(folder_data, computer=code.computer, track=track) | ||
transfer['metadata']['options']['resources'] = {} | ||
builder.transfer = transfer | ||
|
||
nscf_args = (code, structure, protocol) | ||
nscf_kwargs = kwargs | ||
nscf_kwargs['overrides'] = {} | ||
if overrides is not None: | ||
nscf_kwargs['overrides'] = overrides.get('nscf', None) | ||
|
||
# This is for easily setting defaults at each level of: | ||
# [overrides.nscf].pw.parameters.CONTROL.calculation | ||
last_layer = nscf_kwargs['overrides'] | ||
last_layer = last_layer.setdefault('pw', {}) | ||
last_layer = last_layer.setdefault('parameters', {}) | ||
last_layer = last_layer.setdefault('CONTROL', {}) | ||
|
||
if last_layer.setdefault('calculation', 'nscf') != 'nscf': | ||
bad_value = last_layer['calculation'] | ||
raise ValueError( | ||
f'The internal PwBaseWorkChain is for running an NSCF calculation, ' | ||
f'this should not be overriden. ' | ||
f'(Found overrides.nscf.pw.parameters.CONTROL.calculation=`{bad_value}`)' | ||
) | ||
|
||
nscf = PwBaseWorkChain.get_builder_from_protocol(*nscf_args, **nscf_kwargs) | ||
nscf['pw'].pop('parent_folder', None) | ||
nscf.pop('clean_workdir', None) | ||
builder.nscf = nscf | ||
|
||
return builder | ||
|
||
def run_transfer(self): | ||
"""Run the TransferCalcjob to put the data in the remote computer.""" | ||
inputs = self.exposed_inputs(TransferCalcjob, namespace='transfer') | ||
running = self.submit(TransferCalcjob, **inputs) | ||
self.report(f'launching TransferCalcjob<{running.pk}> for put the data into the remote computer') | ||
return ToContext(transfer_calcjob=running) | ||
|
||
def inspect_transfer(self): | ||
"""Verify that the TransferCalcjob to get data finished successfully.""" | ||
calcjob_node = self.ctx.transfer_calcjob | ||
|
||
if not calcjob_node.is_finished_ok: | ||
self.report(f'TransferCalcjob failed with exit status {calcjob_node.exit_status}') | ||
return self.exit_codes.ERROR_SUB_PROCESS_FAILED_TRANSFER | ||
|
||
self.ctx.remote_parent = calcjob_node.outputs.remote_folder | ||
|
||
def run_nscf(self): | ||
"""Run the PwBaseWorkChain in nscf mode on the restart folder.""" | ||
inputs = self.exposed_inputs(PwBaseWorkChain, namespace='nscf') | ||
inputs['metadata']['call_link_label'] = 'nscf' | ||
inputs['pw']['parent_folder'] = self.ctx.remote_parent | ||
|
||
running = self.submit(PwBaseWorkChain, **inputs) | ||
self.report(f'launching PwBaseWorkChain<{running.pk}> in nscf mode') | ||
return ToContext(workchain_nscf=running) | ||
|
||
def inspect_nscf(self): | ||
"""Verify that the PwBaseWorkChain for the scf run finished successfully.""" | ||
workchain = self.ctx.workchain_nscf | ||
|
||
if not workchain.is_finished_ok: | ||
self.report(f'scf PwBaseWorkChain failed with exit status {workchain.exit_status}') | ||
return self.exit_codes.ERROR_SUB_PROCESS_FAILED_NSCF | ||
|
||
self.ctx.remote_data = workchain.outputs.remote_folder | ||
|
||
def results(self): | ||
"""Attach the desired output nodes directly as outputs of the workchain.""" | ||
self.report('workchain succesfully completed') | ||
self.out('remote_data', self.ctx.remote_data) | ||
|
||
|
||
############################################################################################################## | ||
def validate_transfer(value, _): | ||
"""Validate the inputs of the transfer input namespace.""" | ||
|
||
# Check that the source node is there and is of right type | ||
if 'source_nodes' not in value: | ||
return f'The inputs of the transfer namespace were not set correctly: {value}' | ||
|
||
source_nodes = value['source_nodes'] | ||
if 'source_node' not in source_nodes: | ||
return f'The `source_nodes` in the transfer namespace was not set correctly: {source_nodes}' | ||
|
||
source_node = source_nodes['source_node'] | ||
if not isinstance(source_node, FolderData): | ||
return f'The `source_node` in the transfer namespace is not `FolderData`: {source_node}' | ||
|
||
# Check that the files are in the source node | ||
error_message = '' | ||
if 'data-file-schema.xml' not in source_node.list_object_names(): | ||
error_message += f'Missing `data-file-schema.xml` on node PK={source_node.pk}\n' | ||
|
||
if 'charge-density.dat' not in source_node.list_object_names(): | ||
error_message += f'Missing `charge-density.dat` on node PK={source_node.pk}\n' | ||
|
||
if len(error_message) > 0: | ||
return error_message | ||
|
||
|
||
def validate_nscf(value, _): | ||
"""Validate the inputs of the nscf input namespace.""" | ||
parameters = value['pw']['parameters'].get_dict() | ||
if parameters.get('CONTROL', {}).get('calculation', 'scf') != 'nscf': | ||
return '`CONTOL.calculation` in `nscf.pw.parameters` is not set to `nscf`.' | ||
|
||
|
||
def validate_inputs(inputs, _): | ||
"""Validate the inputs of the entire input namespace.""" | ||
computer_transfer = inputs['transfer']['metadata']['computer'] | ||
computer_nscf = inputs['nscf']['pw']['code'].computer | ||
|
||
if computer_transfer.pk != computer_nscf.pk: | ||
return ( | ||
f'The computer where the files are being copied ({computer_transfer}) ' | ||
f'is not where the code resides ({computer_nscf})' | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters