From 285493c81d20bda483be9d5b46c9f990b4ccbd29 Mon Sep 17 00:00:00 2001 From: Joep Vanlier Date: Thu, 5 Dec 2024 10:05:58 +0100 Subject: [PATCH] slice: ensure operators work with all scalars --- changelog.md | 4 ++ lumicks/pylake/channel.py | 17 ++++--- lumicks/pylake/detail/utilities.py | 12 +++++ .../tests/test_channels/test_arithmetic.py | 2 + lumicks/pylake/tests/test_utilities.py | 47 +++++++++++++++++++ 5 files changed, 75 insertions(+), 7 deletions(-) diff --git a/changelog.md b/changelog.md index d499de5fc..7afdc2688 100644 --- a/changelog.md +++ b/changelog.md @@ -22,6 +22,10 @@ * Added improved printing of calibrations performed with `Pylake`. * Improved error message that includes the name of the model when trying to access a model that was not added in an [`FdFit`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.FdFit.html) using angular brackets. +#### Bug fixes + +* Ensure that operators such as (e.g. `+`, `-`, `/`) work on [`Slice`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.channel.Slice.html) with all types that are convertible to scalars. Previously these failed with zero dimensional numpy arrays and other convertible objects. + ## v1.5.3 | 2024-10-29 #### Bug fixes diff --git a/lumicks/pylake/channel.py b/lumicks/pylake/channel.py index 83b63d679..433e9428f 100644 --- a/lumicks/pylake/channel.py +++ b/lumicks/pylake/channel.py @@ -8,7 +8,7 @@ from .detail.plotting import _annotate from .detail.timeindex import to_seconds, to_timestamp -from .detail.utilities import downsample +from .detail.utilities import downsample, convert_to_scalar from .nb_widgets.range_selector import SliceRangeSelectorWidget @@ -72,8 +72,8 @@ def _apply_mask(self, mask): return self._with_data_source(self._src._apply_mask(mask)) def _unpack_other(self, other): - if np.isscalar(other): - return other + if scalar := convert_to_scalar(other): + return scalar if not isinstance(other, Slice): raise TypeError("Trying to perform operation with incompatible types.") @@ -88,7 +88,10 @@ def _unpack_other(self, other): def _generate_labels(self, lhs, operator, rhs, keep_unit): def get_label(item, key): - return item.labels.get(key, "") if not np.isscalar(item) else str(item) + try: + return item.labels.get(key, "") + except AttributeError: + return str(item) # scalar value case labels = {"title": f"({get_label(lhs, 'title')} {operator} {get_label(rhs, 'title')})"} if keep_unit: @@ -110,14 +113,14 @@ def __neg__(self): def __add__(self, other): return Slice( self._src._with_data(self.data + self._unpack_other(other)), - calibration=self._calibration if np.isscalar(other) else None, + calibration=self._calibration if convert_to_scalar(other) is not None else None, labels=self._generate_labels(self, "+", other, keep_unit=True), ) def __sub__(self, other): return Slice( self._src._with_data(self.data - self._unpack_other(other)), - calibration=self._calibration if np.isscalar(other) else None, + calibration=self._calibration if convert_to_scalar(other) is not None else None, labels=self._generate_labels(self, "-", other, keep_unit=True), ) @@ -160,7 +163,7 @@ def __rpow__(self, other): def __radd__(self, other): return Slice( self._src._with_data(self.data + self._unpack_other(other)), - calibration=self._calibration if np.isscalar(other) else None, + calibration=self._calibration if convert_to_scalar(other) is not None else None, labels=self._generate_labels(other, "+", self, keep_unit=True), ) diff --git a/lumicks/pylake/detail/utilities.py b/lumicks/pylake/detail/utilities.py index 5816a7dcb..96a876761 100644 --- a/lumicks/pylake/detail/utilities.py +++ b/lumicks/pylake/detail/utilities.py @@ -1,4 +1,5 @@ import math +import numbers import contextlib import numpy as np @@ -213,3 +214,14 @@ def temp_seed(seed): yield finally: np.random.seed(None) + + +def convert_to_scalar(value): + """Converts to a numeric scalar if possible, otherwise returns None""" + try: + value = np.asarray(value).item() + except ValueError: # Can only convert array of size 1 to Python scalar + return None + + if isinstance(value, numbers.Number): + return value diff --git a/lumicks/pylake/tests/test_channels/test_arithmetic.py b/lumicks/pylake/tests/test_channels/test_arithmetic.py index 73576b23b..38ae02314 100644 --- a/lumicks/pylake/tests/test_channels/test_arithmetic.py +++ b/lumicks/pylake/tests/test_channels/test_arithmetic.py @@ -64,6 +64,8 @@ def test_operator(first_slice, second_slice, operation): [ (slice_continuous_1, 2.0), (slice_timeseries_1, 2.0), + (slice_continuous_1, np.array(2.0)), + (slice_timeseries_1, np.array(2.0)), ], ) def test_operations_scalar(slice1, scalar): diff --git a/lumicks/pylake/tests/test_utilities.py b/lumicks/pylake/tests/test_utilities.py index 5b90eae81..8f3229361 100644 --- a/lumicks/pylake/tests/test_utilities.py +++ b/lumicks/pylake/tests/test_utilities.py @@ -5,6 +5,8 @@ import matplotlib as mpl from numpy.testing import assert_array_equal +from lumicks.pylake.channel import Slice, Continuous +from lumicks.pylake.detail.value import ValueMixin from lumicks.pylake.detail.confocal import timestamp_mean from lumicks.pylake.detail.utilities import * from lumicks.pylake.detail.utilities import ( @@ -338,3 +340,48 @@ def example_method(self, argument=5): assert test.example_method(6) == 6 assert calls == 3 assert len(test._cache) == 3 + + +class Parameter(ValueMixin): + def __init__(self, value, description): + super().__init__(value) + self.description = description + + +@pytest.mark.parametrize( + "test_value, ref", + [ + ( + Slice( + Continuous(np.array([5.0]), int(1e6), 100), + labels={"y": "y", "title": "title", "x": "x"}, + ), + 5.0, + ), + (np.array(1.0), 1.0), + (np.array(1), 1), + (np.array([[[1]]]), 1), + (np.array([[[1.0]]]), 1.0), + ([1], 1), + (Parameter(1.0, "the parameter"), 1.0), + (Parameter(1, "the parameter"), 1), + ], +) +def test_convert_to_scalar_valid(test_value, ref): + assert (value := convert_to_scalar(test_value)) == ref + assert isinstance(value, type(ref)) + + +@pytest.mark.parametrize( + "test_value", + ( + Slice(Continuous(np.arange(100), int(1e6), 100), labels={"y": "y", "title": "t", "x": "x"}), + "str", + [1, 1], + Parameter(["1.0"], "bad string"), + Parameter(["string"], "bad string"), + Parameter([1, 1], "not a scalar"), + ), +) +def test_convert_to_scalar_invalid(test_value): + assert convert_to_scalar(test_value) is None