diff --git a/changelog.md b/changelog.md index d499de5fc..7afdc2688 100644 --- a/changelog.md +++ b/changelog.md @@ -22,6 +22,10 @@ * Added improved printing of calibrations performed with `Pylake`. * Improved error message that includes the name of the model when trying to access a model that was not added in an [`FdFit`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.FdFit.html) using angular brackets. +#### Bug fixes + +* Ensure that operators such as (e.g. `+`, `-`, `/`) work on [`Slice`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.channel.Slice.html) with all types that are convertible to scalars. Previously these failed with zero dimensional numpy arrays and other convertible objects. + ## v1.5.3 | 2024-10-29 #### Bug fixes diff --git a/lumicks/pylake/channel.py b/lumicks/pylake/channel.py index 83b63d679..6fa17cd12 100644 --- a/lumicks/pylake/channel.py +++ b/lumicks/pylake/channel.py @@ -8,7 +8,7 @@ from .detail.plotting import _annotate from .detail.timeindex import to_seconds, to_timestamp -from .detail.utilities import downsample +from .detail.utilities import downsample, convert_to_scalar from .nb_widgets.range_selector import SliceRangeSelectorWidget @@ -72,8 +72,29 @@ def _apply_mask(self, mask): return self._with_data_source(self._src._apply_mask(mask)) def _unpack_other(self, other): - if np.isscalar(other): - return other + """Extract raw data from `other` + + Extracts the raw data from the other object for use in an arithmetic operation. If it is + convertible to a scalar, this object will be converted to a flat scalar value. If it is a + Slice, the function will verify that the timestamps of both slices are the same. If so, + the data will be extracted and returned. + + Parameters + ---------- + other : array_like | Slice + The object to extract data from. + + Raises + ------ + TypeError + If the object is not scalar-like or a slice. + NotImplementedError + If the object is a TimeTag object. + RuntimeError + If the timestamps of the two slices are not the same. + """ + if (scalar := convert_to_scalar(other)) is not None: + return scalar if not isinstance(other, Slice): raise TypeError("Trying to perform operation with incompatible types.") @@ -88,7 +109,10 @@ def _unpack_other(self, other): def _generate_labels(self, lhs, operator, rhs, keep_unit): def get_label(item, key): - return item.labels.get(key, "") if not np.isscalar(item) else str(item) + try: + return item.labels.get(key, "") + except AttributeError: + return str(item) # scalar value case labels = {"title": f"({get_label(lhs, 'title')} {operator} {get_label(rhs, 'title')})"} if keep_unit: @@ -110,14 +134,14 @@ def __neg__(self): def __add__(self, other): return Slice( self._src._with_data(self.data + self._unpack_other(other)), - calibration=self._calibration if np.isscalar(other) else None, + calibration=self._calibration if convert_to_scalar(other) is not None else None, labels=self._generate_labels(self, "+", other, keep_unit=True), ) def __sub__(self, other): return Slice( self._src._with_data(self.data - self._unpack_other(other)), - calibration=self._calibration if np.isscalar(other) else None, + calibration=self._calibration if convert_to_scalar(other) is not None else None, labels=self._generate_labels(self, "-", other, keep_unit=True), ) @@ -160,7 +184,7 @@ def __rpow__(self, other): def __radd__(self, other): return Slice( self._src._with_data(self.data + self._unpack_other(other)), - calibration=self._calibration if np.isscalar(other) else None, + calibration=self._calibration if convert_to_scalar(other) is not None else None, labels=self._generate_labels(other, "+", self, keep_unit=True), ) diff --git a/lumicks/pylake/detail/utilities.py b/lumicks/pylake/detail/utilities.py index 5816a7dcb..96a876761 100644 --- a/lumicks/pylake/detail/utilities.py +++ b/lumicks/pylake/detail/utilities.py @@ -1,4 +1,5 @@ import math +import numbers import contextlib import numpy as np @@ -213,3 +214,14 @@ def temp_seed(seed): yield finally: np.random.seed(None) + + +def convert_to_scalar(value): + """Converts to a numeric scalar if possible, otherwise returns None""" + try: + value = np.asarray(value).item() + except ValueError: # Can only convert array of size 1 to Python scalar + return None + + if isinstance(value, numbers.Number): + return value diff --git a/lumicks/pylake/tests/test_channels/test_arithmetic.py b/lumicks/pylake/tests/test_channels/test_arithmetic.py index 73576b23b..95d1c8b23 100644 --- a/lumicks/pylake/tests/test_channels/test_arithmetic.py +++ b/lumicks/pylake/tests/test_channels/test_arithmetic.py @@ -1,8 +1,11 @@ +import warnings + import numpy as np import pytest from lumicks.pylake.channel import Slice, TimeTags, Continuous, TimeSeries from lumicks.pylake.calibration import ForceCalibrationList +from lumicks.pylake.detail.value import ValueMixin start = 1 + int(1e18) calibration = ForceCalibrationList( @@ -64,6 +67,16 @@ def test_operator(first_slice, second_slice, operation): [ (slice_continuous_1, 2.0), (slice_timeseries_1, 2.0), + (slice_continuous_1, 0), + (slice_timeseries_1, 0), + (slice_continuous_1, 0.0), + (slice_timeseries_1, 0.0), + (slice_continuous_1, np.array(2.0)), + (slice_timeseries_1, np.array(2.0)), + (slice_continuous_1, np.array(0)), + (slice_timeseries_1, np.array(0)), + (slice_continuous_1, ValueMixin(2.0)), + (slice_timeseries_1, ValueMixin(2.0)), ], ) def test_operations_scalar(slice1, scalar): @@ -85,7 +98,9 @@ def test_operator(current_slice, scalar_val, operation, preserve_calibration): assert not getattr(current_slice, operation)(scalar_val).calibration for operator, preserve_calibration, *_ in operators: - test_operator(slice1, scalar, operator, preserve_calibration) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="divide by zero encountered in divide") + test_operator(slice1, scalar, operator, preserve_calibration) slice_continuous_different_timestamps = Slice( @@ -181,7 +196,7 @@ def test_negation(channel_slice): def test_negation_timetags_not_implemented(): with pytest.raises(NotImplementedError): - negated_timetags = -timetags + _ = -timetags def test_labels_slices(): diff --git a/lumicks/pylake/tests/test_utilities.py b/lumicks/pylake/tests/test_utilities.py index 5b90eae81..edc490604 100644 --- a/lumicks/pylake/tests/test_utilities.py +++ b/lumicks/pylake/tests/test_utilities.py @@ -5,6 +5,8 @@ import matplotlib as mpl from numpy.testing import assert_array_equal +from lumicks.pylake.channel import Slice, Continuous +from lumicks.pylake.detail.value import ValueMixin from lumicks.pylake.detail.confocal import timestamp_mean from lumicks.pylake.detail.utilities import * from lumicks.pylake.detail.utilities import ( @@ -338,3 +340,50 @@ def example_method(self, argument=5): assert test.example_method(6) == 6 assert calls == 3 assert len(test._cache) == 3 + + +class NonConvertible: + def __init__(self, value): + self.value = value + + +@pytest.mark.parametrize( + "test_value, ref", + [ + ( + Slice( + Continuous(np.array([5.0]), int(1e6), 100), + labels={"y": "y", "title": "title", "x": "x"}, + ), + 5.0, + ), + (np.array(1.0), 1.0), + (np.array(1), 1), + (np.array(159291604090635630000), 159291604090635630000), + (159291604090635630000, 159291604090635630000), + ([1], 1), + (np.array([[[1]]]), 1), + (np.array([[[1.0]]]), 1.0), + (ValueMixin(1.0), 1.0), + (ValueMixin(1), 1), + ], +) +def test_convert_to_scalar_valid(test_value, ref): + assert (value := convert_to_scalar(test_value)) == ref + assert isinstance(value, type(ref)) + + +@pytest.mark.parametrize( + "test_value", + ( + Slice(Continuous(np.arange(100), int(1e6), 100), labels={"y": "y", "title": "t", "x": "x"}), + "str", + [1, 1], # Not a single scalar + ValueMixin(["1.0"]), # Not a numeric value + ValueMixin(["string"]), + ValueMixin([1, 1]), # Not a scalar + NonConvertible(1), + ), +) +def test_convert_to_scalar_invalid(test_value): + assert convert_to_scalar(test_value) is None