Skip to content

Commit

Permalink
Update some logic in WorkChains.
Browse files Browse the repository at this point in the history
The old logic confused spin_orbit_coupling with spin_non_collinear,
which did not affect the operation workflows because the
spin_non_collinear in not implemented in the code.
By handling this confused logic with the correct logic, it will
facilitate future code extension (for non-collinear).
  • Loading branch information
jiang-yuha0 committed Jun 20, 2024
1 parent 05df5c0 commit 8f03829
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 11 deletions.
34 changes: 26 additions & 8 deletions src/aiida_wannier90_workflows/utils/pseudo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def get_wannier_number_of_bands(
factor=1.2,
only_valence=False,
spin_polarized=False,
spin_non_collinear: bool = False,
spin_orbit_coupling: bool = False,
):
"""Estimate number of bands for a Wannier90 calculation.
Expand All @@ -220,6 +221,8 @@ def get_wannier_number_of_bands(
:type only_valence: bool
:param spin_polarized: magnetic calculation?
:type spin_polarized: bool
:param spin_non_collinear: non-collinear or spin-orbit-coupling
:type spin_non_collinear: bool
:param spin_orbit_coupling: spin orbit coupling calculation?
:type spin_orbit_coupling: bool
:return: number of bands for Wannier90 SCDM
Expand All @@ -236,8 +239,10 @@ def get_wannier_number_of_bands(
raise ValueError("Should use SOC pseudo for SOC calculation")

num_electrons = get_number_of_electrons(structure, pseudos)
num_projections = get_number_of_projections(structure, pseudos, spin_orbit_coupling)
nspin = 2 if (spin_polarized or spin_orbit_coupling) else 1
num_projections = get_number_of_projections(
structure, pseudos, spin_non_collinear, spin_orbit_coupling
)
nspin = 2 if (spin_polarized or spin_non_collinear) else 1
# TODO check nospin, spin, soc # pylint: disable=fixme
if only_valence:
num_bands = int(0.5 * num_electrons * nspin)
Expand All @@ -259,6 +264,7 @@ def get_wannier_number_of_bands_ext(
factor=1.2,
only_valence=False,
spin_polarized=False,
spin_non_collinear: bool = False,
spin_orbit_coupling: bool = False,
):
"""Estimate number of bands for a Wannier90 calculation.
Expand All @@ -272,6 +278,8 @@ def get_wannier_number_of_bands_ext(
:type only_valence: bool
:param spin_polarized: magnetic calculation?
:type spin_polarized: bool
:param spin_non_collinear: non-collinear or spin-orbit-coupling
:type spin_non_collinear: bool
:param spin_orbit_coupling: spin orbit coupling calculation?
:type spin_orbit_coupling: bool
:return: number of bands for Wannier90 SCDM
Expand All @@ -289,9 +297,9 @@ def get_wannier_number_of_bands_ext(

num_electrons = get_number_of_electrons(structure, pseudos)
num_projections = get_number_of_projections_ext(
structure, external_projectors, spin_orbit_coupling
structure, external_projectors, spin_non_collinear, spin_orbit_coupling
)
nspin = 2 if (spin_polarized or spin_orbit_coupling) else 1
nspin = 2 if (spin_polarized or spin_non_collinear) else 1
# TODO check nospin, spin, soc # pylint: disable=fixme
if only_valence:
num_bands = int(0.5 * num_electrons * nspin)
Expand All @@ -309,6 +317,7 @@ def get_wannier_number_of_bands_ext(
def get_number_of_projections(
structure: orm.StructureData,
pseudos: ty.Mapping[str, orm.UpfData],
spin_non_collinear: bool,
spin_orbit_coupling: ty.Optional[bool] = None,
) -> int:
"""Get number of projections for the structure with the given pseudopotential files.
Expand All @@ -320,6 +329,8 @@ def get_number_of_projections(
:type structure: aiida.orm.StructureData
:param pseudos: a dictionary contains orm.UpfData of the structure
:type pseudos: dict
:param spin_non_collinear: non-collinear or spin-orbit-coupling
:type spin_non_collinear: bool
:return: number of projections
:rtype: int
"""
Expand Down Expand Up @@ -357,12 +368,14 @@ def get_number_of_projections(
upf = pseudos[kind]
nprojs = get_number_of_projections_from_upf(upf)
soc = is_soc_pseudo(get_upf_content(pseudos[kind]))
if spin_orbit_coupling and not soc:
# For SOC calculation with non-SOC pseudo, QE will generate
if spin_non_collinear and not soc:
# For magnetic calculation with non-SOC pseudo, QE will generate
# 2 PSWFCs from each one PSWFC in the pseudo
# For collinear-magnetic calculation, spin up and down will calc
# seperately, so nprojs do not times 2
nprojs *= 2
elif not spin_orbit_coupling and soc:
# For non-SOC calculation with SOC pseudo, QE will average
elif not spin_non_collinear and soc:
# For non-magnetic calculation with SOC pseudo, QE will average
# the 2 PSWFCs into one
nprojs //= 2
tot_nprojs += nprojs * composition[kind]
Expand All @@ -373,6 +386,7 @@ def get_number_of_projections(
def get_number_of_projections_ext(
structure: orm.StructureData,
external_projectors: dict,
spin_non_collinear: bool,
spin_orbit_coupling: bool = False,
) -> int:
"""Get number of projections for the structure with the given projector dict.
Expand All @@ -381,6 +395,8 @@ def get_number_of_projections_ext(
:type structure: aiida.orm.StructureData
:param projectors: a dictionary contains projector list of the structure
:type pseudos: dict
:param spin_non_collinear: non-collinear or spin-orbit-coupling
:type spin_non_collinear: bool
:return: number of projections
:rtype: int
"""
Expand All @@ -407,6 +423,8 @@ def get_number_of_projections_ext(
nprojs += round(2 * orb["j"]) + 1
else:
nprojs += 2 * orb["l"] + 1
if spin_non_collinear and not spin_orbit_coupling:
nprojs *= 2
tot_nprojs += nprojs * composition[kind]

return tot_nprojs
Expand Down
4 changes: 4 additions & 0 deletions src/aiida_wannier90_workflows/workflows/base/wannier90.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,13 @@ def get_builder_from_protocol(
factor=meta_parameters["num_bands_factor"],
only_valence=only_valence,
spin_polarized=spin_polarized,
spin_non_collinear=spin_non_collinear,
spin_orbit_coupling=spin_orbit_coupling,
)
num_projs = get_number_of_projections_ext(
structure=structure,
external_projectors=external_projectors,
spin_non_collinear=spin_non_collinear,
spin_orbit_coupling=spin_orbit_coupling,
)
else:
Expand All @@ -344,11 +346,13 @@ def get_builder_from_protocol(
factor=meta_parameters["num_bands_factor"],
only_valence=only_valence,
spin_polarized=spin_polarized,
spin_non_collinear=spin_non_collinear,
spin_orbit_coupling=spin_orbit_coupling,
)
num_projs = get_number_of_projections(
structure=structure,
pseudos=pseudos,
spin_non_collinear=spin_non_collinear,
spin_orbit_coupling=spin_orbit_coupling,
)

Expand Down
19 changes: 16 additions & 3 deletions src/aiida_wannier90_workflows/workflows/wannier90.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,11 +995,22 @@ def sanity_check(self): # pylint: disable=inconsistent-return-statements
check_num_projs = True
if self.should_run_scf():
pseudos = self.inputs["scf"]["pw"]["pseudos"]
spin_orbit_coupling = (
self.inputs["scf"]["pw"]["parameters"]
.get_dict()["SYSTEM"]
.get("SYSTEM", False)
)
elif self.should_run_nscf():
pseudos = self.inputs["nscf"]["pw"]["pseudos"]
spin_orbit_coupling = (
self.inputs["nscf"]["pw"]["parameters"]
.get_dict()["SYSTEM"]
.get("SYSTEM", False)
)
else:
check_num_projs = False
pseudos = None # to avoid pylint errors
pseudos = None
spin_orbit_coupling = None
if check_num_projs:
args = {
"structure": self.ctx.current_structure,
Expand All @@ -1014,9 +1025,11 @@ def sanity_check(self): # pylint: disable=inconsistent-return-statements
params = self.ctx.workchain_wannier90.inputs["wannier90"][
"parameters"
].get_dict()
spin_orbit_coupling = params.get("spinors", False)
spin_non_collinear = params.get("spinors", False)
number_of_projections = get_number_of_projections(
**args, spin_orbit_coupling=spin_orbit_coupling
**args,
spin_non_collinear=spin_non_collinear,
spin_orbit_coupling=spin_orbit_coupling,
)
if number_of_projections != num_proj:
self.report(
Expand Down

0 comments on commit 8f03829

Please sign in to comment.