Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Jacob Wilkins <[email protected]>
  • Loading branch information
ElliottKasoar and oerc0122 committed Nov 8, 2024
1 parent 30fcace commit 15dfa87
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 26 deletions.
20 changes: 10 additions & 10 deletions janus_core/calculations/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ class Phonons(BaseCalculation):
Keyword arguments to pass to geometry optimizer. Default is {}.
n_qpoints : int
Number of q-points to sample along generated path, including end points.
Unused if `paths` is specified. Default is 51.
paths : PathLike | None
Unused if `qpoint_file` is specified. Default is 51.
qpoint_file : PathLike | None
Path to yaml file with info to generate a path of q-points for band structure.
Default is None.
dos_kwargs : dict[str, Any] | None
Expand Down Expand Up @@ -176,7 +176,7 @@ def __init__(
minimize: bool = False,
minimize_kwargs: dict[str, Any] | None = None,
n_qpoints: int = 51,
paths: PathLike | None = None,
qpoint_file: PathLike | None = None,
dos_kwargs: dict[str, Any] | None = None,
pdos_kwargs: dict[str, Any] | None = None,
temp_min: float = 0.0,
Expand Down Expand Up @@ -248,8 +248,8 @@ def __init__(
Keyword arguments to pass to geometry optimizer. Default is {}.
n_qpoints : int
Number of q-points to sample along generated path, including end points.
Unused if `paths` is specified. Default is 51.
paths : PathLike | None
Unused if `qpoint_file` is specified. Default is 51.
qpoint_file : PathLike | None
Path to yaml file with info to generate a path of q-points for band
structure. Default is None.
dos_kwargs : dict[str, Any] | None
Expand Down Expand Up @@ -298,7 +298,7 @@ def __init__(
self.minimize = minimize
self.minimize_kwargs = minimize_kwargs
self.n_qpoints = n_qpoints
self.paths = paths
self.qpoint_file = qpoint_file
self.dos_kwargs = dos_kwargs
self.pdos_kwargs = pdos_kwargs
self.temp_min = temp_min
Expand Down Expand Up @@ -584,14 +584,14 @@ def write_bands(
if save_plots is None:
save_plots = self.plot_to_file

if self.paths:
if self.qpoint_file:
bands_file = self._build_filename("bands.yml.xz", filename=bands_file)

with open(self.paths, encoding="utf8") as paths_file:
paths_info = safe_load(paths_file)
with open(self.qpoint_file, encoding="utf8") as file:
paths_info = safe_load(file)

labels = paths_info["labels"]
num_q_points = sum([len(q) for q in paths_info["paths"]])
num_q_points = sum(len(q) for q in paths_info["paths"])
num_labels = len(labels)
assert (
num_q_points == num_labels
Expand Down
10 changes: 5 additions & 5 deletions janus_core/cli/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def phonons(
Option(
help=(
"Number of q-points to sample along generated path, including end "
"points. Unused if `paths` is specified"
"points. Unused if `qpoint_file` is specified"
)
),
] = 51,
paths: Annotated[
qpoint_file: Annotated[
Optional[Path],
Option(
help=(
Expand Down Expand Up @@ -166,8 +166,8 @@ def phonons(
Whether to calculate and save the band structure. Default is False.
n_qpoints : int
Number of q-points to sample along generated path, including end points.
Unused if `paths` is specified. Default is 51.
paths : Optional[PathLike]
Unused if `qpoint_file` is specified. Default is 51.
qpoint_file : Optional[PathLike]
Path to yaml file with info to generate a path of q-points for band structure.
Default is None.
dos : bool
Expand Down Expand Up @@ -322,7 +322,7 @@ def phonons(
"minimize": minimize,
"minimize_kwargs": minimize_kwargs,
"n_qpoints": n_qpoints,
"paths": paths,
"qpoint_file": qpoint_file,
"dos_kwargs": dos_kwargs,
"pdos_kwargs": pdos_kwargs,
"temp_min": temp_min,
Expand Down
21 changes: 10 additions & 11 deletions tests/test_phonons_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ def test_dos(tmp_path):
)
assert result.exit_code == 0
assert dos_results.exists()
with open(dos_results, encoding="utf8") as file:
lines = file.readlines()
lines = dos_results.read_text().splitlines()
assert lines[1].split()[0] == "-1.0000000000"
assert lines[-1].split()[0] == "0.0000000000"

Expand Down Expand Up @@ -462,7 +461,7 @@ def test_no_carbon(tmp_path):


def test_displacement_kwargs(tmp_path):
"""Test displacment_kwargs can be set."""
"""Test displacement_kwargs can be set."""
file_prefix_1 = tmp_path / "NaCl_1"
file_prefix_2 = tmp_path / "NaCl_2"
displacement_file_1 = tmp_path / "NaCl_1-phonopy.yml"
Expand Down Expand Up @@ -501,21 +500,21 @@ def test_displacement_kwargs(tmp_path):
# Check parameters
with open(displacement_file_1, encoding="utf8") as file:
params = yaml.safe_load(file)
n_displacments_1 = len(params["displacements"])
n_displacements_1 = len(params["displacements"])

assert n_displacments_1 == 4
assert n_displacements_1 == 4

with open(displacement_file_2, encoding="utf8") as file:
params = yaml.safe_load(file)
n_displacments_2 = len(params["displacements"])
n_displacements_2 = len(params["displacements"])

assert n_displacments_2 == 2
assert n_displacements_2 == 2


def test_paths(tmp_path):
"""Test displacment_kwargs can be set."""
"""Test displacement_kwargs can be set."""
file_prefix = tmp_path / "NaCl"
paths = DATA_PATH / "paths.yml"
qpoint_file = DATA_PATH / "paths.yml"
band_results = tmp_path / "NaCl-bands.yml.xz"

result = runner.invoke(
Expand All @@ -526,8 +525,8 @@ def test_paths(tmp_path):
DATA_PATH / "NaCl.cif",
"--no-hdf5",
"--bands",
"--paths",
paths,
"--qpoint-file",
qpoint_file,
"--file-prefix",
file_prefix,
],
Expand Down

0 comments on commit 15dfa87

Please sign in to comment.