Skip to content

Commit

Permalink
test memory reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Sep 26, 2023
1 parent 4117af4 commit 51b0fa3
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 13 deletions.
67 changes: 54 additions & 13 deletions src/sparcscore/pipeline/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def call_as_shard(self):
This function is intented for internal use by the :class:`ShardedSegmentation` helper class. In most cases it is not relevant to the creation of custom segmentation workflows.
"""
self.log(f"Beginning Sharding of Shard with the slicing {self.window}")
self.log(f"Beginning Segmentation of Shard with the slicing {self.window}")

with h5py.File(self.input_path, "r") as hf:
hdf_input = hf.get("channels")
Expand All @@ -136,14 +136,11 @@ def call_as_shard(self):
y = y2 - y1

#initialize directory and load data
self.log(f"Generating a memory mapped temp array with the dimensions {(c, x, y)}")
input_image = tempmmap.array(shape = (c, x, y), dtype = float)
input_image = hdf_input[:, self.window[0], self.window[1]]
self.log(f"Generating a memory mapped temp array with the dimensions {(2, x, y)}")
input_image = tempmmap.array(shape = (2, x, y), dtype = np.uint16)
input_image = hdf_input[:2, self.window[0], self.window[1]]
self.log(f"Input image loaded and mapped to memory.")

if input_image.dtype != float:
input_image = input_image.astype(float)

#perform check to see if any input pixels are not 0, if so perform segmentation, else return array of zeros.
if sc_any(input_image):
try:
Expand All @@ -161,8 +158,6 @@ def call_as_shard(self):
del input_image
gc.collect()

shutil.rmtree(TEMP_DIR_NAME)

def save_segmentation(self, channels, labels, classes):
"""Saves the results of a segmentation at the end of the process.
Expand All @@ -178,22 +173,23 @@ def save_segmentation(self, channels, labels, classes):
# size (C, H, W) is expected
# dims are expanded in case (H, W) is passed

channels = (np.expand_dims(channels, axis=0) if len(channels.shape) == 2 else channels)
labels = np.expand_dims(labels, axis=0) if len(labels.shape) == 2 else labels

map_path = os.path.join(self.directory, self.DEFAULT_OUTPUT_FILE)
hf = h5py.File(map_path, "w")
hf = h5py.File(map_path, "a")

hf.create_dataset(
"labels",
data=labels,
chunks=(1, self.config["chunk_size"], self.config["chunk_size"]),
)
#also save channels
hf.create_dataset(
"channels",
data=channels,
chunks=(1, self.config["chunk_size"], self.config["chunk_size"]),
)

hf.close()

# save classes
Expand Down Expand Up @@ -380,6 +376,15 @@ def save_image(self, array, save_name="", cmap="magma", **kwargs):

def get_output(self):
return os.path.join(self.directory, self.DEFAULT_OUTPUT_FILE)

def process(self, input_image):
self.save_zarr = True
self.save_input_image(input_image)

self.log("Finished segmentation")

#make sure to cleanup temp directories
self.log("=== finished segmentation === ")


class ShardedSegmentation(Segmentation):
Expand All @@ -406,7 +411,7 @@ def __init__(self, *args, **kwargs):
)

self.save_zarr = False

def save_input_image(self, input_image):

output = os.path.join(self.directory, self.DEFAULT_OUTPUT_FILE)
Expand All @@ -424,6 +429,43 @@ def save_input_image(self, input_image):

self.log("Input image added to .h5. Provides data source for reading shard information.")

def save_segmentation(self, channels, labels, classes):
"""Saves the results of a segmentation at the end of the process. For the sharded segmentation no channels are passed because they have already been saved
Args:
labels (np.array): Numpy array of shape ``(height, width)``. Labels are all data which are saved as integer values. These are mostly segmentation maps with integer values corresponding to the labels of cells.
classes (list(int)): List of all classes in the labels array, which have passed the filtering step. All classes contained in this list will be extracted.
"""
self.log("saving segmentation")

# size (C, H, W) is expected
# dims are expanded in case (H, W) is passed

labels = np.expand_dims(labels, axis=0) if len(labels.shape) == 2 else labels

map_path = os.path.join(self.directory, self.DEFAULT_OUTPUT_FILE)
hf = h5py.File(map_path, "w")

hf.create_dataset(
"labels",
data=labels,
chunks=(1, self.config["chunk_size"], self.config["chunk_size"]),
)

hf.close()

# save classes
filtered_path = os.path.join(self.directory, self.DEFAULT_FILTER_FILE)

to_write = "\n".join([str(i) for i in list(classes)])
with open(filtered_path, "w") as myfile:
myfile.write(to_write)

self.log("=== finished segmentation ===")
self.save_segmentation_zarr(labels = labels)

def initialize_shard_list(self, sharding_plan):
_shard_list = []

Expand Down Expand Up @@ -1055,7 +1097,6 @@ def process(self):
# initialize temp object to write segmentations too
self._initialize_tempmmap_array()


segmentation_list = self.initialize_shard_list(indexes, input_path=input_path)

# make more verbose output for troubleshooting and timing purposes.
Expand Down
4 changes: 4 additions & 0 deletions src/sparcscore/pipeline/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,7 @@ def process(self, input_image):
"nucleus_segmentation": tempmmap.array(shape = downsampled_image_size, dtype = np.uint16),
"cytosol_segmentation": tempmmap.array(shape = downsampled_image_size, dtype = np.uint16),
}
self.log("Created memory mapped temp arrays to store")

# could add a normalization step here if so desired
#perform downsampling after saving input image to ensure that we have a duplicate preserving the original dimensions
Expand All @@ -996,9 +997,12 @@ def process(self, input_image):
self.log(f"input image size: {input_image.shape}")

input_image = input_image[:2, :, :] #only get the first 2 channels for segmentation (does not use excess space on the GPU this way)
gc.collect()

self.log(f"input image size after removing excess channels: {input_image.shape}")
input_image = np.pad(input_image, ((0, 0), pad_x, pad_y))
_size_padding = input_image.shape

self.log(f"Performing image padding to ensure that image is compatible with downsample kernel size. Original image was {_size}, padded image is {_size_padding}")
input_image = downsample_img(input_image, N= N)
self.log(f"Downsampled image size {input_image.shape}")
Expand Down

0 comments on commit 51b0fa3

Please sign in to comment.