Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read the label map in nnUNet service #247

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 57 additions & 12 deletions services/nnunet/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

import os
import subprocess
import json

from pathlib import Path

import logging
import SimpleITK as sitk

from platipy.backend import app, DataObject, celery # pylint: disable=unused-import
from platipy.backend import app, DataObject, celery # pylint: disable=unused-import

logger = logging.getLogger(__name__)

Expand All @@ -32,12 +33,13 @@
"clean_sup_slices": False,
}


def clean_sup_slices(mask):
lssif = sitk.LabelShapeStatisticsImageFilter()
max_slice_size = 0
sizes = {}
for z in range(mask.GetSize()[2]-1, -1, -1):
lssif.Execute(sitk.ConnectedComponent(mask[:,:,z]))
for z in range(mask.GetSize()[2] - 1, -1, -1):
lssif.Execute(sitk.ConnectedComponent(mask[:, :, z]))
if len(lssif.GetLabels()) == 0:
continue

Expand All @@ -48,13 +50,40 @@ def clean_sup_slices(mask):

sizes[z] = phys_size
for z in sizes:
if sizes[z] > max_slice_size/2:
mask[:,:,z+1:mask.GetSize()[2]] = 0
if sizes[z] > max_slice_size / 2:
mask[:, :, z + 1 : mask.GetSize()[2]] = 0
break

return mask


def get_structure_names(task):
# Look up structure names if we can find them dataset.json file
if "nnUNet_raw_data_base" not in os.environ:
logger.info("nnUNet_raw_data_base not set")
return {}

raw_path = Path(os.environ["nnUNet_raw_data_base"])
task_path = raw_path.joinpath("nnUNet_raw_data", task)
dataset_file = task_path.joinpath("dataset.json")

logger.info("Attempting to read %s", dataset_file)

if not dataset_file.exists():
logger.info("dataset.json file does not exist for %s", dataset_file)
return {}

dataset = {}
with open(dataset_file, "r") as f:
dataset = json.load(f)

if "labels" not in dataset:
logger.info("Something went wrong reading dataset.json file")
return {}

return dataset["labels"]


@app.register("nnUNet Service", default_settings=NNUNET_SETTINGS_DEFAULTS)
def nnunet_service(data_objects, working_dir, settings):
"""
Expand All @@ -73,8 +102,10 @@ def nnunet_service(data_objects, working_dir, settings):
output_path = Path(working_dir).joinpath("output")
output_path.mkdir()

for data_object in data_objects:
labels = get_structure_names(settings["task"])
logger.info("Read labels: %s", labels)

for data_object in data_objects:
# Create a symbolic link for each image to auto-segment using the nnUNet
do_path = Path(data_object.path)
io_path = input_path.joinpath(f"{settings['task']}_0000.nii.gz")
Expand Down Expand Up @@ -109,13 +140,28 @@ def nnunet_service(data_objects, working_dir, settings):
subprocess.call(command)

for op in output_path.glob("*.nii.gz"):
label_map = sitk.ReadImage(str(op))

label_map_arr = sitk.GetArrayFromImage(label_map)
label_count = label_map_arr.max()

for label_id in range(1, label_count + 1):
mask = label_map == label_id

label_name = f"Structure_{label_id}"
if str(label_id) in labels:
label_name = labels[str(label_id)]

if settings["clean_sup_slices"]:
mask = sitk.ReadImage(str(op))
mask = clean_sup_slices(mask)
sitk.WriteImage(mask, str(op))
if settings["clean_sup_slices"]:
mask = clean_sup_slices(mask)

output_data_object = DataObject(type="FILE", path=str(op), parent=data_object)
mask_file = output_path.joinpath(f"{label_name}.nii.gz")

sitk.WriteImage(mask, str(mask_file))

output_data_object = DataObject(
type="FILE", path=str(mask_file), parent=data_object
)
output_objects.append(output_data_object)

os.remove(io_path)
Expand All @@ -126,7 +172,6 @@ def nnunet_service(data_objects, working_dir, settings):


if __name__ == "__main__":

# Run app by calling "python service.py" from the command line

DICOM_LISTENER_PORT = 7777
Expand Down
Loading