Skip to content

Commit

Permalink
slice: ensure operators work with all scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
JoepVanlier committed Dec 5, 2024
1 parent 009fa10 commit 285493c
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 7 deletions.
4 changes: 4 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
* Added improved printing of calibrations performed with `Pylake`.
* Improved error message that includes the name of the model when trying to access a model that was not added in an [`FdFit`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.FdFit.html) using angular brackets.

#### Bug fixes

* Ensure that operators such as (e.g. `+`, `-`, `/`) work on [`Slice`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.channel.Slice.html) with all types that are convertible to scalars. Previously these failed with zero dimensional numpy arrays and other convertible objects.

## v1.5.3 | 2024-10-29

#### Bug fixes
Expand Down
17 changes: 10 additions & 7 deletions lumicks/pylake/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .detail.plotting import _annotate
from .detail.timeindex import to_seconds, to_timestamp
from .detail.utilities import downsample
from .detail.utilities import downsample, convert_to_scalar
from .nb_widgets.range_selector import SliceRangeSelectorWidget


Expand Down Expand Up @@ -72,8 +72,8 @@ def _apply_mask(self, mask):
return self._with_data_source(self._src._apply_mask(mask))

def _unpack_other(self, other):
if np.isscalar(other):
return other
if scalar := convert_to_scalar(other):
return scalar

if not isinstance(other, Slice):
raise TypeError("Trying to perform operation with incompatible types.")
Expand All @@ -88,7 +88,10 @@ def _unpack_other(self, other):

def _generate_labels(self, lhs, operator, rhs, keep_unit):
def get_label(item, key):
return item.labels.get(key, "") if not np.isscalar(item) else str(item)
try:
return item.labels.get(key, "")
except AttributeError:
return str(item) # scalar value case

labels = {"title": f"({get_label(lhs, 'title')} {operator} {get_label(rhs, 'title')})"}
if keep_unit:
Expand All @@ -110,14 +113,14 @@ def __neg__(self):
def __add__(self, other):
return Slice(
self._src._with_data(self.data + self._unpack_other(other)),
calibration=self._calibration if np.isscalar(other) else None,
calibration=self._calibration if convert_to_scalar(other) is not None else None,
labels=self._generate_labels(self, "+", other, keep_unit=True),
)

def __sub__(self, other):
return Slice(
self._src._with_data(self.data - self._unpack_other(other)),
calibration=self._calibration if np.isscalar(other) else None,
calibration=self._calibration if convert_to_scalar(other) is not None else None,
labels=self._generate_labels(self, "-", other, keep_unit=True),
)

Expand Down Expand Up @@ -160,7 +163,7 @@ def __rpow__(self, other):
def __radd__(self, other):
return Slice(
self._src._with_data(self.data + self._unpack_other(other)),
calibration=self._calibration if np.isscalar(other) else None,
calibration=self._calibration if convert_to_scalar(other) is not None else None,
labels=self._generate_labels(other, "+", self, keep_unit=True),
)

Expand Down
12 changes: 12 additions & 0 deletions lumicks/pylake/detail/utilities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import numbers
import contextlib

import numpy as np
Expand Down Expand Up @@ -213,3 +214,14 @@ def temp_seed(seed):
yield
finally:
np.random.seed(None)


def convert_to_scalar(value):
"""Converts to a numeric scalar if possible, otherwise returns None"""
try:
value = np.asarray(value).item()
except ValueError: # Can only convert array of size 1 to Python scalar
return None

if isinstance(value, numbers.Number):
return value
2 changes: 2 additions & 0 deletions lumicks/pylake/tests/test_channels/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def test_operator(first_slice, second_slice, operation):
[
(slice_continuous_1, 2.0),
(slice_timeseries_1, 2.0),
(slice_continuous_1, np.array(2.0)),
(slice_timeseries_1, np.array(2.0)),
],
)
def test_operations_scalar(slice1, scalar):
Expand Down
47 changes: 47 additions & 0 deletions lumicks/pylake/tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import matplotlib as mpl
from numpy.testing import assert_array_equal

from lumicks.pylake.channel import Slice, Continuous
from lumicks.pylake.detail.value import ValueMixin
from lumicks.pylake.detail.confocal import timestamp_mean
from lumicks.pylake.detail.utilities import *
from lumicks.pylake.detail.utilities import (
Expand Down Expand Up @@ -338,3 +340,48 @@ def example_method(self, argument=5):
assert test.example_method(6) == 6
assert calls == 3
assert len(test._cache) == 3


class Parameter(ValueMixin):
def __init__(self, value, description):
super().__init__(value)
self.description = description


@pytest.mark.parametrize(
"test_value, ref",
[
(
Slice(
Continuous(np.array([5.0]), int(1e6), 100),
labels={"y": "y", "title": "title", "x": "x"},
),
5.0,
),
(np.array(1.0), 1.0),
(np.array(1), 1),
(np.array([[[1]]]), 1),
(np.array([[[1.0]]]), 1.0),
([1], 1),
(Parameter(1.0, "the parameter"), 1.0),
(Parameter(1, "the parameter"), 1),
],
)
def test_convert_to_scalar_valid(test_value, ref):
assert (value := convert_to_scalar(test_value)) == ref
assert isinstance(value, type(ref))


@pytest.mark.parametrize(
"test_value",
(
Slice(Continuous(np.arange(100), int(1e6), 100), labels={"y": "y", "title": "t", "x": "x"}),
"str",
[1, 1],
Parameter(["1.0"], "bad string"),
Parameter(["string"], "bad string"),
Parameter([1, 1], "not a scalar"),
),
)
def test_convert_to_scalar_invalid(test_value):
assert convert_to_scalar(test_value) is None

0 comments on commit 285493c

Please sign in to comment.