From ce710bb36960fa15a338c7efb58cc8b39ca731af Mon Sep 17 00:00:00 2001 From: Jacob Wilkins Date: Mon, 4 Nov 2024 11:11:29 +0000 Subject: [PATCH] Fix tests and respond to comments --- janus_core/processing/post_process.py | 26 ++++++++++++++++++++++++-- tests/test_post_process.py | 11 +++++++---- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/janus_core/processing/post_process.py b/janus_core/processing/post_process.py index 7eba7f23..9fe7edf0 100644 --- a/janus_core/processing/post_process.py +++ b/janus_core/processing/post_process.py @@ -191,7 +191,7 @@ def compute_vaf( index: SliceLike = (0, None, 1), filter_atoms: MaybeSequence[MaybeSequence[Optional[int]]] = ((None),), time_step: float = 1.0, -) -> NDArray[float64]: +) -> tuple[NDArray[float64], list[NDArray[float64]]]: """ Compute the velocity autocorrelation function (VAF) of `data`. @@ -219,8 +219,30 @@ def compute_vaf( Returns ------- - MaybeSequence[NDArray[float64]] + lags : numpy.ndarray + Lags at which the VAFs have been computed. + vafs : list[numpy.ndarray] Computed VAF(s). + + Notes + ----- + `filter_atoms` is given as a series of sequences of atoms, where + each element in the series denotes a VAF subset to calculate and + each sequence determines the atoms (by index) to be included in that VAF. + + E.g. + + .. code-block: Python + + # Species indices in cell + na = (1, 3, 5, 7) + cl = (2, 4, 6, 8) + + compute_vaf(..., filter_atoms=(na, cl)) + + Would compute separate VAFs for each species. + + By default, one VAF will be computed for all atoms in the structure. """ # Ensure if passed scalars they are turned into correct dimensionality if not isinstance(filter_atoms, Sequence): diff --git a/tests/test_post_process.py b/tests/test_post_process.py index 29765845..b1a3a600 100644 --- a/tests/test_post_process.py +++ b/tests/test_post_process.py @@ -179,7 +179,7 @@ def test_vaf(tmp_path): vaf_filter = ((3, 4), (1, 2, 3)) data = read(DATA_PATH / "lj-traj.xyz", index=":") - vaf = post_process.compute_vaf(data) + lags, vaf = post_process.compute_vaf(data) expected = np.loadtxt(DATA_PATH / "vaf-lj.dat") assert isinstance(vaf, list) @@ -187,13 +187,13 @@ def test_vaf(tmp_path): assert isinstance(vaf[0], np.ndarray) assert vaf[0] == approx(expected, rel=1e-9) - vaf = post_process.compute_vaf(data, fft=True) + lags, vaf = post_process.compute_vaf(data, fft=True) assert isinstance(vaf, list) assert len(vaf) == 1 assert isinstance(vaf[0], np.ndarray) - vaf = post_process.compute_vaf( + lags, vaf = post_process.compute_vaf( data, filter_atoms=vaf_filter, filenames=[tmp_path / name for name in vaf_names] ) @@ -205,5 +205,8 @@ def test_vaf(tmp_path): assert (tmp_path / name).exists() expected = np.loadtxt(DATA_PATH / name) written = np.loadtxt(tmp_path / name) + w_lag, w_vaf = written[:, 0], written[:, 1] + assert vaf[i] == approx(expected, rel=1e-9) - assert vaf[i] == approx(written, rel=1e-9) + assert lags == approx(w_lag, rel=1e-9) + assert vaf[i] == approx(w_vaf, rel=1e-9)