From 6161ff7d01d530e07f3519ca5121e5d5fd2b9562 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 9 Jun 2024 20:54:44 -0400 Subject: [PATCH 01/12] wip --- src/ndv/_chunking.py | 218 ++++++++++++++++++++++++++++++++ src/ndv/viewer/_data_wrapper.py | 18 ++- src/ndv/viewer/_viewer.py | 50 +++++--- y.py | 16 +++ z.py | 38 ++++++ 5 files changed, 320 insertions(+), 20 deletions(-) create mode 100644 src/ndv/_chunking.py create mode 100644 y.py create mode 100644 z.py diff --git a/src/ndv/_chunking.py b/src/ndv/_chunking.py new file mode 100644 index 0000000..9b87929 --- /dev/null +++ b/src/ndv/_chunking.py @@ -0,0 +1,218 @@ +from concurrent.futures import Future, ThreadPoolExecutor +from itertools import product +from types import EllipsisType +from typing import ( + Deque, + Hashable, + Iterable, + Iterator, + Mapping, + NamedTuple, + Sequence, + TypeAlias, + cast, +) + +import cmap +import numpy as np +from attr import dataclass + +from .viewer._data_wrapper import DataWrapper + +# any hashable represent a single dimension in an ND array +DimKey: TypeAlias = Hashable +# any object that can be used to index a single dimension in an ND array +Index: TypeAlias = int | slice +# a mapping from dimension keys to indices (eg. {"x": 0, "y": slice(5, 10)}) +# this object is used frequently to query or set the currently displayed slice +Indices: TypeAlias = Mapping[DimKey, Index] +# mapping of dimension keys to the maximum value for that dimension +Sizes: TypeAlias = Mapping[DimKey, int] + + +@dataclass +class ChannelSetting: + visible: bool = True + colormap: cmap.Colormap | str = "gray" + clims: tuple[float, float] | None = None + gamma: float = 1 + auto_clim: bool = False + + +class Response(NamedTuple): + idx: tuple[int | slice, ...] + data: np.ndarray + + +class Slicer: + def __init__( + self, + data_wrapper: DataWrapper | None = None, + chunks: int | tuple[int, ...] | None = None, + ) -> None: + self.chunks = chunks + self.data_wrapper: DataWrapper | None = data_wrapper + self.executor = ThreadPoolExecutor() + self.pending_futures: Deque[Future[Response]] = Deque() + + def __del__(self) -> None: + self.executor.shutdown(cancel_futures=True, wait=True) + + def shutdown(self) -> None: + self.executor.shutdown(wait=True) + + def _request_chunk_sync(self, idx: tuple[int | slice, ...]) -> Response: + if self.data_wrapper is None: + raise ValueError("No data wrapper set") + data = self.data_wrapper[idx] + return Response(idx=idx, data=data) + + def request_index(self, index: Indices) -> None: + if self.data_wrapper is None: + return + idx = self.data_wrapper.to_conventional(index) + if self.chunks is None: + subchunks: Iterable[tuple[int | slice, ...]] = [idx] + else: + shape = self.data_wrapper.data.shape + subchunks = iter_chunk_aligned_slices(shape, self.chunks, idx) + for chunk_idx in subchunks: + future = self.executor.submit(self._request_chunk_sync, chunk_idx) + self.pending_futures.append(future) + future.add_done_callback(self._on_chunk_ready) + + def _on_chunk_ready(self, future: Future[Response]) -> None: + chunk = future.result() + # process the chunk data + print(chunk.idx, chunk.data.shape) + self.pending_futures.remove(future) + + +def _axis_chunks(total_length: int, chunk_size: int) -> tuple[int, ...]: + """Break `total_length` into chunks of `chunk_size` plus remainder. + + Examples + -------- + >>> _axis_chunks(10, 3) + (3, 3, 3, 1) + """ + sequence = (chunk_size,) * (total_length // chunk_size) + if remainder := total_length % chunk_size: + sequence += (remainder,) + return sequence + + +def _shape_chunks( + shape: tuple[int, ...], chunks: int | tuple[int, ...] +) -> tuple[tuple[int, ...], ...]: + """Break `shape` into chunks of `chunks` along each axis. + + Examples + -------- + >>> _shape_chunks((10, 10, 10), 3) + ((3, 3, 3, 1), (3, 3, 3, 1), (3, 3, 3, 1)) + """ + if isinstance(chunks, int): + chunks = (chunks,) * len(shape) + elif isinstance(chunks, Sequence): + if len(chunks) != len(shape): + raise ValueError("Length of `chunks` must match length of `shape`") + else: + raise TypeError("`chunks` must be an int or sequence of ints") + return tuple(_axis_chunks(length, chunk) for length, chunk in zip(shape, chunks)) + + +def _slice2range(sl: slice | int, dim_size: int) -> range: + """Convert slice to range, handling single int as well. + + Examples + -------- + >>> _slice2range(3, 10) + range(3, 4) + """ + if isinstance(sl, int): + return range(sl, sl + 1) + start = 0 if sl.start is None else max(sl.start, 0) + stop = dim_size if sl.stop is None else min(sl.stop, dim_size) + return range(start, stop) + + +def iter_chunk_aligned_slices( + shape: tuple[int, ...], + chunks: int | tuple[int, ...], + slices: tuple[int | slice | EllipsisType, ...], +) -> Iterator[tuple[slice, ...]]: + """Yield chunk-aligned slices for a given shape and slices. + + Parameters + ---------- + shape : tuple[int, ...] + The shape of the array to slice. + chunks : int or tuple[int, ...] + The size of each chunk. If a single int, the same size is used for all + dimensions. + slices : tuple[int | slice | Ellipsis, ...] + The full slices to apply to the array. Ellipsis is supported to + represent multiple slices. + + Examples + -------- + >>> list(iter_chunk_aligned_slices((6, 6), 4, (slice(1, 4), ...))) + [ + (slice(1, 4, None), slice(0, 4, None)), + (slice(1, 4, None), slice(4, 6, None)), + ] + + >>> list(iter_chunk_aligned_slices((10, 9), (4, 3), (slice(3, 9), slice(1, None)))) + [ + (slice(3, 4, None), slice(1, 3, None)), + (slice(3, 4, None), slice(3, 6, None)), + (slice(3, 4, None), slice(6, 9, None)), + (slice(4, 8, None), slice(1, 3, None)), + (slice(4, 8, None), slice(3, 6, None)), + (slice(4, 8, None), slice(6, 9, None)), + (slice(8, 9, None), slice(1, 3, None)), + (slice(8, 9, None), slice(3, 6, None)), + (slice(8, 9, None), slice(6, 9, None)), + ] + """ + # Make chunks same length as shape if single int + ndim = len(shape) + if isinstance(chunks, int): + chunks = (chunks,) * ndim + if any(x == 0 for x in chunks): + raise ValueError("Chunk size must be greater than zero") + + if any(isinstance(sl, EllipsisType) for sl in slices): + # Replace Ellipsis with multiple slices + if slices.count(Ellipsis) > 1: + raise ValueError("Only one Ellipsis is allowed") + el_idx = slices.index(Ellipsis) + n_remaining = ndim - len(slices) + 1 + slices = slices[:el_idx] + (slice(None),) * n_remaining + slices[el_idx + 1 :] + + if not (len(chunks) == ndim == len(slices)): + raise ValueError("Length of `chunks`, `shape`, and `slices` must match") + + # Create ranges for each dimension based on the slices provided + slices = cast(tuple[int | slice, ...], slices) + ranges = [_slice2range(sl, dim) for sl, dim in zip(slices, shape)] + + # Generate indices for each dimension that align with chunks + aligned_ranges = ( + range(r.start - (r.start % ch), r.stop, ch) for r, ch in zip(ranges, chunks) + ) + + # Create all combinations of these aligned ranges + for indices in product(*aligned_ranges): + chunk_slices = [] + for idx, rng, ch in zip(indices, ranges, chunks): + # Calculate the actual slice for each dimension + start = max(rng.start, idx) + stop = min(rng.stop, idx + ch) + if start >= stop: # Skip empty slices + break + chunk_slices.append(slice(start, stop)) + else: + # Only add this combination of slices if all dimensions are valid + yield tuple(chunk_slices) diff --git a/src/ndv/viewer/_data_wrapper.py b/src/ndv/viewer/_data_wrapper.py index 002bbc5..6b49ea4 100644 --- a/src/ndv/viewer/_data_wrapper.py +++ b/src/ndv/viewer/_data_wrapper.py @@ -129,6 +129,13 @@ def isel(self, indexers: Indices) -> np.ndarray: """ raise NotImplementedError + def __getitem__(self, index: tuple[int | slice, ...]) -> np.ndarray: + return self._data[index] + + def to_conventional(self, indexers: Indices) -> tuple[int | slice, ...]: + """Convert named indices to a tuple of integers and slices.""" + return tuple(indexers.get(k, slice(None)) for k in range(len(self.data.shape))) + def isel_async( self, indexers: list[Indices] ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: @@ -245,8 +252,15 @@ class ArrayLikeWrapper(DataWrapper, Generic[ArrayT]): PRIORITY = 100 def isel(self, indexers: Indices) -> np.ndarray: - idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) - return self._asarray(self._data[idx]) + idx = [] + for k in range(len(self._data.shape)): + i = indexers.get(k, slice(None)) + if isinstance(i, int): + idx.extend([i, None]) + else: + idx.append(i) + + return self._asarray(self._data[tuple(idx)]) def _asarray(self, data: npt.ArrayLike) -> np.ndarray: return np.asarray(data) diff --git a/src/ndv/viewer/_viewer.py b/src/ndv/viewer/_viewer.py index 473f0f5..abfdb3f 100644 --- a/src/ndv/viewer/_viewer.py +++ b/src/ndv/viewer/_viewer.py @@ -273,7 +273,7 @@ def set_data( # redraw if initial_index is None: - idx = {k: int(v // 2) for k, v in sizes.items()} + idx = {k: int(v // 2) for k, v in sizes.items() if k not in visualized_dims} else: if not isinstance(initial_index, dict): # pragma: no cover raise TypeError("initial_index must be a dict") @@ -400,34 +400,48 @@ def _image_key(self, index: Indices) -> ImgKey: return val return 0 + def _build_requests(self, index: Indices) -> list[Indices]: + # receives an unordered mapping of dimension keys to int | slice + # for example {1: 0, 2: 128, 3: 128, 0: 38} + # returns a list of indices to request from the datastore that takes + # into account the channel axis, channel mode, and whether 2d or 3d mode. + sizes = self._data_wrapper.sizes() + if self._channel_mode == ChannelMode.COMPOSITE and self._channel_axis in sizes: + indices: list[Indices] = [ + {**index, self._channel_axis: i} + for i in range(sizes[self._channel_axis]) + ] + else: + indices = [index] + # don't request any dimensions that are not visualized + for idx in indices: + for k in self._visualized_dims: + idx.pop(k, None) + + return indices + def _update_data_for_index(self, index: Indices) -> None: """Retrieve data for `index` from datastore and update canvas image(s). + This is the first step in updating the displayed image, it is triggered by + the valueChanged signal from the sliders. + This will pull the data from the datastore using the given index, and update the image handle(s) with the new data. This method is *asynchronous*. It makes a request for the new data slice and queues _on_data_future_done to be called when the data is ready. """ - if ( - self._channel_axis is not None - and self._channel_mode == ChannelMode.COMPOSITE - and self._channel_axis in (sizes := self._data_wrapper.sizes()) - ): - indices: list[Indices] = [ - {**index, self._channel_axis: i} - for i in range(sizes[self._channel_axis]) - ] - else: - indices = [index] - + print("\n\n-----------------") + print("update_data_for_index", index) + print("visualized_dims", self._visualized_dims) + print("sizes", self._data_wrapper.sizes()) + print("channel_axis", self._channel_axis) + print("channel_mode", self._channel_mode) + + indices = self._build_requests(index) if self._last_future: self._last_future.cancel() - # don't request any dimensions that are not visualized - indices = [ - {k: v for k, v in idx.items() if k not in self._visualized_dims} - for idx in indices - ] try: self._last_future = f = self._data_wrapper.isel_async(indices) except Exception as e: diff --git a/y.py b/y.py new file mode 100644 index 0000000..a5eb6b9 --- /dev/null +++ b/y.py @@ -0,0 +1,16 @@ +import numpy as np + +import ndv +from ndv._chunking import Slicer + +data = np.random.rand(10, 3, 8, 5, 128, 128) +wrapper = ndv.DataWrapper.create(data) +slicer = Slicer(wrapper, chunks=(5, 1, 2, 2, 64, 34)) + +index = {0: 2, 1: 2, 2: 0, 3: 4} +idx = wrapper.to_conventional(index) +print(idx) +print(wrapper[idx].shape) + +slicer.request_index(index) +# slicer.shutdown() diff --git a/z.py b/z.py new file mode 100644 index 0000000..28ee931 --- /dev/null +++ b/z.py @@ -0,0 +1,38 @@ +import random + +import dask.array as da +from dask.distributed import Client, as_completed + + +# Function to load a chunk +def load_chunk(chunk): + # Simulate loading time + import time + + t = random.random() * 5 + print(t) + time.sleep(t) + return chunk + + +if __name__ == "__main__": + # Set up Dask Client + client = Client() + # Create a Dask array (simulate chunked storage) + x = da.random.random((10, 10), chunks=(5, 5)) + + # Submit tasks directly to the scheduler and get futures + futures = [] + for i in range(x.numblocks[0]): + for j in range(x.numblocks[1]): + chunk = x.blocks[i, j] + future = client.submit(load_chunk, chunk) + futures.append(future) + + # Monitor progress using as_completed + for future in as_completed(futures): + result = future.result() + print("Chunk ready:", result.shape) + + # Close the client + client.close() From 6284b5ab3b2450b83cddfabcc310dcffa7e4e98d Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 9 Jun 2024 23:03:42 -0400 Subject: [PATCH 02/12] hard-coded chunking --- src/ndv/_chunking.py | 42 +++++++++++++++++++++---- src/ndv/viewer/_backends/_vispy.py | 9 +++++- src/ndv/viewer/_viewer.py | 49 ++++++++++++++++++++---------- 3 files changed, 77 insertions(+), 23 deletions(-) diff --git a/src/ndv/_chunking.py b/src/ndv/_chunking.py index 9b87929..ea8a538 100644 --- a/src/ndv/_chunking.py +++ b/src/ndv/_chunking.py @@ -1,7 +1,9 @@ from concurrent.futures import Future, ThreadPoolExecutor +from functools import partial from itertools import product from types import EllipsisType from typing import ( + Callable, Deque, Hashable, Iterable, @@ -52,7 +54,7 @@ def __init__( ) -> None: self.chunks = chunks self.data_wrapper: DataWrapper | None = data_wrapper - self.executor = ThreadPoolExecutor() + self.executor = ThreadPoolExecutor(max_workers=4) self.pending_futures: Deque[Future[Response]] = Deque() def __del__(self) -> None: @@ -67,7 +69,7 @@ def _request_chunk_sync(self, idx: tuple[int | slice, ...]) -> Response: data = self.data_wrapper[idx] return Response(idx=idx, data=data) - def request_index(self, index: Indices) -> None: + def request_index(self, index: Indices, func: Callable) -> None: if self.data_wrapper is None: return idx = self.data_wrapper.to_conventional(index) @@ -75,16 +77,21 @@ def request_index(self, index: Indices) -> None: subchunks: Iterable[tuple[int | slice, ...]] = [idx] else: shape = self.data_wrapper.data.shape - subchunks = iter_chunk_aligned_slices(shape, self.chunks, idx) + subchunks = sorted( + iter_chunk_aligned_slices(shape, self.chunks, idx), + key=lambda x: center_distance_key(x, shape), + ) for chunk_idx in subchunks: future = self.executor.submit(self._request_chunk_sync, chunk_idx) self.pending_futures.append(future) - future.add_done_callback(self._on_chunk_ready) + future.add_done_callback(partial(self._on_chunk_ready, func)) - def _on_chunk_ready(self, future: Future[Response]) -> None: + def _on_chunk_ready(self, func: Callable, future: Future[Response]) -> None: chunk = future.result() # process the chunk data - print(chunk.idx, chunk.data.shape) + + # print(start, chunk.data.squeeze().shape) + func(chunk) self.pending_futures.remove(future) @@ -216,3 +223,26 @@ def iter_chunk_aligned_slices( else: # Only add this combination of slices if all dimensions are valid yield tuple(chunk_slices) + + +def slice_center(s, dim_size): + """Calculate the center of a slice based on its start and stop attributes.""" + # For integer slices, center is the integer itself. + if isinstance(s, int): + return s + # For slice objects, calculate the middle point. + start = s.start if s.start is not None else 0 + stop = s.stop if s.stop is not None else dim_size + return (start + stop) / 2 + + +def center_distance_key(slice_tuple, shape): + """Calculate the Euclidean distance from the center of the slices to the center of the shape.""" + shape_center = [dim / 2 for dim in shape] + slice_centers = [slice_center(s, dim) for s, dim in zip(slice_tuple, shape)] + + # Calculate Euclidean distance from the slice centers to the shape center + distance = np.sqrt( + sum((sc - cc) ** 2 for sc, cc in zip(slice_centers, shape_center)) + ) + return distance diff --git a/src/ndv/viewer/_backends/_vispy.py b/src/ndv/viewer/_backends/_vispy.py index 7ed78d7..8a535ce 100644 --- a/src/ndv/viewer/_backends/_vispy.py +++ b/src/ndv/viewer/_backends/_vispy.py @@ -132,12 +132,19 @@ def refresh(self) -> None: self._canvas.update() def add_image( - self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + self, + data: np.ndarray | None = None, + cmap: cmap.Colormap | None = None, + offset: tuple[int, ...] = (), ) -> VispyImageHandle: """Add a new Image node to the scene.""" img = scene.visuals.Image(data, parent=self._view.scene) img.set_gl_state("additive", depth_test=False) img.interactive = True + + if offset: + img.transform = scene.STTransform(translate=offset[::-1]) + if data is not None: self._current_shape, prev_shape = data.shape, self._current_shape if not prev_shape: diff --git a/src/ndv/viewer/_viewer.py b/src/ndv/viewer/_viewer.py index abfdb3f..9efcabe 100644 --- a/src/ndv/viewer/_viewer.py +++ b/src/ndv/viewer/_viewer.py @@ -258,6 +258,9 @@ def set_data( """ # store the data self._data_wrapper = DataWrapper.create(data) + from ndv._chunking import Slicer + + self._slicer = Slicer(self._data_wrapper, chunks=(1, 1, 64, 47)) # set channel axis self._channel_axis = self._data_wrapper.guess_channel_axis() @@ -438,17 +441,20 @@ def _update_data_for_index(self, index: Indices) -> None: print("channel_axis", self._channel_axis) print("channel_mode", self._channel_mode) - indices = self._build_requests(index) - if self._last_future: - self._last_future.cancel() + self._data_wrapper.to_conventional(index) + self._slicer.request_index(index, self._on_data_slice_ready) + + # indices = self._build_requests(index) + # if self._last_future: + # self._last_future.cancel() - try: - self._last_future = f = self._data_wrapper.isel_async(indices) - except Exception as e: - raise type(e)(f"Failed to index data with {index}: {e}") from e + # try: + # self._last_future = f = self._data_wrapper.isel_async(indices) + # except Exception as e: + # raise type(e)(f"Failed to index data with {index}: {e}") from e - self._progress_spinner.show() - f.add_done_callback(self._on_data_slice_ready) + # self._progress_spinner.show() + # f.add_done_callback(self._on_data_slice_ready) def closeEvent(self, a0: QCloseEvent | None) -> None: if self._last_future is not None: @@ -464,6 +470,9 @@ def _on_data_slice_ready( Connected to the future returned by _isel. """ + offset = tuple(int(getattr(sl, "start", sl)) for sl in future.idx)[-2:] + self._update_canvas_data(future.data, offset) + return # NOTE: removing the reference to the last future here is important # because the future has a reference to this widget in its _done_callbacks # which will prevent the widget from being garbage collected if the future @@ -476,20 +485,25 @@ def _on_data_slice_ready( self._update_canvas_data(datum, idx) self._canvas.refresh() - def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: + def _update_canvas_data(self, data: np.ndarray, offset: list[int]) -> None: + # def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: """Actually update the image handle(s) with the (sliced) data. By this point, data should be sliced from the underlying datastore. Any dimensions remaining that are more than the number of visualized dimensions (currently just 2D) will be reduced using max intensity projection (currently). """ - imkey = self._image_key(index) + # imkey = self._image_key(index) + imkey = 0 + print(offset) datum = self._reduce_data_for_display(data) - if handles := self._img_handles[imkey]: + if handles := self._img_handles[offset]: for handle in handles: + print("updating handle") handle.data = datum - if ctrl := self._lut_ctrls.get(imkey, None): - ctrl.update_autoscale() + handle.clim = (0, 45000) + # if ctrl := self._lut_ctrls.get(imkey, None): + # ctrl.update_autoscale() else: cm = ( next(self._cmap_cycle) @@ -497,11 +511,14 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: else GRAYS ) if datum.ndim == 2: - handles.append(self._canvas.add_image(datum, cmap=cm)) + handle = self._canvas.add_image(datum, cmap=cm, offset=offset) + handle.clim = (0, 45000) + handles.append(handle) elif datum.ndim == 3: handles.append(self._canvas.add_volume(datum, cmap=cm)) if imkey not in self._lut_ctrls: - ch_index = index.get(self._channel_axis, 0) + # ch_index = index.get(self._channel_axis, 0) + ch_index = 0 self._lut_ctrls[imkey] = c = LutControl( f"Ch {ch_index}", handles, From b68667e92b873d028f2745a040587562efc7cd18 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 10 Jun 2024 11:37:23 -0400 Subject: [PATCH 03/12] working, but very messy --- src/ndv/_chunking.py | 249 +++++++++++++++++++---------- src/ndv/viewer/_backends/_vispy.py | 1 + src/ndv/viewer/_viewer.py | 179 +++++++++------------ 3 files changed, 238 insertions(+), 191 deletions(-) diff --git a/src/ndv/_chunking.py b/src/ndv/_chunking.py index ea8a538..7e091b7 100644 --- a/src/ndv/_chunking.py +++ b/src/ndv/_chunking.py @@ -1,25 +1,28 @@ +from __future__ import annotations + +import math from concurrent.futures import Future, ThreadPoolExecutor -from functools import partial from itertools import product from types import EllipsisType from typing import ( - Callable, + TYPE_CHECKING, + Any, Deque, Hashable, - Iterable, - Iterator, Mapping, NamedTuple, Sequence, - TypeAlias, cast, ) -import cmap import numpy as np -from attr import dataclass +from rich import print + +if TYPE_CHECKING: + from collections import deque + from typing import Callable, Iterable, Iterator, TypeAlias -from .viewer._data_wrapper import DataWrapper + from .viewer._data_wrapper import DataWrapper # any hashable represent a single dimension in an ND array DimKey: TypeAlias = Hashable @@ -32,101 +35,177 @@ Sizes: TypeAlias = Mapping[DimKey, int] -@dataclass -class ChannelSetting: - visible: bool = True - colormap: cmap.Colormap | str = "gray" - clims: tuple[float, float] | None = None - gamma: float = 1 - auto_clim: bool = False +class ChunkResponse(NamedTuple): + idx: tuple[int | slice, ...] # index that was requested + data: np.ndarray # the data that was returned + offset: tuple[int, ...] # offset of the data in the full array (derived from idx) + channel_index: int = -1 -class Response(NamedTuple): - idx: tuple[int | slice, ...] - data: np.ndarray +RequestFinished = object() -class Slicer: +class Chunker: def __init__( self, data_wrapper: DataWrapper | None = None, chunks: int | tuple[int, ...] | None = None, + on_ready: Callable[[ChunkResponse], Any] | None = None, ) -> None: self.chunks = chunks self.data_wrapper: DataWrapper | None = data_wrapper self.executor = ThreadPoolExecutor(max_workers=4) - self.pending_futures: Deque[Future[Response]] = Deque() + self.pending_futures: deque[Future[ChunkResponse]] = Deque() + self.on_ready = on_ready + self.channel_axis: int | None = None def __del__(self) -> None: - self.executor.shutdown(cancel_futures=True, wait=True) + self.shutdown() def shutdown(self) -> None: - self.executor.shutdown(wait=True) + self.executor.shutdown(cancel_futures=True, wait=True) - def _request_chunk_sync(self, idx: tuple[int | slice, ...]) -> Response: - if self.data_wrapper is None: - raise ValueError("No data wrapper set") - data = self.data_wrapper[idx] - return Response(idx=idx, data=data) + def _request_chunk_sync( + self, idx: tuple[int | slice, ...], channel_axis: int | None + ) -> ChunkResponse: + # idx is guaranteed to have length equal to the number of dimensions + if channel_axis is not None: + channel_index = idx[channel_axis] + if isinstance(channel_index, slice): + channel_index = channel_index.start + else: + channel_index = -1 + + data = self.data_wrapper[idx] # type: ignore [index] + data = _reduce_data_for_display(data, 2) + # FIXME: temporary + # this needs to be aware of nvisible dimensions + try: + offset = tuple(int(getattr(sl, "start", sl)) for sl in idx)[-2:] + except TypeError: + offset = (0, 0) + + import time + + time.sleep(0.02) + return ChunkResponse( + idx=idx, data=data, offset=offset, channel_index=channel_index + ) + + def request_index(self, index: Indices, cancel_existing: bool = True) -> None: + if cancel_existing: + for future in list(self.pending_futures): + future.cancel() - def request_index(self, index: Indices, func: Callable) -> None: if self.data_wrapper is None: return idx = self.data_wrapper.to_conventional(index) - if self.chunks is None: + + if (chunks := self.chunks) is None: subchunks: Iterable[tuple[int | slice, ...]] = [idx] else: shape = self.data_wrapper.data.shape + + # we never chunk the channel axis + if isinstance(chunks, int): + _chunks = [chunks] * len(shape) + else: + _chunks = list(chunks) + if self.channel_axis is not None: + _chunks[self.channel_axis] = 1 + + # TODO: allow the viewer to pass a center coord, to load chunks + # preferentially around that point subchunks = sorted( - iter_chunk_aligned_slices(shape, self.chunks, idx), - key=lambda x: center_distance_key(x, shape), + iter_chunk_aligned_slices(shape, _chunks, idx), + key=lambda x: distance_from_coord(x, shape), ) + # print("Requesting index:", idx) + # print("subchunks", subchunks) + # print() for chunk_idx in subchunks: - future = self.executor.submit(self._request_chunk_sync, chunk_idx) + future = self.executor.submit( + self._request_chunk_sync, chunk_idx, self.channel_axis + ) self.pending_futures.append(future) - future.add_done_callback(partial(self._on_chunk_ready, func)) - - def _on_chunk_ready(self, func: Callable, future: Future[Response]) -> None: - chunk = future.result() - # process the chunk data + future.add_done_callback(self._on_chunk_ready) - # print(start, chunk.data.squeeze().shape) - func(chunk) + def _on_chunk_ready(self, future: Future[ChunkResponse]) -> None: self.pending_futures.remove(future) + if future.cancelled(): + return + if err := future.exception(): + print(f"{type(err).__name__}: in chunk request: {err}") + return + if self.on_ready is not None: + self.on_ready(future.result()) + if not self.pending_futures: + # FIXME: this emits multiple times sometimes + # Fix typing + self.on_ready(RequestFinished) -def _axis_chunks(total_length: int, chunk_size: int) -> tuple[int, ...]: - """Break `total_length` into chunks of `chunk_size` plus remainder. - - Examples - -------- - >>> _axis_chunks(10, 3) - (3, 3, 3, 1) - """ - sequence = (chunk_size,) * (total_length // chunk_size) - if remainder := total_length % chunk_size: - sequence += (remainder,) - return sequence - +def _reduce_data_for_display( + data: np.ndarray, ndims: int, reductor: Callable[..., np.ndarray] = np.max +) -> np.ndarray: + """Reduce the number of dimensions in the data for display. -def _shape_chunks( - shape: tuple[int, ...], chunks: int | tuple[int, ...] -) -> tuple[tuple[int, ...], ...]: - """Break `shape` into chunks of `chunks` along each axis. + This function takes a data array and reduces the number of dimensions to + the max allowed for display. The default behavior is to reduce the smallest + dimensions, using np.max. This can be improved in the future. - Examples - -------- - >>> _shape_chunks((10, 10, 10), 3) - ((3, 3, 3, 1), (3, 3, 3, 1), (3, 3, 3, 1)) + This also coerces 64-bit data to 32-bit data. """ - if isinstance(chunks, int): - chunks = (chunks,) * len(shape) - elif isinstance(chunks, Sequence): - if len(chunks) != len(shape): - raise ValueError("Length of `chunks` must match length of `shape`") - else: - raise TypeError("`chunks` must be an int or sequence of ints") - return tuple(_axis_chunks(length, chunk) for length, chunk in zip(shape, chunks)) + # TODO + # - allow dimensions to control how they are reduced (as opposed to just max) + # - for better way to determine which dims need to be reduced (currently just + # the smallest dims) + data = data.squeeze() + if extra_dims := data.ndim - ndims: + shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) + smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) + data = reductor(data, axis=smallest_dims) + + if data.dtype.itemsize > 4: # More than 32 bits + if np.issubdtype(data.dtype, np.integer): + data = data.astype(np.int32) + else: + data = data.astype(np.float32) + return data + + +# def _axis_chunks(total_length: int, chunk_size: int) -> tuple[int, ...]: +# """Break `total_length` into chunks of `chunk_size` plus remainder. + +# Examples +# -------- +# >>> _axis_chunks(10, 3) +# (3, 3, 3, 1) +# """ +# sequence = (chunk_size,) * (total_length // chunk_size) +# if remainder := total_length % chunk_size: +# sequence += (remainder,) +# return sequence + + +# def _shape_chunks( +# shape: tuple[int, ...], chunks: int | tuple[int, ...] +# ) -> tuple[tuple[int, ...], ...]: +# """Break `shape` into chunks of `chunks` along each axis. + +# Examples +# -------- +# >>> _shape_chunks((10, 10, 10), 3) +# ((3, 3, 3, 1), (3, 3, 3, 1), (3, 3, 3, 1)) +# """ +# if isinstance(chunks, int): +# chunks = (chunks,) * len(shape) +# elif isinstance(chunks, Sequence): +# if len(chunks) != len(shape): +# raise ValueError("Length of `chunks` must match length of `shape`") +# else: +# raise TypeError("`chunks` must be an int or sequence of ints") +# return tuple(_axis_chunks(length, chunk) for length, chunk in zip(shape, chunks)) def _slice2range(sl: slice | int, dim_size: int) -> range: @@ -145,8 +224,8 @@ def _slice2range(sl: slice | int, dim_size: int) -> range: def iter_chunk_aligned_slices( - shape: tuple[int, ...], - chunks: int | tuple[int, ...], + shape: Sequence[int], + chunks: Sequence[int], slices: tuple[int | slice | EllipsisType, ...], ) -> Iterator[tuple[slice, ...]]: """Yield chunk-aligned slices for a given shape and slices. @@ -185,8 +264,6 @@ def iter_chunk_aligned_slices( """ # Make chunks same length as shape if single int ndim = len(shape) - if isinstance(chunks, int): - chunks = (chunks,) * ndim if any(x == 0 for x in chunks): raise ValueError("Chunk size must be greater than zero") @@ -225,24 +302,22 @@ def iter_chunk_aligned_slices( yield tuple(chunk_slices) -def slice_center(s, dim_size): +def slice_center(s: slice | int, dim_size: int) -> float: """Calculate the center of a slice based on its start and stop attributes.""" - # For integer slices, center is the integer itself. if isinstance(s, int): return s - # For slice objects, calculate the middle point. - start = s.start if s.start is not None else 0 - stop = s.stop if s.stop is not None else dim_size + start = float(s.start) if s.start is not None else 0 + stop = float(s.stop) if s.stop is not None else dim_size return (start + stop) / 2 -def center_distance_key(slice_tuple, shape): - """Calculate the Euclidean distance from the center of the slices to the center of the shape.""" - shape_center = [dim / 2 for dim in shape] - slice_centers = [slice_center(s, dim) for s, dim in zip(slice_tuple, shape)] - - # Calculate Euclidean distance from the slice centers to the shape center - distance = np.sqrt( - sum((sc - cc) ** 2 for sc, cc in zip(slice_centers, shape_center)) - ) - return distance +def distance_from_coord( + slice_tuple: tuple[slice | int, ...], + shape: tuple[int, ...], + coord: Iterable[float] = (), # defaults to center of shape +) -> float: + """Euclidean distance from the center of an nd slice to the center of shape.""" + if not coord: + coord = (dim / 2 for dim in shape) + slice_centers = (slice_center(s, dim) for s, dim in zip(slice_tuple, shape)) + return math.hypot(*(sc - cc for sc, cc in zip(slice_centers, coord))) diff --git a/src/ndv/viewer/_backends/_vispy.py b/src/ndv/viewer/_backends/_vispy.py index 8a535ce..329ff42 100644 --- a/src/ndv/viewer/_backends/_vispy.py +++ b/src/ndv/viewer/_backends/_vispy.py @@ -192,6 +192,7 @@ def set_range( is_3d = isinstance(self._camera, scene.ArcballCamera) if is_3d: self._camera._quaternion = DEFAULT_QUATERNION + print("Setting range", x, y, z, margin) self._view.camera.set_range(x=x, y=y, z=z, margin=margin) if is_3d: max_size = max(self._current_shape) diff --git a/src/ndv/viewer/_viewer.py b/src/ndv/viewer/_viewer.py index 9efcabe..537970d 100644 --- a/src/ndv/viewer/_viewer.py +++ b/src/ndv/viewer/_viewer.py @@ -2,14 +2,14 @@ from collections import defaultdict from itertools import cycle -from typing import TYPE_CHECKING, Literal, cast +from typing import TYPE_CHECKING, Literal, Sequence, cast import cmap -import numpy as np from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget from superqt import QCollapsible, QElidingLabel, QIconifyIcon, ensure_main_thread from superqt.utils import qthrottled, signals_blocked +from ndv._chunking import Chunker, ChunkResponse, RequestFinished from ndv.viewer._components import ( ChannelMode, ChannelModeButton, @@ -20,16 +20,16 @@ from ._backends import get_canvas from ._data_wrapper import DataWrapper from ._dims_slider import DimsSliders -from ._lut_control import LutControl if TYPE_CHECKING: from concurrent.futures import Future - from typing import Any, Callable, Hashable, Iterable, Sequence, TypeAlias + from typing import Any, Hashable, Iterable, TypeAlias from qtpy.QtGui import QCloseEvent from ._backends._protocols import PCanvas, PImageHandle from ._dims_slider import DimKey, Indices, Sizes + from ._lut_control import LutControl ImgKey: TypeAlias = Hashable # any mapping of dimensions to sizes @@ -119,7 +119,10 @@ def __init__( # ATTRIBUTES ---------------------------------------------------- # mapping of key to a list of objects that control image nodes in the canvas - self._img_handles: defaultdict[ImgKey, list[PImageHandle]] = defaultdict(list) + self._img_handles: defaultdict[int, dict[tuple, PImageHandle]] = defaultdict( + dict + ) + # mapping of same keys to the LutControl objects control image display props self._lut_ctrls: dict[ImgKey, LutControl] = {} # the set of dimensions we are currently visualizing (e.g. XY) @@ -140,6 +143,14 @@ def __init__( # number of dimensions to display self._ndims: Literal[2, 3] = 2 + self._chunker = Chunker( + None, + # IMPORTANT + # chunking here will determine how non-visualized dims are reduced + # so chunkshape will need to change based on the set of visualized dims + chunks=(1000, 1000, 64, 32), + on_ready=self._on_data_slice_ready, + ) # WIDGETS ---------------------------------------------------- @@ -258,12 +269,11 @@ def set_data( """ # store the data self._data_wrapper = DataWrapper.create(data) - from ndv._chunking import Slicer - - self._slicer = Slicer(self._data_wrapper, chunks=(1, 1, 64, 47)) + self._chunker.data_wrapper = self._data_wrapper # set channel axis self._channel_axis = self._data_wrapper.guess_channel_axis() + self._chunker.channel_axis = self._channel_axis # update the dimensions we are visualizing sizes = self._data_wrapper.sizes() @@ -403,26 +413,6 @@ def _image_key(self, index: Indices) -> ImgKey: return val return 0 - def _build_requests(self, index: Indices) -> list[Indices]: - # receives an unordered mapping of dimension keys to int | slice - # for example {1: 0, 2: 128, 3: 128, 0: 38} - # returns a list of indices to request from the datastore that takes - # into account the channel axis, channel mode, and whether 2d or 3d mode. - sizes = self._data_wrapper.sizes() - if self._channel_mode == ChannelMode.COMPOSITE and self._channel_axis in sizes: - indices: list[Indices] = [ - {**index, self._channel_axis: i} - for i in range(sizes[self._channel_axis]) - ] - else: - indices = [index] - # don't request any dimensions that are not visualized - for idx in indices: - for k in self._visualized_dims: - idx.pop(k, None) - - return indices - def _update_data_for_index(self, index: Indices) -> None: """Retrieve data for `index` from datastore and update canvas image(s). @@ -434,15 +424,9 @@ def _update_data_for_index(self, index: Indices) -> None: makes a request for the new data slice and queues _on_data_future_done to be called when the data is ready. """ - print("\n\n-----------------") - print("update_data_for_index", index) - print("visualized_dims", self._visualized_dims) - print("sizes", self._data_wrapper.sizes()) - print("channel_axis", self._channel_axis) - print("channel_mode", self._channel_mode) - - self._data_wrapper.to_conventional(index) - self._slicer.request_index(index, self._on_data_slice_ready) + print(f"\n--------\nrequesting index {index}") + self._progress_spinner.show() + self._chunker.request_index(index) # indices = self._build_requests(index) # if self._last_future: @@ -457,21 +441,16 @@ def _update_data_for_index(self, index: Indices) -> None: # f.add_done_callback(self._on_data_slice_ready) def closeEvent(self, a0: QCloseEvent | None) -> None: - if self._last_future is not None: - self._last_future.cancel() - self._last_future = None + self._chunker.shutdown() super().closeEvent(a0) @ensure_main_thread # type: ignore - def _on_data_slice_ready( - self, future: Future[Iterable[tuple[Indices, np.ndarray]]] - ) -> None: + def _on_data_slice_ready(self, response: ChunkResponse) -> None: """Update the displayed image for the given index. Connected to the future returned by _isel. """ - offset = tuple(int(getattr(sl, "start", sl)) for sl in future.idx)[-2:] - self._update_canvas_data(future.data, offset) + self._update_canvas_data(response) return # NOTE: removing the reference to the last future here is important # because the future has a reference to this widget in its _done_callbacks @@ -485,7 +464,7 @@ def _on_data_slice_ready( self._update_canvas_data(datum, idx) self._canvas.refresh() - def _update_canvas_data(self, data: np.ndarray, offset: list[int]) -> None: + def _update_canvas_data(self, response: ChunkResponse) -> None: # def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: """Actually update the image handle(s) with the (sliced) data. @@ -493,73 +472,65 @@ def _update_canvas_data(self, data: np.ndarray, offset: list[int]) -> None: dimensions remaining that are more than the number of visualized dimensions (currently just 2D) will be reduced using max intensity projection (currently). """ - # imkey = self._image_key(index) - imkey = 0 - print(offset) - datum = self._reduce_data_for_display(data) - if handles := self._img_handles[offset]: - for handle in handles: - print("updating handle") - handle.data = datum - handle.clim = (0, 45000) - # if ctrl := self._lut_ctrls.get(imkey, None): - # ctrl.update_autoscale() + if response is RequestFinished: # fix typing + self._progress_spinner.hide() + return + + if self._channel_mode == ChannelMode.MONO: + channel_key = -1 else: - cm = ( - next(self._cmap_cycle) - if self._channel_mode == ChannelMode.COMPOSITE - else GRAYS - ) + channel_key = response.channel_index + offset = response.offset + datum = response.data + + if ( + channel_handles := self._img_handles[channel_key] + ) and offset in channel_handles: + handle = channel_handles[offset] + handle.data = datum + else: + cm = DEFAULT_COLORMAPS[channel_key] if datum.ndim == 2: handle = self._canvas.add_image(datum, cmap=cm, offset=offset) handle.clim = (0, 45000) - handles.append(handle) - elif datum.ndim == 3: - handles.append(self._canvas.add_volume(datum, cmap=cm)) - if imkey not in self._lut_ctrls: - # ch_index = index.get(self._channel_axis, 0) - ch_index = 0 - self._lut_ctrls[imkey] = c = LutControl( - f"Ch {ch_index}", - handles, - self, - cmaplist=self._cmaps + DEFAULT_COLORMAPS, - ) - self._lut_drop.addWidget(c) - - def _reduce_data_for_display( - self, data: np.ndarray, reductor: Callable[..., np.ndarray] = np.max - ) -> np.ndarray: - """Reduce the number of dimensions in the data for display. - - This function takes a data array and reduces the number of dimensions to - the max allowed for display. The default behavior is to reduce the smallest - dimensions, using np.max. This can be improved in the future. - - This also coerces 64-bit data to 32-bit data. - """ - # TODO - # - allow dimensions to control how they are reduced (as opposed to just max) - # - for better way to determine which dims need to be reduced (currently just - # the smallest dims) - data = data.squeeze() - visualized_dims = self._ndims - if extra_dims := data.ndim - visualized_dims: - shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) - smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) - data = reductor(data, axis=smallest_dims) - - if data.dtype.itemsize > 4: # More than 32 bits - if np.issubdtype(data.dtype, np.integer): - data = data.astype(np.int32) + channel_handles[offset] = handle else: - data = data.astype(np.float32) - return data + raise NotImplementedError("Volume rendering not yet supported") + self._canvas.refresh() + + # if handles := self._img_handles[offset]: + # for handle in handles: + # handle.data = datum + # handle.clim = (0, 45000) + # # if ctrl := self._lut_ctrls.get(imkey, None): + # # ctrl.update_autoscale() + # else: + # cm = ( + # next(self._cmap_cycle) + # if self._channel_mode == ChannelMode.COMPOSITE + # else GRAYS + # ) + # if datum.ndim == 2: + # handle = self._canvas.add_image(datum, cmap=cm, offset=offset) + # handle.clim = (0, 45000) + # handles.append(handle) + # elif datum.ndim == 3: + # handles.append(self._canvas.add_volume(datum, cmap=cm)) + # if imkey not in self._lut_ctrls: + # # ch_index = index.get(self._channel_axis, 0) + # ch_index = 0 + # self._lut_ctrls[imkey] = c = LutControl( + # f"Ch {ch_index}", + # handles, + # self, + # cmaplist=self._cmaps + DEFAULT_COLORMAPS, + # ) + # self._lut_drop.addWidget(c) def _clear_images(self) -> None: """Remove all images from the canvas.""" for handles in self._img_handles.values(): - for handle in handles: + for handle in handles.values(): handle.remove() self._img_handles.clear() From c226bef0295681cd0bf66a48a3f8d124b4947652 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 10 Jun 2024 13:44:25 -0400 Subject: [PATCH 04/12] wip on zarr --- src/ndv/_chunking.py | 138 ++++++++++++++++++-------------- src/ndv/viewer/_data_wrapper.py | 14 +++- src/ndv/viewer/_viewer.py | 78 ++++++------------ 3 files changed, 117 insertions(+), 113 deletions(-) diff --git a/src/ndv/_chunking.py b/src/ndv/_chunking.py index 7e091b7..832c7f8 100644 --- a/src/ndv/_chunking.py +++ b/src/ndv/_chunking.py @@ -3,7 +3,6 @@ import math from concurrent.futures import Future, ThreadPoolExecutor from itertools import product -from types import EllipsisType from typing import ( TYPE_CHECKING, Any, @@ -20,6 +19,7 @@ if TYPE_CHECKING: from collections import deque + from types import EllipsisType from typing import Callable, Iterable, Iterator, TypeAlias from .viewer._data_wrapper import DataWrapper @@ -42,7 +42,8 @@ class ChunkResponse(NamedTuple): channel_index: int = -1 -RequestFinished = object() +# sentinel value +RequestFinished = ChunkResponse((), np.empty(0), ()) class Chunker: @@ -58,12 +59,13 @@ def __init__( self.pending_futures: deque[Future[ChunkResponse]] = Deque() self.on_ready = on_ready self.channel_axis: int | None = None + self._notification_sent = True def __del__(self) -> None: self.shutdown() - def shutdown(self) -> None: - self.executor.shutdown(cancel_futures=True, wait=True) + def shutdown(self, cancel_futures: bool = True, wait: bool = True) -> None: + self.executor.shutdown(cancel_futures=cancel_futures, wait=wait) def _request_chunk_sync( self, idx: tuple[int | slice, ...], channel_axis: int | None @@ -85,9 +87,6 @@ def _request_chunk_sync( except TypeError: offset = (0, 0) - import time - - time.sleep(0.02) return ChunkResponse( idx=idx, data=data, offset=offset, channel_index=channel_index ) @@ -106,6 +105,9 @@ def request_index(self, index: Indices, cancel_existing: bool = True) -> None: else: shape = self.data_wrapper.data.shape + # TODO + # we should *only* chunk along visualized axes ... + # we never chunk the channel axis if isinstance(chunks, int): _chunks = [chunks] * len(shape) @@ -120,9 +122,10 @@ def request_index(self, index: Indices, cancel_existing: bool = True) -> None: iter_chunk_aligned_slices(shape, _chunks, idx), key=lambda x: distance_from_coord(x, shape), ) - # print("Requesting index:", idx) - # print("subchunks", subchunks) - # print() + print("Requesting index:", idx) + print("subchunks", subchunks) + print() + self._notification_sent = False for chunk_idx in subchunks: future = self.executor.submit( self._request_chunk_sync, chunk_idx, self.channel_axis @@ -139,9 +142,9 @@ def _on_chunk_ready(self, future: Future[ChunkResponse]) -> None: return if self.on_ready is not None: self.on_ready(future.result()) - if not self.pending_futures: - # FIXME: this emits multiple times sometimes + if not self.pending_futures and not self._notification_sent: # Fix typing + self._notification_sent = True self.on_ready(RequestFinished) @@ -159,12 +162,11 @@ def _reduce_data_for_display( # TODO # - allow dimensions to control how they are reduced (as opposed to just max) # - for better way to determine which dims need to be reduced (currently just - # the smallest dims) + # the first extra dims) data = data.squeeze() if extra_dims := data.ndim - ndims: - shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) - smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) - data = reductor(data, axis=smallest_dims) + axis = tuple(range(extra_dims)) + data = reductor(data, axis=axis) if data.dtype.itemsize > 4: # More than 32 bits if np.issubdtype(data.dtype, np.integer): @@ -174,40 +176,6 @@ def _reduce_data_for_display( return data -# def _axis_chunks(total_length: int, chunk_size: int) -> tuple[int, ...]: -# """Break `total_length` into chunks of `chunk_size` plus remainder. - -# Examples -# -------- -# >>> _axis_chunks(10, 3) -# (3, 3, 3, 1) -# """ -# sequence = (chunk_size,) * (total_length // chunk_size) -# if remainder := total_length % chunk_size: -# sequence += (remainder,) -# return sequence - - -# def _shape_chunks( -# shape: tuple[int, ...], chunks: int | tuple[int, ...] -# ) -> tuple[tuple[int, ...], ...]: -# """Break `shape` into chunks of `chunks` along each axis. - -# Examples -# -------- -# >>> _shape_chunks((10, 10, 10), 3) -# ((3, 3, 3, 1), (3, 3, 3, 1), (3, 3, 3, 1)) -# """ -# if isinstance(chunks, int): -# chunks = (chunks,) * len(shape) -# elif isinstance(chunks, Sequence): -# if len(chunks) != len(shape): -# raise ValueError("Length of `chunks` must match length of `shape`") -# else: -# raise TypeError("`chunks` must be an int or sequence of ints") -# return tuple(_axis_chunks(length, chunk) for length, chunk in zip(shape, chunks)) - - def _slice2range(sl: slice | int, dim_size: int) -> range: """Convert slice to range, handling single int as well. @@ -226,7 +194,7 @@ def _slice2range(sl: slice | int, dim_size: int) -> range: def iter_chunk_aligned_slices( shape: Sequence[int], chunks: Sequence[int], - slices: tuple[int | slice | EllipsisType, ...], + slices: Sequence[int | slice | EllipsisType], ) -> Iterator[tuple[slice, ...]]: """Yield chunk-aligned slices for a given shape and slices. @@ -241,15 +209,31 @@ def iter_chunk_aligned_slices( The full slices to apply to the array. Ellipsis is supported to represent multiple slices. + Returns + ------- + Iterator[tuple[slice, ...]] + An iterator of chunk-aligned slices. + + Raises + ------ + ValueError + If the length of `chunks`, `shape`, and `slices` do not match, or any chunks + are zero. + IndexError + If more than one Ellipsis is present in `slices`. + Examples -------- - >>> list(iter_chunk_aligned_slices((6, 6), 4, (slice(1, 4), ...))) + >>> list(iter_chunk_aligned_slices(shape=(6, 6), chunks=4, (slice(1, 4), ...))) [ (slice(1, 4, None), slice(0, 4, None)), (slice(1, 4, None), slice(4, 6, None)), ] - >>> list(iter_chunk_aligned_slices((10, 9), (4, 3), (slice(3, 9), slice(1, None)))) + >>> x = iter_chunk_aligned_slices( + ... shape=(10, 9), chunks=(4, 3), slices=(slice(3, 9), slice(1, None)) + ... ) + >>> list(x) [ (slice(3, 4, None), slice(1, 3, None)), (slice(3, 4, None), slice(3, 6, None)), @@ -267,19 +251,23 @@ def iter_chunk_aligned_slices( if any(x == 0 for x in chunks): raise ValueError("Chunk size must be greater than zero") - if any(isinstance(sl, EllipsisType) for sl in slices): + if num_ellipsis := slices.count(Ellipsis): + if num_ellipsis > 1: + raise IndexError("an index can only have a single ellipsis ('...')") # Replace Ellipsis with multiple slices - if slices.count(Ellipsis) > 1: - raise ValueError("Only one Ellipsis is allowed") el_idx = slices.index(Ellipsis) n_remaining = ndim - len(slices) + 1 - slices = slices[:el_idx] + (slice(None),) * n_remaining + slices[el_idx + 1 :] + slices = ( + tuple(slices[:el_idx]) + + (slice(None),) * n_remaining + + tuple(slices[el_idx + 1 :]) + ) + slices = cast(tuple[int | slice, ...], slices) # now we have no Ellipsis if not (len(chunks) == ndim == len(slices)): raise ValueError("Length of `chunks`, `shape`, and `slices` must match") # Create ranges for each dimension based on the slices provided - slices = cast(tuple[int | slice, ...], slices) ranges = [_slice2range(sl, dim) for sl, dim in zip(slices, shape)] # Generate indices for each dimension that align with chunks @@ -321,3 +309,37 @@ def distance_from_coord( coord = (dim / 2 for dim in shape) slice_centers = (slice_center(s, dim) for s, dim in zip(slice_tuple, shape)) return math.hypot(*(sc - cc for sc, cc in zip(slice_centers, coord))) + + +# def _axis_chunks(total_length: int, chunk_size: int) -> tuple[int, ...]: +# """Break `total_length` into chunks of `chunk_size` plus remainder. + +# Examples +# -------- +# >>> _axis_chunks(10, 3) +# (3, 3, 3, 1) +# """ +# sequence = (chunk_size,) * (total_length // chunk_size) +# if remainder := total_length % chunk_size: +# sequence += (remainder,) +# return sequence + + +# def _shape_chunks( +# shape: tuple[int, ...], chunks: int | tuple[int, ...] +# ) -> tuple[tuple[int, ...], ...]: +# """Break `shape` into chunks of `chunks` along each axis. + +# Examples +# -------- +# >>> _shape_chunks((10, 10, 10), 3) +# ((3, 3, 3, 1), (3, 3, 3, 1), (3, 3, 3, 1)) +# """ +# if isinstance(chunks, int): +# chunks = (chunks,) * len(shape) +# elif isinstance(chunks, Sequence): +# if len(chunks) != len(shape): +# raise ValueError("Length of `chunks` must match length of `shape`") +# else: +# raise TypeError("`chunks` must be an int or sequence of ints") +# return tuple(_axis_chunks(length, chunk) for length, chunk in zip(shape, chunks)) diff --git a/src/ndv/viewer/_data_wrapper.py b/src/ndv/viewer/_data_wrapper.py index 6b49ea4..967b276 100644 --- a/src/ndv/viewer/_data_wrapper.py +++ b/src/ndv/viewer/_data_wrapper.py @@ -132,8 +132,16 @@ def isel(self, indexers: Indices) -> np.ndarray: def __getitem__(self, index: tuple[int | slice, ...]) -> np.ndarray: return self._data[index] + def chunks(self) -> tuple[int, ...] | int | None: + if hasattr(self._data, "chunks"): + return self._data.chunks + return None + def to_conventional(self, indexers: Indices) -> tuple[int | slice, ...]: """Convert named indices to a tuple of integers and slices.""" + if hasattr(self, "_name2index"): + indexers = {self._name2index.get(k, k): v for k, v in indexers.items()} + return tuple(indexers.get(k, slice(None)) for k in range(len(self.data.shape))) def isel_async( @@ -142,14 +150,14 @@ def isel_async( """Asynchronous version of isel.""" return _EXECUTOR.submit(lambda: [(idx, self.isel(idx)) for idx in indexers]) - def guess_channel_axis(self) -> Hashable | None: + def guess_channel_axis(self) -> int | None: """Return the (best guess) axis name for the channel dimension.""" # for arrays with labeled dimensions, # see if any of the dimensions are named "channel" - for dimkey, val in self.sizes().items(): + for i, (dimkey, val) in enumerate(self.sizes().items()): if str(dimkey).lower() in self.COMMON_CHANNEL_NAMES: if val <= self.MAX_CHANNELS: - return dimkey + return i # for shaped arrays, use the smallest dimension as the channel axis shape = getattr(self._data, "shape", None) diff --git a/src/ndv/viewer/_viewer.py b/src/ndv/viewer/_viewer.py index 537970d..3355e2b 100644 --- a/src/ndv/viewer/_viewer.py +++ b/src/ndv/viewer/_viewer.py @@ -73,13 +73,13 @@ class NDViewer(QWidget): with the `_dims_sliders.value()` method. To programmatically set the current position, use the `setIndex` method. This will set the values of the sliders, which in turn will trigger the display of the new slice via the - `_update_data_for_index` method. - - `_update_data_for_index` is an asynchronous method that retrieves the data for + `_request_data_for_index` method. + - `_request_data_for_index` is an asynchronous method that retrieves the data for the given index from the datastore (using `_isel`) and queues the - `_on_data_slice_ready` method to be called when the data is ready. The logic + `_draw_chunk` method to be called when the data is ready. The logic for extracting data from the datastore is defined in `_data_wrapper.py`, which handles idiosyncrasies of different datastores (e.g. xarray, tensorstore, etc). - - `_on_data_slice_ready` is called when the data is ready, and updates the image. + - `_draw_chunk` is called when the data is ready, and updates the image. Note that if the slice is multidimensional, the data will be reduced to 2D using max intensity projection (and double-clicking on any given dimension slider will turn it into a range slider allowing a projection to be made over that dimension). @@ -138,7 +138,7 @@ def __init__( else: self._cmaps = DEFAULT_COLORMAPS self._cmap_cycle = cycle(self._cmaps) - # the last future that was created by _update_data_for_index + # the last future that was created by _request_data_for_index self._last_future: Future | None = None # number of dimensions to display @@ -148,8 +148,8 @@ def __init__( # IMPORTANT # chunking here will determine how non-visualized dims are reduced # so chunkshape will need to change based on the set of visualized dims - chunks=(1000, 1000, 64, 32), - on_ready=self._on_data_slice_ready, + chunks=32, + on_ready=self._draw_chunk, ) # WIDGETS ---------------------------------------------------- @@ -180,7 +180,7 @@ def __init__( # the sliders that control the index of the displayed image self._dims_sliders = DimsSliders(self) self._dims_sliders.valueChanged.connect( - qthrottled(self._update_data_for_index, 20, leading=True) + qthrottled(self._request_data_for_index, 20, leading=True) ) self._lut_drop = QCollapsible("LUTs", self) @@ -270,6 +270,12 @@ def set_data( # store the data self._data_wrapper = DataWrapper.create(data) self._chunker.data_wrapper = self._data_wrapper + if chunks := self._data_wrapper.chunks(): + # temp hack ... always group non-visible channels + chunks = list(chunks) + chunks[:-2] = (1000,) * len(chunks[:-2]) + print(chunks) + self._chunker.chunks = tuple(chunks) # set channel axis self._channel_axis = self._data_wrapper.guess_channel_axis() @@ -325,7 +331,7 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None: # clear image handles and redraw if self._img_handles: self._clear_images() - self._update_data_for_index(self._dims_sliders.value()) + self._request_data_for_index(self._dims_sliders.value()) def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: """Set the mode for displaying the channels. @@ -359,7 +365,7 @@ def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: if self._img_handles: self._clear_images() - self._update_data_for_index(self._dims_sliders.value()) + self._request_data_for_index(self._dims_sliders.value()) def set_current_index(self, index: Indices | None = None) -> None: """Set the index of the displayed image. @@ -413,7 +419,7 @@ def _image_key(self, index: Indices) -> ImgKey: return val return 0 - def _update_data_for_index(self, index: Indices) -> None: + def _request_data_for_index(self, index: Indices) -> None: """Retrieve data for `index` from datastore and update canvas image(s). This is the first step in updating the displayed image, it is triggered by @@ -428,44 +434,8 @@ def _update_data_for_index(self, index: Indices) -> None: self._progress_spinner.show() self._chunker.request_index(index) - # indices = self._build_requests(index) - # if self._last_future: - # self._last_future.cancel() - - # try: - # self._last_future = f = self._data_wrapper.isel_async(indices) - # except Exception as e: - # raise type(e)(f"Failed to index data with {index}: {e}") from e - - # self._progress_spinner.show() - # f.add_done_callback(self._on_data_slice_ready) - - def closeEvent(self, a0: QCloseEvent | None) -> None: - self._chunker.shutdown() - super().closeEvent(a0) - @ensure_main_thread # type: ignore - def _on_data_slice_ready(self, response: ChunkResponse) -> None: - """Update the displayed image for the given index. - - Connected to the future returned by _isel. - """ - self._update_canvas_data(response) - return - # NOTE: removing the reference to the last future here is important - # because the future has a reference to this widget in its _done_callbacks - # which will prevent the widget from being garbage collected if the future - self._last_future = None - self._progress_spinner.hide() - if future.cancelled(): - return - - for idx, datum in future.result(): - self._update_canvas_data(datum, idx) - self._canvas.refresh() - - def _update_canvas_data(self, response: ChunkResponse) -> None: - # def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: + def _draw_chunk(self, response: ChunkResponse) -> None: """Actually update the image handle(s) with the (sliced) data. By this point, data should be sliced from the underlying datastore. Any @@ -473,6 +443,7 @@ def _update_canvas_data(self, response: ChunkResponse) -> None: (currently just 2D) will be reduced using max intensity projection (currently). """ if response is RequestFinished: # fix typing + print(">>>>>>>>>>>>> RequestFinished") self._progress_spinner.hide() return @@ -482,17 +453,16 @@ def _update_canvas_data(self, response: ChunkResponse) -> None: channel_key = response.channel_index offset = response.offset datum = response.data - if ( channel_handles := self._img_handles[channel_key] ) and offset in channel_handles: handle = channel_handles[offset] handle.data = datum else: - cm = DEFAULT_COLORMAPS[channel_key] + cm = DEFAULT_COLORMAPS[channel_key] # TODO if datum.ndim == 2: handle = self._canvas.add_image(datum, cmap=cm, offset=offset) - handle.clim = (0, 45000) + handle.clim = (0, 100) channel_handles[offset] = handle else: raise NotImplementedError("Volume rendering not yet supported") @@ -542,4 +512,8 @@ def _clear_images(self) -> None: def _is_idle(self) -> bool: """Return True if no futures are running. Used for testing, and debugging.""" - return self._last_future is None + return bool(self._chunker.pending_futures) + + def closeEvent(self, a0: QCloseEvent | None) -> None: + self._chunker.shutdown() + super().closeEvent(a0) From 792db652e7f68ffc74556af34f971ddb20157939 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 10 Jun 2024 16:24:39 -0400 Subject: [PATCH 05/12] realized snag --- src/ndv/_chunking.py | 25 ++-- src/ndv/viewer/_backends/_protocols.py | 10 +- src/ndv/viewer/_backends/_vispy.py | 10 +- src/ndv/viewer/_data_wrapper.py | 104 ++++++++-------- src/ndv/viewer/_lut_control.py | 23 ++-- src/ndv/viewer/_viewer.py | 160 ++++++++++++++----------- 6 files changed, 195 insertions(+), 137 deletions(-) diff --git a/src/ndv/_chunking.py b/src/ndv/_chunking.py index 832c7f8..59fc5b0 100644 --- a/src/ndv/_chunking.py +++ b/src/ndv/_chunking.py @@ -8,6 +8,7 @@ Any, Deque, Hashable, + Literal, Mapping, NamedTuple, Sequence, @@ -68,7 +69,7 @@ def shutdown(self, cancel_futures: bool = True, wait: bool = True) -> None: self.executor.shutdown(cancel_futures=cancel_futures, wait=wait) def _request_chunk_sync( - self, idx: tuple[int | slice, ...], channel_axis: int | None + self, idx: tuple[int | slice, ...], channel_axis: int | None, ndims: int ) -> ChunkResponse: # idx is guaranteed to have length equal to the number of dimensions if channel_axis is not None: @@ -79,28 +80,33 @@ def _request_chunk_sync( channel_index = -1 data = self.data_wrapper[idx] # type: ignore [index] - data = _reduce_data_for_display(data, 2) + data = _reduce_data_for_display(data, ndims) # FIXME: temporary # this needs to be aware of nvisible dimensions try: - offset = tuple(int(getattr(sl, "start", sl)) for sl in idx)[-2:] + offset = tuple(int(getattr(sl, "start", sl)) for sl in idx)[-3:] except TypeError: - offset = (0, 0) + offset = (0, 0, 0) return ChunkResponse( idx=idx, data=data, offset=offset, channel_index=channel_index ) - def request_index(self, index: Indices, cancel_existing: bool = True) -> None: + def request_index( + self, index: Indices, *, cancel_existing: bool = True, ndims: Literal[2, 3] = 2 + ) -> None: if cancel_existing: for future in list(self.pending_futures): future.cancel() + # TODO: see if we can get the channel_axis logic back to the viewer/request side + if self.data_wrapper is None: return idx = self.data_wrapper.to_conventional(index) - - if (chunks := self.chunks) is None: + chunks = self.chunks + multi_channel = isinstance(index.get(self.channel_axis), slice) + if chunks is None and not multi_channel: subchunks: Iterable[tuple[int | slice, ...]] = [idx] else: shape = self.data_wrapper.data.shape @@ -111,6 +117,9 @@ def request_index(self, index: Indices, cancel_existing: bool = True) -> None: # we never chunk the channel axis if isinstance(chunks, int): _chunks = [chunks] * len(shape) + elif chunks is None: + # hack FIXME + _chunks = [100000] * len(shape) else: _chunks = list(chunks) if self.channel_axis is not None: @@ -128,7 +137,7 @@ def request_index(self, index: Indices, cancel_existing: bool = True) -> None: self._notification_sent = False for chunk_idx in subchunks: future = self.executor.submit( - self._request_chunk_sync, chunk_idx, self.channel_axis + self._request_chunk_sync, chunk_idx, self.channel_axis, ndims ) self.pending_futures.append(future) future.add_done_callback(self._on_chunk_ready) diff --git a/src/ndv/viewer/_backends/_protocols.py b/src/ndv/viewer/_backends/_protocols.py index 413038d..f45a655 100644 --- a/src/ndv/viewer/_backends/_protocols.py +++ b/src/ndv/viewer/_backends/_protocols.py @@ -41,8 +41,14 @@ def set_range( def refresh(self) -> None: ... def qwidget(self) -> QWidget: ... def add_image( - self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ... + self, + data: np.ndarray | None = ..., + cmap: cmap.Colormap | None = ..., + offset: tuple[float, float] | None = None, # (Y, X) ) -> PImageHandle: ... def add_volume( - self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ... + self, + data: np.ndarray | None = ..., + cmap: cmap.Colormap | None = ..., + offset: tuple[float, float, float] | None = ..., # (Z, Y, X) ) -> PImageHandle: ... diff --git a/src/ndv/viewer/_backends/_vispy.py b/src/ndv/viewer/_backends/_vispy.py index 329ff42..f8b04ed 100644 --- a/src/ndv/viewer/_backends/_vispy.py +++ b/src/ndv/viewer/_backends/_vispy.py @@ -135,7 +135,7 @@ def add_image( self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None, - offset: tuple[int, ...] = (), + offset: tuple[float, float] | None = None, # (Y, X) ) -> VispyImageHandle: """Add a new Image node to the scene.""" img = scene.visuals.Image(data, parent=self._view.scene) @@ -155,13 +155,19 @@ def add_image( return handle def add_volume( - self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + self, + data: np.ndarray | None = None, + cmap: cmap.Colormap | None = None, + offset: tuple[float, float, float] | None = None, # (Z, Y, X) ) -> VispyImageHandle: vol = scene.visuals.Volume( data, parent=self._view.scene, interpolation="nearest" ) vol.set_gl_state("additive", depth_test=False) vol.interactive = True + if offset: + vol.transform = scene.STTransform(translate=offset[::-1]) + if data is not None: self._current_shape, prev_shape = data.shape, self._current_shape if len(prev_shape) != 3: diff --git a/src/ndv/viewer/_data_wrapper.py b/src/ndv/viewer/_data_wrapper.py index 967b276..366e34f 100644 --- a/src/ndv/viewer/_data_wrapper.py +++ b/src/ndv/viewer/_data_wrapper.py @@ -4,8 +4,8 @@ import logging import sys +import warnings from abc import abstractmethod -from concurrent.futures import Future, ThreadPoolExecutor from contextlib import suppress from typing import ( TYPE_CHECKING, @@ -13,7 +13,6 @@ Container, Generic, Hashable, - Iterable, Iterator, Mapping, Sequence, @@ -58,7 +57,7 @@ def __gt__(self, other: _T_contra, /) -> bool: ... _T = TypeVar("_T", bound=type) # Global executor for slice requests -_EXECUTOR = ThreadPoolExecutor(max_workers=2) +# _EXECUTOR = ThreadPoolExecutor(max_workers=2) def _recurse_subclasses(cls: _T) -> Iterator[_T]: @@ -102,13 +101,6 @@ def create(cls, data: ArrayT) -> DataWrapper[ArrayT]: return subclass(data) raise NotImplementedError(f"Don't know how to wrap type {type(data)}") - def __init__(self, data: ArrayT) -> None: - self._data = data - - @property - def data(self) -> ArrayT: - return self._data - @classmethod @abstractmethod def supports(cls, obj: Any) -> bool: @@ -119,53 +111,74 @@ def supports(cls, obj: Any) -> bool: """ raise NotImplementedError - @abstractmethod - def isel(self, indexers: Indices) -> np.ndarray: - """Select a slice from a data store using (possibly) named indices. + def __init__(self, data: ArrayT) -> None: + self._data = data + self._name2index: dict[str, int] = {} + if names := self.dimension_names(): + self._name2index = {name: i for i, name in enumerate(names)} - This follows the xarray-style indexing, where indexers is a mapping of - dimension names to indices or slices. Subclasses should implement this - method to return a numpy array. - """ - raise NotImplementedError + @property + def data(self) -> ArrayT: + return self._data + + # @abstractmethod + # def isel(self, indexers: Indices) -> np.ndarray: + # """Select a slice from a data store using (possibly) named indices. + + # This follows the xarray-style indexing, where indexers is a mapping of + # dimension names to indices or slices. Subclasses should implement this + # method to return a numpy array. + # """ + # raise NotImplementedError + + def shape(self) -> tuple[int, ...]: + return self._data.shape # type: ignore def __getitem__(self, index: tuple[int | slice, ...]) -> np.ndarray: - return self._data[index] + # reimplement in subclasses + return np.asarray(self._data[index]) # type: ignore [index] def chunks(self) -> tuple[int, ...] | int | None: - if hasattr(self._data, "chunks"): - return self._data.chunks + if chunks := getattr(self._data, "chunks", None): + if isinstance(chunks, Sequence) and all(isinstance(x, int) for x in chunks): + return tuple(chunks) + warnings.warn( + f"Unexpected chunks attribute: {chunks!r}. Ignoring.", stacklevel=2 + ) + return None + + def dimension_names(self) -> tuple[str, ...] | None: + """Return the names of the dimensions of the data.""" return None def to_conventional(self, indexers: Indices) -> tuple[int | slice, ...]: """Convert named indices to a tuple of integers and slices.""" - if hasattr(self, "_name2index"): - indexers = {self._name2index.get(k, k): v for k, v in indexers.items()} - - return tuple(indexers.get(k, slice(None)) for k in range(len(self.data.shape))) + _indexers = {self._name2index.get(str(k), k): v for k, v in indexers.items()} + return tuple(_indexers.get(k, slice(None)) for k in range(len(self.shape()))) - def isel_async( - self, indexers: list[Indices] - ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: - """Asynchronous version of isel.""" - return _EXECUTOR.submit(lambda: [(idx, self.isel(idx)) for idx in indexers]) + # def isel_async( + # self, indexers: list[Indices] + # ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: + # """Asynchronous version of isel.""" + # return _EXECUTOR.submit(lambda: [(idx, self.isel(idx)) for idx in indexers]) def guess_channel_axis(self) -> int | None: """Return the (best guess) axis name for the channel dimension.""" # for arrays with labeled dimensions, # see if any of the dimensions are named "channel" - for i, (dimkey, val) in enumerate(self.sizes().items()): - if str(dimkey).lower() in self.COMMON_CHANNEL_NAMES: - if val <= self.MAX_CHANNELS: - return i + shape = self.shape() + if names := self.dimension_names(): + for ax, name in enumerate(names): + if ( + name.lower() in self.COMMON_CHANNEL_NAMES + and shape[ax] <= self.MAX_CHANNELS + ): + return ax # for shaped arrays, use the smallest dimension as the channel axis - shape = getattr(self._data, "shape", None) - if isinstance(shape, Sequence): - with suppress(ValueError): - smallest_dim = min(shape) - if smallest_dim <= self.MAX_CHANNELS: - return shape.index(smallest_dim) + with suppress(ValueError): + if (smallest_dim := min(shape)) <= self.MAX_CHANNELS: + return shape.index(smallest_dim) return None def save_as_zarr(self, save_loc: str | Path) -> None: @@ -179,13 +192,10 @@ def sizes(self) -> Sizes: (`dims` is used by xarray, `names` is used by torch, etc...). If no labels are found, the dimensions are just named by their integer index. """ - shape = getattr(self._data, "shape", None) - if not isinstance(shape, Sequence) or not all( - isinstance(x, int) for x in shape - ): - raise NotImplementedError(f"Cannot determine sizes for {type(self._data)}") - dims = range(len(shape)) - return {dim: int(size) for dim, size in zip(dims, shape)} + shape = self.shape() + if (names := self.dimension_names()) and len(names) == len(shape): + return dict(zip(names, shape)) + return dict(enumerate(shape)) def summary_info(self) -> str: """Return info label with information about the data.""" diff --git a/src/ndv/viewer/_lut_control.py b/src/ndv/viewer/_lut_control.py index b20909e..b42005a 100644 --- a/src/ndv/viewer/_lut_control.py +++ b/src/ndv/viewer/_lut_control.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Mapping, cast import numpy as np from qtpy.QtCore import Qt @@ -35,13 +35,14 @@ def showPopup(self) -> None: class LutControl(QWidget): def __init__( self, + channel: Mapping[Any, PImageHandle], name: str = "", - handles: Iterable[PImageHandle] = (), parent: QWidget | None = None, cmaplist: Iterable[Any] = (), + cmap: cmap.Colormap | None = None, ) -> None: super().__init__(parent) - self._handles = handles + self._channel = channel self._name = name self._visible = QCheckBox(name) @@ -50,10 +51,12 @@ def __init__( self._cmap = CmapCombo() self._cmap.currentColormapChanged.connect(self._on_cmap_changed) - for handle in handles: + for handle in channel.values(): self._cmap.addColormap(handle.cmap) for color in cmaplist: self._cmap.addColormap(color) + if cmap is not None: + self._cmap.setCurrentColormap(cmap) self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) self._clims.setStyleSheet(SS) @@ -84,36 +87,36 @@ def autoscaleChecked(self) -> bool: def _on_clims_changed(self, clims: tuple[float, float]) -> None: self._auto_clim.setChecked(False) - for handle in self._handles: + for handle in self._channel.values(): handle.clim = clims def _on_visible_changed(self, visible: bool) -> None: - for handle in self._handles: + for handle in self._channel.values(): handle.visible = visible if visible: self.update_autoscale() def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: - for handle in self._handles: + for handle in self._channel.values(): handle.cmap = cmap def update_autoscale(self) -> None: if ( not self._auto_clim.isChecked() or not self._visible.isChecked() - or not self._handles + or not self._channel ): return # find the min and max values for the current channel clims = [np.inf, -np.inf] - for handle in self._handles: + for handle in self._channel.values(): clims[0] = min(clims[0], np.nanmin(handle.data)) clims[1] = max(clims[1], np.nanmax(handle.data)) mi, ma = tuple(int(x) for x in clims) if mi != ma: - for handle in self._handles: + for handle in self._channel.values(): handle.clim = (mi, ma) # set the slider values to the new clims diff --git a/src/ndv/viewer/_viewer.py b/src/ndv/viewer/_viewer.py index 3355e2b..3bce41f 100644 --- a/src/ndv/viewer/_viewer.py +++ b/src/ndv/viewer/_viewer.py @@ -1,8 +1,7 @@ from __future__ import annotations -from collections import defaultdict from itertools import cycle -from typing import TYPE_CHECKING, Literal, Sequence, cast +from typing import TYPE_CHECKING, Iterator, Literal, Mapping, Sequence, cast import cmap from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget @@ -18,18 +17,19 @@ ) from ._backends import get_canvas +from ._backends._protocols import PImageHandle from ._data_wrapper import DataWrapper from ._dims_slider import DimsSliders +from ._lut_control import LutControl if TYPE_CHECKING: - from concurrent.futures import Future from typing import Any, Hashable, Iterable, TypeAlias + import numpy as np from qtpy.QtGui import QCloseEvent - from ._backends._protocols import PCanvas, PImageHandle + from ._backends._protocols import PCanvas from ._dims_slider import DimKey, Indices, Sizes - from ._lut_control import LutControl ImgKey: TypeAlias = Hashable # any mapping of dimensions to sizes @@ -47,7 +47,31 @@ cmap.Colormap("cubehelix"), cmap.Colormap("gray"), ] -ALL_CHANNELS = slice(None) +MONO_CHANNEL = -999999 + + +class Channel(Mapping[tuple, PImageHandle]): + def __init__( + self, ch_key: int, canvas: PCanvas, cmap: cmap.Colormap = GRAYS + ) -> None: + self.ch_key = ch_key + self._handles: dict[Any, PImageHandle] = {} + self.cmap = cmap + + def __getitem__(self, key: tuple) -> PImageHandle: + return self._handles[key] + + def __setitem__(self, key: tuple, value: PImageHandle) -> None: + self._handles[key] = value + + def __iter__(self) -> Iterator[tuple]: + yield from self._handles + + def __len__(self) -> int: + return len(self._handles) + + def __contains__(self, key: object) -> bool: + return key in self._handles class NDViewer(QWidget): @@ -111,7 +135,7 @@ def __init__( *, colormaps: Iterable[cmap._colormap.ColorStopsLike] | None = None, parent: QWidget | None = None, - channel_axis: DimKey | None = None, + channel_axis: int | None = None, channel_mode: ChannelMode | str = ChannelMode.MONO, ): super().__init__(parent=parent) @@ -119,16 +143,16 @@ def __init__( # ATTRIBUTES ---------------------------------------------------- # mapping of key to a list of objects that control image nodes in the canvas - self._img_handles: defaultdict[int, dict[tuple, PImageHandle]] = defaultdict( - dict - ) + self._channels: dict[int, Channel] = {} # mapping of same keys to the LutControl objects control image display props - self._lut_ctrls: dict[ImgKey, LutControl] = {} - # the set of dimensions we are currently visualizing (e.g. XY) + self._lut_ctrls: dict[int, LutControl] = {} + + # the set of dimensions we are currently visualizing (e.g. (-2, -1) for 2D) # this is used to control which dimensions have sliders and the behavior # of isel when selecting data from the datastore self._visualized_dims: set[DimKey] = set() + # the axis that represents the channels in the data self._channel_axis = channel_axis self._channel_mode: ChannelMode = None # type: ignore # set in set_channel_mode @@ -138,8 +162,6 @@ def __init__( else: self._cmaps = DEFAULT_COLORMAPS self._cmap_cycle = cycle(self._cmaps) - # the last future that was created by _request_data_for_index - self._last_future: Future | None = None # number of dimensions to display self._ndims: Literal[2, 3] = 2 @@ -268,13 +290,14 @@ def set_data( the initial index will be set to the middle of the data. """ # store the data + self._clear_images() + self._data_wrapper = DataWrapper.create(data) self._chunker.data_wrapper = self._data_wrapper if chunks := self._data_wrapper.chunks(): # temp hack ... always group non-visible channels chunks = list(chunks) chunks[:-2] = (1000,) * len(chunks[:-2]) - print(chunks) self._chunker.chunks = tuple(chunks) # set channel axis @@ -298,6 +321,7 @@ def set_data( raise TypeError("initial_index must be a dict") idx = initial_index self.set_current_index(idx) + # update the data info label self._data_info_label.setText(self._data_wrapper.summary_info()) @@ -329,7 +353,7 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None: self._dims_sliders.set_dimension_visible(dim3, True if ndim == 2 else False) # clear image handles and redraw - if self._img_handles: + if self._channels: self._clear_images() self._request_data_for_index(self._dims_sliders.value()) @@ -363,7 +387,7 @@ def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: self._channel_axis, mode != ChannelMode.COMPOSITE ) - if self._img_handles: + if self._channels: self._clear_images() self._request_data_for_index(self._dims_sliders.value()) @@ -430,79 +454,79 @@ def _request_data_for_index(self, index: Indices) -> None: makes a request for the new data slice and queues _on_data_future_done to be called when the data is ready. """ - print(f"\n--------\nrequesting index {index}") + print(f"\n--------\nrequesting index {index}", self._channel_axis) + if ( + self._channel_mode == ChannelMode.COMPOSITE + and self._channel_axis is not None + ): + index = {**index, self._channel_axis: slice(None)} self._progress_spinner.show() - self._chunker.request_index(index) + # TODO: don't request channels not being displayed + # TODO: don't request if the data is already in the cache + self._chunker.request_index(index, ndims=self._ndims) @ensure_main_thread # type: ignore - def _draw_chunk(self, response: ChunkResponse) -> None: + def _draw_chunk(self, chunk: ChunkResponse) -> None: """Actually update the image handle(s) with the (sliced) data. By this point, data should be sliced from the underlying datastore. Any dimensions remaining that are more than the number of visualized dimensions (currently just 2D) will be reduced using max intensity projection (currently). """ - if response is RequestFinished: # fix typing - print(">>>>>>>>>>>>> RequestFinished") + if chunk is RequestFinished: # fix typing self._progress_spinner.hide() + for lut in self._lut_ctrls.values(): + lut.update_autoscale() return if self._channel_mode == ChannelMode.MONO: - channel_key = -1 + ch_key = MONO_CHANNEL else: - channel_key = response.channel_index - offset = response.offset - datum = response.data - if ( - channel_handles := self._img_handles[channel_key] - ) and offset in channel_handles: - handle = channel_handles[offset] - handle.data = datum + ch_key = chunk.channel_index + + # TODO: Channel object creation could be moved. + # having it here is the laziest... but means that the order of arrival + # of the chunks will determine the order of the channels in the LUTS + # (without additional logic to sort them by index, etc.) + if (channel := self._channels.get(ch_key)) is None: + channel = self._create_channel(ch_key) + + data = chunk.data + if (offset := chunk.offset) in channel: + channel[offset].data = data else: - cm = DEFAULT_COLORMAPS[channel_key] # TODO - if datum.ndim == 2: - handle = self._canvas.add_image(datum, cmap=cm, offset=offset) - handle.clim = (0, 100) - channel_handles[offset] = handle - else: - raise NotImplementedError("Volume rendering not yet supported") + print(f"{data.ndim=}") + if data.ndim == 2: + _offset2 = (offset[-2], offset[-1]) if offset else None + handle = self._canvas.add_image(data, offset=_offset2) + elif data.ndim == 3: + _offset3 = (offset[-3], offset[-2], offset[-1]) if offset else None + handle = self._canvas.add_volume(data, offset=_offset3) + handle.cmap = channel.cmap + channel[offset] = handle self._canvas.refresh() - # if handles := self._img_handles[offset]: - # for handle in handles: - # handle.data = datum - # handle.clim = (0, 45000) - # # if ctrl := self._lut_ctrls.get(imkey, None): - # # ctrl.update_autoscale() - # else: - # cm = ( - # next(self._cmap_cycle) - # if self._channel_mode == ChannelMode.COMPOSITE - # else GRAYS - # ) - # if datum.ndim == 2: - # handle = self._canvas.add_image(datum, cmap=cm, offset=offset) - # handle.clim = (0, 45000) - # handles.append(handle) - # elif datum.ndim == 3: - # handles.append(self._canvas.add_volume(datum, cmap=cm)) - # if imkey not in self._lut_ctrls: - # # ch_index = index.get(self._channel_axis, 0) - # ch_index = 0 - # self._lut_ctrls[imkey] = c = LutControl( - # f"Ch {ch_index}", - # handles, - # self, - # cmaplist=self._cmaps + DEFAULT_COLORMAPS, - # ) - # self._lut_drop.addWidget(c) + def _create_channel(self, ch_key: int) -> Channel: + # improve this + cmap = GRAYS if ch_key == MONO_CHANNEL else next(self._cmap_cycle) + + self._channels[ch_key] = channel = Channel(ch_key, self._canvas, cmap=cmap) + self._lut_ctrls[ch_key] = lut = LutControl( + channel, + f"Ch {ch_key}", + self, + cmaplist=self._cmaps + DEFAULT_COLORMAPS, + cmap=cmap, + ) + self._lut_drop.addWidget(lut) + return channel def _clear_images(self) -> None: """Remove all images from the canvas.""" - for handles in self._img_handles.values(): + for handles in self._channels.values(): for handle in handles.values(): handle.remove() - self._img_handles.clear() + self._channels.clear() # clear the current LutControls as well for c in self._lut_ctrls.values(): From aeb0f4d4ba21dc1104e0a6253a23a7c6b000a892 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Jun 2024 20:24:51 +0000 Subject: [PATCH 06/12] style(pre-commit.ci): auto fixes [...] --- src/ndv/viewer/_viewer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ndv/viewer/_viewer.py b/src/ndv/viewer/_viewer.py index 3bce41f..1a28f82 100644 --- a/src/ndv/viewer/_viewer.py +++ b/src/ndv/viewer/_viewer.py @@ -25,7 +25,6 @@ if TYPE_CHECKING: from typing import Any, Hashable, Iterable, TypeAlias - import numpy as np from qtpy.QtGui import QCloseEvent from ._backends._protocols import PCanvas From 351e7ad86e4f5c5c7771f02ab3f415d32bea64f6 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 10 Jun 2024 20:43:34 -0400 Subject: [PATCH 07/12] 3d chunks! --- src/ndv/_chunking.py | 3 + src/ndv/viewer/_backends/_protocols.py | 3 +- src/ndv/viewer/_backends/_vispy.py | 10 ++- src/ndv/viewer/_lut_control.py | 40 +++++++++--- src/ndv/viewer/_viewer.py | 84 ++++++++++++++++---------- xx.py | 33 ++++++++++ 6 files changed, 130 insertions(+), 43 deletions(-) create mode 100644 xx.py diff --git a/src/ndv/_chunking.py b/src/ndv/_chunking.py index 59fc5b0..bbaa8e3 100644 --- a/src/ndv/_chunking.py +++ b/src/ndv/_chunking.py @@ -85,9 +85,12 @@ def _request_chunk_sync( # this needs to be aware of nvisible dimensions try: offset = tuple(int(getattr(sl, "start", sl)) for sl in idx)[-3:] + offset = (idx[0].start, idx[2].start, idx[3].start) except TypeError: offset = (0, 0, 0) + import time + time.sleep(0.05) return ChunkResponse( idx=idx, data=data, offset=offset, channel_index=channel_index ) diff --git a/src/ndv/viewer/_backends/_protocols.py b/src/ndv/viewer/_backends/_protocols.py index f45a655..de65877 100644 --- a/src/ndv/viewer/_backends/_protocols.py +++ b/src/ndv/viewer/_backends/_protocols.py @@ -13,6 +13,7 @@ class PImageHandle(Protocol): def data(self) -> np.ndarray: ... @data.setter def data(self, data: np.ndarray) -> None: ... + def set_data(self, data: np.ndarray, offset: tuple) -> None: ... @property def visible(self) -> bool: ... @visible.setter @@ -44,11 +45,9 @@ def add_image( self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ..., - offset: tuple[float, float] | None = None, # (Y, X) ) -> PImageHandle: ... def add_volume( self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ..., - offset: tuple[float, float, float] | None = ..., # (Z, Y, X) ) -> PImageHandle: ... diff --git a/src/ndv/viewer/_backends/_vispy.py b/src/ndv/viewer/_backends/_vispy.py index f8b04ed..897f42f 100644 --- a/src/ndv/viewer/_backends/_vispy.py +++ b/src/ndv/viewer/_backends/_vispy.py @@ -44,6 +44,10 @@ def data(self, data: np.ndarray) -> None: return self._visual.set_data(data) + def set_data(self, data: np.ndarray, offset: tuple) -> None: + print("Setting data", data.shape, offset) + self._visual._texture._set_data(data, offset=offset) + @property def visible(self) -> bool: return bool(self._visual.visible) @@ -161,7 +165,11 @@ def add_volume( offset: tuple[float, float, float] | None = None, # (Z, Y, X) ) -> VispyImageHandle: vol = scene.visuals.Volume( - data, parent=self._view.scene, interpolation="nearest" + data, + parent=self._view.scene, + interpolation="nearest", + texture_format="auto", + clim=(0, 40000), ) vol.set_gl_state("additive", depth_test=False) vol.interactive = True diff --git a/src/ndv/viewer/_lut_control.py b/src/ndv/viewer/_lut_control.py index b42005a..416a449 100644 --- a/src/ndv/viewer/_lut_control.py +++ b/src/ndv/viewer/_lut_control.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Mapping, cast +from typing import TYPE_CHECKING, Any, Sequence, cast import numpy as np from qtpy.QtCore import Qt @@ -35,7 +35,7 @@ def showPopup(self) -> None: class LutControl(QWidget): def __init__( self, - channel: Mapping[Any, PImageHandle], + channel: Sequence[PImageHandle], name: str = "", parent: QWidget | None = None, cmaplist: Iterable[Any] = (), @@ -51,7 +51,7 @@ def __init__( self._cmap = CmapCombo() self._cmap.currentColormapChanged.connect(self._on_cmap_changed) - for handle in channel.values(): + for handle in channel: self._cmap.addColormap(handle.cmap) for color in cmaplist: self._cmap.addColormap(color) @@ -87,17 +87,17 @@ def autoscaleChecked(self) -> bool: def _on_clims_changed(self, clims: tuple[float, float]) -> None: self._auto_clim.setChecked(False) - for handle in self._channel.values(): + for handle in self._channel: handle.clim = clims def _on_visible_changed(self, visible: bool) -> None: - for handle in self._channel.values(): + for handle in self._channel: handle.visible = visible if visible: self.update_autoscale() def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: - for handle in self._channel.values(): + for handle in self._channel: handle.cmap = cmap def update_autoscale(self) -> None: @@ -110,13 +110,13 @@ def update_autoscale(self) -> None: # find the min and max values for the current channel clims = [np.inf, -np.inf] - for handle in self._channel.values(): + for handle in self._channel: clims[0] = min(clims[0], np.nanmin(handle.data)) clims[1] = max(clims[1], np.nanmax(handle.data)) mi, ma = tuple(int(x) for x in clims) if mi != ma: - for handle in self._channel.values(): + for handle in self._channel: handle.clim = (mi, ma) # set the slider values to the new clims @@ -124,3 +124,27 @@ def update_autoscale(self) -> None: self._clims.setMinimum(min(mi, self._clims.minimum())) self._clims.setMaximum(max(ma, self._clims.maximum())) self._clims.setValue((mi, ma)) + + +def _get_default_clim_from_data(data: np.ndarray) -> tuple[float, float]: + """Compute a reasonable clim from the min and max, taking nans into account. + + If there are no non-finite values (nan, inf, -inf) this is as fast as it can be. + Otherwise, this functions is about 3x slower. + """ + # Fast + min_value = data.min() + max_value = data.max() + + # Need more work? The nan-functions are slower + min_finite = np.isfinite(min_value) + max_finite = np.isfinite(max_value) + if not (min_finite and max_finite): + finite_data = data[np.isfinite(data)] + if finite_data.size: + min_value = finite_data.min() + max_value = finite_data.max() + else: + min_value = max_value = 0 # no finite values in the data + + return min_value, max_value diff --git a/src/ndv/viewer/_viewer.py b/src/ndv/viewer/_viewer.py index 1a28f82..84b34f1 100644 --- a/src/ndv/viewer/_viewer.py +++ b/src/ndv/viewer/_viewer.py @@ -1,9 +1,19 @@ from __future__ import annotations from itertools import cycle -from typing import TYPE_CHECKING, Iterator, Literal, Mapping, Sequence, cast +from typing import ( + TYPE_CHECKING, + Hashable, + Literal, + MutableSequence, + Sequence, + SupportsIndex, + cast, + overload, +) import cmap +import numpy as np from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget from superqt import QCollapsible, QElidingLabel, QIconifyIcon, ensure_main_thread from superqt.utils import qthrottled, signals_blocked @@ -23,7 +33,7 @@ from ._lut_control import LutControl if TYPE_CHECKING: - from typing import Any, Hashable, Iterable, TypeAlias + from typing import Any, Iterable, TypeAlias from qtpy.QtGui import QCloseEvent @@ -49,28 +59,40 @@ MONO_CHANNEL = -999999 -class Channel(Mapping[tuple, PImageHandle]): - def __init__( - self, ch_key: int, canvas: PCanvas, cmap: cmap.Colormap = GRAYS - ) -> None: +class Channel(MutableSequence[PImageHandle]): + def __init__(self, ch_key: int, cmap: cmap.Colormap = GRAYS) -> None: self.ch_key = ch_key - self._handles: dict[Any, PImageHandle] = {} + self._handles: list[PImageHandle] = [] self.cmap = cmap - def __getitem__(self, key: tuple) -> PImageHandle: - return self._handles[key] - - def __setitem__(self, key: tuple, value: PImageHandle) -> None: - self._handles[key] = value + @overload + def __getitem__(self, i: int) -> PImageHandle: ... + @overload + def __getitem__(self, i: slice) -> list[PImageHandle]: ... + def __getitem__(self, i: int | slice) -> PImageHandle | list[PImageHandle]: + return self._handles[i] + + @overload + def __setitem__(self, i: SupportsIndex, value: PImageHandle) -> None: ... + @overload + def __setitem__(self, i: slice, value: Iterable[PImageHandle]) -> None: ... + def __setitem__( + self, i: SupportsIndex | slice, value: PImageHandle | Iterable[PImageHandle] + ) -> None: + self._handles[i] = value # type: ignore - def __iter__(self) -> Iterator[tuple]: - yield from self._handles + @overload + def __delitem__(self, i: int) -> None: ... + @overload + def __delitem__(self, i: slice) -> None: ... + def __delitem__(self, i: int | slice) -> None: + del self._handles[i] def __len__(self) -> int: return len(self._handles) - def __contains__(self, key: object) -> bool: - return key in self._handles + def insert(self, i: int, value: PImageHandle) -> None: + self._handles.insert(i, value) class NDViewer(QWidget): @@ -169,7 +191,7 @@ def __init__( # IMPORTANT # chunking here will determine how non-visualized dims are reduced # so chunkshape will need to change based on the set of visualized dims - chunks=32, + chunks=(20, 100, 32, 32), on_ready=self._draw_chunk, ) @@ -483,33 +505,31 @@ def _draw_chunk(self, chunk: ChunkResponse) -> None: else: ch_key = chunk.channel_index + data = chunk.data + if data.ndim == 2: + return # TODO: Channel object creation could be moved. # having it here is the laziest... but means that the order of arrival # of the chunks will determine the order of the channels in the LUTS # (without additional logic to sort them by index, etc.) - if (channel := self._channels.get(ch_key)) is None: - channel = self._create_channel(ch_key) + if (handles := self._channels.get(ch_key)) is None: + handles = self._create_channel(ch_key) - data = chunk.data - if (offset := chunk.offset) in channel: - channel[offset].data = data - else: - print(f"{data.ndim=}") + if not handles: if data.ndim == 2: - _offset2 = (offset[-2], offset[-1]) if offset else None - handle = self._canvas.add_image(data, offset=_offset2) + handles.append(self._canvas.add_image(data, cmap=handles.cmap)) elif data.ndim == 3: - _offset3 = (offset[-3], offset[-2], offset[-1]) if offset else None - handle = self._canvas.add_volume(data, offset=_offset3) - handle.cmap = channel.cmap - channel[offset] = handle + empty = np.empty((60, 256, 256), dtype=np.uint16) + handles.append(self._canvas.add_volume(empty, cmap=handles.cmap)) + + handles[0].set_data(data, chunk.offset) self._canvas.refresh() def _create_channel(self, ch_key: int) -> Channel: # improve this cmap = GRAYS if ch_key == MONO_CHANNEL else next(self._cmap_cycle) - self._channels[ch_key] = channel = Channel(ch_key, self._canvas, cmap=cmap) + self._channels[ch_key] = channel = Channel(ch_key, cmap=cmap) self._lut_ctrls[ch_key] = lut = LutControl( channel, f"Ch {ch_key}", @@ -523,7 +543,7 @@ def _create_channel(self, ch_key: int) -> Channel: def _clear_images(self) -> None: """Remove all images from the canvas.""" for handles in self._channels.values(): - for handle in handles.values(): + for handle in handles: handle.remove() self._channels.clear() diff --git a/xx.py b/xx.py new file mode 100644 index 0000000..5baef72 --- /dev/null +++ b/xx.py @@ -0,0 +1,33 @@ +import numpy as np +from rich import print +from vispy import app, io, scene + +from ndv._chunking import iter_chunk_aligned_slices + +vol1 = np.load(io.load_data_file("volume/stent.npz"))["arr_0"].astype(np.uint16) + +canvas = scene.SceneCanvas(keys="interactive", size=(800, 600), show=True) +view = canvas.central_widget.add_view() +print("--------- create vol") +volume1 = scene.Volume( + np.empty_like(vol1), parent=view.scene, texture_format="auto", clim=(0, 1200) +) +print("--------- create cam") +view.camera = scene.cameras.ArcballCamera(parent=view.scene, name="Arcball") + +# Generate new data to update a subset of the volume + + +slices = iter_chunk_aligned_slices( + vol1.shape, chunks=(32, 32, 32), slices=(slice(None), slice(None), slice(None)) +) + +for slice in list(slices)[::1]: + offset = (x.start for x in slice) + chunk = vol1[slice] + # Update the texture with the new data at the calculated offset + print("--------- update vol") + volume1._texture._set_data(chunk, offset=tuple(offset)) +canvas.update() +print("--------- run app") +app.run() From cbc0107e4b5e135ee3f8a52c24c77c20ab369d64 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 11 Jun 2024 08:15:33 -0400 Subject: [PATCH 08/12] notes on model --- src/ndv/_chunking.py | 4 +- src/ndv/viewer/_state.py | 123 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 src/ndv/viewer/_state.py diff --git a/src/ndv/_chunking.py b/src/ndv/_chunking.py index bbaa8e3..e16944d 100644 --- a/src/ndv/_chunking.py +++ b/src/ndv/_chunking.py @@ -88,9 +88,9 @@ def _request_chunk_sync( offset = (idx[0].start, idx[2].start, idx[3].start) except TypeError: offset = (0, 0, 0) - import time + # import time - time.sleep(0.05) + # time.sleep(0.05) return ChunkResponse( idx=idx, data=data, offset=offset, channel_index=channel_index ) diff --git a/src/ndv/viewer/_state.py b/src/ndv/viewer/_state.py new file mode 100644 index 0000000..5a80162 --- /dev/null +++ b/src/ndv/viewer/_state.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import sys +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Mapping, Protocol, Sequence + +import numpy as np + +if TYPE_CHECKING: + import cmap + +# either the name or the index of a dimension +DimKey = str | int +# position or slice along a specific dimension +Index = int | slice +# name or dimension index of a channel +# string is only supported for arrays with string-type coordinates along the channel dim +# None is a special value that means all channels +ChannelKey = int | str | None + +SLOTS = {"slots": True} if sys.version_info >= (3, 10) else {} + + +@dataclass(**SLOTS) +class ChannelDisplay: + visible: bool = True + cmap: cmap._colormap.ColorStopsLike = "gray" + clims: tuple[float, float] | None = None + gamma: float = 1.0 + # whether to autoscale + # if a tuple the first element is the lower quantile + # and the second element is the upper quantile + # if True or (0, 1) use (np.min(), np.max()) ... otherwise use np.quantile + autoscale: bool | tuple[float, float] = (0, 1) + + +@dataclass(**SLOTS) +class ViewerState: + # index of the currently displayed slice + # for example (-2, -1) for the standard 2D viewer + # if string, then name2index is used to convert to index + visualized_indices: tuple[DimKey, DimKey] | tuple[DimKey, DimKey, DimKey] = (-2, -1) + + # the currently displayed position/slice along each dimension + # missing indices are assumed to be slice(None) (or 0?) + # if more than len(visualized_indices) have non-integer values, then + # reducers are used to reduce the data along the remaining dimensions + current_index: Mapping[DimKey, Index] = field(default_factory=dict) + + # functions to reduce data along axes remaining after slicing + reducers: Reducer | Mapping[DimKey, Reducer] = np.max + + # note: it is an error for channel_index to be in visualized_indices + channel_index: DimKey | None = None + + # settings for each channel along the channel dimension + # None is a special value that means all channels + # if channel_index is None, then luts[None] is used + luts: Mapping[ChannelKey, ChannelDisplay] = field(default_factory=dict) + # default colormap to use for channel [0, 1, 2, ...] + colormap_options: Sequence[cmap._colormap.ColorStopsLike] = ("gray",) + + +class Reducer(Protocol): + def __call__( + self, data: np.ndarray, /, *, axis: int | tuple[int, ...] | None + ) -> np.ndarray | float: ... + + +class NDViewer: + def __init__(self, data: Any, state: ViewerState | None) -> None: + self._state = state or ViewerState() + if data is not None: + self.set_data(data) + + @property + def data(self) -> Any: + raise NotImplementedError + + def set_data(self, data: Any) -> None: ... + + @property + def state(self) -> ViewerState: + return self._state + + def set_state(self, state: ViewerState) -> None: + # validate... + self._state = state + + def set_visualized_indices(self, indices: tuple[DimKey, DimKey]) -> None: + """Set which indices are visualized.""" + if self._state.channel_index in indices: + raise ValueError( + f"channel index ({self._state.channel_index!r}) cannot be in visualized" + f"indices: {indices}" + ) + self._state.visualized_indices = indices + self.refresh() + + def set_channel_index(self, index: DimKey | None) -> None: + """Set the channel index.""" + if index in self._state.visualized_indices: + # consider alternatives to raising. + # e.g. if len(visualized_indices) == 3, then we could pop index + raise ValueError( + f"channel index ({index!r}) cannot be in visualized indices: " + f"{self._state.visualized_indices}" + ) + self._state.channel_index = index + self.refresh() + + def set_current_index(self, index: Mapping[DimKey, Index]) -> None: + """Set the currentl displayed index.""" + self._state.current_index = index + self.refresh() + + def refresh(self) -> None: + """Refresh the viewer.""" + index = self._state.current_index + self._chunker.request_index(index) + + @ensure_main_thread # type: ignore + def _draw_chunk(self, chunk: ChunkResponse) -> None: From dee420f031edbfbef2a6fdc843ade79115338026 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 11 Jun 2024 17:28:42 -0400 Subject: [PATCH 09/12] working on new models --- src/ndv/_chunk_executor.py | 425 ++++++++++++++++++ src/ndv/_chunking.py | 25 +- src/ndv/viewer/_backends/__init__.py | 2 +- src/ndv/viewer/_backends/_vispy.py | 7 +- .../_backends/{_protocols.py => protocols.py} | 0 src/ndv/viewer/_lut_control.py | 2 +- src/ndv/viewer/_state.py | 62 +-- src/ndv/viewer/_v2.py | 193 ++++++++ src/ndv/viewer/_viewer.py | 4 +- tests/test_chunker.py | 109 +++++ y.py | 4 +- z.py | 45 +- 12 files changed, 767 insertions(+), 111 deletions(-) create mode 100644 src/ndv/_chunk_executor.py rename src/ndv/viewer/_backends/{_protocols.py => protocols.py} (100%) create mode 100644 src/ndv/viewer/_v2.py create mode 100644 tests/test_chunker.py diff --git a/src/ndv/_chunk_executor.py b/src/ndv/_chunk_executor.py new file mode 100644 index 0000000..09500aa --- /dev/null +++ b/src/ndv/_chunk_executor.py @@ -0,0 +1,425 @@ +from __future__ import annotations + +import math +from collections import deque +from concurrent.futures import Executor, Future, ThreadPoolExecutor +from itertools import product +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Iterator, + Literal, + Mapping, + NamedTuple, + Protocol, + Sequence, + SupportsIndex, + cast, +) + +import numpy as np + +if TYPE_CHECKING: + from types import EllipsisType + from typing import TypeAlias + + import numpy.typing as npt + + class SupportsDunderLT(Protocol): + def __lt__(self, other: Any) -> bool: ... + + class SupportsDunderGT(Protocol): + def __gt__(self, other: Any) -> bool: ... + + SupportsComparison: TypeAlias = SupportsDunderLT | SupportsDunderGT + +NULL = object() + + +class SupportsChunking(Protocol): + @property + def shape(self) -> Sequence[int]: ... + def __getitem__(self, idx: tuple[int | slice, ...]) -> npt.ArrayLike: ... + + +class ChunkResponse(NamedTuple): + # location in the original array + location: tuple[int | slice, ...] + # the data that was returned + data: np.ndarray + + @property + def offset(self) -> tuple[int, ...]: + return tuple(i.start if isinstance(i, slice) else i for i in self.location) + + +ChunkFuture = Future[ChunkResponse] + + +class Chunker: + def __init__(self, executor: Executor | None = None) -> None: + self._executor = executor or self._default_executor() + self._pending_futures: deque[ChunkFuture] = deque() + self._request_chunk = _get_chunk + + @classmethod + def _default_executor(cls) -> ThreadPoolExecutor: + return ThreadPoolExecutor(thread_name_prefix=cls.__name__) + + def is_idle(self) -> bool: + return all(f.done() for f in self._pending_futures) + + def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: + self._executor.shutdown(wait=wait, cancel_futures=cancel_futures) + + def __enter__(self) -> Chunker: + return self + + def __exit__(self, *_: Any) -> Literal[False]: + self.shutdown(wait=True) + return False + + def request_chunks( + self, + data: SupportsChunking, + index: Mapping[int, int | slice] | tuple[int | slice, ...] | None = None, + # None implies no chunking + chunk_shape: int | None | Sequence[int | None] = None, + *, + sort_key: Callable[[tuple[int | slice, ...]], SupportsComparison] | None = NULL, # type: ignore + cancel_existing: bool = False, + ) -> list[ChunkFuture]: + """Request chunks from `data` based on the given `index` and `chunk_shape`. + + Parameters + ---------- + data : SupportsChunking + The data to request chunks from. Must have a `shape` attribute and support + indexing (`__getitem__`) with a tuple of int or slice. + index : Mapping[int, int | slice] | tuple[int | slice, ...] | None + A subarray to request. + If a Mapping, it should look like {dim: index} where `index` is a single + index or a slice and dim is the dimension to index. + If an tuple, it should be a regular tuple of integer or slice: + e.g. (6, 0, slice(None, None, None), slice(1, 10, None)) + If `None` (default), the full array is requested. + chunk_shape : int | tuple[int, ...] | None + The shape of each chunk. If a single int, the same size is used for all + dimensions. If `None`, no chunking is done. Note that chunk shape applies + to the data *prior* to indexing. Chunks will be aligned with the original + data, not the indexed data... so given an axis `0` with length 100, if you + request a slice from that index `index={0: slice(40,60)}` and provide a + `chunk_shape=50`, you will get two chunks: (40, 50) and (50, 60). + The intention is that chunk_shape should align with the chunk layout of the + original data, to optimize reading from disk or other sources, even when + reading a subset of the data that is not aligned with the chunks. + sort_key: Callable[[tuple[slice, ...]], SupportsComparison] | None + A function to sort the chunks before submitting them. This can be used to + prioritize chunks that are more likely to be needed first (such as those + within a certain distance of the current view). The function should take + a tuple of slices and return a value that can be compared with `<` and `>`. + If None, no sorting is done. + cancel_existing : bool + If True, cancel any existing pending futures before submitting new ones. + + Returns + ------- + list[Future[ChunkResponse]] + A list of futures that will contain the requested chunks when they are + available. Use `Future.add_done_callback` to register a callback to handle + the results. + """ + if cancel_existing: + for future in list(self._pending_futures): + future.cancel() + + if index is None: + index = tuple(slice(None) for _ in range(len(data.shape))) + + if isinstance(index, Mapping): + index = indexers_to_conventional_slice(index) + + # at this point, index is a tuple of int or slice + # e.g. (6, 0, slice(None, None, None), slice(1, 10, None)) + # now, determine the subchunk indices to request + if chunk_shape is None: + # TODO: check whether we need to cast this to something without integers. + indices: Iterable[tuple[int | slice, ...]] = [index] + else: + indices = iter_chunk_aligned_slices(data.shape, chunk_shape, index) + if sort_key is not None: + if sort_key is NULL: + + def sort_key(x: tuple[int | slice, ...]) -> SupportsComparison: + return distance_from_coord(x, data.shape) + + indices = sorted(indices, key=sort_key) + + # submit the a request for each subchunk + futures = [] + for chunk_index in indices: + future = self._executor.submit(self._request_chunk, data, chunk_index) + self._pending_futures.append(future) + future.add_done_callback(self._pending_futures.remove) + futures.append(future) + return futures + + +def _get_chunk(data: SupportsChunking, index: tuple[int | slice, ...]) -> ChunkResponse: + chunk_data = _reduce_data_for_display(data[index], len(data.shape)) + # import time + + # time.sleep(0.05) + return ChunkResponse(location=index, data=chunk_data) + + +def indexers_to_conventional_slice( + indexers: Mapping[int, int | slice], ndim: int | None = None +) -> tuple[int | slice, ...]: + """Convert Mapping of {dim: index} to a conventional tuple of int or slice. + + `indexers` need not be ordered. If `ndim` is not provided, it is inferred + from the maximum key in `indexers`. + + Parameters + ---------- + indexers : Mapping[int, int | slice] + Mapping of {dim: index} where `index` is a single index or a slice. + ndim : int | None + Number of dimensions. If None, inferred from the maximum key in `indexers`. + + Examples + -------- + >>> indexers_to_conventional_slice({1: 0, 0: 6, 3: slice(1, 10, None)}) + (6, 0, slice(None, None, None), slice(1, 10, None)) + + """ + if not indexers: + return (slice(None),) + + if ndim is None: + ndim = max(indexers) + 1 + return tuple(indexers.get(k, slice(None)) for k in range(ndim)) + + +def _slice2range(sl: SupportsIndex | slice, dim_size: int) -> tuple[int, int]: + """Convert slice to range, handling single int as well. + + Examples + -------- + >>> _slice2range(3, 10) + (3, 4) + """ + if not isinstance(sl, slice): + idx = sl.__index__() + return (idx, idx + 1) + start = 0 if sl.start is None else max(sl.start, 0) + stop = dim_size if sl.stop is None else min(sl.stop, dim_size) + return (start, stop) + + +def iter_chunk_aligned_slices( + shape: Sequence[int], + chunks: int | Sequence[int | None], + slices: Sequence[int | slice | EllipsisType], +) -> Iterator[tuple[slice, ...]]: + """Yield chunk-aligned slices for a given shape and slices. + + Parameters + ---------- + shape : tuple[int, ...] + The shape of the array to slice. + chunks : int or tuple[int, ...] + The size of each chunk. If a single int, the same size is used for all + dimensions. + slices : tuple[int | slice | Ellipsis, ...] + The full slices to apply to the array. Ellipsis is supported to + represent multiple slices. + + Returns + ------- + Iterator[tuple[slice, ...]] + An iterator of chunk-aligned slices. + + Raises + ------ + ValueError + If the length of `chunks`, `shape`, and `slices` do not match, or any chunks + are zero. + IndexError + If more than one Ellipsis is present in `slices`. + + Examples + -------- + >>> list( + ... iter_chunk_aligned_slices(shape=(6, 6), chunks=4, slices=(slice(1, 4), ...)) + ... ) + [ + (slice(1, 4, None), slice(0, 4, None)), + (slice(1, 4, None), slice(4, 6, None)), + ] + + >>> x = iter_chunk_aligned_slices( + ... shape=(10, 9), chunks=(4, 3), slices=(slice(3, 9), slice(1, None)) + ... ) + >>> list(x) + [ + (slice(3, 4, None), slice(1, 3, None)), + (slice(3, 4, None), slice(3, 6, None)), + (slice(3, 4, None), slice(6, 9, None)), + (slice(4, 8, None), slice(1, 3, None)), + (slice(4, 8, None), slice(3, 6, None)), + (slice(4, 8, None), slice(6, 9, None)), + (slice(8, 9, None), slice(1, 3, None)), + (slice(8, 9, None), slice(3, 6, None)), + (slice(8, 9, None), slice(6, 9, None)), + ] + """ + # Make chunks same length as shape if single int + ndim = len(shape) + if isinstance(chunks, int): + chunks = (chunks,) * ndim + elif not len(chunks) == ndim: + raise ValueError("Length of `chunks` must match length of `shape`") + + if any(x == 0 for x in chunks): + raise ValueError("Chunk size must be greater than zero") + + # convert any `None` chunks to full size of the dimension + chunks = tuple(x if x is not None else shape[i] for i, x in enumerate(chunks)) + + if num_ellipsis := slices.count(Ellipsis): + if num_ellipsis > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + # Replace Ellipsis with multiple slices + el_idx = slices.index(Ellipsis) + n_remaining = ndim - len(slices) + 1 + slices = ( + tuple(slices[:el_idx]) + + (slice(None),) * n_remaining + + tuple(slices[el_idx + 1 :]) + ) + slices = cast(tuple[int | slice, ...], slices) # now we have no Ellipsis + if not ndim == len(slices): + # Fill in remaining dimensions with full slices + slices = slices + (slice(None),) * (ndim - len(slices)) + + # Create ranges for each dimension based on the slices provided + ranges = [_slice2range(sl, dim) for sl, dim in zip(slices, shape)] + + # Generate indices for each dimension that align with chunks + aligned_ranges = ( + range(start - (start % chunk_size), stop, chunk_size) + for (start, stop), chunk_size in zip(ranges, chunks) + ) + + # Create all combinations of these aligned ranges + for indices in product(*aligned_ranges): + chunk_slices = [] + for idx, (start, stop), ch in zip(indices, ranges, chunks): + # Calculate the actual slice for each dimension + start = max(start, idx) + stop = min(stop, idx + ch) + if start >= stop: # Skip empty slices + break + chunk_slices.append(slice(start, stop)) + else: + # Only add this combination of slices if all dimensions are valid + yield tuple(chunk_slices) + + +def _slice_center(s: slice | int, dim_size: int) -> float: + """Calculate the center of a slice based on its start and stop attributes.""" + if isinstance(s, int): + return s + start = float(s.start) if s.start is not None else 0 + stop = float(s.stop) if s.stop is not None else dim_size + return (start + stop) / 2 + + +def distance_from_coord( + slice_tuple: Sequence[slice | int], + shape: Sequence[int], + coord: Iterable[float] = (), # defaults to center of shape +) -> float: + """Euclidean distance from the center of an nd slice to the center of shape.""" + if not coord: + coord = (dim / 2 for dim in shape) + slice_centers = (_slice_center(s, dim) for s, dim in zip(slice_tuple, shape)) + return math.hypot(*(sc - cc for sc, cc in zip(slice_centers, coord))) + + +def _reduce_data_for_display( + data: npt.ArrayLike, ndims: int, reductor: Callable[..., np.ndarray] = np.max +) -> np.ndarray: + """Reduce the number of dimensions in the data for display. + + This function takes a data array and reduces the number of dimensions to + the max allowed for display. The default behavior is to reduce the smallest + dimensions, using np.max. This can be improved in the future. + + This also coerces 64-bit data to 32-bit data. + """ + # TODO + # - allow dimensions to control how they are reduced (as opposed to just max) + # - for better way to determine which dims need to be reduced (currently just + # the first extra dims) + data = np.asarray(data).squeeze() + if extra_dims := data.ndim - ndims: + axis = tuple(range(extra_dims)) + data = reductor(data, axis=axis) + + if data.dtype.itemsize > 4: # More than 32 bits + if np.issubdtype(data.dtype, np.integer): + data = data.astype(np.int32) + else: + data = data.astype(np.float32) + return data + + +# class DaskChunker: +# def __init__(self) -> None: +# try: +# import dask +# import dask.array as da +# from dask.distributed import Client +# except ImportError as e: +# raise ImportError("Dask is required for DaskChunker") from e +# self._dask = dask +# self._da = da +# self._client = Client() + +# def request_chunks( +# self, +# data: SupportsChunking, +# index: Mapping[int, int | slice] | IndexTuple | None = None, +# chunk_shape: int | tuple[int, ...] | None = None, # None implies no chunking +# *, +# cancel_existing: bool = False, +# ) -> list[Future[ChunkResponse]]: +# if isinstance(index, Mapping): +# index = indexers_to_conventional_slice(index) + +# if isinstance(data, self._da.Array): # type: ignore +# dask_data = data +# else: +# dask_data = self._da.from_array(data, chunks=chunk_shape) # type: ignore + +# subarray = dask_data[index] +# block_ranges = (range(x) for x in subarray.numblocks) +# for blk in product(*(block_ranges)): +# offset = tuple(sum(sizes[:x]) for sizes, x in zip(subarray.chunks, blk)) +# chunk = subarray.blocks[blk] +# future = self._client.compute(chunk) + +# @no_type_check +# def _set_result(_chunk=chunk, _future=future): +# _future.set_result( +# ChunkResponse(idx=index, data=_chunk.compute(), offset=offset) +# ) + +# futures.append() + +# return [data] diff --git a/src/ndv/_chunking.py b/src/ndv/_chunking.py index e16944d..4137f07 100644 --- a/src/ndv/_chunking.py +++ b/src/ndv/_chunking.py @@ -188,24 +188,24 @@ def _reduce_data_for_display( return data -def _slice2range(sl: slice | int, dim_size: int) -> range: +def _slice2range(sl: slice | int, dim_size: int) -> tuple[int, int]: """Convert slice to range, handling single int as well. Examples -------- >>> _slice2range(3, 10) - range(3, 4) + (3, 4) """ if isinstance(sl, int): - return range(sl, sl + 1) + return (sl, sl + 1) start = 0 if sl.start is None else max(sl.start, 0) stop = dim_size if sl.stop is None else min(sl.stop, dim_size) - return range(start, stop) + return (start, stop) def iter_chunk_aligned_slices( shape: Sequence[int], - chunks: Sequence[int], + chunks: int | Sequence[int], slices: Sequence[int | slice | EllipsisType], ) -> Iterator[tuple[slice, ...]]: """Yield chunk-aligned slices for a given shape and slices. @@ -236,7 +236,9 @@ def iter_chunk_aligned_slices( Examples -------- - >>> list(iter_chunk_aligned_slices(shape=(6, 6), chunks=4, (slice(1, 4), ...))) + >>> list( + ... iter_chunk_aligned_slices(shape=(6, 6), chunks=4, slices=(slice(1, 4), ...)) + ... ) [ (slice(1, 4, None), slice(0, 4, None)), (slice(1, 4, None), slice(4, 6, None)), @@ -260,6 +262,8 @@ def iter_chunk_aligned_slices( """ # Make chunks same length as shape if single int ndim = len(shape) + if isinstance(chunks, int): + chunks = (chunks,) * ndim if any(x == 0 for x in chunks): raise ValueError("Chunk size must be greater than zero") @@ -284,16 +288,17 @@ def iter_chunk_aligned_slices( # Generate indices for each dimension that align with chunks aligned_ranges = ( - range(r.start - (r.start % ch), r.stop, ch) for r, ch in zip(ranges, chunks) + range(start - (start % chunk_size), stop, chunk_size) + for (start, stop), chunk_size in zip(ranges, chunks) ) # Create all combinations of these aligned ranges for indices in product(*aligned_ranges): chunk_slices = [] - for idx, rng, ch in zip(indices, ranges, chunks): + for idx, (start, stop), ch in zip(indices, ranges, chunks): # Calculate the actual slice for each dimension - start = max(rng.start, idx) - stop = min(rng.stop, idx + ch) + start = max(start, idx) + stop = min(stop, idx + ch) if start >= stop: # Skip empty slices break chunk_slices.append(slice(start, stop)) diff --git a/src/ndv/viewer/_backends/__init__.py b/src/ndv/viewer/_backends/__init__.py index 310c2be..2ebf248 100644 --- a/src/ndv/viewer/_backends/__init__.py +++ b/src/ndv/viewer/_backends/__init__.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ndv.viewer._backends._protocols import PCanvas + from ndv.viewer._backends.protocols import PCanvas def get_canvas(backend: str | None = None) -> type[PCanvas]: diff --git a/src/ndv/viewer/_backends/_vispy.py b/src/ndv/viewer/_backends/_vispy.py index 897f42f..d4adfd3 100644 --- a/src/ndv/viewer/_backends/_vispy.py +++ b/src/ndv/viewer/_backends/_vispy.py @@ -44,8 +44,11 @@ def data(self, data: np.ndarray) -> None: return self._visual.set_data(data) + def clear(self) -> None: + offset = (0,) * self.data.ndim + self.set_data(np.zeros(self.data.shape, dtype=self.data.dtype), offset) + def set_data(self, data: np.ndarray, offset: tuple) -> None: - print("Setting data", data.shape, offset) self._visual._texture._set_data(data, offset=offset) @property @@ -169,7 +172,6 @@ def add_volume( parent=self._view.scene, interpolation="nearest", texture_format="auto", - clim=(0, 40000), ) vol.set_gl_state("additive", depth_test=False) vol.interactive = True @@ -206,7 +208,6 @@ def set_range( is_3d = isinstance(self._camera, scene.ArcballCamera) if is_3d: self._camera._quaternion = DEFAULT_QUATERNION - print("Setting range", x, y, z, margin) self._view.camera.set_range(x=x, y=y, z=z, margin=margin) if is_3d: max_size = max(self._current_shape) diff --git a/src/ndv/viewer/_backends/_protocols.py b/src/ndv/viewer/_backends/protocols.py similarity index 100% rename from src/ndv/viewer/_backends/_protocols.py rename to src/ndv/viewer/_backends/protocols.py diff --git a/src/ndv/viewer/_lut_control.py b/src/ndv/viewer/_lut_control.py index 416a449..1b9c1b3 100644 --- a/src/ndv/viewer/_lut_control.py +++ b/src/ndv/viewer/_lut_control.py @@ -16,7 +16,7 @@ import cmap - from ._backends._protocols import PImageHandle + from ._backends.protocols import PImageHandle class CmapCombo(QColormapComboBox): diff --git a/src/ndv/viewer/_state.py b/src/ndv/viewer/_state.py index 5a80162..46d5d63 100644 --- a/src/ndv/viewer/_state.py +++ b/src/ndv/viewer/_state.py @@ -2,7 +2,7 @@ import sys from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Mapping, Protocol, Sequence +from typing import TYPE_CHECKING, Literal, Mapping, Protocol, Sequence import numpy as np @@ -60,64 +60,12 @@ class ViewerState: # default colormap to use for channel [0, 1, 2, ...] colormap_options: Sequence[cmap._colormap.ColorStopsLike] = ("gray",) + @property + def ndim(self) -> Literal[2, 3]: + return 2 if len(self.visualized_indices) == 2 else 3 + class Reducer(Protocol): def __call__( self, data: np.ndarray, /, *, axis: int | tuple[int, ...] | None ) -> np.ndarray | float: ... - - -class NDViewer: - def __init__(self, data: Any, state: ViewerState | None) -> None: - self._state = state or ViewerState() - if data is not None: - self.set_data(data) - - @property - def data(self) -> Any: - raise NotImplementedError - - def set_data(self, data: Any) -> None: ... - - @property - def state(self) -> ViewerState: - return self._state - - def set_state(self, state: ViewerState) -> None: - # validate... - self._state = state - - def set_visualized_indices(self, indices: tuple[DimKey, DimKey]) -> None: - """Set which indices are visualized.""" - if self._state.channel_index in indices: - raise ValueError( - f"channel index ({self._state.channel_index!r}) cannot be in visualized" - f"indices: {indices}" - ) - self._state.visualized_indices = indices - self.refresh() - - def set_channel_index(self, index: DimKey | None) -> None: - """Set the channel index.""" - if index in self._state.visualized_indices: - # consider alternatives to raising. - # e.g. if len(visualized_indices) == 3, then we could pop index - raise ValueError( - f"channel index ({index!r}) cannot be in visualized indices: " - f"{self._state.visualized_indices}" - ) - self._state.channel_index = index - self.refresh() - - def set_current_index(self, index: Mapping[DimKey, Index]) -> None: - """Set the currentl displayed index.""" - self._state.current_index = index - self.refresh() - - def refresh(self) -> None: - """Refresh the viewer.""" - index = self._state.current_index - self._chunker.request_index(index) - - @ensure_main_thread # type: ignore - def _draw_chunk(self, chunk: ChunkResponse) -> None: diff --git a/src/ndv/viewer/_v2.py b/src/ndv/viewer/_v2.py new file mode 100644 index 0000000..8ad43cc --- /dev/null +++ b/src/ndv/viewer/_v2.py @@ -0,0 +1,193 @@ +from typing import TYPE_CHECKING, Any, Mapping + +import numpy as np +from qtpy.QtWidgets import QVBoxLayout, QWidget +from superqt import ensure_main_thread + +from ndv._chunk_executor import Chunker, ChunkFuture +from ndv.viewer._backends import get_canvas +from ndv.viewer._state import ViewerState + +if TYPE_CHECKING: + from ndv.viewer._backends.protocols import PCanvas, PImageHandle + + +class NDViewer(QWidget): + def __init__(self, data: Any, *, parent: QWidget | None = None): + super().__init__(parent=parent) + self._state = ViewerState(visualized_indices=(0, 2, 3)) + self._chunker = Chunker() + self._channels: dict[int | None, PImageHandle] = {} + + self._canvas: PCanvas = get_canvas()(lambda x: None) + self._canvas.set_ndim(self._state.ndim) + + layout = QVBoxLayout(self) + layout.setSpacing(2) + layout.setContentsMargins(6, 6, 6, 6) + layout.addWidget(self._canvas.qwidget(), 1) + + if data is not None: + self.set_data(data) + + def __del__(self) -> None: + self._chunker.shutdown(cancel_futures=True, wait=False) + + def set_data(self, data: Any) -> None: + self._data = data + + def set_current_index(self, index: Mapping[int | str, int | slice]) -> None: + """Set the currentl displayed index.""" + self._state.current_index = index + self.refresh() + + def refresh(self) -> None: + self._request_data_for_index(self._state.current_index) + + def _norm_index(self, index: int | str) -> int: + """Remove string keys from index.""" + # TODO: this is a temporary solution + # the datawrapper __getitem__ should handle this + dim_names = () + if index in dim_names: + return dim_names.index(index) + elif isinstance(index, int): + return index + raise ValueError(f"Invalid index: {index}") + + def _request_data_for_index(self, index: Mapping[int | str, int | slice]) -> None: + ndim = len(self._data.shape) + + # determine chunk shape + # only visualized dimensions are chunked + chunk_size = 64 # TODO: pick bettter + chunk_shape: list[int | None] = [None] * ndim + visualized = [self._norm_index(dim) for dim in self._state.visualized_indices] + for dim in range(ndim): + if dim in visualized: + chunk_shape[dim] = chunk_size + + index = {self._norm_index(k): v for k, v in index.items()} + print("--------") + print("chunk shape", chunk_shape) + print("index", index) + + # clear existing handles + for handle in self._channels.values(): + handle.clear() + + for future in self._chunker.request_chunks( + data=self._data, + index=index, + chunk_shape=chunk_shape, + cancel_existing=True, + ): + future.add_done_callback(self._draw_chunk) + + @ensure_main_thread # type: ignore + def _draw_chunk(self, future: ChunkFuture) -> None: + if future.cancelled(): + return + if future.exception(): + print("ERROR: ", future.exception()) + return + + chunk = future.result() + data = chunk.data + offset = chunk.offset + + if self._state.channel_index is None: + channel_index = None + else: + channel_index = offset[self._norm_index(self._state.channel_index)] + + visualized = [self._norm_index(dim) for dim in self._state.visualized_indices] + offset = tuple(offset[i] for i in visualized) + + if data.ndim == 2: + return + + if not (handle := self._channels.get(channel_index)): + full_shape = self._data.shape + texture_shape = tuple( + full_shape[self._norm_index(i)] for i in self._state.visualized_indices + ) + empty = np.empty(texture_shape, dtype=chunk.data.dtype) + self._channels[channel_index] = handle = self._canvas.add_volume(empty) + + mi, ma = handle.clim + handle.clim = (min(mi, np.min(data)), max(ma, np.max(data))) + handle.set_data(data, offset) + self._canvas.refresh() + print("drawn chunk") + + # # of the chunks will determine the order of the channels in the LUTS + # # (without additional logic to sort them by index, etc.) + # if (handles := self._channels.get(ch_key)) is None: + # handles = self._create_channel(ch_key) + + # if not handles: + # if data.ndim == 2: + # handles.append(self._canvas.add_image(data, cmap=handles.cmap)) + # elif data.ndim == 3: + # empty = np.empty((60, 256, 256), dtype=np.uint16) + # handles.append(self._canvas.add_volume(empty, cmap=handles.cmap)) + + # handles[0].set_data(data, chunk.offset) + # self._canvas.refresh() + + +# class NDViewer: +# def __init__(self, data: Any, state: ViewerState | None) -> None: +# self._state = state or ViewerState() +# if data is not None: +# self.set_data(data) + +# @property +# def data(self) -> Any: +# raise NotImplementedError + +# def set_data(self, data: Any) -> None: ... + +# @property +# def state(self) -> ViewerState: +# return self._state + +# def set_state(self, state: ViewerState) -> None: +# # validate... +# self._state = state + +# def set_visualized_indices(self, indices: tuple[DimKey, DimKey]) -> None: +# """Set which indices are visualized.""" +# if self._state.channel_index in indices: +# raise ValueError( +# f"channel index ({self._state.channel_index!r}) cannot be in visualized" +# f"indices: {indices}" +# ) +# self._state.visualized_indices = indices +# self.refresh() + +# def set_channel_index(self, index: DimKey | None) -> None: +# """Set the channel index.""" +# if index in self._state.visualized_indices: +# # consider alternatives to raising. +# # e.g. if len(visualized_indices) == 3, then we could pop index +# raise ValueError( +# f"channel index ({index!r}) cannot be in visualized indices: " +# f"{self._state.visualized_indices}" +# ) +# self._state.channel_index = index +# self.refresh() + +# def set_current_index(self, index: Mapping[DimKey, Index]) -> None: +# """Set the currentl displayed index.""" +# self._state.current_index = index +# self.refresh() + +# def refresh(self) -> None: +# """Refresh the viewer.""" +# index = self._state.current_index +# self._chunker.request_index(index) + +# @ensure_main_thread # type: ignore +# def _draw_chunk(self, chunk: ChunkResponse) -> None: ... diff --git a/src/ndv/viewer/_viewer.py b/src/ndv/viewer/_viewer.py index 84b34f1..29565cb 100644 --- a/src/ndv/viewer/_viewer.py +++ b/src/ndv/viewer/_viewer.py @@ -27,7 +27,7 @@ ) from ._backends import get_canvas -from ._backends._protocols import PImageHandle +from ._backends.protocols import PImageHandle from ._data_wrapper import DataWrapper from ._dims_slider import DimsSliders from ._lut_control import LutControl @@ -37,7 +37,7 @@ from qtpy.QtGui import QCloseEvent - from ._backends._protocols import PCanvas + from ._backends.protocols import PCanvas from ._dims_slider import DimKey, Indices, Sizes ImgKey: TypeAlias = Hashable diff --git a/tests/test_chunker.py b/tests/test_chunker.py new file mode 100644 index 0000000..546e967 --- /dev/null +++ b/tests/test_chunker.py @@ -0,0 +1,109 @@ +import numpy as np +import numpy.testing as npt + +from ndv._chunk_executor import Chunker, iter_chunk_aligned_slices + + +def test_iter_chunk_aligned_slices() -> None: + x = iter_chunk_aligned_slices( + shape=(10, 9), chunks=(4, 3), slices=np.index_exp[3:9, 1:None] + ) + assert list(x) == [ + (slice(3, 4, None), slice(1, 3, None)), + (slice(3, 4, None), slice(3, 6, None)), + (slice(3, 4, None), slice(6, 9, None)), + (slice(4, 8, None), slice(1, 3, None)), + (slice(4, 8, None), slice(3, 6, None)), + (slice(4, 8, None), slice(6, 9, None)), + (slice(8, 9, None), slice(1, 3, None)), + (slice(8, 9, None), slice(3, 6, None)), + (slice(8, 9, None), slice(6, 9, None)), + ] + + # this one tests that slices doesn't need to be the same length as shape + # ... is added at the end + y = iter_chunk_aligned_slices(shape=(6, 6), chunks=4, slices=np.index_exp[1:4]) + assert list(y) == [ + (slice(1, 4, None), slice(0, 4, None)), + (slice(1, 4, None), slice(4, 6, None)), + ] + + # this tests ellipsis in the middle + z = iter_chunk_aligned_slices( + shape=(3, 3, 3), chunks=2, slices=np.index_exp[1, ..., :2] + ) + assert list(z) == [ + (slice(1, 2, None), slice(0, 2, None), slice(0, 2, None)), + (slice(1, 2, None), slice(2, 3, None), slice(0, 2, None)), + ] + + +def test_chunker() -> None: + data = np.random.rand(100, 100).astype(np.float32) + + with Chunker() as chunker: + futures = chunker.request_chunks(data) + + assert len(futures) == 1 + npt.assert_array_equal(data, futures[0].result().data) + + data2 = np.random.rand(30, 30, 30).astype(np.float32) + # test that the data is correctly chunked with weird chunk shapes + with Chunker() as chunker: + futures = chunker.request_chunks( + data2, index={0: 0}, chunk_shape=(None, 17, 12) + ) + + new = np.empty_like(data2[0]) + for future in futures: + result = future.result() + new[result.array_location[1:]] = result.data + npt.assert_array_equal(new, data2[0]) + + +# # this test is provided as an example of using dask to accomplish a similar thing +# # this library should try to retain support for using dask instead of the internal +# # chunker ... but it's nice not to have to depend on dask otherwise. +# def test_dask_chunker() -> None: +# try: +# import dask.array as da +# from dask.distributed import Client +# except ImportError: +# pytest.skip("Dask not installed") + +# from itertools import product + +# data = np.random.rand(100, 100).astype(np.float32) +# dask_data_chunked = da.from_array(data, chunks=(25, 20)) # type: ignore +# chunk_sizes = dask_data_chunked.chunks + +# with Client() as client: # type: ignore [no-untyped-call] +# for idx in product(*(range(x) for x in dask_data_chunked.numblocks)): +# # Calculate the start indices (offsets) for the chunk +# # THIS is the main thing we'd need for visualization purposes. +# # wish there was an easier way to get this from the chunk_result alone +# offset = tuple(sum(sizes[:x]) for sizes, x in zip(chunk_sizes, idx)) + +# chunk = dask_data_chunked.blocks[idx] + +# future = client.compute(chunk) +# chunk_result = future.result() + +# # Test that the data is correctly chunked and equal to the original data +# sub_idx = tuple(slice(x, x + y) for x, y in zip(offset, chunk_result.shape)) +# expected = data[sub_idx] +# npt.assert_array_equal(expected, chunk_result) + + +# def test_dask_map_blocks() -> None: +# if TYPE_CHECKING: +# import dask.array +# else: +# dask = pytest.importorskip("dask") + +# dask_array = dask.array.random.randint(100, size=(100, 100), chunks=(25, 20)) +# block_ranges = (range(x) for x in dask_array.numblocks) +# for block_id in product(*(block_ranges)): +# _offset = tuple(sum(sizes[:x]) for sizes, x in zip(dask_array.chunks, block_id)) +# _chunk = dask_array.blocks[block_id] +# print(_offset, _chunk) diff --git a/y.py b/y.py index a5eb6b9..2855628 100644 --- a/y.py +++ b/y.py @@ -1,11 +1,11 @@ import numpy as np import ndv -from ndv._chunking import Slicer +from ndv._chunking import Chunker data = np.random.rand(10, 3, 8, 5, 128, 128) wrapper = ndv.DataWrapper.create(data) -slicer = Slicer(wrapper, chunks=(5, 1, 2, 2, 64, 34)) +slicer = Chunker(wrapper, chunks=(5, 1, 2, 2, 64, 34)) index = {0: 2, 1: 2, 2: 0, 3: 4} idx = wrapper.to_conventional(index) diff --git a/z.py b/z.py index 28ee931..579842b 100644 --- a/z.py +++ b/z.py @@ -1,38 +1,13 @@ -import random +import sys -import dask.array as da -from dask.distributed import Client, as_completed +from qtpy.QtWidgets import QApplication +import ndv +from ndv.viewer._v2 import NDViewer -# Function to load a chunk -def load_chunk(chunk): - # Simulate loading time - import time - - t = random.random() * 5 - print(t) - time.sleep(t) - return chunk - - -if __name__ == "__main__": - # Set up Dask Client - client = Client() - # Create a Dask array (simulate chunked storage) - x = da.random.random((10, 10), chunks=(5, 5)) - - # Submit tasks directly to the scheduler and get futures - futures = [] - for i in range(x.numblocks[0]): - for j in range(x.numblocks[1]): - chunk = x.blocks[i, j] - future = client.submit(load_chunk, chunk) - futures.append(future) - - # Monitor progress using as_completed - for future in as_completed(futures): - result = future.result() - print("Chunk ready:", result.shape) - - # Close the client - client.close() +data = ndv.data.cells3d() +app = QApplication(sys.argv) +viewer = NDViewer(data) +viewer.show() +viewer.set_current_index({1: 0}) +# app.exec() From a1b9ebe1a7729f9341ad879d359a5c016d3ba2d5 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Wed, 12 Jun 2024 11:05:21 -0400 Subject: [PATCH 10/12] wip --- src/ndv/viewer/_backends/_vispy.py | 25 +++++++++++++++---- src/ndv/viewer/_backends/protocols.py | 2 +- src/ndv/viewer/_v2.py | 35 +++++++++++++++++++++------ z.py => x.py | 3 +-- 4 files changed, 49 insertions(+), 16 deletions(-) rename z.py => x.py (80%) diff --git a/src/ndv/viewer/_backends/_vispy.py b/src/ndv/viewer/_backends/_vispy.py index d4adfd3..77a0750 100644 --- a/src/ndv/viewer/_backends/_vispy.py +++ b/src/ndv/viewer/_backends/_vispy.py @@ -6,6 +6,7 @@ import numpy as np import vispy +import vispy.gloo import vispy.scene import vispy.visuals from superqt.utils import qthrottled @@ -14,6 +15,8 @@ if TYPE_CHECKING: import cmap + import vispy.gloo.glir + import vispy.gloo.texture from qtpy.QtWidgets import QWidget from vispy.scene.events import SceneMouseEvent @@ -22,9 +25,9 @@ class VispyImageHandle: - def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None: + def __init__(self, visual: scene.Image | scene.Volume) -> None: self._visual = visual - self._ndim = 2 if isinstance(visual, scene.visuals.Image) else 3 + self._ndim = 2 if isinstance(visual, scene.Image) else 3 @property def data(self) -> np.ndarray: @@ -46,10 +49,22 @@ def data(self, data: np.ndarray) -> None: def clear(self) -> None: offset = (0,) * self.data.ndim - self.set_data(np.zeros(self.data.shape, dtype=self.data.dtype), offset) + self.directly_set_texture_offset( + np.zeros(self.data.shape, dtype=self.data.dtype), offset + ) + + def directly_set_texture_offset(self, data: np.ndarray, offset: tuple) -> None: + """LOW-LEVEL: Set the texture data at offset directly. - def set_data(self, data: np.ndarray, offset: tuple) -> None: - self._visual._texture._set_data(data, offset=offset) + We are bypassing all data transformations and checks here, so data *must* be + the correct shape and dtype. + """ + if self._ndim == 3: + if data.ndim == 3: + data = data[..., :] # add channel axis + texture = cast("vispy.gloo.texture.Texture3D", self._visual._texture) + queue = cast("vispy.gloo.glir.GlirQueue", texture._glir) + queue.command("DATA", texture._id, offset, data) @property def visible(self) -> bool: diff --git a/src/ndv/viewer/_backends/protocols.py b/src/ndv/viewer/_backends/protocols.py index de65877..f8895af 100644 --- a/src/ndv/viewer/_backends/protocols.py +++ b/src/ndv/viewer/_backends/protocols.py @@ -13,7 +13,7 @@ class PImageHandle(Protocol): def data(self) -> np.ndarray: ... @data.setter def data(self, data: np.ndarray) -> None: ... - def set_data(self, data: np.ndarray, offset: tuple) -> None: ... + def directly_set_texture_offset(self, data: np.ndarray, offset: tuple) -> None: ... @property def visible(self) -> bool: ... @visible.setter diff --git a/src/ndv/viewer/_v2.py b/src/ndv/viewer/_v2.py index 8ad43cc..4bf7c41 100644 --- a/src/ndv/viewer/_v2.py +++ b/src/ndv/viewer/_v2.py @@ -3,9 +3,11 @@ import numpy as np from qtpy.QtWidgets import QVBoxLayout, QWidget from superqt import ensure_main_thread +from superqt.utils import qthrottled from ndv._chunk_executor import Chunker, ChunkFuture from ndv.viewer._backends import get_canvas +from ndv.viewer._dims_slider import DimsSliders from ndv.viewer._state import ViewerState if TYPE_CHECKING: @@ -22,10 +24,17 @@ def __init__(self, data: Any, *, parent: QWidget | None = None): self._canvas: PCanvas = get_canvas()(lambda x: None) self._canvas.set_ndim(self._state.ndim) + # the sliders that control the index of the displayed image + self._dims_sliders = DimsSliders(self) + self._dims_sliders.valueChanged.connect( + qthrottled(self._request_data_for_index, 20, leading=True) + ) + layout = QVBoxLayout(self) layout.setSpacing(2) layout.setContentsMargins(6, 6, 6, 6) layout.addWidget(self._canvas.qwidget(), 1) + layout.addWidget(self._dims_sliders, 0) if data is not None: self.set_data(data) @@ -35,6 +44,9 @@ def __del__(self) -> None: def set_data(self, data: Any) -> None: self._data = data + self._dims_sliders.setMaxima( + {i: data.shape[i] - 1 for i in range(len(data.shape))} + ) def set_current_index(self, index: Mapping[int | str, int | slice]) -> None: """Set the currentl displayed index.""" @@ -68,14 +80,17 @@ def _request_data_for_index(self, index: Mapping[int | str, int | slice]) -> Non chunk_shape[dim] = chunk_size index = {self._norm_index(k): v for k, v in index.items()} - print("--------") - print("chunk shape", chunk_shape) - print("index", index) + for v in visualized: + if isinstance(index.get(v), int): + del index[v] + + if not index: + return + print("requesting data for index", index, chunk_shape) # clear existing handles for handle in self._channels.values(): handle.clear() - for future in self._chunker.request_chunks( data=self._data, index=index, @@ -115,11 +130,15 @@ def _draw_chunk(self, future: ChunkFuture) -> None: empty = np.empty(texture_shape, dtype=chunk.data.dtype) self._channels[channel_index] = handle = self._canvas.add_volume(empty) - mi, ma = handle.clim - handle.clim = (min(mi, np.min(data)), max(ma, np.max(data))) - handle.set_data(data, offset) + try: + mi, ma = handle.clim + handle.clim = (min(mi, np.min(data)), max(ma, np.max(data))) + except Exception as e: + print("err in clim: ", e) + handle.clim = (0, 5000) + + handle.directly_set_texture_offset(data, offset) self._canvas.refresh() - print("drawn chunk") # # of the chunks will determine the order of the channels in the LUTS # # (without additional logic to sort them by index, etc.) diff --git a/z.py b/x.py similarity index 80% rename from z.py rename to x.py index 579842b..812e058 100644 --- a/z.py +++ b/x.py @@ -9,5 +9,4 @@ app = QApplication(sys.argv) viewer = NDViewer(data) viewer.show() -viewer.set_current_index({1: 0}) -# app.exec() +app.exec() From 22e57b279aae3778979eb214da346f728a13cb01 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Wed, 12 Jun 2024 15:11:35 -0400 Subject: [PATCH 11/12] viewer2 viewer 1 --- src/ndv/__init__.py | 12 +- src/ndv/_chunk_executor.py | 27 +- src/ndv/_chunking.py | 2 +- src/ndv/viewer/_backends/__init__.py | 2 +- src/ndv/viewer/_backends/_protocols.py | 48 ++ src/ndv/viewer/_backends/_vispy.py | 48 +- src/ndv/viewer/_data_wrapper.py | 118 ++-- src/ndv/viewer/_lut_control.py | 49 +- src/ndv/viewer/_viewer.py | 279 ++++----- src/ndv/viewer2/__init__.py | 1 + src/ndv/viewer2/_backends/__init__.py | 36 ++ src/ndv/viewer2/_backends/_pygfx.py | 235 ++++++++ src/ndv/viewer2/_backends/_vispy.py | 259 ++++++++ .../_backends/protocols.py | 0 src/ndv/viewer2/_components.py | 66 ++ src/ndv/viewer2/_data_wrapper.py | 418 +++++++++++++ src/ndv/viewer2/_dims_slider.py | 529 +++++++++++++++++ src/ndv/viewer2/_lut_control.py | 150 +++++ src/ndv/viewer2/_octree.py | 105 ++++ src/ndv/viewer2/_save_button.py | 34 ++ src/ndv/{viewer => viewer2}/_state.py | 0 src/ndv/{viewer => viewer2}/_v2.py | 13 +- src/ndv/viewer2/_viewer.py | 562 ++++++++++++++++++ src/ndv/viewer2/spin.gif | Bin 0 -> 2384 bytes tests/test_chunker.py | 28 +- x.py | 2 +- 26 files changed, 2676 insertions(+), 347 deletions(-) create mode 100644 src/ndv/viewer/_backends/_protocols.py create mode 100644 src/ndv/viewer2/__init__.py create mode 100644 src/ndv/viewer2/_backends/__init__.py create mode 100644 src/ndv/viewer2/_backends/_pygfx.py create mode 100644 src/ndv/viewer2/_backends/_vispy.py rename src/ndv/{viewer => viewer2}/_backends/protocols.py (100%) create mode 100644 src/ndv/viewer2/_components.py create mode 100644 src/ndv/viewer2/_data_wrapper.py create mode 100644 src/ndv/viewer2/_dims_slider.py create mode 100644 src/ndv/viewer2/_lut_control.py create mode 100644 src/ndv/viewer2/_octree.py create mode 100644 src/ndv/viewer2/_save_button.py rename src/ndv/{viewer => viewer2}/_state.py (100%) rename src/ndv/{viewer => viewer2}/_v2.py (95%) create mode 100644 src/ndv/viewer2/_viewer.py create mode 100644 src/ndv/viewer2/spin.gif diff --git a/src/ndv/__init__.py b/src/ndv/__init__.py index 7faa6ea..a699742 100644 --- a/src/ndv/__init__.py +++ b/src/ndv/__init__.py @@ -13,8 +13,8 @@ from . import data from .util import imshow -from .viewer._data_wrapper import DataWrapper -from .viewer._viewer import NDViewer +from .viewer2._data_wrapper import DataWrapper +from .viewer2._viewer import NDViewer __all__ = ["NDViewer", "DataWrapper", "imshow", "data"] @@ -23,7 +23,7 @@ # these may be used externally, but are not guaranteed to be available at runtime # they must be used inside a TYPE_CHECKING block - from .viewer._dims_slider import DimKey as DimKey - from .viewer._dims_slider import Index as Index - from .viewer._dims_slider import Indices as Indices - from .viewer._dims_slider import Sizes as Sizes + from .viewer2._dims_slider import DimKey as DimKey + from .viewer2._dims_slider import Index as Index + from .viewer2._dims_slider import Indices as Indices + from .viewer2._dims_slider import Sizes as Sizes diff --git a/src/ndv/_chunk_executor.py b/src/ndv/_chunk_executor.py index 09500aa..5cb8501 100644 --- a/src/ndv/_chunk_executor.py +++ b/src/ndv/_chunk_executor.py @@ -204,20 +204,21 @@ def indexers_to_conventional_slice( return tuple(indexers.get(k, slice(None)) for k in range(ndim)) -def _slice2range(sl: SupportsIndex | slice, dim_size: int) -> tuple[int, int]: - """Convert slice to range, handling single int as well. +def _slice_indices(sl: SupportsIndex | slice, dim_size: int) -> tuple[int, int, int]: + """Convert slice to range arguments, handling single int as well. Examples -------- >>> _slice2range(3, 10) - (3, 4) + (3, 4, 1) + >>> _slice2range(slice(1, 4), 10) + (1, 4, 1) + >>> _slice2range(slice(1, None), 10) + (1, 10, 1) """ - if not isinstance(sl, slice): - idx = sl.__index__() - return (idx, idx + 1) - start = 0 if sl.start is None else max(sl.start, 0) - stop = dim_size if sl.stop is None else min(sl.stop, dim_size) - return (start, stop) + if isinstance(sl, slice): + return sl.indices(dim_size) + return (sl.__index__(), sl.__index__() + 1, 1) def iter_chunk_aligned_slices( @@ -307,24 +308,24 @@ def iter_chunk_aligned_slices( slices = slices + (slice(None),) * (ndim - len(slices)) # Create ranges for each dimension based on the slices provided - ranges = [_slice2range(sl, dim) for sl, dim in zip(slices, shape)] + ranges = [_slice_indices(sl, dim) for sl, dim in zip(slices, shape)] # Generate indices for each dimension that align with chunks aligned_ranges = ( range(start - (start % chunk_size), stop, chunk_size) - for (start, stop), chunk_size in zip(ranges, chunks) + for (start, stop, _), chunk_size in zip(ranges, chunks) ) # Create all combinations of these aligned ranges for indices in product(*aligned_ranges): chunk_slices = [] - for idx, (start, stop), ch in zip(indices, ranges, chunks): + for idx, (start, stop, step), ch in zip(indices, ranges, chunks): # Calculate the actual slice for each dimension start = max(start, idx) stop = min(stop, idx + ch) if start >= stop: # Skip empty slices break - chunk_slices.append(slice(start, stop)) + chunk_slices.append(slice(start, stop, step)) else: # Only add this combination of slices if all dimensions are valid yield tuple(chunk_slices) diff --git a/src/ndv/_chunking.py b/src/ndv/_chunking.py index 4137f07..f712747 100644 --- a/src/ndv/_chunking.py +++ b/src/ndv/_chunking.py @@ -23,7 +23,7 @@ from types import EllipsisType from typing import Callable, Iterable, Iterator, TypeAlias - from .viewer._data_wrapper import DataWrapper + from .viewer2._data_wrapper import DataWrapper # any hashable represent a single dimension in an ND array DimKey: TypeAlias = Hashable diff --git a/src/ndv/viewer/_backends/__init__.py b/src/ndv/viewer/_backends/__init__.py index 2ebf248..310c2be 100644 --- a/src/ndv/viewer/_backends/__init__.py +++ b/src/ndv/viewer/_backends/__init__.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ndv.viewer._backends.protocols import PCanvas + from ndv.viewer._backends._protocols import PCanvas def get_canvas(backend: str | None = None) -> type[PCanvas]: diff --git a/src/ndv/viewer/_backends/_protocols.py b/src/ndv/viewer/_backends/_protocols.py new file mode 100644 index 0000000..413038d --- /dev/null +++ b/src/ndv/viewer/_backends/_protocols.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol + +if TYPE_CHECKING: + import cmap + import numpy as np + from qtpy.QtWidgets import QWidget + + +class PImageHandle(Protocol): + @property + def data(self) -> np.ndarray: ... + @data.setter + def data(self, data: np.ndarray) -> None: ... + @property + def visible(self) -> bool: ... + @visible.setter + def visible(self, visible: bool) -> None: ... + @property + def clim(self) -> Any: ... + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: ... + @property + def cmap(self) -> Any: ... + @cmap.setter + def cmap(self, cmap: Any) -> None: ... + def remove(self) -> None: ... + + +class PCanvas(Protocol): + def __init__(self, set_info: Callable[[str], None]) -> None: ... + def set_ndim(self, ndim: Literal[2, 3]) -> None: ... + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, + margin: float = ..., + ) -> None: ... + def refresh(self) -> None: ... + def qwidget(self) -> QWidget: ... + def add_image( + self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ... + ) -> PImageHandle: ... + def add_volume( + self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ... + ) -> PImageHandle: ... diff --git a/src/ndv/viewer/_backends/_vispy.py b/src/ndv/viewer/_backends/_vispy.py index 77a0750..7ed78d7 100644 --- a/src/ndv/viewer/_backends/_vispy.py +++ b/src/ndv/viewer/_backends/_vispy.py @@ -6,7 +6,6 @@ import numpy as np import vispy -import vispy.gloo import vispy.scene import vispy.visuals from superqt.utils import qthrottled @@ -15,8 +14,6 @@ if TYPE_CHECKING: import cmap - import vispy.gloo.glir - import vispy.gloo.texture from qtpy.QtWidgets import QWidget from vispy.scene.events import SceneMouseEvent @@ -25,9 +22,9 @@ class VispyImageHandle: - def __init__(self, visual: scene.Image | scene.Volume) -> None: + def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None: self._visual = visual - self._ndim = 2 if isinstance(visual, scene.Image) else 3 + self._ndim = 2 if isinstance(visual, scene.visuals.Image) else 3 @property def data(self) -> np.ndarray: @@ -47,25 +44,6 @@ def data(self, data: np.ndarray) -> None: return self._visual.set_data(data) - def clear(self) -> None: - offset = (0,) * self.data.ndim - self.directly_set_texture_offset( - np.zeros(self.data.shape, dtype=self.data.dtype), offset - ) - - def directly_set_texture_offset(self, data: np.ndarray, offset: tuple) -> None: - """LOW-LEVEL: Set the texture data at offset directly. - - We are bypassing all data transformations and checks here, so data *must* be - the correct shape and dtype. - """ - if self._ndim == 3: - if data.ndim == 3: - data = data[..., :] # add channel axis - texture = cast("vispy.gloo.texture.Texture3D", self._visual._texture) - queue = cast("vispy.gloo.glir.GlirQueue", texture._glir) - queue.command("DATA", texture._id, offset, data) - @property def visible(self) -> bool: return bool(self._visual.visible) @@ -154,19 +132,12 @@ def refresh(self) -> None: self._canvas.update() def add_image( - self, - data: np.ndarray | None = None, - cmap: cmap.Colormap | None = None, - offset: tuple[float, float] | None = None, # (Y, X) + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None ) -> VispyImageHandle: """Add a new Image node to the scene.""" img = scene.visuals.Image(data, parent=self._view.scene) img.set_gl_state("additive", depth_test=False) img.interactive = True - - if offset: - img.transform = scene.STTransform(translate=offset[::-1]) - if data is not None: self._current_shape, prev_shape = data.shape, self._current_shape if not prev_shape: @@ -177,22 +148,13 @@ def add_image( return handle def add_volume( - self, - data: np.ndarray | None = None, - cmap: cmap.Colormap | None = None, - offset: tuple[float, float, float] | None = None, # (Z, Y, X) + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None ) -> VispyImageHandle: vol = scene.visuals.Volume( - data, - parent=self._view.scene, - interpolation="nearest", - texture_format="auto", + data, parent=self._view.scene, interpolation="nearest" ) vol.set_gl_state("additive", depth_test=False) vol.interactive = True - if offset: - vol.transform = scene.STTransform(translate=offset[::-1]) - if data is not None: self._current_shape, prev_shape = data.shape, self._current_shape if len(prev_shape) != 3: diff --git a/src/ndv/viewer/_data_wrapper.py b/src/ndv/viewer/_data_wrapper.py index 366e34f..002bbc5 100644 --- a/src/ndv/viewer/_data_wrapper.py +++ b/src/ndv/viewer/_data_wrapper.py @@ -4,8 +4,8 @@ import logging import sys -import warnings from abc import abstractmethod +from concurrent.futures import Future, ThreadPoolExecutor from contextlib import suppress from typing import ( TYPE_CHECKING, @@ -13,6 +13,7 @@ Container, Generic, Hashable, + Iterable, Iterator, Mapping, Sequence, @@ -57,7 +58,7 @@ def __gt__(self, other: _T_contra, /) -> bool: ... _T = TypeVar("_T", bound=type) # Global executor for slice requests -# _EXECUTOR = ThreadPoolExecutor(max_workers=2) +_EXECUTOR = ThreadPoolExecutor(max_workers=2) def _recurse_subclasses(cls: _T) -> Iterator[_T]: @@ -101,6 +102,13 @@ def create(cls, data: ArrayT) -> DataWrapper[ArrayT]: return subclass(data) raise NotImplementedError(f"Don't know how to wrap type {type(data)}") + def __init__(self, data: ArrayT) -> None: + self._data = data + + @property + def data(self) -> ArrayT: + return self._data + @classmethod @abstractmethod def supports(cls, obj: Any) -> bool: @@ -111,74 +119,38 @@ def supports(cls, obj: Any) -> bool: """ raise NotImplementedError - def __init__(self, data: ArrayT) -> None: - self._data = data - self._name2index: dict[str, int] = {} - if names := self.dimension_names(): - self._name2index = {name: i for i, name in enumerate(names)} - - @property - def data(self) -> ArrayT: - return self._data - - # @abstractmethod - # def isel(self, indexers: Indices) -> np.ndarray: - # """Select a slice from a data store using (possibly) named indices. - - # This follows the xarray-style indexing, where indexers is a mapping of - # dimension names to indices or slices. Subclasses should implement this - # method to return a numpy array. - # """ - # raise NotImplementedError - - def shape(self) -> tuple[int, ...]: - return self._data.shape # type: ignore - - def __getitem__(self, index: tuple[int | slice, ...]) -> np.ndarray: - # reimplement in subclasses - return np.asarray(self._data[index]) # type: ignore [index] - - def chunks(self) -> tuple[int, ...] | int | None: - if chunks := getattr(self._data, "chunks", None): - if isinstance(chunks, Sequence) and all(isinstance(x, int) for x in chunks): - return tuple(chunks) - warnings.warn( - f"Unexpected chunks attribute: {chunks!r}. Ignoring.", stacklevel=2 - ) - return None - - def dimension_names(self) -> tuple[str, ...] | None: - """Return the names of the dimensions of the data.""" - return None + @abstractmethod + def isel(self, indexers: Indices) -> np.ndarray: + """Select a slice from a data store using (possibly) named indices. - def to_conventional(self, indexers: Indices) -> tuple[int | slice, ...]: - """Convert named indices to a tuple of integers and slices.""" - _indexers = {self._name2index.get(str(k), k): v for k, v in indexers.items()} - return tuple(_indexers.get(k, slice(None)) for k in range(len(self.shape()))) + This follows the xarray-style indexing, where indexers is a mapping of + dimension names to indices or slices. Subclasses should implement this + method to return a numpy array. + """ + raise NotImplementedError - # def isel_async( - # self, indexers: list[Indices] - # ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: - # """Asynchronous version of isel.""" - # return _EXECUTOR.submit(lambda: [(idx, self.isel(idx)) for idx in indexers]) + def isel_async( + self, indexers: list[Indices] + ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: + """Asynchronous version of isel.""" + return _EXECUTOR.submit(lambda: [(idx, self.isel(idx)) for idx in indexers]) - def guess_channel_axis(self) -> int | None: + def guess_channel_axis(self) -> Hashable | None: """Return the (best guess) axis name for the channel dimension.""" # for arrays with labeled dimensions, # see if any of the dimensions are named "channel" - shape = self.shape() - if names := self.dimension_names(): - for ax, name in enumerate(names): - if ( - name.lower() in self.COMMON_CHANNEL_NAMES - and shape[ax] <= self.MAX_CHANNELS - ): - return ax + for dimkey, val in self.sizes().items(): + if str(dimkey).lower() in self.COMMON_CHANNEL_NAMES: + if val <= self.MAX_CHANNELS: + return dimkey # for shaped arrays, use the smallest dimension as the channel axis - with suppress(ValueError): - if (smallest_dim := min(shape)) <= self.MAX_CHANNELS: - return shape.index(smallest_dim) + shape = getattr(self._data, "shape", None) + if isinstance(shape, Sequence): + with suppress(ValueError): + smallest_dim = min(shape) + if smallest_dim <= self.MAX_CHANNELS: + return shape.index(smallest_dim) return None def save_as_zarr(self, save_loc: str | Path) -> None: @@ -192,10 +164,13 @@ def sizes(self) -> Sizes: (`dims` is used by xarray, `names` is used by torch, etc...). If no labels are found, the dimensions are just named by their integer index. """ - shape = self.shape() - if (names := self.dimension_names()) and len(names) == len(shape): - return dict(zip(names, shape)) - return dict(enumerate(shape)) + shape = getattr(self._data, "shape", None) + if not isinstance(shape, Sequence) or not all( + isinstance(x, int) for x in shape + ): + raise NotImplementedError(f"Cannot determine sizes for {type(self._data)}") + dims = range(len(shape)) + return {dim: int(size) for dim, size in zip(dims, shape)} def summary_info(self) -> str: """Return info label with information about the data.""" @@ -270,15 +245,8 @@ class ArrayLikeWrapper(DataWrapper, Generic[ArrayT]): PRIORITY = 100 def isel(self, indexers: Indices) -> np.ndarray: - idx = [] - for k in range(len(self._data.shape)): - i = indexers.get(k, slice(None)) - if isinstance(i, int): - idx.extend([i, None]) - else: - idx.append(i) - - return self._asarray(self._data[tuple(idx)]) + idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) + return self._asarray(self._data[idx]) def _asarray(self, data: npt.ArrayLike) -> np.ndarray: return np.asarray(data) diff --git a/src/ndv/viewer/_lut_control.py b/src/ndv/viewer/_lut_control.py index 1b9c1b3..b20909e 100644 --- a/src/ndv/viewer/_lut_control.py +++ b/src/ndv/viewer/_lut_control.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np from qtpy.QtCore import Qt @@ -16,7 +16,7 @@ import cmap - from ._backends.protocols import PImageHandle + from ._backends._protocols import PImageHandle class CmapCombo(QColormapComboBox): @@ -35,14 +35,13 @@ def showPopup(self) -> None: class LutControl(QWidget): def __init__( self, - channel: Sequence[PImageHandle], name: str = "", + handles: Iterable[PImageHandle] = (), parent: QWidget | None = None, cmaplist: Iterable[Any] = (), - cmap: cmap.Colormap | None = None, ) -> None: super().__init__(parent) - self._channel = channel + self._handles = handles self._name = name self._visible = QCheckBox(name) @@ -51,12 +50,10 @@ def __init__( self._cmap = CmapCombo() self._cmap.currentColormapChanged.connect(self._on_cmap_changed) - for handle in channel: + for handle in handles: self._cmap.addColormap(handle.cmap) for color in cmaplist: self._cmap.addColormap(color) - if cmap is not None: - self._cmap.setCurrentColormap(cmap) self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) self._clims.setStyleSheet(SS) @@ -87,36 +84,36 @@ def autoscaleChecked(self) -> bool: def _on_clims_changed(self, clims: tuple[float, float]) -> None: self._auto_clim.setChecked(False) - for handle in self._channel: + for handle in self._handles: handle.clim = clims def _on_visible_changed(self, visible: bool) -> None: - for handle in self._channel: + for handle in self._handles: handle.visible = visible if visible: self.update_autoscale() def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: - for handle in self._channel: + for handle in self._handles: handle.cmap = cmap def update_autoscale(self) -> None: if ( not self._auto_clim.isChecked() or not self._visible.isChecked() - or not self._channel + or not self._handles ): return # find the min and max values for the current channel clims = [np.inf, -np.inf] - for handle in self._channel: + for handle in self._handles: clims[0] = min(clims[0], np.nanmin(handle.data)) clims[1] = max(clims[1], np.nanmax(handle.data)) mi, ma = tuple(int(x) for x in clims) if mi != ma: - for handle in self._channel: + for handle in self._handles: handle.clim = (mi, ma) # set the slider values to the new clims @@ -124,27 +121,3 @@ def update_autoscale(self) -> None: self._clims.setMinimum(min(mi, self._clims.minimum())) self._clims.setMaximum(max(ma, self._clims.maximum())) self._clims.setValue((mi, ma)) - - -def _get_default_clim_from_data(data: np.ndarray) -> tuple[float, float]: - """Compute a reasonable clim from the min and max, taking nans into account. - - If there are no non-finite values (nan, inf, -inf) this is as fast as it can be. - Otherwise, this functions is about 3x slower. - """ - # Fast - min_value = data.min() - max_value = data.max() - - # Need more work? The nan-functions are slower - min_finite = np.isfinite(min_value) - max_finite = np.isfinite(max_value) - if not (min_finite and max_finite): - finite_data = data[np.isfinite(data)] - if finite_data.size: - min_value = finite_data.min() - max_value = finite_data.max() - else: - min_value = max_value = 0 # no finite values in the data - - return min_value, max_value diff --git a/src/ndv/viewer/_viewer.py b/src/ndv/viewer/_viewer.py index 29565cb..473f0f5 100644 --- a/src/ndv/viewer/_viewer.py +++ b/src/ndv/viewer/_viewer.py @@ -1,16 +1,8 @@ from __future__ import annotations +from collections import defaultdict from itertools import cycle -from typing import ( - TYPE_CHECKING, - Hashable, - Literal, - MutableSequence, - Sequence, - SupportsIndex, - cast, - overload, -) +from typing import TYPE_CHECKING, Literal, cast import cmap import numpy as np @@ -18,7 +10,6 @@ from superqt import QCollapsible, QElidingLabel, QIconifyIcon, ensure_main_thread from superqt.utils import qthrottled, signals_blocked -from ndv._chunking import Chunker, ChunkResponse, RequestFinished from ndv.viewer._components import ( ChannelMode, ChannelModeButton, @@ -27,17 +18,17 @@ ) from ._backends import get_canvas -from ._backends.protocols import PImageHandle from ._data_wrapper import DataWrapper from ._dims_slider import DimsSliders from ._lut_control import LutControl if TYPE_CHECKING: - from typing import Any, Iterable, TypeAlias + from concurrent.futures import Future + from typing import Any, Callable, Hashable, Iterable, Sequence, TypeAlias from qtpy.QtGui import QCloseEvent - from ._backends.protocols import PCanvas + from ._backends._protocols import PCanvas, PImageHandle from ._dims_slider import DimKey, Indices, Sizes ImgKey: TypeAlias = Hashable @@ -56,43 +47,7 @@ cmap.Colormap("cubehelix"), cmap.Colormap("gray"), ] -MONO_CHANNEL = -999999 - - -class Channel(MutableSequence[PImageHandle]): - def __init__(self, ch_key: int, cmap: cmap.Colormap = GRAYS) -> None: - self.ch_key = ch_key - self._handles: list[PImageHandle] = [] - self.cmap = cmap - - @overload - def __getitem__(self, i: int) -> PImageHandle: ... - @overload - def __getitem__(self, i: slice) -> list[PImageHandle]: ... - def __getitem__(self, i: int | slice) -> PImageHandle | list[PImageHandle]: - return self._handles[i] - - @overload - def __setitem__(self, i: SupportsIndex, value: PImageHandle) -> None: ... - @overload - def __setitem__(self, i: slice, value: Iterable[PImageHandle]) -> None: ... - def __setitem__( - self, i: SupportsIndex | slice, value: PImageHandle | Iterable[PImageHandle] - ) -> None: - self._handles[i] = value # type: ignore - - @overload - def __delitem__(self, i: int) -> None: ... - @overload - def __delitem__(self, i: slice) -> None: ... - def __delitem__(self, i: int | slice) -> None: - del self._handles[i] - - def __len__(self) -> int: - return len(self._handles) - - def insert(self, i: int, value: PImageHandle) -> None: - self._handles.insert(i, value) +ALL_CHANNELS = slice(None) class NDViewer(QWidget): @@ -118,13 +73,13 @@ class NDViewer(QWidget): with the `_dims_sliders.value()` method. To programmatically set the current position, use the `setIndex` method. This will set the values of the sliders, which in turn will trigger the display of the new slice via the - `_request_data_for_index` method. - - `_request_data_for_index` is an asynchronous method that retrieves the data for + `_update_data_for_index` method. + - `_update_data_for_index` is an asynchronous method that retrieves the data for the given index from the datastore (using `_isel`) and queues the - `_draw_chunk` method to be called when the data is ready. The logic + `_on_data_slice_ready` method to be called when the data is ready. The logic for extracting data from the datastore is defined in `_data_wrapper.py`, which handles idiosyncrasies of different datastores (e.g. xarray, tensorstore, etc). - - `_draw_chunk` is called when the data is ready, and updates the image. + - `_on_data_slice_ready` is called when the data is ready, and updates the image. Note that if the slice is multidimensional, the data will be reduced to 2D using max intensity projection (and double-clicking on any given dimension slider will turn it into a range slider allowing a projection to be made over that dimension). @@ -156,7 +111,7 @@ def __init__( *, colormaps: Iterable[cmap._colormap.ColorStopsLike] | None = None, parent: QWidget | None = None, - channel_axis: int | None = None, + channel_axis: DimKey | None = None, channel_mode: ChannelMode | str = ChannelMode.MONO, ): super().__init__(parent=parent) @@ -164,16 +119,13 @@ def __init__( # ATTRIBUTES ---------------------------------------------------- # mapping of key to a list of objects that control image nodes in the canvas - self._channels: dict[int, Channel] = {} - + self._img_handles: defaultdict[ImgKey, list[PImageHandle]] = defaultdict(list) # mapping of same keys to the LutControl objects control image display props - self._lut_ctrls: dict[int, LutControl] = {} - - # the set of dimensions we are currently visualizing (e.g. (-2, -1) for 2D) + self._lut_ctrls: dict[ImgKey, LutControl] = {} + # the set of dimensions we are currently visualizing (e.g. XY) # this is used to control which dimensions have sliders and the behavior # of isel when selecting data from the datastore self._visualized_dims: set[DimKey] = set() - # the axis that represents the channels in the data self._channel_axis = channel_axis self._channel_mode: ChannelMode = None # type: ignore # set in set_channel_mode @@ -183,17 +135,11 @@ def __init__( else: self._cmaps = DEFAULT_COLORMAPS self._cmap_cycle = cycle(self._cmaps) + # the last future that was created by _update_data_for_index + self._last_future: Future | None = None # number of dimensions to display self._ndims: Literal[2, 3] = 2 - self._chunker = Chunker( - None, - # IMPORTANT - # chunking here will determine how non-visualized dims are reduced - # so chunkshape will need to change based on the set of visualized dims - chunks=(20, 100, 32, 32), - on_ready=self._draw_chunk, - ) # WIDGETS ---------------------------------------------------- @@ -223,7 +169,7 @@ def __init__( # the sliders that control the index of the displayed image self._dims_sliders = DimsSliders(self) self._dims_sliders.valueChanged.connect( - qthrottled(self._request_data_for_index, 20, leading=True) + qthrottled(self._update_data_for_index, 20, leading=True) ) self._lut_drop = QCollapsible("LUTs", self) @@ -311,19 +257,10 @@ def set_data( the initial index will be set to the middle of the data. """ # store the data - self._clear_images() - self._data_wrapper = DataWrapper.create(data) - self._chunker.data_wrapper = self._data_wrapper - if chunks := self._data_wrapper.chunks(): - # temp hack ... always group non-visible channels - chunks = list(chunks) - chunks[:-2] = (1000,) * len(chunks[:-2]) - self._chunker.chunks = tuple(chunks) # set channel axis self._channel_axis = self._data_wrapper.guess_channel_axis() - self._chunker.channel_axis = self._channel_axis # update the dimensions we are visualizing sizes = self._data_wrapper.sizes() @@ -336,13 +273,12 @@ def set_data( # redraw if initial_index is None: - idx = {k: int(v // 2) for k, v in sizes.items() if k not in visualized_dims} + idx = {k: int(v // 2) for k, v in sizes.items()} else: if not isinstance(initial_index, dict): # pragma: no cover raise TypeError("initial_index must be a dict") idx = initial_index self.set_current_index(idx) - # update the data info label self._data_info_label.setText(self._data_wrapper.summary_info()) @@ -374,9 +310,9 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None: self._dims_sliders.set_dimension_visible(dim3, True if ndim == 2 else False) # clear image handles and redraw - if self._channels: + if self._img_handles: self._clear_images() - self._request_data_for_index(self._dims_sliders.value()) + self._update_data_for_index(self._dims_sliders.value()) def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: """Set the mode for displaying the channels. @@ -408,9 +344,9 @@ def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: self._channel_axis, mode != ChannelMode.COMPOSITE ) - if self._channels: + if self._img_handles: self._clear_images() - self._request_data_for_index(self._dims_sliders.value()) + self._update_data_for_index(self._dims_sliders.value()) def set_current_index(self, index: Indices | None = None) -> None: """Set the index of the displayed image. @@ -464,88 +400,137 @@ def _image_key(self, index: Indices) -> ImgKey: return val return 0 - def _request_data_for_index(self, index: Indices) -> None: + def _update_data_for_index(self, index: Indices) -> None: """Retrieve data for `index` from datastore and update canvas image(s). - This is the first step in updating the displayed image, it is triggered by - the valueChanged signal from the sliders. - This will pull the data from the datastore using the given index, and update the image handle(s) with the new data. This method is *asynchronous*. It makes a request for the new data slice and queues _on_data_future_done to be called when the data is ready. """ - print(f"\n--------\nrequesting index {index}", self._channel_axis) if ( - self._channel_mode == ChannelMode.COMPOSITE - and self._channel_axis is not None + self._channel_axis is not None + and self._channel_mode == ChannelMode.COMPOSITE + and self._channel_axis in (sizes := self._data_wrapper.sizes()) ): - index = {**index, self._channel_axis: slice(None)} + indices: list[Indices] = [ + {**index, self._channel_axis: i} + for i in range(sizes[self._channel_axis]) + ] + else: + indices = [index] + + if self._last_future: + self._last_future.cancel() + + # don't request any dimensions that are not visualized + indices = [ + {k: v for k, v in idx.items() if k not in self._visualized_dims} + for idx in indices + ] + try: + self._last_future = f = self._data_wrapper.isel_async(indices) + except Exception as e: + raise type(e)(f"Failed to index data with {index}: {e}") from e + self._progress_spinner.show() - # TODO: don't request channels not being displayed - # TODO: don't request if the data is already in the cache - self._chunker.request_index(index, ndims=self._ndims) + f.add_done_callback(self._on_data_slice_ready) + + def closeEvent(self, a0: QCloseEvent | None) -> None: + if self._last_future is not None: + self._last_future.cancel() + self._last_future = None + super().closeEvent(a0) @ensure_main_thread # type: ignore - def _draw_chunk(self, chunk: ChunkResponse) -> None: + def _on_data_slice_ready( + self, future: Future[Iterable[tuple[Indices, np.ndarray]]] + ) -> None: + """Update the displayed image for the given index. + + Connected to the future returned by _isel. + """ + # NOTE: removing the reference to the last future here is important + # because the future has a reference to this widget in its _done_callbacks + # which will prevent the widget from being garbage collected if the future + self._last_future = None + self._progress_spinner.hide() + if future.cancelled(): + return + + for idx, datum in future.result(): + self._update_canvas_data(datum, idx) + self._canvas.refresh() + + def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: """Actually update the image handle(s) with the (sliced) data. By this point, data should be sliced from the underlying datastore. Any dimensions remaining that are more than the number of visualized dimensions (currently just 2D) will be reduced using max intensity projection (currently). """ - if chunk is RequestFinished: # fix typing - self._progress_spinner.hide() - for lut in self._lut_ctrls.values(): - lut.update_autoscale() - return - - if self._channel_mode == ChannelMode.MONO: - ch_key = MONO_CHANNEL + imkey = self._image_key(index) + datum = self._reduce_data_for_display(data) + if handles := self._img_handles[imkey]: + for handle in handles: + handle.data = datum + if ctrl := self._lut_ctrls.get(imkey, None): + ctrl.update_autoscale() else: - ch_key = chunk.channel_index - - data = chunk.data - if data.ndim == 2: - return - # TODO: Channel object creation could be moved. - # having it here is the laziest... but means that the order of arrival - # of the chunks will determine the order of the channels in the LUTS - # (without additional logic to sort them by index, etc.) - if (handles := self._channels.get(ch_key)) is None: - handles = self._create_channel(ch_key) - - if not handles: - if data.ndim == 2: - handles.append(self._canvas.add_image(data, cmap=handles.cmap)) - elif data.ndim == 3: - empty = np.empty((60, 256, 256), dtype=np.uint16) - handles.append(self._canvas.add_volume(empty, cmap=handles.cmap)) - - handles[0].set_data(data, chunk.offset) - self._canvas.refresh() - - def _create_channel(self, ch_key: int) -> Channel: - # improve this - cmap = GRAYS if ch_key == MONO_CHANNEL else next(self._cmap_cycle) - - self._channels[ch_key] = channel = Channel(ch_key, cmap=cmap) - self._lut_ctrls[ch_key] = lut = LutControl( - channel, - f"Ch {ch_key}", - self, - cmaplist=self._cmaps + DEFAULT_COLORMAPS, - cmap=cmap, - ) - self._lut_drop.addWidget(lut) - return channel + cm = ( + next(self._cmap_cycle) + if self._channel_mode == ChannelMode.COMPOSITE + else GRAYS + ) + if datum.ndim == 2: + handles.append(self._canvas.add_image(datum, cmap=cm)) + elif datum.ndim == 3: + handles.append(self._canvas.add_volume(datum, cmap=cm)) + if imkey not in self._lut_ctrls: + ch_index = index.get(self._channel_axis, 0) + self._lut_ctrls[imkey] = c = LutControl( + f"Ch {ch_index}", + handles, + self, + cmaplist=self._cmaps + DEFAULT_COLORMAPS, + ) + self._lut_drop.addWidget(c) + + def _reduce_data_for_display( + self, data: np.ndarray, reductor: Callable[..., np.ndarray] = np.max + ) -> np.ndarray: + """Reduce the number of dimensions in the data for display. + + This function takes a data array and reduces the number of dimensions to + the max allowed for display. The default behavior is to reduce the smallest + dimensions, using np.max. This can be improved in the future. + + This also coerces 64-bit data to 32-bit data. + """ + # TODO + # - allow dimensions to control how they are reduced (as opposed to just max) + # - for better way to determine which dims need to be reduced (currently just + # the smallest dims) + data = data.squeeze() + visualized_dims = self._ndims + if extra_dims := data.ndim - visualized_dims: + shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) + smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) + data = reductor(data, axis=smallest_dims) + + if data.dtype.itemsize > 4: # More than 32 bits + if np.issubdtype(data.dtype, np.integer): + data = data.astype(np.int32) + else: + data = data.astype(np.float32) + return data def _clear_images(self) -> None: """Remove all images from the canvas.""" - for handles in self._channels.values(): + for handles in self._img_handles.values(): for handle in handles: handle.remove() - self._channels.clear() + self._img_handles.clear() # clear the current LutControls as well for c in self._lut_ctrls.values(): @@ -555,8 +540,4 @@ def _clear_images(self) -> None: def _is_idle(self) -> bool: """Return True if no futures are running. Used for testing, and debugging.""" - return bool(self._chunker.pending_futures) - - def closeEvent(self, a0: QCloseEvent | None) -> None: - self._chunker.shutdown() - super().closeEvent(a0) + return self._last_future is None diff --git a/src/ndv/viewer2/__init__.py b/src/ndv/viewer2/__init__.py new file mode 100644 index 0000000..09c9470 --- /dev/null +++ b/src/ndv/viewer2/__init__.py @@ -0,0 +1 @@ +"""viewer source.""" diff --git a/src/ndv/viewer2/_backends/__init__.py b/src/ndv/viewer2/_backends/__init__.py new file mode 100644 index 0000000..6c50889 --- /dev/null +++ b/src/ndv/viewer2/_backends/__init__.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import importlib +import importlib.util +import os +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ndv.viewer2._backends.protocols import PCanvas + + +def get_canvas(backend: str | None = None) -> type[PCanvas]: + backend = backend or os.getenv("NDV_CANVAS_BACKEND", None) + if backend == "vispy" or (backend is None and "vispy" in sys.modules): + from ._vispy import VispyViewerCanvas + + return VispyViewerCanvas + + if backend == "pygfx" or (backend is None and "pygfx" in sys.modules): + from ._pygfx import PyGFXViewerCanvas + + return PyGFXViewerCanvas + + if backend is None: + if importlib.util.find_spec("vispy") is not None: + from ._vispy import VispyViewerCanvas + + return VispyViewerCanvas + + if importlib.util.find_spec("pygfx") is not None: + from ._pygfx import PyGFXViewerCanvas + + return PyGFXViewerCanvas + + raise RuntimeError("No canvas backend found") diff --git a/src/ndv/viewer2/_backends/_pygfx.py b/src/ndv/viewer2/_backends/_pygfx.py new file mode 100644 index 0000000..f8190ed --- /dev/null +++ b/src/ndv/viewer2/_backends/_pygfx.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +import numpy as np +import pygfx +from qtpy.QtCore import QSize +from wgpu.gui.qt import QWgpuCanvas + +if TYPE_CHECKING: + import cmap + from pygfx.materials import ImageBasicMaterial + from pygfx.resources import Texture + from qtpy.QtWidgets import QWidget + + +class PyGFXImageHandle: + def __init__(self, image: pygfx.Image | pygfx.Volume, render: Callable) -> None: + self._image = image + self._render = render + self._grid = cast("Texture", image.geometry.grid) + self._material = cast("ImageBasicMaterial", image.material) + + @property + def data(self) -> np.ndarray: + return self._grid.data # type: ignore [no-any-return] + + @data.setter + def data(self, data: np.ndarray) -> None: + self._grid.data[:] = data + self._grid.update_range((0, 0, 0), self._grid.size) + + @property + def visible(self) -> bool: + return bool(self._image.visible) + + @visible.setter + def visible(self, visible: bool) -> None: + self._image.visible = visible + self._render() + + @property + def clim(self) -> Any: + return self._material.clim + + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: + self._material.clim = clims + self._render() + + @property + def cmap(self) -> cmap.Colormap: + return self._cmap + + @cmap.setter + def cmap(self, cmap: cmap.Colormap) -> None: + self._cmap = cmap + self._material.map = cmap.to_pygfx() + self._render() + + def remove(self) -> None: + if (par := self._image.parent) is not None: + par.remove(self._image) + + +class _QWgpuCanvas(QWgpuCanvas): + def sizeHint(self) -> QSize: + return QSize(512, 512) + + +class PyGFXViewerCanvas: + """pygfx-based canvas wrapper.""" + + def __init__(self, set_info: Callable[[str], None]) -> None: + self._set_info = set_info + self._current_shape: tuple[int, ...] = () + self._last_state: dict[Literal[2, 3], Any] = {} + + self._canvas = _QWgpuCanvas(size=(512, 512)) + self._renderer = pygfx.renderers.WgpuRenderer(self._canvas) + try: + # requires https://github.com/pygfx/pygfx/pull/752 + self._renderer.blend_mode = "additive" + except ValueError: + warnings.warn( + "This version of pygfx does not yet support additive blending.", + stacklevel=3, + ) + self._renderer.blend_mode = "weighted_depth" + + self._scene = pygfx.Scene() + self._camera: pygfx.Camera | None = None + self._ndim: Literal[2, 3] | None = None + + def qwidget(self) -> QWidget: + return cast("QWidget", self._canvas) + + def set_ndim(self, ndim: Literal[2, 3]) -> None: + """Set the number of dimensions of the displayed data.""" + if ndim == self._ndim: + return + elif self._ndim is not None and self._camera is not None: + # remember the current state before switching to the new camera + self._last_state[self._ndim] = self._camera.get_state() + + self._ndim = ndim + if ndim == 3: + self._camera = cam = pygfx.PerspectiveCamera(0, 1) + cam.show_object(self._scene, up=(0, -1, 0), view_dir=(0, 0, 1)) + controller = pygfx.OrbitController(cam, register_events=self._renderer) + zoom = "zoom" + # FIXME: there is still an issue with rotational centration. + # the controller is not rotating around the middle of the volume... + # but I think it might actually be a pygfx issue... the critical state + # seems to be somewhere outside of the camera's get_state dict. + else: + self._camera = cam = pygfx.OrthographicCamera(512, 512) + cam.local.scale_y = -1 + cam.local.position = (256, 256, 0) + controller = pygfx.PanZoomController(cam, register_events=self._renderer) + zoom = "zoom_to_point" + + self._controller = controller + # increase zoom wheel gain + self._controller.controls.update({"wheel": (zoom, "push", -0.005)}) + + # restore the previous state if it exists + if state := self._last_state.get(ndim): + cam.set_state(state) + + def add_image( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> PyGFXImageHandle: + """Add a new Image node to the scene.""" + tex = pygfx.Texture(data, dim=2) + image = pygfx.Image( + pygfx.Geometry(grid=tex), + # depth_test=False for additive-like blending + pygfx.ImageBasicMaterial(depth_test=False), + ) + self._scene.add(image) + + if data is not None: + self._current_shape, prev_shape = data.shape, self._current_shape + if not prev_shape: + self.set_range() + + # FIXME: I suspect there are more performant ways to refresh the canvas + # look into it. + handle = PyGFXImageHandle(image, self.refresh) + if cmap is not None: + handle.cmap = cmap + return handle + + def add_volume( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> PyGFXImageHandle: + tex = pygfx.Texture(data, dim=3) + vol = pygfx.Volume( + pygfx.Geometry(grid=tex), + # depth_test=False for additive-like blending + pygfx.VolumeRayMaterial(interpolation="nearest", depth_test=False), + ) + self._scene.add(vol) + + if data is not None: + vol.local_position = [-0.5 * i for i in data.shape[::-1]] + self._current_shape, prev_shape = data.shape, self._current_shape + if len(prev_shape) != 3: + self.set_range() + + # FIXME: I suspect there are more performant ways to refresh the canvas + # look into it. + handle = PyGFXImageHandle(vol, self.refresh) + if cmap is not None: + handle.cmap = cmap + return handle + + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, + margin: float = 0.05, + ) -> None: + """Update the range of the PanZoomCamera. + + When called with no arguments, the range is set to the full extent of the data. + """ + if not self._scene.children or self._camera is None: + return + + cam = self._camera + cam.show_object(self._scene) + + width, height, depth = np.ptp(self._scene.get_world_bounding_box(), axis=0) + if width < 0.01: + width = 1 + if height < 0.01: + height = 1 + cam.width = width + cam.height = height + cam.zoom = 1 - margin + self.refresh() + + def refresh(self) -> None: + self._canvas.update() + self._canvas.request_draw(self._animate) + + def _animate(self) -> None: + self._renderer.render(self._scene, self._camera) + + # def _on_mouse_move(self, event: SceneMouseEvent) -> None: + # """Mouse moved on the canvas, display the pixel value and position.""" + # images = [] + # # Get the images the mouse is over + # seen = set() + # while visual := self._canvas.visual_at(event.pos): + # if isinstance(visual, scene.visuals.Image): + # images.append(visual) + # visual.interactive = False + # seen.add(visual) + # for visual in seen: + # visual.interactive = True + # if not images: + # return + + # tform = images[0].get_transform("canvas", "visual") + # px, py, *_ = (int(x) for x in tform.map(event.pos)) + # text = f"[{py}, {px}]" + # for c, img in enumerate(images): + # with suppress(IndexError): + # text += f" c{c}: {img._data[py, px]}" + # self._set_info(text) diff --git a/src/ndv/viewer2/_backends/_vispy.py b/src/ndv/viewer2/_backends/_vispy.py new file mode 100644 index 0000000..1460fa8 --- /dev/null +++ b/src/ndv/viewer2/_backends/_vispy.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +import warnings +from contextlib import suppress +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +import numpy as np +import vispy +import vispy.gloo +import vispy.scene +import vispy.visuals +from superqt.utils import qthrottled +from vispy import scene +from vispy.util.quaternion import Quaternion + +if TYPE_CHECKING: + import cmap + import vispy.gloo.glir + import vispy.gloo.texture + from qtpy.QtWidgets import QWidget + from vispy.scene.events import SceneMouseEvent + +turn = np.sin(np.pi / 4) +DEFAULT_QUATERNION = Quaternion(turn, turn, 0, 0) + + +class VispyImageHandle: + def __init__(self, visual: scene.Image | scene.Volume) -> None: + self._visual = visual + self._ndim = 2 if isinstance(visual, scene.Image) else 3 + + @property + def data(self) -> np.ndarray: + try: + return self._visual._data # type: ignore [no-any-return] + except AttributeError: + return self._visual._last_data # type: ignore [no-any-return] + + @data.setter + def data(self, data: np.ndarray) -> None: + if not data.ndim == self._ndim: + warnings.warn( + f"Got wrong number of dimensions ({data.ndim}) for vispy " + f"visual of type {type(self._visual)}.", + stacklevel=2, + ) + return + self._visual.set_data(data) + + def clear(self) -> None: + offset = (0,) * self.data.ndim + self.directly_set_texture_offset( + np.zeros(self.data.shape, dtype=self.data.dtype), offset + ) + + def directly_set_texture_offset(self, data: np.ndarray, offset: tuple) -> None: + """LOW-LEVEL: Set the texture data at offset directly. + + We are bypassing all data transformations and checks here, so data *must* be + the correct shape and dtype. + """ + if self._ndim == 3: + if data.ndim == 3: + data = data.reshape((*data.shape, 1)) + texture = cast("vispy.gloo.texture.Texture3D", self._visual._texture) + queue = cast("vispy.gloo.glir.GlirQueue", texture._glir) + queue.command("DATA", texture._id, offset, data) + + @property + def visible(self) -> bool: + return bool(self._visual.visible) + + @visible.setter + def visible(self, visible: bool) -> None: + self._visual.visible = visible + + @property + def clim(self) -> Any: + return self._visual.clim + + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: + with suppress(ZeroDivisionError): + self._visual.clim = clims + + @property + def cmap(self) -> cmap.Colormap: + return self._cmap + + @cmap.setter + def cmap(self, cmap: cmap.Colormap) -> None: + self._cmap = cmap + self._visual.cmap = cmap.to_vispy() + + @property + def transform(self) -> np.ndarray: + raise NotImplementedError + + @transform.setter + def transform(self, transform: np.ndarray) -> None: + raise NotImplementedError + + def remove(self) -> None: + self._visual.parent = None + + +class VispyViewerCanvas: + """Vispy-based viewer for data. + + All vispy-specific code is encapsulated in this class (and non-vispy canvases + could be swapped in if needed as long as they implement the same interface). + """ + + def __init__(self, set_info: Callable[[str], None]) -> None: + self._set_info = set_info + self._canvas = scene.SceneCanvas() + self._canvas.events.mouse_move.connect(qthrottled(self._on_mouse_move, 60)) + self._current_shape: tuple[int, ...] = () + self._last_state: dict[Literal[2, 3], Any] = {} + + central_wdg: scene.Widget = self._canvas.central_widget + self._view: scene.ViewBox = central_wdg.add_view() + self._ndim: Literal[2, 3] | None = None + + @property + def _camera(self) -> vispy.scene.cameras.BaseCamera: + return self._view.camera + + def set_ndim(self, ndim: Literal[2, 3]) -> None: + """Set the number of dimensions of the displayed data.""" + if ndim == self._ndim: + return + elif self._ndim is not None: + # remember the current state before switching to the new camera + self._last_state[self._ndim] = self._camera.get_state() + + self._ndim = ndim + if ndim == 3: + cam = scene.ArcballCamera(fov=0) + # this sets the initial view similar to what the panzoom view would have. + cam._quaternion = DEFAULT_QUATERNION + else: + cam = scene.PanZoomCamera(aspect=1, flip=(0, 1)) + + # restore the previous state if it exists + if state := self._last_state.get(ndim): + cam.set_state(state) + self._view.camera = cam + + def qwidget(self) -> QWidget: + return cast("QWidget", self._canvas.native) + + def refresh(self) -> None: + self._canvas.update() + + def add_image( + self, + data: np.ndarray | None = None, + cmap: cmap.Colormap | None = None, + offset: tuple[float, float] | None = None, # (Y, X) + ) -> VispyImageHandle: + """Add a new Image node to the scene.""" + img = scene.visuals.Image(data, parent=self._view.scene) + img.set_gl_state("additive", depth_test=False) + img.interactive = True + + if offset: + img.transform = scene.STTransform(translate=offset[::-1]) + + if data is not None: + self._current_shape, prev_shape = data.shape, self._current_shape + if not prev_shape: + self.set_range() + handle = VispyImageHandle(img) + if cmap is not None: + handle.cmap = cmap + return handle + + def add_volume( + self, + data: np.ndarray | None = None, + cmap: cmap.Colormap | None = None, + offset: tuple[float, float, float] | None = None, # (Z, Y, X) + ) -> VispyImageHandle: + vol = scene.visuals.Volume( + data, + parent=self._view.scene, + interpolation="nearest", + texture_format="auto", + ) + vol.set_gl_state("additive", depth_test=False) + vol.interactive = True + if offset: + vol.transform = scene.STTransform(translate=offset[::-1]) + + if data is not None: + self._current_shape, prev_shape = data.shape, self._current_shape + if len(prev_shape) != 3: + self.set_range() + handle = VispyImageHandle(vol) + if cmap is not None: + handle.cmap = cmap + return handle + + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, + margin: float = 0.01, + ) -> None: + """Update the range of the PanZoomCamera. + + When called with no arguments, the range is set to the full extent of the data. + """ + if len(self._current_shape) >= 2: + if x is None: + x = (0, self._current_shape[-1]) + if y is None: + y = (0, self._current_shape[-2]) + if z is None and len(self._current_shape) == 3: + z = (0, self._current_shape[-3]) + is_3d = isinstance(self._camera, scene.ArcballCamera) + if is_3d: + self._camera._quaternion = DEFAULT_QUATERNION + self._view.camera.set_range(x=x, y=y, z=z, margin=margin) + if is_3d: + max_size = max(self._current_shape) + self._camera.scale_factor = max_size + 6 + + def _on_mouse_move(self, event: SceneMouseEvent) -> None: + """Mouse moved on the canvas, display the pixel value and position.""" + images = [] + # Get the images the mouse is over + # FIXME: this is narsty ... there must be a better way to do this + seen = set() + try: + while visual := self._canvas.visual_at(event.pos): + if isinstance(visual, scene.visuals.Image): + images.append(visual) + visual.interactive = False + seen.add(visual) + except Exception: + return + for visual in seen: + visual.interactive = True + if not images: + return + + tform = images[0].get_transform("canvas", "visual") + px, py, *_ = (int(x) for x in tform.map(event.pos)) + text = f"[{py}, {px}]" + for c, img in enumerate(reversed(images)): + with suppress(IndexError): + value = img._data[py, px] + if isinstance(value, (np.floating, float)): + value = f"{value:.2f}" + text += f" {c}: {value}" + self._set_info(text) diff --git a/src/ndv/viewer/_backends/protocols.py b/src/ndv/viewer2/_backends/protocols.py similarity index 100% rename from src/ndv/viewer/_backends/protocols.py rename to src/ndv/viewer2/_backends/protocols.py diff --git a/src/ndv/viewer2/_components.py b/src/ndv/viewer2/_components.py new file mode 100644 index 0000000..68d0675 --- /dev/null +++ b/src/ndv/viewer2/_components.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from enum import Enum +from pathlib import Path + +from qtpy.QtCore import QSize +from qtpy.QtGui import QMovie +from qtpy.QtWidgets import QLabel, QPushButton, QWidget +from superqt import QIconifyIcon + +SPIN_GIF = str(Path(__file__).parent / "spin.gif") + + +class DimToggleButton(QPushButton): + def __init__(self, parent: QWidget | None = None): + icn = QIconifyIcon("f7:view-2d", color="#333333") + icn.addKey("f7:view-3d", state=QIconifyIcon.State.On, color="white") + super().__init__(icn, "", parent) + self.setCheckable(True) + self.setChecked(True) + + +class QSpinner(QLabel): + def __init__(self, parent: QWidget | None = None): + super().__init__(parent) + size = QSize(16, 16) + mov = QMovie(SPIN_GIF, parent=self) + self.setFixedSize(size) + mov.setScaledSize(size) + mov.setSpeed(150) + mov.start() + self.setMovie(mov) + self.hide() + + +class ChannelMode(str, Enum): + COMPOSITE = "composite" + MONO = "mono" + + def __str__(self) -> str: + return self.value + + +class ChannelModeButton(QPushButton): + def __init__(self, parent: QWidget | None = None): + super().__init__(parent) + self.setCheckable(True) + self.toggled.connect(self.next_mode) + + # set minimum width to the width of the larger string 'composite' + self.setMinimumWidth(92) # magic number :/ + + def next_mode(self) -> None: + if self.isChecked(): + self.setMode(ChannelMode.MONO) + else: + self.setMode(ChannelMode.COMPOSITE) + + def mode(self) -> ChannelMode: + return ChannelMode.MONO if self.isChecked() else ChannelMode.COMPOSITE + + def setMode(self, mode: ChannelMode) -> None: + # we show the name of the next mode, not the current one + other = ChannelMode.COMPOSITE if mode is ChannelMode.MONO else ChannelMode.MONO + self.setText(str(other)) + self.setChecked(mode == ChannelMode.MONO) diff --git a/src/ndv/viewer2/_data_wrapper.py b/src/ndv/viewer2/_data_wrapper.py new file mode 100644 index 0000000..366e34f --- /dev/null +++ b/src/ndv/viewer2/_data_wrapper.py @@ -0,0 +1,418 @@ +"""In this module, we provide built-in support for many array types.""" + +from __future__ import annotations + +import logging +import sys +import warnings +from abc import abstractmethod +from contextlib import suppress +from typing import ( + TYPE_CHECKING, + ClassVar, + Container, + Generic, + Hashable, + Iterator, + Mapping, + Sequence, + TypeVar, +) + +import numpy as np + +if TYPE_CHECKING: + from pathlib import Path + from typing import Any, Protocol, TypeAlias, TypeGuard + + import dask.array as da + import numpy.typing as npt + import pyopencl.array as cl_array + import sparse + import tensorstore as ts + import torch + import xarray as xr + import zarr + from torch._tensor import Tensor + + from ._dims_slider import Index, Indices, Sizes + + _T_contra = TypeVar("_T_contra", contravariant=True) + + class SupportsIndexing(Protocol): + def __getitem__(self, key: Index | tuple[Index, ...]) -> npt.ArrayLike: ... + @property + def shape(self) -> tuple[int, ...]: ... + + class SupportsDunderLT(Protocol[_T_contra]): + def __lt__(self, other: _T_contra, /) -> bool: ... + + class SupportsDunderGT(Protocol[_T_contra]): + def __gt__(self, other: _T_contra, /) -> bool: ... + + SupportsRichComparison: TypeAlias = SupportsDunderLT[Any] | SupportsDunderGT[Any] + + +ArrayT = TypeVar("ArrayT") +_T = TypeVar("_T", bound=type) + +# Global executor for slice requests +# _EXECUTOR = ThreadPoolExecutor(max_workers=2) + + +def _recurse_subclasses(cls: _T) -> Iterator[_T]: + for subclass in cls.__subclasses__(): + yield subclass + yield from _recurse_subclasses(subclass) + + +class DataWrapper(Generic[ArrayT]): + """Interface for wrapping different array-like data types. + + `DataWrapper.create` is a factory method that returns a DataWrapper instance + for the given data type. If your datastore type is not supported, you may implement + a new DataWrapper subclass to handle your data type. To do this, import and + subclass DataWrapper, and (minimally) implement the supports and isel methods. + Ensure that your class is imported before the DataWrapper.create method is called, + and it will be automatically detected and used to wrap your data. + """ + + # Order in which subclasses are checked for support. + # Lower numbers are checked first, and the first supporting subclass is used. + # Default is 50, and fallback to numpy-like duckarray is 100. + # Subclasses can override this to change the priority in which they are checked + PRIORITY: ClassVar[SupportsRichComparison] = 50 + # These names will be checked when looking for a channel axis + COMMON_CHANNEL_NAMES: ClassVar[Container[str]] = ("channel", "ch", "c") + # Maximum dimension size consider when guessing the channel axis + MAX_CHANNELS = 16 + + @classmethod + def create(cls, data: ArrayT) -> DataWrapper[ArrayT]: + if isinstance(data, DataWrapper): + return data + # check subclasses for support + # This allows users to define their own DataWrapper subclasses which will + # be automatically detected (assuming they have been imported by this point) + for subclass in sorted(_recurse_subclasses(cls), key=lambda x: x.PRIORITY): + with suppress(Exception): + if subclass.supports(data): + logging.debug(f"Using {subclass.__name__} to wrap {type(data)}") + return subclass(data) + raise NotImplementedError(f"Don't know how to wrap type {type(data)}") + + @classmethod + @abstractmethod + def supports(cls, obj: Any) -> bool: + """Return True if this wrapper can handle the given object. + + Any exceptions raised by this method will be suppressed, so it is safe to + directly import necessary dependencies without a try/except block. + """ + raise NotImplementedError + + def __init__(self, data: ArrayT) -> None: + self._data = data + self._name2index: dict[str, int] = {} + if names := self.dimension_names(): + self._name2index = {name: i for i, name in enumerate(names)} + + @property + def data(self) -> ArrayT: + return self._data + + # @abstractmethod + # def isel(self, indexers: Indices) -> np.ndarray: + # """Select a slice from a data store using (possibly) named indices. + + # This follows the xarray-style indexing, where indexers is a mapping of + # dimension names to indices or slices. Subclasses should implement this + # method to return a numpy array. + # """ + # raise NotImplementedError + + def shape(self) -> tuple[int, ...]: + return self._data.shape # type: ignore + + def __getitem__(self, index: tuple[int | slice, ...]) -> np.ndarray: + # reimplement in subclasses + return np.asarray(self._data[index]) # type: ignore [index] + + def chunks(self) -> tuple[int, ...] | int | None: + if chunks := getattr(self._data, "chunks", None): + if isinstance(chunks, Sequence) and all(isinstance(x, int) for x in chunks): + return tuple(chunks) + warnings.warn( + f"Unexpected chunks attribute: {chunks!r}. Ignoring.", stacklevel=2 + ) + return None + + def dimension_names(self) -> tuple[str, ...] | None: + """Return the names of the dimensions of the data.""" + return None + + def to_conventional(self, indexers: Indices) -> tuple[int | slice, ...]: + """Convert named indices to a tuple of integers and slices.""" + _indexers = {self._name2index.get(str(k), k): v for k, v in indexers.items()} + return tuple(_indexers.get(k, slice(None)) for k in range(len(self.shape()))) + + # def isel_async( + # self, indexers: list[Indices] + # ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: + # """Asynchronous version of isel.""" + # return _EXECUTOR.submit(lambda: [(idx, self.isel(idx)) for idx in indexers]) + + def guess_channel_axis(self) -> int | None: + """Return the (best guess) axis name for the channel dimension.""" + # for arrays with labeled dimensions, + # see if any of the dimensions are named "channel" + shape = self.shape() + if names := self.dimension_names(): + for ax, name in enumerate(names): + if ( + name.lower() in self.COMMON_CHANNEL_NAMES + and shape[ax] <= self.MAX_CHANNELS + ): + return ax + + # for shaped arrays, use the smallest dimension as the channel axis + with suppress(ValueError): + if (smallest_dim := min(shape)) <= self.MAX_CHANNELS: + return shape.index(smallest_dim) + return None + + def save_as_zarr(self, save_loc: str | Path) -> None: + raise NotImplementedError("save_as_zarr not implemented for this data type.") + + def sizes(self) -> Sizes: + """Return a mapping of {dimkey: size} for the data. + + The default implementation uses the shape attribute of the data, and + tries to find dimension names in the `dims`, `names`, or `labels` attributes. + (`dims` is used by xarray, `names` is used by torch, etc...). If no labels + are found, the dimensions are just named by their integer index. + """ + shape = self.shape() + if (names := self.dimension_names()) and len(names) == len(shape): + return dict(zip(names, shape)) + return dict(enumerate(shape)) + + def summary_info(self) -> str: + """Return info label with information about the data.""" + package = getattr(self._data, "__module__", "").split(".")[0] + info = f"{package}.{getattr(type(self._data), '__qualname__', '')}" + + if sizes := self.sizes(): + # if all of the dimension keys are just integers, omit them from size_str + if all(isinstance(x, int) for x in sizes): + size_str = repr(tuple(sizes.values())) + # otherwise, include the keys in the size_str + else: + size_str = ", ".join(f"{k}:{v}" for k, v in sizes.items()) + size_str = f"({size_str})" + info += f" {size_str}" + if dtype := getattr(self._data, "dtype", ""): + info += f", {dtype}" + if nbytes := getattr(self._data, "nbytes", 0) / 1e6: + info += f", {nbytes:.2f}MB" + return info + + +class XarrayWrapper(DataWrapper["xr.DataArray"]): + """Wrapper for xarray DataArray objects.""" + + def isel(self, indexers: Indices) -> np.ndarray: + return np.asarray(self._data.isel(indexers)) + + def sizes(self) -> Mapping[Hashable, int]: + return {k: int(v) for k, v in self._data.sizes.items()} + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[xr.DataArray]: + if (xr := sys.modules.get("xarray")) and isinstance(obj, xr.DataArray): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + self._data.to_zarr(save_loc) + + +class TensorstoreWrapper(DataWrapper["ts.TensorStore"]): + """Wrapper for tensorstore.TensorStore objects.""" + + def __init__(self, data: Any) -> None: + super().__init__(data) + import tensorstore as ts + + self._ts = ts + + def sizes(self) -> Mapping[Hashable, int]: + return {dim.label: dim.size for dim in self._data.domain} + + def isel(self, indexers: Indices) -> np.ndarray: + result = ( + self._data[self._ts.d[tuple(indexers)][tuple(indexers.values())]] + .read() + .result() + ) + return np.asarray(result) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[ts.TensorStore]: + if (ts := sys.modules.get("tensorstore")) and isinstance(obj, ts.TensorStore): + return True + return False + + +class ArrayLikeWrapper(DataWrapper, Generic[ArrayT]): + """Wrapper for numpy duck array-like objects.""" + + PRIORITY = 100 + + def isel(self, indexers: Indices) -> np.ndarray: + idx = [] + for k in range(len(self._data.shape)): + i = indexers.get(k, slice(None)) + if isinstance(i, int): + idx.extend([i, None]) + else: + idx.append(i) + + return self._asarray(self._data[tuple(idx)]) + + def _asarray(self, data: npt.ArrayLike) -> np.ndarray: + return np.asarray(data) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[SupportsIndexing]: + if ( + ( + isinstance(obj, np.ndarray) + or hasattr(obj, "__array_function__") + or hasattr(obj, "__array_namespace__") + or hasattr(obj, "__array__") + ) + and hasattr(obj, "__getitem__") + and hasattr(obj, "shape") + ): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + try: + import zarr + except ImportError: + raise ImportError("zarr is required to save this data type.") from None + + if isinstance(self._data, zarr.Array): + self._data.store = zarr.DirectoryStore(save_loc) + else: + zarr.save(str(save_loc), self._data) + + +class DaskWrapper(DataWrapper["da.Array"]): + """Wrapper for dask array objects.""" + + def isel(self, indexers: Indices) -> np.ndarray: + idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) + return np.asarray(self._data[idx].compute()) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[da.Array]: + if (da := sys.modules.get("dask.array")) and isinstance(obj, da.Array): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + self._data.to_zarr(url=str(save_loc)) + + +class CLArrayWrapper(ArrayLikeWrapper["cl_array.Array"]): + """Wrapper for pyopencl array objects.""" + + PRIORITY = 50 + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[cl_array.Array]: + if (cl_array := sys.modules.get("pyopencl.array")) and isinstance( + obj, cl_array.Array + ): + return True + return False + + def _asarray(self, data: cl_array.Array) -> np.ndarray: + return np.asarray(data.get()) + + +class SparseArrayWrapper(ArrayLikeWrapper["sparse.Array"]): + PRIORITY = 50 + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[sparse.COO]: + if (sparse := sys.modules.get("sparse")) and isinstance(obj, sparse.COO): + return True + return False + + def _asarray(self, data: sparse.COO) -> np.ndarray: + return np.asarray(data.todense()) + + +class ZarrArrayWrapper(ArrayLikeWrapper["zarr.Array"]): + """Wrapper for zarr array objects.""" + + PRIORITY = 50 + + def __init__(self, data: Any) -> None: + super().__init__(data) + self._name2index: dict[Hashable, int] + if "_ARRAY_DIMENSIONS" in data.attrs: + self._name2index = { + name: i for i, name in enumerate(data.attrs["_ARRAY_DIMENSIONS"]) + } + else: + self._name2index = {i: i for i in range(data.ndim)} + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[zarr.Array]: + if (zarr := sys.modules.get("zarr")) and isinstance(obj, zarr.Array): + return True + return False + + def sizes(self) -> Sizes: + return dict(zip(self._name2index, self.data.shape)) + + def isel(self, indexers: Indices) -> np.ndarray: + # convert possibly named indices to integer indices + real_indexers = {self._name2index.get(k, k): v for k, v in indexers.items()} + return super().isel(real_indexers) + + +class TorchTensorWrapper(DataWrapper["torch.Tensor"]): + """Wrapper for torch tensor objects.""" + + def __init__(self, data: Tensor) -> None: + super().__init__(data) + self._name2index: dict[Hashable, int] + if names := getattr(data, "names", None): + # names may be something like (None, None, None)... + self._name2index = { + (i if name is None else name): i for i, name in enumerate(names) + } + else: + self._name2index = {i: i for i in range(data.ndim)} + + def sizes(self) -> Sizes: + return dict(zip(self._name2index, self.data.shape)) + + def isel(self, indexers: Indices) -> np.ndarray: + # convert possibly named indices to integer indices + real_indexers = {self._name2index.get(k, k): v for k, v in indexers.items()} + # convert to tuple of slices + idx = tuple(real_indexers.get(i, slice(None)) for i in range(self.data.ndim)) + return self.data[idx].numpy(force=True) # type: ignore [no-any-return] + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[torch.Tensor]: + if (torch := sys.modules.get("torch")) and isinstance(obj, torch.Tensor): + return True + return False diff --git a/src/ndv/viewer2/_dims_slider.py b/src/ndv/viewer2/_dims_slider.py new file mode 100644 index 0000000..69db980 --- /dev/null +++ b/src/ndv/viewer2/_dims_slider.py @@ -0,0 +1,529 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast +from warnings import warn + +from qtpy.QtCore import QPoint, QPointF, QSize, Qt, Signal +from qtpy.QtGui import QCursor, QResizeEvent +from qtpy.QtWidgets import ( + QDialog, + QDoubleSpinBox, + QFormLayout, + QFrame, + QHBoxLayout, + QLabel, + QPushButton, + QSizePolicy, + QSlider, + QSpinBox, + QVBoxLayout, + QWidget, +) +from superqt import QLabeledRangeSlider +from superqt.iconify import QIconifyIcon +from superqt.utils import signals_blocked + +if TYPE_CHECKING: + from typing import Hashable, Mapping, TypeAlias + + from qtpy.QtGui import QResizeEvent + + # any hashable represent a single dimension in an ND array + DimKey: TypeAlias = Hashable + # any object that can be used to index a single dimension in an ND array + Index: TypeAlias = int | slice + # a mapping from dimension keys to indices (eg. {"x": 0, "y": slice(5, 10)}) + # this object is used frequently to query or set the currently displayed slice + Indices: TypeAlias = Mapping[DimKey, Index] + # mapping of dimension keys to the maximum value for that dimension + Sizes: TypeAlias = Mapping[DimKey, int] + + +SS = """ +QSlider::groove:horizontal { + height: 15px; + background: qlineargradient( + x1:0, y1:0, x2:0, y2:1, + stop:0 rgba(128, 128, 128, 0.25), + stop:1 rgba(128, 128, 128, 0.1) + ); + border-radius: 3px; +} + +QSlider::handle:horizontal { + width: 38px; + background: #999999; + border-radius: 3px; +} + +QLabel { font-size: 12px; } + +QRangeSlider { qproperty-barColor: qlineargradient( + x1:0, y1:0, x2:0, y2:1, + stop:0 rgba(100, 80, 120, 0.2), + stop:1 rgba(100, 80, 120, 0.4) + )} + +SliderLabel { + font-size: 12px; + color: white; +} +""" + + +class QtPopup(QDialog): + """A generic popup window.""" + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.setModal(False) # if False, then clicking anywhere else closes it + self.setWindowFlags(Qt.WindowType.Popup | Qt.WindowType.FramelessWindowHint) + + self.frame = QFrame(self) + layout = QVBoxLayout(self) + layout.addWidget(self.frame) + layout.setContentsMargins(0, 0, 0, 0) + + def show_above_mouse(self, *args: Any) -> None: + """Show popup dialog above the mouse cursor position.""" + pos = QCursor().pos() # mouse position + szhint = self.sizeHint() + pos -= QPoint(szhint.width() // 2, szhint.height() + 14) + self.move(pos) + self.resize(self.sizeHint()) + self.show() + + +class PlayButton(QPushButton): + """Just a styled QPushButton that toggles between play and pause icons.""" + + fpsChanged = Signal(float) + + PLAY_ICON = "bi:play-fill" + PAUSE_ICON = "bi:pause-fill" + + def __init__(self, fps: float = 20, parent: QWidget | None = None) -> None: + icn = QIconifyIcon(self.PLAY_ICON, color="#888888") + icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On, color="#4580DD") + super().__init__(icn, "", parent) + self.spin = QDoubleSpinBox(self) + self.spin.setRange(0.5, 100) + self.spin.setValue(fps) + self.spin.valueChanged.connect(self.fpsChanged) + self.setCheckable(True) + self.setFixedSize(14, 18) + self.setIconSize(QSize(16, 16)) + self.setStyleSheet("border: none; padding: 0; margin: 0;") + + self._popup = QtPopup(self) + form = QFormLayout(self._popup.frame) + form.setContentsMargins(6, 6, 6, 6) + form.addRow("FPS", self.spin) + + def mousePressEvent(self, e: Any) -> None: + if e and e.button() == Qt.MouseButton.RightButton: + self._show_fps_dialog(e.globalPosition()) + else: + super().mousePressEvent(e) + + def _show_fps_dialog(self, pos: QPointF) -> None: + self._popup.show_above_mouse() + + +class LockButton(QPushButton): + LOCK_ICON = "uis:unlock" + UNLOCK_ICON = "uis:lock" + + def __init__(self, text: str = "", parent: QWidget | None = None) -> None: + icn = QIconifyIcon(self.LOCK_ICON, color="#888888") + icn.addKey(self.UNLOCK_ICON, state=QIconifyIcon.State.On, color="red") + super().__init__(icn, text, parent) + self.setCheckable(True) + self.setFixedSize(20, 20) + self.setIconSize(QSize(14, 14)) + self.setStyleSheet("border: none; padding: 0; margin: 0;") + + +class DimsSlider(QWidget): + """A single slider in the DimsSliders widget. + + Provides a play/pause button that toggles animation of the slider value. + Has a QLabeledSlider for the actual value. + Adds a label for the maximum value (e.g. "3 / 10") + """ + + valueChanged = Signal(object, object) # where object is int | slice + + def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.setStyleSheet(SS) + self._slice_mode = False + self._dim_key = dimension_key + + self._timer_id: int | None = None # timer for play button + self._play_btn = PlayButton(parent=self) + self._play_btn.fpsChanged.connect(self.set_fps) + self._play_btn.toggled.connect(self._toggle_animation) + + self._dim_key = dimension_key + self._dim_label = QLabel(str(dimension_key)) + self._dim_label.setSizePolicy( + QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred + ) + self._dim_label.setToolTip("Double-click to toggle slice mode") + + # note, this lock button only prevents the slider from updating programmatically + # using self.setValue, it doesn't prevent the user from changing the value. + self._lock_btn = LockButton(parent=self) + + self._pos_label = QSpinBox(self) + self._pos_label.valueChanged.connect(self._on_pos_label_edited) + self._pos_label.setButtonSymbols(QSpinBox.ButtonSymbols.NoButtons) + self._pos_label.setAlignment(Qt.AlignmentFlag.AlignRight) + self._pos_label.setStyleSheet( + "border: none; padding: 0; margin: 0; background: transparent" + ) + self._out_of_label = QLabel(self) + + self._int_slider = QSlider(Qt.Orientation.Horizontal) + self._int_slider.rangeChanged.connect(self._on_range_changed) + self._int_slider.valueChanged.connect(self._on_int_value_changed) + + self._slice_slider = slc = QLabeledRangeSlider(Qt.Orientation.Horizontal) + slc.setHandleLabelPosition(QLabeledRangeSlider.LabelPosition.LabelsOnHandle) + slc.setEdgeLabelMode(QLabeledRangeSlider.EdgeLabelMode.NoLabel) + slc.setVisible(False) + slc.rangeChanged.connect(self._on_range_changed) + slc.valueChanged.connect(self._on_slice_value_changed) + + self.installEventFilter(self) + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(2) + layout.addWidget(self._play_btn) + layout.addWidget(self._dim_label) + layout.addWidget(self._int_slider) + layout.addWidget(self._slice_slider) + layout.addWidget(self._pos_label) + layout.addWidget(self._out_of_label) + layout.addWidget(self._lock_btn) + self.setMinimumHeight(26) + + def resizeEvent(self, a0: QResizeEvent | None) -> None: + if isinstance(par := self.parent(), DimsSliders): + par.resizeEvent(None) + + def mouseDoubleClickEvent(self, a0: Any) -> None: + self._set_slice_mode(not self._slice_mode) + super().mouseDoubleClickEvent(a0) + + def containMaximum(self, max_val: int) -> None: + if max_val > self._int_slider.maximum(): + self._int_slider.setMaximum(max_val) + if max_val > self._slice_slider.maximum(): + self._slice_slider.setMaximum(max_val) + + def setMaximum(self, max_val: int) -> None: + self._int_slider.setMaximum(max_val) + self._slice_slider.setMaximum(max_val) + + def setMinimum(self, min_val: int) -> None: + self._int_slider.setMinimum(min_val) + self._slice_slider.setMinimum(min_val) + + def containMinimum(self, min_val: int) -> None: + if min_val < self._int_slider.minimum(): + self._int_slider.setMinimum(min_val) + if min_val < self._slice_slider.minimum(): + self._slice_slider.setMinimum(min_val) + + def setRange(self, min_val: int, max_val: int) -> None: + self._int_slider.setRange(min_val, max_val) + self._slice_slider.setRange(min_val, max_val) + + def value(self) -> Index: + if not self._slice_mode: + return self._int_slider.value() # type: ignore + start, *_, stop = cast("tuple[int, ...]", self._slice_slider.value()) + if start == stop: + return start + return slice(start, stop) + + def setValue(self, val: Index) -> None: + # variant of setValue that always updates the maximum + self._set_slice_mode(isinstance(val, slice)) + if self._lock_btn.isChecked(): + return + if isinstance(val, slice): + start = int(val.start) if val.start is not None else 0 + stop = ( + int(val.stop) if val.stop is not None else self._slice_slider.maximum() + ) + self._slice_slider.setValue((start, stop)) + else: + self._int_slider.setValue(val) + # self._slice_slider.setValue((val, val + 1)) + + def forceValue(self, val: Index) -> None: + """Set value and increase range if necessary.""" + if isinstance(val, slice): + if isinstance(val.start, int): + self.containMinimum(val.start) + if isinstance(val.stop, int): + self.containMaximum(val.stop) + else: + self.containMinimum(val) + self.containMaximum(val) + self.setValue(val) + + def _set_slice_mode(self, mode: bool = True) -> None: + if mode == self._slice_mode: + return + self._slice_mode = bool(mode) + self._slice_slider.setVisible(self._slice_mode) + self._int_slider.setVisible(not self._slice_mode) + # self._pos_label.setVisible(not self._slice_mode) + self.valueChanged.emit(self._dim_key, self.value()) + + def set_fps(self, fps: float) -> None: + self._play_btn.spin.setValue(fps) + self._toggle_animation(self._play_btn.isChecked()) + + def _toggle_animation(self, checked: bool) -> None: + if checked: + if self._timer_id is not None: + self.killTimer(self._timer_id) + interval = int(1000 / self._play_btn.spin.value()) + self._timer_id = self.startTimer(interval) + elif self._timer_id is not None: + self.killTimer(self._timer_id) + self._timer_id = None + + def timerEvent(self, event: Any) -> None: + """Handle timer event for play button, move to the next frame.""" + # TODO + # for now just increment the value by 1, but we should be able to + # take FPS into account better and skip additional frames if the timerEvent + # is delayed for some reason. + inc = 1 + if self._slice_mode: + val = cast(tuple[int, int], self._slice_slider.value()) + next_val = [v + inc for v in val] + if next_val[1] > self._slice_slider.maximum(): + # wrap around, without going below the min handle + next_val = [v - val[0] for v in val] + self._slice_slider.setValue(next_val) + else: + ival = self._int_slider.value() + ival = (ival + inc) % (self._int_slider.maximum() + 1) + self._int_slider.setValue(ival) + + def _on_pos_label_edited(self) -> None: + if self._slice_mode: + self._slice_slider.setValue( + (self._slice_slider.value()[0], self._pos_label.value()) + ) + else: + self._int_slider.setValue(self._pos_label.value()) + + def _on_range_changed(self, min: int, max: int) -> None: + self._out_of_label.setText(f"| {max}") + self._pos_label.setRange(min, max) + self.resizeEvent(None) + self.setVisible(min != max) + + def setVisible(self, visible: bool) -> None: + if self._has_no_range(): + visible = False + super().setVisible(visible) + + def _has_no_range(self) -> bool: + if self._slice_mode: + return bool(self._slice_slider.minimum() == self._slice_slider.maximum()) + return bool(self._int_slider.minimum() == self._int_slider.maximum()) + + def _on_int_value_changed(self, value: int) -> None: + self._pos_label.setValue(value) + if not self._slice_mode: + self.valueChanged.emit(self._dim_key, value) + + def _on_slice_value_changed(self, value: tuple[int, int]) -> None: + self._pos_label.setValue(int(value[1])) + with signals_blocked(self._int_slider): + self._int_slider.setValue(int(value[0])) + if self._slice_mode: + self.valueChanged.emit(self._dim_key, slice(value[0], value[1] + 1)) + + +class DimsSliders(QWidget): + """A Collection of DimsSlider widgets for each dimension in the data. + + Maintains the global current index and emits a signal when it changes. + """ + + valueChanged = Signal(dict) # dict is of type Indices + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self._locks_visible: bool | Mapping[DimKey, bool] = False + self._sliders: dict[DimKey, DimsSlider] = {} + self._current_index: dict[DimKey, Index] = {} + self._invisible_dims: set[DimKey] = set() + + self.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum) + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + def __contains__(self, key: DimKey) -> bool: + """Return True if the dimension key is present in the DimsSliders.""" + return key in self._sliders + + def slider(self, key: DimKey) -> DimsSlider: + """Return the DimsSlider widget for the given dimension key.""" + return self._sliders[key] + + def value(self) -> Indices: + """Return mapping of {dim_key -> current index} for each dimension.""" + return self._current_index.copy() + + def setValue(self, values: Indices) -> None: + """Set the current index for each dimension. + + Parameters + ---------- + values : Mapping[Hashable, int | slice] + Mapping of {dim_key -> index} for each dimension. If value is a slice, + the slider will be in slice mode. If the dimension is not present in the + DimsSliders, it will be added. + """ + if self._current_index == values: + return + with signals_blocked(self): + for dim, index in values.items(): + self.add_or_update_dimension(dim, index) + # FIXME: i don't know why this this is ever empty ... only happens on pyside6 + if val := self.value(): + self.valueChanged.emit(val) + + def minima(self) -> Sizes: + """Return mapping of {dim_key -> minimum value} for each dimension.""" + return {k: v._int_slider.minimum() for k, v in self._sliders.items()} + + def setMinima(self, values: Sizes) -> None: + """Set the minimum value for each dimension. + + Parameters + ---------- + values : Mapping[Hashable, int] + Mapping of {dim_key -> minimum value} for each dimension. + """ + for name, min_val in values.items(): + if name not in self._sliders: + self.add_dimension(name) + self._sliders[name].setMinimum(min_val) + + def maxima(self) -> Sizes: + """Return mapping of {dim_key -> maximum value} for each dimension.""" + return {k: v._int_slider.maximum() for k, v in self._sliders.items()} + + def setMaxima(self, values: Sizes) -> None: + """Set the maximum value for each dimension. + + Parameters + ---------- + values : Mapping[Hashable, int] + Mapping of {dim_key -> maximum value} for each dimension. + """ + for name, max_val in values.items(): + if name not in self._sliders: + self.add_dimension(name) + self._sliders[name].setMaximum(max_val) + + def set_locks_visible(self, visible: bool | Mapping[DimKey, bool]) -> None: + """Set the visibility of the lock buttons for all dimensions.""" + self._locks_visible = visible + for dim, slider in self._sliders.items(): + viz = visible if isinstance(visible, bool) else visible.get(dim, False) + slider._lock_btn.setVisible(viz) + + def add_dimension(self, key: DimKey, val: Index | None = None) -> None: + """Add a new dimension to the DimsSliders widget. + + Parameters + ---------- + key : Hashable + The name of the dimension. + val : int | slice, optional + The initial value for the dimension. If a slice, the slider will be in + slice mode. + """ + self._sliders[key] = slider = DimsSlider(dimension_key=key, parent=self) + if isinstance(self._locks_visible, dict) and key in self._locks_visible: + slider._lock_btn.setVisible(self._locks_visible[key]) + else: + slider._lock_btn.setVisible(bool(self._locks_visible)) + + val_int = val.start if isinstance(val, slice) else val + slider.setVisible(key not in self._invisible_dims) + if isinstance(val_int, int): + slider.setRange(val_int, val_int) + elif isinstance(val_int, slice): + slider.setRange(val_int.start or 0, val_int.stop or 1) + + val = val if val is not None else 0 + self._current_index[key] = val + slider.forceValue(val) + slider.valueChanged.connect(self._on_dim_slider_value_changed) + cast("QVBoxLayout", self.layout()).addWidget(slider) + + def set_dimension_visible(self, key: DimKey, visible: bool) -> None: + """Set the visibility of a dimension in the DimsSliders widget. + + Once a dimension is hidden, it will not be shown again until it is explicitly + made visible again with this method. + """ + if visible: + self._invisible_dims.discard(key) + if key in self._sliders: + self._current_index[key] = self._sliders[key].value() + else: + self._invisible_dims.add(key) + self._current_index.pop(key, None) + if key in self._sliders: + self._sliders[key].setVisible(visible) + + def remove_dimension(self, key: DimKey) -> None: + """Remove a dimension from the DimsSliders widget.""" + try: + slider = self._sliders.pop(key) + except KeyError: + warn(f"Dimension {key} not found in DimsSliders", stacklevel=2) + return + cast("QVBoxLayout", self.layout()).removeWidget(slider) + slider.deleteLater() + + def _on_dim_slider_value_changed(self, key: DimKey, value: Index) -> None: + self._current_index[key] = value + self.valueChanged.emit(self.value()) + + def add_or_update_dimension(self, key: DimKey, value: Index) -> None: + """Add a dimension if it doesn't exist, otherwise update the value.""" + if key in self._sliders: + self._sliders[key].forceValue(value) + else: + self.add_dimension(key, value) + + def resizeEvent(self, a0: QResizeEvent | None) -> None: + # align all labels + if sliders := list(self._sliders.values()): + for lbl in ("_dim_label", "_pos_label", "_out_of_label"): + lbl_width = max(getattr(s, lbl).sizeHint().width() for s in sliders) + for s in sliders: + getattr(s, lbl).setFixedWidth(lbl_width) + + super().resizeEvent(a0) + + def sizeHint(self) -> QSize: + return super().sizeHint().boundedTo(QSize(9999, 0)) diff --git a/src/ndv/viewer2/_lut_control.py b/src/ndv/viewer2/_lut_control.py new file mode 100644 index 0000000..1b9c1b3 --- /dev/null +++ b/src/ndv/viewer2/_lut_control.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence, cast + +import numpy as np +from qtpy.QtCore import Qt +from qtpy.QtWidgets import QCheckBox, QFrame, QHBoxLayout, QPushButton, QWidget +from superqt import QLabeledRangeSlider +from superqt.cmap import QColormapComboBox +from superqt.utils import signals_blocked + +from ._dims_slider import SS + +if TYPE_CHECKING: + from typing import Iterable + + import cmap + + from ._backends.protocols import PImageHandle + + +class CmapCombo(QColormapComboBox): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent, allow_user_colormaps=True, add_colormap_text="Add...") + self.setMinimumSize(120, 21) + # self.setStyleSheet("background-color: transparent;") + + def showPopup(self) -> None: + super().showPopup() + popup = self.findChild(QFrame) + popup.setMinimumWidth(self.width() + 100) + popup.move(popup.x(), popup.y() - self.height() - popup.height()) + + +class LutControl(QWidget): + def __init__( + self, + channel: Sequence[PImageHandle], + name: str = "", + parent: QWidget | None = None, + cmaplist: Iterable[Any] = (), + cmap: cmap.Colormap | None = None, + ) -> None: + super().__init__(parent) + self._channel = channel + self._name = name + + self._visible = QCheckBox(name) + self._visible.setChecked(True) + self._visible.toggled.connect(self._on_visible_changed) + + self._cmap = CmapCombo() + self._cmap.currentColormapChanged.connect(self._on_cmap_changed) + for handle in channel: + self._cmap.addColormap(handle.cmap) + for color in cmaplist: + self._cmap.addColormap(color) + if cmap is not None: + self._cmap.setCurrentColormap(cmap) + + self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) + self._clims.setStyleSheet(SS) + self._clims.setHandleLabelPosition( + QLabeledRangeSlider.LabelPosition.LabelsOnHandle + ) + self._clims.setEdgeLabelMode(QLabeledRangeSlider.EdgeLabelMode.NoLabel) + self._clims.setRange(0, 2**8) + self._clims.valueChanged.connect(self._on_clims_changed) + + self._auto_clim = QPushButton("Auto") + self._auto_clim.setMaximumWidth(42) + self._auto_clim.setCheckable(True) + self._auto_clim.setChecked(True) + self._auto_clim.toggled.connect(self.update_autoscale) + + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._visible) + layout.addWidget(self._cmap) + layout.addWidget(self._clims) + layout.addWidget(self._auto_clim) + + self.update_autoscale() + + def autoscaleChecked(self) -> bool: + return cast("bool", self._auto_clim.isChecked()) + + def _on_clims_changed(self, clims: tuple[float, float]) -> None: + self._auto_clim.setChecked(False) + for handle in self._channel: + handle.clim = clims + + def _on_visible_changed(self, visible: bool) -> None: + for handle in self._channel: + handle.visible = visible + if visible: + self.update_autoscale() + + def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: + for handle in self._channel: + handle.cmap = cmap + + def update_autoscale(self) -> None: + if ( + not self._auto_clim.isChecked() + or not self._visible.isChecked() + or not self._channel + ): + return + + # find the min and max values for the current channel + clims = [np.inf, -np.inf] + for handle in self._channel: + clims[0] = min(clims[0], np.nanmin(handle.data)) + clims[1] = max(clims[1], np.nanmax(handle.data)) + + mi, ma = tuple(int(x) for x in clims) + if mi != ma: + for handle in self._channel: + handle.clim = (mi, ma) + + # set the slider values to the new clims + with signals_blocked(self._clims): + self._clims.setMinimum(min(mi, self._clims.minimum())) + self._clims.setMaximum(max(ma, self._clims.maximum())) + self._clims.setValue((mi, ma)) + + +def _get_default_clim_from_data(data: np.ndarray) -> tuple[float, float]: + """Compute a reasonable clim from the min and max, taking nans into account. + + If there are no non-finite values (nan, inf, -inf) this is as fast as it can be. + Otherwise, this functions is about 3x slower. + """ + # Fast + min_value = data.min() + max_value = data.max() + + # Need more work? The nan-functions are slower + min_finite = np.isfinite(min_value) + max_finite = np.isfinite(max_value) + if not (min_finite and max_finite): + finite_data = data[np.isfinite(data)] + if finite_data.size: + min_value = finite_data.min() + max_value = finite_data.max() + else: + min_value = max_value = 0 # no finite values in the data + + return min_value, max_value diff --git a/src/ndv/viewer2/_octree.py b/src/ndv/viewer2/_octree.py new file mode 100644 index 0000000..db73aad --- /dev/null +++ b/src/ndv/viewer2/_octree.py @@ -0,0 +1,105 @@ +from typing import Any, Generic, Iterator, NamedTuple, TypeVar + +MAX_DEPTH = 8 + + +class Coord(NamedTuple): + x: float + y: float + z: float = 0 + + +class Bounds(NamedTuple): + x_min: float + x_max: float + y_min: float + y_max: float + z_min: float = 0 + z_max: float = 0 + + @property + def midpoint(self) -> Coord: + return Coord( + (self.x_min + self.x_max) / 2, + (self.y_min + self.y_max) / 2, + (self.z_min + self.z_max) / 2, + ) + + def isdisjoint(self, other: "Bounds") -> bool: + return ( + self.x_max < other.x_min + or self.x_min > other.x_max + or self.y_max < other.y_min + or self.y_min > other.y_max + or self.z_max < other.z_min + or self.z_min > other.z_max + ) + + def intersects(self, other: "Bounds") -> bool: + return not self.isdisjoint(other) + + def split(self) -> tuple["Bounds", ...]: + x_min, x_max, y_min, y_max, z_min, z_max = self + x_mid, y_mid, z_mid = self.midpoint + return ( + Bounds(x_min, x_mid, y_min, y_mid, z_min, z_mid), + Bounds(x_mid, x_max, y_min, y_mid, z_min, z_mid), + Bounds(x_min, x_mid, y_mid, y_max, z_min, z_mid), + Bounds(x_mid, x_max, y_mid, y_max, z_min, z_mid), + Bounds(x_min, x_mid, y_min, y_mid, z_mid, z_max), + Bounds(x_mid, x_max, y_min, y_mid, z_mid, z_max), + Bounds(x_min, x_mid, y_mid, y_max, z_mid, z_max), + Bounds(x_mid, x_max, y_mid, y_max, z_mid, z_max), + ) + + +T = TypeVar("T") + + +class OctreeNode(Generic[T]): + def __init__(self, bounds: Bounds, depth: int = 0) -> None: + # spatial bounds of this node (xmin, xmax, ymin, ymax, zmin, zmax) + self.bounds = bounds + self.children: list[OctreeNode] = [] # children of this node + self.depth = depth # depth of the node in the tree + self.data: T | None = None # placeholder for storing references to data chunks + + def is_leaf(self) -> bool: + return not self.children + + def split(self, max_depth: int = MAX_DEPTH) -> None: + if self.depth < max_depth: + self.children = [ + OctreeNode(bounds, self.depth + 1) for bounds in self.bounds.split() + ] + + def insert_data(self, chunk: T, chunk_bounds: Bounds) -> None: + """Insert data into the tree. + + If `self` is a leaf and the depth is less than the maximum depth, `self` is + split into 8 children. The data is then inserted into the first child that + intersects with the data bounds. + """ + if self.is_leaf(): + if self.depth < MAX_DEPTH: + self.split(MAX_DEPTH) + else: + self.data = chunk + return + for child in self.children: + if child.bounds.intersects(chunk_bounds): + child.insert_data(chunk, chunk_bounds) + break + + def query(self, view_bounds: Bounds) -> Iterator[T]: + """Query the tree for data chunks that intersect with the given bounds. + + Yields data chunks that intersect with the given bounds. + """ + if self.bounds.intersects(view_bounds): + if self.is_leaf(): + if self.data is not None: + yield self.data + else: + for child in self.children: + yield from child.query(view_bounds) diff --git a/src/ndv/viewer2/_save_button.py b/src/ndv/viewer2/_save_button.py new file mode 100644 index 0000000..0ce4511 --- /dev/null +++ b/src/ndv/viewer2/_save_button.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from qtpy.QtWidgets import QFileDialog, QPushButton, QWidget +from superqt.iconify import QIconifyIcon + +if TYPE_CHECKING: + from ._data_wrapper import DataWrapper + + +class SaveButton(QPushButton): + def __init__( + self, + data_wrapper: DataWrapper, + parent: QWidget | None = None, + ): + super().__init__(parent=parent) + self.setIcon(QIconifyIcon("mdi:content-save")) + self.clicked.connect(self._on_click) + + self._data_wrapper = data_wrapper + self._last_loc = str(Path.home()) + + def _on_click(self) -> None: + self._last_loc, _ = QFileDialog.getSaveFileName( + self, "Choose destination", str(self._last_loc), "" + ) + suffix = Path(self._last_loc).suffix + if suffix in (".zarr", ".ome.zarr", ""): + self._data_wrapper.save_as_zarr(self._last_loc) + else: + raise ValueError(f"Unsupported file format: {self._last_loc}") diff --git a/src/ndv/viewer/_state.py b/src/ndv/viewer2/_state.py similarity index 100% rename from src/ndv/viewer/_state.py rename to src/ndv/viewer2/_state.py diff --git a/src/ndv/viewer/_v2.py b/src/ndv/viewer2/_v2.py similarity index 95% rename from src/ndv/viewer/_v2.py rename to src/ndv/viewer2/_v2.py index 4bf7c41..7cce272 100644 --- a/src/ndv/viewer/_v2.py +++ b/src/ndv/viewer2/_v2.py @@ -6,12 +6,12 @@ from superqt.utils import qthrottled from ndv._chunk_executor import Chunker, ChunkFuture -from ndv.viewer._backends import get_canvas -from ndv.viewer._dims_slider import DimsSliders -from ndv.viewer._state import ViewerState +from ndv.viewer2._backends import get_canvas +from ndv.viewer2._dims_slider import DimsSliders +from ndv.viewer2._state import ViewerState if TYPE_CHECKING: - from ndv.viewer._backends.protocols import PCanvas, PImageHandle + from ndv.viewer2._backends.protocols import PCanvas, PImageHandle class NDViewer(QWidget): @@ -72,7 +72,7 @@ def _request_data_for_index(self, index: Mapping[int | str, int | slice]) -> Non # determine chunk shape # only visualized dimensions are chunked - chunk_size = 64 # TODO: pick bettter + chunk_size = 128 # TODO: pick bettter chunk_shape: list[int | None] = [None] * ndim visualized = [self._norm_index(dim) for dim in self._state.visualized_indices] for dim in range(ndim): @@ -87,7 +87,6 @@ def _request_data_for_index(self, index: Mapping[int | str, int | slice]) -> Non if not index: return - print("requesting data for index", index, chunk_shape) # clear existing handles for handle in self._channels.values(): handle.clear() @@ -137,6 +136,8 @@ def _draw_chunk(self, future: ChunkFuture) -> None: print("err in clim: ", e) handle.clim = (0, 5000) + print(">>draw:") + print(f" data: {data.shape} @ {offset}") handle.directly_set_texture_offset(data, offset) self._canvas.refresh() diff --git a/src/ndv/viewer2/_viewer.py b/src/ndv/viewer2/_viewer.py new file mode 100644 index 0000000..a4735f9 --- /dev/null +++ b/src/ndv/viewer2/_viewer.py @@ -0,0 +1,562 @@ +from __future__ import annotations + +from itertools import cycle +from typing import ( + TYPE_CHECKING, + Hashable, + Literal, + MutableSequence, + Sequence, + SupportsIndex, + cast, + overload, +) + +import cmap +import numpy as np +from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget +from superqt import QCollapsible, QElidingLabel, QIconifyIcon, ensure_main_thread +from superqt.utils import qthrottled, signals_blocked + +from ndv._chunking import Chunker, ChunkResponse, RequestFinished +from ndv.viewer2._components import ( + ChannelMode, + ChannelModeButton, + DimToggleButton, + QSpinner, +) + +from ._backends import get_canvas +from ._backends.protocols import PImageHandle +from ._data_wrapper import DataWrapper +from ._dims_slider import DimsSliders +from ._lut_control import LutControl + +if TYPE_CHECKING: + from typing import Any, Iterable, TypeAlias + + from qtpy.QtGui import QCloseEvent + + from ._backends.protocols import PCanvas + from ._dims_slider import DimKey, Indices, Sizes + + ImgKey: TypeAlias = Hashable + # any mapping of dimensions to sizes + SizesLike: TypeAlias = Sizes | Iterable[int | tuple[DimKey, int] | Sequence] + +MID_GRAY = "#888888" +GRAYS = cmap.Colormap("gray") +DEFAULT_COLORMAPS = [ + cmap.Colormap("green"), + cmap.Colormap("magenta"), + cmap.Colormap("cyan"), + cmap.Colormap("yellow"), + cmap.Colormap("red"), + cmap.Colormap("blue"), + cmap.Colormap("cubehelix"), + cmap.Colormap("gray"), +] +MONO_CHANNEL = -999999 + + +class Channel(MutableSequence[PImageHandle]): + def __init__(self, ch_key: int, cmap: cmap.Colormap = GRAYS) -> None: + self.ch_key = ch_key + self._handles: list[PImageHandle] = [] + self.cmap = cmap + + @overload + def __getitem__(self, i: int) -> PImageHandle: ... + @overload + def __getitem__(self, i: slice) -> list[PImageHandle]: ... + def __getitem__(self, i: int | slice) -> PImageHandle | list[PImageHandle]: + return self._handles[i] + + @overload + def __setitem__(self, i: SupportsIndex, value: PImageHandle) -> None: ... + @overload + def __setitem__(self, i: slice, value: Iterable[PImageHandle]) -> None: ... + def __setitem__( + self, i: SupportsIndex | slice, value: PImageHandle | Iterable[PImageHandle] + ) -> None: + self._handles[i] = value # type: ignore + + @overload + def __delitem__(self, i: int) -> None: ... + @overload + def __delitem__(self, i: slice) -> None: ... + def __delitem__(self, i: int | slice) -> None: + del self._handles[i] + + def __len__(self) -> int: + return len(self._handles) + + def insert(self, i: int, value: PImageHandle) -> None: + self._handles.insert(i, value) + + +class NDViewer(QWidget): + """A viewer for ND arrays. + + This widget displays a single slice from an ND array (or a composite of slices in + different colormaps). The widget provides sliders to select the slice to display, + and buttons to control the display mode of the channels. + + An important concept in this widget is the "index". The index is a mapping of + dimensions to integers or slices that define the slice of the data to display. For + example, a numpy slice of `[0, 1, 5:10]` would be represented as + `{0: 0, 1: 1, 2: slice(5, 10)}`, but dimensions can also be named, e.g. + `{'t': 0, 'c': 1, 'z': slice(5, 10)}`. The index is used to select the data from + the datastore, and to determine the position of the sliders. + + The flow of data is as follows: + + - The user sets the data using the `set_data` method. This will set the number + and range of the sliders to the shape of the data, and display the first slice. + - The user can then use the sliders to select the slice to display. The current + slice is defined as a `Mapping` of `{dim -> int|slice}` and can be retrieved + with the `_dims_sliders.value()` method. To programmatically set the current + position, use the `setIndex` method. This will set the values of the sliders, + which in turn will trigger the display of the new slice via the + `_request_data_for_index` method. + - `_request_data_for_index` is an asynchronous method that retrieves the data for + the given index from the datastore (using `_isel`) and queues the + `_draw_chunk` method to be called when the data is ready. The logic + for extracting data from the datastore is defined in `_data_wrapper.py`, which + handles idiosyncrasies of different datastores (e.g. xarray, tensorstore, etc). + - `_draw_chunk` is called when the data is ready, and updates the image. + Note that if the slice is multidimensional, the data will be reduced to 2D using + max intensity projection (and double-clicking on any given dimension slider will + turn it into a range slider allowing a projection to be made over that dimension). + - The image is displayed on the canvas, which is an object that implements the + `PCanvas` protocol (mostly, it has an `add_image` method that returns a handle + to the added image that can be used to update the data and display). This + small abstraction allows for various backends to be used (e.g. vispy, pygfx, etc). + + Parameters + ---------- + data : Any + The data to display. This can be any duck-like ND array, including numpy, dask, + xarray, jax, tensorstore, zarr, etc. You can add support for new datastores by + subclassing `DataWrapper` and implementing the required methods. See + `DataWrapper` for more information. + parent : QWidget, optional + The parent widget of this widget. + channel_axis : Hashable, optional + The axis that represents the channels in the data. If not provided, this will + be guessed from the data. + channel_mode : ChannelMode, optional + The initial mode for displaying the channels. If not provided, this will be + set to ChannelMode.MONO. + """ + + def __init__( + self, + data: DataWrapper | Any, + *, + colormaps: Iterable[cmap._colormap.ColorStopsLike] | None = None, + parent: QWidget | None = None, + channel_axis: int | None = None, + channel_mode: ChannelMode | str = ChannelMode.MONO, + ): + super().__init__(parent=parent) + + # ATTRIBUTES ---------------------------------------------------- + + # mapping of key to a list of objects that control image nodes in the canvas + self._channels: dict[int, Channel] = {} + + # mapping of same keys to the LutControl objects control image display props + self._lut_ctrls: dict[int, LutControl] = {} + + # the set of dimensions we are currently visualizing (e.g. (-2, -1) for 2D) + # this is used to control which dimensions have sliders and the behavior + # of isel when selecting data from the datastore + self._visualized_dims: set[DimKey] = set() + + # the axis that represents the channels in the data + self._channel_axis = channel_axis + self._channel_mode: ChannelMode = None # type: ignore # set in set_channel_mode + # colormaps that will be cycled through when displaying composite images + if colormaps is not None: + self._cmaps = [cmap.Colormap(c) for c in colormaps] + else: + self._cmaps = DEFAULT_COLORMAPS + self._cmap_cycle = cycle(self._cmaps) + + # number of dimensions to display + self._ndims: Literal[2, 3] = 2 + self._chunker = Chunker( + None, + # IMPORTANT + # chunking here will determine how non-visualized dims are reduced + # so chunkshape will need to change based on the set of visualized dims + chunks=(20, 100, 32, 32), + on_ready=self._draw_chunk, + ) + + # WIDGETS ---------------------------------------------------- + + # the button that controls the display mode of the channels + self._channel_mode_btn = ChannelModeButton(self) + self._channel_mode_btn.clicked.connect(self.set_channel_mode) + # button to reset the zoom of the canvas + self._set_range_btn = QPushButton( + QIconifyIcon("fluent:full-screen-maximize-24-filled"), "", self + ) + self._set_range_btn.clicked.connect(self._on_set_range_clicked) + + # button to change number of displayed dimensions + self._ndims_btn = DimToggleButton(self) + self._ndims_btn.clicked.connect(self._toggle_3d) + + # place to display dataset summary + self._data_info_label = QElidingLabel("", parent=self) + self._progress_spinner = QSpinner(self) + + # place to display arbitrary text + self._hover_info_label = QLabel("", self) + # the canvas that displays the images + self._canvas: PCanvas = get_canvas()(self._hover_info_label.setText) + self._canvas.set_ndim(self._ndims) + + # the sliders that control the index of the displayed image + self._dims_sliders = DimsSliders(self) + self._dims_sliders.valueChanged.connect( + qthrottled(self._request_data_for_index, 20, leading=True) + ) + + self._lut_drop = QCollapsible("LUTs", self) + self._lut_drop.setCollapsedIcon(QIconifyIcon("bi:chevron-down", color=MID_GRAY)) + self._lut_drop.setExpandedIcon(QIconifyIcon("bi:chevron-up", color=MID_GRAY)) + lut_layout = cast("QVBoxLayout", self._lut_drop.layout()) + lut_layout.setContentsMargins(0, 1, 0, 1) + lut_layout.setSpacing(0) + if ( + hasattr(self._lut_drop, "_content") + and (layout := self._lut_drop._content.layout()) is not None + ): + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # LAYOUT ----------------------------------------------------- + + self._btns = btns = QHBoxLayout() + btns.setContentsMargins(0, 0, 0, 0) + btns.setSpacing(0) + btns.addStretch() + btns.addWidget(self._channel_mode_btn) + btns.addWidget(self._ndims_btn) + btns.addWidget(self._set_range_btn) + + info = QHBoxLayout() + info.setContentsMargins(0, 0, 0, 2) + info.setSpacing(0) + info.addWidget(self._data_info_label) + info.addWidget(self._progress_spinner) + + layout = QVBoxLayout(self) + layout.setSpacing(2) + layout.setContentsMargins(6, 6, 6, 6) + layout.addLayout(info) + layout.addWidget(self._canvas.qwidget(), 1) + layout.addWidget(self._hover_info_label) + layout.addWidget(self._dims_sliders) + layout.addWidget(self._lut_drop) + layout.addLayout(btns) + + # SETUP ------------------------------------------------------ + + self.set_channel_mode(channel_mode) + if data is not None: + self.set_data(data) + + # ------------------- PUBLIC API ---------------------------- + @property + def dims_sliders(self) -> DimsSliders: + """Return the DimsSliders widget.""" + return self._dims_sliders + + @property + def data_wrapper(self) -> DataWrapper: + """Return the DataWrapper object around the datastore.""" + return self._data_wrapper + + @property + def data(self) -> Any: + """Return the data backing the view.""" + return self._data_wrapper.data + + @data.setter + def data(self, data: Any) -> None: + """Set the data backing the view.""" + raise AttributeError("Cannot set data directly. Use `set_data` method.") + + def set_data( + self, data: DataWrapper | Any, *, initial_index: Indices | None = None + ) -> None: + """Set the datastore, and, optionally, the sizes of the data. + + Properties + ---------- + data : DataWrapper | Any + The data to display. This can be any duck-like ND array, including numpy, + dask, xarray, jax, tensorstore, zarr, etc. You can add support for new + datastores by subclassing `DataWrapper` and implementing the required + methods. If a `DataWrapper` instance is passed, it is used directly. + See `DataWrapper` for more information. + initial_index : Indices | None + The initial index to display. This is a mapping of dimensions to integers + or slices that define the slice of the data to display. If not provided, + the initial index will be set to the middle of the data. + """ + # store the data + self._clear_images() + + self._data_wrapper = DataWrapper.create(data) + self._chunker.data_wrapper = self._data_wrapper + if chunks := self._data_wrapper.chunks(): + # temp hack ... always group non-visible channels + chunks = list(chunks) + chunks[:-2] = (1000,) * len(chunks[:-2]) + self._chunker.chunks = tuple(chunks) + + # set channel axis + self._channel_axis = self._data_wrapper.guess_channel_axis() + self._chunker.channel_axis = self._channel_axis + + # update the dimensions we are visualizing + sizes = self._data_wrapper.sizes() + visualized_dims = list(sizes)[-self._ndims :] + self.set_visualized_dims(visualized_dims) + + # update the range of all the sliders to match the sizes we set above + with signals_blocked(self._dims_sliders): + self._update_slider_ranges() + + # redraw + if initial_index is None: + idx = {k: int(v // 2) for k, v in sizes.items() if k not in visualized_dims} + else: + if not isinstance(initial_index, dict): # pragma: no cover + raise TypeError("initial_index must be a dict") + idx = initial_index + self.set_current_index(idx) + + # update the data info label + self._data_info_label.setText(self._data_wrapper.summary_info()) + + def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: + """Set the dimensions that will be visualized. + + This dims will NOT have sliders associated with them. + """ + self._visualized_dims = set(dims) + for d in self._dims_sliders._sliders: + self._dims_sliders.set_dimension_visible(d, d not in self._visualized_dims) + for d in self._visualized_dims: + self._dims_sliders.set_dimension_visible(d, False) + + def set_ndim(self, ndim: Literal[2, 3]) -> None: + """Set the number of dimensions to display.""" + if ndim not in (2, 3): + raise ValueError("ndim must be 2 or 3") + + self._ndims = ndim + self._canvas.set_ndim(ndim) + + # set the visibility of the last non-channel dimension + sizes = list(self._data_wrapper.sizes()) + if self._channel_axis is not None: + sizes = [x for x in sizes if x != self._channel_axis] + if len(sizes) >= 3: + dim3 = sizes[-3] + self._dims_sliders.set_dimension_visible(dim3, True if ndim == 2 else False) + + # clear image handles and redraw + if self._channels: + self._clear_images() + self._request_data_for_index(self._dims_sliders.value()) + + def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: + """Set the mode for displaying the channels. + + In "composite" mode, the channels are displayed as a composite image, using + self._channel_axis as the channel axis. In "grayscale" mode, each channel is + displayed separately. (If mode is None, the current value of the + channel_mode_picker button is used) + + Parameters + ---------- + mode : ChannelMode | str | None + The mode to set, must be one of 'composite' or 'mono'. + """ + # bool may happen when called from the button clicked signal + if mode is None or isinstance(mode, bool): + mode = self._channel_mode_btn.mode() + else: + mode = ChannelMode(mode) + self._channel_mode_btn.setMode(mode) + if mode == self._channel_mode: + return + + self._channel_mode = mode + self._cmap_cycle = cycle(self._cmaps) # reset the colormap cycle + if self._channel_axis is not None: + # set the visibility of the channel slider + self._dims_sliders.set_dimension_visible( + self._channel_axis, mode != ChannelMode.COMPOSITE + ) + + if self._channels: + self._clear_images() + self._request_data_for_index(self._dims_sliders.value()) + + def set_current_index(self, index: Indices | None = None) -> None: + """Set the index of the displayed image. + + `index` is a mapping of dimensions to integers or slices that define the slice + of the data to display. For example, a numpy slice of `[0, 1, 5:10]` would be + represented as `{0: 0, 1: 1, 2: slice(5, 10)}`, but dimensions can also be + named, e.g. `{'t': 0, 'c': 1, 'z': slice(5, 10)}` if the data has named + dimensions. + + Note, calling `.set_current_index()` with no arguments will force the widget + to redraw the current slice. + """ + self._dims_sliders.setValue(index or {}) + + # camelCase aliases + + dimsSliders = dims_sliders + setChannelMode = set_channel_mode + setData = set_data + setCurrentIndex = set_current_index + setVisualizedDims = set_visualized_dims + + # ------------------- PRIVATE METHODS ---------------------------- + + def _toggle_3d(self) -> None: + self.set_ndim(3 if self._ndims == 2 else 2) + + def _update_slider_ranges(self) -> None: + """Set the maximum values of the sliders. + + If `sizes` is not provided, sizes will be inferred from the datastore. + """ + maxes = self._data_wrapper.sizes() + self._dims_sliders.setMaxima({k: v - 1 for k, v in maxes.items()}) + + # FIXME: this needs to be moved and made user-controlled + for dim in list(maxes.keys())[-self._ndims :]: + self._dims_sliders.set_dimension_visible(dim, False) + + def _on_set_range_clicked(self) -> None: + # using method to swallow the parameter passed by _set_range_btn.clicked + self._canvas.set_range() + + def _image_key(self, index: Indices) -> ImgKey: + """Return the key for image handle(s) corresponding to `index`.""" + if self._channel_mode == ChannelMode.COMPOSITE: + val = index.get(self._channel_axis, 0) + if isinstance(val, slice): + return (val.start, val.stop) + return val + return 0 + + def _request_data_for_index(self, index: Indices) -> None: + """Retrieve data for `index` from datastore and update canvas image(s). + + This is the first step in updating the displayed image, it is triggered by + the valueChanged signal from the sliders. + + This will pull the data from the datastore using the given index, and update + the image handle(s) with the new data. This method is *asynchronous*. It + makes a request for the new data slice and queues _on_data_future_done to be + called when the data is ready. + """ + print(f"\n--------\nrequesting index {index}", self._channel_axis) + if ( + self._channel_mode == ChannelMode.COMPOSITE + and self._channel_axis is not None + ): + index = {**index, self._channel_axis: slice(None)} + self._progress_spinner.show() + # TODO: don't request channels not being displayed + # TODO: don't request if the data is already in the cache + self._chunker.request_index(index, ndims=self._ndims) + + @ensure_main_thread # type: ignore + def _draw_chunk(self, chunk: ChunkResponse) -> None: + """Actually update the image handle(s) with the (sliced) data. + + By this point, data should be sliced from the underlying datastore. Any + dimensions remaining that are more than the number of visualized dimensions + (currently just 2D) will be reduced using max intensity projection (currently). + """ + if chunk is RequestFinished: # fix typing + self._progress_spinner.hide() + for lut in self._lut_ctrls.values(): + lut.update_autoscale() + return + + if self._channel_mode == ChannelMode.MONO: + ch_key = MONO_CHANNEL + else: + ch_key = chunk.channel_index + + data = chunk.data + if data.ndim == 2: + return + # TODO: Channel object creation could be moved. + # having it here is the laziest... but means that the order of arrival + # of the chunks will determine the order of the channels in the LUTS + # (without additional logic to sort them by index, etc.) + if (handles := self._channels.get(ch_key)) is None: + handles = self._create_channel(ch_key) + + if not handles: + if data.ndim == 2: + handles.append(self._canvas.add_image(data, cmap=handles.cmap)) + elif data.ndim == 3: + empty = np.empty((60, 256, 256), dtype=np.uint16) + handles.append(self._canvas.add_volume(empty, cmap=handles.cmap)) + + handles[0].set_data(data, chunk.offset) + self._canvas.refresh() + + def _create_channel(self, ch_key: int) -> Channel: + # improve this + cmap = GRAYS if ch_key == MONO_CHANNEL else next(self._cmap_cycle) + + self._channels[ch_key] = channel = Channel(ch_key, cmap=cmap) + self._lut_ctrls[ch_key] = lut = LutControl( + channel, + f"Ch {ch_key}", + self, + cmaplist=self._cmaps + DEFAULT_COLORMAPS, + cmap=cmap, + ) + self._lut_drop.addWidget(lut) + return channel + + def _clear_images(self) -> None: + """Remove all images from the canvas.""" + for handles in self._channels.values(): + for handle in handles: + handle.remove() + self._channels.clear() + + # clear the current LutControls as well + for c in self._lut_ctrls.values(): + cast("QVBoxLayout", self.layout()).removeWidget(c) + c.deleteLater() + self._lut_ctrls.clear() + + def _is_idle(self) -> bool: + """Return True if no futures are running. Used for testing, and debugging.""" + return bool(self._chunker.pending_futures) + + def closeEvent(self, a0: QCloseEvent | None) -> None: + self._chunker.shutdown() + super().closeEvent(a0) diff --git a/src/ndv/viewer2/spin.gif b/src/ndv/viewer2/spin.gif new file mode 100644 index 0000000000000000000000000000000000000000..f54d0d6695e5793cd1dd2fd5ba5ac7b3d3e90f22 GIT binary patch literal 2384 zcmaJ>dpuNm8$WWXkXjN;bPSeU=7wP=Qo08CC!cEEQ>$B`pPxbV&8dU<&P0N^lbLZ?JUk^&5AgknN6aiWwB z!QqZNDhYukg;;X2G)gQMOm{JYCzgw4Jh22MMfxByU^t7-7wJ}}GQz_NG?7fs60so~ znTUWfkbFLefcB*Xk>H0P7DvKhC>T7JLcx=RNF+R!O2Og$rgO<+cA^jx$)|HUv$@oH zb9H7AN?>L(B;_YT9I8|-1gDB7@aNY;niubLE@yr%6gIZr#@;SBf~FW{Pz6Up=VEj8GJJExc^aK zZ_mRA-T(Rdr+?qS*LC;D&W=0nZMR!*-Ta~DM)UQi#%m2%wO1}*x_F`feBHUTwP#MB zI(ed|`unQN;}yrs%l`FUX^H0Ok;8`$78eyB_-Dbl`TO&7bM|FtWoD%BP1}?D&F)?5 z6qQnuoRlb+NfRVukx&r7lh2FeLLBxE*7kqIZri$L^QM^SjbBGaG8qvY=;2|Z>uDju zR0^3C6d15>Ezy4s0q^IF^YQk=!kY@|>EZ6S+STPNXD3I51GviGZl$fwm)2G*ESFoD zf3a+-*%DI|8W|esFMy8+hxZ1C#=t_r;iGFXjS4~>z8vC!$$U8v)NMH#yyqZQ z9G^S2Go+A=uYv#%_*8+9ArJucH41IV?^SINbM}TsZ2oiCkuqV}ud5@DXM0>Z9H*_T zICp37 z>U>UONsU{bv-tuSx?eCbR?A&BxCqzHun*(tj;e2W2=Hgi0-k8n3%r58G+( zKCf2^>9hUw&>gk*z*3kaLZNLKuV}ksus6ZdEu(2!WZu?in~baWbCkQYRA+_C2EO0< zi^^LLJyqXdui=^QJuq&*J>9tSr-_u|ko%|hEVkHRA312yU>RA@?ZtM+I?W+5f=NCV z60Fd=AwAmP57>BJYR&9sCMs85&R)3bsHURY!|_B+U(gtHcss@YYFV41iH!;FK7*%^ z)Z{rv-O$iY9|)38j7YC;arM*8!8U`j&0%yvKSZJJ>mQ^!zexYNzi{G}tzl*bwuEkF z7pLJ0+?&$;LvgBy<@mRGX%fGZW#8UnGmdQ~m>QKWFdkvlYMP5}dZfOw1JMm3bHL33 zJ-`g^51qC@rdD{B`YVntn^FDqKqubJkvBzkZO_An8kPgiK;ZlTu?^)=M( z?9e1pFpD2V+p!d;Ee;J&Id1&hbz>=|q1XD6G5Z)6k2YV4Y5jw}T57j^{9XEv9Fv`2 zx6&~6fuiUlLu3(Sk#*smJ^?dz(0JOEk7xq5GNb5GY41k{Bel%+z@eIB$wJ3H4do_e5BJxH#0wQqUUtdcd0h0#bR4v;Y7A literal 0 HcmV?d00001 diff --git a/tests/test_chunker.py b/tests/test_chunker.py index 546e967..1165c01 100644 --- a/tests/test_chunker.py +++ b/tests/test_chunker.py @@ -9,23 +9,23 @@ def test_iter_chunk_aligned_slices() -> None: shape=(10, 9), chunks=(4, 3), slices=np.index_exp[3:9, 1:None] ) assert list(x) == [ - (slice(3, 4, None), slice(1, 3, None)), - (slice(3, 4, None), slice(3, 6, None)), - (slice(3, 4, None), slice(6, 9, None)), - (slice(4, 8, None), slice(1, 3, None)), - (slice(4, 8, None), slice(3, 6, None)), - (slice(4, 8, None), slice(6, 9, None)), - (slice(8, 9, None), slice(1, 3, None)), - (slice(8, 9, None), slice(3, 6, None)), - (slice(8, 9, None), slice(6, 9, None)), + (slice(3, 4, 1), slice(1, 3, 1)), + (slice(3, 4, 1), slice(3, 6, 1)), + (slice(3, 4, 1), slice(6, 9, 1)), + (slice(4, 8, 1), slice(1, 3, 1)), + (slice(4, 8, 1), slice(3, 6, 1)), + (slice(4, 8, 1), slice(6, 9, 1)), + (slice(8, 9, 1), slice(1, 3, 1)), + (slice(8, 9, 1), slice(3, 6, 1)), + (slice(8, 9, 1), slice(6, 9, 1)), ] # this one tests that slices doesn't need to be the same length as shape # ... is added at the end y = iter_chunk_aligned_slices(shape=(6, 6), chunks=4, slices=np.index_exp[1:4]) assert list(y) == [ - (slice(1, 4, None), slice(0, 4, None)), - (slice(1, 4, None), slice(4, 6, None)), + (slice(1, 4, 1), slice(0, 4, 1)), + (slice(1, 4, 1), slice(4, 6, 1)), ] # this tests ellipsis in the middle @@ -33,8 +33,8 @@ def test_iter_chunk_aligned_slices() -> None: shape=(3, 3, 3), chunks=2, slices=np.index_exp[1, ..., :2] ) assert list(z) == [ - (slice(1, 2, None), slice(0, 2, None), slice(0, 2, None)), - (slice(1, 2, None), slice(2, 3, None), slice(0, 2, None)), + (slice(1, 2, 1), slice(0, 2, 1), slice(0, 2, 1)), + (slice(1, 2, 1), slice(2, 3, 1), slice(0, 2, 1)), ] @@ -57,7 +57,7 @@ def test_chunker() -> None: new = np.empty_like(data2[0]) for future in futures: result = future.result() - new[result.array_location[1:]] = result.data + new[result.location[1:]] = result.data npt.assert_array_equal(new, data2[0]) diff --git a/x.py b/x.py index 812e058..cebe52b 100644 --- a/x.py +++ b/x.py @@ -3,7 +3,7 @@ from qtpy.QtWidgets import QApplication import ndv -from ndv.viewer._v2 import NDViewer +from ndv.viewer2._v2 import NDViewer data = ndv.data.cells3d() app = QApplication(sys.argv) From 15c93c836a557683263302e54fb078dc93e14fab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jun 2024 19:15:29 +0000 Subject: [PATCH 12/12] style(pre-commit.ci): auto fixes [...] --- src/ndv/viewer2/_octree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ndv/viewer2/_octree.py b/src/ndv/viewer2/_octree.py index db73aad..e1af860 100644 --- a/src/ndv/viewer2/_octree.py +++ b/src/ndv/viewer2/_octree.py @@ -1,4 +1,4 @@ -from typing import Any, Generic, Iterator, NamedTuple, TypeVar +from typing import Generic, Iterator, NamedTuple, TypeVar MAX_DEPTH = 8