Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common Force Sets workflow #256

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
133 changes: 133 additions & 0 deletions aiida_common_workflows/workflows/phonons/common_force_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
"""Equation of state workflow that can use any code plugin implementing the common relax workflow."""
import inspect

from aiida import orm
from aiida.common import exceptions
from aiida.engine import WorkChain, append_, calcfunction
from aiida.plugins import WorkflowFactory

from aiida_common_workflows.workflows.relax.generator import RelaxType, SpinType, ElectronicType
from aiida_common_workflows.workflows.relax.workchain import CommonRelaxWorkChain

ForceSetsWorkChain = WorkflowFactory('phonopy.force_sets')

def validate_common_inputs(value, _):
"""Validate the entire input namespace."""
# Validate that the provided ``generator_inputs`` are valid for the associated input generator.
process_class = WorkflowFactory(value['sub_process_class'])
generator = process_class.get_input_generator()

try:
generator.get_builder(value['structure'], **value['generator_inputs'])
except Exception as exc: # pylint: disable=broad-except
return f'`{generator.__class__.__name__}.get_builder()` fails for the provided `generator_inputs`: {exc}'

def validate_sub_process_class(value, _):
"""Validate the sub process class."""
try:
process_class = WorkflowFactory(value)
except exceptions.EntryPointError:
return f'`{value}` is not a valid or registered workflow entry point.'

if not inspect.isclass(process_class) or not issubclass(process_class, CommonRelaxWorkChain):
return f'`{value}` is not a subclass of the `CommonRelaxWorkChain` common workflow.'


class CommonForceSetsWorkChain(ForceSetsWorkChain):
"""
Workflow to compute automatically the force set of a given structure
using the frozen phonons approach.

Phonopy is used to produce structures with displacements,
while the forces are calculated with a quantum engine of choice.
"""

_RUN_PREFIX = 'force_calc'

@classmethod
def define(cls, spec):
# yapf: disable
super().define(spec)
spec.input_namespace('generator_inputs',
help='The inputs that will be passed to the input generator of the specified `sub_process`.')
spec.input('generator_inputs.engines', valid_type=dict, non_db=True)
spec.input('generator_inputs.protocol', valid_type=str, non_db=True,
help='The protocol to use when determining the workchain inputs.')
spec.input('generator_inputs.spin_type', valid_type=(SpinType, str), required=False, non_db=True,
help='The type of spin for the calculation.')
spec.input('generator_inputs.electronic_type', valid_type=(ElectronicType, str), required=False, non_db=True,
help='The type of electronics (insulator/metal) for the calculation.')
spec.input('generator_inputs.magnetization_per_site', valid_type=(list, tuple), required=False, non_db=True,
help='List containing the initial magnetization per atomic site.')
spec.input_namespace('sub_process', dynamic=True, populate_defaults=False)
spec.input('sub_process_class', non_db=True, validator=validate_sub_process_class)
spec.inputs.validator = validate_common_inputs

spec.exit_code(400, 'ERROR_SUB_PROCESS_FAILED', # can't we say exactly which are not finished ok?
message='At least one of the `{cls}` sub processes did not finish successfully.')


def get_sub_workchain_builder(self, structure):
"""Return the builder for the scf workchain."""
process_class = WorkflowFactory(self.inputs.sub_process_class)

relax_type = {'relax_type':RelaxType.NONE} # scf type

builder = process_class.get_input_generator().get_builder(
structure,
**self.inputs.generator_inputs,
**relax_type,
)
builder._update(**self.inputs.get('sub_process', {})) # pylint: disable=protected-access

return builder

def collect_forces_and_energies(self):
"""Collect forces and energies from calculation outputs."""
forces_dict = {}

for key, workchain in self.ctx.items(): # key: e.g. "supercell_001"
if key.startswith(self._RUN_PREFIX):
num = key.split("_")[-1] # e.g. "001"

output = workchain.outputs

forces_dict[f'forces_{num}'] = output["forces"]
forces_dict[f'energy_{num}'] = output["total_energy"]

return forces_dict

