Skip to content

Commit

Permalink
update sharding to run with overlap and add downsampling segmentation…
Browse files Browse the repository at this point in the history
… workflow
  • Loading branch information
sophiamaedler committed Sep 25, 2023
1 parent 2446b74 commit b95800f
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 8 deletions.
29 changes: 27 additions & 2 deletions src/sparcscore/pipeline/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,14 +443,28 @@ def calculate_sharding_plan(self, image_size):
upper_y = (y + 1) * shard_size[0]
upper_x = (x + 1) * shard_size[1]

#add px overlap to each shard
lower_y = lower_y - self.config["overlap_px"]
lower_x = lower_x - self.config["overlap_px"]
upper_y = upper_y + self.config["overlap_px"]
upper_x = upper_x + self.config["overlap_px"]

#make sure that each limit stays within the slides
if lower_y < 0:
lower_y = 0
if lower_x < 0:
lower_x = 0

if last_row:
upper_y = image_size[0]

if last_column:
upper_x = image_size[1]

shard = (slice(lower_y, upper_y), slice(lower_x, upper_x))
print(shard)
_sharding_plan.append(shard)

return _sharding_plan

def resolve_sharding(self, sharding_plan):
Expand Down Expand Up @@ -510,6 +524,7 @@ def resolve_sharding(self, sharding_plan):

filtered_classes_combined = []
edge_classes_combined = []

for i, window in enumerate(sharding_plan):
self.log(f"Stitching tile {i}")

Expand All @@ -532,6 +547,17 @@ def resolve_sharding(self, sharding_plan):
shifted_map, edge_labels = shift_labels(
local_hdf_labels, class_id_shift, return_shifted_labels=True
)

orig_input = hdf_labels[:, window[0], window[1]]
shifted_map = np.where((orig_input != 0) & (shifted_map == 0), orig_input, shifted_map)
#since shards are computed with overlap there potentially alreadty exist segmentations in the selected area that we wish to keep
# if orig_input has a value that is not 0 (i.e. background) and the new map would replace this with 0 then we should keep the original value, in all other cases we should overwrite the values with the
# new ones from the second shard
# this will result in cell ids that are missing in the file but does not matter as all cell ids will be unique
# any ids that were on the shard edges will be removed

#potential issue: this does not check if we create a cytosol without a matching nucleus? But this should have been implemented in altanas segmentation method
# for other segmentation methods this could cause issues

hdf_labels[:, window[0], window[1]] = shifted_map

Expand Down Expand Up @@ -592,7 +618,6 @@ def resolve_sharding(self, sharding_plan):

self.log("resolved sharding plan.")


#add segmentation results to ome.zarr
self.save_zarr = True

Expand All @@ -607,7 +632,7 @@ def resolve_sharding(self, sharding_plan):

# Add section here that cleans up the results from the tiles and deletes them to save memory
self.log("Deleting intermediate tile results to free up storage space")
shutil.rmtree(self.shard_directory)
shutil.rmtree(self.shard_directory, ignore_errors=True)

def process(self, input_image):
self.save_zarr = False
Expand Down
4 changes: 3 additions & 1 deletion src/sparcscore/pipeline/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
from shapely.geometry import Polygon
from rasterio.features import rasterize
import xarray as xr

def _read_napari_csv(path):
# read csv table
Expand All @@ -25,4 +26,5 @@ def _read_napari_csv(path):
def _generate_mask_polygon(poly, outshape):
x, y = outshape
img = rasterize(poly, out_shape = (x, y))
return(img.astype("bool"))
return(img.astype("bool"))

Loading

0 comments on commit b95800f

Please sign in to comment.