From b29b5614b505c78786a27b7b995da7b26e065768 Mon Sep 17 00:00:00 2001 From: AndresOrtegaGuerrero Date: Fri, 2 Feb 2024 12:49:43 +0100 Subject: [PATCH] compatible --- .../functions/seekpath_structure_analysis.py | 47 +++++++++- .../functions/test_seekpath_analysis.py | 89 +++++++++++++++++++ 2 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 tests/calculations/functions/test_seekpath_analysis.py diff --git a/src/aiida_quantumespresso/calculations/functions/seekpath_structure_analysis.py b/src/aiida_quantumespresso/calculations/functions/seekpath_structure_analysis.py index 4237342be..e4f5a93b4 100644 --- a/src/aiida_quantumespresso/calculations/functions/seekpath_structure_analysis.py +++ b/src/aiida_quantumespresso/calculations/functions/seekpath_structure_analysis.py @@ -3,6 +3,8 @@ from aiida.engine import calcfunction from aiida.orm import Data +from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData + @calcfunction def seekpath_structure_analysis(structure, **kwargs): @@ -28,4 +30,47 @@ def seekpath_structure_analysis(structure, **kwargs): # All keyword arugments should be `Data` node instances of base type and so should have the `.value` attribute unwrapped_kwargs = {key: node.value for key, node in kwargs.items() if isinstance(node, Data)} - return get_explicit_kpoints_path(structure, **unwrapped_kwargs) + result = get_explicit_kpoints_path(structure, **unwrapped_kwargs) + + # If the input structure was a HubbardStructureData, update the primitive structure with Hubbard parameters + if isinstance(structure, HubbardStructureData): + update_structure_with_hubbard(structure, result) + + return result + + +def update_structure_with_hubbard(structure, result): + """Update the structure based on Hubbard parameters if the input structure is a HubbardStructureData.""" + hubbard_parameters = structure.hubbard.parameters + if not hubbard_parameters: + return + + hubbard_structure = HubbardStructureData.from_structure(result['primitive_structure']) + + for parameter in hubbard_parameters: + atom_index = parameter.atom_index + atom_name = structure.sites[atom_index].kind_name + + if parameter.hubbard_type != 'V': + hubbard_structure.initialize_onsites_hubbard( + atom_name=atom_name, + atom_manifold=parameter.atom_manifold, + value=parameter.value, + hubbard_type=parameter.hubbard_type, + use_kinds=True, + ) + else: + neighbour_index = parameter.neighbour_index + neighbour_name = structure.sites[neighbour_index].kind_name + + hubbard_structure.initialize_intersites_hubbard( + atom_name=atom_name, + atom_manifold=parameter.atom_manifold, + neighbour_name=neighbour_name, + neighbour_manifold=parameter.neighbour_manifold, + value=parameter.value, + hubbard_type=parameter.hubbard_type, + use_kinds=True, + ) + + result['primitive_structure'] = hubbard_structure diff --git a/tests/calculations/functions/test_seekpath_analysis.py b/tests/calculations/functions/test_seekpath_analysis.py new file mode 100644 index 000000000..24ec3bbc1 --- /dev/null +++ b/tests/calculations/functions/test_seekpath_analysis.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +"""Tests for the `seekpath_structure_analysis` function for HubbbardStructureData.""" +import pytest + +from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis +from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData + + +@pytest.fixture +def generate_hubbardstructure_conv(): + """Return a `HubbardStructureData` instance in conventional cell.""" + + def _generate_hubbardstructure_conv(): + cell = [[7.96416, 0.0, 4.87664153e-16], [-4.87664153e-16, 7.96416, 4.87664153e-16], [0.0, 0.0, 7.96416]] + sites = [['Li', 'Li', (2.98656, 0.99552, 6.96864)], ['Li', 'Li', (6.96864, 6.96864, 4.9776)], + ['Li', 'Li', (0.99552, 0.99552, 4.9776)], ['Li', 'Li', (4.9776, 6.96864, 6.96864)], + ['Li', 'Li', (2.98656, 4.9776, 2.98656)], ['Li', 'Li', (6.96864, 2.98656, 0.99552)], + ['Li', 'Li', (0.99552, 4.9776, 0.99552)], ['Li', 'Li', (4.9776, 2.98656, 2.98656)], + ['Li', 'Li', (6.96864, 0.99552, 2.98656)], ['Li', 'Li', (2.98656, 6.96864, 0.99552)], + ['Li', 'Li', (4.9776, 0.99552, 0.99552)], ['Li', 'Li', (0.99552, 6.96864, 2.98656)], + ['Li', 'Li', (6.96864, 4.9776, 6.96864)], ['Li', 'Li', (2.98656, 2.98656, 4.9776)], + ['Li', 'Li', (4.9776, 4.9776, 4.9776)], ['Li', 'Li', (0.99552, 2.98656, 6.96864)], + ['Co', 'Co', (2.98656, 6.96864, 4.9776)], ['Co', 'Co', (0.99552, 6.96864, 6.96864)], + ['Co', 'Co', (0.99552, 4.9776, 4.9776)], ['Co', 'Co', (2.98656, 4.9776, 6.96864)], + ['Co', 'Co', (2.98656, 2.98656, 0.99552)], ['Co', 'Co', (0.99552, 2.98656, 2.98656)], + ['Co', 'Co', (0.99552, 0.99552, 0.99552)], ['Co', 'Co', (2.98656, 0.99552, 2.98656)], + ['Co', 'Co', (6.96864, 6.96864, 0.99552)], ['Co', 'Co', (4.9776, 6.96864, 2.98656)], + ['Co', 'Co', (4.9776, 4.9776, 0.99552)], ['Co', 'Co', (6.96864, 4.9776, 2.98656)], + ['Co', 'Co', (6.96864, 2.98656, 4.9776)], ['Co', 'Co', (4.9776, 2.98656, 6.96864)], + ['Co', 'Co', (4.9776, 0.99552, 4.9776)], ['Co', 'Co', (6.96864, 0.99552, 6.96864)], + ['O', 'O', (3.06333, 0.91875, 4.90083)], ['O', 'O', (0.91875, 3.06333, 4.90083)], + ['O', 'O', (3.06333, 3.06333, 7.04541)], ['O', 'O', (0.91875, 0.91875, 7.04541)], + ['O', 'O', (5.05437, 1.07229, 6.89187)], ['O', 'O', (2.90979, 6.89187, 6.89187)], + ['O', 'O', (1.07229, 6.89187, 5.05437)], ['O', 'O', (6.89187, 1.07229, 5.05437)], + ['O', 'O', (3.06333, 4.90083, 0.91875)], ['O', 'O', (0.91875, 7.04541, 0.91875)], + ['O', 'O', (3.06333, 7.04541, 3.06333)], ['O', 'O', (0.91875, 4.90083, 3.06333)], + ['O', 'O', (5.05437, 5.05437, 2.90979)], ['O', 'O', (2.90979, 2.90979, 2.90979)], + ['O', 'O', (1.07229, 2.90979, 1.07229)], ['O', 'O', (6.89187, 5.05437, 1.07229)], + ['O', 'O', (7.04541, 0.91875, 0.91875)], ['O', 'O', (4.90083, 3.06333, 0.91875)], + ['O', 'O', (7.04541, 3.06333, 3.06333)], ['O', 'O', (4.90083, 0.91875, 3.06333)], + ['O', 'O', (1.07229, 1.07229, 2.90979)], ['O', 'O', (6.89187, 6.89187, 2.90979)], + ['O', 'O', (5.05437, 6.89187, 1.07229)], ['O', 'O', (2.90979, 1.07229, 1.07229)], + ['O', 'O', (7.04541, 4.90083, 4.90083)], ['O', 'O', (4.90083, 7.04541, 4.90083)], + ['O', 'O', (7.04541, 7.04541, 7.04541)], ['O', 'O', (4.90083, 4.90083, 7.04541)], + ['O', 'O', (1.07229, 5.05437, 6.89187)], ['O', 'O', (6.89187, 2.90979, 6.89187)], + ['O', 'O', (5.05437, 2.90979, 5.05437)], ['O', 'O', (2.90979, 5.05437, 5.05437)]] + structure = HubbardStructureData(cell=cell, sites=sites) + + initializations = [ + ('Co', '3d', 3.0, 'U', True), + ('O', '2p', 2.0, 'U', True), + ('Li', '2s', 1.5, 'U', True), + ] + + for atom_name, atom_manifold, value, hubbard_type, use_kinds in initializations: + structure.initialize_onsites_hubbard( + atom_name=atom_name, + atom_manifold=atom_manifold, + value=value, + hubbard_type=hubbard_type, + use_kinds=use_kinds, + ) + + return structure + + return _generate_hubbardstructure_conv + + +# pylint: disable=W0621 +@pytest.mark.usefixtures('aiida_profile') +def test_seekpath_analysis(generate_hubbardstructure_conv): + """Test the `seekpath_structure_analysis` calculation function for HubbardStructureData.""" + structure = generate_hubbardstructure_conv() + conventional_parameters = structure.hubbard.parameters + + result = seekpath_structure_analysis(structure) + primitive_parameters = result['primitive_structure'].hubbard.parameters + + assert isinstance( + result['primitive_structure'], HubbardStructureData + ), 'Primitive structure should be a HubbardStructureData' + assert conventional_parameters != primitive_parameters, 'Primitive parameters should be different' + assert len(primitive_parameters) == len( + conventional_parameters + ), 'Primitive parameters should have the same length as conventional parameters' + assert all( + conv_param.atom_manifold == prim_param.atom_manifold + for conv_param, prim_param in zip(conventional_parameters, primitive_parameters) + ), 'Atom manifold should match in primitive'