Skip to content

Commit

Permalink
Add tests for supercell_energies and forces attributes of API
Browse files Browse the repository at this point in the history
  • Loading branch information
atztogo committed Apr 30, 2024
1 parent eecc4b3 commit 1d35ca2
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 25 deletions.
82 changes: 59 additions & 23 deletions phono3py/api_phono3py.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import copy
from collections.abc import Sequence
from typing import Literal, Optional, Union

Expand Down Expand Up @@ -610,7 +611,20 @@ def dataset(self):

@dataset.setter
def dataset(self, dataset):
self._dataset = dataset
if dataset is None:
self._dataset = None
elif "first_atoms" in dataset:
self._dataset = copy.deepcopy(dataset)
elif "displacements" in dataset:
self._dataset = {}
self.displacements = dataset["displacements"]
if "forces" in dataset:
self.forces = dataset["forces"]
if "supercell_energies" in dataset:
self.supercell_energies = dataset["supercell_energies"]
else:
raise RuntimeError("Data format of dataset is wrong.")

self._supercells_with_displacements = None
self._phonon_supercells_with_displacements = None

Expand Down Expand Up @@ -643,7 +657,21 @@ def phonon_dataset(self):

@phonon_dataset.setter
def phonon_dataset(self, dataset):
self._phonon_dataset = dataset
if dataset is None:
self._phonon_dataset = None
elif "first_atoms" in dataset:
self._phonon_dataset = copy.deepcopy(dataset)
elif "displacements" in dataset:
self._phonon_dataset = {}
self.phonon_displacements = dataset["displacements"]
if "forces" in dataset:
self.phonon_forces = dataset["forces"]
if "supercell_energies" in dataset:
self.phonon_supercell_energies = dataset["supercell_energies"]
else:
raise RuntimeError("Data format of dataset is wrong.")

self._phonon_supercells_with_displacements = None

@property
def band_indices(self):
Expand Down Expand Up @@ -929,7 +957,7 @@ def phonon_forces(self):
be the same order of phonon_supercells_with_displacements.
"""
self._get_phonon_forces_energies(target="forces")
return self._get_phonon_forces_energies(target="forces")

@phonon_forces.setter
def phonon_forces(self, values):
Expand All @@ -948,7 +976,7 @@ def phonon_supercell_energies(self):
to be the same order of phonon_supercells_with_displacements.
"""
self._get_phonon_forces_energies(target="supercell_energies")
return self._get_phonon_forces_energies(target="supercell_energies")

@phonon_supercell_energies.setter
def phonon_supercell_energies(self, values):
Expand Down Expand Up @@ -2261,10 +2289,15 @@ def _extract_fc2_fc3_calculators(self, fc_calculator, fc_calculator_options, ord

def _get_forces_energies(
self, target: Literal["forces", "supercell_energies"]
) -> np.ndarray:
if target in self._dataset:
) -> Optional[np.ndarray]:
"""Return fc3 forces and supercell energies.
Return None if tagert data is not found rather than raising exception.
"""
if target in self._dataset: # type-2
return self._dataset[target]
elif "first_atoms" in self._dataset:
elif "first_atoms" in self._dataset: # type-1
num_scells = len(self._dataset["first_atoms"])
for disp1 in self._dataset["first_atoms"]:
num_scells += len(disp1["second_atoms"])
Expand All @@ -2274,30 +2307,31 @@ def _get_forces_energies(
dtype="double",
order="C",
)
type1_target = "forces"
elif target == "supercell_energies":
values = np.zeros(num_scells, dtype="double")
type1_target = "supercell_energy"
count = 0
for disp1 in self._dataset["first_atoms"]:
values[count] = disp1[target]
values[count] = disp1[type1_target]
count += 1
for disp1 in self._dataset["first_atoms"]:
for disp2 in disp1["second_atoms"]:
values[count] = disp2[target]
values[count] = disp2[type1_target]
count += 1
return values
else:
raise RuntimeError("FC3 displacement dataset is in wrong format.")
return None

def _set_forces_energies(
self, values, target: Literal["forces", "supercell_energies"]
):
if "first_atoms" in self._dataset:
if "first_atoms" in self._dataset: # type-1
count = 0
for disp1 in self._dataset["first_atoms"]:
if target == "forces":
disp1[target] = np.array(values[count], dtype="double", order="C")
elif target == "supercell_energies":
disp1[target] = float(values[count])
disp1["supercell_energy"] = float(values[count])
count += 1
for disp1 in self._dataset["first_atoms"]:
for disp2 in disp1["second_atoms"]:
Expand All @@ -2306,22 +2340,27 @@ def _set_forces_energies(
values[count], dtype="double", order="C"
)
elif target == "supercell_energies":
disp2[target] = float(values[count])
disp2["supercell_energy"] = float(values[count])
count += 1
elif "displacements" in self._dataset or "forces" in self._dataset:
elif "displacements" in self._dataset or "forces" in self._dataset: # type-2
self._dataset[target] = np.array(values, dtype="double", order="C")
else:
raise RuntimeError("Set of FC3 displacements is not available.")

