From 0d17807f24d349d9429fed2c69928242c51265a4 Mon Sep 17 00:00:00 2001 From: Gabriel Selzer Date: Wed, 20 Nov 2024 23:17:28 -0600 Subject: [PATCH 1/4] Port stats model --- src/ndv/models/__init__.py | 9 +- src/ndv/models/_stats_model.py | 138 ++++++++++++++++++++++++++ src/ndv/views/protocols.py | 172 ++++++++++++++++++++++++++++++++- tests/test_stats_model.py | 55 +++++++++++ 4 files changed, 371 insertions(+), 3 deletions(-) create mode 100644 src/ndv/models/_stats_model.py create mode 100644 tests/test_stats_model.py diff --git a/src/ndv/models/__init__.py b/src/ndv/models/__init__.py index 801e689..9afa018 100644 --- a/src/ndv/models/__init__.py +++ b/src/ndv/models/__init__.py @@ -3,6 +3,13 @@ from ._array_display_model import ArrayDisplayModel from ._data_display_model import DataDisplayModel from ._lut_model import LUTModel +from ._stats_model import StatsModel from .data_wrappers._data_wrapper import DataWrapper -__all__ = ["ArrayDisplayModel", "LUTModel", "DataDisplayModel", "DataWrapper"] +__all__ = [ + "ArrayDisplayModel", + "LUTModel", + "DataDisplayModel", + "DataWrapper", + "StatsModel", +] diff --git a/src/ndv/models/_stats_model.py b/src/ndv/models/_stats_model.py new file mode 100644 index 0000000..7de259f --- /dev/null +++ b/src/ndv/models/_stats_model.py @@ -0,0 +1,138 @@ +"""Model protocols for data display.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Annotated, cast + +import numpy as np +from pydantic import ( + GetCoreSchemaHandler, + GetJsonSchemaHandler, + computed_field, + model_validator, +) +from pydantic_core import core_schema + +from ndv.models._base_model import NDVModel + +if TYPE_CHECKING: + from typing import Any + + from pydantic.json_schema import JsonSchemaValue + + +# copied from https://github.com/tlambert03/microsim +class _NumpyNdarrayPydanticAnnotation: + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + def validate_from_any(value: Any) -> np.ndarray: + try: + return np.asarray(value) + except Exception as e: + raise ValueError(f"Cannot cast {value} to numpy.ndarray: {e}") from e + + from_any_schema = core_schema.chain_schema( + [ + core_schema.any_schema(), + core_schema.no_info_plain_validator_function(validate_from_any), + ] + ) + + return core_schema.json_or_python_schema( + json_schema=from_any_schema, + python_schema=core_schema.union_schema( + [ + core_schema.is_instance_schema(np.ndarray), + from_any_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: instance.tolist() + ), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + # Use the same schema that would be used for arrays + return handler(core_schema.list_schema(core_schema.any_schema())) + + +NumpyNdarray = Annotated[np.ndarray, _NumpyNdarrayPydanticAnnotation] + + +class StatsModel(NDVModel): + """Representation of the statistics of a dataset. + + A model that computes and caches statistical properties of a dataset, + including standard deviation, average, and histogram. + + Those interested in statistics should listen to the events.data and events.bins + signals emitted by this object. + + TODO can we only have the data signal? + + Parameters + ---------- + data : np.ndarray | None + The dataset. + bins : int + Number of bins to use for histogram computation. Defaults to 256. + average : float + The average (mean) value of data. + standard_deviation : float + The standard deviation of data. + histogram : tuple[Sequence[int], Sequence[float]] + A 2-tuple of sequences. + + The first sequence contains (n) integers, where index i is the number of data + points in the ith bin. + + The second sequence contains (n+1) floats. The ith bin spans the domain + between the values at index i (inclusive) and index i+1 (exclusive). + """ + + data: NumpyNdarray | None = None + bins: int = 256 + + @model_validator(mode="before") + def validate_data(cls, input: dict[str, Any], *args: Any) -> dict[str, Any]: + """Delete computed fields when data changes.""" + # Recompute computed stats when bins/data changes + if "data" in input: + for field in ["average", "standard_deviation", "histogram"]: + if field in input: + del input[field] + return input + + @computed_field # type: ignore[prop-decorator] + @cached_property + def standard_deviation(self) -> float: + """Computes the standard deviation of the dataset.""" + if self.data is None: + return float("nan") + return float(np.std(self.data)) + + @computed_field # type: ignore[prop-decorator] + @cached_property + def average(self) -> float: + """Computes the average of the dataset.""" + if self.data is None: + return float("nan") + return float(np.mean(self.data)) + + @computed_field # type: ignore[prop-decorator] + @cached_property + def histogram(self) -> tuple[Sequence[int], Sequence[float]]: + """Computes the histogram of the dataset.""" + if self.data is None: + return ([], []) + return cast( + tuple[Sequence[int], Sequence[float]], + np.histogram(self.data, bins=self.bins), + ) diff --git a/src/ndv/views/protocols.py b/src/ndv/views/protocols.py index 8909acd..3e3cc89 100644 --- a/src/ndv/views/protocols.py +++ b/src/ndv/views/protocols.py @@ -3,12 +3,14 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Literal, Protocol +import cmap +from psygnal import Signal + if TYPE_CHECKING: from collections.abc import Container, Hashable, Mapping, Sequence - import cmap import numpy as np - from psygnal import Signal, SignalInstance + from psygnal import SignalInstance from qtpy.QtCore import Qt from qtpy.QtWidgets import QWidget @@ -28,6 +30,172 @@ def setClims(self, clims: tuple[float, float]) -> None: ... def setLutVisible(self, visible: bool) -> None: ... +class LutView(Protocol): + """An (interactive) view of a LookUp Table (LUT).""" + + cmapChanged: Signal = Signal(cmap.Colormap) + gammaChanged: Signal = Signal(float) + climsChanged: Signal = Signal(tuple[float, float]) + autoscaleChanged: Signal = Signal(object) + + def set_visibility(self, visible: bool) -> None: + """Defines whether this view is visible. + + Properties + ---------- + visible : bool + True iff the view should be visible. + """ + ... + + def set_cmap(self, lut: cmap.Colormap) -> None: + """Defines the colormap backing the view. + + Properties + ---------- + lut : cmap.Colormap + The object mapping scalar values to RGB(A) colors. + """ + ... + + def set_gamma(self, gamma: float) -> None: + """Defines the exponent used for gamma correction. + + Properties + ---------- + gamma : float + The exponent used for gamma correction + """ + ... + + def set_clims(self, clims: tuple[float, float]) -> None: + """Defines the input clims. + + The contrast limits (clims) are the input values mapped to the minimum and + maximum (respectively) of the LUT. + + Properties + ---------- + clims : tuple[float, float] + The clims + """ + ... + + def set_autoscale(self, autoscale: bool | tuple[float, float]) -> None: + """Defines whether autoscale has been enabled. + + Autoscale defines whether the contrast limits (clims) are adjusted when the + data changes. + + Properties + ---------- + autoscale : bool | tuple[float, float] + If a boolean, true iff clims automatically changed on dataset alteration. + If a tuple, indicated that clims automatically changed. Values denote + the fraction of the dataset located below and above the lower and + upper clims, respectively. + """ + ... + + def view(self) -> Any: + """The native object that can be displayed.""" + ... + + +class StatsView(Protocol): + """A view of the statistics of a dataset.""" + + def set_histogram( + self, values: Sequence[float], bin_edges: Sequence[float] + ) -> None: + """Defines the distribution of the dataset. + + Properties + ---------- + values : Sequence[int] + A length (n) sequence of values representing clustered counts of data + points. values[i] defines the number of data points falling between + bin_edges[i] and bin_edges[i+1]. + bin_edges : Sequence[float] + A length (n+1) sequence of values defining the intervals partitioning + all data points. Must be non-decreasing. + """ + ... + + def set_std_dev(self, std_dev: float) -> None: + """Defines the standard deviation of the dataset. + + Properties + ---------- + std_dev : float + The standard deviation. + """ + ... + + def set_average(self, avg: float) -> None: + """Defines the average value of the dataset. + + Properties + ---------- + std_dev : float + The average value of the dataset. + """ + ... + + def view(self) -> Any: + """The native object that can be displayed.""" + ... + + +class HistogramView(StatsView, LutView): + """A histogram-based view for LookUp Table (LUT) adjustment.""" + + def set_domain(self, bounds: tuple[float, float] | None) -> None: + """Sets the domain of the view. + + Properties + ---------- + bounds : tuple[float, float] | None + If a tuple, sets the displayed extremes of the x axis to the passed + values. If None, sets them to the extent of the data instead. + """ + ... + + def set_range(self, bounds: tuple[float, float] | None) -> None: + """Sets the range of the view. + + Properties + ---------- + bounds : tuple[float, float] | None + If a tuple, sets the displayed extremes of the y axis to the passed + values. If None, sets them to the extent of the data instead. + """ + ... + + def set_vertical(self, vertical: bool) -> None: + """Sets the axis of the domain. + + Properties + ---------- + vertical : bool + If true, views the domain along the y axis and the range along the x + axis. If false, views the domain along the x axis and the range along + the y axis. + """ + ... + + def set_range_log(self, enabled: bool) -> None: + """Sets the axis scale of the range. + + Properties + ---------- + enabled : bool + If true, the range will be displayed with a logarithmic (base 10) + scale. If false, the range will be displayed with a linear scale. + """ + ... + + class CanvasElement(Protocol): """Protocol defining an interactive element on the Canvas.""" diff --git a/tests/test_stats_model.py b/tests/test_stats_model.py new file mode 100644 index 0000000..c068e7b --- /dev/null +++ b/tests/test_stats_model.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from math import isnan +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from ndv.models import StatsModel + +if TYPE_CHECKING: + from pytestqt.qtbot import QtBot + + +EPSILON = 1e-6 + + +@pytest.fixture +def data() -> np.ndarray: + # Seeded random Generator + gen = np.random.default_rng(0xDEADBEEF) + # Average - 1.000104 + # Std. Dev. - 10.003385 + data = gen.normal(1, 10, (1000, 1000)) + return data + + +def test_empty_stats_model() -> None: + model = StatsModel() + assert None is model.data + assert isnan(model.average) + assert isnan(model.standard_deviation) + assert ([], []) == model.histogram + assert model.bins == 256 + + +def test_stats_model(qtbot: QtBot, data: np.ndarray) -> None: + model = StatsModel() + with qtbot.wait_signal(model.events.data): + model.data = data + assert np.all(model.data == data) + # Basic regression tests + assert abs(model.average - 1.000104) < 1e-6 + assert abs(model.standard_deviation - 10.003385) < 1e-6 + assert 256 == model.bins + values, edges = model.histogram + assert len(values) == 256 + assert np.all(values >= 0) + assert np.all(values <= data.size) + assert len(edges) == 257 + assert edges[0] == np.min(data) + assert edges[256] == np.max(data) + # Assert bins changed emits a signal + with qtbot.wait_signal(model.events.bins): + model.bins = 128 From 1978c59f1f618288635c366918b7acb9a04074ef Mon Sep 17 00:00:00 2001 From: Gabriel Selzer Date: Thu, 21 Nov 2024 17:58:20 -0600 Subject: [PATCH 2/4] Port histogram model/view --- mvc_histogram.py | 114 +++++ src/ndv/views/_vispy/_vispy.py | 746 ++++++++++++++++++++++++++++- src/ndv/views/protocols.py | 104 ++-- tests/test_vispy_histogram_view.py | 337 +++++++++++++ 4 files changed, 1242 insertions(+), 59 deletions(-) create mode 100644 mvc_histogram.py create mode 100644 tests/test_vispy_histogram_view.py diff --git a/mvc_histogram.py b/mvc_histogram.py new file mode 100644 index 0000000..37918b7 --- /dev/null +++ b/mvc_histogram.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from qtpy.QtCore import QTimer +from qtpy.QtWidgets import ( + QApplication, + QPushButton, + QVBoxLayout, + QWidget, +) + +from ndv.models import LUTModel, StatsModel +from ndv.views._vispy._vispy import VispyHistogramView + +if TYPE_CHECKING: + from typing import Any + + from ndv.views.protocols import PHistogramView + + +class Controller: + """A (Qt) wrapper around another HistogramView with some additional controls.""" + + def __init__( + self, + stats_model: StatsModel | None = None, + lut_model: LUTModel | None = None, + view: PHistogramView | None = None, + ) -> None: + self._wdg = QWidget() + if stats_model is None: + stats_model = StatsModel() + if lut_model is None: + lut_model = LUTModel() + if view is None: + view = VispyHistogramView() + self._stats = stats_model + self._lut = lut_model + self._view = view + + # A HistogramView is both a StatsView and a LUTView + # StatModel <-> StatsView + self._stats.events.data.connect(self._set_data) + self._stats.events.bins.connect(self._set_data) + # LutModel <-> LutView + self._lut.events.clims.connect(self._set_model_clims) + self._view.climsChanged.connect(self._set_view_clims) + self._lut.events.gamma.connect(self._set_model_gamma) + self._view.gammaChanged.connect(self._set_view_gamma) + + # Vertical box + self._vert = QPushButton("Vertical") + self._vert.setCheckable(True) + self._vert.toggled.connect(self._view.setVertical) + + # Log box + self._log = QPushButton("Logarithmic") + self._log.setCheckable(True) + self._log.toggled.connect(self._view.setRangeLog) + + # Data updates + self._data_btn = QPushButton("Change Data") + self._data_btn.setCheckable(True) + self._data_btn.toggled.connect( + lambda toggle: self.timer.blockSignals(not toggle) + ) + + def _update_data() -> None: + """Replaces the displayed data.""" + self._stats.data = np.random.normal(10, 10, 10000) + + self.timer = QTimer() + self.timer.setInterval(10) + self.timer.blockSignals(True) + self.timer.timeout.connect(_update_data) + self.timer.start() + + # Layout + self._layout = QVBoxLayout(self._wdg) + self._layout.addWidget(self._view.view()) + self._layout.addWidget(self._vert) + self._layout.addWidget(self._log) + self._layout.addWidget(self._data_btn) + + def _set_data(self) -> None: + values, bin_edges = self._stats.histogram + self._view.setHistogram(values, bin_edges) + + def _set_model_clims(self) -> None: + clims = self._lut.clims + self._view.setClims(clims) + + def _set_view_clims(self, clims: tuple[float, float]) -> None: + self._lut.clims = clims + + def _set_model_gamma(self) -> None: + gamma = self._lut.gamma + self._view.setGamma(gamma) + + def _set_view_gamma(self, gamma: float) -> None: + self._lut.gamma = gamma + + def view(self) -> Any: + """Returns an object that can be displayed by the active backend.""" + return self._wdg + + +app = QApplication.instance() or QApplication([]) + +widget = Controller() +widget.view().show() +app.exec() diff --git a/src/ndv/views/_vispy/_vispy.py b/src/ndv/views/_vispy/_vispy.py index b122fc4..3d49a19 100755 --- a/src/ndv/views/_vispy/_vispy.py +++ b/src/ndv/views/_vispy/_vispy.py @@ -2,7 +2,8 @@ import warnings from contextlib import suppress -from typing import TYPE_CHECKING, Any, Literal, cast +from enum import Enum, auto +from typing import TYPE_CHECKING, Any, Literal, TypedDict, Unpack, cast from weakref import WeakKeyDictionary import cmap @@ -10,17 +11,19 @@ import vispy import vispy.scene import vispy.visuals +from psygnal import Signal from qtpy.QtCore import Qt from vispy import scene from vispy.color import Color from vispy.util.quaternion import Quaternion -from ndv.views.protocols import PCanvas +from ndv.views.protocols import PCanvas, PHistogramView if TYPE_CHECKING: from collections.abc import Sequence from typing import Callable + import numpy.typing as npt from qtpy.QtWidgets import QWidget from ndv.views.protocols import CanvasElement @@ -596,3 +599,742 @@ def elements_at(self, pos_xy: tuple[float, float]) -> list[CanvasElement]: if (handle := self._elements.get(vis)) is not None: elements.append(handle) return elements + + +class Grabbable(Enum): + NONE = auto() + LEFT_CLIM = auto() + RIGHT_CLIM = auto() + GAMMA = auto() + + +if TYPE_CHECKING: + # just here cause vispy has poor type hints + from collections.abc import Sequence + + from vispy.app.canvas import MouseEvent + + class Grid(scene.Grid): + def add_view( + self, + row: int | None = None, + col: int | None = None, + row_span: int = 1, + col_span: int = 1, + **kwargs: Any, + ) -> scene.ViewBox: + super().add_view(...) + + def add_widget( + self, + widget: None | scene.Widget = None, + row: int | None = None, + col: int | None = None, + row_span: int = 1, + col_span: int = 1, + **kwargs: Any, + ) -> scene.Widget: + super().add_widget(...) + + class WidgetKwargs(TypedDict, total=False): + pos: tuple[float, float] + size: tuple[float, float] + border_color: str + border_width: float + bgcolor: str + padding: float + margin: float + + class TextVisualKwargs(TypedDict, total=False): + text: str + color: str + bold: bool + italic: bool + face: str + font_size: float + pos: tuple[float, float] | tuple[float, float, float] + rotation: float + method: Literal["cpu", "gpu"] + depth_test: bool + + class AxisWidgetKwargs(TypedDict, total=False): + orientation: Literal["left", "bottom"] + tick_direction: tuple[int, int] + axis_color: str + tick_color: str + text_color: str + minor_tick_length: float + major_tick_length: float + tick_width: float + tick_label_margin: float + tick_font_size: float + axis_width: float + axis_label: str + axis_label_margin: float + axis_font_size: float + font_size: float # overrides tick_font_size and axis_font_size + + +__all__ = ["PlotWidget"] + + +DEFAULT_AXIS_KWARGS: AxisWidgetKwargs = { + "text_color": "w", + "axis_color": "w", + "tick_color": "w", + "tick_width": 1, + "tick_font_size": 8, + "tick_label_margin": 12, + "axis_label_margin": 50, + "minor_tick_length": 2, + "major_tick_length": 5, + "axis_width": 1, + "axis_font_size": 10, +} + + +class Component(str, Enum): + PAD_LEFT = "pad_left" + PAD_RIGHT = "pad_right" + PAD_BOTTOM = "pad_bottom" + TITLE = "title" + CBAR_TOP = "cbar_top" + CBAR_LEFT = "cbar_left" + CBAR_RIGHT = "cbar_right" + CBAR_BOTTOM = "cbar_bottom" + YAXIS = "yaxis" + XAXIS = "xaxis" + XLABEL = "xlabel" + YLABEL = "ylabel" + + def __str__(self) -> str: + return self.value + + +class PlotWidget(scene.Widget): + """Widget to facilitate plotting. + + Parameters + ---------- + fg_color : str + The default color for the plot. + xlabel : str + The x-axis label. + ylabel : str + The y-axis label. + title : str + The title of the plot. + lock_axis : {'x', 'y', None} + Prevent panning and zooming along a particular axis. + **widget_kwargs : dict + Keyword arguments to pass to the parent class. + """ + + def __init__( + self, + fg_color: str = "k", + xlabel: str = "", + ylabel: str = "", + title: str = "", + lock_axis: Literal["x", "y", None] = None, + **widget_kwargs: Unpack[WidgetKwargs], + ) -> None: + self._fg_color = fg_color + self._visuals: list[scene.VisualNode] = [] + super().__init__(**widget_kwargs) + self.unfreeze() + self.grid = cast("Grid", self.add_grid(spacing=0, margin=10)) + + title_kwargs: TextVisualKwargs = {"font_size": 14, "color": "w"} + label_kwargs: TextVisualKwargs = {"font_size": 10, "color": "w"} + self._title = scene.Label(str(title), **title_kwargs) + self._xlabel = scene.Label(str(xlabel), **label_kwargs) + self._ylabel = scene.Label(str(ylabel), rotation=-90, **label_kwargs) + + axis_kwargs: AxisWidgetKwargs = DEFAULT_AXIS_KWARGS + self.yaxis = scene.AxisWidget(orientation="left", **axis_kwargs) + self.xaxis = scene.AxisWidget(orientation="bottom", **axis_kwargs) + + # 2D Plot layout: + # + # c0 c1 c2 c3 c4 c5 c6 + # +----------+-------+-------+-------+---------+---------+-----------+ + # r0 | | | title | | | + # | +-----------------------+---------+---------+ | + # r1 | | | cbar | | | + # |----------+-------+-------+-------+---------+---------+ ----------| + # r2 | pad_left | cbar | ylabel| yaxis | view | cbar | pad_right | + # |----------+-------+-------+-------+---------+---------+ ----------| + # r3 | | | xaxis | | | + # | +-----------------------+---------+---------+ | + # r4 | | | xlabel | | | + # | +-----------------------+---------+---------+ | + # r5 | | | cbar | | | + # |---------+------------------------+---------+---------+-----------| + # r6 | | pad_bottom | | + # +---------+------------------------+---------+---------+-----------+ + + self._grid_wdgs: dict[Component, scene.Widget] = {} + for name, row, col, widget in [ + (Component.PAD_LEFT, 2, 0, None), + (Component.PAD_RIGHT, 2, 6, None), + (Component.PAD_BOTTOM, 6, 4, None), + (Component.TITLE, 0, 4, self._title), + (Component.CBAR_TOP, 1, 4, None), + (Component.CBAR_LEFT, 2, 1, None), + (Component.CBAR_RIGHT, 2, 5, None), + (Component.CBAR_BOTTOM, 5, 4, None), + (Component.YAXIS, 2, 3, self.yaxis), + (Component.XAXIS, 3, 4, self.xaxis), + (Component.XLABEL, 4, 4, self._xlabel), + (Component.YLABEL, 2, 2, self._ylabel), + ]: + self._grid_wdgs[name] = wdg = self.grid.add_widget(widget, row=row, col=col) + # If we don't set max size, they will expand to fill the entire grid + # occluding pretty much everything else. + if str(name).startswith(("cbar", "pad")): + if name in { + Component.PAD_LEFT, + Component.PAD_RIGHT, + Component.CBAR_LEFT, + Component.CBAR_RIGHT, + }: + wdg.width_max = 2 + else: + wdg.height_max = 2 + + # The main view into which plots are added + self._view = self.grid.add_view(row=2, col=4) + + # NOTE: this is a mess of hardcoded values... not sure whether they will work + # cross-platform. Note that `width_max` and `height_max` of 2 is actually + # *less* visible than 0 for some reason. They should also be extracted into + # some sort of `hide/show` logic for each component + self._grid_wdgs[Component.YAXIS].width_max = 30 # otherwise it takes too much + self._grid_wdgs[Component.PAD_LEFT].width_max = 20 # otherwise you get clipping + self._grid_wdgs[Component.XAXIS].height_max = 20 # otherwise it takes too much + self.ylabel = ylabel + self.xlabel = xlabel + self.title = title + + # VIEWBOX (this has to go last, see vispy #1748) + self.camera = self._view.camera = PanZoom1DCamera(lock_axis) + # this has to come after camera is set + self.xaxis.link_view(self._view) + self.yaxis.link_view(self._view) + self.freeze() + + @property + def title(self) -> str: + """The title label.""" + return self._title.text # type: ignore [no-any-return] + + @title.setter + def title(self, text: str) -> None: + """Set the title of the plot.""" + self._title.text = text + wdg = self._grid_wdgs[Component.TITLE] + wdg.height_min = wdg.height_max = 30 if text else 2 + + @property + def xlabel(self) -> str: + """The x-axis label.""" + return self._xlabel.text # type: ignore [no-any-return] + + @xlabel.setter + def xlabel(self, text: str) -> None: + """Set the x-axis label.""" + self._xlabel.text = text + wdg = self._grid_wdgs[Component.XLABEL] + wdg.height_min = wdg.height_max = 40 if text else 2 + + @property + def ylabel(self) -> str: + """The y-axis label.""" + return self._ylabel.text # type: ignore [no-any-return] + + @ylabel.setter + def ylabel(self, text: str) -> None: + """Set the x-axis label.""" + self._ylabel.text = text + wdg = self._grid_wdgs[Component.YLABEL] + wdg.width_min = wdg.width_max = 20 if text else 2 + + def lock_axis(self, axis: Literal["x", "y", None]) -> None: + """Prevent panning and zooming along a particular axis.""" + self.camera._axis = axis + # self.camera.set_range() + + +class PanZoom1DCamera(scene.cameras.PanZoomCamera): + """Camera that allows panning and zooming along one axis only. + + Parameters + ---------- + axis : {'x', 'y', None} + The axis along which to allow panning and zooming. + *args : tuple + Positional arguments to pass to the parent class. + **kwargs : dict + Keyword arguments to pass to the parent class. + """ + + def __init__( + self, axis: Literal["x", "y", None] = None, *args: Any, **kwargs: Any + ) -> None: + self._axis: Literal["x", "y", None] = axis + super().__init__(*args, **kwargs) + + @property + def axis_index(self) -> Literal[0, 1, None]: + """Return the index of the axis along which to pan and zoom.""" + if self._axis in ("x", 0): + return 0 + elif self._axis in ("y", 1): + return 1 + return None + + def zoom( + self, + factor: float | tuple[float, float], + center: tuple[float, ...] | None = None, + ) -> None: + """Zoom the camera by `factor` around `center`.""" + if self.axis_index is None: + super().zoom(factor, center=center) + return + + if isinstance(factor, (float, int)): + factor = (factor, factor) + _factor = list(factor) + _factor[self.axis_index] = 1 + super().zoom(_factor, center=center) + + def pan(self, pan: Sequence[float]) -> None: + """Pan the camera by `pan`.""" + if self.axis_index is None: + super().pan(pan) + return + _pan = list(pan) + _pan[self.axis_index] = 0 + super().pan(*_pan) + + def set_range( + self, + x: tuple | None = None, + y: tuple | None = None, + z: tuple | None = None, + margin: float = 0, # overriding to create a different default from super() + ) -> None: + """Reset the camera view to the specified range.""" + super().set_range(x, y, z, margin) + + +# TODO: Move much of this logic to _qt +class VispyHistogramView(PHistogramView): + """A HistogramView on a VisPy SceneCanvas.""" + + visibleChanged = Signal(bool) + autoscaleChanged = Signal(bool) + cmapChanged = Signal(cmap.Colormap) + climsChanged = Signal(tuple) + gammaChanged = Signal(float) + + def __init__(self) -> None: + # ------------ data and state ------------ # + + self._values: Sequence[float] | np.ndarray | None = None + self._bin_edges: Sequence[float] | np.ndarray | None = None + self._clims: tuple[float, float] | None = None + self._gamma: float = 1 + + # the currently grabbed object + self._grabbed: Grabbable = Grabbable.NONE + # whether the y-axis is logarithmic + self._log_y: bool = False + # whether the histogram is vertical + self._vertical: bool = False + # The values of the left and right edges on the canvas (respectively) + self._domain: tuple[float, float] | None = None + # The values of the bottom and top edges on the canvas (respectively) + self._range: tuple[float, float] | None = None + + # ------------ VisPy Canvas ------------ # + + self._canvas = scene.SceneCanvas() + self._canvas.unfreeze() + self._canvas.on_mouse_press = self.on_mouse_press + self._canvas.on_mouse_move = self.on_mouse_move + self._canvas.on_mouse_release = self.on_mouse_release + self._canvas.freeze() + + ## -- Visuals -- ## + + # NB We directly use scene.Mesh, instead of scene.Histogram, + # so that we can control the calculation of the histogram ourselves + self._hist_mesh = scene.Mesh(color="red") + + # The Lut Line visualizes both the clims (vertical line segments connecting the + # first two and last two points, respectively) and the gamma curve + # (the polyline between all remaining points) + self._lut_line = scene.LinePlot( + data=(0), # Dummy value to prevent resizing errors + color="k", + connect="strip", + symbol=None, + line_kind="-", + width=1.5, + marker_size=10.0, + edge_color="k", + face_color="b", + edge_width=1.0, + ) + self._lut_line.visible = False + self._lut_line.order = -1 + + # The gamma handle appears halfway between the clims + self._gamma_handle_pos: np.ndarray = np.ndarray((1, 2)) + self._gamma_handle = scene.Markers( + pos=self._gamma_handle_pos, + size=6, + edge_width=0, + ) + self._gamma_handle.visible = False + self._gamma_handle.order = -2 + + # One transform to rule them all! + self._handle_transform = scene.transforms.STTransform() + self._lut_line.transform = self._handle_transform + self._gamma_handle.transform = self._handle_transform + + ## -- Plot -- ## + self.plot = PlotWidget() + self.plot.lock_axis("y") + self._canvas.central_widget.add_widget(self.plot) + self.node_tform = self.plot.node_transform(self.plot._view.scene) + + self.plot._view.add(self._hist_mesh) + self.plot._view.add(self._lut_line) + self.plot._view.add(self._gamma_handle) + + # ------------- StatsView Protocol methods ------------- # + + def setHistogram(self, values: Sequence[float], bin_edges: Sequence[float]) -> None: + """Set the histogram values and bin edges. + + These inputs follow the same format as the return value of numpy.histogram. + """ + self._values = values + self._bin_edges = bin_edges + self._update_histogram() + if self._clims is None: + self.setClims((self._bin_edges[0], self._bin_edges[-1])) + self._resize() + + def setStdDev(self, std_dev: float) -> None: + # Nothing to do. + # TODO: maybe show text somewhere + pass + + def setAverage(self, average: float) -> None: + # Nothing to do + # TODO: maybe show text somewhere + pass + + def view(self) -> Any: + return self._canvas.native + + # ------------- LutView Protocol methods ------------- # + + def setName(self, name: str) -> None: + # Nothing to do + # TODO: maybe show text somewhere + pass + + def setLutVisible(self, visible: bool) -> None: + if self._hist_mesh is None: + return # pragma: no cover + self._hist_mesh.visible = visible + self._lut_line.visible = visible + self._gamma_handle.visible = visible + + def setColormap(self, lut: cmap.Colormap) -> None: + if self._hist_mesh is not None: + self._hist_mesh.color = lut.color_stops[-1].color.hex + + def setGamma(self, gamma: float) -> None: + if gamma < 0: + raise ValueError("gamma must be non-negative!") + self._gamma = gamma + self._update_lut_lines() + + def setClims(self, clims: tuple[float, float] | None) -> None: + # FIXME + if clims is None: + return + if clims[1] < clims[0]: + clims = (clims[1], clims[0]) + self._clims = clims + self._update_lut_lines() + + def setAutoScale(self, autoscale: bool | tuple[float, float]) -> None: + # Nothing to do (yet) + pass + + # ------------- HistogramView Protocol methods ------------- # + + def setDomain(self, bounds: tuple[float, float] | None) -> None: + if bounds is not None: + if bounds[0] is None or bounds[1] is None: + # TODO: Sensible defaults? + raise ValueError("Domain min/max cannot be None!") + if bounds[0] > bounds[1]: + bounds = (bounds[1], bounds[0]) + self._domain = bounds + self._resize() + + def setRange(self, bounds: tuple[float, float] | None) -> None: + if bounds is not None: + if bounds[0] is None or bounds[1] is None: + # TODO: Sensible defaults? + raise ValueError("Range min/max cannot be None!") + if bounds[0] > bounds[1]: + bounds = (bounds[1], bounds[0]) + self._range = bounds + self._resize() + + def setVertical(self, vertical: bool) -> None: + self._vertical = vertical + self._update_histogram() + self.plot.lock_axis("x" if vertical else "y") + # When vertical, smaller values should appear at the top of the canvas + self.plot.camera.flip = [False, vertical, False] + self._update_lut_lines() + self._resize() + + def setRangeLog(self, enabled: bool) -> None: + if enabled != self._log_y: + self._log_y = enabled + self._update_histogram() + self._update_lut_lines() + self._resize() + + # ------------- Private methods ------------- # + + def _update_histogram(self) -> None: + """ + Updates the displayed histogram with current View parameters. + + NB: Much of this code is graciously borrowed from: + + https://github.com/vispy/vispy/blob/af847424425d4ce51f144a4d1c75ab4033fe39be/vispy/visuals/histogram.py#L28 + """ + if self._values is None or self._bin_edges is None: + return # pragma: no cover + values = self._values + if self._log_y: + # Replace zero values with 1 (which will be log10(1) = 0) + values = np.where(values == 0, 1, values) + values = np.log10(values) + + verts, faces = _hist_counts_to_mesh(values, self._bin_edges, self._vertical) + self._hist_mesh.set_data(vertices=verts, faces=faces) + + # FIXME: This should be called internally upon set_data, right? + # Looks like https://github.com/vispy/vispy/issues/1899 + self._hist_mesh._bounds_changed() + + def _update_lut_lines(self, npoints: int = 256) -> None: + if self._clims is None or self._gamma is None: + return # pragma: no cover + + # 2 additional points for each of the two vertical clims lines + X = np.empty(npoints + 4) + Y = np.empty(npoints + 4) + if self._vertical: + # clims lines + X[0:2], Y[0:2] = (1, 0.5), self._clims[0] + X[-2:], Y[-2:] = (0.5, 0), self._clims[1] + # gamma line + X[2:-2] = np.linspace(0, 1, npoints) ** self._gamma + Y[2:-2] = np.linspace(self._clims[0], self._clims[1], npoints) + midpoint = np.array([(2**-self._gamma, np.mean(self._clims))]) + else: + # clims lines + X[0:2], Y[0:2] = self._clims[0], (1, 0.5) + X[-2:], Y[-2:] = self._clims[1], (0.5, 0) + # gamma line + X[2:-2] = np.linspace(self._clims[0], self._clims[1], npoints) + Y[2:-2] = np.linspace(0, 1, npoints) ** self._gamma + midpoint = np.array([(np.mean(self._clims), 2**-self._gamma)]) + + # TODO: Move to self.edit_cmap + color = np.linspace(0.2, 0.8, npoints + 4).repeat(4).reshape(-1, 4) + c1, c2 = [0.4] * 4, [0.7] * 4 + color[0:3] = [c1, c2, c1] + color[-3:] = [c1, c2, c1] + + self._lut_line.set_data((X, Y), marker_size=0, color=color) + self._lut_line.visible = True + + self._gamma_handle_pos[:] = midpoint[0] + self._gamma_handle.set_data(pos=self._gamma_handle_pos) + self._gamma_handle.visible = True + + # FIXME: These should be called internally upon set_data, right? + # Looks like https://github.com/vispy/vispy/issues/1899 + self._lut_line._bounds_changed() + for v in self._lut_line._subvisuals: + v._bounds_changed() + self._gamma_handle._bounds_changed() + + def on_mouse_press(self, event: MouseEvent) -> None: + if event.pos is None: + return # pragma: no cover + # check whether the user grabbed a node + self._grabbed = self._find_nearby_node(event) + if self._grabbed != Grabbable.NONE: + # disconnect the pan/zoom mouse events until handle is dropped + self.plot.camera.interactive = False + + def on_mouse_release(self, event: MouseEvent) -> None: + self._grabbed = Grabbable.NONE + self.plot.camera.interactive = True + + def on_mouse_move(self, event: MouseEvent) -> None: + """Called whenever mouse moves over canvas.""" + if event.pos is None: + return # pragma: no cover + if self._clims is None: + return # pragma: no cover + + if self._grabbed in [Grabbable.LEFT_CLIM, Grabbable.RIGHT_CLIM]: + newlims = list(self._clims) + if self._vertical: + c = self._to_plot_coords(event.pos)[1] + else: + c = self._to_plot_coords(event.pos)[0] + if self._grabbed is Grabbable.LEFT_CLIM: + newlims[0] = min(newlims[1], c) + elif self._grabbed is Grabbable.RIGHT_CLIM: + newlims[1] = max(newlims[0], c) + self.climsChanged.emit(newlims) + return + elif self._grabbed is Grabbable.GAMMA: + y0, y1 = ( + self.plot.xaxis.axis.domain + if self._vertical + else self.plot.yaxis.axis.domain + ) + y = self._to_plot_coords(event.pos)[0 if self._vertical else 1] + if y < np.maximum(y0, 0) or y > y1: + return + self.gammaChanged.emit(-np.log2(y / y1)) + return + + # TODO: try to remove the Qt aspect here so that we can use + # this for Jupyter as well + self._canvas.native.unsetCursor() + + nearby = self._find_nearby_node(event) + + if nearby in [Grabbable.LEFT_CLIM, Grabbable.RIGHT_CLIM]: + if self._vertical: + cursor = Qt.CursorShape.SplitVCursor + else: + cursor = Qt.CursorShape.SplitHCursor + self._canvas.native.setCursor(cursor) + elif nearby is Grabbable.GAMMA: + if self._vertical: + cursor = Qt.CursorShape.SplitHCursor + else: + cursor = Qt.CursorShape.SplitVCursor + self._canvas.native.setCursor(cursor) + else: + x, y = self._to_plot_coords(event.pos) + x1, x2 = self.plot.xaxis.axis.domain + y1, y2 = self.plot.yaxis.axis.domain + if (x1 < x <= x2) and (y1 <= y <= y2): + self._canvas.native.setCursor(Qt.CursorShape.SizeAllCursor) + + def _find_nearby_node(self, event: MouseEvent, tolerance: int = 5) -> Grabbable: + """Describes whether the event is near a clim.""" + click_x, click_y = event.pos + + # NB Computations are performed in canvas-space + # for easier tolerance computation. + plot_to_canvas = self.node_tform.imap + gamma_to_plot = self._handle_transform.map + + if self._clims is not None: + if self._vertical: + click = click_y + right = plot_to_canvas([0, self._clims[1]])[1] + left = plot_to_canvas([0, self._clims[0]])[1] + else: + click = click_x + right = plot_to_canvas([self._clims[1], 0])[0] + left = plot_to_canvas([self._clims[0], 0])[0] + + # Right bound always selected on overlap + if bool(abs(right - click) < tolerance): + return Grabbable.RIGHT_CLIM + if bool(abs(left - click) < tolerance): + return Grabbable.LEFT_CLIM + + if self._gamma_handle_pos is not None: + gx, gy = plot_to_canvas(gamma_to_plot(self._gamma_handle_pos[0]))[:2] + if bool(abs(gx - click_x) < tolerance and abs(gy - click_y) < tolerance): + return Grabbable.GAMMA + + return Grabbable.NONE + + def _to_plot_coords(self, pos: Sequence[float]) -> tuple[float, float]: + """Return the plot coordinates of the given position.""" + x, y = self.node_tform.map(pos)[:2] + return x, y + + def _resize(self) -> None: + self.plot.camera.set_range( + x=self._range if self._vertical else self._domain, + y=self._domain if self._vertical else self._range, + # FIXME: Bitten by https://github.com/vispy/vispy/issues/1483 + # It's pretty visible in logarithmic mode + margin=1e-30, + ) + if self._vertical: + scale = 0.98 * self.plot.xaxis.axis.domain[1] + self._handle_transform.scale = (scale, 1) + else: + scale = 0.98 * self.plot.yaxis.axis.domain[1] + self._handle_transform.scale = (1, scale) + + +def _hist_counts_to_mesh( + values: Sequence[float] | npt.NDArray, + bin_edges: Sequence[float] | npt.NDArray, + vertical: bool = False, +) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.uint32]]: + """Convert histogram counts to mesh vertices and faces for plotting.""" + n_edges = len(bin_edges) + X, Y = (1, 0) if vertical else (0, 1) + + # 4-5 + # | | + # 1-2/7-8 + # |/| | | + # 0-3-6-9 + # construct vertices + vertices = np.zeros((3 * n_edges - 2, 3), np.float32) + vertices[:, X] = np.repeat(bin_edges, 3)[1:-1] + vertices[1::3, Y] = values + vertices[2::3, Y] = values + vertices[vertices == float("-inf")] = 0 + + # construct triangles + faces = np.zeros((2 * n_edges - 2, 3), np.uint32) + offsets = 3 * np.arange(n_edges - 1, dtype=np.uint32)[:, np.newaxis] + faces[::2] = np.array([0, 2, 1]) + offsets + faces[1::2] = np.array([2, 0, 3]) + offsets + + return vertices, faces diff --git a/src/ndv/views/protocols.py b/src/ndv/views/protocols.py index 3e3cc89..3ba27c7 100644 --- a/src/ndv/views/protocols.py +++ b/src/ndv/views/protocols.py @@ -3,14 +3,12 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Literal, Protocol -import cmap -from psygnal import Signal - if TYPE_CHECKING: from collections.abc import Container, Hashable, Mapping, Sequence + import cmap import numpy as np - from psygnal import SignalInstance + from psygnal import Signal, SignalInstance from qtpy.QtCore import Qt from qtpy.QtWidgets import QWidget @@ -22,78 +20,76 @@ class PLutView(Protocol): autoscaleChanged: Signal cmapChanged: Signal climsChanged: Signal + gammaChanged: Signal - def setName(self, name: str) -> None: ... - def setAutoScale(self, auto: bool) -> None: ... - def setColormap(self, cmap: cmap.Colormap) -> None: ... - def setClims(self, clims: tuple[float, float]) -> None: ... - def setLutVisible(self, visible: bool) -> None: ... - - -class LutView(Protocol): - """An (interactive) view of a LookUp Table (LUT).""" - - cmapChanged: Signal = Signal(cmap.Colormap) - gammaChanged: Signal = Signal(float) - climsChanged: Signal = Signal(tuple[float, float]) - autoscaleChanged: Signal = Signal(object) - - def set_visibility(self, visible: bool) -> None: - """Defines whether this view is visible. + def setName(self, name: str) -> None: + """Defines the name of the view. Properties ---------- - visible : bool - True iff the view should be visible. + name : str + The name (label) of the LUT """ ... - def set_cmap(self, lut: cmap.Colormap) -> None: - """Defines the colormap backing the view. + def setAutoScale(self, auto: bool | tuple) -> None: + """Defines whether autoscale has been enabled. + + Autoscale defines whether the contrast limits (clims) are adjusted when the + data changes. Properties ---------- - lut : cmap.Colormap - The object mapping scalar values to RGB(A) colors. + autoscale : bool | tuple[float, float] + If a boolean, true iff clims automatically changed on dataset alteration. + If a tuple, indicated that clims automatically changed. Values denote + the fraction of the dataset located below and above the lower and + upper clims, respectively. """ ... - def set_gamma(self, gamma: float) -> None: - """Defines the exponent used for gamma correction. + def setColormap(self, cmap: cmap.Colormap) -> None: + """Defines the colormap backing the view. Properties ---------- - gamma : float - The exponent used for gamma correction + lut : cmap.Colormap + The object mapping scalar values to RGB(A) colors. """ ... - def set_clims(self, clims: tuple[float, float]) -> None: + def setClims(self, clims: tuple[float, float] | None) -> None: """Defines the input clims. The contrast limits (clims) are the input values mapped to the minimum and maximum (respectively) of the LUT. + TODO: What does None imply? Autoscale? + Properties ---------- - clims : tuple[float, float] + clims : tuple[float, float] | None The clims """ ... - def set_autoscale(self, autoscale: bool | tuple[float, float]) -> None: - """Defines whether autoscale has been enabled. + def setGamma(self, gamma: float) -> None: + """Defines the input gamma. - Autoscale defines whether the contrast limits (clims) are adjusted when the - data changes. + properties + ---------- + gamma : float + The gamma + """ + ... + + def setLutVisible(self, visible: bool) -> None: + """Defines whether this view is visible. Properties ---------- - autoscale : bool | tuple[float, float] - If a boolean, true iff clims automatically changed on dataset alteration. - If a tuple, indicated that clims automatically changed. Values denote - the fraction of the dataset located below and above the lower and - upper clims, respectively. + visible : bool + True iff the view should be visible. """ ... @@ -102,12 +98,10 @@ def view(self) -> Any: ... -class StatsView(Protocol): +class PStatsView(Protocol): """A view of the statistics of a dataset.""" - def set_histogram( - self, values: Sequence[float], bin_edges: Sequence[float] - ) -> None: + def setHistogram(self, values: Sequence[float], bin_edges: Sequence[float]) -> None: """Defines the distribution of the dataset. Properties @@ -122,7 +116,7 @@ def set_histogram( """ ... - def set_std_dev(self, std_dev: float) -> None: + def setStdDev(self, std_dev: float) -> None: """Defines the standard deviation of the dataset. Properties @@ -132,7 +126,7 @@ def set_std_dev(self, std_dev: float) -> None: """ ... - def set_average(self, avg: float) -> None: + def setAverage(self, avg: float) -> None: """Defines the average value of the dataset. Properties @@ -142,15 +136,11 @@ def set_average(self, avg: float) -> None: """ ... - def view(self) -> Any: - """The native object that can be displayed.""" - ... - -class HistogramView(StatsView, LutView): +class PHistogramView(PStatsView, PLutView): """A histogram-based view for LookUp Table (LUT) adjustment.""" - def set_domain(self, bounds: tuple[float, float] | None) -> None: + def setDomain(self, bounds: tuple[float, float] | None) -> None: """Sets the domain of the view. Properties @@ -161,7 +151,7 @@ def set_domain(self, bounds: tuple[float, float] | None) -> None: """ ... - def set_range(self, bounds: tuple[float, float] | None) -> None: + def setRange(self, bounds: tuple[float, float] | None) -> None: """Sets the range of the view. Properties @@ -172,7 +162,7 @@ def set_range(self, bounds: tuple[float, float] | None) -> None: """ ... - def set_vertical(self, vertical: bool) -> None: + def setVertical(self, vertical: bool) -> None: """Sets the axis of the domain. Properties @@ -184,7 +174,7 @@ def set_vertical(self, vertical: bool) -> None: """ ... - def set_range_log(self, enabled: bool) -> None: + def setRangeLog(self, enabled: bool) -> None: """Sets the axis scale of the range. Properties diff --git a/tests/test_vispy_histogram_view.py b/tests/test_vispy_histogram_view.py new file mode 100644 index 0000000..59ecba1 --- /dev/null +++ b/tests/test_vispy_histogram_view.py @@ -0,0 +1,337 @@ +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import cmap +import numpy as np +import pytest +from qtpy.QtWidgets import QHBoxLayout, QWidget +from vispy.app.canvas import MouseEvent +from vispy.color import Color + +from ndv.views._vispy._vispy import Grabbable, VispyHistogramView + +if TYPE_CHECKING: + from pytestqt.qtbot import QtBot + +# Accounts for differences between 32-bit and 64-bit floats +EPSILON = 1e-6 +# FIXME: Why do plot checks need a larger epsilon? +PLOT_EPSILON = 1e-4 + + +@pytest.fixture +def data() -> np.ndarray: + gen = np.random.default_rng(seed=0xDEADBEEF) + return gen.normal(10, 10, 10000).astype(np.float64) + + +@pytest.fixture +def view(qtbot: QtBot, data: np.ndarray) -> VispyHistogramView: + # Create view + view = VispyHistogramView() + view._canvas.size = (100, 100) + # FIXME: Why does `qtbot.add_widget(view.view())` not work? + wdg = QWidget() + layout = QHBoxLayout(wdg) + layout.addWidget(view.view()) + qtbot.add_widget(wdg) + # Set initial data + values, bin_edges = np.histogram(data) + view.setHistogram(values, bin_edges) + + return view + + +def test_plot(view: VispyHistogramView) -> None: + plot = view.plot + + assert plot.title == "" + plot.title = "foo" + assert plot._title.text == "foo" + + assert plot.xlabel == "" + plot.xlabel = "bar" + assert plot._xlabel.text == "bar" + + assert plot.ylabel == "" + plot.ylabel = "baz" + assert plot._ylabel.text == "baz" + + # Test axis lock - pan + _domain = plot.xaxis.axis.domain + _range = plot.yaxis.axis.domain + plot.camera.pan([20, 20]) + assert np.all(np.isclose(_domain, [x - 20 for x in plot.xaxis.axis.domain])) + assert np.all(np.isclose(_range, plot.yaxis.axis.domain)) + + # Test axis lock - zoom + _domain = plot.xaxis.axis.domain + _range = plot.yaxis.axis.domain + plot.camera.zoom(0.5) + dx = (_domain[1] - _domain[0]) / 4 + assert np.all( + np.isclose([_domain[0] + dx, _domain[1] - dx], plot.xaxis.axis.domain) + ) + assert np.all(np.isclose(_range, plot.yaxis.axis.domain)) + + +def test_clims(data: np.ndarray, view: VispyHistogramView) -> None: + # on startup, clims should be at the extent of the data + clims = np.min(data), np.max(data) + assert view._clims is not None + assert clims[0] == view._clims[0] + assert clims[1] == view._clims[1] + assert abs(clims[0] - view._lut_line._line.pos[0, 0]) <= EPSILON + assert abs(clims[1] - view._lut_line._line.pos[-1, 0]) <= EPSILON + # set clims, assert a change + clims = 9, 11 + view.setClims(clims) + assert clims[0] == view._clims[0] + assert clims[1] == view._clims[1] + assert abs(clims[0] - view._lut_line._line.pos[0, 0]) <= EPSILON + assert abs(clims[1] - view._lut_line._line.pos[-1, 0]) <= EPSILON + # set clims backwards - ensure the view flips them + clims = 5, 3 + view.setClims(clims) + assert clims[1] == view._clims[0] + assert clims[0] == view._clims[1] + assert abs(clims[1] - view._lut_line._line.pos[0, 0]) <= EPSILON + assert abs(clims[0] - view._lut_line._line.pos[-1, 0]) <= EPSILON + + +def test_gamma(data: np.ndarray, view: VispyHistogramView) -> None: + # on startup, gamma should be 1 + assert 1 == view._gamma + gx, gy = (np.max(data) + np.min(data)) / 2, 0.5**view._gamma + assert abs(gx - view._gamma_handle_pos[0, 0]) <= EPSILON + assert abs(gy - view._gamma_handle_pos[0, 1]) <= EPSILON + # set gamma, assert a change + g = 2 + view.setGamma(g) + assert g == view._gamma + gx, gy = (np.max(data) + np.min(data)) / 2, 0.5**view._gamma + assert abs(gx - view._gamma_handle_pos[0, 0]) <= EPSILON + assert abs(gy - view._gamma_handle_pos[0, 1]) <= EPSILON + # set invalid gammas, assert no change + with pytest.raises(ValueError): + view.setGamma(-1) + + +def test_cmap(view: VispyHistogramView) -> None: + # By default, histogram is red + assert view._hist_mesh.color == Color("red") + # Set cmap, assert a change + view.setColormap(cmap.Colormap("blue")) + assert view._hist_mesh.color == Color("blue") + + +def test_visibility(view: VispyHistogramView) -> None: + # By default, everything is visible + assert view._hist_mesh.visible + assert view._lut_line.visible + assert view._gamma_handle.visible + # Visible = False + view.setLutVisible(False) + assert not view._hist_mesh.visible + assert not view._lut_line.visible + assert not view._gamma_handle.visible + # Visible = True + view.setLutVisible(True) + assert view._hist_mesh.visible + assert view._lut_line.visible + assert view._gamma_handle.visible + + +def test_domain(data: np.ndarray, view: VispyHistogramView) -> None: + def assert_extent(min_x: float, max_x: float) -> None: + domain = view.plot.xaxis.axis.domain + assert abs(min_x - domain[0]) <= PLOT_EPSILON + assert abs(max_x - domain[1]) <= PLOT_EPSILON + min_y, max_y = 0, np.max(np.histogram(data)[0]) + range = view.plot.yaxis.axis.domain # noqa: A001 + assert abs(min_y - range[0]) <= PLOT_EPSILON + assert abs(max_y - range[1]) <= PLOT_EPSILON + + # By default, the view should be around the histogram + assert_extent(np.min(data), np.max(data)) + # Set the domain, request a change + new_domain = (10, 12) + view.setDomain(new_domain) + assert_extent(*new_domain) + # Set the domain to None, assert going back + new_domain = None + view.setDomain(new_domain) + assert_extent(np.min(data), np.max(data)) + # Assert None value in tuple raises ValueError + with pytest.raises(ValueError): + view.setDomain((None, 12)) + # Set the domain with min>max, ensure values flipped + new_domain = (12, 10) + view.setDomain(new_domain) + assert_extent(10, 12) + + +def test_range(data: np.ndarray, view: VispyHistogramView) -> None: + # FIXME: Why do we need a larger epsilon? + _EPSILON = 1e-4 + + def assert_extent(min_y: float, max_y: float) -> None: + min_x, max_x = np.min(data), np.max(data) + domain = view.plot.xaxis.axis.domain + assert abs(min_x - domain[0]) <= _EPSILON + assert abs(max_x - domain[1]) <= _EPSILON + range = view.plot.yaxis.axis.domain # noqa: A001 + assert abs(min_y - range[0]) <= _EPSILON + assert abs(max_y - range[1]) <= _EPSILON + + # By default, the view should be around the histogram + assert_extent(0, np.max(np.histogram(data)[0])) + # Set the range, request a change + new_range = (10, 12) + view.setRange(new_range) + assert_extent(*new_range) + # Set the range to None, assert going back + new_range = None + view.setRange(new_range) + assert_extent(0, np.max(np.histogram(data)[0])) + # Assert None value in tuple raises ValueError + with pytest.raises(ValueError): + view.setRange((None, 12)) + # Set the range with min>max, ensure values flipped + new_range = (12, 10) + view.setRange(new_range) + assert_extent(10, 12) + + +def test_vertical(view: VispyHistogramView) -> None: + # Start out Horizontal + assert not view._vertical + domain_before = view.plot.xaxis.axis.domain + range_before = view.plot.yaxis.axis.domain + # Toggle vertical, assert domain <-> range + view.setVertical(True) + assert view._vertical + domain_after = view.plot.xaxis.axis.domain + # NB vertical mode inverts y axis + range_after = view.plot.yaxis.axis.domain[::-1] + assert abs(domain_before[0] - range_after[0]) <= PLOT_EPSILON + assert abs(domain_before[1] - range_after[1]) <= PLOT_EPSILON + assert abs(range_before[0] - domain_after[0]) <= PLOT_EPSILON + assert abs(range_before[1] - domain_after[1]) <= PLOT_EPSILON + # Toggle vertical again, assert domain <-> range again + view.setVertical(False) + assert not view._vertical + domain_after = view.plot.xaxis.axis.domain + range_after = view.plot.yaxis.axis.domain + assert abs(domain_before[0] - domain_after[0]) <= PLOT_EPSILON + assert abs(domain_before[1] - domain_after[1]) <= PLOT_EPSILON + assert abs(range_before[0] - range_after[0]) <= PLOT_EPSILON + assert abs(range_before[1] - range_after[1]) <= PLOT_EPSILON + + +def test_log(view: VispyHistogramView) -> None: + # Start out linear + assert not view._log_y + linear_range = view.plot.yaxis.axis.domain[1] + linear_hist = view._hist_mesh.bounds(1)[1] + # lut line, gamma markers controlled by scale + linear_line_scale = view._handle_transform.scale[1] + + # Toggle log, assert range shrinks + view.setRangeLog(True) + assert view._log_y + log_range = view.plot.yaxis.axis.domain[1] + log_hist = view._hist_mesh.bounds(1)[1] + log_line_scale = view._handle_transform.scale[1] + assert abs(math.log10(linear_range) - log_range) <= EPSILON + assert abs(math.log10(linear_hist) - log_hist) <= EPSILON + # NB This final check isn't so simple because of margins, scale checks, + # etc - so need a larger epsilon. + assert abs(math.log10(linear_line_scale) - log_line_scale) <= 0.1 + + # Toggle log, assert range reverts + view.setRangeLog(False) + assert not view._log_y + revert_range = view.plot.yaxis.axis.domain[1] + revert_hist = view._hist_mesh.bounds(1)[1] + revert_line_scale = view._handle_transform.scale[1] + assert abs(linear_range - revert_range) <= EPSILON + assert abs(linear_hist - revert_hist) <= EPSILON + assert abs(linear_line_scale - revert_line_scale) <= EPSILON + + +# @pytest.mark.skipif(sys.platform != "darwin", reason="the mouse event is tricky") +def test_move_clim(qtbot: QtBot, view: VispyHistogramView) -> None: + # Set clims within the viewbox + view.setDomain((0, 100)) + view.setClims((10, 90)) + # Click on the left clim + press_pos = view.node_tform.imap([10])[:2] + event = MouseEvent("mouse_press", pos=press_pos, button=1) + view.on_mouse_press(event) + assert view._grabbed == Grabbable.LEFT_CLIM + assert not view.plot.camera.interactive + # Move it to 50 + move_pos = view.node_tform.imap([50])[:2] + event = MouseEvent("mouse_move", pos=move_pos, button=1) + with qtbot.waitSignal(view.climsChanged): + view.on_mouse_move(event) + assert view._grabbed == Grabbable.LEFT_CLIM + assert not view.plot.camera.interactive + # Release mouse + release_pos = move_pos + event = MouseEvent("mouse_release", pos=release_pos, button=1) + view.on_mouse_release(event) + assert view._grabbed == Grabbable.NONE + assert view.plot.camera.interactive + + # Move both clims to 50 + view.setClims((50, 50)) + # Ensure clicking and moving at 50 moves the right clim + press_pos = view.node_tform.imap([50])[:2] + event = MouseEvent("mouse_press", pos=press_pos, button=1) + view.on_mouse_press(event) + assert view._grabbed == Grabbable.RIGHT_CLIM + assert not view.plot.camera.interactive + # Move it to 70 + move_pos = view.node_tform.imap([70])[:2] + event = MouseEvent("mouse_move", pos=move_pos, button=1) + with qtbot.waitSignal(view.climsChanged): + view.on_mouse_move(event) + assert view._grabbed == Grabbable.RIGHT_CLIM + assert not view.plot.camera.interactive + # Release mouse + release_pos = move_pos + event = MouseEvent("mouse_release", pos=release_pos, button=1) + view.on_mouse_release(event) + assert view._grabbed == Grabbable.NONE + assert view.plot.camera.interactive + + +def test_move_gamma(qtbot: QtBot, view: VispyHistogramView) -> None: + # Set clims outside the viewbox + # NB the canvas is small in this test, so we have to put the clims + # far away or they'll be grabbed over the gamma + view.setDomain((0, 100)) + view.setClims((-9950, 10050)) + # Click on the gamma handle + press_pos = view.node_tform.imap(view._handle_transform.map([50, 0.5]))[:2] + event = MouseEvent("mouse_press", pos=press_pos, button=1) + view.on_mouse_press(event) + assert view._grabbed == Grabbable.GAMMA + assert not view.plot.camera.interactive + # Move it to 50 + move_pos = view.node_tform.imap(view._handle_transform.map([50, 0.75]))[:2] + event = MouseEvent("mouse_move", pos=move_pos, button=1) + with qtbot.waitSignal(view.gammaChanged): + view.on_mouse_move(event) + assert view._grabbed == Grabbable.GAMMA + assert not view.plot.camera.interactive + # Release mouse + release_pos = move_pos + event = MouseEvent("mouse_release", pos=release_pos, button=1) + view.on_mouse_release(event) + assert view._grabbed == Grabbable.NONE + assert view.plot.camera.interactive From 555f6db8b798416a3c3acfe7acee7564d69965d6 Mon Sep 17 00:00:00 2001 From: Gabriel Selzer Date: Fri, 22 Nov 2024 11:02:00 -0600 Subject: [PATCH 3/4] Create QtHistogramView Wraps the VispyHistogramView up with some Qt widgets for control. Derived from the histogram example --- mvc_histogram.py | 16 +---- src/ndv/controller/_controller.py | 4 +- src/ndv/views/__init__.py | 26 +++++++- src/ndv/views/_qt/qt_view.py | 101 +++++++++++++++++++++++++++++- src/ndv/views/_vispy/_vispy.py | 5 +- 5 files changed, 132 insertions(+), 20 deletions(-) diff --git a/mvc_histogram.py b/mvc_histogram.py index 37918b7..e65396c 100644 --- a/mvc_histogram.py +++ b/mvc_histogram.py @@ -12,7 +12,7 @@ ) from ndv.models import LUTModel, StatsModel -from ndv.views._vispy._vispy import VispyHistogramView +from ndv.views import get_histogram_backend if TYPE_CHECKING: from typing import Any @@ -35,7 +35,7 @@ def __init__( if lut_model is None: lut_model = LUTModel() if view is None: - view = VispyHistogramView() + view = get_histogram_backend() self._stats = stats_model self._lut = lut_model self._view = view @@ -50,16 +50,6 @@ def __init__( self._lut.events.gamma.connect(self._set_model_gamma) self._view.gammaChanged.connect(self._set_view_gamma) - # Vertical box - self._vert = QPushButton("Vertical") - self._vert.setCheckable(True) - self._vert.toggled.connect(self._view.setVertical) - - # Log box - self._log = QPushButton("Logarithmic") - self._log.setCheckable(True) - self._log.toggled.connect(self._view.setRangeLog) - # Data updates self._data_btn = QPushButton("Change Data") self._data_btn.setCheckable(True) @@ -80,8 +70,6 @@ def _update_data() -> None: # Layout self._layout = QVBoxLayout(self._wdg) self._layout.addWidget(self._view.view()) - self._layout.addWidget(self._vert) - self._layout.addWidget(self._log) self._layout.addWidget(self._data_btn) def _set_data(self) -> None: diff --git a/src/ndv/controller/_controller.py b/src/ndv/controller/_controller.py index 7ce2f40..57e0344 100644 --- a/src/ndv/controller/_controller.py +++ b/src/ndv/controller/_controller.py @@ -16,7 +16,9 @@ class ViewerController: """The controller mostly manages the connection between the model and the view.""" def __init__( - self, view: PView | None = None, data: DataDisplayModel | None = None + self, + view: PView | None = None, + data: DataDisplayModel | None = None, ) -> None: if data is None: data = DataDisplayModel() diff --git a/src/ndv/views/__init__.py b/src/ndv/views/__init__.py index d6e94a8..25255a8 100644 --- a/src/ndv/views/__init__.py +++ b/src/ndv/views/__init__.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from ndv.views.protocols import PCanvas - from .protocols import PView + from .protocols import PHistogramView, PView def get_view_backend() -> PView: @@ -25,6 +25,14 @@ def get_view_backend() -> PView: raise RuntimeError("Could not determine the appropriate viewer backend") +def get_histogram_backend(backend: str | None = None) -> PHistogramView: + if _is_running_in_qapp(): + from ._qt.qt_view import QHistogramView + + return QHistogramView() + raise RuntimeError("Could not determine the appropriate histogram backend") + + def _is_running_in_notebook() -> bool: if IPython := sys.modules.get("IPython"): if shell := IPython.get_ipython(): @@ -64,3 +72,19 @@ def get_canvas_class(backend: str | None = None) -> type[PCanvas]: return PyGFXViewerCanvas raise RuntimeError("No canvas backend found") + + +def get_histogram_class(backend: str | None = None) -> type[PHistogramView]: + backend = backend or os.getenv("NDV_CANVAS_BACKEND", None) + if backend == "vispy" or (backend is None and "vispy" in sys.modules): + from ndv.views._vispy._vispy import VispyHistogramView + + return VispyHistogramView + + if backend is None: + if importlib.util.find_spec("vispy") is not None: + from ndv.views._vispy._vispy import VispyHistogramView + + return VispyHistogramView + + raise RuntimeError("No histogram backend found") diff --git a/src/ndv/views/_qt/qt_view.py b/src/ndv/views/_qt/qt_view.py index 75814ed..e2f424a 100644 --- a/src/ndv/views/_qt/qt_view.py +++ b/src/ndv/views/_qt/qt_view.py @@ -18,7 +18,7 @@ from superqt.utils import signals_blocked from ndv._types import AxisKey -from ndv.views import get_canvas_class +from ndv.views import get_canvas_class, get_histogram_class from ndv.views._qt._dims_slider import SS from ndv.views.protocols import PImageHandle @@ -212,3 +212,102 @@ def refresh(self) -> None: def set_visible_axes(self, axes: Sequence[Hashable]) -> None: """Set the visible axes.""" self._visible_axes.setText(", ".join(map(str, axes))) + + +class QHistogramView(QWidget): + """A Qt wrapper around a 'backend' Histogram View. + + Parameters + ---------- + parent: QWidget | None + If a widget, set as the parent of this widget + """ + + visibleChanged = Signal() + autoscaleChanged = Signal() + cmapChanged = Signal(cmap.Colormap) + climsChanged = Signal(tuple) + gammaChanged = Signal(float) + + def __init__(self, parent: QWidget | None = None): + super().__init__(parent) + self._backend = get_histogram_class()() + self._backend.visibleChanged.connect(self.visibleChanged.emit) + self._backend.autoscaleChanged.connect(self.autoscaleChanged.emit) + self._backend.cmapChanged.connect(self.cmapChanged.emit) + self._backend.climsChanged.connect(self.climsChanged.emit) + self._backend.gammaChanged.connect(self.gammaChanged.emit) + + # Vertical box + self._vert = QPushButton("Vertical") + self._vert.setCheckable(True) + self._vert.toggled.connect(self._backend.setVertical) + + # Log box + self._log = QPushButton("Logarithmic") + self._log.setCheckable(True) + self._log.toggled.connect(self._backend.setRangeLog) + + # Layout + self._layout = QVBoxLayout(self) + # FIXME: Add to protocol? + self._layout.addWidget(self._backend.view()) + self._layout.addWidget(self._vert) + self._layout.addWidget(self._log) + + # ------------- StatsView Protocol methods ------------- # + + def setHistogram(self, values: Sequence[float], bin_edges: Sequence[float]) -> None: + """Set the histogram values and bin edges. + + These inputs follow the same format as the return value of numpy.histogram. + """ + self._backend.setHistogram(values, bin_edges) + + def setStdDev(self, std_dev: float) -> None: + self._backend.setStdDev(std_dev) + + def setAverage(self, average: float) -> None: + self._backend.setStdDev(average) + + def view(self) -> Any: + return self + + # ------------- LutView Protocol methods ------------- # + + def setName(self, name: str) -> None: + # TODO: maybe show text somewhere + self._backend.setName(name) + pass + + def setLutVisible(self, visible: bool) -> None: + self._backend.setLutVisible(visible) + + def setColormap(self, lut: cmap.Colormap) -> None: + # TODO: Maybe some controls would be nice here? + self._backend.setColormap(lut) + + def setGamma(self, gamma: float) -> None: + self._backend.setGamma(gamma) + + def setClims(self, clims: tuple[float, float] | None) -> None: + self._backend.setClims(clims) + + def setAutoScale(self, autoscale: bool | tuple[float, float]) -> None: + self._backend.setAutoScale(autoscale) + + # ------------- HistogramView Protocol methods ------------- # + + def setDomain(self, bounds: tuple[float, float] | None) -> None: + self._backend.setDomain(bounds) + + def setRange(self, bounds: tuple[float, float] | None) -> None: + self._backend.setRange(bounds) + + def setVertical(self, vertical: bool) -> None: + self._vert.setChecked(vertical) + self._backend.setVertical(vertical) + + def setRangeLog(self, enabled: bool) -> None: + self._log.setChecked(enabled) + self._backend.setRangeLog(enabled) diff --git a/src/ndv/views/_vispy/_vispy.py b/src/ndv/views/_vispy/_vispy.py index 3d49a19..71318b0 100755 --- a/src/ndv/views/_vispy/_vispy.py +++ b/src/ndv/views/_vispy/_vispy.py @@ -1209,15 +1209,14 @@ def on_mouse_move(self, event: MouseEvent) -> None: return # pragma: no cover if self._grabbed in [Grabbable.LEFT_CLIM, Grabbable.RIGHT_CLIM]: - newlims = list(self._clims) if self._vertical: c = self._to_plot_coords(event.pos)[1] else: c = self._to_plot_coords(event.pos)[0] if self._grabbed is Grabbable.LEFT_CLIM: - newlims[0] = min(newlims[1], c) + newlims = (min(self._clims[1], c), self._clims[1]) elif self._grabbed is Grabbable.RIGHT_CLIM: - newlims[1] = max(newlims[0], c) + newlims = (self._clims[0], max(self._clims[0], c)) self.climsChanged.emit(newlims) return elif self._grabbed is Grabbable.GAMMA: From 114c36b802d0fc105239f64a38c948fe1226c72a Mon Sep 17 00:00:00 2001 From: Gabriel Selzer Date: Fri, 22 Nov 2024 17:11:52 -0600 Subject: [PATCH 4/4] Create JupyterHistogramView --- mvc_histogram.ipynb | 112 +++++++++++++++++++ src/ndv/views/__init__.py | 19 +++- src/ndv/views/_jupyter/jupyter_view.py | 143 ++++++++++++++++++++++++- src/ndv/views/_qt/qt_view.py | 21 +++- src/ndv/views/_vispy/_vispy.py | 27 ++--- src/ndv/views/protocols.py | 14 +++ 6 files changed, 321 insertions(+), 15 deletions(-) create mode 100644 mvc_histogram.ipynb diff --git a/mvc_histogram.ipynb b/mvc_histogram.ipynb new file mode 100644 index 0000000..8dff188 --- /dev/null +++ b/mvc_histogram.ipynb @@ -0,0 +1,112 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5b3d3ba7", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "\n", + "from ndv.data import cells3d\n", + "from ndv.models import LUTModel, StatsModel\n", + "from ndv.views import get_histogram_backend\n", + "from ndv.views.protocols import PHistogramView\n", + "\n", + "\n", + "# TODO: Put this somewhere else.\n", + "class Controller:\n", + " \"\"\"A (Qt) wrapper around another HistogramView with some additional controls.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " stats_model: StatsModel | None = None,\n", + " lut_model: LUTModel | None = None,\n", + " view: PHistogramView | None = None,\n", + " ) -> None:\n", + " if stats_model is None:\n", + " stats_model = StatsModel()\n", + " if lut_model is None:\n", + " lut_model = LUTModel()\n", + " if view is None:\n", + " view = get_histogram_backend()\n", + " self._stats = stats_model\n", + " self._lut = lut_model\n", + " self._view = view\n", + "\n", + " # A HistogramView is both a StatsView and a LUTView\n", + " # StatModel <-> StatsView\n", + " self._stats.events.data.connect(self._set_data)\n", + " self._stats.events.bins.connect(self._set_data)\n", + " # LutModel <-> LutView\n", + " self._lut.events.clims.connect(self._set_model_clims)\n", + " self._view.climsChanged.connect(self._set_view_clims)\n", + " self._lut.events.gamma.connect(self._set_model_gamma)\n", + " self._view.gammaChanged.connect(self._set_view_gamma)\n", + "\n", + " def _set_data(self) -> None:\n", + " values, bin_edges = self._stats.histogram\n", + " self._view.setHistogram(values, bin_edges)\n", + "\n", + " def _set_model_clims(self) -> None:\n", + " clims = self._lut.clims\n", + " self._view.setClims(clims)\n", + "\n", + " def _set_view_clims(self, clims: tuple[float, float]) -> None:\n", + " self._lut.clims = clims\n", + "\n", + " def _set_model_gamma(self) -> None:\n", + " gamma = self._lut.gamma\n", + " self._view.setGamma(gamma)\n", + "\n", + " def _set_view_gamma(self, gamma: float) -> None:\n", + " self._lut.gamma = gamma\n", + "\n", + " def view(self) -> Any:\n", + " \"\"\"Returns an object that can be displayed by the active backend.\"\"\"\n", + " return self._view\n", + "\n", + "\n", + "viewer = Controller()\n", + "viewer._stats.data = cells3d()\n", + "\n", + "viewer.view().show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "595ad5fd", + "metadata": {}, + "outputs": [], + "source": [ + "# Change the data\n", + "from numpy.random import normal\n", + "\n", + "viewer._stats.data = normal(30000, 10000, 10000000)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/ndv/views/__init__.py b/src/ndv/views/__init__.py index 25255a8..777dbc9 100644 --- a/src/ndv/views/__init__.py +++ b/src/ndv/views/__init__.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ndv.views.protocols import PCanvas + from ndv.views.protocols import PCanvas, PCursor from .protocols import PHistogramView, PView @@ -26,6 +26,10 @@ def get_view_backend() -> PView: def get_histogram_backend(backend: str | None = None) -> PHistogramView: + if _is_running_in_notebook(): + from ._jupyter.jupyter_view import JupyterHistogramView + + return JupyterHistogramView() if _is_running_in_qapp(): from ._qt.qt_view import QHistogramView @@ -88,3 +92,16 @@ def get_histogram_class(backend: str | None = None) -> type[PHistogramView]: return VispyHistogramView raise RuntimeError("No histogram backend found") + + +def get_cursor_class(backend: str | None = None) -> type[PCursor]: + if _is_running_in_notebook(): + from ._jupyter.jupyter_view import JupyterCursor + + return JupyterCursor + elif _is_running_in_qapp(): + from ._qt.qt_view import QCursor + + return QCursor + + raise RuntimeError("Could not determine the appropriate viewer backend") diff --git a/src/ndv/views/_jupyter/jupyter_view.py b/src/ndv/views/_jupyter/jupyter_view.py index 56e87b6..18b2db0 100644 --- a/src/ndv/views/_jupyter/jupyter_view.py +++ b/src/ndv/views/_jupyter/jupyter_view.py @@ -6,7 +6,8 @@ import ipywidgets as widgets from psygnal import Signal -from ndv.views import get_canvas_class +from ndv.views import get_canvas_class, get_histogram_class +from ndv.views.protocols import CursorType if TYPE_CHECKING: from collections.abc import Container, Hashable, Mapping, Sequence @@ -166,3 +167,143 @@ def show(self) -> None: def refresh(self) -> None: """Refresh the viewer.""" self._canvas.refresh() + + +class JupyterHistogramView: + """A Jupyter wrapper around a 'backend' Histogram View.""" + + visibleChanged = Signal() + autoscaleChanged = Signal() + cmapChanged = Signal(cmap.Colormap) + climsChanged = Signal(tuple) + gammaChanged = Signal(float) + + def __init__(self) -> None: + super().__init__() + self._backend = get_histogram_class()() + self._backend.visibleChanged.connect(self.visibleChanged.emit) + self._backend.autoscaleChanged.connect(self.autoscaleChanged.emit) + self._backend.cmapChanged.connect(self.cmapChanged.emit) + self._backend.climsChanged.connect(self.climsChanged.emit) + self._backend.gammaChanged.connect(self.gammaChanged.emit) + self._vert = widgets.ToggleButton( + value=False, + description="Vertical", + button_style="", # 'success', 'info', 'warning', 'danger' or '' + # TODO: Workshop tooltip + tooltip="If enabled, histogram domain will be displayed along the vertical axis", + ) + self._vert.observe(self._on_vertical_changed, names="value") + + self._log = widgets.ToggleButton( + value=False, + description="Logarithmic Range", + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Display the base-10 logarithm of each bin height", + ) + self._log.observe(self._on_log_changed, names="value") + # `qwidget` is obviously a misnomer here. it works, because vispy is smart + # enough to return a widget that ipywidgets can display in the appropriate + # context, but we should be managing that more explicitly ourselves. + self.layout = widgets.VBox([self._backend.view(), self._vert, self._log]) + + def show(self) -> None: + """Show the viewer.""" + from IPython.display import display + + display(self.layout) # type: ignore [no-untyped-call] + + def refresh(self) -> None: + self._backend.refresh() + + # ------------- StatsView Protocol methods ------------- # + + def setHistogram(self, values: Sequence[float], bin_edges: Sequence[float]) -> None: + """Set the histogram values and bin edges. + + These inputs follow the same format as the return value of numpy.histogram. + """ + self._backend.setHistogram(values, bin_edges) + self._backend.refresh() + + def setStdDev(self, std_dev: float) -> None: + self._backend.setStdDev(std_dev) + self._backend.refresh() + + def setAverage(self, average: float) -> None: + self._backend.setStdDev(average) + self._backend.refresh() + + def view(self) -> Any: + return self + + # ------------- LutView Protocol methods ------------- # + + def setName(self, name: str) -> None: + # TODO: maybe show text somewhere + self._backend.setName(name) + self._backend.refresh() + pass + + def setLutVisible(self, visible: bool) -> None: + self._backend.setLutVisible(visible) + self._backend.refresh() + + def setColormap(self, lut: cmap.Colormap) -> None: + # TODO: Maybe some controls would be nice here? + self._backend.setColormap(lut) + self._backend.refresh() + + def setGamma(self, gamma: float) -> None: + self._backend.setGamma(gamma) + self._backend.refresh() + + def setClims(self, clims: tuple[float, float] | None) -> None: + self._backend.setClims(clims) + self._backend.refresh() + + def setAutoScale(self, autoscale: bool | tuple[float, float]) -> None: + self._backend.setAutoScale(autoscale) + self._backend.refresh() + + # ------------- HistogramView Protocol methods ------------- # + + def setDomain(self, bounds: tuple[float, float] | None) -> None: + self._backend.setDomain(bounds) + self._backend.refresh() + + def setRange(self, bounds: tuple[float, float] | None) -> None: + self._backend.setRange(bounds) + self._backend.refresh() + + def setVertical(self, vertical: bool) -> None: + self._vert.value = vertical + self._backend.setVertical(vertical) + self._backend.refresh() + + def setRangeLog(self, enabled: bool) -> None: + self._log.value = enabled + self._backend.setRangeLog(enabled) + self._backend.refresh() + + def _on_vertical_changed(self, change: dict[str, Any]) -> None: + self.setVertical(self._vert.value) + + def _on_log_changed(self, change: dict[str, Any]) -> None: + self.setRangeLog(self._log.value) + + +class JupyterCursor: + def __init__(self, native: Any) -> None: + # FIXME + self._native = native + + def set(self, type: CursorType) -> None: + if type is CursorType.DEFAULT: + self._native.cursor = "default" + elif type is CursorType.V_ARROW: + self._native.cursor = "ns-resize" + elif type is CursorType.H_ARROW: + self._native.cursor = "ew-resize" + elif type is CursorType.ALL_ARROW: + self._native.cursor = "move" diff --git a/src/ndv/views/_qt/qt_view.py b/src/ndv/views/_qt/qt_view.py index e2f424a..566bb60 100644 --- a/src/ndv/views/_qt/qt_view.py +++ b/src/ndv/views/_qt/qt_view.py @@ -20,7 +20,7 @@ from ndv._types import AxisKey from ndv.views import get_canvas_class, get_histogram_class from ndv.views._qt._dims_slider import SS -from ndv.views.protocols import PImageHandle +from ndv.views.protocols import CursorType, PImageHandle class CmapCombo(QColormapComboBox): @@ -255,6 +255,9 @@ def __init__(self, parent: QWidget | None = None): self._layout.addWidget(self._vert) self._layout.addWidget(self._log) + def refresh(self) -> None: + self._backend.refresh() + # ------------- StatsView Protocol methods ------------- # def setHistogram(self, values: Sequence[float], bin_edges: Sequence[float]) -> None: @@ -311,3 +314,19 @@ def setVertical(self, vertical: bool) -> None: def setRangeLog(self, enabled: bool) -> None: self._log.setChecked(enabled) self._backend.setRangeLog(enabled) + + +class QCursor: + def __init__(self, native: Any) -> None: + # FIXME + self._native = native + + def set(self, type: CursorType) -> None: + if type is CursorType.V_ARROW: + self._native.setCursor(Qt.CursorShape.SplitVCursor) + elif type is CursorType.H_ARROW: + self._native.setCursor(Qt.CursorShape.SplitHCursor) + elif type is CursorType.ALL_ARROW: + self._native.setCursor(Qt.CursorShape.SizeAllCursor) + else: + self._native.unsetCursor() diff --git a/src/ndv/views/_vispy/_vispy.py b/src/ndv/views/_vispy/_vispy.py index 71318b0..09598f9 100755 --- a/src/ndv/views/_vispy/_vispy.py +++ b/src/ndv/views/_vispy/_vispy.py @@ -17,7 +17,8 @@ from vispy.color import Color from vispy.util.quaternion import Quaternion -from ndv.views.protocols import PCanvas, PHistogramView +from ndv.views import get_cursor_class +from ndv.views.protocols import CursorType, PCanvas, PHistogramView if TYPE_CHECKING: from collections.abc import Sequence @@ -968,6 +969,8 @@ def __init__(self) -> None: self._canvas.on_mouse_release = self.on_mouse_release self._canvas.freeze() + self._cursor = get_cursor_class()(self._canvas.native) + ## -- Visuals -- ## # NB We directly use scene.Mesh, instead of scene.Histogram, @@ -1017,6 +1020,9 @@ def __init__(self) -> None: self.plot._view.add(self._lut_line) self.plot._view.add(self._gamma_handle) + def refresh(self) -> None: + self._canvas.update() + # ------------- StatsView Protocol methods ------------- # def setHistogram(self, values: Sequence[float], bin_edges: Sequence[float]) -> None: @@ -1231,30 +1237,27 @@ def on_mouse_move(self, event: MouseEvent) -> None: self.gammaChanged.emit(-np.log2(y / y1)) return - # TODO: try to remove the Qt aspect here so that we can use - # this for Jupyter as well - self._canvas.native.unsetCursor() - nearby = self._find_nearby_node(event) if nearby in [Grabbable.LEFT_CLIM, Grabbable.RIGHT_CLIM]: if self._vertical: - cursor = Qt.CursorShape.SplitVCursor + cursor_type = CursorType.V_ARROW else: - cursor = Qt.CursorShape.SplitHCursor - self._canvas.native.setCursor(cursor) + cursor_type = CursorType.H_ARROW elif nearby is Grabbable.GAMMA: if self._vertical: - cursor = Qt.CursorShape.SplitHCursor + cursor_type = CursorType.H_ARROW else: - cursor = Qt.CursorShape.SplitVCursor - self._canvas.native.setCursor(cursor) + cursor_type = CursorType.V_ARROW else: x, y = self._to_plot_coords(event.pos) x1, x2 = self.plot.xaxis.axis.domain y1, y2 = self.plot.yaxis.axis.domain if (x1 < x <= x2) and (y1 <= y <= y2): - self._canvas.native.setCursor(Qt.CursorShape.SizeAllCursor) + cursor_type = CursorType.ALL_ARROW + else: + cursor_type = CursorType.DEFAULT + self._cursor.set(cursor_type) def _find_nearby_node(self, event: MouseEvent, tolerance: int = 5) -> Grabbable: """Describes whether the event is near a clim.""" diff --git a/src/ndv/views/protocols.py b/src/ndv/views/protocols.py index 3ba27c7..b7ae9d8 100644 --- a/src/ndv/views/protocols.py +++ b/src/ndv/views/protocols.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Sequence +from enum import Enum, auto from typing import TYPE_CHECKING, Any, Literal, Protocol if TYPE_CHECKING: @@ -140,6 +141,7 @@ def setAverage(self, avg: float) -> None: class PHistogramView(PStatsView, PLutView): """A histogram-based view for LookUp Table (LUT) adjustment.""" + def refresh(self) -> None: ... def setDomain(self, bounds: tuple[float, float] | None) -> None: """Sets the domain of the view. @@ -310,3 +312,15 @@ def add_roi( color: cmap.Color | None = None, border_color: cmap.Color | None = None, ) -> PRoiHandle: ... + + +class CursorType(Enum): + DEFAULT = auto() + V_ARROW = auto() + H_ARROW = auto() + ALL_ARROW = auto() + + +class PCursor(Protocol): + def __init__(self, native: Any) -> None: ... + def set(self, type: CursorType) -> None: ...