Skip to content

Commit

Permalink
first masking roi implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzocerrone committed Oct 8, 2024
1 parent 5034bf6 commit 68fd201
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 84 deletions.
42 changes: 21 additions & 21 deletions docs/notebooks/image.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,17 @@
"array = image.array\n",
"print(f\"{array}\")\n",
"\n",
"dask_array = image.dask_array\n",
"dask_array = image.on_disk_dask_array\n",
"dask_array"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note, directly accessing the `.array` or `.dask_array` attributes will load the image as stored in the file.\n",
"Note, directly accessing the `.on_disk_array` or `.on_disk_dask_array` attributes will load the image as stored in the file.\n",
"\n",
"Since in principle the images can have different axes order. A safer way to access the image data is to use the `.get_data()` method, which will return the image data in canonical order (TCZYX)."
"Since in principle the images can have different axes order. A safer way to access the image data is to use the `.array` method, which will return the image data in canonical order (TCZYX)."
]
},
{
Expand All @@ -91,7 +91,7 @@
"metadata": {},
"outputs": [],
"source": [
"image_numpy = image.get_data(c=0, x=slice(0, 250), y=slice(0, 250), preserve_dimensions=False, mode=\"numpy\")\n",
"image_numpy = image.array(c=0, x=slice(0, 250), y=slice(0, 250), preserve_dimensions=False, mode=\"numpy\")\n",
"\n",
"print(f\"{image_numpy.shape=}\")"
]
Expand All @@ -115,7 +115,7 @@
"roi = roi_table.get_roi(\"FOV_1\")\n",
"print(f\"{roi=}\")\n",
"\n",
"image_roi_1 = image.get_data_from_roi(roi=roi, c=0, preserve_dimensions=True, mode=\"dask\")\n",
"image_roi_1 = image.array_from_roi(roi=roi, c=0, preserve_dimensions=True, mode=\"dask\")\n",
"image_roi_1"
]
},
Expand All @@ -138,10 +138,10 @@
"print(f\"{image_2.pixel_size=}\")\n",
"\n",
"# Get roi for higher resolution image\n",
"image_1_roi_1 = image.get_data_from_roi(roi=roi, c=0, preserve_dimensions=False)\n",
"image_1_roi_1 = image.array_from_roi(roi=roi, c=0, preserve_dimensions=False)\n",
"\n",
"# Get roi for lower resolution image\n",
"image_2_roi_1 = image_2.get_data_from_roi(roi=roi, c=0, preserve_dimensions=False)\n",
"image_2_roi_1 = image_2.array_from_roi(roi=roi, c=0, preserve_dimensions=False)\n",
"\n",
"# Plot the two images side by side\n",
"fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n",
Expand All @@ -156,7 +156,7 @@
"source": [
"# Writing Images\n",
"\n",
"Similarly to the `.get_data()` we can use the `.set_data()` (or `set_data_roi`) method to write part of an image to disk."
"Similarly to the `.array()` we can use the `.set_array()` (or `set_array_from_roi`) method to write part of an image to disk."
]
},
{
Expand All @@ -168,25 +168,25 @@
"import numpy as np\n",
"\n",
"# Get a small slice of the image\n",
"small_slice = image.get_data(x=slice(1000, 2000), y=slice(1000, 2000))\n",
"small_slice = image.array(x=slice(1000, 2000), y=slice(1000, 2000))\n",
"\n",
"# Set the sample slice to zeros\n",
"zeros_slice = np.zeros_like(small_slice)\n",
"image.set_data(patch=zeros_slice, x=slice(1000, 2000), y=slice(1000, 2000))\n",
"image.set_array(patch=zeros_slice, x=slice(1000, 2000), y=slice(1000, 2000))\n",
"\n",
"\n",
"# Load the image from disk and show the edited image\n",
"nuclei = ngff_image.label.get(\"nuclei\")\n",
"fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n",
"axs[0].imshow(image.array[0, 0], cmap=\"gray\")\n",
"axs[1].imshow(nuclei.array[0])\n",
"axs[0].imshow(image.on_disk_array[0, 0], cmap=\"gray\")\n",
"axs[1].imshow(nuclei.on_disk_array[0])\n",
"for ax in axs:\n",
" ax.axis(\"off\")\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"# Add back the original slice to the image\n",
"image.set_data(patch=small_slice, x=slice(1000, 2000), y=slice(1000, 2000))"
"image.set_array(patch=small_slice, x=slice(1000, 2000), y=slice(1000, 2000))"
]
},
{
Expand All @@ -207,13 +207,13 @@
"# Create a a new label object and set it to a simple segmentation\n",
"new_label = ngff_image.label.derive(\"new_label\", overwrite=True)\n",
"\n",
"simple_segmentation = image.array[0] > 100\n",
"new_label.array[...] = simple_segmentation\n",
"simple_segmentation = image.on_disk_array[0] > 100\n",
"new_label.on_disk_array[...] = simple_segmentation\n",
"\n",
"# make a subplot with two image show side by side\n",
"fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n",
"axs[0].imshow(image.array[0, 0], cmap=\"gray\")\n",
"axs[1].imshow(new_label.array[0], cmap=\"gray\")\n",
"axs[0].imshow(image.on_disk_array[0, 0], cmap=\"gray\")\n",
"axs[1].imshow(new_label.on_disk_array[0], cmap=\"gray\")\n",
"for ax in axs:\n",
" ax.axis(\"off\")\n",
"plt.tight_layout()\n",
Expand All @@ -240,12 +240,12 @@
"label_0 = ngff_image.label.get(\"new_label\", path=\"0\")\n",
"label_2 = ngff_image.label.get(\"new_label\", path=\"2\")\n",
"\n",
"label_before_consolidation = label_2.array[...]\n",
"label_before_consolidation = label_2.on_disk_array[...]\n",
"\n",
"# Consolidate the label\n",
"label_0.consolidate()\n",
"\n",
"label_after_consolidation = label_2.array[...]\n",
"label_after_consolidation = label_2.on_disk_array[...]\n",
"\n",
"\n",
"# make a subplot with two image show side by side\n",
Expand Down Expand Up @@ -285,7 +285,7 @@
"# Create a table with random features for each nuclei in each ROI\n",
"list_of_records = []\n",
"for roi in roi_table.list_rois:\n",
" nuclei_in_roi = nuclei.get_data_from_roi(roi, mode='numpy')\n",
" nuclei_in_roi = nuclei.array_from_roi(roi, mode='numpy')\n",
" for nuclei_id in np.unique(nuclei_in_roi)[1:]:\n",
" list_of_records.append(\n",
" {\"label\": nuclei_id,\n",
Expand Down Expand Up @@ -333,7 +333,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
"version": "3.12.7"
}
},
"nbformat": 4,
Expand Down
27 changes: 17 additions & 10 deletions docs/notebooks/processing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@
"# - set the data in the MIP image\n",
"for roi in roi_table.list_rois:\n",
" print(f\" - Processing ROI {roi.infos.get(\"field_index\")}\")\n",
" patch = source_image.get_data_from_roi(roi)\n",
" patch = source_image.array_from_roi(roi)\n",
" mip_patch = patch.max(axis=1, keepdims=True)\n",
" mip_image.set_data_from_roi(patch=mip_patch, roi=roi)\n",
" mip_image.set_array_from_roi(patch=mip_patch, roi=roi)\n",
" \n",
"print(\"MIP image saved\")\n",
"\n",
"plt.figure(figsize=(5, 5))\n",
"plt.title(\"Mip\")\n",
"plt.imshow(mip_image.array[0, 0, :, :], cmap=\"gray\")\n",
"plt.imshow(mip_image.on_disk_array[0, 0, :, :], cmap=\"gray\")\n",
"plt.axis('off')\n",
"plt.tight_layout()\n",
"plt.show()\n"
Expand All @@ -124,12 +124,12 @@
"# Get the MIP image at a lower resolution\n",
"mip_image_2 = mip_ngff.get_image(path=\"2\")\n",
"\n",
"image_before_consolidation = mip_image_2.get_data(c=0, z=0)\n",
"image_before_consolidation = mip_image_2.array(c=0, z=0)\n",
"\n",
"# Consolidate the pyramid\n",
"mip_image.consolidate()\n",
"\n",
"image_after_consolidation = mip_image_2.get_data(c=0, z=0)\n",
"image_after_consolidation = mip_image_2.array(c=0, z=0)\n",
"\n",
"fig, axs = plt.subplots(2, 1, figsize=(10, 5))\n",
"axs[0].set_title(\"Before consolidation\")\n",
Expand Down Expand Up @@ -256,28 +256,35 @@
"max_label = 0\n",
"for roi in roi_table.list_rois:\n",
" print(f\" - Processing ROI {roi.infos.get(\"field_index\")}\")\n",
" patch = source_image.get_data_from_roi(roi, c=dapi_idx)\n",
" patch = source_image.array_from_roi(roi, c=dapi_idx)\n",
" segmentation = otsu_threshold_segmentation(patch, max_label)\n",
"\n",
" # Add the max label of the previous segmentation to avoid overlapping labels\n",
" max_label = segmentation.max()\n",
"\n",
" nuclei_image.set_data_from_roi(patch=segmentation, roi=roi)\n",
" nuclei_image.set_array_from_roi(patch=segmentation, roi=roi)\n",
"\n",
"# Consolidate the segmentation image\n",
"nuclei_image.consolidate()\n",
"\n",
"print(\"Segmentation image saved\")\n",
"fig, axs = plt.subplots(2, 1, figsize=(10, 5))\n",
"axs[0].set_title(\"MIP\")\n",
"axs[0].imshow(source_image.array[0, 0], cmap=\"gray\")\n",
"axs[0].imshow(source_image.on_disk_array[0, 0], cmap=\"gray\")\n",
"axs[1].set_title(\"Nuclei segmentation\")\n",
"axs[1].imshow(nuclei_image.array[0], cmap=rand_cmap, interpolation='nearest')\n",
"axs[1].imshow(nuclei_image.on_disk_array[0], cmap=rand_cmap, interpolation='nearest')\n",
"for ax in axs:\n",
" ax.axis('off')\n",
"plt.tight_layout()\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -296,7 +303,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
"version": "3.12.7"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies = [
"aiohttp",
"dask[array]",
"dask[distributed]",
"dask-image",
]

