Skip to content

Commit

Permalink
Create SliceLike utils
Browse files Browse the repository at this point in the history
  • Loading branch information
harveydevereux committed Oct 22, 2024
1 parent eb931ad commit bad0bd5
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 39 deletions.
12 changes: 2 additions & 10 deletions janus_core/helpers/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

if TYPE_CHECKING:
from janus_core.helpers.janus_types import SliceLike
from janus_core.helpers.utils import slicelike_len_for


# pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -81,16 +82,7 @@ def atom_count(self, n_atoms: int):
if self.atoms:
if isinstance(self.atoms, list):
return len(self.atoms)
if isinstance(self.atoms, int):
return 1

start = self.atoms.start
stop = self.atoms.stop
step = self.atoms.step
start = start if start is None else 0
stop = stop if stop is None else n_atoms
step = step if step is None else 1
return len(range(start, stop, step))
return slicelike_len_for(self.n_atoms)
return 0


Expand Down
31 changes: 3 additions & 28 deletions janus_core/helpers/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,8 @@
MaybeSequence,
PathLike,
SliceLike,
StartStopStep,
)


def _process_index(index: SliceLike) -> StartStopStep:
"""
Standarize `SliceLike`s into tuple of `start`, `stop`, `step`.
Parameters
----------
index : SliceLike
`SliceLike` to standardize.
Returns
-------
StartStopStep
Standardized `SliceLike` as `start`, `stop`, `step` triplet.
"""
if isinstance(index, int):
if index == -1:
return (index, None, 1)
return (index, index + 1, 1)

if isinstance(index, (slice, range)):
return (index.start, index.stop, index.step)

return index
from janus_core.helpers.utils import slicelike_to_startstopstep


def compute_rdf(
Expand Down Expand Up @@ -93,7 +68,7 @@ def compute_rdf(
If `by_elements` is true returns a `dict` of RDF by element pairs.
Otherwise returns RDF of total system filtered by elements.
"""
index = _process_index(index)
index = slicelike_to_startstopstep(index)

if not isinstance(data, Sequence):
data = [data]
Expand Down Expand Up @@ -234,7 +209,7 @@ def compute_vaf(
)

# Extract requested data
index = _process_index(index)
index = slicelike_to_startstopstep(index)
data = data[slice(*index)]

if use_velocities:
Expand Down
52 changes: 52 additions & 0 deletions janus_core/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
MaybeSequence,
PathLike,
Properties,
SliceLike,
StartStopStep,
)
from janus_core.helpers.mlip_calculators import choose_calculator

Expand Down Expand Up @@ -771,3 +773,53 @@ def check_calculator(calc: Calculator, attribute: str) -> None:
raise NotImplementedError(
f"The attached calculator does not currently support {attribute}"
)


def slicelike_to_startstopstep(index: SliceLike) -> StartStopStep:
"""
Standarize `SliceLike`s into tuple of `start`, `stop`, `step`.
Parameters
----------
index : SliceLike
`SliceLike` to standardize.
Returns
-------
StartStopStep
Standardized `SliceLike` as `start`, `stop`, `step` triplet.
"""
if isinstance(index, int):
if index == -1:
return (index, None, 1)
return (index, index + 1, 1)

if isinstance(index, (slice, range)):
return (index.start, index.stop, index.step)

return index


def slicelike_len_for(slc: SliceLike, sliceable_length: int) -> int:
"""
Calculate the length of a SliceLike applied to a sliceable of a given length.
Parameters
----------
slc : SliceLike
The applied SliceLike.
sliceable_length : int
The length of the sliceable object.
Returns
-------
int
Length of the result of applying slc.
"""
start, stop, step = slicelike_to_startstopstep(slc)
if stop is None:
stop = sliceable_length
# start = start if start is None else 0
# stop = stop if stop is None else sliceable_length
# step = step if step is None else 1
return len(range(start, stop, step))
42 changes: 41 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@
import pytest

from janus_core.cli.utils import dict_paths_to_strs, dict_remove_hyphens
from janus_core.helpers.janus_types import SliceLike, StartStopStep
from janus_core.helpers.mlip_calculators import choose_calculator
from janus_core.helpers.utils import none_to_dict, output_structs
from janus_core.helpers.utils import (
none_to_dict,
output_structs,
slicelike_len_for,
slicelike_to_startstopstep,
)

DATA_PATH = Path(__file__).parent / "data/NaCl.cif"
MODEL_PATH = Path(__file__).parent / "models/mace_mp_small.model"
Expand Down Expand Up @@ -154,3 +160,37 @@ def test_none_to_dict(dicts_in):
assert dicts[2] == dicts_in[2]
assert dicts[3] == dicts_in[3]
assert dicts[4] == {}


@pytest.mark.parametrize(
"slc, expected",
[
((1, 2, 3), (1, 2, 3)),
(1, (1, 2, 1)),
(range(1, 2, 3), (1, 2, 3)),
(slice(1, 2, 3), (1, 2, 3)),
(-1, (-1, None, 1)),
(range(10), (0, 10, 1)),
(slice(0, None, 1), (0, None, 1)),
],
)
def test_slicelike_to_startstopstep(slc: SliceLike, expected: StartStopStep):
"""Test converting SliceLike to StartStopStep."""
assert slicelike_to_startstopstep(slc) == expected


@pytest.mark.parametrize(
"slc_len, expected",
[
(((1, 2, 3), 3), 1),
((1, 1), 1),
((range(1, 2, 3), 3), 1),
((slice(1, 2, 3), 3), 1),
((-1, 1), 2),
((range(10), 10), 10),
((slice(0, None, 2), 10), 5),
],
)
def test_slicelike_len_for(slc_len: tuple[SliceLike, int], expected: int):
"""Test converting SliceLike to StartStopStep."""
assert slicelike_len_for(*slc_len) == expected

0 comments on commit bad0bd5

Please sign in to comment.