diff --git a/src/ndv/__init__.py b/src/ndv/__init__.py index 7faa6ea..a699742 100644 --- a/src/ndv/__init__.py +++ b/src/ndv/__init__.py @@ -13,8 +13,8 @@ from . import data from .util import imshow -from .viewer._data_wrapper import DataWrapper -from .viewer._viewer import NDViewer +from .viewer2._data_wrapper import DataWrapper +from .viewer2._viewer import NDViewer __all__ = ["NDViewer", "DataWrapper", "imshow", "data"] @@ -23,7 +23,7 @@ # these may be used externally, but are not guaranteed to be available at runtime # they must be used inside a TYPE_CHECKING block - from .viewer._dims_slider import DimKey as DimKey - from .viewer._dims_slider import Index as Index - from .viewer._dims_slider import Indices as Indices - from .viewer._dims_slider import Sizes as Sizes + from .viewer2._dims_slider import DimKey as DimKey + from .viewer2._dims_slider import Index as Index + from .viewer2._dims_slider import Indices as Indices + from .viewer2._dims_slider import Sizes as Sizes diff --git a/src/ndv/_chunk_executor.py b/src/ndv/_chunk_executor.py new file mode 100644 index 0000000..5cb8501 --- /dev/null +++ b/src/ndv/_chunk_executor.py @@ -0,0 +1,426 @@ +from __future__ import annotations + +import math +from collections import deque +from concurrent.futures import Executor, Future, ThreadPoolExecutor +from itertools import product +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Iterator, + Literal, + Mapping, + NamedTuple, + Protocol, + Sequence, + SupportsIndex, + cast, +) + +import numpy as np + +if TYPE_CHECKING: + from types import EllipsisType + from typing import TypeAlias + + import numpy.typing as npt + + class SupportsDunderLT(Protocol): + def __lt__(self, other: Any) -> bool: ... + + class SupportsDunderGT(Protocol): + def __gt__(self, other: Any) -> bool: ... + + SupportsComparison: TypeAlias = SupportsDunderLT | SupportsDunderGT + +NULL = object() + + +class SupportsChunking(Protocol): + @property + def shape(self) -> Sequence[int]: ... + def __getitem__(self, idx: tuple[int | slice, ...]) -> npt.ArrayLike: ... + + +class ChunkResponse(NamedTuple): + # location in the original array + location: tuple[int | slice, ...] + # the data that was returned + data: np.ndarray + + @property + def offset(self) -> tuple[int, ...]: + return tuple(i.start if isinstance(i, slice) else i for i in self.location) + + +ChunkFuture = Future[ChunkResponse] + + +class Chunker: + def __init__(self, executor: Executor | None = None) -> None: + self._executor = executor or self._default_executor() + self._pending_futures: deque[ChunkFuture] = deque() + self._request_chunk = _get_chunk + + @classmethod + def _default_executor(cls) -> ThreadPoolExecutor: + return ThreadPoolExecutor(thread_name_prefix=cls.__name__) + + def is_idle(self) -> bool: + return all(f.done() for f in self._pending_futures) + + def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: + self._executor.shutdown(wait=wait, cancel_futures=cancel_futures) + + def __enter__(self) -> Chunker: + return self + + def __exit__(self, *_: Any) -> Literal[False]: + self.shutdown(wait=True) + return False + + def request_chunks( + self, + data: SupportsChunking, + index: Mapping[int, int | slice] | tuple[int | slice, ...] | None = None, + # None implies no chunking + chunk_shape: int | None | Sequence[int | None] = None, + *, + sort_key: Callable[[tuple[int | slice, ...]], SupportsComparison] | None = NULL, # type: ignore + cancel_existing: bool = False, + ) -> list[ChunkFuture]: + """Request chunks from `data` based on the given `index` and `chunk_shape`. + + Parameters + ---------- + data : SupportsChunking + The data to request chunks from. Must have a `shape` attribute and support + indexing (`__getitem__`) with a tuple of int or slice. + index : Mapping[int, int | slice] | tuple[int | slice, ...] | None + A subarray to request. + If a Mapping, it should look like {dim: index} where `index` is a single + index or a slice and dim is the dimension to index. + If an tuple, it should be a regular tuple of integer or slice: + e.g. (6, 0, slice(None, None, None), slice(1, 10, None)) + If `None` (default), the full array is requested. + chunk_shape : int | tuple[int, ...] | None + The shape of each chunk. If a single int, the same size is used for all + dimensions. If `None`, no chunking is done. Note that chunk shape applies + to the data *prior* to indexing. Chunks will be aligned with the original + data, not the indexed data... so given an axis `0` with length 100, if you + request a slice from that index `index={0: slice(40,60)}` and provide a + `chunk_shape=50`, you will get two chunks: (40, 50) and (50, 60). + The intention is that chunk_shape should align with the chunk layout of the + original data, to optimize reading from disk or other sources, even when + reading a subset of the data that is not aligned with the chunks. + sort_key: Callable[[tuple[slice, ...]], SupportsComparison] | None + A function to sort the chunks before submitting them. This can be used to + prioritize chunks that are more likely to be needed first (such as those + within a certain distance of the current view). The function should take + a tuple of slices and return a value that can be compared with `<` and `>`. + If None, no sorting is done. + cancel_existing : bool + If True, cancel any existing pending futures before submitting new ones. + + Returns + ------- + list[Future[ChunkResponse]] + A list of futures that will contain the requested chunks when they are + available. Use `Future.add_done_callback` to register a callback to handle + the results. + """ + if cancel_existing: + for future in list(self._pending_futures): + future.cancel() + + if index is None: + index = tuple(slice(None) for _ in range(len(data.shape))) + + if isinstance(index, Mapping): + index = indexers_to_conventional_slice(index) + + # at this point, index is a tuple of int or slice + # e.g. (6, 0, slice(None, None, None), slice(1, 10, None)) + # now, determine the subchunk indices to request + if chunk_shape is None: + # TODO: check whether we need to cast this to something without integers. + indices: Iterable[tuple[int | slice, ...]] = [index] + else: + indices = iter_chunk_aligned_slices(data.shape, chunk_shape, index) + if sort_key is not None: + if sort_key is NULL: + + def sort_key(x: tuple[int | slice, ...]) -> SupportsComparison: + return distance_from_coord(x, data.shape) + + indices = sorted(indices, key=sort_key) + + # submit the a request for each subchunk + futures = [] + for chunk_index in indices: + future = self._executor.submit(self._request_chunk, data, chunk_index) + self._pending_futures.append(future) + future.add_done_callback(self._pending_futures.remove) + futures.append(future) + return futures + + +def _get_chunk(data: SupportsChunking, index: tuple[int | slice, ...]) -> ChunkResponse: + chunk_data = _reduce_data_for_display(data[index], len(data.shape)) + # import time + + # time.sleep(0.05) + return ChunkResponse(location=index, data=chunk_data) + + +def indexers_to_conventional_slice( + indexers: Mapping[int, int | slice], ndim: int | None = None +) -> tuple[int | slice, ...]: + """Convert Mapping of {dim: index} to a conventional tuple of int or slice. + + `indexers` need not be ordered. If `ndim` is not provided, it is inferred + from the maximum key in `indexers`. + + Parameters + ---------- + indexers : Mapping[int, int | slice] + Mapping of {dim: index} where `index` is a single index or a slice. + ndim : int | None + Number of dimensions. If None, inferred from the maximum key in `indexers`. + + Examples + -------- + >>> indexers_to_conventional_slice({1: 0, 0: 6, 3: slice(1, 10, None)}) + (6, 0, slice(None, None, None), slice(1, 10, None)) + + """ + if not indexers: + return (slice(None),) + + if ndim is None: + ndim = max(indexers) + 1 + return tuple(indexers.get(k, slice(None)) for k in range(ndim)) + + +def _slice_indices(sl: SupportsIndex | slice, dim_size: int) -> tuple[int, int, int]: + """Convert slice to range arguments, handling single int as well. + + Examples + -------- + >>> _slice2range(3, 10) + (3, 4, 1) + >>> _slice2range(slice(1, 4), 10) + (1, 4, 1) + >>> _slice2range(slice(1, None), 10) + (1, 10, 1) + """ + if isinstance(sl, slice): + return sl.indices(dim_size) + return (sl.__index__(), sl.__index__() + 1, 1) + + +def iter_chunk_aligned_slices( + shape: Sequence[int], + chunks: int | Sequence[int | None], + slices: Sequence[int | slice | EllipsisType], +) -> Iterator[tuple[slice, ...]]: + """Yield chunk-aligned slices for a given shape and slices. + + Parameters + ---------- + shape : tuple[int, ...] + The shape of the array to slice. + chunks : int or tuple[int, ...] + The size of each chunk. If a single int, the same size is used for all + dimensions. + slices : tuple[int | slice | Ellipsis, ...] + The full slices to apply to the array. Ellipsis is supported to + represent multiple slices. + + Returns + ------- + Iterator[tuple[slice, ...]] + An iterator of chunk-aligned slices. + + Raises + ------ + ValueError + If the length of `chunks`, `shape`, and `slices` do not match, or any chunks + are zero. + IndexError + If more than one Ellipsis is present in `slices`. + + Examples + -------- + >>> list( + ... iter_chunk_aligned_slices(shape=(6, 6), chunks=4, slices=(slice(1, 4), ...)) + ... ) + [ + (slice(1, 4, None), slice(0, 4, None)), + (slice(1, 4, None), slice(4, 6, None)), + ] + + >>> x = iter_chunk_aligned_slices( + ... shape=(10, 9), chunks=(4, 3), slices=(slice(3, 9), slice(1, None)) + ... ) + >>> list(x) + [ + (slice(3, 4, None), slice(1, 3, None)), + (slice(3, 4, None), slice(3, 6, None)), + (slice(3, 4, None), slice(6, 9, None)), + (slice(4, 8, None), slice(1, 3, None)), + (slice(4, 8, None), slice(3, 6, None)), + (slice(4, 8, None), slice(6, 9, None)), + (slice(8, 9, None), slice(1, 3, None)), + (slice(8, 9, None), slice(3, 6, None)), + (slice(8, 9, None), slice(6, 9, None)), + ] + """ + # Make chunks same length as shape if single int + ndim = len(shape) + if isinstance(chunks, int): + chunks = (chunks,) * ndim + elif not len(chunks) == ndim: + raise ValueError("Length of `chunks` must match length of `shape`") + + if any(x == 0 for x in chunks): + raise ValueError("Chunk size must be greater than zero") + + # convert any `None` chunks to full size of the dimension + chunks = tuple(x if x is not None else shape[i] for i, x in enumerate(chunks)) + + if num_ellipsis := slices.count(Ellipsis): + if num_ellipsis > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + # Replace Ellipsis with multiple slices + el_idx = slices.index(Ellipsis) + n_remaining = ndim - len(slices) + 1 + slices = ( + tuple(slices[:el_idx]) + + (slice(None),) * n_remaining + + tuple(slices[el_idx + 1 :]) + ) + slices = cast(tuple[int | slice, ...], slices) # now we have no Ellipsis + if not ndim == len(slices): + # Fill in remaining dimensions with full slices + slices = slices + (slice(None),) * (ndim - len(slices)) + + # Create ranges for each dimension based on the slices provided + ranges = [_slice_indices(sl, dim) for sl, dim in zip(slices, shape)] + + # Generate indices for each dimension that align with chunks + aligned_ranges = ( + range(start - (start % chunk_size), stop, chunk_size) + for (start, stop, _), chunk_size in zip(ranges, chunks) + ) + + # Create all combinations of these aligned ranges + for indices in product(*aligned_ranges): + chunk_slices = [] + for idx, (start, stop, step), ch in zip(indices, ranges, chunks): + # Calculate the actual slice for each dimension + start = max(start, idx) + stop = min(stop, idx + ch) + if start >= stop: # Skip empty slices + break + chunk_slices.append(slice(start, stop, step)) + else: + # Only add this combination of slices if all dimensions are valid + yield tuple(chunk_slices) + + +def _slice_center(s: slice | int, dim_size: int) -> float: + """Calculate the center of a slice based on its start and stop attributes.""" + if isinstance(s, int): + return s + start = float(s.start) if s.start is not None else 0 + stop = float(s.stop) if s.stop is not None else dim_size + return (start + stop) / 2 + + +def distance_from_coord( + slice_tuple: Sequence[slice | int], + shape: Sequence[int], + coord: Iterable[float] = (), # defaults to center of shape +) -> float: + """Euclidean distance from the center of an nd slice to the center of shape.""" + if not coord: + coord = (dim / 2 for dim in shape) + slice_centers = (_slice_center(s, dim) for s, dim in zip(slice_tuple, shape)) + return math.hypot(*(sc - cc for sc, cc in zip(slice_centers, coord))) + + +def _reduce_data_for_display( + data: npt.ArrayLike, ndims: int, reductor: Callable[..., np.ndarray] = np.max +) -> np.ndarray: + """Reduce the number of dimensions in the data for display. + + This function takes a data array and reduces the number of dimensions to + the max allowed for display. The default behavior is to reduce the smallest + dimensions, using np.max. This can be improved in the future. + + This also coerces 64-bit data to 32-bit data. + """ + # TODO + # - allow dimensions to control how they are reduced (as opposed to just max) + # - for better way to determine which dims need to be reduced (currently just + # the first extra dims) + data = np.asarray(data).squeeze() + if extra_dims := data.ndim - ndims: + axis = tuple(range(extra_dims)) + data = reductor(data, axis=axis) + + if data.dtype.itemsize > 4: # More than 32 bits + if np.issubdtype(data.dtype, np.integer): + data = data.astype(np.int32) + else: + data = data.astype(np.float32) + return data + + +# class DaskChunker: +# def __init__(self) -> None: +# try: +# import dask +# import dask.array as da +# from dask.distributed import Client +# except ImportError as e: +# raise ImportError("Dask is required for DaskChunker") from e +# self._dask = dask +# self._da = da +# self._client = Client() + +# def request_chunks( +# self, +# data: SupportsChunking, +# index: Mapping[int, int | slice] | IndexTuple | None = None, +# chunk_shape: int | tuple[int, ...] | None = None, # None implies no chunking +# *, +# cancel_existing: bool = False, +# ) -> list[Future[ChunkResponse]]: +# if isinstance(index, Mapping): +# index = indexers_to_conventional_slice(index) + +# if isinstance(data, self._da.Array): # type: ignore +# dask_data = data +# else: +# dask_data = self._da.from_array(data, chunks=chunk_shape) # type: ignore + +# subarray = dask_data[index] +# block_ranges = (range(x) for x in subarray.numblocks) +# for blk in product(*(block_ranges)): +# offset = tuple(sum(sizes[:x]) for sizes, x in zip(subarray.chunks, blk)) +# chunk = subarray.blocks[blk] +# future = self._client.compute(chunk) + +# @no_type_check +# def _set_result(_chunk=chunk, _future=future): +# _future.set_result( +# ChunkResponse(idx=index, data=_chunk.compute(), offset=offset) +# ) + +# futures.append() + +# return [data] diff --git a/src/ndv/_chunking.py b/src/ndv/_chunking.py new file mode 100644 index 0000000..f712747 --- /dev/null +++ b/src/ndv/_chunking.py @@ -0,0 +1,362 @@ +from __future__ import annotations + +import math +from concurrent.futures import Future, ThreadPoolExecutor +from itertools import product +from typing import ( + TYPE_CHECKING, + Any, + Deque, + Hashable, + Literal, + Mapping, + NamedTuple, + Sequence, + cast, +) + +import numpy as np +from rich import print + +if TYPE_CHECKING: + from collections import deque + from types import EllipsisType + from typing import Callable, Iterable, Iterator, TypeAlias + + from .viewer2._data_wrapper import DataWrapper + +# any hashable represent a single dimension in an ND array +DimKey: TypeAlias = Hashable +# any object that can be used to index a single dimension in an ND array +Index: TypeAlias = int | slice +# a mapping from dimension keys to indices (eg. {"x": 0, "y": slice(5, 10)}) +# this object is used frequently to query or set the currently displayed slice +Indices: TypeAlias = Mapping[DimKey, Index] +# mapping of dimension keys to the maximum value for that dimension +Sizes: TypeAlias = Mapping[DimKey, int] + + +class ChunkResponse(NamedTuple): + idx: tuple[int | slice, ...] # index that was requested + data: np.ndarray # the data that was returned + offset: tuple[int, ...] # offset of the data in the full array (derived from idx) + channel_index: int = -1 + + +# sentinel value +RequestFinished = ChunkResponse((), np.empty(0), ()) + + +class Chunker: + def __init__( + self, + data_wrapper: DataWrapper | None = None, + chunks: int | tuple[int, ...] | None = None, + on_ready: Callable[[ChunkResponse], Any] | None = None, + ) -> None: + self.chunks = chunks + self.data_wrapper: DataWrapper | None = data_wrapper + self.executor = ThreadPoolExecutor(max_workers=4) + self.pending_futures: deque[Future[ChunkResponse]] = Deque() + self.on_ready = on_ready + self.channel_axis: int | None = None + self._notification_sent = True + + def __del__(self) -> None: + self.shutdown() + + def shutdown(self, cancel_futures: bool = True, wait: bool = True) -> None: + self.executor.shutdown(cancel_futures=cancel_futures, wait=wait) + + def _request_chunk_sync( + self, idx: tuple[int | slice, ...], channel_axis: int | None, ndims: int + ) -> ChunkResponse: + # idx is guaranteed to have length equal to the number of dimensions + if channel_axis is not None: + channel_index = idx[channel_axis] + if isinstance(channel_index, slice): + channel_index = channel_index.start + else: + channel_index = -1 + + data = self.data_wrapper[idx] # type: ignore [index] + data = _reduce_data_for_display(data, ndims) + # FIXME: temporary + # this needs to be aware of nvisible dimensions + try: + offset = tuple(int(getattr(sl, "start", sl)) for sl in idx)[-3:] + offset = (idx[0].start, idx[2].start, idx[3].start) + except TypeError: + offset = (0, 0, 0) + # import time + + # time.sleep(0.05) + return ChunkResponse( + idx=idx, data=data, offset=offset, channel_index=channel_index + ) + + def request_index( + self, index: Indices, *, cancel_existing: bool = True, ndims: Literal[2, 3] = 2 + ) -> None: + if cancel_existing: + for future in list(self.pending_futures): + future.cancel() + + # TODO: see if we can get the channel_axis logic back to the viewer/request side + + if self.data_wrapper is None: + return + idx = self.data_wrapper.to_conventional(index) + chunks = self.chunks + multi_channel = isinstance(index.get(self.channel_axis), slice) + if chunks is None and not multi_channel: + subchunks: Iterable[tuple[int | slice, ...]] = [idx] + else: + shape = self.data_wrapper.data.shape + + # TODO + # we should *only* chunk along visualized axes ... + + # we never chunk the channel axis + if isinstance(chunks, int): + _chunks = [chunks] * len(shape) + elif chunks is None: + # hack FIXME + _chunks = [100000] * len(shape) + else: + _chunks = list(chunks) + if self.channel_axis is not None: + _chunks[self.channel_axis] = 1 + + # TODO: allow the viewer to pass a center coord, to load chunks + # preferentially around that point + subchunks = sorted( + iter_chunk_aligned_slices(shape, _chunks, idx), + key=lambda x: distance_from_coord(x, shape), + ) + print("Requesting index:", idx) + print("subchunks", subchunks) + print() + self._notification_sent = False + for chunk_idx in subchunks: + future = self.executor.submit( + self._request_chunk_sync, chunk_idx, self.channel_axis, ndims + ) + self.pending_futures.append(future) + future.add_done_callback(self._on_chunk_ready) + + def _on_chunk_ready(self, future: Future[ChunkResponse]) -> None: + self.pending_futures.remove(future) + if future.cancelled(): + return + if err := future.exception(): + print(f"{type(err).__name__}: in chunk request: {err}") + return + if self.on_ready is not None: + self.on_ready(future.result()) + if not self.pending_futures and not self._notification_sent: + # Fix typing + self._notification_sent = True + self.on_ready(RequestFinished) + + +def _reduce_data_for_display( + data: np.ndarray, ndims: int, reductor: Callable[..., np.ndarray] = np.max +) -> np.ndarray: + """Reduce the number of dimensions in the data for display. + + This function takes a data array and reduces the number of dimensions to + the max allowed for display. The default behavior is to reduce the smallest + dimensions, using np.max. This can be improved in the future. + + This also coerces 64-bit data to 32-bit data. + """ + # TODO + # - allow dimensions to control how they are reduced (as opposed to just max) + # - for better way to determine which dims need to be reduced (currently just + # the first extra dims) + data = data.squeeze() + if extra_dims := data.ndim - ndims: + axis = tuple(range(extra_dims)) + data = reductor(data, axis=axis) + + if data.dtype.itemsize > 4: # More than 32 bits + if np.issubdtype(data.dtype, np.integer): + data = data.astype(np.int32) + else: + data = data.astype(np.float32) + return data + + +def _slice2range(sl: slice | int, dim_size: int) -> tuple[int, int]: + """Convert slice to range, handling single int as well. + + Examples + -------- + >>> _slice2range(3, 10) + (3, 4) + """ + if isinstance(sl, int): + return (sl, sl + 1) + start = 0 if sl.start is None else max(sl.start, 0) + stop = dim_size if sl.stop is None else min(sl.stop, dim_size) + return (start, stop) + + +def iter_chunk_aligned_slices( + shape: Sequence[int], + chunks: int | Sequence[int], + slices: Sequence[int | slice | EllipsisType], +) -> Iterator[tuple[slice, ...]]: + """Yield chunk-aligned slices for a given shape and slices. + + Parameters + ---------- + shape : tuple[int, ...] + The shape of the array to slice. + chunks : int or tuple[int, ...] + The size of each chunk. If a single int, the same size is used for all + dimensions. + slices : tuple[int | slice | Ellipsis, ...] + The full slices to apply to the array. Ellipsis is supported to + represent multiple slices. + + Returns + ------- + Iterator[tuple[slice, ...]] + An iterator of chunk-aligned slices. + + Raises + ------ + ValueError + If the length of `chunks`, `shape`, and `slices` do not match, or any chunks + are zero. + IndexError + If more than one Ellipsis is present in `slices`. + + Examples + -------- + >>> list( + ... iter_chunk_aligned_slices(shape=(6, 6), chunks=4, slices=(slice(1, 4), ...)) + ... ) + [ + (slice(1, 4, None), slice(0, 4, None)), + (slice(1, 4, None), slice(4, 6, None)), + ] + + >>> x = iter_chunk_aligned_slices( + ... shape=(10, 9), chunks=(4, 3), slices=(slice(3, 9), slice(1, None)) + ... ) + >>> list(x) + [ + (slice(3, 4, None), slice(1, 3, None)), + (slice(3, 4, None), slice(3, 6, None)), + (slice(3, 4, None), slice(6, 9, None)), + (slice(4, 8, None), slice(1, 3, None)), + (slice(4, 8, None), slice(3, 6, None)), + (slice(4, 8, None), slice(6, 9, None)), + (slice(8, 9, None), slice(1, 3, None)), + (slice(8, 9, None), slice(3, 6, None)), + (slice(8, 9, None), slice(6, 9, None)), + ] + """ + # Make chunks same length as shape if single int + ndim = len(shape) + if isinstance(chunks, int): + chunks = (chunks,) * ndim + if any(x == 0 for x in chunks): + raise ValueError("Chunk size must be greater than zero") + + if num_ellipsis := slices.count(Ellipsis): + if num_ellipsis > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + # Replace Ellipsis with multiple slices + el_idx = slices.index(Ellipsis) + n_remaining = ndim - len(slices) + 1 + slices = ( + tuple(slices[:el_idx]) + + (slice(None),) * n_remaining + + tuple(slices[el_idx + 1 :]) + ) + slices = cast(tuple[int | slice, ...], slices) # now we have no Ellipsis + + if not (len(chunks) == ndim == len(slices)): + raise ValueError("Length of `chunks`, `shape`, and `slices` must match") + + # Create ranges for each dimension based on the slices provided + ranges = [_slice2range(sl, dim) for sl, dim in zip(slices, shape)] + + # Generate indices for each dimension that align with chunks + aligned_ranges = ( + range(start - (start % chunk_size), stop, chunk_size) + for (start, stop), chunk_size in zip(ranges, chunks) + ) + + # Create all combinations of these aligned ranges + for indices in product(*aligned_ranges): + chunk_slices = [] + for idx, (start, stop), ch in zip(indices, ranges, chunks): + # Calculate the actual slice for each dimension + start = max(start, idx) + stop = min(stop, idx + ch) + if start >= stop: # Skip empty slices + break + chunk_slices.append(slice(start, stop)) + else: + # Only add this combination of slices if all dimensions are valid + yield tuple(chunk_slices) + + +def slice_center(s: slice | int, dim_size: int) -> float: + """Calculate the center of a slice based on its start and stop attributes.""" + if isinstance(s, int): + return s + start = float(s.start) if s.start is not None else 0 + stop = float(s.stop) if s.stop is not None else dim_size + return (start + stop) / 2 + + +def distance_from_coord( + slice_tuple: tuple[slice | int, ...], + shape: tuple[int, ...], + coord: Iterable[float] = (), # defaults to center of shape +) -> float: + """Euclidean distance from the center of an nd slice to the center of shape.""" + if not coord: + coord = (dim / 2 for dim in shape) + slice_centers = (slice_center(s, dim) for s, dim in zip(slice_tuple, shape)) + return math.hypot(*(sc - cc for sc, cc in zip(slice_centers, coord))) + + +# def _axis_chunks(total_length: int, chunk_size: int) -> tuple[int, ...]: +# """Break `total_length` into chunks of `chunk_size` plus remainder. + +# Examples +# -------- +# >>> _axis_chunks(10, 3) +# (3, 3, 3, 1) +# """ +# sequence = (chunk_size,) * (total_length // chunk_size) +# if remainder := total_length % chunk_size: +# sequence += (remainder,) +# return sequence + + +# def _shape_chunks( +# shape: tuple[int, ...], chunks: int | tuple[int, ...] +# ) -> tuple[tuple[int, ...], ...]: +# """Break `shape` into chunks of `chunks` along each axis. + +# Examples +# -------- +# >>> _shape_chunks((10, 10, 10), 3) +# ((3, 3, 3, 1), (3, 3, 3, 1), (3, 3, 3, 1)) +# """ +# if isinstance(chunks, int): +# chunks = (chunks,) * len(shape) +# elif isinstance(chunks, Sequence): +# if len(chunks) != len(shape): +# raise ValueError("Length of `chunks` must match length of `shape`") +# else: +# raise TypeError("`chunks` must be an int or sequence of ints") +# return tuple(_axis_chunks(length, chunk) for length, chunk in zip(shape, chunks)) diff --git a/src/ndv/viewer2/__init__.py b/src/ndv/viewer2/__init__.py new file mode 100644 index 0000000..09c9470 --- /dev/null +++ b/src/ndv/viewer2/__init__.py @@ -0,0 +1 @@ +"""viewer source.""" diff --git a/src/ndv/viewer2/_backends/__init__.py b/src/ndv/viewer2/_backends/__init__.py new file mode 100644 index 0000000..6c50889 --- /dev/null +++ b/src/ndv/viewer2/_backends/__init__.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import importlib +import importlib.util +import os +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ndv.viewer2._backends.protocols import PCanvas + + +def get_canvas(backend: str | None = None) -> type[PCanvas]: + backend = backend or os.getenv("NDV_CANVAS_BACKEND", None) + if backend == "vispy" or (backend is None and "vispy" in sys.modules): + from ._vispy import VispyViewerCanvas + + return VispyViewerCanvas + + if backend == "pygfx" or (backend is None and "pygfx" in sys.modules): + from ._pygfx import PyGFXViewerCanvas + + return PyGFXViewerCanvas + + if backend is None: + if importlib.util.find_spec("vispy") is not None: + from ._vispy import VispyViewerCanvas + + return VispyViewerCanvas + + if importlib.util.find_spec("pygfx") is not None: + from ._pygfx import PyGFXViewerCanvas + + return PyGFXViewerCanvas + + raise RuntimeError("No canvas backend found") diff --git a/src/ndv/viewer2/_backends/_pygfx.py b/src/ndv/viewer2/_backends/_pygfx.py new file mode 100644 index 0000000..f8190ed --- /dev/null +++ b/src/ndv/viewer2/_backends/_pygfx.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +import numpy as np +import pygfx +from qtpy.QtCore import QSize +from wgpu.gui.qt import QWgpuCanvas + +if TYPE_CHECKING: + import cmap + from pygfx.materials import ImageBasicMaterial + from pygfx.resources import Texture + from qtpy.QtWidgets import QWidget + + +class PyGFXImageHandle: + def __init__(self, image: pygfx.Image | pygfx.Volume, render: Callable) -> None: + self._image = image + self._render = render + self._grid = cast("Texture", image.geometry.grid) + self._material = cast("ImageBasicMaterial", image.material) + + @property + def data(self) -> np.ndarray: + return self._grid.data # type: ignore [no-any-return] + + @data.setter + def data(self, data: np.ndarray) -> None: + self._grid.data[:] = data + self._grid.update_range((0, 0, 0), self._grid.size) + + @property + def visible(self) -> bool: + return bool(self._image.visible) + + @visible.setter + def visible(self, visible: bool) -> None: + self._image.visible = visible + self._render() + + @property + def clim(self) -> Any: + return self._material.clim + + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: + self._material.clim = clims + self._render() + + @property + def cmap(self) -> cmap.Colormap: + return self._cmap + + @cmap.setter + def cmap(self, cmap: cmap.Colormap) -> None: + self._cmap = cmap + self._material.map = cmap.to_pygfx() + self._render() + + def remove(self) -> None: + if (par := self._image.parent) is not None: + par.remove(self._image) + + +class _QWgpuCanvas(QWgpuCanvas): + def sizeHint(self) -> QSize: + return QSize(512, 512) + + +class PyGFXViewerCanvas: + """pygfx-based canvas wrapper.""" + + def __init__(self, set_info: Callable[[str], None]) -> None: + self._set_info = set_info + self._current_shape: tuple[int, ...] = () + self._last_state: dict[Literal[2, 3], Any] = {} + + self._canvas = _QWgpuCanvas(size=(512, 512)) + self._renderer = pygfx.renderers.WgpuRenderer(self._canvas) + try: + # requires https://github.com/pygfx/pygfx/pull/752 + self._renderer.blend_mode = "additive" + except ValueError: + warnings.warn( + "This version of pygfx does not yet support additive blending.", + stacklevel=3, + ) + self._renderer.blend_mode = "weighted_depth" + + self._scene = pygfx.Scene() + self._camera: pygfx.Camera | None = None + self._ndim: Literal[2, 3] | None = None + + def qwidget(self) -> QWidget: + return cast("QWidget", self._canvas) + + def set_ndim(self, ndim: Literal[2, 3]) -> None: + """Set the number of dimensions of the displayed data.""" + if ndim == self._ndim: + return + elif self._ndim is not None and self._camera is not None: + # remember the current state before switching to the new camera + self._last_state[self._ndim] = self._camera.get_state() + + self._ndim = ndim + if ndim == 3: + self._camera = cam = pygfx.PerspectiveCamera(0, 1) + cam.show_object(self._scene, up=(0, -1, 0), view_dir=(0, 0, 1)) + controller = pygfx.OrbitController(cam, register_events=self._renderer) + zoom = "zoom" + # FIXME: there is still an issue with rotational centration. + # the controller is not rotating around the middle of the volume... + # but I think it might actually be a pygfx issue... the critical state + # seems to be somewhere outside of the camera's get_state dict. + else: + self._camera = cam = pygfx.OrthographicCamera(512, 512) + cam.local.scale_y = -1 + cam.local.position = (256, 256, 0) + controller = pygfx.PanZoomController(cam, register_events=self._renderer) + zoom = "zoom_to_point" + + self._controller = controller + # increase zoom wheel gain + self._controller.controls.update({"wheel": (zoom, "push", -0.005)}) + + # restore the previous state if it exists + if state := self._last_state.get(ndim): + cam.set_state(state) + + def add_image( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> PyGFXImageHandle: + """Add a new Image node to the scene.""" + tex = pygfx.Texture(data, dim=2) + image = pygfx.Image( + pygfx.Geometry(grid=tex), + # depth_test=False for additive-like blending + pygfx.ImageBasicMaterial(depth_test=False), + ) + self._scene.add(image) + + if data is not None: + self._current_shape, prev_shape = data.shape, self._current_shape + if not prev_shape: + self.set_range() + + # FIXME: I suspect there are more performant ways to refresh the canvas + # look into it. + handle = PyGFXImageHandle(image, self.refresh) + if cmap is not None: + handle.cmap = cmap + return handle + + def add_volume( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> PyGFXImageHandle: + tex = pygfx.Texture(data, dim=3) + vol = pygfx.Volume( + pygfx.Geometry(grid=tex), + # depth_test=False for additive-like blending + pygfx.VolumeRayMaterial(interpolation="nearest", depth_test=False), + ) + self._scene.add(vol) + + if data is not None: + vol.local_position = [-0.5 * i for i in data.shape[::-1]] + self._current_shape, prev_shape = data.shape, self._current_shape + if len(prev_shape) != 3: + self.set_range() + + # FIXME: I suspect there are more performant ways to refresh the canvas + # look into it. + handle = PyGFXImageHandle(vol, self.refresh) + if cmap is not None: + handle.cmap = cmap + return handle + + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, + margin: float = 0.05, + ) -> None: + """Update the range of the PanZoomCamera. + + When called with no arguments, the range is set to the full extent of the data. + """ + if not self._scene.children or self._camera is None: + return + + cam = self._camera + cam.show_object(self._scene) + + width, height, depth = np.ptp(self._scene.get_world_bounding_box(), axis=0) + if width < 0.01: + width = 1 + if height < 0.01: + height = 1 + cam.width = width + cam.height = height + cam.zoom = 1 - margin + self.refresh() + + def refresh(self) -> None: + self._canvas.update() + self._canvas.request_draw(self._animate) + + def _animate(self) -> None: + self._renderer.render(self._scene, self._camera) + + # def _on_mouse_move(self, event: SceneMouseEvent) -> None: + # """Mouse moved on the canvas, display the pixel value and position.""" + # images = [] + # # Get the images the mouse is over + # seen = set() + # while visual := self._canvas.visual_at(event.pos): + # if isinstance(visual, scene.visuals.Image): + # images.append(visual) + # visual.interactive = False + # seen.add(visual) + # for visual in seen: + # visual.interactive = True + # if not images: + # return + + # tform = images[0].get_transform("canvas", "visual") + # px, py, *_ = (int(x) for x in tform.map(event.pos)) + # text = f"[{py}, {px}]" + # for c, img in enumerate(images): + # with suppress(IndexError): + # text += f" c{c}: {img._data[py, px]}" + # self._set_info(text) diff --git a/src/ndv/viewer2/_backends/_vispy.py b/src/ndv/viewer2/_backends/_vispy.py new file mode 100644 index 0000000..1460fa8 --- /dev/null +++ b/src/ndv/viewer2/_backends/_vispy.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +import warnings +from contextlib import suppress +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +import numpy as np +import vispy +import vispy.gloo +import vispy.scene +import vispy.visuals +from superqt.utils import qthrottled +from vispy import scene +from vispy.util.quaternion import Quaternion + +if TYPE_CHECKING: + import cmap + import vispy.gloo.glir + import vispy.gloo.texture + from qtpy.QtWidgets import QWidget + from vispy.scene.events import SceneMouseEvent + +turn = np.sin(np.pi / 4) +DEFAULT_QUATERNION = Quaternion(turn, turn, 0, 0) + + +class VispyImageHandle: + def __init__(self, visual: scene.Image | scene.Volume) -> None: + self._visual = visual + self._ndim = 2 if isinstance(visual, scene.Image) else 3 + + @property + def data(self) -> np.ndarray: + try: + return self._visual._data # type: ignore [no-any-return] + except AttributeError: + return self._visual._last_data # type: ignore [no-any-return] + + @data.setter + def data(self, data: np.ndarray) -> None: + if not data.ndim == self._ndim: + warnings.warn( + f"Got wrong number of dimensions ({data.ndim}) for vispy " + f"visual of type {type(self._visual)}.", + stacklevel=2, + ) + return + self._visual.set_data(data) + + def clear(self) -> None: + offset = (0,) * self.data.ndim + self.directly_set_texture_offset( + np.zeros(self.data.shape, dtype=self.data.dtype), offset + ) + + def directly_set_texture_offset(self, data: np.ndarray, offset: tuple) -> None: + """LOW-LEVEL: Set the texture data at offset directly. + + We are bypassing all data transformations and checks here, so data *must* be + the correct shape and dtype. + """ + if self._ndim == 3: + if data.ndim == 3: + data = data.reshape((*data.shape, 1)) + texture = cast("vispy.gloo.texture.Texture3D", self._visual._texture) + queue = cast("vispy.gloo.glir.GlirQueue", texture._glir) + queue.command("DATA", texture._id, offset, data) + + @property + def visible(self) -> bool: + return bool(self._visual.visible) + + @visible.setter + def visible(self, visible: bool) -> None: + self._visual.visible = visible + + @property + def clim(self) -> Any: + return self._visual.clim + + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: + with suppress(ZeroDivisionError): + self._visual.clim = clims + + @property + def cmap(self) -> cmap.Colormap: + return self._cmap + + @cmap.setter + def cmap(self, cmap: cmap.Colormap) -> None: + self._cmap = cmap + self._visual.cmap = cmap.to_vispy() + + @property + def transform(self) -> np.ndarray: + raise NotImplementedError + + @transform.setter + def transform(self, transform: np.ndarray) -> None: + raise NotImplementedError + + def remove(self) -> None: + self._visual.parent = None + + +class VispyViewerCanvas: + """Vispy-based viewer for data. + + All vispy-specific code is encapsulated in this class (and non-vispy canvases + could be swapped in if needed as long as they implement the same interface). + """ + + def __init__(self, set_info: Callable[[str], None]) -> None: + self._set_info = set_info + self._canvas = scene.SceneCanvas() + self._canvas.events.mouse_move.connect(qthrottled(self._on_mouse_move, 60)) + self._current_shape: tuple[int, ...] = () + self._last_state: dict[Literal[2, 3], Any] = {} + + central_wdg: scene.Widget = self._canvas.central_widget + self._view: scene.ViewBox = central_wdg.add_view() + self._ndim: Literal[2, 3] | None = None + + @property + def _camera(self) -> vispy.scene.cameras.BaseCamera: + return self._view.camera + + def set_ndim(self, ndim: Literal[2, 3]) -> None: + """Set the number of dimensions of the displayed data.""" + if ndim == self._ndim: + return + elif self._ndim is not None: + # remember the current state before switching to the new camera + self._last_state[self._ndim] = self._camera.get_state() + + self._ndim = ndim + if ndim == 3: + cam = scene.ArcballCamera(fov=0) + # this sets the initial view similar to what the panzoom view would have. + cam._quaternion = DEFAULT_QUATERNION + else: + cam = scene.PanZoomCamera(aspect=1, flip=(0, 1)) + + # restore the previous state if it exists + if state := self._last_state.get(ndim): + cam.set_state(state) + self._view.camera = cam + + def qwidget(self) -> QWidget: + return cast("QWidget", self._canvas.native) + + def refresh(self) -> None: + self._canvas.update() + + def add_image( + self, + data: np.ndarray | None = None, + cmap: cmap.Colormap | None = None, + offset: tuple[float, float] | None = None, # (Y, X) + ) -> VispyImageHandle: + """Add a new Image node to the scene.""" + img = scene.visuals.Image(data, parent=self._view.scene) + img.set_gl_state("additive", depth_test=False) + img.interactive = True + + if offset: + img.transform = scene.STTransform(translate=offset[::-1]) + + if data is not None: + self._current_shape, prev_shape = data.shape, self._current_shape + if not prev_shape: + self.set_range() + handle = VispyImageHandle(img) + if cmap is not None: + handle.cmap = cmap + return handle + + def add_volume( + self, + data: np.ndarray | None = None, + cmap: cmap.Colormap | None = None, + offset: tuple[float, float, float] | None = None, # (Z, Y, X) + ) -> VispyImageHandle: + vol = scene.visuals.Volume( + data, + parent=self._view.scene, + interpolation="nearest", + texture_format="auto", + ) + vol.set_gl_state("additive", depth_test=False) + vol.interactive = True + if offset: + vol.transform = scene.STTransform(translate=offset[::-1]) + + if data is not None: + self._current_shape, prev_shape = data.shape, self._current_shape + if len(prev_shape) != 3: + self.set_range() + handle = VispyImageHandle(vol) + if cmap is not None: + handle.cmap = cmap + return handle + + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, + margin: float = 0.01, + ) -> None: + """Update the range of the PanZoomCamera. + + When called with no arguments, the range is set to the full extent of the data. + """ + if len(self._current_shape) >= 2: + if x is None: + x = (0, self._current_shape[-1]) + if y is None: + y = (0, self._current_shape[-2]) + if z is None and len(self._current_shape) == 3: + z = (0, self._current_shape[-3]) + is_3d = isinstance(self._camera, scene.ArcballCamera) + if is_3d: + self._camera._quaternion = DEFAULT_QUATERNION + self._view.camera.set_range(x=x, y=y, z=z, margin=margin) + if is_3d: + max_size = max(self._current_shape) + self._camera.scale_factor = max_size + 6 + + def _on_mouse_move(self, event: SceneMouseEvent) -> None: + """Mouse moved on the canvas, display the pixel value and position.""" + images = [] + # Get the images the mouse is over + # FIXME: this is narsty ... there must be a better way to do this + seen = set() + try: + while visual := self._canvas.visual_at(event.pos): + if isinstance(visual, scene.visuals.Image): + images.append(visual) + visual.interactive = False + seen.add(visual) + except Exception: + return + for visual in seen: + visual.interactive = True + if not images: + return + + tform = images[0].get_transform("canvas", "visual") + px, py, *_ = (int(x) for x in tform.map(event.pos)) + text = f"[{py}, {px}]" + for c, img in enumerate(reversed(images)): + with suppress(IndexError): + value = img._data[py, px] + if isinstance(value, (np.floating, float)): + value = f"{value:.2f}" + text += f" {c}: {value}" + self._set_info(text) diff --git a/src/ndv/viewer2/_backends/protocols.py b/src/ndv/viewer2/_backends/protocols.py new file mode 100644 index 0000000..f8895af --- /dev/null +++ b/src/ndv/viewer2/_backends/protocols.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol + +if TYPE_CHECKING: + import cmap + import numpy as np + from qtpy.QtWidgets import QWidget + + +class PImageHandle(Protocol): + @property + def data(self) -> np.ndarray: ... + @data.setter + def data(self, data: np.ndarray) -> None: ... + def directly_set_texture_offset(self, data: np.ndarray, offset: tuple) -> None: ... + @property + def visible(self) -> bool: ... + @visible.setter + def visible(self, visible: bool) -> None: ... + @property + def clim(self) -> Any: ... + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: ... + @property + def cmap(self) -> Any: ... + @cmap.setter + def cmap(self, cmap: Any) -> None: ... + def remove(self) -> None: ... + + +class PCanvas(Protocol): + def __init__(self, set_info: Callable[[str], None]) -> None: ... + def set_ndim(self, ndim: Literal[2, 3]) -> None: ... + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, + margin: float = ..., + ) -> None: ... + def refresh(self) -> None: ... + def qwidget(self) -> QWidget: ... + def add_image( + self, + data: np.ndarray | None = ..., + cmap: cmap.Colormap | None = ..., + ) -> PImageHandle: ... + def add_volume( + self, + data: np.ndarray | None = ..., + cmap: cmap.Colormap | None = ..., + ) -> PImageHandle: ... diff --git a/src/ndv/viewer2/_components.py b/src/ndv/viewer2/_components.py new file mode 100644 index 0000000..68d0675 --- /dev/null +++ b/src/ndv/viewer2/_components.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from enum import Enum +from pathlib import Path + +from qtpy.QtCore import QSize +from qtpy.QtGui import QMovie +from qtpy.QtWidgets import QLabel, QPushButton, QWidget +from superqt import QIconifyIcon + +SPIN_GIF = str(Path(__file__).parent / "spin.gif") + + +class DimToggleButton(QPushButton): + def __init__(self, parent: QWidget | None = None): + icn = QIconifyIcon("f7:view-2d", color="#333333") + icn.addKey("f7:view-3d", state=QIconifyIcon.State.On, color="white") + super().__init__(icn, "", parent) + self.setCheckable(True) + self.setChecked(True) + + +class QSpinner(QLabel): + def __init__(self, parent: QWidget | None = None): + super().__init__(parent) + size = QSize(16, 16) + mov = QMovie(SPIN_GIF, parent=self) + self.setFixedSize(size) + mov.setScaledSize(size) + mov.setSpeed(150) + mov.start() + self.setMovie(mov) + self.hide() + + +class ChannelMode(str, Enum): + COMPOSITE = "composite" + MONO = "mono" + + def __str__(self) -> str: + return self.value + + +class ChannelModeButton(QPushButton): + def __init__(self, parent: QWidget | None = None): + super().__init__(parent) + self.setCheckable(True) + self.toggled.connect(self.next_mode) + + # set minimum width to the width of the larger string 'composite' + self.setMinimumWidth(92) # magic number :/ + + def next_mode(self) -> None: + if self.isChecked(): + self.setMode(ChannelMode.MONO) + else: + self.setMode(ChannelMode.COMPOSITE) + + def mode(self) -> ChannelMode: + return ChannelMode.MONO if self.isChecked() else ChannelMode.COMPOSITE + + def setMode(self, mode: ChannelMode) -> None: + # we show the name of the next mode, not the current one + other = ChannelMode.COMPOSITE if mode is ChannelMode.MONO else ChannelMode.MONO + self.setText(str(other)) + self.setChecked(mode == ChannelMode.MONO) diff --git a/src/ndv/viewer2/_data_wrapper.py b/src/ndv/viewer2/_data_wrapper.py new file mode 100644 index 0000000..366e34f --- /dev/null +++ b/src/ndv/viewer2/_data_wrapper.py @@ -0,0 +1,418 @@ +"""In this module, we provide built-in support for many array types.""" + +from __future__ import annotations + +import logging +import sys +import warnings +from abc import abstractmethod +from contextlib import suppress +from typing import ( + TYPE_CHECKING, + ClassVar, + Container, + Generic, + Hashable, + Iterator, + Mapping, + Sequence, + TypeVar, +) + +import numpy as np + +if TYPE_CHECKING: + from pathlib import Path + from typing import Any, Protocol, TypeAlias, TypeGuard + + import dask.array as da + import numpy.typing as npt + import pyopencl.array as cl_array + import sparse + import tensorstore as ts + import torch + import xarray as xr + import zarr + from torch._tensor import Tensor + + from ._dims_slider import Index, Indices, Sizes + + _T_contra = TypeVar("_T_contra", contravariant=True) + + class SupportsIndexing(Protocol): + def __getitem__(self, key: Index | tuple[Index, ...]) -> npt.ArrayLike: ... + @property + def shape(self) -> tuple[int, ...]: ... + + class SupportsDunderLT(Protocol[_T_contra]): + def __lt__(self, other: _T_contra, /) -> bool: ... + + class SupportsDunderGT(Protocol[_T_contra]): + def __gt__(self, other: _T_contra, /) -> bool: ... + + SupportsRichComparison: TypeAlias = SupportsDunderLT[Any] | SupportsDunderGT[Any] + + +ArrayT = TypeVar("ArrayT") +_T = TypeVar("_T", bound=type) + +# Global executor for slice requests +# _EXECUTOR = ThreadPoolExecutor(max_workers=2) + + +def _recurse_subclasses(cls: _T) -> Iterator[_T]: + for subclass in cls.__subclasses__(): + yield subclass + yield from _recurse_subclasses(subclass) + + +class DataWrapper(Generic[ArrayT]): + """Interface for wrapping different array-like data types. + + `DataWrapper.create` is a factory method that returns a DataWrapper instance + for the given data type. If your datastore type is not supported, you may implement + a new DataWrapper subclass to handle your data type. To do this, import and + subclass DataWrapper, and (minimally) implement the supports and isel methods. + Ensure that your class is imported before the DataWrapper.create method is called, + and it will be automatically detected and used to wrap your data. + """ + + # Order in which subclasses are checked for support. + # Lower numbers are checked first, and the first supporting subclass is used. + # Default is 50, and fallback to numpy-like duckarray is 100. + # Subclasses can override this to change the priority in which they are checked + PRIORITY: ClassVar[SupportsRichComparison] = 50 + # These names will be checked when looking for a channel axis + COMMON_CHANNEL_NAMES: ClassVar[Container[str]] = ("channel", "ch", "c") + # Maximum dimension size consider when guessing the channel axis + MAX_CHANNELS = 16 + + @classmethod + def create(cls, data: ArrayT) -> DataWrapper[ArrayT]: + if isinstance(data, DataWrapper): + return data + # check subclasses for support + # This allows users to define their own DataWrapper subclasses which will + # be automatically detected (assuming they have been imported by this point) + for subclass in sorted(_recurse_subclasses(cls), key=lambda x: x.PRIORITY): + with suppress(Exception): + if subclass.supports(data): + logging.debug(f"Using {subclass.__name__} to wrap {type(data)}") + return subclass(data) + raise NotImplementedError(f"Don't know how to wrap type {type(data)}") + + @classmethod + @abstractmethod + def supports(cls, obj: Any) -> bool: + """Return True if this wrapper can handle the given object. + + Any exceptions raised by this method will be suppressed, so it is safe to + directly import necessary dependencies without a try/except block. + """ + raise NotImplementedError + + def __init__(self, data: ArrayT) -> None: + self._data = data + self._name2index: dict[str, int] = {} + if names := self.dimension_names(): + self._name2index = {name: i for i, name in enumerate(names)} + + @property + def data(self) -> ArrayT: + return self._data + + # @abstractmethod + # def isel(self, indexers: Indices) -> np.ndarray: + # """Select a slice from a data store using (possibly) named indices. + + # This follows the xarray-style indexing, where indexers is a mapping of + # dimension names to indices or slices. Subclasses should implement this + # method to return a numpy array. + # """ + # raise NotImplementedError + + def shape(self) -> tuple[int, ...]: + return self._data.shape # type: ignore + + def __getitem__(self, index: tuple[int | slice, ...]) -> np.ndarray: + # reimplement in subclasses + return np.asarray(self._data[index]) # type: ignore [index] + + def chunks(self) -> tuple[int, ...] | int | None: + if chunks := getattr(self._data, "chunks", None): + if isinstance(chunks, Sequence) and all(isinstance(x, int) for x in chunks): + return tuple(chunks) + warnings.warn( + f"Unexpected chunks attribute: {chunks!r}. Ignoring.", stacklevel=2 + ) + return None + + def dimension_names(self) -> tuple[str, ...] | None: + """Return the names of the dimensions of the data.""" + return None + + def to_conventional(self, indexers: Indices) -> tuple[int | slice, ...]: + """Convert named indices to a tuple of integers and slices.""" + _indexers = {self._name2index.get(str(k), k): v for k, v in indexers.items()} + return tuple(_indexers.get(k, slice(None)) for k in range(len(self.shape()))) + + # def isel_async( + # self, indexers: list[Indices] + # ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: + # """Asynchronous version of isel.""" + # return _EXECUTOR.submit(lambda: [(idx, self.isel(idx)) for idx in indexers]) + + def guess_channel_axis(self) -> int | None: + """Return the (best guess) axis name for the channel dimension.""" + # for arrays with labeled dimensions, + # see if any of the dimensions are named "channel" + shape = self.shape() + if names := self.dimension_names(): + for ax, name in enumerate(names): + if ( + name.lower() in self.COMMON_CHANNEL_NAMES + and shape[ax] <= self.MAX_CHANNELS + ): + return ax + + # for shaped arrays, use the smallest dimension as the channel axis + with suppress(ValueError): + if (smallest_dim := min(shape)) <= self.MAX_CHANNELS: + return shape.index(smallest_dim) + return None + + def save_as_zarr(self, save_loc: str | Path) -> None: + raise NotImplementedError("save_as_zarr not implemented for this data type.") + + def sizes(self) -> Sizes: + """Return a mapping of {dimkey: size} for the data. + + The default implementation uses the shape attribute of the data, and + tries to find dimension names in the `dims`, `names`, or `labels` attributes. + (`dims` is used by xarray, `names` is used by torch, etc...). If no labels + are found, the dimensions are just named by their integer index. + """ + shape = self.shape() + if (names := self.dimension_names()) and len(names) == len(shape): + return dict(zip(names, shape)) + return dict(enumerate(shape)) + + def summary_info(self) -> str: + """Return info label with information about the data.""" + package = getattr(self._data, "__module__", "").split(".")[0] + info = f"{package}.{getattr(type(self._data), '__qualname__', '')}" + + if sizes := self.sizes(): + # if all of the dimension keys are just integers, omit them from size_str + if all(isinstance(x, int) for x in sizes): + size_str = repr(tuple(sizes.values())) + # otherwise, include the keys in the size_str + else: + size_str = ", ".join(f"{k}:{v}" for k, v in sizes.items()) + size_str = f"({size_str})" + info += f" {size_str}" + if dtype := getattr(self._data, "dtype", ""): + info += f", {dtype}" + if nbytes := getattr(self._data, "nbytes", 0) / 1e6: + info += f", {nbytes:.2f}MB" + return info + + +class XarrayWrapper(DataWrapper["xr.DataArray"]): + """Wrapper for xarray DataArray objects.""" + + def isel(self, indexers: Indices) -> np.ndarray: + return np.asarray(self._data.isel(indexers)) + + def sizes(self) -> Mapping[Hashable, int]: + return {k: int(v) for k, v in self._data.sizes.items()} + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[xr.DataArray]: + if (xr := sys.modules.get("xarray")) and isinstance(obj, xr.DataArray): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + self._data.to_zarr(save_loc) + + +class TensorstoreWrapper(DataWrapper["ts.TensorStore"]): + """Wrapper for tensorstore.TensorStore objects.""" + + def __init__(self, data: Any) -> None: + super().__init__(data) + import tensorstore as ts + + self._ts = ts + + def sizes(self) -> Mapping[Hashable, int]: + return {dim.label: dim.size for dim in self._data.domain} + + def isel(self, indexers: Indices) -> np.ndarray: + result = ( + self._data[self._ts.d[tuple(indexers)][tuple(indexers.values())]] + .read() + .result() + ) + return np.asarray(result) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[ts.TensorStore]: + if (ts := sys.modules.get("tensorstore")) and isinstance(obj, ts.TensorStore): + return True + return False + + +class ArrayLikeWrapper(DataWrapper, Generic[ArrayT]): + """Wrapper for numpy duck array-like objects.""" + + PRIORITY = 100 + + def isel(self, indexers: Indices) -> np.ndarray: + idx = [] + for k in range(len(self._data.shape)): + i = indexers.get(k, slice(None)) + if isinstance(i, int): + idx.extend([i, None]) + else: + idx.append(i) + + return self._asarray(self._data[tuple(idx)]) + + def _asarray(self, data: npt.ArrayLike) -> np.ndarray: + return np.asarray(data) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[SupportsIndexing]: + if ( + ( + isinstance(obj, np.ndarray) + or hasattr(obj, "__array_function__") + or hasattr(obj, "__array_namespace__") + or hasattr(obj, "__array__") + ) + and hasattr(obj, "__getitem__") + and hasattr(obj, "shape") + ): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + try: + import zarr + except ImportError: + raise ImportError("zarr is required to save this data type.") from None + + if isinstance(self._data, zarr.Array): + self._data.store = zarr.DirectoryStore(save_loc) + else: + zarr.save(str(save_loc), self._data) + + +class DaskWrapper(DataWrapper["da.Array"]): + """Wrapper for dask array objects.""" + + def isel(self, indexers: Indices) -> np.ndarray: + idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) + return np.asarray(self._data[idx].compute()) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[da.Array]: + if (da := sys.modules.get("dask.array")) and isinstance(obj, da.Array): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + self._data.to_zarr(url=str(save_loc)) + + +class CLArrayWrapper(ArrayLikeWrapper["cl_array.Array"]): + """Wrapper for pyopencl array objects.""" + + PRIORITY = 50 + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[cl_array.Array]: + if (cl_array := sys.modules.get("pyopencl.array")) and isinstance( + obj, cl_array.Array + ): + return True + return False + + def _asarray(self, data: cl_array.Array) -> np.ndarray: + return np.asarray(data.get()) + + +class SparseArrayWrapper(ArrayLikeWrapper["sparse.Array"]): + PRIORITY = 50 + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[sparse.COO]: + if (sparse := sys.modules.get("sparse")) and isinstance(obj, sparse.COO): + return True + return False + + def _asarray(self, data: sparse.COO) -> np.ndarray: + return np.asarray(data.todense()) + + +class ZarrArrayWrapper(ArrayLikeWrapper["zarr.Array"]): + """Wrapper for zarr array objects.""" + + PRIORITY = 50 + + def __init__(self, data: Any) -> None: + super().__init__(data) + self._name2index: dict[Hashable, int] + if "_ARRAY_DIMENSIONS" in data.attrs: + self._name2index = { + name: i for i, name in enumerate(data.attrs["_ARRAY_DIMENSIONS"]) + } + else: + self._name2index = {i: i for i in range(data.ndim)} + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[zarr.Array]: + if (zarr := sys.modules.get("zarr")) and isinstance(obj, zarr.Array): + return True + return False + + def sizes(self) -> Sizes: + return dict(zip(self._name2index, self.data.shape)) + + def isel(self, indexers: Indices) -> np.ndarray: + # convert possibly named indices to integer indices + real_indexers = {self._name2index.get(k, k): v for k, v in indexers.items()} + return super().isel(real_indexers) + + +class TorchTensorWrapper(DataWrapper["torch.Tensor"]): + """Wrapper for torch tensor objects.""" + + def __init__(self, data: Tensor) -> None: + super().__init__(data) + self._name2index: dict[Hashable, int] + if names := getattr(data, "names", None): + # names may be something like (None, None, None)... + self._name2index = { + (i if name is None else name): i for i, name in enumerate(names) + } + else: + self._name2index = {i: i for i in range(data.ndim)} + + def sizes(self) -> Sizes: + return dict(zip(self._name2index, self.data.shape)) + + def isel(self, indexers: Indices) -> np.ndarray: + # convert possibly named indices to integer indices + real_indexers = {self._name2index.get(k, k): v for k, v in indexers.items()} + # convert to tuple of slices + idx = tuple(real_indexers.get(i, slice(None)) for i in range(self.data.ndim)) + return self.data[idx].numpy(force=True) # type: ignore [no-any-return] + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[torch.Tensor]: + if (torch := sys.modules.get("torch")) and isinstance(obj, torch.Tensor): + return True + return False diff --git a/src/ndv/viewer2/_dims_slider.py b/src/ndv/viewer2/_dims_slider.py new file mode 100644 index 0000000..69db980 --- /dev/null +++ b/src/ndv/viewer2/_dims_slider.py @@ -0,0 +1,529 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast +from warnings import warn + +from qtpy.QtCore import QPoint, QPointF, QSize, Qt, Signal +from qtpy.QtGui import QCursor, QResizeEvent +from qtpy.QtWidgets import ( + QDialog, + QDoubleSpinBox, + QFormLayout, + QFrame, + QHBoxLayout, + QLabel, + QPushButton, + QSizePolicy, + QSlider, + QSpinBox, + QVBoxLayout, + QWidget, +) +from superqt import QLabeledRangeSlider +from superqt.iconify import QIconifyIcon +from superqt.utils import signals_blocked + +if TYPE_CHECKING: + from typing import Hashable, Mapping, TypeAlias + + from qtpy.QtGui import QResizeEvent + + # any hashable represent a single dimension in an ND array + DimKey: TypeAlias = Hashable + # any object that can be used to index a single dimension in an ND array + Index: TypeAlias = int | slice + # a mapping from dimension keys to indices (eg. {"x": 0, "y": slice(5, 10)}) + # this object is used frequently to query or set the currently displayed slice + Indices: TypeAlias = Mapping[DimKey, Index] + # mapping of dimension keys to the maximum value for that dimension + Sizes: TypeAlias = Mapping[DimKey, int] + + +SS = """ +QSlider::groove:horizontal { + height: 15px; + background: qlineargradient( + x1:0, y1:0, x2:0, y2:1, + stop:0 rgba(128, 128, 128, 0.25), + stop:1 rgba(128, 128, 128, 0.1) + ); + border-radius: 3px; +} + +QSlider::handle:horizontal { + width: 38px; + background: #999999; + border-radius: 3px; +} + +QLabel { font-size: 12px; } + +QRangeSlider { qproperty-barColor: qlineargradient( + x1:0, y1:0, x2:0, y2:1, + stop:0 rgba(100, 80, 120, 0.2), + stop:1 rgba(100, 80, 120, 0.4) + )} + +SliderLabel { + font-size: 12px; + color: white; +} +""" + + +class QtPopup(QDialog): + """A generic popup window.""" + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.setModal(False) # if False, then clicking anywhere else closes it + self.setWindowFlags(Qt.WindowType.Popup | Qt.WindowType.FramelessWindowHint) + + self.frame = QFrame(self) + layout = QVBoxLayout(self) + layout.addWidget(self.frame) + layout.setContentsMargins(0, 0, 0, 0) + + def show_above_mouse(self, *args: Any) -> None: + """Show popup dialog above the mouse cursor position.""" + pos = QCursor().pos() # mouse position + szhint = self.sizeHint() + pos -= QPoint(szhint.width() // 2, szhint.height() + 14) + self.move(pos) + self.resize(self.sizeHint()) + self.show() + + +class PlayButton(QPushButton): + """Just a styled QPushButton that toggles between play and pause icons.""" + + fpsChanged = Signal(float) + + PLAY_ICON = "bi:play-fill" + PAUSE_ICON = "bi:pause-fill" + + def __init__(self, fps: float = 20, parent: QWidget | None = None) -> None: + icn = QIconifyIcon(self.PLAY_ICON, color="#888888") + icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On, color="#4580DD") + super().__init__(icn, "", parent) + self.spin = QDoubleSpinBox(self) + self.spin.setRange(0.5, 100) + self.spin.setValue(fps) + self.spin.valueChanged.connect(self.fpsChanged) + self.setCheckable(True) + self.setFixedSize(14, 18) + self.setIconSize(QSize(16, 16)) + self.setStyleSheet("border: none; padding: 0; margin: 0;") + + self._popup = QtPopup(self) + form = QFormLayout(self._popup.frame) + form.setContentsMargins(6, 6, 6, 6) + form.addRow("FPS", self.spin) + + def mousePressEvent(self, e: Any) -> None: + if e and e.button() == Qt.MouseButton.RightButton: + self._show_fps_dialog(e.globalPosition()) + else: + super().mousePressEvent(e) + + def _show_fps_dialog(self, pos: QPointF) -> None: + self._popup.show_above_mouse() + + +class LockButton(QPushButton): + LOCK_ICON = "uis:unlock" + UNLOCK_ICON = "uis:lock" + + def __init__(self, text: str = "", parent: QWidget | None = None) -> None: + icn = QIconifyIcon(self.LOCK_ICON, color="#888888") + icn.addKey(self.UNLOCK_ICON, state=QIconifyIcon.State.On, color="red") + super().__init__(icn, text, parent) + self.setCheckable(True) + self.setFixedSize(20, 20) + self.setIconSize(QSize(14, 14)) + self.setStyleSheet("border: none; padding: 0; margin: 0;") + + +class DimsSlider(QWidget): + """A single slider in the DimsSliders widget. + + Provides a play/pause button that toggles animation of the slider value. + Has a QLabeledSlider for the actual value. + Adds a label for the maximum value (e.g. "3 / 10") + """ + + valueChanged = Signal(object, object) # where object is int | slice + + def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.setStyleSheet(SS) + self._slice_mode = False + self._dim_key = dimension_key + + self._timer_id: int | None = None # timer for play button + self._play_btn = PlayButton(parent=self) + self._play_btn.fpsChanged.connect(self.set_fps) + self._play_btn.toggled.connect(self._toggle_animation) + + self._dim_key = dimension_key + self._dim_label = QLabel(str(dimension_key)) + self._dim_label.setSizePolicy( + QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred + ) + self._dim_label.setToolTip("Double-click to toggle slice mode") + + # note, this lock button only prevents the slider from updating programmatically + # using self.setValue, it doesn't prevent the user from changing the value. + self._lock_btn = LockButton(parent=self) + + self._pos_label = QSpinBox(self) + self._pos_label.valueChanged.connect(self._on_pos_label_edited) + self._pos_label.setButtonSymbols(QSpinBox.ButtonSymbols.NoButtons) + self._pos_label.setAlignment(Qt.AlignmentFlag.AlignRight) + self._pos_label.setStyleSheet( + "border: none; padding: 0; margin: 0; background: transparent" + ) + self._out_of_label = QLabel(self) + + self._int_slider = QSlider(Qt.Orientation.Horizontal) + self._int_slider.rangeChanged.connect(self._on_range_changed) + self._int_slider.valueChanged.connect(self._on_int_value_changed) + + self._slice_slider = slc = QLabeledRangeSlider(Qt.Orientation.Horizontal) + slc.setHandleLabelPosition(QLabeledRangeSlider.LabelPosition.LabelsOnHandle) + slc.setEdgeLabelMode(QLabeledRangeSlider.EdgeLabelMode.NoLabel) + slc.setVisible(False) + slc.rangeChanged.connect(self._on_range_changed) + slc.valueChanged.connect(self._on_slice_value_changed) + + self.installEventFilter(self) + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(2) + layout.addWidget(self._play_btn) + layout.addWidget(self._dim_label) + layout.addWidget(self._int_slider) + layout.addWidget(self._slice_slider) + layout.addWidget(self._pos_label) + layout.addWidget(self._out_of_label) + layout.addWidget(self._lock_btn) + self.setMinimumHeight(26) + + def resizeEvent(self, a0: QResizeEvent | None) -> None: + if isinstance(par := self.parent(), DimsSliders): + par.resizeEvent(None) + + def mouseDoubleClickEvent(self, a0: Any) -> None: + self._set_slice_mode(not self._slice_mode) + super().mouseDoubleClickEvent(a0) + + def containMaximum(self, max_val: int) -> None: + if max_val > self._int_slider.maximum(): + self._int_slider.setMaximum(max_val) + if max_val > self._slice_slider.maximum(): + self._slice_slider.setMaximum(max_val) + + def setMaximum(self, max_val: int) -> None: + self._int_slider.setMaximum(max_val) + self._slice_slider.setMaximum(max_val) + + def setMinimum(self, min_val: int) -> None: + self._int_slider.setMinimum(min_val) + self._slice_slider.setMinimum(min_val) + + def containMinimum(self, min_val: int) -> None: + if min_val < self._int_slider.minimum(): + self._int_slider.setMinimum(min_val) + if min_val < self._slice_slider.minimum(): + self._slice_slider.setMinimum(min_val) + + def setRange(self, min_val: int, max_val: int) -> None: + self._int_slider.setRange(min_val, max_val) + self._slice_slider.setRange(min_val, max_val) + + def value(self) -> Index: + if not self._slice_mode: + return self._int_slider.value() # type: ignore + start, *_, stop = cast("tuple[int, ...]", self._slice_slider.value()) + if start == stop: + return start + return slice(start, stop) + + def setValue(self, val: Index) -> None: + # variant of setValue that always updates the maximum + self._set_slice_mode(isinstance(val, slice)) + if self._lock_btn.isChecked(): + return + if isinstance(val, slice): + start = int(val.start) if val.start is not None else 0 + stop = ( + int(val.stop) if val.stop is not None else self._slice_slider.maximum() + ) + self._slice_slider.setValue((start, stop)) + else: + self._int_slider.setValue(val) + # self._slice_slider.setValue((val, val + 1)) + + def forceValue(self, val: Index) -> None: + """Set value and increase range if necessary.""" + if isinstance(val, slice): + if isinstance(val.start, int): + self.containMinimum(val.start) + if isinstance(val.stop, int): + self.containMaximum(val.stop) + else: + self.containMinimum(val) + self.containMaximum(val) + self.setValue(val) + + def _set_slice_mode(self, mode: bool = True) -> None: + if mode == self._slice_mode: + return + self._slice_mode = bool(mode) + self._slice_slider.setVisible(self._slice_mode) + self._int_slider.setVisible(not self._slice_mode) + # self._pos_label.setVisible(not self._slice_mode) + self.valueChanged.emit(self._dim_key, self.value()) + + def set_fps(self, fps: float) -> None: + self._play_btn.spin.setValue(fps) + self._toggle_animation(self._play_btn.isChecked()) + + def _toggle_animation(self, checked: bool) -> None: + if checked: + if self._timer_id is not None: + self.killTimer(self._timer_id) + interval = int(1000 / self._play_btn.spin.value()) + self._timer_id = self.startTimer(interval) + elif self._timer_id is not None: + self.killTimer(self._timer_id) + self._timer_id = None + + def timerEvent(self, event: Any) -> None: + """Handle timer event for play button, move to the next frame.""" + # TODO + # for now just increment the value by 1, but we should be able to + # take FPS into account better and skip additional frames if the timerEvent + # is delayed for some reason. + inc = 1 + if self._slice_mode: + val = cast(tuple[int, int], self._slice_slider.value()) + next_val = [v + inc for v in val] + if next_val[1] > self._slice_slider.maximum(): + # wrap around, without going below the min handle + next_val = [v - val[0] for v in val] + self._slice_slider.setValue(next_val) + else: + ival = self._int_slider.value() + ival = (ival + inc) % (self._int_slider.maximum() + 1) + self._int_slider.setValue(ival) + + def _on_pos_label_edited(self) -> None: + if self._slice_mode: + self._slice_slider.setValue( + (self._slice_slider.value()[0], self._pos_label.value()) + ) + else: + self._int_slider.setValue(self._pos_label.value()) + + def _on_range_changed(self, min: int, max: int) -> None: + self._out_of_label.setText(f"| {max}") + self._pos_label.setRange(min, max) + self.resizeEvent(None) + self.setVisible(min != max) + + def setVisible(self, visible: bool) -> None: + if self._has_no_range(): + visible = False + super().setVisible(visible) + + def _has_no_range(self) -> bool: + if self._slice_mode: + return bool(self._slice_slider.minimum() == self._slice_slider.maximum()) + return bool(self._int_slider.minimum() == self._int_slider.maximum()) + + def _on_int_value_changed(self, value: int) -> None: + self._pos_label.setValue(value) + if not self._slice_mode: + self.valueChanged.emit(self._dim_key, value) + + def _on_slice_value_changed(self, value: tuple[int, int]) -> None: + self._pos_label.setValue(int(value[1])) + with signals_blocked(self._int_slider): + self._int_slider.setValue(int(value[0])) + if self._slice_mode: + self.valueChanged.emit(self._dim_key, slice(value[0], value[1] + 1)) + + +class DimsSliders(QWidget): + """A Collection of DimsSlider widgets for each dimension in the data. + + Maintains the global current index and emits a signal when it changes. + """ + + valueChanged = Signal(dict) # dict is of type Indices + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self._locks_visible: bool | Mapping[DimKey, bool] = False + self._sliders: dict[DimKey, DimsSlider] = {} + self._current_index: dict[DimKey, Index] = {} + self._invisible_dims: set[DimKey] = set() + + self.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum) + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + def __contains__(self, key: DimKey) -> bool: + """Return True if the dimension key is present in the DimsSliders.""" + return key in self._sliders + + def slider(self, key: DimKey) -> DimsSlider: + """Return the DimsSlider widget for the given dimension key.""" + return self._sliders[key] + + def value(self) -> Indices: + """Return mapping of {dim_key -> current index} for each dimension.""" + return self._current_index.copy() + + def setValue(self, values: Indices) -> None: + """Set the current index for each dimension. + + Parameters + ---------- + values : Mapping[Hashable, int | slice] + Mapping of {dim_key -> index} for each dimension. If value is a slice, + the slider will be in slice mode. If the dimension is not present in the + DimsSliders, it will be added. + """ + if self._current_index == values: + return + with signals_blocked(self): + for dim, index in values.items(): + self.add_or_update_dimension(dim, index) + # FIXME: i don't know why this this is ever empty ... only happens on pyside6 + if val := self.value(): + self.valueChanged.emit(val) + + def minima(self) -> Sizes: + """Return mapping of {dim_key -> minimum value} for each dimension.""" + return {k: v._int_slider.minimum() for k, v in self._sliders.items()} + + def setMinima(self, values: Sizes) -> None: + """Set the minimum value for each dimension. + + Parameters + ---------- + values : Mapping[Hashable, int] + Mapping of {dim_key -> minimum value} for each dimension. + """ + for name, min_val in values.items(): + if name not in self._sliders: + self.add_dimension(name) + self._sliders[name].setMinimum(min_val) + + def maxima(self) -> Sizes: + """Return mapping of {dim_key -> maximum value} for each dimension.""" + return {k: v._int_slider.maximum() for k, v in self._sliders.items()} + + def setMaxima(self, values: Sizes) -> None: + """Set the maximum value for each dimension. + + Parameters + ---------- + values : Mapping[Hashable, int] + Mapping of {dim_key -> maximum value} for each dimension. + """ + for name, max_val in values.items(): + if name not in self._sliders: + self.add_dimension(name) + self._sliders[name].setMaximum(max_val) + + def set_locks_visible(self, visible: bool | Mapping[DimKey, bool]) -> None: + """Set the visibility of the lock buttons for all dimensions.""" + self._locks_visible = visible + for dim, slider in self._sliders.items(): + viz = visible if isinstance(visible, bool) else visible.get(dim, False) + slider._lock_btn.setVisible(viz) + + def add_dimension(self, key: DimKey, val: Index | None = None) -> None: + """Add a new dimension to the DimsSliders widget. + + Parameters + ---------- + key : Hashable + The name of the dimension. + val : int | slice, optional + The initial value for the dimension. If a slice, the slider will be in + slice mode. + """ + self._sliders[key] = slider = DimsSlider(dimension_key=key, parent=self) + if isinstance(self._locks_visible, dict) and key in self._locks_visible: + slider._lock_btn.setVisible(self._locks_visible[key]) + else: + slider._lock_btn.setVisible(bool(self._locks_visible)) + + val_int = val.start if isinstance(val, slice) else val + slider.setVisible(key not in self._invisible_dims) + if isinstance(val_int, int): + slider.setRange(val_int, val_int) + elif isinstance(val_int, slice): + slider.setRange(val_int.start or 0, val_int.stop or 1) + + val = val if val is not None else 0 + self._current_index[key] = val + slider.forceValue(val) + slider.valueChanged.connect(self._on_dim_slider_value_changed) + cast("QVBoxLayout", self.layout()).addWidget(slider) + + def set_dimension_visible(self, key: DimKey, visible: bool) -> None: + """Set the visibility of a dimension in the DimsSliders widget. + + Once a dimension is hidden, it will not be shown again until it is explicitly + made visible again with this method. + """ + if visible: + self._invisible_dims.discard(key) + if key in self._sliders: + self._current_index[key] = self._sliders[key].value() + else: + self._invisible_dims.add(key) + self._current_index.pop(key, None) + if key in self._sliders: + self._sliders[key].setVisible(visible) + + def remove_dimension(self, key: DimKey) -> None: + """Remove a dimension from the DimsSliders widget.""" + try: + slider = self._sliders.pop(key) + except KeyError: + warn(f"Dimension {key} not found in DimsSliders", stacklevel=2) + return + cast("QVBoxLayout", self.layout()).removeWidget(slider) + slider.deleteLater() + + def _on_dim_slider_value_changed(self, key: DimKey, value: Index) -> None: + self._current_index[key] = value + self.valueChanged.emit(self.value()) + + def add_or_update_dimension(self, key: DimKey, value: Index) -> None: + """Add a dimension if it doesn't exist, otherwise update the value.""" + if key in self._sliders: + self._sliders[key].forceValue(value) + else: + self.add_dimension(key, value) + + def resizeEvent(self, a0: QResizeEvent | None) -> None: + # align all labels + if sliders := list(self._sliders.values()): + for lbl in ("_dim_label", "_pos_label", "_out_of_label"): + lbl_width = max(getattr(s, lbl).sizeHint().width() for s in sliders) + for s in sliders: + getattr(s, lbl).setFixedWidth(lbl_width) + + super().resizeEvent(a0) + + def sizeHint(self) -> QSize: + return super().sizeHint().boundedTo(QSize(9999, 0)) diff --git a/src/ndv/viewer2/_lut_control.py b/src/ndv/viewer2/_lut_control.py new file mode 100644 index 0000000..1b9c1b3 --- /dev/null +++ b/src/ndv/viewer2/_lut_control.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence, cast + +import numpy as np +from qtpy.QtCore import Qt +from qtpy.QtWidgets import QCheckBox, QFrame, QHBoxLayout, QPushButton, QWidget +from superqt import QLabeledRangeSlider +from superqt.cmap import QColormapComboBox +from superqt.utils import signals_blocked + +from ._dims_slider import SS + +if TYPE_CHECKING: + from typing import Iterable + + import cmap + + from ._backends.protocols import PImageHandle + + +class CmapCombo(QColormapComboBox): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent, allow_user_colormaps=True, add_colormap_text="Add...") + self.setMinimumSize(120, 21) + # self.setStyleSheet("background-color: transparent;") + + def showPopup(self) -> None: + super().showPopup() + popup = self.findChild(QFrame) + popup.setMinimumWidth(self.width() + 100) + popup.move(popup.x(), popup.y() - self.height() - popup.height()) + + +class LutControl(QWidget): + def __init__( + self, + channel: Sequence[PImageHandle], + name: str = "", + parent: QWidget | None = None, + cmaplist: Iterable[Any] = (), + cmap: cmap.Colormap | None = None, + ) -> None: + super().__init__(parent) + self._channel = channel + self._name = name + + self._visible = QCheckBox(name) + self._visible.setChecked(True) + self._visible.toggled.connect(self._on_visible_changed) + + self._cmap = CmapCombo() + self._cmap.currentColormapChanged.connect(self._on_cmap_changed) + for handle in channel: + self._cmap.addColormap(handle.cmap) + for color in cmaplist: + self._cmap.addColormap(color) + if cmap is not None: + self._cmap.setCurrentColormap(cmap) + + self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) + self._clims.setStyleSheet(SS) + self._clims.setHandleLabelPosition( + QLabeledRangeSlider.LabelPosition.LabelsOnHandle + ) + self._clims.setEdgeLabelMode(QLabeledRangeSlider.EdgeLabelMode.NoLabel) + self._clims.setRange(0, 2**8) + self._clims.valueChanged.connect(self._on_clims_changed) + + self._auto_clim = QPushButton("Auto") + self._auto_clim.setMaximumWidth(42) + self._auto_clim.setCheckable(True) + self._auto_clim.setChecked(True) + self._auto_clim.toggled.connect(self.update_autoscale) + + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._visible) + layout.addWidget(self._cmap) + layout.addWidget(self._clims) + layout.addWidget(self._auto_clim) + + self.update_autoscale() + + def autoscaleChecked(self) -> bool: + return cast("bool", self._auto_clim.isChecked()) + + def _on_clims_changed(self, clims: tuple[float, float]) -> None: + self._auto_clim.setChecked(False) + for handle in self._channel: + handle.clim = clims + + def _on_visible_changed(self, visible: bool) -> None: + for handle in self._channel: + handle.visible = visible + if visible: + self.update_autoscale() + + def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: + for handle in self._channel: + handle.cmap = cmap + + def update_autoscale(self) -> None: + if ( + not self._auto_clim.isChecked() + or not self._visible.isChecked() + or not self._channel + ): + return + + # find the min and max values for the current channel + clims = [np.inf, -np.inf] + for handle in self._channel: + clims[0] = min(clims[0], np.nanmin(handle.data)) + clims[1] = max(clims[1], np.nanmax(handle.data)) + + mi, ma = tuple(int(x) for x in clims) + if mi != ma: + for handle in self._channel: + handle.clim = (mi, ma) + + # set the slider values to the new clims + with signals_blocked(self._clims): + self._clims.setMinimum(min(mi, self._clims.minimum())) + self._clims.setMaximum(max(ma, self._clims.maximum())) + self._clims.setValue((mi, ma)) + + +def _get_default_clim_from_data(data: np.ndarray) -> tuple[float, float]: + """Compute a reasonable clim from the min and max, taking nans into account. + + If there are no non-finite values (nan, inf, -inf) this is as fast as it can be. + Otherwise, this functions is about 3x slower. + """ + # Fast + min_value = data.min() + max_value = data.max() + + # Need more work? The nan-functions are slower + min_finite = np.isfinite(min_value) + max_finite = np.isfinite(max_value) + if not (min_finite and max_finite): + finite_data = data[np.isfinite(data)] + if finite_data.size: + min_value = finite_data.min() + max_value = finite_data.max() + else: + min_value = max_value = 0 # no finite values in the data + + return min_value, max_value diff --git a/src/ndv/viewer2/_octree.py b/src/ndv/viewer2/_octree.py new file mode 100644 index 0000000..e1af860 --- /dev/null +++ b/src/ndv/viewer2/_octree.py @@ -0,0 +1,105 @@ +from typing import Generic, Iterator, NamedTuple, TypeVar + +MAX_DEPTH = 8 + + +class Coord(NamedTuple): + x: float + y: float + z: float = 0 + + +class Bounds(NamedTuple): + x_min: float + x_max: float + y_min: float + y_max: float + z_min: float = 0 + z_max: float = 0 + + @property + def midpoint(self) -> Coord: + return Coord( + (self.x_min + self.x_max) / 2, + (self.y_min + self.y_max) / 2, + (self.z_min + self.z_max) / 2, + ) + + def isdisjoint(self, other: "Bounds") -> bool: + return ( + self.x_max < other.x_min + or self.x_min > other.x_max + or self.y_max < other.y_min + or self.y_min > other.y_max + or self.z_max < other.z_min + or self.z_min > other.z_max + ) + + def intersects(self, other: "Bounds") -> bool: + return not self.isdisjoint(other) + + def split(self) -> tuple["Bounds", ...]: + x_min, x_max, y_min, y_max, z_min, z_max = self + x_mid, y_mid, z_mid = self.midpoint + return ( + Bounds(x_min, x_mid, y_min, y_mid, z_min, z_mid), + Bounds(x_mid, x_max, y_min, y_mid, z_min, z_mid), + Bounds(x_min, x_mid, y_mid, y_max, z_min, z_mid), + Bounds(x_mid, x_max, y_mid, y_max, z_min, z_mid), + Bounds(x_min, x_mid, y_min, y_mid, z_mid, z_max), + Bounds(x_mid, x_max, y_min, y_mid, z_mid, z_max), + Bounds(x_min, x_mid, y_mid, y_max, z_mid, z_max), + Bounds(x_mid, x_max, y_mid, y_max, z_mid, z_max), + ) + + +T = TypeVar("T") + + +class OctreeNode(Generic[T]): + def __init__(self, bounds: Bounds, depth: int = 0) -> None: + # spatial bounds of this node (xmin, xmax, ymin, ymax, zmin, zmax) + self.bounds = bounds + self.children: list[OctreeNode] = [] # children of this node + self.depth = depth # depth of the node in the tree + self.data: T | None = None # placeholder for storing references to data chunks + + def is_leaf(self) -> bool: + return not self.children + + def split(self, max_depth: int = MAX_DEPTH) -> None: + if self.depth < max_depth: + self.children = [ + OctreeNode(bounds, self.depth + 1) for bounds in self.bounds.split() + ] + + def insert_data(self, chunk: T, chunk_bounds: Bounds) -> None: + """Insert data into the tree. + + If `self` is a leaf and the depth is less than the maximum depth, `self` is + split into 8 children. The data is then inserted into the first child that + intersects with the data bounds. + """ + if self.is_leaf(): + if self.depth < MAX_DEPTH: + self.split(MAX_DEPTH) + else: + self.data = chunk + return + for child in self.children: + if child.bounds.intersects(chunk_bounds): + child.insert_data(chunk, chunk_bounds) + break + + def query(self, view_bounds: Bounds) -> Iterator[T]: + """Query the tree for data chunks that intersect with the given bounds. + + Yields data chunks that intersect with the given bounds. + """ + if self.bounds.intersects(view_bounds): + if self.is_leaf(): + if self.data is not None: + yield self.data + else: + for child in self.children: + yield from child.query(view_bounds) diff --git a/src/ndv/viewer2/_save_button.py b/src/ndv/viewer2/_save_button.py new file mode 100644 index 0000000..0ce4511 --- /dev/null +++ b/src/ndv/viewer2/_save_button.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from qtpy.QtWidgets import QFileDialog, QPushButton, QWidget +from superqt.iconify import QIconifyIcon + +if TYPE_CHECKING: + from ._data_wrapper import DataWrapper + + +class SaveButton(QPushButton): + def __init__( + self, + data_wrapper: DataWrapper, + parent: QWidget | None = None, + ): + super().__init__(parent=parent) + self.setIcon(QIconifyIcon("mdi:content-save")) + self.clicked.connect(self._on_click) + + self._data_wrapper = data_wrapper + self._last_loc = str(Path.home()) + + def _on_click(self) -> None: + self._last_loc, _ = QFileDialog.getSaveFileName( + self, "Choose destination", str(self._last_loc), "" + ) + suffix = Path(self._last_loc).suffix + if suffix in (".zarr", ".ome.zarr", ""): + self._data_wrapper.save_as_zarr(self._last_loc) + else: + raise ValueError(f"Unsupported file format: {self._last_loc}") diff --git a/src/ndv/viewer2/_state.py b/src/ndv/viewer2/_state.py new file mode 100644 index 0000000..46d5d63 --- /dev/null +++ b/src/ndv/viewer2/_state.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import sys +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Literal, Mapping, Protocol, Sequence + +import numpy as np + +if TYPE_CHECKING: + import cmap + +# either the name or the index of a dimension +DimKey = str | int +# position or slice along a specific dimension +Index = int | slice +# name or dimension index of a channel +# string is only supported for arrays with string-type coordinates along the channel dim +# None is a special value that means all channels +ChannelKey = int | str | None + +SLOTS = {"slots": True} if sys.version_info >= (3, 10) else {} + + +@dataclass(**SLOTS) +class ChannelDisplay: + visible: bool = True + cmap: cmap._colormap.ColorStopsLike = "gray" + clims: tuple[float, float] | None = None + gamma: float = 1.0 + # whether to autoscale + # if a tuple the first element is the lower quantile + # and the second element is the upper quantile + # if True or (0, 1) use (np.min(), np.max()) ... otherwise use np.quantile + autoscale: bool | tuple[float, float] = (0, 1) + + +@dataclass(**SLOTS) +class ViewerState: + # index of the currently displayed slice + # for example (-2, -1) for the standard 2D viewer + # if string, then name2index is used to convert to index + visualized_indices: tuple[DimKey, DimKey] | tuple[DimKey, DimKey, DimKey] = (-2, -1) + + # the currently displayed position/slice along each dimension + # missing indices are assumed to be slice(None) (or 0?) + # if more than len(visualized_indices) have non-integer values, then + # reducers are used to reduce the data along the remaining dimensions + current_index: Mapping[DimKey, Index] = field(default_factory=dict) + + # functions to reduce data along axes remaining after slicing + reducers: Reducer | Mapping[DimKey, Reducer] = np.max + + # note: it is an error for channel_index to be in visualized_indices + channel_index: DimKey | None = None + + # settings for each channel along the channel dimension + # None is a special value that means all channels + # if channel_index is None, then luts[None] is used + luts: Mapping[ChannelKey, ChannelDisplay] = field(default_factory=dict) + # default colormap to use for channel [0, 1, 2, ...] + colormap_options: Sequence[cmap._colormap.ColorStopsLike] = ("gray",) + + @property + def ndim(self) -> Literal[2, 3]: + return 2 if len(self.visualized_indices) == 2 else 3 + + +class Reducer(Protocol): + def __call__( + self, data: np.ndarray, /, *, axis: int | tuple[int, ...] | None + ) -> np.ndarray | float: ... diff --git a/src/ndv/viewer2/_v2.py b/src/ndv/viewer2/_v2.py new file mode 100644 index 0000000..7cce272 --- /dev/null +++ b/src/ndv/viewer2/_v2.py @@ -0,0 +1,213 @@ +from typing import TYPE_CHECKING, Any, Mapping + +import numpy as np +from qtpy.QtWidgets import QVBoxLayout, QWidget +from superqt import ensure_main_thread +from superqt.utils import qthrottled + +from ndv._chunk_executor import Chunker, ChunkFuture +from ndv.viewer2._backends import get_canvas +from ndv.viewer2._dims_slider import DimsSliders +from ndv.viewer2._state import ViewerState + +if TYPE_CHECKING: + from ndv.viewer2._backends.protocols import PCanvas, PImageHandle + + +class NDViewer(QWidget): + def __init__(self, data: Any, *, parent: QWidget | None = None): + super().__init__(parent=parent) + self._state = ViewerState(visualized_indices=(0, 2, 3)) + self._chunker = Chunker() + self._channels: dict[int | None, PImageHandle] = {} + + self._canvas: PCanvas = get_canvas()(lambda x: None) + self._canvas.set_ndim(self._state.ndim) + + # the sliders that control the index of the displayed image + self._dims_sliders = DimsSliders(self) + self._dims_sliders.valueChanged.connect( + qthrottled(self._request_data_for_index, 20, leading=True) + ) + + layout = QVBoxLayout(self) + layout.setSpacing(2) + layout.setContentsMargins(6, 6, 6, 6) + layout.addWidget(self._canvas.qwidget(), 1) + layout.addWidget(self._dims_sliders, 0) + + if data is not None: + self.set_data(data) + + def __del__(self) -> None: + self._chunker.shutdown(cancel_futures=True, wait=False) + + def set_data(self, data: Any) -> None: + self._data = data + self._dims_sliders.setMaxima( + {i: data.shape[i] - 1 for i in range(len(data.shape))} + ) + + def set_current_index(self, index: Mapping[int | str, int | slice]) -> None: + """Set the currentl displayed index.""" + self._state.current_index = index + self.refresh() + + def refresh(self) -> None: + self._request_data_for_index(self._state.current_index) + + def _norm_index(self, index: int | str) -> int: + """Remove string keys from index.""" + # TODO: this is a temporary solution + # the datawrapper __getitem__ should handle this + dim_names = () + if index in dim_names: + return dim_names.index(index) + elif isinstance(index, int): + return index + raise ValueError(f"Invalid index: {index}") + + def _request_data_for_index(self, index: Mapping[int | str, int | slice]) -> None: + ndim = len(self._data.shape) + + # determine chunk shape + # only visualized dimensions are chunked + chunk_size = 128 # TODO: pick bettter + chunk_shape: list[int | None] = [None] * ndim + visualized = [self._norm_index(dim) for dim in self._state.visualized_indices] + for dim in range(ndim): + if dim in visualized: + chunk_shape[dim] = chunk_size + + index = {self._norm_index(k): v for k, v in index.items()} + for v in visualized: + if isinstance(index.get(v), int): + del index[v] + + if not index: + return + + # clear existing handles + for handle in self._channels.values(): + handle.clear() + for future in self._chunker.request_chunks( + data=self._data, + index=index, + chunk_shape=chunk_shape, + cancel_existing=True, + ): + future.add_done_callback(self._draw_chunk) + + @ensure_main_thread # type: ignore + def _draw_chunk(self, future: ChunkFuture) -> None: + if future.cancelled(): + return + if future.exception(): + print("ERROR: ", future.exception()) + return + + chunk = future.result() + data = chunk.data + offset = chunk.offset + + if self._state.channel_index is None: + channel_index = None + else: + channel_index = offset[self._norm_index(self._state.channel_index)] + + visualized = [self._norm_index(dim) for dim in self._state.visualized_indices] + offset = tuple(offset[i] for i in visualized) + + if data.ndim == 2: + return + + if not (handle := self._channels.get(channel_index)): + full_shape = self._data.shape + texture_shape = tuple( + full_shape[self._norm_index(i)] for i in self._state.visualized_indices + ) + empty = np.empty(texture_shape, dtype=chunk.data.dtype) + self._channels[channel_index] = handle = self._canvas.add_volume(empty) + + try: + mi, ma = handle.clim + handle.clim = (min(mi, np.min(data)), max(ma, np.max(data))) + except Exception as e: + print("err in clim: ", e) + handle.clim = (0, 5000) + + print(">>draw:") + print(f" data: {data.shape} @ {offset}") + handle.directly_set_texture_offset(data, offset) + self._canvas.refresh() + + # # of the chunks will determine the order of the channels in the LUTS + # # (without additional logic to sort them by index, etc.) + # if (handles := self._channels.get(ch_key)) is None: + # handles = self._create_channel(ch_key) + + # if not handles: + # if data.ndim == 2: + # handles.append(self._canvas.add_image(data, cmap=handles.cmap)) + # elif data.ndim == 3: + # empty = np.empty((60, 256, 256), dtype=np.uint16) + # handles.append(self._canvas.add_volume(empty, cmap=handles.cmap)) + + # handles[0].set_data(data, chunk.offset) + # self._canvas.refresh() + + +# class NDViewer: +# def __init__(self, data: Any, state: ViewerState | None) -> None: +# self._state = state or ViewerState() +# if data is not None: +# self.set_data(data) + +# @property +# def data(self) -> Any: +# raise NotImplementedError + +# def set_data(self, data: Any) -> None: ... + +# @property +# def state(self) -> ViewerState: +# return self._state + +# def set_state(self, state: ViewerState) -> None: +# # validate... +# self._state = state + +# def set_visualized_indices(self, indices: tuple[DimKey, DimKey]) -> None: +# """Set which indices are visualized.""" +# if self._state.channel_index in indices: +# raise ValueError( +# f"channel index ({self._state.channel_index!r}) cannot be in visualized" +# f"indices: {indices}" +# ) +# self._state.visualized_indices = indices +# self.refresh() + +# def set_channel_index(self, index: DimKey | None) -> None: +# """Set the channel index.""" +# if index in self._state.visualized_indices: +# # consider alternatives to raising. +# # e.g. if len(visualized_indices) == 3, then we could pop index +# raise ValueError( +# f"channel index ({index!r}) cannot be in visualized indices: " +# f"{self._state.visualized_indices}" +# ) +# self._state.channel_index = index +# self.refresh() + +# def set_current_index(self, index: Mapping[DimKey, Index]) -> None: +# """Set the currentl displayed index.""" +# self._state.current_index = index +# self.refresh() + +# def refresh(self) -> None: +# """Refresh the viewer.""" +# index = self._state.current_index +# self._chunker.request_index(index) + +# @ensure_main_thread # type: ignore +# def _draw_chunk(self, chunk: ChunkResponse) -> None: ... diff --git a/src/ndv/viewer2/_viewer.py b/src/ndv/viewer2/_viewer.py new file mode 100644 index 0000000..a4735f9 --- /dev/null +++ b/src/ndv/viewer2/_viewer.py @@ -0,0 +1,562 @@ +from __future__ import annotations + +from itertools import cycle +from typing import ( + TYPE_CHECKING, + Hashable, + Literal, + MutableSequence, + Sequence, + SupportsIndex, + cast, + overload, +) + +import cmap +import numpy as np +from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget +from superqt import QCollapsible, QElidingLabel, QIconifyIcon, ensure_main_thread +from superqt.utils import qthrottled, signals_blocked + +from ndv._chunking import Chunker, ChunkResponse, RequestFinished +from ndv.viewer2._components import ( + ChannelMode, + ChannelModeButton, + DimToggleButton, + QSpinner, +) + +from ._backends import get_canvas +from ._backends.protocols import PImageHandle +from ._data_wrapper import DataWrapper +from ._dims_slider import DimsSliders +from ._lut_control import LutControl + +if TYPE_CHECKING: + from typing import Any, Iterable, TypeAlias + + from qtpy.QtGui import QCloseEvent + + from ._backends.protocols import PCanvas + from ._dims_slider import DimKey, Indices, Sizes + + ImgKey: TypeAlias = Hashable + # any mapping of dimensions to sizes + SizesLike: TypeAlias = Sizes | Iterable[int | tuple[DimKey, int] | Sequence] + +MID_GRAY = "#888888" +GRAYS = cmap.Colormap("gray") +DEFAULT_COLORMAPS = [ + cmap.Colormap("green"), + cmap.Colormap("magenta"), + cmap.Colormap("cyan"), + cmap.Colormap("yellow"), + cmap.Colormap("red"), + cmap.Colormap("blue"), + cmap.Colormap("cubehelix"), + cmap.Colormap("gray"), +] +MONO_CHANNEL = -999999 + + +class Channel(MutableSequence[PImageHandle]): + def __init__(self, ch_key: int, cmap: cmap.Colormap = GRAYS) -> None: + self.ch_key = ch_key + self._handles: list[PImageHandle] = [] + self.cmap = cmap + + @overload + def __getitem__(self, i: int) -> PImageHandle: ... + @overload + def __getitem__(self, i: slice) -> list[PImageHandle]: ... + def __getitem__(self, i: int | slice) -> PImageHandle | list[PImageHandle]: + return self._handles[i] + + @overload + def __setitem__(self, i: SupportsIndex, value: PImageHandle) -> None: ... + @overload + def __setitem__(self, i: slice, value: Iterable[PImageHandle]) -> None: ... + def __setitem__( + self, i: SupportsIndex | slice, value: PImageHandle | Iterable[PImageHandle] + ) -> None: + self._handles[i] = value # type: ignore + + @overload + def __delitem__(self, i: int) -> None: ... + @overload + def __delitem__(self, i: slice) -> None: ... + def __delitem__(self, i: int | slice) -> None: + del self._handles[i] + + def __len__(self) -> int: + return len(self._handles) + + def insert(self, i: int, value: PImageHandle) -> None: + self._handles.insert(i, value) + + +class NDViewer(QWidget): + """A viewer for ND arrays. + + This widget displays a single slice from an ND array (or a composite of slices in + different colormaps). The widget provides sliders to select the slice to display, + and buttons to control the display mode of the channels. + + An important concept in this widget is the "index". The index is a mapping of + dimensions to integers or slices that define the slice of the data to display. For + example, a numpy slice of `[0, 1, 5:10]` would be represented as + `{0: 0, 1: 1, 2: slice(5, 10)}`, but dimensions can also be named, e.g. + `{'t': 0, 'c': 1, 'z': slice(5, 10)}`. The index is used to select the data from + the datastore, and to determine the position of the sliders. + + The flow of data is as follows: + + - The user sets the data using the `set_data` method. This will set the number + and range of the sliders to the shape of the data, and display the first slice. + - The user can then use the sliders to select the slice to display. The current + slice is defined as a `Mapping` of `{dim -> int|slice}` and can be retrieved + with the `_dims_sliders.value()` method. To programmatically set the current + position, use the `setIndex` method. This will set the values of the sliders, + which in turn will trigger the display of the new slice via the + `_request_data_for_index` method. + - `_request_data_for_index` is an asynchronous method that retrieves the data for + the given index from the datastore (using `_isel`) and queues the + `_draw_chunk` method to be called when the data is ready. The logic + for extracting data from the datastore is defined in `_data_wrapper.py`, which + handles idiosyncrasies of different datastores (e.g. xarray, tensorstore, etc). + - `_draw_chunk` is called when the data is ready, and updates the image. + Note that if the slice is multidimensional, the data will be reduced to 2D using + max intensity projection (and double-clicking on any given dimension slider will + turn it into a range slider allowing a projection to be made over that dimension). + - The image is displayed on the canvas, which is an object that implements the + `PCanvas` protocol (mostly, it has an `add_image` method that returns a handle + to the added image that can be used to update the data and display). This + small abstraction allows for various backends to be used (e.g. vispy, pygfx, etc). + + Parameters + ---------- + data : Any + The data to display. This can be any duck-like ND array, including numpy, dask, + xarray, jax, tensorstore, zarr, etc. You can add support for new datastores by + subclassing `DataWrapper` and implementing the required methods. See + `DataWrapper` for more information. + parent : QWidget, optional + The parent widget of this widget. + channel_axis : Hashable, optional + The axis that represents the channels in the data. If not provided, this will + be guessed from the data. + channel_mode : ChannelMode, optional + The initial mode for displaying the channels. If not provided, this will be + set to ChannelMode.MONO. + """ + + def __init__( + self, + data: DataWrapper | Any, + *, + colormaps: Iterable[cmap._colormap.ColorStopsLike] | None = None, + parent: QWidget | None = None, + channel_axis: int | None = None, + channel_mode: ChannelMode | str = ChannelMode.MONO, + ): + super().__init__(parent=parent) + + # ATTRIBUTES ---------------------------------------------------- + + # mapping of key to a list of objects that control image nodes in the canvas + self._channels: dict[int, Channel] = {} + + # mapping of same keys to the LutControl objects control image display props + self._lut_ctrls: dict[int, LutControl] = {} + + # the set of dimensions we are currently visualizing (e.g. (-2, -1) for 2D) + # this is used to control which dimensions have sliders and the behavior + # of isel when selecting data from the datastore + self._visualized_dims: set[DimKey] = set() + + # the axis that represents the channels in the data + self._channel_axis = channel_axis + self._channel_mode: ChannelMode = None # type: ignore # set in set_channel_mode + # colormaps that will be cycled through when displaying composite images + if colormaps is not None: + self._cmaps = [cmap.Colormap(c) for c in colormaps] + else: + self._cmaps = DEFAULT_COLORMAPS + self._cmap_cycle = cycle(self._cmaps) + + # number of dimensions to display + self._ndims: Literal[2, 3] = 2 + self._chunker = Chunker( + None, + # IMPORTANT + # chunking here will determine how non-visualized dims are reduced + # so chunkshape will need to change based on the set of visualized dims + chunks=(20, 100, 32, 32), + on_ready=self._draw_chunk, + ) + + # WIDGETS ---------------------------------------------------- + + # the button that controls the display mode of the channels + self._channel_mode_btn = ChannelModeButton(self) + self._channel_mode_btn.clicked.connect(self.set_channel_mode) + # button to reset the zoom of the canvas + self._set_range_btn = QPushButton( + QIconifyIcon("fluent:full-screen-maximize-24-filled"), "", self + ) + self._set_range_btn.clicked.connect(self._on_set_range_clicked) + + # button to change number of displayed dimensions + self._ndims_btn = DimToggleButton(self) + self._ndims_btn.clicked.connect(self._toggle_3d) + + # place to display dataset summary + self._data_info_label = QElidingLabel("", parent=self) + self._progress_spinner = QSpinner(self) + + # place to display arbitrary text + self._hover_info_label = QLabel("", self) + # the canvas that displays the images + self._canvas: PCanvas = get_canvas()(self._hover_info_label.setText) + self._canvas.set_ndim(self._ndims) + + # the sliders that control the index of the displayed image + self._dims_sliders = DimsSliders(self) + self._dims_sliders.valueChanged.connect( + qthrottled(self._request_data_for_index, 20, leading=True) + ) + + self._lut_drop = QCollapsible("LUTs", self) + self._lut_drop.setCollapsedIcon(QIconifyIcon("bi:chevron-down", color=MID_GRAY)) + self._lut_drop.setExpandedIcon(QIconifyIcon("bi:chevron-up", color=MID_GRAY)) + lut_layout = cast("QVBoxLayout", self._lut_drop.layout()) + lut_layout.setContentsMargins(0, 1, 0, 1) + lut_layout.setSpacing(0) + if ( + hasattr(self._lut_drop, "_content") + and (layout := self._lut_drop._content.layout()) is not None + ): + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # LAYOUT ----------------------------------------------------- + + self._btns = btns = QHBoxLayout() + btns.setContentsMargins(0, 0, 0, 0) + btns.setSpacing(0) + btns.addStretch() + btns.addWidget(self._channel_mode_btn) + btns.addWidget(self._ndims_btn) + btns.addWidget(self._set_range_btn) + + info = QHBoxLayout() + info.setContentsMargins(0, 0, 0, 2) + info.setSpacing(0) + info.addWidget(self._data_info_label) + info.addWidget(self._progress_spinner) + + layout = QVBoxLayout(self) + layout.setSpacing(2) + layout.setContentsMargins(6, 6, 6, 6) + layout.addLayout(info) + layout.addWidget(self._canvas.qwidget(), 1) + layout.addWidget(self._hover_info_label) + layout.addWidget(self._dims_sliders) + layout.addWidget(self._lut_drop) + layout.addLayout(btns) + + # SETUP ------------------------------------------------------ + + self.set_channel_mode(channel_mode) + if data is not None: + self.set_data(data) + + # ------------------- PUBLIC API ---------------------------- + @property + def dims_sliders(self) -> DimsSliders: + """Return the DimsSliders widget.""" + return self._dims_sliders + + @property + def data_wrapper(self) -> DataWrapper: + """Return the DataWrapper object around the datastore.""" + return self._data_wrapper + + @property + def data(self) -> Any: + """Return the data backing the view.""" + return self._data_wrapper.data + + @data.setter + def data(self, data: Any) -> None: + """Set the data backing the view.""" + raise AttributeError("Cannot set data directly. Use `set_data` method.") + + def set_data( + self, data: DataWrapper | Any, *, initial_index: Indices | None = None + ) -> None: + """Set the datastore, and, optionally, the sizes of the data. + + Properties + ---------- + data : DataWrapper | Any + The data to display. This can be any duck-like ND array, including numpy, + dask, xarray, jax, tensorstore, zarr, etc. You can add support for new + datastores by subclassing `DataWrapper` and implementing the required + methods. If a `DataWrapper` instance is passed, it is used directly. + See `DataWrapper` for more information. + initial_index : Indices | None + The initial index to display. This is a mapping of dimensions to integers + or slices that define the slice of the data to display. If not provided, + the initial index will be set to the middle of the data. + """ + # store the data + self._clear_images() + + self._data_wrapper = DataWrapper.create(data) + self._chunker.data_wrapper = self._data_wrapper + if chunks := self._data_wrapper.chunks(): + # temp hack ... always group non-visible channels + chunks = list(chunks) + chunks[:-2] = (1000,) * len(chunks[:-2]) + self._chunker.chunks = tuple(chunks) + + # set channel axis + self._channel_axis = self._data_wrapper.guess_channel_axis() + self._chunker.channel_axis = self._channel_axis + + # update the dimensions we are visualizing + sizes = self._data_wrapper.sizes() + visualized_dims = list(sizes)[-self._ndims :] + self.set_visualized_dims(visualized_dims) + + # update the range of all the sliders to match the sizes we set above + with signals_blocked(self._dims_sliders): + self._update_slider_ranges() + + # redraw + if initial_index is None: + idx = {k: int(v // 2) for k, v in sizes.items() if k not in visualized_dims} + else: + if not isinstance(initial_index, dict): # pragma: no cover + raise TypeError("initial_index must be a dict") + idx = initial_index + self.set_current_index(idx) + + # update the data info label + self._data_info_label.setText(self._data_wrapper.summary_info()) + + def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: + """Set the dimensions that will be visualized. + + This dims will NOT have sliders associated with them. + """ + self._visualized_dims = set(dims) + for d in self._dims_sliders._sliders: + self._dims_sliders.set_dimension_visible(d, d not in self._visualized_dims) + for d in self._visualized_dims: + self._dims_sliders.set_dimension_visible(d, False) + + def set_ndim(self, ndim: Literal[2, 3]) -> None: + """Set the number of dimensions to display.""" + if ndim not in (2, 3): + raise ValueError("ndim must be 2 or 3") + + self._ndims = ndim + self._canvas.set_ndim(ndim) + + # set the visibility of the last non-channel dimension + sizes = list(self._data_wrapper.sizes()) + if self._channel_axis is not None: + sizes = [x for x in sizes if x != self._channel_axis] + if len(sizes) >= 3: + dim3 = sizes[-3] + self._dims_sliders.set_dimension_visible(dim3, True if ndim == 2 else False) + + # clear image handles and redraw + if self._channels: + self._clear_images() + self._request_data_for_index(self._dims_sliders.value()) + + def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: + """Set the mode for displaying the channels. + + In "composite" mode, the channels are displayed as a composite image, using + self._channel_axis as the channel axis. In "grayscale" mode, each channel is + displayed separately. (If mode is None, the current value of the + channel_mode_picker button is used) + + Parameters + ---------- + mode : ChannelMode | str | None + The mode to set, must be one of 'composite' or 'mono'. + """ + # bool may happen when called from the button clicked signal + if mode is None or isinstance(mode, bool): + mode = self._channel_mode_btn.mode() + else: + mode = ChannelMode(mode) + self._channel_mode_btn.setMode(mode) + if mode == self._channel_mode: + return + + self._channel_mode = mode + self._cmap_cycle = cycle(self._cmaps) # reset the colormap cycle + if self._channel_axis is not None: + # set the visibility of the channel slider + self._dims_sliders.set_dimension_visible( + self._channel_axis, mode != ChannelMode.COMPOSITE + ) + + if self._channels: + self._clear_images() + self._request_data_for_index(self._dims_sliders.value()) + + def set_current_index(self, index: Indices | None = None) -> None: + """Set the index of the displayed image. + + `index` is a mapping of dimensions to integers or slices that define the slice + of the data to display. For example, a numpy slice of `[0, 1, 5:10]` would be + represented as `{0: 0, 1: 1, 2: slice(5, 10)}`, but dimensions can also be + named, e.g. `{'t': 0, 'c': 1, 'z': slice(5, 10)}` if the data has named + dimensions. + + Note, calling `.set_current_index()` with no arguments will force the widget + to redraw the current slice. + """ + self._dims_sliders.setValue(index or {}) + + # camelCase aliases + + dimsSliders = dims_sliders + setChannelMode = set_channel_mode + setData = set_data + setCurrentIndex = set_current_index + setVisualizedDims = set_visualized_dims + + # ------------------- PRIVATE METHODS ---------------------------- + + def _toggle_3d(self) -> None: + self.set_ndim(3 if self._ndims == 2 else 2) + + def _update_slider_ranges(self) -> None: + """Set the maximum values of the sliders. + + If `sizes` is not provided, sizes will be inferred from the datastore. + """ + maxes = self._data_wrapper.sizes() + self._dims_sliders.setMaxima({k: v - 1 for k, v in maxes.items()}) + + # FIXME: this needs to be moved and made user-controlled + for dim in list(maxes.keys())[-self._ndims :]: + self._dims_sliders.set_dimension_visible(dim, False) + + def _on_set_range_clicked(self) -> None: + # using method to swallow the parameter passed by _set_range_btn.clicked + self._canvas.set_range() + + def _image_key(self, index: Indices) -> ImgKey: + """Return the key for image handle(s) corresponding to `index`.""" + if self._channel_mode == ChannelMode.COMPOSITE: + val = index.get(self._channel_axis, 0) + if isinstance(val, slice): + return (val.start, val.stop) + return val + return 0 + + def _request_data_for_index(self, index: Indices) -> None: + """Retrieve data for `index` from datastore and update canvas image(s). + + This is the first step in updating the displayed image, it is triggered by + the valueChanged signal from the sliders. + + This will pull the data from the datastore using the given index, and update + the image handle(s) with the new data. This method is *asynchronous*. It + makes a request for the new data slice and queues _on_data_future_done to be + called when the data is ready. + """ + print(f"\n--------\nrequesting index {index}", self._channel_axis) + if ( + self._channel_mode == ChannelMode.COMPOSITE + and self._channel_axis is not None + ): + index = {**index, self._channel_axis: slice(None)} + self._progress_spinner.show() + # TODO: don't request channels not being displayed + # TODO: don't request if the data is already in the cache + self._chunker.request_index(index, ndims=self._ndims) + + @ensure_main_thread # type: ignore + def _draw_chunk(self, chunk: ChunkResponse) -> None: + """Actually update the image handle(s) with the (sliced) data. + + By this point, data should be sliced from the underlying datastore. Any + dimensions remaining that are more than the number of visualized dimensions + (currently just 2D) will be reduced using max intensity projection (currently). + """ + if chunk is RequestFinished: # fix typing + self._progress_spinner.hide() + for lut in self._lut_ctrls.values(): + lut.update_autoscale() + return + + if self._channel_mode == ChannelMode.MONO: + ch_key = MONO_CHANNEL + else: + ch_key = chunk.channel_index + + data = chunk.data + if data.ndim == 2: + return + # TODO: Channel object creation could be moved. + # having it here is the laziest... but means that the order of arrival + # of the chunks will determine the order of the channels in the LUTS + # (without additional logic to sort them by index, etc.) + if (handles := self._channels.get(ch_key)) is None: + handles = self._create_channel(ch_key) + + if not handles: + if data.ndim == 2: + handles.append(self._canvas.add_image(data, cmap=handles.cmap)) + elif data.ndim == 3: + empty = np.empty((60, 256, 256), dtype=np.uint16) + handles.append(self._canvas.add_volume(empty, cmap=handles.cmap)) + + handles[0].set_data(data, chunk.offset) + self._canvas.refresh() + + def _create_channel(self, ch_key: int) -> Channel: + # improve this + cmap = GRAYS if ch_key == MONO_CHANNEL else next(self._cmap_cycle) + + self._channels[ch_key] = channel = Channel(ch_key, cmap=cmap) + self._lut_ctrls[ch_key] = lut = LutControl( + channel, + f"Ch {ch_key}", + self, + cmaplist=self._cmaps + DEFAULT_COLORMAPS, + cmap=cmap, + ) + self._lut_drop.addWidget(lut) + return channel + + def _clear_images(self) -> None: + """Remove all images from the canvas.""" + for handles in self._channels.values(): + for handle in handles: + handle.remove() + self._channels.clear() + + # clear the current LutControls as well + for c in self._lut_ctrls.values(): + cast("QVBoxLayout", self.layout()).removeWidget(c) + c.deleteLater() + self._lut_ctrls.clear() + + def _is_idle(self) -> bool: + """Return True if no futures are running. Used for testing, and debugging.""" + return bool(self._chunker.pending_futures) + + def closeEvent(self, a0: QCloseEvent | None) -> None: + self._chunker.shutdown() + super().closeEvent(a0) diff --git a/src/ndv/viewer2/spin.gif b/src/ndv/viewer2/spin.gif new file mode 100644 index 0000000..f54d0d6 Binary files /dev/null and b/src/ndv/viewer2/spin.gif differ diff --git a/tests/test_chunker.py b/tests/test_chunker.py new file mode 100644 index 0000000..1165c01 --- /dev/null +++ b/tests/test_chunker.py @@ -0,0 +1,109 @@ +import numpy as np +import numpy.testing as npt + +from ndv._chunk_executor import Chunker, iter_chunk_aligned_slices + + +def test_iter_chunk_aligned_slices() -> None: + x = iter_chunk_aligned_slices( + shape=(10, 9), chunks=(4, 3), slices=np.index_exp[3:9, 1:None] + ) + assert list(x) == [ + (slice(3, 4, 1), slice(1, 3, 1)), + (slice(3, 4, 1), slice(3, 6, 1)), + (slice(3, 4, 1), slice(6, 9, 1)), + (slice(4, 8, 1), slice(1, 3, 1)), + (slice(4, 8, 1), slice(3, 6, 1)), + (slice(4, 8, 1), slice(6, 9, 1)), + (slice(8, 9, 1), slice(1, 3, 1)), + (slice(8, 9, 1), slice(3, 6, 1)), + (slice(8, 9, 1), slice(6, 9, 1)), + ] + + # this one tests that slices doesn't need to be the same length as shape + # ... is added at the end + y = iter_chunk_aligned_slices(shape=(6, 6), chunks=4, slices=np.index_exp[1:4]) + assert list(y) == [ + (slice(1, 4, 1), slice(0, 4, 1)), + (slice(1, 4, 1), slice(4, 6, 1)), + ] + + # this tests ellipsis in the middle + z = iter_chunk_aligned_slices( + shape=(3, 3, 3), chunks=2, slices=np.index_exp[1, ..., :2] + ) + assert list(z) == [ + (slice(1, 2, 1), slice(0, 2, 1), slice(0, 2, 1)), + (slice(1, 2, 1), slice(2, 3, 1), slice(0, 2, 1)), + ] + + +def test_chunker() -> None: + data = np.random.rand(100, 100).astype(np.float32) + + with Chunker() as chunker: + futures = chunker.request_chunks(data) + + assert len(futures) == 1 + npt.assert_array_equal(data, futures[0].result().data) + + data2 = np.random.rand(30, 30, 30).astype(np.float32) + # test that the data is correctly chunked with weird chunk shapes + with Chunker() as chunker: + futures = chunker.request_chunks( + data2, index={0: 0}, chunk_shape=(None, 17, 12) + ) + + new = np.empty_like(data2[0]) + for future in futures: + result = future.result() + new[result.location[1:]] = result.data + npt.assert_array_equal(new, data2[0]) + + +# # this test is provided as an example of using dask to accomplish a similar thing +# # this library should try to retain support for using dask instead of the internal +# # chunker ... but it's nice not to have to depend on dask otherwise. +# def test_dask_chunker() -> None: +# try: +# import dask.array as da +# from dask.distributed import Client +# except ImportError: +# pytest.skip("Dask not installed") + +# from itertools import product + +# data = np.random.rand(100, 100).astype(np.float32) +# dask_data_chunked = da.from_array(data, chunks=(25, 20)) # type: ignore +# chunk_sizes = dask_data_chunked.chunks + +# with Client() as client: # type: ignore [no-untyped-call] +# for idx in product(*(range(x) for x in dask_data_chunked.numblocks)): +# # Calculate the start indices (offsets) for the chunk +# # THIS is the main thing we'd need for visualization purposes. +# # wish there was an easier way to get this from the chunk_result alone +# offset = tuple(sum(sizes[:x]) for sizes, x in zip(chunk_sizes, idx)) + +# chunk = dask_data_chunked.blocks[idx] + +# future = client.compute(chunk) +# chunk_result = future.result() + +# # Test that the data is correctly chunked and equal to the original data +# sub_idx = tuple(slice(x, x + y) for x, y in zip(offset, chunk_result.shape)) +# expected = data[sub_idx] +# npt.assert_array_equal(expected, chunk_result) + + +# def test_dask_map_blocks() -> None: +# if TYPE_CHECKING: +# import dask.array +# else: +# dask = pytest.importorskip("dask") + +# dask_array = dask.array.random.randint(100, size=(100, 100), chunks=(25, 20)) +# block_ranges = (range(x) for x in dask_array.numblocks) +# for block_id in product(*(block_ranges)): +# _offset = tuple(sum(sizes[:x]) for sizes, x in zip(dask_array.chunks, block_id)) +# _chunk = dask_array.blocks[block_id] +# print(_offset, _chunk) diff --git a/x.py b/x.py new file mode 100644 index 0000000..cebe52b --- /dev/null +++ b/x.py @@ -0,0 +1,12 @@ +import sys + +from qtpy.QtWidgets import QApplication + +import ndv +from ndv.viewer2._v2 import NDViewer + +data = ndv.data.cells3d() +app = QApplication(sys.argv) +viewer = NDViewer(data) +viewer.show() +app.exec() diff --git a/xx.py b/xx.py new file mode 100644 index 0000000..5baef72 --- /dev/null +++ b/xx.py @@ -0,0 +1,33 @@ +import numpy as np +from rich import print +from vispy import app, io, scene + +from ndv._chunking import iter_chunk_aligned_slices + +vol1 = np.load(io.load_data_file("volume/stent.npz"))["arr_0"].astype(np.uint16) + +canvas = scene.SceneCanvas(keys="interactive", size=(800, 600), show=True) +view = canvas.central_widget.add_view() +print("--------- create vol") +volume1 = scene.Volume( + np.empty_like(vol1), parent=view.scene, texture_format="auto", clim=(0, 1200) +) +print("--------- create cam") +view.camera = scene.cameras.ArcballCamera(parent=view.scene, name="Arcball") + +# Generate new data to update a subset of the volume + + +slices = iter_chunk_aligned_slices( + vol1.shape, chunks=(32, 32, 32), slices=(slice(None), slice(None), slice(None)) +) + +for slice in list(slices)[::1]: + offset = (x.start for x in slice) + chunk = vol1[slice] + # Update the texture with the new data at the calculated offset + print("--------- update vol") + volume1._texture._set_data(chunk, offset=tuple(offset)) +canvas.update() +print("--------- run app") +app.run() diff --git a/y.py b/y.py new file mode 100644 index 0000000..2855628 --- /dev/null +++ b/y.py @@ -0,0 +1,16 @@ +import numpy as np + +import ndv +from ndv._chunking import Chunker + +data = np.random.rand(10, 3, 8, 5, 128, 128) +wrapper = ndv.DataWrapper.create(data) +slicer = Chunker(wrapper, chunks=(5, 1, 2, 2, 64, 34)) + +index = {0: 2, 1: 2, 2: 0, 3: 4} +idx = wrapper.to_conventional(index) +print(idx) +print(wrapper[idx].shape) + +slicer.request_index(index) +# slicer.shutdown()