# https://peps.python.org/pep-0621/#dependencies-optional-dependencies
Expand Down
51 changes: 51 additions & 0 deletions src/ngio/core/image_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""A module to handle OME-NGFF images stored in Zarr format."""

from typing import Literal

import dask.array as da
import numpy as np

from ngio._common_types import ArrayLike
from ngio.core.image_like_handler import ImageLike
from ngio.core.roi import WorldCooROI
from ngio.io import StoreOrGroup
from ngio.ngff_meta.fractal_image_meta import ImageMeta, PixelSize

Expand All @@ -22,6 +29,7 @@ def __init__(
highest_resolution: bool = False,
strict: bool = True,
cache: bool = True,
label_group=None,
) -> None:
"""Initialize the the Image Object.
Expand All @@ -47,6 +55,7 @@ def __init__(
meta_mode="image",
cache=cache,
)
self._label_group = label_group

@property
def metadata(self) -> ImageMeta:
Expand All @@ -65,3 +74,45 @@ def get_channel_idx(
) -> int:
"""Return the index of the channel."""
return self.metadata.get_channel_idx(label=label, wavelength_id=wavelength_id)

def masked_array(
self,
roi: WorldCooROI,
t: int | slice | None = None,
c: int | slice | None = None,
mask_mode: Literal["bbox", "mask"] = "bbox",
mode: Literal["numpy"] = "numpy",
preserve_dimensions: bool = False,
) -> ArrayLike:
"""Return the image data from a region of interest (ROI).
Args:
roi (WorldCooROI): The region of interest.
t (int | slice | None): The time index or slice.
c (int | slice | None): The channel index or slice.
mask_mode (str): Masking mode
mode (str): The mode to return the data.
preserve_dimensions (bool): Whether to preserve the dimensions of the data.
"""
data_pipe = self._build_roi_pipe(
roi=roi, t=t, c=c, preserve_dimensions=preserve_dimensions
)

if mask_mode == "bbox":
return self._get_pipe(data_pipe=data_pipe, mode=mode)

label = self._label_group.get(
roi.infos["label_name"], pixel_size=self.pixel_size
)

mask = label.mask(
roi,
t=t,
mode=mode,
)
array = self._get_pipe(data_pipe=data_pipe, mode=mode)
if mode == "numpy":
return_array = np.where(mask, array, 0)
else:
return_array = da.where(mask, array, 0)
return return_array
Loading

0 comments on commit 68fd201

Please sign in to comment.