def _get_phonon_forces_energies(
self, target: Literal["forces", "supercell_energies"]
):
) -> Optional[np.ndarray]:
"""Return fc2 forces and supercell energies.
Return None if tagert data is not found rather than raising exception.
"""
if self._phonon_dataset is None:
raise RuntimeError("Dataset for fc2does not exist.")

if target in self._phonon_dataset:
if target in self._phonon_dataset: # type-2
return self._phonon_dataset[target]
elif "first_atoms" in self._phonon_dataset:
elif "first_atoms" in self._phonon_dataset: # type-1
values = []
for disp in self._phonon_dataset["first_atoms"]:
if target == "forces":
Expand All @@ -2332,10 +2371,7 @@ def _get_phonon_forces_energies(
values.append(disp["supercell_energy"])
if values:
return np.array(values, dtype="double", order="C")
else:
None
else:
raise RuntimeError("FC2 displacement dataset is in wrong format.")
return None

def _set_phonon_forces_energies(
self, values, target: Literal["forces", "supercell_energies"]
Expand All @@ -2351,7 +2387,7 @@ def _set_phonon_forces_energies(
disp["supercell_energy"] = float(v)
elif "displacements" in self._phonon_dataset:
_values = np.array(values, dtype="double", order="C")
natom = len(self._supercell)
natom = len(self._phonon_supercell)
ndisps = len(self._phonon_dataset["displacements"])
if target == "forces" and (
_values.ndim != 3 or _values.shape != (ndisps, natom, 3)
Expand Down
126 changes: 126 additions & 0 deletions test/api/test_api_phono3py.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Tests of Phono3py API."""

from __future__ import annotations

from pathlib import Path

import numpy as np

from phono3py import Phono3py

cwd = Path(__file__).parent
Expand Down Expand Up @@ -40,3 +44,125 @@ def test_displacements_setter_Si(si_pbesol_111_222_fd: Phono3py):
)
ph3.displacements = displacements
ph3.phonon_displacements = phonon_displacements


def test_type1_forces_energies_setter_Si(si_111_222_fd: Phono3py):
"""Test type1 supercell_energies, phonon_supercell_energies attributes."""
ph3_in = si_111_222_fd
ref_ph_supercell_energies = [-346.85204143]
ref_supercell_energies = [
-43.3509760,
-43.33608775,
-43.35352904,
-43.34370672,
-43.34590849,
-43.34540162,
-43.34421408,
-43.34481089,
-43.34703607,
-43.34241924,
-43.34786243,
-43.34168203,
-43.34274245,
-43.34703607,
-43.34786243,
-43.34184454,
]
np.testing.assert_allclose(ph3_in.supercell_energies, ref_supercell_energies)
np.testing.assert_allclose(
ph3_in.phonon_supercell_energies, ref_ph_supercell_energies
)

ref_force00 = [-0.4109520800000000, 0.0000000100000000, 0.0000000300000000]
ref_force_last = [0.1521426300000000, 0.0715600600000000, -0.0715600700000000]
ref_ph_force00 = [-0.4027479600000000, 0.0000000200000000, 0.0000001000000000]
np.testing.assert_allclose(ph3_in.forces[0, 0], ref_force00)
np.testing.assert_allclose(ph3_in.forces[-1, -1], ref_force_last)
np.testing.assert_allclose(ph3_in.phonon_forces[0, 0], ref_ph_force00)

ph3 = Phono3py(
ph3_in.unitcell,
supercell_matrix=ph3_in.supercell_matrix,
phonon_supercell_matrix=ph3_in.phonon_supercell_matrix,
primitive_matrix=ph3_in.primitive_matrix,
)
ph3.dataset = ph3_in.dataset
ph3.phonon_dataset = ph3_in.phonon_dataset

ph3.supercell_energies = ph3_in.supercell_energies + 1
ph3.phonon_supercell_energies = ph3_in.phonon_supercell_energies + 1
np.testing.assert_allclose(ph3_in.supercell_energies + 1, ph3.supercell_energies)
np.testing.assert_allclose(
ph3_in.phonon_supercell_energies + 1, ph3.phonon_supercell_energies
)

ph3.forces = ph3_in.forces + 1
ph3.phonon_forces = ph3_in.phonon_forces + 1
np.testing.assert_allclose(ph3_in.forces + 1, ph3.forces)
np.testing.assert_allclose(ph3_in.phonon_forces + 1, ph3.phonon_forces)


def test_type2_forces_energies_setter_Si(si_111_222_rd: Phono3py):
"""Test type2 supercell_energies, phonon_supercell_energies attributes."""
ph3_in = si_111_222_rd
ref_ph_supercell_energies = [
-346.81061270, # 1
-346.81263617, # 2
]
ref_supercell_energies = [
-43.35270268, # 1
-43.35211687, # 2
-43.35122776, # 3
-43.35226673, # 4
-43.35146358, # 5
-43.35133209, # 6
-43.35042212, # 7
-43.35008442, # 8
-43.34968796, # 9
-43.35348999, # 10
-43.35134937, # 11
-43.35335251, # 12
-43.35160892, # 13
-43.35009115, # 14
-43.35202797, # 15
-43.35076370, # 16
-43.35174477, # 17
-43.35107001, # 18
-43.35037949, # 19
-43.35126123, # 20
]
np.testing.assert_allclose(ph3_in.supercell_energies, ref_supercell_energies)
np.testing.assert_allclose(
ph3_in.phonon_supercell_energies, ref_ph_supercell_energies
)

ref_force00 = [0.0445647800000000, 0.1702929900000000, 0.0913398200000000]
ref_force_last = [-0.1749668700000000, 0.0146997300000000, -0.1336066300000000]
ref_ph_force00 = [-0.0161598900000000, -0.1161657500000000, 0.1399128100000000]
ref_ph_force_last = [0.1049486700000000, 0.0795870900000000, 0.1062164600000000]

np.testing.assert_allclose(ph3_in.forces[0, 0], ref_force00)
np.testing.assert_allclose(ph3_in.forces[-1, -1], ref_force_last)
np.testing.assert_allclose(ph3_in.phonon_forces[0, 0], ref_ph_force00)
np.testing.assert_allclose(ph3_in.phonon_forces[-1, -1], ref_ph_force_last)

ph3 = Phono3py(
ph3_in.unitcell,
supercell_matrix=ph3_in.supercell_matrix,
phonon_supercell_matrix=ph3_in.phonon_supercell_matrix,
primitive_matrix=ph3_in.primitive_matrix,
)
ph3.dataset = ph3_in.dataset
ph3.phonon_dataset = ph3_in.phonon_dataset
ph3.supercell_energies = ph3_in.supercell_energies + 1
ph3.phonon_supercell_energies = ph3_in.phonon_supercell_energies + 1

np.testing.assert_allclose(ph3_in.supercell_energies + 1, ph3.supercell_energies)
np.testing.assert_allclose(
ph3_in.phonon_supercell_energies + 1, ph3.phonon_supercell_energies
)

ph3.forces = ph3_in.forces + 1
ph3.phonon_forces = ph3_in.phonon_forces + 1
np.testing.assert_allclose(ph3_in.forces + 1, ph3.forces)
np.testing.assert_allclose(ph3_in.phonon_forces + 1, ph3.phonon_forces)
18 changes: 16 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,20 @@ def aln_lda(request) -> Phono3py:
)


@pytest.fixture(scope="session")
def si_111_222_fd() -> Phono3py:
"""Return Phono3py class instance of Si-1x1x1-2x2x2 FD."""
yaml_filename = cwd / "phono3py_params_Si-111-222-fd.yaml.xz"
return phono3py.load(yaml_filename, produce_fc=False, log_level=1)


@pytest.fixture(scope="session")
def si_111_222_rd() -> Phono3py:
"""Return Phono3py class instance of Si-1x1x1-2x2x2 RD."""
yaml_filename = cwd / "phono3py_params_Si-111-222-rd.yaml.xz"
return phono3py.load(yaml_filename, produce_fc=False, log_level=1)


@pytest.fixture(scope="session")
def ph_nacl() -> Phonopy:
"""Return Phonopy class instance of NaCl 2x2x2."""
Expand Down Expand Up @@ -512,7 +526,7 @@ def ph_si() -> Phonopy:


@pytest.fixture(scope="session")
def si_111_222_fd() -> tarfile.TarFile:
def si_111_222_fd_raw_data() -> tarfile.TarFile:
"""Return Si fc3 111 fc2 222 vasp inputs.
tar.getnames()
Expand Down Expand Up @@ -546,7 +560,7 @@ def si_111_222_fd() -> tarfile.TarFile:


@pytest.fixture(scope="session")
def si_111_222_rd() -> tarfile.TarFile:
def si_111_222_rd_raw_data() -> tarfile.TarFile:
"""Return Si fc3 111 fc2 222 vasp inputs.
tar.getnames()
Expand Down

0 comments on commit 1d35ca2

Please sign in to comment.