Skip to content

Commit

Permalink
slice: ensure operators work with all scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
JoepVanlier committed Dec 5, 2024
1 parent 009fa10 commit c86418d
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 8 deletions.
4 changes: 4 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
* Added improved printing of calibrations performed with `Pylake`.
* Improved error message that includes the name of the model when trying to access a model that was not added in an [`FdFit`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.FdFit.html) using angular brackets.

#### Bug fixes

* Ensure that operators such as (e.g. `+`, `-`, `/`) work on [`Slice`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.channel.Slice.html) with all types that are convertible to scalars. Previously these failed with zero dimensional numpy arrays and other convertible objects.

## v1.5.3 | 2024-10-29

#### Bug fixes
Expand Down
38 changes: 31 additions & 7 deletions lumicks/pylake/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .detail.plotting import _annotate
from .detail.timeindex import to_seconds, to_timestamp
from .detail.utilities import downsample
from .detail.utilities import downsample, convert_to_scalar
from .nb_widgets.range_selector import SliceRangeSelectorWidget


Expand Down Expand Up @@ -72,8 +72,29 @@ def _apply_mask(self, mask):
return self._with_data_source(self._src._apply_mask(mask))

def _unpack_other(self, other):
if np.isscalar(other):
return other
"""Extract raw data from `other`
Extracts the raw data from the other object for use in an arithmetic operation. If it is
convertible to a scalar, this object will be converted to a flat scalar value. If it is a
Slice, the function will verify that the timestamps of both slices are the same. If so,
the data will be extracted and returned.
Parameters
----------
other : array_like | Slice
The object to extract data from.
Raises
------
TypeError
If the object is not scalar-like or a slice.
NotImplementedError
If the object is a TimeTag object.
RuntimeError
If the timestamps of the two slices are not the same.
"""
if (scalar := convert_to_scalar(other)) is not None:
return scalar

if not isinstance(other, Slice):
raise TypeError("Trying to perform operation with incompatible types.")
Expand All @@ -88,7 +109,10 @@ def _unpack_other(self, other):

def _generate_labels(self, lhs, operator, rhs, keep_unit):
def get_label(item, key):
return item.labels.get(key, "") if not np.isscalar(item) else str(item)
try:
return item.labels.get(key, "")
except AttributeError:
return str(item) # scalar value case

labels = {"title": f"({get_label(lhs, 'title')} {operator} {get_label(rhs, 'title')})"}
if keep_unit:
Expand All @@ -110,14 +134,14 @@ def __neg__(self):
def __add__(self, other):
return Slice(
self._src._with_data(self.data + self._unpack_other(other)),
calibration=self._calibration if np.isscalar(other) else None,
calibration=self._calibration if convert_to_scalar(other) is not None else None,
labels=self._generate_labels(self, "+", other, keep_unit=True),
)

def __sub__(self, other):
return Slice(
self._src._with_data(self.data - self._unpack_other(other)),
calibration=self._calibration if np.isscalar(other) else None,
calibration=self._calibration if convert_to_scalar(other) is not None else None,
labels=self._generate_labels(self, "-", other, keep_unit=True),
)

Expand Down Expand Up @@ -160,7 +184,7 @@ def __rpow__(self, other):
def __radd__(self, other):
return Slice(
self._src._with_data(self.data + self._unpack_other(other)),
calibration=self._calibration if np.isscalar(other) else None,
calibration=self._calibration if convert_to_scalar(other) is not None else None,
labels=self._generate_labels(other, "+", self, keep_unit=True),
)

Expand Down
12 changes: 12 additions & 0 deletions lumicks/pylake/detail/utilities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import numbers
import contextlib

import numpy as np
Expand Down Expand Up @@ -213,3 +214,14 @@ def temp_seed(seed):
yield
finally:
np.random.seed(None)


def convert_to_scalar(value):
"""Converts to a numeric scalar if possible, otherwise returns None"""
try:
value = np.asarray(value).item()
except ValueError: # Can only convert array of size 1 to Python scalar
return None

if isinstance(value, numbers.Number):
return value
7 changes: 6 additions & 1 deletion lumicks/pylake/tests/test_channels/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from lumicks.pylake.channel import Slice, TimeTags, Continuous, TimeSeries
from lumicks.pylake.calibration import ForceCalibrationList
from lumicks.pylake.detail.value import ValueMixin

start = 1 + int(1e18)
calibration = ForceCalibrationList(
Expand Down Expand Up @@ -64,6 +65,10 @@ def test_operator(first_slice, second_slice, operation):
[
(slice_continuous_1, 2.0),
(slice_timeseries_1, 2.0),
(slice_continuous_1, np.array(2.0)),
(slice_timeseries_1, np.array(2.0)),
(slice_continuous_1, ValueMixin(2.0)),
(slice_timeseries_1, ValueMixin(2.0)),
],
)
def test_operations_scalar(slice1, scalar):
Expand Down Expand Up @@ -181,7 +186,7 @@ def test_negation(channel_slice):

def test_negation_timetags_not_implemented():
with pytest.raises(NotImplementedError):
negated_timetags = -timetags
_ = -timetags


def test_labels_slices():
Expand Down
49 changes: 49 additions & 0 deletions lumicks/pylake/tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import matplotlib as mpl
from numpy.testing import assert_array_equal

from lumicks.pylake.channel import Slice, Continuous
from lumicks.pylake.detail.value import ValueMixin
from lumicks.pylake.detail.confocal import timestamp_mean
from lumicks.pylake.detail.utilities import *
from lumicks.pylake.detail.utilities import (
Expand Down Expand Up @@ -338,3 +340,50 @@ def example_method(self, argument=5):
assert test.example_method(6) == 6
assert calls == 3
assert len(test._cache) == 3


class NonConvertible:
def __init__(self, value):
self.value = value


@pytest.mark.parametrize(
"test_value, ref",
[
(
Slice(
Continuous(np.array([5.0]), int(1e6), 100),
labels={"y": "y", "title": "title", "x": "x"},
),
5.0,
),
(np.array(1.0), 1.0),
(np.array(1), 1),
(np.array(159291604090635630000), 159291604090635630000),
(159291604090635630000, 159291604090635630000),
([1], 1),
(np.array([[[1]]]), 1),
(np.array([[[1.0]]]), 1.0),
(ValueMixin(1.0), 1.0),
(ValueMixin(1), 1),
],
)
def test_convert_to_scalar_valid(test_value, ref):
assert (value := convert_to_scalar(test_value)) == ref
assert isinstance(value, type(ref))


@pytest.mark.parametrize(
"test_value",
(
Slice(Continuous(np.arange(100), int(1e6), 100), labels={"y": "y", "title": "t", "x": "x"}),
"str",
[1, 1], # Not a single scalar
ValueMixin(["1.0"]), # Not a numeric value
ValueMixin(["string"]),
ValueMixin([1, 1]), # Not a scalar
NonConvertible(1),
),
)
def test_convert_to_scalar_invalid(test_value):
assert convert_to_scalar(test_value) is None

0 comments on commit c86418d

Please sign in to comment.