From f3d6a99b54d8c2d6a6124d894fc3b7379c1c1de2 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Fri, 15 Dec 2023 09:05:11 +1100 Subject: [PATCH] Read the label map in nnUNet service --- services/nnunet/service.py | 69 +++++++++++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/services/nnunet/service.py b/services/nnunet/service.py index af101cff..03a00e4d 100644 --- a/services/nnunet/service.py +++ b/services/nnunet/service.py @@ -14,13 +14,14 @@ import os import subprocess +import json from pathlib import Path import logging import SimpleITK as sitk -from platipy.backend import app, DataObject, celery # pylint: disable=unused-import +from platipy.backend import app, DataObject, celery # pylint: disable=unused-import logger = logging.getLogger(__name__) @@ -32,12 +33,13 @@ "clean_sup_slices": False, } + def clean_sup_slices(mask): lssif = sitk.LabelShapeStatisticsImageFilter() max_slice_size = 0 sizes = {} - for z in range(mask.GetSize()[2]-1, -1, -1): - lssif.Execute(sitk.ConnectedComponent(mask[:,:,z])) + for z in range(mask.GetSize()[2] - 1, -1, -1): + lssif.Execute(sitk.ConnectedComponent(mask[:, :, z])) if len(lssif.GetLabels()) == 0: continue @@ -48,13 +50,40 @@ def clean_sup_slices(mask): sizes[z] = phys_size for z in sizes: - if sizes[z] > max_slice_size/2: - mask[:,:,z+1:mask.GetSize()[2]] = 0 + if sizes[z] > max_slice_size / 2: + mask[:, :, z + 1 : mask.GetSize()[2]] = 0 break return mask +def get_structure_names(task): + # Look up structure names if we can find them dataset.json file + if "nnUNet_raw_data_base" not in os.environ: + logger.info("nnUNet_raw_data_base not set") + return {} + + raw_path = Path(os.environ["nnUNet_raw_data_base"]) + task_path = raw_path.joinpath("nnUNet_raw_data", task) + dataset_file = task_path.joinpath("dataset.json") + + logger.info("Attempting to read %s", dataset_file) + + if not dataset_file.exists(): + logger.info("dataset.json file does not exist for %s", dataset_file) + return {} + + dataset = {} + with open(dataset_file, "r") as f: + dataset = json.load(f) + + if "labels" not in dataset: + logger.info("Something went wrong reading dataset.json file") + return {} + + return dataset["labels"] + + @app.register("nnUNet Service", default_settings=NNUNET_SETTINGS_DEFAULTS) def nnunet_service(data_objects, working_dir, settings): """ @@ -73,8 +102,10 @@ def nnunet_service(data_objects, working_dir, settings): output_path = Path(working_dir).joinpath("output") output_path.mkdir() - for data_object in data_objects: + labels = get_structure_names(settings["task"]) + logger.info("Read labels: %s", labels) + for data_object in data_objects: # Create a symbolic link for each image to auto-segment using the nnUNet do_path = Path(data_object.path) io_path = input_path.joinpath(f"{settings['task']}_0000.nii.gz") @@ -109,13 +140,28 @@ def nnunet_service(data_objects, working_dir, settings): subprocess.call(command) for op in output_path.glob("*.nii.gz"): + label_map = sitk.ReadImage(str(op)) + + label_map_arr = sitk.GetArrayFromImage(label_map) + label_count = label_map_arr.max() + + for label_id in range(1, label_count + 1): + mask = label_map == label_id + + label_name = f"Structure_{label_id}" + if str(label_id) in labels: + label_name = labels[str(label_id)] - if settings["clean_sup_slices"]: - mask = sitk.ReadImage(str(op)) - mask = clean_sup_slices(mask) - sitk.WriteImage(mask, str(op)) + if settings["clean_sup_slices"]: + mask = clean_sup_slices(mask) - output_data_object = DataObject(type="FILE", path=str(op), parent=data_object) + mask_file = output_path.joinpath(f"{label_name}.nii.gz") + + sitk.WriteImage(mask, str(mask_file)) + + output_data_object = DataObject( + type="FILE", path=str(mask_file), parent=data_object + ) output_objects.append(output_data_object) os.remove(io_path) @@ -126,7 +172,6 @@ def nnunet_service(data_objects, working_dir, settings): if __name__ == "__main__": - # Run app by calling "python service.py" from the command line DICOM_LISTENER_PORT = 7777