def run_forces(self):
"""Run supercell force calculations."""
for key, supercell in self.ctx.supercells.items():
num = key.split("_")[-1]
if num == key:
num = 0
label = self._RUN_PREFIX + "_%s" % num
builder = self.get_sub_workchain_builder(supercell)
builder.metadata.label = label # very necessary?
future = self.submit(builder)
self.report(f"submitting `{builder.process_class.__name__}` <PK={future.pk}> with {key} as structure")
self.to_context(**{label: future})

def inspect_forces(self):
"""Inspect all children workflows to make sure they finished successfully."""
failed_runs = []

for label, workchain in self.ctx.items():
if label.startswith(self._RUN_PREFIX):
if workchain.is_finished_ok:
forces = workchain.outputs.forces
self.out(f'supercells_forces.{label}', forces)
else:
failed_runs.append(workchain.pk)

if failed_runs:
self.report("workchain(s) with <PK={}> did not finish correctly".format(failed_runs))
return self.exit_codes.ERROR_SUB_PROCESS_FAILED.format(cls=self.inputs.sub_process_class) # pylint: disable=no-member

self.ctx.forces = self.collect_forces_and_energies()



1 change: 1 addition & 0 deletions setup.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"aiida.workflows": [
"common_workflows.dissociation_curve = aiida_common_workflows.workflows.dissociation:DissociationCurveWorkChain",
"common_workflows.eos = aiida_common_workflows.workflows.eos:EquationOfStateWorkChain",
"common_workflows.phonons.force_sets = aiida_common_workflows.workflows.phonons.common_force_sets:CommonForceSetsWorkChain",
"common_workflows.relax.abinit = aiida_common_workflows.workflows.relax.abinit.workchain:AbinitCommonRelaxWorkChain",
"common_workflows.relax.bigdft = aiida_common_workflows.workflows.relax.bigdft.workchain:BigDftCommonRelaxWorkChain",
"common_workflows.relax.castep = aiida_common_workflows.workflows.relax.castep.workchain:CastepCommonRelaxWorkChain",
Expand Down
Empty file.
Empty file added tests/workflows/eos/__init__.py
Empty file.
Empty file.
180 changes: 180 additions & 0 deletions tests/workflows/phonons/test_workchain_force_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# -*- coding: utf-8 -*-
# pylint: disable=redefined-outer-name
"""Tests for the :mod:`aiida_common_workflows.workflows.common_force_sets` module."""
import pytest

from aiida.engine import WorkChain
from aiida.plugins import WorkflowFactory

from aiida_common_workflows.plugins import get_workflow_entry_point_names
from aiida_common_workflows.workflows.phonons import common_force_sets
from aiida_common_workflows.workflows.relax.workchain import CommonRelaxWorkChain


@pytest.fixture
def ctx():
"""Return the context for a port validator."""
return None


@pytest.fixture(scope='function', params=get_workflow_entry_point_names('relax'))
def common_relax_workchain(request) -> CommonRelaxWorkChain:
"""Fixture that parametrizes over all the registered implementations of the ``CommonRelaxWorkChain``."""
return WorkflowFactory(request.param)


@pytest.fixture
def generate_workchain():
"""Generate an instance of a `WorkChain`."""

def _generate_workchain(entry_point, inputs):
"""Generate an instance of a `WorkChain` with the given entry point and inputs.

:param entry_point: entry point name of the work chain subclass.
:param inputs: inputs to be passed to process construction.
:return: a `WorkChain` instance.
"""
from aiida.engine.utils import instantiate_process
from aiida.manage.manager import get_manager
from aiida.plugins import WorkflowFactory

process_class = WorkflowFactory(entry_point)
runner = get_manager().get_runner()
process = instantiate_process(runner, process_class, **inputs)

return process

return _generate_workchain

@pytest.fixture
def generate_workchain_force_sets(generate_workchain, generate_structure, generate_code):
"""Generate an instance of a `ForceSetsWorkChain`."""

def _generate_workchain_force_sets(append_inputs=None, return_inputs=False):
from aiida.orm import List
entry_point = 'common_workflows.phonons.force_sets'

inputs = {
'structure': generate_structure(symbols=('Si',)),
'supercell_matrix': List(list=[1,1,1]),
'sub_process_class': 'common_workflows.relax.quantum_espresso',
'generator_inputs': {
'engines': {
'relax': {
'code': generate_code('quantumespresso.pw').store(),
'options': {
'resources': {
'num_machines': 1
}
}
}
},
'electronic_type': 'insulator',
'protocol': 'moderate',
},
}

if return_inputs:
return inputs

if append_inputs is not None:
inputs.update(append_inputs)

process = generate_workchain(entry_point, inputs)

return process

return _generate_workchain_force_sets


def test_validate_sub_process_class(ctx):
"""Test the `validate_sub_process_class` validator."""
for value in [None, WorkChain]:
message = f'`{value}` is not a valid or registered workflow entry point.'
assert common_force_sets.validate_sub_process_class(value, ctx) == message


def test_validate_sub_process_class_plugins(ctx, common_relax_workchain):
"""Test the `validate_sub_process_class` validator."""
from aiida_common_workflows.plugins import get_entry_point_name_from_class
assert common_force_sets.validate_sub_process_class(get_entry_point_name_from_class(common_relax_workchain).name, ctx) is None


@pytest.mark.usefixtures('sssp')
def test_run_forces(generate_workchain_force_sets):
"""Test `CommonForceSetsWorkChain.run_forces`."""
process = generate_workchain_force_sets()
process.setup()
process.run_forces()

for key in ['cells', 'primitive_matrix', 'displacement_dataset']:
assert key in process.outputs

for key in ['primitive', 'supercell', 'supercell_1']:
assert key in process.outputs['cells']

# Double check for the `setup` method (already tested in `aiida-phonopy`).
assert 'primitive' not in process.ctx.supercells
assert 'supercell' not in process.ctx.supercells
assert 'supercell_1' in process.ctx.supercells

# Check for
assert 'force_calc_1' in process.ctx

@pytest.mark.usefixtures('sssp')
def test_outline(generate_workchain_force_sets):
"""Test `CommonForceSetsWorkChain` outline."""
from plumpy.process_states import ProcessState
from aiida.common import LinkType
from aiida.orm import WorkflowNode, ArrayData, Float
import numpy as np

process = generate_workchain_force_sets()

node = WorkflowNode().store()
node.label = 'force_calc_1'
forces = ArrayData()
forces.set_array('forces', np.array([[0.,0.,0.],[0.,0.,0.]]))
forces.store()
forces.add_incoming(node, link_type=LinkType.RETURN, link_label='forces')
energy = Float(0.).store()
energy.add_incoming(node, link_type=LinkType.RETURN, link_label='total_energy')

node.set_process_state(ProcessState.FINISHED)
node.set_exit_status(0)

process.ctx.force_calc_1 = node

process.inspect_forces()

assert 'force_calc_1' in process.outputs['supercells_forces']
assert 'forces' in process.ctx
assert 'forces_1' in process.ctx.forces

process.run_results()

assert 'force_sets' in process.outputs


@pytest.mark.usefixtures('sssp')
def test_run_outline_with_subtracting_residual_forces(generate_workchain_force_sets):
"""Test `CommonForceSetsWorkChain.run_forces`."""
from aiida.orm import Bool
process = generate_workchain_force_sets(append_inputs={'subtract_residual_forces':Bool(True)})
process.setup()
process.run_forces()

for key in ['cells', 'primitive_matrix', 'displacement_dataset']:
assert key in process.outputs

for key in ['primitive', 'supercell', 'supercell_1']:
assert key in process.outputs['cells']

# Double check for the `setup` method (already tested in `aiida-phonopy`).
assert 'primitive' not in process.ctx.supercells
assert 'supercell' in process.ctx.supercells
assert 'supercell_1' in process.ctx.supercells

# Check for
assert 'force_calc_0' in process.ctx
assert 'force_calc_1' in process.ctx