diff --git a/sources/zarr/large_image_source_zarr/__init__.py b/sources/zarr/large_image_source_zarr/__init__.py index b99b423be..43925be65 100644 --- a/sources/zarr/large_image_source_zarr/__init__.py +++ b/sources/zarr/large_image_source_zarr/__init__.py @@ -363,6 +363,7 @@ def _readFrameValues(self, found, baseArray): if axes_values.get(a) is not None ] self._frameUnits = {k: axes_units.get(k) for k in self.frameAxes if k in axes_units} + self._frameValues = None frame_values_shape = [baseArray.shape[self._axes[a]] for a in self.frameAxes] frame_values_shape.append(len(frame_values_shape)) frame_values = np.empty(frame_values_shape, dtype=object) @@ -387,7 +388,8 @@ def _readFrameValues(self, found, baseArray): if name: slicing[self._frameAxes.index(name)] = j frame_values[tuple(slicing)] = value - self._frameValues = frame_values + if frame_values.size > 0: + self._frameValues = frame_values def _validateZarr(self): """ @@ -645,6 +647,50 @@ def _validateNewTile(self, tile, mask, placement, axes): return tile, mask, placement, axes + def _updateFrameValues(self, frame_values, placement, axes, new_axes, new_dims): + self._frameAxes = [ + a for a in axes + if a in frame_values or + (self.frameAxes is not None and a in self.frameAxes) + ] + frames_shape = [new_dims[a] for a in self.frameAxes] + frames_shape.append(len(frames_shape)) + if self.frameValues is None: + self._frameValues = np.empty(frames_shape, dtype=object) + elif self.frameValues.shape != frames_shape: + if len(new_axes): + for i in new_axes.values(): + self._frameValues = np.expand_dims(self._frameValues, axis=i) + frame_padding = [ + (0, s - self.frameValues.shape[i]) + for i, s in enumerate(frames_shape) + ] + frame_padding[-1] = (0, 0) + self._frameValues = np.pad(self._frameValues, frame_padding) + for i in new_axes.values(): + self._frameValues = np.insert( + self._frameValues, i, 0, axis=len(frames_shape) - 1, + ) + current_frame_slice = tuple(placement.get(a) for a in self.frameAxes) + for i, k in enumerate(self.frameAxes): + self.frameValues[(*current_frame_slice, i)] = frame_values.get(k) + + def _resizeImage(self, arr, new_shape, new_axes, chunking): + if new_shape != arr.shape: + if len(new_axes): + for i in new_axes.values(): + arr = np.expand_dims(arr, axis=i) + arr = np.pad( + arr, + [(0, s - arr.shape[i]) for i, s in enumerate(new_shape)], + ) + new_arr = zarr.empty(new_shape, chunks=chunking, dtype=arr.dtype) + new_arr[:] = arr[:] + arr = new_arr + else: + arr.resize(*new_shape) + return arr + def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs): """ Add a numpy or image tile to the image, expanding the image as needed @@ -665,6 +711,12 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs): ``level`` is a reserved word and not permitted for an axis name. """ self._checkEditable() + try: + # read any info written by other processes + self._validateZarr() + except TileSourceError: + pass + updateMetadata = False store_path = str(kwargs.pop('level', 0)) placement = { 'x': x, @@ -678,9 +730,12 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs): tile, mask, placement, axes = self._validateNewTile(tile, mask, placement, axes) with self._threadLock and self._processLock: + old_axes = self._axes if hasattr(self, '_axes') else {} self._axes = {k: i for i, k in enumerate(axes)} + new_axes = {k: i for k, i in self._axes.items() if k not in old_axes} new_dims = { a: max( + self._axisCounts.get(a, 0) if hasattr(self, '_axisCounts') else 0, self._dims.get(store_path, {}).get(a, 0), placement.get(a, 0) + tile.shape[i], ) @@ -694,23 +749,8 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs): if len(frame_values.keys()) > 0: # update self.frameValues - self.frameAxes = [ - a for a in axes - if a in frame_values or - (self.frameAxes is not None and a in self.frameAxes) - ] - frames_shape = [new_dims[a] for a in self.frameAxes] - frames_shape.append(len(frames_shape)) - if self.frameValues is None: - self.frameValues = np.empty(frames_shape, dtype=object) - elif self.frameValues.shape != frames_shape: - self.frameValues = np.pad( - self.frameValues, - [(0, s - self.frameValues.shape[i]) for i, s in enumerate(frames_shape)], - ) - current_frame_slice = tuple(placement.get(a) for a in self.frameAxes) - for i, k in enumerate(self.frameAxes): - self.frameValues[(*current_frame_slice, i)] = frame_values.get(k) + updateMetadata = True + self._updateFrameValues(frame_values, placement, axes, new_axes, new_dims) current_arrays = dict(self._zarr.arrays()) if store_path == '0': @@ -728,16 +768,18 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs): ]) else: arr = current_arrays[store_path] - new_shape = tuple(max(v, arr.shape[i]) for i, v in enumerate(new_dims.values())) - if new_shape != arr.shape: - arr.resize(*new_shape) - if arr.chunks[-1] != new_dims.get('s'): - # rechunk if length of samples axis changes - chunking = tuple([ - self._tileSize if a in ['x', 'y'] else - new_dims.get('s') if a == 's' else 1 - for a in axes - ]) + new_shape = tuple( + max(v, arr.shape[old_axes[k]] if k in old_axes else 0) + for k, v in new_dims.items() + ) + if arr.chunks[-1] != new_dims.get('s') or len(new_axes): + # rechunk if length of samples axis changed or any new axis added + chunking = tuple([ + self._tileSize if a in ['x', 'y'] else + new_dims.get('s') if a == 's' else 1 + for a in axes + ]) + arr = self._resizeImage(arr, new_shape, new_axes, chunking) if mask is not None: arr[placement_slices] = np.where(mask, tile, arr[placement_slices]) @@ -767,6 +809,8 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs): self._levels = None self.levels = int(max(1, math.ceil(math.log(max( self.sizeX / self.tileWidth, self.sizeY / self.tileHeight)) / math.log(2)) + 1)) + if updateMetadata: + self._writeInternalMetadata() def addAssociatedImage(self, image, imageKey=None): """ @@ -1002,6 +1046,7 @@ def frameAxes(self): def frameAxes(self, axes): self._checkEditable() self._frameAxes = axes + self._writeInternalMetadata() @property def frameUnits(self): @@ -1034,6 +1079,7 @@ def frameValues(self, a): err = f'frameValues must have {len(self.frameAxes) + 1} dimensions.' raise ValueError(err) self._frameValues = a + self._writeInternalMetadata() def _generateDownsampledLevels(self, resample_method): self._checkEditable() diff --git a/test/test_sink.py b/test/test_sink.py index 37db5f1a3..6d2194974 100644 --- a/test/test_sink.py +++ b/test/test_sink.py @@ -600,10 +600,6 @@ def testFrameValuesSmall(use_add_tile_args, tmp_path): )) expected_metadata = get_expected_metadata(axis_spec, frame_shape) - sink.frameAxes = list(axis_spec.keys()) - sink.frameUnits = { - k: v['units'] for k, v in axis_spec.items() - } frame_values_shape = [ *[len(v['values']) for v in axis_spec.values()], len(axis_spec), @@ -633,7 +629,11 @@ def testFrameValuesSmall(use_add_tile_args, tmp_path): index += 1 if not use_add_tile_args: + sink.frameAxes = list(axis_spec.keys()) sink.frameValues = frame_values + sink.frameUnits = { + k: v['units'] for k, v in axis_spec.items() + } compare_metadata(dict(sink.getMetadata()), expected_metadata) sink.write(output_file) @@ -686,10 +686,6 @@ def testFrameValues(use_add_tile_args, tmp_path): ) expected_metadata = get_expected_metadata(axis_spec, frame_shape) - sink.frameAxes = list(axis_spec.keys()) - sink.frameUnits = { - k: v['units'] for k, v in axis_spec.items() - } frame_values_shape = [ *[len(v['values']) for v in axis_spec.values()], len(axis_spec), @@ -727,7 +723,11 @@ def testFrameValues(use_add_tile_args, tmp_path): index += 1 if not use_add_tile_args: + sink.frameAxes = list(axis_spec.keys()) sink.frameValues = frame_values + sink.frameUnits = { + k: v['units'] for k, v in axis_spec.items() + } compare_metadata(dict(sink.getMetadata()), expected_metadata) sink.write(output_file) @@ -783,19 +783,81 @@ def testSubprocess(tmp_path): subprocess.run([sys.executable, '-c', """import large_image_source_zarr import numpy as np sink = large_image_source_zarr.open('%s') -sink.addTile(np.ones((1, 1, 1)), x=2047, y=2047, t=5, z=2) +sink.addTile(np.ones((1, 1, 1)), x=2047, y=2047, t=5, z=2, t_value='thursday', z_value=0.2) """ % path], capture_output=True, text=True, check=True) - sink.addTile(np.ones((1, 1, 1)), x=5000, y=4095, t=0, z=4) + sink.addTile(np.ones((1, 1, 1)), x=5000, y=4095, t=0, z=4, t_value='sunday', z_value=0.4) - assert sink.metadata['IndexRange']['IndexZ'] == 5 + metadata = sink.getMetadata() + assert metadata['IndexRange']['IndexZ'] == 5 assert sink.getRegion( region=dict(left=2047, top=2047, width=1, height=1), format='numpy', frame=17, )[0] == 1 + assert metadata['ValueT']['values'][17] == 'thursday' + assert metadata['ValueZ']['values'][17] == 0.2 assert sink.getRegion( region=dict(left=5000, top=4095, width=1, height=1), format='numpy', frame=24, )[0] == 1 + assert metadata['ValueT']['values'][24] == 'sunday' + assert metadata['ValueZ']['values'][24] == 0.4 assert sink.sizeX == 5001 + + +@pytest.mark.parametrize('axes_order', ['tzd', 'tdz', 'dzt', 'dtz', 'ztd', 'zdt']) +def testAddAxes(tmp_path, axes_order): + sink = large_image_source_zarr.new() + kwarg_groups = [ + dict(t=0, t_value='sunday'), + dict( + t=5, t_value='friday', + z=1, z_value=0.1, + axes=axes_order.replace('d', '') + 'yxs', + ), + dict( + t=6, t_value='saturday', + z=2, z_value=0.2, + d=1, d_value=100, + axes=axes_order + 'yxs', + ), + ] + for kwarg_group in kwarg_groups: + sink.addTile( + np.ones((4, 4, 4)), + x=1020, y=1020, + **kwarg_group, + ) + + metadata = sink.getMetadata() + t_values = metadata['ValueT']['values'] + z_values = metadata['ValueZ']['values'] + d_values = metadata['ValueD']['values'] + t_stride = metadata['IndexStride']['IndexT'] + z_stride = metadata['IndexStride']['IndexZ'] + expected_filled_frames = [ + # first and last frame are known, middle frame depends on axis ordering + 0, z_stride + t_stride * 5, 41, + ] + for frame in metadata.get('frames', []): + frame_index = frame.get('Frame') + sample = sink.getRegion( + region=dict(left=1020, top=1020, width=1, height=1), + format='numpy', + frame=frame_index, + )[0] + frame_values = dict( + t_value=t_values[frame_index], + z_value=z_values[frame_index], + d_value=d_values[frame_index], + ) + kwarg_group = {} + if frame_index in expected_filled_frames: + kwarg_group = kwarg_groups[expected_filled_frames.index(frame_index)] + assert (sample == 1).all() + else: + assert (sample == 0).all() + + for k, v in frame_values.items(): + assert v == kwarg_group.get(k, 0)