From bad0bd5d5b5cee3946e980dff76cc2d3defe17b0 Mon Sep 17 00:00:00 2001 From: Harvey Devereux Date: Tue, 22 Oct 2024 16:20:51 +0100 Subject: [PATCH] Create SliceLike utils --- janus_core/helpers/observables.py | 12 ++----- janus_core/helpers/post_process.py | 31 ++---------------- janus_core/helpers/utils.py | 52 ++++++++++++++++++++++++++++++ tests/test_utils.py | 42 +++++++++++++++++++++++- 4 files changed, 98 insertions(+), 39 deletions(-) diff --git a/janus_core/helpers/observables.py b/janus_core/helpers/observables.py index bdccffa4..cd268867 100644 --- a/janus_core/helpers/observables.py +++ b/janus_core/helpers/observables.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: from janus_core.helpers.janus_types import SliceLike + from janus_core.helpers.utils import slicelike_len_for # pylint: disable=too-few-public-methods @@ -81,16 +82,7 @@ def atom_count(self, n_atoms: int): if self.atoms: if isinstance(self.atoms, list): return len(self.atoms) - if isinstance(self.atoms, int): - return 1 - - start = self.atoms.start - stop = self.atoms.stop - step = self.atoms.step - start = start if start is None else 0 - stop = stop if stop is None else n_atoms - step = step if step is None else 1 - return len(range(start, stop, step)) + return slicelike_len_for(self.n_atoms) return 0 diff --git a/janus_core/helpers/post_process.py b/janus_core/helpers/post_process.py index a750ad32..44955f22 100644 --- a/janus_core/helpers/post_process.py +++ b/janus_core/helpers/post_process.py @@ -14,33 +14,8 @@ MaybeSequence, PathLike, SliceLike, - StartStopStep, ) - - -def _process_index(index: SliceLike) -> StartStopStep: - """ - Standarize `SliceLike`s into tuple of `start`, `stop`, `step`. - - Parameters - ---------- - index : SliceLike - `SliceLike` to standardize. - - Returns - ------- - StartStopStep - Standardized `SliceLike` as `start`, `stop`, `step` triplet. - """ - if isinstance(index, int): - if index == -1: - return (index, None, 1) - return (index, index + 1, 1) - - if isinstance(index, (slice, range)): - return (index.start, index.stop, index.step) - - return index +from janus_core.helpers.utils import slicelike_to_startstopstep def compute_rdf( @@ -93,7 +68,7 @@ def compute_rdf( If `by_elements` is true returns a `dict` of RDF by element pairs. Otherwise returns RDF of total system filtered by elements. """ - index = _process_index(index) + index = slicelike_to_startstopstep(index) if not isinstance(data, Sequence): data = [data] @@ -234,7 +209,7 @@ def compute_vaf( ) # Extract requested data - index = _process_index(index) + index = slicelike_to_startstopstep(index) data = data[slice(*index)] if use_velocities: diff --git a/janus_core/helpers/utils.py b/janus_core/helpers/utils.py index 3573465c..ee7b4429 100644 --- a/janus_core/helpers/utils.py +++ b/janus_core/helpers/utils.py @@ -32,6 +32,8 @@ MaybeSequence, PathLike, Properties, + SliceLike, + StartStopStep, ) from janus_core.helpers.mlip_calculators import choose_calculator @@ -771,3 +773,53 @@ def check_calculator(calc: Calculator, attribute: str) -> None: raise NotImplementedError( f"The attached calculator does not currently support {attribute}" ) + + +def slicelike_to_startstopstep(index: SliceLike) -> StartStopStep: + """ + Standarize `SliceLike`s into tuple of `start`, `stop`, `step`. + + Parameters + ---------- + index : SliceLike + `SliceLike` to standardize. + + Returns + ------- + StartStopStep + Standardized `SliceLike` as `start`, `stop`, `step` triplet. + """ + if isinstance(index, int): + if index == -1: + return (index, None, 1) + return (index, index + 1, 1) + + if isinstance(index, (slice, range)): + return (index.start, index.stop, index.step) + + return index + + +def slicelike_len_for(slc: SliceLike, sliceable_length: int) -> int: + """ + Calculate the length of a SliceLike applied to a sliceable of a given length. + + Parameters + ---------- + slc : SliceLike + The applied SliceLike. + sliceable_length : int + The length of the sliceable object. + + Returns + ------- + int + Length of the result of applying slc. + """ + start, stop, step = slicelike_to_startstopstep(slc) + if stop is None: + stop = sliceable_length + # start = start if start is None else 0 + # stop = stop if stop is None else sliceable_length + # step = step if step is None else 1 + return len(range(start, stop, step)) diff --git a/tests/test_utils.py b/tests/test_utils.py index aa37d41f..a2d8841a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,8 +7,14 @@ import pytest from janus_core.cli.utils import dict_paths_to_strs, dict_remove_hyphens +from janus_core.helpers.janus_types import SliceLike, StartStopStep from janus_core.helpers.mlip_calculators import choose_calculator -from janus_core.helpers.utils import none_to_dict, output_structs +from janus_core.helpers.utils import ( + none_to_dict, + output_structs, + slicelike_len_for, + slicelike_to_startstopstep, +) DATA_PATH = Path(__file__).parent / "data/NaCl.cif" MODEL_PATH = Path(__file__).parent / "models/mace_mp_small.model" @@ -154,3 +160,37 @@ def test_none_to_dict(dicts_in): assert dicts[2] == dicts_in[2] assert dicts[3] == dicts_in[3] assert dicts[4] == {} + + +@pytest.mark.parametrize( + "slc, expected", + [ + ((1, 2, 3), (1, 2, 3)), + (1, (1, 2, 1)), + (range(1, 2, 3), (1, 2, 3)), + (slice(1, 2, 3), (1, 2, 3)), + (-1, (-1, None, 1)), + (range(10), (0, 10, 1)), + (slice(0, None, 1), (0, None, 1)), + ], +) +def test_slicelike_to_startstopstep(slc: SliceLike, expected: StartStopStep): + """Test converting SliceLike to StartStopStep.""" + assert slicelike_to_startstopstep(slc) == expected + + +@pytest.mark.parametrize( + "slc_len, expected", + [ + (((1, 2, 3), 3), 1), + ((1, 1), 1), + ((range(1, 2, 3), 3), 1), + ((slice(1, 2, 3), 3), 1), + ((-1, 1), 2), + ((range(10), 10), 10), + ((slice(0, None, 2), 10), 5), + ], +) +def test_slicelike_len_for(slc_len: tuple[SliceLike, int], expected: int): + """Test converting SliceLike to StartStopStep.""" + assert slicelike_len_for(*slc_len) == expected